# Model Inference

In [None]:
import torch
import seqgen.seq_gen as g
import random
from seqgen.model import seq2seq_lstm
from seqgen.vocabulary import *

%load_ext autoreload
%autoreload 2

In [None]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

In [None]:
features, target_seqs = g.generate_synthetic_training_data(100, max_length=25, swap_times=0, device=device)
input_seqs = torch.Tensor(features[:, :, 0]).to(torch.int64)
coordinates = torch.Tensor(features[:, :, 1:])

In [None]:
num_layers=2
embedding_dim = 100
hidden_size=100
batch_size=100
max_length=25
bidirectional=True

load_from_checkpoint = True
checkpoint_file = "model_2022-12-25_11-55-53.pt"

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

encoder = seq2seq_lstm.EncoderRNN(vocab_size=len(vocab_in), embedding_dim=embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=bidirectional).to(features.device)
decoder = seq2seq_lstm.DecoderRNN(embedding_dim=embedding_dim, num_layers=num_layers, hidden_size=hidden_size, vocab_size=len(vocab_out), bidirectional=bidirectional).to(features.device)

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file)
    encoder.load_state_dict(checkpoint['encoder_model_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_model_state_dict'])

In [None]:
def predict(input_seqs, coordinates, target_seqs):
    vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
    vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

    predictions = torch.zeros(target_seqs.shape).to(torch.int64).to(device)

    with torch.no_grad():
        # Initialize the encoder hidden state and cell state with zeros
        hn = encoder.initHidden(input_seqs.shape[0], device=features.device)
        cn = encoder.initHidden(input_seqs.shape[0], device=features.device)

        # Iterate over the sequence words and run every word through the encoder
        for i in range(input_seqs.shape[1]):
            output, (hn, cn) = encoder(
                input_seqs[:, i].unsqueeze(dim=1),
                coordinates[:, i],
                (hn, cn)
            )

        # Predict tokens of the target sequence by running the hidden state through
        # the decoder
        for i in range(0, target_seqs.size(1)):
            output, (hn, cn) = decoder(
                target_seqs[:, i].unsqueeze(dim=1),
                coordinates[:, i],
                (hn, cn)
            )
            # Select the indices of the most likely tokens
            predictions[:, i] = torch.argmax(output, dim=2).squeeze()
        
        return predictions

In [None]:
# Pick random sequence and its prediction from the model
import random

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

predictions = predict(input_seqs, coordinates, target_seqs)

i = random.randint(0, predictions.size(0) - 1)
print("MODEL INPUT", vocab_in.decode_sequence(input_seqs[i][1:].cpu().numpy()))
print("TARGET OUTPUT", vocab_out.decode_sequence(target_seqs[i][1:].cpu().numpy()))
print("MODEL OUTPUT", vocab_out.decode_sequence(predictions[i].cpu().numpy()))