In [1]:
import torch
import seqgen.seq_gen as g
import random
import matplotlib.pyplot as plt
import seaborn as sns
from seqgen.model import rnn, embedding, attention
from seqgen.vocabulary import *
from seqgen.preprocess import *
from seqgen.datasets.sequences import *
from seqgen.datasets.realdata import RealSequencesDataset

torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

In [2]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

Device cuda


In [3]:
lr = 1e-2

cell_type=rnn.CellType.LSTM
encoder_embedding_type=embedding.EmbeddingType.COORDS_DIRECT
decoder_embedding_type=embedding.EmbeddingType.POS_SUBSPACE
attention_type=attention.AttentionType.DOT

use_real_dataset=True
num_layers=3
embedding_dim=64
hidden_size=64
batch_size=128
max_length=50
bidirectional=True

In [4]:
vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

if use_real_dataset:
    dataset = RealSequencesDataset(filename="data/train/label.txt", vocab_in=vocab_in, vocab_out=vocab_out, max_length=max_length-2, batch_size=batch_size, device=device)
else:
    dataset = SyntheticSequenceDataset(vocab_in, vocab_out, max_length, batch_size, continue_prob=0.95, additional_eos=True, device=device)
    
positions = torch.arange(max_length).repeat(batch_size, 1).to(device)

input_seqs, coordinates, target_seqs = dataset[0]
input_seqs.shape, coordinates.shape, target_seqs.shape, positions.shape

(torch.Size([128, 50]),
 torch.Size([128, 50, 4]),
 torch.Size([128, 50]),
 torch.Size([128, 50]))

# The Encoder

In [5]:
load_from_checkpoint = False
checkpoint_file = "model_2023-01-15_09-17-53.pt"

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

encoder = rnn.RecurrentEncoder(
    cell_type=cell_type,
    embedding_type=encoder_embedding_type,
    vocab_size=len(vocab_in),
    embedding_dim=embedding_dim,
    hidden_size=hidden_size,
    max_length=max_length,
    num_layers=num_layers,
    dropout=0.1,
    bidirectional=bidirectional,
    device=device
).to(device)

decoder = rnn.RecurrentAttentionDecoder(
    cell_type=cell_type,
    embedding_type=decoder_embedding_type,
    attention_type=attention_type,
    embedding_dim=embedding_dim,
    hidden_size=hidden_size,
    vocab_size=len(vocab_out),
    max_length=max_length,
    batch_size=batch_size,
    num_layers=num_layers,
    dropout=0.1,
    bidirectional=bidirectional,
    device=device
).to(device)

# Initialize optimizer for encoder and decoder
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)

# Loss function
criterion = torch.nn.NLLLoss()

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file)
    cell_type=checkpoint['cell_type']
    attention_type=checkpoint['attention_type']
    encoder_embedding_type=checkpoint['encoder_embedding_type']
    decoder_embedding_typecheckpoint['decoder_embedding_type']
    encoder.load_state_dict(checkpoint['encoder_model_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_model_state_dict'])
    encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer_state_dict'])
    decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer_state_dict'])
    num_layers = checkpoint['num_layers']
    embedding_dim = checkpoint['embedding_dim']
    hidden_size = checkpoint['hidden_size']
    bidirectional = checkpoint['bidirectional']

In [6]:
# Initialize the encoder hidden state and cell state with zeros
hn = encoder.initHidden(input_seqs.shape[0], device=dataset.device)
cn = encoder.initHidden(input_seqs.shape[0], device=dataset.device)
hidden = (hn, cn) if cell_type == rnn.CellType.LSTM else hn

_hidden_size = hidden_size * 2 if bidirectional else hidden_size
encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size*num_layers)).to(device)
encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)

# Iterate over the sequence words and run every word through the encoder
for i in range(input_seqs.shape[1]):
    # Run the i-th word of the input sequence through the encoder.
    # As a result we will get the prediction (output), the hidden state and the cell state.
    # The hidden state and cell state will be used as inputs in the next round
    print(f"Run word {i+1} of all {input_seqs.shape[0]} sequences through the encoder")
    output, hidden = encoder(
        input_seqs[:, i].unsqueeze(dim=1),
        coordinates[:, i],
        hidden
    )
    encoder_outputs[:, i:i+1, :] = output

Run word 1 of all 128 sequences through the encoder
Run word 2 of all 128 sequences through the encoder
Run word 3 of all 128 sequences through the encoder
Run word 4 of all 128 sequences through the encoder
Run word 5 of all 128 sequences through the encoder
Run word 6 of all 128 sequences through the encoder
Run word 7 of all 128 sequences through the encoder
Run word 8 of all 128 sequences through the encoder
Run word 9 of all 128 sequences through the encoder
Run word 10 of all 128 sequences through the encoder
Run word 11 of all 128 sequences through the encoder
Run word 12 of all 128 sequences through the encoder
Run word 13 of all 128 sequences through the encoder
Run word 14 of all 128 sequences through the encoder
Run word 15 of all 128 sequences through the encoder
Run word 16 of all 128 sequences through the encoder
Run word 17 of all 128 sequences through the encoder
Run word 18 of all 128 sequences through the encoder
Run word 19 of all 128 sequences through the encoder
Ru

In [7]:
output.shape, hn.shape, cn.shape, encoder_outputs.shape

(torch.Size([128, 1, 128]),
 torch.Size([6, 128, 64]),
 torch.Size([6, 128, 64]),
 torch.Size([128, 50, 128]))

# The Decoder

In [8]:
loss = 0

# Iterate over words of target sequence and run words through the decoder.
# This will produce a prediction for the next word in the sequence
for i in range(0, target_seqs.size(1)):
    print(f"Run word {i+1} through decoder", hn[0].shape if cell_type == rnn.CellType.LSTM else hn.shape, encoder_hidden_states.shape)
    output, hn, _ = decoder(
        x=target_seqs[:, i].unsqueeze(dim=1),
        positions=positions[:, i:i+1],
        annotations=encoder_outputs,
        hidden=hidden
    )
    loss += criterion(output.squeeze(), target_seqs[:, i])

print("LOSS", loss.item() / max_length)

Run word 1 through decoder torch.Size([128, 64]) torch.Size([128, 50, 384])
Run word 2 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 3 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 4 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 5 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 6 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 7 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 8 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 9 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 10 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 11 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 12 through decoder torch.Size([6, 128, 64]) torch.Size([128, 50, 384])
Run word 13 through decoder torch.Size([6, 128, 64])

# Training

In [9]:
history = []
accuracies = []

for epoch in range(10000):
    # With a certain chance present the model the true predictions
    # instead of its own predictions in the next iteration
    use_teacher_forcing_prob = 0.5
    use_teacher_forcing = random.random() < use_teacher_forcing_prob
    
    # Get a batch of trianing data
    input_seqs, coordinates, target_seqs = dataset[0]

    # Initialize the encoder hidden state and cell state with zeros
    hn = encoder.initHidden(input_seqs.shape[0], device=dataset.device)
    cn = encoder.initHidden(input_seqs.shape[0], device=dataset.device)
    hidden = (hn, cn) if cell_type == rnn.CellType.LSTM else hn
    
    # Initialize encoder outputs tensor
    last_n_states = 2 if bidirectional else 1
    _hidden_size = hidden_size * 2 if bidirectional else hidden_size
    encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size*num_layers)).to(device)
    encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)
    
    # Set gradients of all model parameters to zero
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Initialize loss
    loss = 0
    
    ####################
    #     ENCODING     #
    ####################

    # Iterate over the sequence words and run every word through the encoder
    for i in range(input_seqs.shape[1]):
        # Run the i-th word of the input sequence through the encoder.
        # As a result we will get the prediction (output), the hidden state (hn).
        # The hidden state and cell state will be used as inputs in the next round
        output, hidden = encoder(
            x=input_seqs[:, i].unsqueeze(dim=1),
            coordinates=coordinates[:, i],
            hidden=hidden
        )
        # Save encoder outputs and states for current word
        encoder_outputs[:, i:i+1, :] = output
        encoder_hidden_states[:, i, :] = rnn.concat_hidden_states(hn)

    ####################
    #     DECODING     #
    ####################
    
    accuracy = 0.0

    # The first words that we be presented to the model is the '<start>' token
    prediction = target_seqs[:, 0]
    
    # Iterate over words of target sequence and run words through the decoder.
    # This will produce a prediction for the next word in the sequence
    for i in range(1, target_seqs.size(1)):
        # Run word i through decoder and get word i+1 and the new hidden state as outputs
        if use_teacher_forcing:
            output, hidden, _ = decoder(
                x=target_seqs[:, i-1].unsqueeze(dim=1),
                positions=positions[:, i-1:i],
                annotations=encoder_outputs,
                hidden=hidden
            )
            # Get the predicted classes of the model
            topv, topi = output.topk(1)
        else:
            output, hidden, _ = decoder(
                x=prediction.unsqueeze(dim=1),
                positions=positions[:, i-1:i],
                annotations=encoder_outputs,
                hidden=hidden
            )
            # Get the predicted classes of the model
            topv, topi = output.topk(1)
            prediction = topi.squeeze()    
        loss += criterion(output.squeeze(), target_seqs[:, i])
        accuracy += float((topi.squeeze() == target_seqs[:, i]).sum() / (target_seqs.size(0)*(target_seqs.size(1)-1)))
    
    history.append(loss.item())
    accuracies.append(accuracy)
    
    print_every = 100
    if not epoch % print_every:
        _accuracy = sum(accuracies[-print_every:]) / print_every
        print(f"LOSS after epoch {epoch}", loss.item() / (target_seqs.size(1)), "ACCURACY", _accuracy)

    # Compute gradient
    loss.backward()
    accuracy = 0.0

    # Update weights of encoder and decoder
    encoder_optimizer.step()
    decoder_optimizer.step()

LOSS after epoch 0 5.054454040527344 ACCURACY 0.0
LOSS after epoch 100 1.6105517578125 ACCURACY 0.6921938659199804
LOSS after epoch 200 1.3887765502929688 ACCURACY 0.7354830872791354
LOSS after epoch 300 1.3115267944335938 ACCURACY 0.7491103193652816
LOSS after epoch 400 0.9655345916748047 ACCURACY 0.7547098092571832
LOSS after epoch 500 1.2592845153808594 ACCURACY 0.7733721177605912
LOSS after epoch 600 0.9983010864257813 ACCURACY 0.7725334690650925
LOSS after epoch 700 0.8760892486572266 ACCURACY 0.8149585323780775
LOSS after epoch 800 0.6421528625488281 ACCURACY 0.8523341688793152
LOSS after epoch 900 0.21059833526611327 ACCURACY 0.8777471140213311
LOSS after epoch 1000 0.17809185028076172 ACCURACY 0.8989269617851824
LOSS after epoch 1100 0.44169475555419924 ACCURACY 0.9109406745340675
LOSS after epoch 1200 0.1563927173614502 ACCURACY 0.9226897191163153
LOSS after epoch 1300 0.43213146209716796 ACCURACY 0.9333466071821749
LOSS after epoch 1400 0.29591222763061525 ACCURACY 0.93519928

KeyboardInterrupt: 

#### Save model history

In [None]:
import pickle
from datetime import datetime

model_data = {
    "history": history,
    "lr": lr,
    "cell_type": cell_type,
    "encoder_embedding_type": encoder_embedding_type,
    "decoder_embedding_type": decoder_embedding_type,
    "attention_type": attention_type,
    "embedding_dim": embedding_dim,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "max_length": max_length
}

now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d_%H-%M-%S")
filename = f"{cell_type}_{num_layers}layers_encemb-{encoder_embedding_type}_decemb-{decoder_embedding_type}_attn-{attention_type}"

torch.save({
    'epoch': epoch,
    'encoder_model_state_dict': encoder.state_dict(),
    'decoder_model_state_dict': decoder.state_dict(),
    'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
    'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
    'loss': loss,
    "history": history,
    "lr": lr,
    "cell_type": cell_type,
    "encoder_embedding_type": encoder_embedding_type,
    "decoder_embedding_type": decoder_embedding_type,
    "attention_type": attention_type,
    "embedding_dim": embedding_dim,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "max_length": max_length,
    "num_layers": num_layers,
    "bidirectional": bidirectional,
}, filename + ".pt")


with open(filename + '.pkl', 'wb') as f:
    pickle.dump(model_data, f)
    
print(str(date_time), "Saved model: " + filename)

## Make predictions

We run our input sequences through the model and get output seuences. Then we decode the output sequences with the Vocabulary class and get our final latex code.

In [None]:
def predict(input_seqs, coordinates, target_seqs):
    vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
    vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

    predictions = torch.zeros(target_seqs.shape)
    attention_matrix = torch.zeros((input_seqs.shape[0], input_seqs.shape[1], input_seqs.shape[1]))

    with torch.no_grad():
        # Initialize the encoder hidden state and cell state with zeros
        hn = encoder.initHidden(input_seqs.shape[0], device=dataset.device)
        
        # Initialize the encoder hidden state and cell state with 
        last_n_states = 2 if bidirectional else 1
        _hidden_size = hidden_size * 2 if bidirectional else hidden_size
        encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size*num_layers)).to(device)
        encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)

        # Iterate over the sequence words and run every word through the encoder
        for i in range(input_seqs.size(1)):
            output, hn = encoder(
                input_seqs[:, i].unsqueeze(dim=1),
                coordinates[:, i],
                hn
            )
            encoder_outputs[:, i:i+1, :] = output
            encoder_hidden_states[:, i, :] = rnn.concat_hidden_states(hn)
        
        # Predict tokens of the target sequence by running the hidden state through
        # the decoder
        for i in range(0, target_seqs.size(1)):
            output, hn, attention = decoder(
                x=target_seqs[:, i-1].unsqueeze(dim=1),
                coordinates=coordinates[:, i-1],
                annotations=encoder_outputs,
                hidden=hn
            )
            # Select the indices of the most likely tokens
            predicted_char = torch.argmax(output, dim=1)
            predictions[:, i] = torch.argmax(output, dim=1).squeeze()
            attention_matrix[:, :, i:i+1] = attention
        
        return predictions, attention_matrix

In [None]:
prediction, attention_matrix = predict(input_seqs, coordinates, target_seqs)
prediction.shape, attention_matrix.shape, input_seqs.shape

In [None]:
idx = random.randint(0, prediction.size(0)-1)
seq_in = vocab_in.decode_sequence(input_seqs[idx].cpu().numpy())
seq_out = vocab_out.decode_sequence(predictions[idx].cpu().numpy())
fig, ax = plt.subplots(1,1)
ax.matshow(attention_matrix[idx], cmap='bone')
ax.set_xticklabels([seq_out[j] for j in range(prediction.size(1))], rotation=45)
ax.set_yticklabels([seq_in[j] for j in range(prediction.size(1))])
#ax.tick_params(labelsize=15)
ax.set(xlabel='Output Sequence', ylabel='Input Sequence')
ax.xaxis.set_major_locator(plt.MaxNLocator(prediction.size(1)))
ax.yaxis.set_major_locator(plt.MaxNLocator(prediction.size(1)))

In [None]:
in_swapped = g.random_swap(input_seqs[0], i=2).unsqueeze(dim=0)
coords_swapped = g.random_swap(coordinates[0], i=2).unsqueeze(dim=0)
prediction_swapped = predict(in_swapped, coords_swapped, target_seqs[0:1])

In [None]:
input_seqs[0:1] == in_swapped

In [None]:
prediction == prediction_swapped

In [None]:
# Pick random sequence and its prediction from the model
import random

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

predictions, attention_matrix = predict(input_seqs, coordinates, target_seqs)

i = random.randint(0, predictions.size(0))
print("MODEL INPUT", vocab_in.decode_sequence(input_seqs[i].cpu().numpy()))
print("MODEL OUTPUT", vocab_out.decode_sequence(predictions[i].cpu().numpy()))
print("TARGET OUTPUT", vocab_out.decode_sequence(target_seqs[i][1:].cpu().numpy()))

In [None]:
prediction = vocab_out.decode_sequence(predictions[i].cpu().numpy())
prediction = list(filter(lambda x: x != '<end>', prediction))
prediction = "".join(prediction)
print("MODEL OUTPUT", prediction)