In [21]:
import torch

import sympy as sp

from model.vocabulary import vocabulary_answers as target_vocabulary
from model.vocabulary import vocabulary_expressions as source_vocabulary
from model.equation_interpreter import Equation
from model.tokens import Token

import json

In [22]:
print(target_vocabulary)
print(source_vocabulary)

<Vocabulary(size=26)>
<Vocabulary(size=45)>


In [23]:
from model.model import Model

In [24]:
config_path = "./configs/default.json"
print(f"Using configuration: {config_path}")

file = open(config_path, "r")
config = json.load(file)
file.close()

Using configuration: ./configs/default.json


# Load model from checkpoint

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [26]:
model = Model(
    source_vocab_size=len(source_vocabulary), source_embedding_size=config["embedding_size"],
    target_vocab_size=len(target_vocabulary), target_embedding_size=config["embedding_size"],
    encoding_size=config["rnn_hidden_size"], target_bos_index=target_vocabulary.begin_seq_index,
    max_seq_length=185
).to(device)

checkpoint = torch.load("model_3.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [27]:
model.eval()

Model(
  (encoder): Encoder(
    (source_embedding): Embedding(45, 64, padding_idx=41)
    (birnn): GRU(64, 256, batch_first=True, bidirectional=True)
  )
  (decoder): Decoder(
    (target_embedding): Embedding(26, 64, padding_idx=22)
    (gru_cell): GRUCell(576, 512)
    (hidden_map): Linear(in_features=512, out_features=512, bias=True)
    (classifier): Linear(in_features=1024, out_features=26, bias=True)
    (dropout): Dropout(p=0.3, inplace=True)
  )
)

# Helper functions

In [28]:
def to_indices(scores):
    _, indices = torch.max(scores, dim=1)
    return indices

def sentence_from_indices(indices, vocab, strict=True):
    out = []
    for index in indices:
        index = index.item()
        if index == vocab.begin_seq_index and strict:
            continue
        elif index == vocab.end_seq_index and strict:
            return " ".join(out)
        else:
            out.append(vocab.getToken(index))
    return " ".join(out)

# Inference

In [29]:
# n^2 test
test_expression = ["#", "/", "0", "0"]
test_tensor = torch.tensor([
    source_vocabulary.vectorize(test_expression) for _ in range(config["batch_size"])
], dtype=torch.int32)
test_pred = model(
    test_tensor,
    torch.LongTensor([len(test_tensor[0]) for _ in range(len(test_tensor))]),
    target_sequence=None
)

prediction = sentence_from_indices(to_indices(test_pred[0]), target_vocabulary)

print(f"Test expression: {test_expression}")
print(f"Predicted shape: {test_pred.shape}")
print(f"Predicted value: {prediction}")

Test expression: ['#', '/', '-2', '-2']
Predicted shape: torch.Size([8, 4, 26])
Predicted value: TT_INTEGER TT_INTEGER TT_DIVIDE


In [30]:
token_list = [Token(t_type) for t_type in prediction.split(" ")]
predicted_equation = Equation(token_list, notation="postfix")
predicted_equation.getMathmetaicalNotation()

'(Z/Z)'

In [31]:
sp.parse_expr(predicted_equation.getMathmetaicalNotation().replace("^", "**"))

1