In [11]:
import torch

from model.vocabulary import vocabulary_answers as target_vocabulary
from model.vocabulary import vocabulary_expressions as source_vocabulary

import json

In [12]:
from model.model import Model

In [13]:
config_path = "./configs/default.json"
print(f"Using configuration: {config_path}")

file = open(config_path, "r")
config = json.load(file)
file.close()

Using configuration: ./configs/default.json


# Load model from checkpoint

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
model = Model(
    source_vocab_size=len(source_vocabulary), source_embedding_size=config["embedding_size"],
    target_vocab_size=len(target_vocabulary), target_embedding_size=config["embedding_size"],
    encoding_size=config["rnn_hidden_size"], target_bos_index=source_vocabulary.begin_seq_index,
    max_seq_length=171
).to(device)

checkpoint = torch.load("model.pth")
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [19]:
model.eval()

Model(
  (encoder): Encoder(
    (source_embedding): Embedding(45, 64, padding_idx=0)
    (birnn): GRU(64, 64, batch_first=True, bidirectional=True)
  )
  (decoder): Decoder(
    (target_embedding): Embedding(26, 64, padding_idx=0)
    (gru_cell): GRUCell(192, 128)
    (hidden_map): Linear(in_features=128, out_features=128, bias=True)
    (classifier): Linear(in_features=256, out_features=26, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
)

# Helper functions

In [20]:
def to_indices(scores):
    _, indices = torch.max(scores, dim=1)
    return indices

def sentence_from_indices(indices, vocab, strict=True):
    out = []
    for index in indices:
        index = index.item()
        if index == vocab.begin_seq_index and strict:
            continue
        elif index == vocab.end_seq_index and strict:
            return " ".join(out)
        else:
            out.append(vocab.getToken(index))
    return " ".join(out)

# Inference

In [24]:
# n^2 test
test_expression = ["#", "/", "0", "0", "0"]
test_tensor = torch.tensor([
    source_vocabulary.vectorize(test_expression) for _ in range(config["batch_size"])
], dtype=torch.int32)
test_pred = model(
    test_tensor,
    torch.LongTensor([len(test_tensor[0]) for _ in range(len(test_tensor))]),
    target_sequence=None
)

print(f"Test expression: {test_expression}")
print(f"Predicted shape: {test_pred.shape}")
print(f"Predicted value: {sentence_from_indices(to_indices(test_pred[0]), target_vocabulary)}")

Test expression: ['#', '/', '0', '0', '0']
Predicted shape: torch.Size([8, 95, 26])
Predicted value: TT_ZERO TT_INTEGER TT_INTEGER TT_INTEGER TT_SQRT TT_MINUS TT_SQRT TT_MULTIPLY TT_PI TT_INTEGER TT_POW TT_PLUS TT_INTEGER TT_INTEGER TT_RATIONAL TT_MULTIPLY TT_PLUS TT_INTEGER TT_INTEGER TT_LOG TT_MULTIPLY TT_MINUS TT_INTEGER TT_INTEGER TT_LOG TT_MULTIPLY TT_MINUS TT_INTEGER TT_INTEGER TT_LOG TT_MULTIPLY TT_PLUS TT_INTEGER TT_INTEGER TT_LOG TT_MULTIPLY TT_PLUS TT_INTEGER TT_INTEGER TT_SQRT TT_MULTIPLY TT_INTEGER TT_INTEGER TT_SQRT TT_MINUS TT_LOG TT_MULTIPLY TT_PLUS TT_INTEGER TT_DIVIDE
