In [1]:
import torch
import seqgen.seq_gen as g
import random

%load_ext autoreload
%autoreload 2

In [17]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

# model hyperparameters
lr = 1e-2
num_layers=3
embedding_dim = 128
hidden_size=128
bidirectional=True

# Synthetic data generator hyperparameters
batch_size=8
max_length=20

Device cpu


In [18]:
features, target_seqs = g.generate_synthetic_training_data(batch_size, max_length=max_length, device=device, swap_times=0)
input_seqs = torch.Tensor(features[:, :, 0]).to(torch.int64)
coordinates = torch.Tensor(features[:, :, 1:])

In [19]:
features.shape, input_seqs.shape, coordinates.shape, target_seqs.shape

(torch.Size([8, 20, 5]),
 torch.Size([8, 20]),
 torch.Size([8, 20, 4]),
 torch.Size([8, 20]))

# The Encoder

In [20]:
from seqgen.model import seq2seq_lstm
from seqgen.vocabulary import *

In [21]:
load_from_checkpoint = False
checkpoint_file = "model_2022-12-27_07-07-47.pt"

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

encoder = seq2seq_lstm.EncoderRNN(vocab_size=len(vocab_in), embedding_dim=embedding_dim, num_layers=num_layers, max_length=max_length, hidden_size=hidden_size, bidirectional=bidirectional, pos_encoding=False, device=device).to(features.device)
decoder = seq2seq_lstm.AttnDecoderRNN(embedding_dim=embedding_dim, num_layers=num_layers, hidden_size=hidden_size, vocab_size=len(vocab_out), max_length=max_length, bidirectional=bidirectional, pos_encoding=False, device=device).to(features.device)
#positions = seq2seq_lstm.get_position_encoding(max_length, embedding_dim, device=device)
positions = seq2seq_lstm.get_coordinate_encoding(coordinates, d=embedding_dim, device=device)

# Load model weights from checkpoint
if load_from_checkpoint:
    num_layers = checkpoint['num_layers']
    embedding_dim = checkpoint['embedding_dim']
    hidden_size = checkpoint['hidden_size']
    bidirectional = checkpoint['bidirectional']

# Initialize optimizer for encoder and decoder
encoder_optimizer = torch.optim.SGD(encoder.parameters(), lr=lr)
decoder_optimizer = torch.optim.SGD(decoder.parameters(), lr=lr)

# Loss function
criterion = torch.nn.NLLLoss()

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file)
    encoder.load_state_dict(checkpoint['encoder_model_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_model_state_dict'])
    encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer_state_dict'])
    decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer_state_dict'])

In [22]:
# Initialize the encoder hidden state and cell state with zeros
hn = encoder.initHidden(input_seqs.shape[0], device=features.device)
cn = encoder.initHidden(input_seqs.shape[0], device=features.device)

_hidden_size = hidden_size * 2 if bidirectional else hidden_size
encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size*num_layers)).to(device)
encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)

# Iterate over the sequence words and run every word through the encoder
for i in range(input_seqs.shape[1]):
    # Run the i-th word of the input sequence through the encoder.
    # As a result we will get the prediction (output), the hidden state and the cell state.
    # The hidden state and cell state will be used as inputs in the next round
    print(f"Run word {i+1} of all {input_seqs.shape[0]} sequences through the encoder")
    output, (hn, cn) = encoder(input_seqs[:, i].unsqueeze(dim=1), coordinates[:, i], positions[:, i:i+1], (hn, cn))
    encoder_outputs[:, i:i+1, :] = output
    encoder_hidden_states[:, i, :] = seq2seq_lstm.concat_hidden_states(hn)

Run word 1 of all 8 sequences through the encoder
Run word 2 of all 8 sequences through the encoder
Run word 3 of all 8 sequences through the encoder
Run word 4 of all 8 sequences through the encoder
Run word 5 of all 8 sequences through the encoder
Run word 6 of all 8 sequences through the encoder
Run word 7 of all 8 sequences through the encoder
Run word 8 of all 8 sequences through the encoder
Run word 9 of all 8 sequences through the encoder
Run word 10 of all 8 sequences through the encoder
Run word 11 of all 8 sequences through the encoder
Run word 12 of all 8 sequences through the encoder
Run word 13 of all 8 sequences through the encoder
Run word 14 of all 8 sequences through the encoder
Run word 15 of all 8 sequences through the encoder
Run word 16 of all 8 sequences through the encoder
Run word 17 of all 8 sequences through the encoder
Run word 18 of all 8 sequences through the encoder
Run word 19 of all 8 sequences through the encoder
Run word 20 of all 8 sequences through t

In [23]:
output.shape, hn.shape, cn.shape, encoder_hidden_states.shape, encoder_outputs.shape, coordinates.shape, torch.cat([encoder_outputs, coordinates], dim=2).shape

(torch.Size([8, 1, 256]),
 torch.Size([6, 8, 128]),
 torch.Size([6, 8, 128]),
 torch.Size([8, 20, 768]),
 torch.Size([8, 20, 256]),
 torch.Size([8, 20, 4]),
 torch.Size([8, 20, 260]))

In [15]:
# Concatenate hidden state vectors from multiple hidden state layers
hn_cat = seq2seq_lstm.concat_hidden_states(hn)
hn_cat.shape

torch.Size([8, 384])

## TEST

In [16]:
hn.shape, encoder_hidden_states.shape

(torch.Size([3, 8, 128]), torch.Size([8, 20, 384]))

In [11]:
h_st = seq2seq_lstm.combine_encoder_annotations_and_hidden_state(hn, encoder_hidden_states)
h_st.shape

torch.Size([8, 20, 768])

# The Decoder

In [12]:
loss = 0

# This will produce a prediction for the next word in the sequence
for i in range(0, target_seqs.size(1)):
    print(f"Run word {i+1} through decoder")
    output, (hn, cn), attn_weights = decoder(
        x=target_seqs[:, i].unsqueeze(dim=1),
        coordinates=coordinates[:, i],
        position=positions[:, i:i+1],
        hidden=(hn, cn),
        encoder_outputs=encoder_hidden_states
    )
    loss += criterion(output.squeeze(), target_seqs[:, i])

print("LOSS", loss.item() / max_length)
output.shape, hn.shape, cn.shape, attn_weights.shape

Run word 1 through decoder


RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x512 and 256x128)

# Training

In [None]:
history = []

for epoch in range(200000):
    # With a certain chance present the model the true predictions
    # instead of its own predictions in the next iteration
    use_teacher_forcing_prob = 0.5
    use_teacher_forcing = random.random() < use_teacher_forcing_prob
    
    # Get a batch of trianing data
    features, target_seqs = g.generate_synthetic_training_data(batch_size, max_length=max_length, continue_prob=0.99, device=device, swap_times=0)
    features = features.to(device)
    target_seqs = target_seqs.to(device)
    input_seqs = torch.Tensor(features[:, :, 0]).to(torch.int64)
    coordinates = torch.Tensor(features[:, :, 1:])

    # Initialize the encoder hidden state and cell state with zeros
    hn_enc = encoder.initHidden(input_seqs.shape[0], device=features.device)
    cn_enc = encoder.initHidden(input_seqs.shape[0], device=features.device)
    
    # Initialize encoder outputs tensor
    _hidden_size = hidden_size * 2 if bidirectional else hidden_size
    encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)
    encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size)).to(device)
    
    # Set gradients of all model parameters to zero
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Initialize loss
    loss = 0
    
    ####################
    #     ENCODING     #
    ####################

    # Iterate over the sequence words and run every word through the encoder
    for i in range(input_seqs.shape[1]):
        # Run the i-th word of the input sequence through the encoder.
        # As a result we will get the prediction (output), the hidden state (hn) and the cell state (cn).
        # The hidden state and cell state will be used as inputs in the next round
        output, (hn_enc, cn_enc) = encoder(
            input_seqs[:, i].unsqueeze(dim=1),
            coordinates[:, i],
            positions[:, i:i+1],
            (hn_enc, cn_enc)
        )
        # Save encoder outputs and states for current word
        encoder_outputs[:, i:i+1, :] = output
        encoder_hidden_states[:, i:i+1, :] = hn_enc[-1].unsqueeze(dim=1)
        
    ####################
    #     DECODING     #
    ####################

    # The first words that we be presented to the model is the '<start>' token
    prediction = target_seqs[:, 0]
    
    # The initial hidden state of the decoder is the final hidden state of the decoder
    hn_dec, cn_dec = hn_enc, cn_enc
    
    # Iterate over words of target sequence and run words through the decoder.
    # This will produce a prediction for the next word in the sequence
    for i in range(1, target_seqs.size(1)):
        # Run word i through decoder and get word i+1 and the new hidden state as outputs
        if use_teacher_forcing:
            output, (hn_dec, cn_dec), attn_weights = decoder(
                target_seqs[:, i-1].unsqueeze(dim=1),
                coordinates[:, i-1],
                positions[:, i-1:i],
                (hn_dec, cn_dec),
                encoder_outputs=encoder_hidden_states
            )
        else:
            output, (hn_dec, cn_dec), attn_weights = decoder(
                prediction.unsqueeze(dim=1),
                coordinates[:, i-1],
                positions[:, i-1:i],
                (hn_dec, cn_dec),
                encoder_outputs=encoder_hidden_states
            )

            # Get the predicted classes of the model
            topv, topi = output.topk(1)
            prediction = topi.squeeze()    
        loss += criterion(output.squeeze(), target_seqs[:, i])
    
    history.append(loss.item())
    if not epoch % 100:
        print(f"LOSS after epoch {epoch}", loss.item() / target_seqs.size(1))

    # Compute gradient
    loss.backward()

    # Update weights of encoder and decoder
    encoder_optimizer.step()
    decoder_optimizer.step()

#### Save model history

In [None]:
import pickle
from datetime import datetime

model_data = {
    "history": history,
    "lr": lr,
    "embedding_dim": embedding_dim,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "max_length": max_length
}

now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d_%H-%M-%S")

torch.save({
    'epoch': epoch,
    'encoder_model_state_dict': encoder.state_dict(),
    'decoder_model_state_dict': decoder.state_dict(),
    'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
    'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
    'loss': loss,
    "history": history,
    "lr": lr,
    "embedding_dim": embedding_dim,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "max_length": max_length,
    "num_layers": num_layers,
    "bidirectional": bidirectional,
}, "model_attn_" + date_time + ".pt")

## Make predictions

We run our input sequences through the model and get output seuences. Then we decode the output sequences with the Vocabulary class and get our final latex code.

In [None]:
def predict(input_seqs, coordinates, target_seqs):
    vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
    vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

    predictions = torch.zeros(target_seqs.shape)

    with torch.no_grad():
        # Initialize the encoder hidden state and cell state with zeros
        hn = encoder.initHidden(input_seqs.shape[0], device=features.device)
        cn = encoder.initHidden(input_seqs.shape[0], device=features.device)

        # Iterate over the sequence words and run every word through the encoder
        for i in range(input_seqs.shape[1]):
            output, (hn, cn) = encoder(
                input_seqs[:, i].unsqueeze(dim=1),
                coordinates[:, i],
                (hn, cn)
            )

        # Predict tokens of the target sequence by running the hidden state through
        # the decoder
        for i in range(0, target_seqs.size(1)):
            output, (hn, cn) = decoder(
                target_seqs[:, i].unsqueeze(dim=1),
                coordinates[:, i],
                (hn, cn)
            )
            # Select the indices of the most likely tokens
            predicted_char = torch.argmax(output, dim=2)
            predictions[:, i] = torch.argmax(output, dim=2).squeeze()
        
        return predictions

In [None]:
prediction = predict(input_seqs[0:1], coordinates[0:1], target_seqs[0:1])
input_seqs[0:1], prediction

In [None]:
in_swapped = g.random_swap(input_seqs[0], i=2).unsqueeze(dim=0)
coords_swapped = g.random_swap(coordinates[0], i=2).unsqueeze(dim=0)
prediction_swapped = predict(in_swapped, coords_swapped, target_seqs[0:1])
in_swapped, prediction_swapped

In [None]:
input_seqs[0:1] == in_swapped

In [None]:
prediction == prediction_swapped

In [None]:
# Pick random sequence and its prediction from the model
import random

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

predictions = predict(input_seqs, coordinates, target_seqs)

i = random.randint(0, predictions.size(0))
print("MODEL INPUT", vocab_in.decode_sequence(input_seqs[i].cpu().numpy()))
print("MODEL OUTPUT", vocab_out.decode_sequence(predictions[i].cpu().numpy()))
print("TARGET OUTPUT", vocab_out.decode_sequence(target_seqs[i][1:].cpu().numpy()))

In [None]:
prediction = vocab_out.decode_sequence(predictions[i].cpu().numpy())
prediction = list(filter(lambda x: x != '<end>', prediction))
prediction = "".join(prediction)
print("MODEL OUTPUT", prediction)