In [1]:
import json

import torch

from data import ScanDataset,ScanAugmentedDataset,MTDataset,SCAN_collate
from SymbolicOperator import SymbolicOperator

In [2]:
with open('vocab.json','r') as f:
    vocab = json.load(f)

In [3]:
in_vocab_size = len(vocab['in_token_to_idx'])
out_vocab_size = len(vocab['out_idx_to_token'])

In [4]:
model = SymbolicOperator(in_vocab_size, out_vocab_size)

In [5]:
checkpoint_path = 'checkpoints/b043faab3a8d43d2a7ceb3ee510cd2a2'
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
model.eval()

SymbolicOperator(
  (attention): Attention()
  (gate_embedding): Embedding(16, 1)
  (program_embedding): Embedding(16, 200)
  (primitive_embedding): Embedding(16, 200)
  (gate_linear): Linear(in_features=128, out_features=1, bias=True)
  (executor_rnn_cell): GRUCell(1, 384)
  (out_linear): Linear(in_features=200, out_features=8, bias=True)
)

In [9]:
instruction_text = '<SOS> jump twice and walk twice <EOS>'.split()
action_text = '<SOS> I_JUMP I_JUMP I_WALK I_WALK <EOS>'.split()

instruction = []
for i, token in enumerate(instruction_text):
    idx = int(vocab['in_token_to_idx'][token])
    idx = torch.tensor(idx)
    instruction.append(idx)
    
action = []
for i, token in enumerate(action_text):
    idx = int(vocab['out_token_to_idx'][token])
    idx = torch.tensor(idx)
    action.append(idx)
    
instructions = torch.tensor([instruction])
actions = torch.tensor([action])
print(instructions)

with torch.no_grad():
    output, true_actions = model(instructions, actions)

true_actions = true_actions[0]
predicted_actions = output[0].argmax(0)
print(predicted_actions)
print(true_actions)

tensor([[ 0,  1,  4,  5, 12,  4,  8]])
tensor([2, 2, 2, 2, 3])
tensor([2, 2, 4, 4, 3])


In [7]:
scratch_history = []
for word in model.scratch_history:
    scratch_history.append([])
    for step in word:
        gate = step[0][0][0].item()
        read = step[1][0][0].long()
        write = step[2][0][0].long()
        scratch = step[3].argmax(-1)[0]
        scratch_history[-1].append([gate, read, write, scratch])

for i, steps in enumerate(scratch_history):
    word = instruction_text[i+1]
    print('seeing', word)
    for gate, read, write, scratch_pad in steps:
        tokens = []
        for token_idx in scratch_pad:
            token = vocab['out_idx_to_token'][str(token_idx.item())]
            tokens.append(token)
        if gate == 1:
            print('read:', word)
        else:
            print('read:', read)
        print('write:', write)
        print(tokens)
    print()

seeing jump
read: jump
write: tensor([1, 0, 0, 0, 0])
['<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>']
read: jump
write: tensor([1, 0, 0, 0, 0])
['I_JUMP', '<EOS>', '<EOS>', '<EOS>', '<EOS>']
read: jump
write: tensor([1, 0, 0, 0, 0])
['I_JUMP', '<EOS>', '<EOS>', '<EOS>', '<EOS>']
read: jump
write: tensor([0, 1, 0, 0, 0])
['I_JUMP', 'I_JUMP', '<EOS>', '<EOS>', '<EOS>']
read: jump
write: tensor([0, 1, 0, 0, 0])
['I_JUMP', 'I_JUMP', '<EOS>', '<EOS>', '<EOS>']

seeing twice
read: tensor([1, 0, 0, 0, 0])
write: tensor([0, 1, 0, 0, 0])
['I_JUMP', 'I_JUMP', '<EOS>', '<EOS>', '<EOS>']
read: tensor([1, 0, 0, 0, 0])
write: tensor([0, 0, 1, 0, 0])
['I_JUMP', 'I_JUMP', '<EOS>', '<EOS>', '<EOS>']
read: tensor([1, 0, 0, 0, 0])
write: tensor([0, 0, 1, 0, 0])
['I_JUMP', 'I_JUMP', 'I_JUMP', '<EOS>', '<EOS>']
read: tensor([1, 0, 0, 0, 0])
write: tensor([0, 0, 1, 0, 0])
['I_JUMP', 'I_JUMP', 'I_JUMP', '<EOS>', '<EOS>']
read: tensor([1, 0, 0, 0, 0])
write: tensor([0, 0, 1, 0, 0])
['I_JUMP', 'I_JUMP', 'I_JUMP'

In [8]:
model.scratch_history[0]

[[tensor([[1.]]),
  tensor([[[1., 0., 0., 0., 0.]]]),
  tensor([[[1., 0., 0., 0., 0.]]]),
  tensor([[[-4.3914, -4.4068,  2.8510,  3.4906, -1.8638, -3.9698, -3.7215,
            -3.3059],
           [-5.1195, -5.1162, -2.4968,  7.8235, -0.1476, -4.6568, -4.3110,
            -2.7688],
           [-5.2455, -5.2389, -3.4220,  8.5731,  0.1493, -4.7756, -4.4130,
            -2.6759],
           [-5.2466, -5.2400, -3.4303,  8.5799,  0.1519, -4.7767, -4.4139,
            -2.6751],
           [-5.2466, -5.2400, -3.4304,  8.5799,  0.1520, -4.7767, -4.4139,
            -2.6751]]])],
 [tensor([[1.]]),
  tensor([[[1., 0., 0., 0., 0.]]]),
  tensor([[[1., 0., 0., 0., 0.]]]),
  tensor([[[-3.8433, -3.8727,  6.8773,  0.2284, -3.1558, -3.4526, -3.2776,
            -3.7102],
           [-5.1164, -5.1131, -2.4741,  7.8051, -0.1549, -4.6539, -4.3085,
            -2.7711],
           [-5.2455, -5.2389, -3.4220,  8.5731,  0.1493, -4.7756, -4.4130,
            -2.6759],
           [-5.2466, -5.2400, -3.4303,  