In [1]:
import torch
import seqgen.seq_gen as g
import random
import matplotlib.pyplot as plt
import seaborn as sns
from seqgen.model import seq2seq_lstm
from seqgen.vocabulary import *

torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

In [2]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

Device cpu


  return torch._C._cuda_getDeviceCount() > 0


In [6]:
lr = 1e-2
num_layers=1
embedding_dim=16
hidden_size=16
batch_size=32
max_length=10
positional_encoding=True
bidirectional=True

In [7]:
features, target_seqs = g.generate_synthetic_training_data(batch_size, max_length=max_length, device=device, continue_prob=0.999, swap_times=10)
input_seqs = torch.tensor(features[:, :, 0]).to(torch.int64)
coordinates = torch.tensor(features[:, :, 1:])
positions_coords = seq2seq_lstm.get_coordinate_encoding(coordinates, max_length=max_length, d=embedding_dim, device=device)
positions_targets = seq2seq_lstm.get_position_encoding(max_length, embedding_dim, device=device).repeat(batch_size, 1, 1)

  input_seqs = torch.tensor(features[:, :, 0]).to(torch.int64)
  coordinates = torch.tensor(features[:, :, 1:])


In [8]:
features.shape, input_seqs.shape, coordinates.shape, target_seqs.shape, positions_coords.shape, positions_targets.shape

(torch.Size([32, 10, 5]),
 torch.Size([32, 10]),
 torch.Size([32, 10, 4]),
 torch.Size([32, 10]),
 torch.Size([32, 10, 16]),
 torch.Size([32, 10, 16]))

# The Encoder

In [9]:
load_from_checkpoint = False
checkpoint_file = "model_2023-01-15_09-17-53.pt"

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

encoder = seq2seq_lstm.EncoderGRUPosEnc(vocab_size=len(vocab_in), embedding_dim=embedding_dim, num_layers=num_layers, max_length=max_length, hidden_size=hidden_size, bidirectional=bidirectional, pos_encoding=positional_encoding).to(features.device)
decoder = seq2seq_lstm.DecoderGRUAttention(embedding_dim=embedding_dim, num_layers=num_layers, max_length=max_length, hidden_size=hidden_size, vocab_size=len(vocab_out), bidirectional=bidirectional, pos_encoding=positional_encoding).to(features.device)

# Initialize optimizer for encoder and decoder
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)

# Loss function
criterion = torch.nn.NLLLoss()

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file)
    encoder.load_state_dict(checkpoint['encoder_model_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_model_state_dict'])
    encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer_state_dict'])
    decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer_state_dict'])
    num_layers = checkpoint['num_layers']
    embedding_dim = checkpoint['embedding_dim']
    hidden_size = checkpoint['hidden_size']
    bidirectional = checkpoint['bidirectional']



In [10]:
# Initialize the encoder hidden state and cell state with zeros
hn = encoder.initHidden(input_seqs.shape[0], device=features.device)
cn = encoder.initHidden(input_seqs.shape[0], device=features.device)

_hidden_size = hidden_size * 2 if bidirectional else hidden_size
encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size*num_layers)).to(device)
encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)

# Iterate over the sequence words and run every word through the encoder
for i in range(input_seqs.shape[1]):
    # Run the i-th word of the input sequence through the encoder.
    # As a result we will get the prediction (output), the hidden state and the cell state.
    # The hidden state and cell state will be used as inputs in the next round
    print(f"Run word {i+1} of all {input_seqs.shape[0]} sequences through the encoder")
    output, hn = encoder(
        input_seqs[:, i].unsqueeze(dim=1),
        coordinates[:, i],
        positions_coords[:, i:i+1],
        hn
    )
    encoder_outputs[:, i:i+1, :] = output
    encoder_hidden_states[:, i, :] = seq2seq_lstm.concat_hidden_states(hn)

Run word 1 of all 32 sequences through the encoder
Run word 2 of all 32 sequences through the encoder
Run word 3 of all 32 sequences through the encoder
Run word 4 of all 32 sequences through the encoder
Run word 5 of all 32 sequences through the encoder
Run word 6 of all 32 sequences through the encoder
Run word 7 of all 32 sequences through the encoder
Run word 8 of all 32 sequences through the encoder
Run word 9 of all 32 sequences through the encoder
Run word 10 of all 32 sequences through the encoder


In [11]:
output.shape, hn.shape, cn.shape, encoder_hidden_states.shape, encoder_outputs.shape

(torch.Size([32, 1, 32]),
 torch.Size([2, 32, 16]),
 torch.Size([2, 32, 16]),
 torch.Size([32, 10, 32]),
 torch.Size([32, 10, 32]))

# The Decoder

In [12]:
loss = 0

# Iterate over words of target sequence and run words through the decoder.
# This will produce a prediction for the next word in the sequence
for i in range(0, target_seqs.size(1)):
    print(f"Run word {i+1} through decoder", hn.shape, encoder_hidden_states.shape)
    output, hn, attention = decoder(
        x=target_seqs[:, i].unsqueeze(dim=1),
        coordinates=coordinates[:, i],
        annotations=encoder_outputs,
        position=positions_targets[:, i:i+1],
        hidden=hn
    )
    loss += criterion(output.squeeze(), target_seqs[:, i])

print("LOSS", loss.item() / max_length)

Run word 1 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
Run word 2 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
Run word 3 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
Run word 4 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
Run word 5 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
Run word 6 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
Run word 7 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
Run word 8 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
Run word 9 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
Run word 10 through decoder torch.Size([2, 32, 16]) torch.Size([32, 10, 32])
LOSS 3.2254169464111326


# Training

In [None]:
history = []
accuracies = []

for epoch in range(50000):
    # With a certain chance present the model the true predictions
    # instead of its own predictions in the next iteration
    use_teacher_forcing_prob = 0.5
    use_teacher_forcing = random.random() < use_teacher_forcing_prob
    
    # Get a batch of trianing data
    features, target_seqs = g.generate_synthetic_training_data(batch_size, max_length=max_length, continue_prob=0.999, device=device, swap_times=max_length)
    features = features.to(device)
    target_seqs = target_seqs.to(device)
    input_seqs = torch.tensor(features[:, :, 0]).to(torch.int64)
    coordinates = torch.tensor(features[:, :, 1:])
    positions_coords = seq2seq_lstm.get_coordinate_encoding(coordinates, max_length=max_length, d=embedding_dim, device=device)
    positions_targets = seq2seq_lstm.get_position_encoding(max_length, embedding_dim, device=device).repeat(batch_size, 1, 1)

    # Initialize the encoder hidden state and cell state with zeros
    hn = encoder.initHidden(input_seqs.shape[0], device=features.device)
    
    # Initialize encoder outputs tensor
    last_n_states = 2 if bidirectional else 1
    _hidden_size = hidden_size * 2 if bidirectional else hidden_size
    encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size*num_layers)).to(device)
    encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)
    
    # Set gradients of all model parameters to zero
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Initialize loss
    loss = 0
    
    ####################
    #     ENCODING     #
    ####################

    # Iterate over the sequence words and run every word through the encoder
    for i in range(input_seqs.shape[1]):
        # Run the i-th word of the input sequence through the encoder.
        # As a result we will get the prediction (output), the hidden state (hn).
        # The hidden state and cell state will be used as inputs in the next round
        output, hn = encoder(
            input_seqs[:, i].unsqueeze(dim=1),
            coordinates[:, i],
            positions_coords[:, i:i+1],
            hn
        )
        # Save encoder outputs and states for current word
        encoder_outputs[:, i:i+1, :] = output
        encoder_hidden_states[:, i, :] = seq2seq_lstm.concat_hidden_states(hn)

    ####################
    #     DECODING     #
    ####################
    
    accuracy = 0.0

    # The first words that we be presented to the model is the '<start>' token
    prediction = target_seqs[:, 0]
    
    # Iterate over words of target sequence and run words through the decoder.
    # This will produce a prediction for the next word in the sequence
    for i in range(1, target_seqs.size(1)):
        # Run word i through decoder and get word i+1 and the new hidden state as outputs
        if use_teacher_forcing:
            output, hn, attention = decoder(
                x=target_seqs[:, i-1].unsqueeze(dim=1),
                coordinates=coordinates[:, i-1],
                annotations=encoder_outputs,
                position=positions_targets[:, i:i+1],
                hidden=hn
            )
            # Get the predicted classes of the model
            topv, topi = output.topk(1)
        else:
            output, hn, attention = decoder(
                x=prediction.unsqueeze(dim=1),
                coordinates=coordinates[:, i-1],
                annotations=encoder_outputs,
                position=positions_targets[:, i:i+1],
                hidden=hn
            )
            # Get the predicted classes of the model
            topv, topi = output.topk(1)
            prediction = topi.squeeze()    
        loss += criterion(output.squeeze(), target_seqs[:, i])
        accuracy += float((topi.squeeze() == target_seqs[:, i]).sum() / (target_seqs.size(0)*(target_seqs.size(1)-1)))
    
    history.append(loss.item())
    accuracies.append(accuracy)
    
    print_every = 100
    if not epoch % print_every:
        _accuracy = sum(accuracies[-print_every:]) / print_every
        print(f"LOSS after epoch {epoch}", loss.item() / (target_seqs.size(1)), "ACCURACY", _accuracy)

    # Compute gradient
    loss.backward()
    accuracy = 0.0

    # Update weights of encoder and decoder
    encoder_optimizer.step()
    decoder_optimizer.step()

  input_seqs = torch.tensor(features[:, :, 0]).to(torch.int64)
  coordinates = torch.tensor(features[:, :, 1:])


LOSS after epoch 0 2.925399398803711 ACCURACY 0.00031250000232830645
LOSS after epoch 100 1.6385673522949218 ACCURACY 0.285520836240612
LOSS after epoch 200 1.2141572952270507 ACCURACY 0.5580902794282884
LOSS after epoch 300 1.1330288887023925 ACCURACY 0.6374305572733283
LOSS after epoch 400 1.089541244506836 ACCURACY 0.6720486124418676
LOSS after epoch 500 1.0537545204162597 ACCURACY 0.6924305563420057
LOSS after epoch 600 1.0178559303283692 ACCURACY 0.7072222233563662
LOSS after epoch 700 0.5338333129882813 ACCURACY 0.7455555554106832
LOSS after epoch 800 0.8128387451171875 ACCURACY 0.7469791673123837
LOSS after epoch 900 0.4649209499359131 ACCURACY 0.7863888899981976
LOSS after epoch 1000 0.8619023323059082 ACCURACY 0.7694444453157484
LOSS after epoch 1100 0.4248810768127441 ACCURACY 0.8015972237288952
LOSS after epoch 1200 0.7515764236450195 ACCURACY 0.794965279456228
LOSS after epoch 1300 0.48764815330505373 ACCURACY 0.8137500003352761
LOSS after epoch 1400 0.3805121660232544 ACCU

#### Save model history

In [None]:
import pickle
from datetime import datetime

model_data = {
    "history": history,
    "lr": lr,
    "embedding_dim": embedding_dim,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "max_length": max_length
}

now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d_%H-%M-%S")

torch.save({
    'epoch': epoch,
    'encoder_model_state_dict': encoder.state_dict(),
    'decoder_model_state_dict': decoder.state_dict(),
    'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
    'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
    'loss': loss,
    "history": history,
    "lr": lr,
    "embedding_dim": embedding_dim,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "max_length": max_length,
    "num_layers": num_layers,
    "bidirectional": bidirectional,
}, "model_" + date_time + ".pt")


with open("training_" + date_time + '.pkl', 'wb') as f:
    pickle.dump(model_data, f)

## Make predictions

We run our input sequences through the model and get output seuences. Then we decode the output sequences with the Vocabulary class and get our final latex code.

In [None]:
def predict(input_seqs, coordinates, target_seqs):
    vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
    vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

    predictions = torch.zeros(target_seqs.shape)
    attention_matrix = torch.zeros((input_seqs.shape[0], input_seqs.shape[1], input_seqs.shape[1]))

    with torch.no_grad():
        # Initialize the encoder hidden state and cell state with zeros
        hn = encoder.initHidden(input_seqs.shape[0], device=features.device)
        
        # Initialize the encoder hidden state and cell state with 
        last_n_states = 2 if bidirectional else 1
        _hidden_size = hidden_size * 2 if bidirectional else hidden_size
        encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size*num_layers)).to(device)
        encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)

        # Iterate over the sequence words and run every word through the encoder
        for i in range(input_seqs.size(1)):
            output, hn = encoder(
                input_seqs[:, i].unsqueeze(dim=1),
                coordinates[:, i],
                positions_coords[:, i:i+1],
                hn
            )
            encoder_outputs[:, i:i+1, :] = output
            encoder_hidden_states[:, i, :] = seq2seq_lstm.concat_hidden_states(hn)
        
        # Predict tokens of the target sequence by running the hidden state through
        # the decoder
        for i in range(0, target_seqs.size(1)):
            output, hn, attention = decoder(
                x=target_seqs[:, i-1].unsqueeze(dim=1),
                coordinates=coordinates[:, i-1],
                annotations=encoder_outputs,
                position=positions_targets[:, i:i+1],
                hidden=hn
            )
            # Select the indices of the most likely tokens
            predicted_char = torch.argmax(output, dim=1)
            predictions[:, i] = torch.argmax(output, dim=1).squeeze()
            attention_matrix[:, :, i:i+1] = attention
        
        return predictions, attention_matrix

In [None]:
prediction, attention_matrix = predict(input_seqs, coordinates, target_seqs)
prediction.shape, attention_matrix.shape

In [None]:
plt.imshow(attention_matrix[random.randint(0, prediction.size(0)-1)])

In [None]:
in_swapped = g.random_swap(input_seqs[0], i=2).unsqueeze(dim=0)
coords_swapped = g.random_swap(coordinates[0], i=2).unsqueeze(dim=0)
prediction_swapped = predict(in_swapped, coords_swapped, target_seqs[0:1])

In [None]:
input_seqs[0:1] == in_swapped

In [None]:
prediction == prediction_swapped

In [None]:
# Pick random sequence and its prediction from the model
import random

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

predictions, attention_matrix = predict(input_seqs, coordinates, target_seqs)

i = random.randint(0, predictions.size(0))
print("MODEL INPUT", vocab_in.decode_sequence(input_seqs[i].cpu().numpy()))
print("MODEL OUTPUT", vocab_out.decode_sequence(predictions[i].cpu().numpy()))
print("TARGET OUTPUT", vocab_out.decode_sequence(target_seqs[i][1:].cpu().numpy()))

In [None]:
prediction = vocab_out.decode_sequence(predictions[i].cpu().numpy())
prediction = list(filter(lambda x: x != '<end>', prediction))
prediction = "".join(prediction)
print("MODEL OUTPUT", prediction)

In [None]:
sns.heatmap(attention_matrix[0])