In [1]:
import torch
import seqgen.seq_gen as g
import random
import matplotlib.pyplot as plt
import seaborn as sns

torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

In [2]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

Device cpu


  return torch._C._cuda_getDeviceCount() > 0


In [3]:
features, target_seqs = g.generate_synthetic_training_data(8, max_length=10, device=device, continue_prob=0.997, swap_times=0)
input_seqs = torch.Tensor(features[:, :, 0]).to(torch.int64)
coordinates = torch.Tensor(features[:, :, 1:])

In [4]:
features.shape, input_seqs.shape, coordinates.shape, target_seqs.shape

(torch.Size([8, 10, 5]),
 torch.Size([8, 10]),
 torch.Size([8, 10, 4]),
 torch.Size([8, 10]))

# The Encoder

In [5]:
from seqgen.model import seq2seq_lstm
from seqgen.vocabulary import *

In [6]:
lr = 1e-2
num_layers=1
embedding_dim = 32
hidden_size=32
batch_size=8
max_length=10
bidirectional=True

load_from_checkpoint = False
checkpoint_file = "model_len25_biy_layers3.pt"

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

encoder = seq2seq_lstm.EncoderRNN(vocab_size=len(vocab_in), embedding_dim=embedding_dim, num_layers=num_layers, max_length=max_length, hidden_size=hidden_size, bidirectional=bidirectional, pos_encoding=False).to(features.device)
#attn = seq2seq_lstm.AdditiveAttention(hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, max_length=max_length).to(features.device)
decoder = seq2seq_lstm.DecoderRNN2(embedding_dim=embedding_dim, num_layers=num_layers, max_length=max_length, hidden_size=hidden_size, vocab_size=len(vocab_out), bidirectional=bidirectional, pos_encoding=False).to(features.device)

# Initialize optimizer for encoder and decoder
encoder_optimizer = torch.optim.SGD(encoder.parameters(), lr=lr)
decoder_optimizer = torch.optim.SGD(decoder.parameters(), lr=lr)
#attn_optimizer = torch.optim.SGD(attn.parameters(), lr=lr)
#positions = seq2seq_lstm.get_position_encoding(max_length, embedding_dim, device=device)
positions = seq2seq_lstm.get_coordinate_encoding(coordinates, d=embedding_dim, device=device)

# Loss function
criterion = torch.nn.NLLLoss()

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file)
    encoder.load_state_dict(checkpoint['encoder_model_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_model_state_dict'])
    encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer_state_dict'])
    decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer_state_dict'])
    num_layers = checkpoint['num_layers']
    embedding_dim = checkpoint['embedding_dim']
    hidden_size = checkpoint['hidden_size']
    bidirectional = checkpoint['bidirectional']



In [7]:
# Initialize the encoder hidden state and cell state with zeros
hn = encoder.initHidden(input_seqs.shape[0], device=features.device)
cn = encoder.initHidden(input_seqs.shape[0], device=features.device)

_hidden_size = hidden_size * 2 if bidirectional else hidden_size
encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size*num_layers)).to(device)
encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)

# Iterate over the sequence words and run every word through the encoder
for i in range(input_seqs.shape[1]):
    # Run the i-th word of the input sequence through the encoder.
    # As a result we will get the prediction (output), the hidden state and the cell state.
    # The hidden state and cell state will be used as inputs in the next round
    print(f"Run word {i+1} of all {input_seqs.shape[0]} sequences through the encoder")
    output, (hn, cn) = encoder(input_seqs[:, i].unsqueeze(dim=1), coordinates[:, i], positions[:, i:i+1], (hn, cn))
    encoder_outputs[:, i:i+1, :] = output
    encoder_hidden_states[:, i, :] = seq2seq_lstm.concat_hidden_states(hn)

Run word 1 of all 8 sequences through the encoder
Run word 2 of all 8 sequences through the encoder
Run word 3 of all 8 sequences through the encoder
Run word 4 of all 8 sequences through the encoder
Run word 5 of all 8 sequences through the encoder
Run word 6 of all 8 sequences through the encoder
Run word 7 of all 8 sequences through the encoder
Run word 8 of all 8 sequences through the encoder
Run word 9 of all 8 sequences through the encoder
Run word 10 of all 8 sequences through the encoder


In [8]:
output.shape, hn.shape, cn.shape, encoder_hidden_states.shape, encoder_outputs.shape

(torch.Size([8, 1, 64]),
 torch.Size([2, 8, 32]),
 torch.Size([2, 8, 32]),
 torch.Size([8, 10, 64]),
 torch.Size([8, 10, 64]))

# The Decoder

In [9]:
loss = 0

# Iterate over words of target sequence and run words through the decoder.
# This will produce a prediction for the next word in the sequence
for i in range(0, target_seqs.size(1)):
    print(f"Run word {i+1} through decoder", hn.shape, encoder_hidden_states.shape)
    output, (hn, cn), attention = decoder(
        x=target_seqs[:, i].unsqueeze(dim=1),
        coordinates=coordinates[:, i],
        annotations=encoder_hidden_states,
        position=positions[:, i:i+1],
        hidden=(hn, cn)
    )
    loss += criterion(output.squeeze(), target_seqs[:, i])

print("LOSS", loss.item() / max_length)

Run word 1 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
Run word 2 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
Run word 3 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
Run word 4 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
Run word 5 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
Run word 6 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
Run word 7 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
Run word 8 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
Run word 9 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
Run word 10 through decoder torch.Size([2, 8, 32]) torch.Size([8, 10, 64])
LOSS 3.26245231628418


# Attention

In [10]:
#context_vector, attention = attn(hn, encoder_hidden_states, logging=True)
#context_vector.shape, attention.shape

# Training

In [None]:
history = []
accuracies = []

for epoch in range(100000):
    # With a certain chance present the model the true predictions
    # instead of its own predictions in the next iteration
    use_teacher_forcing_prob = 0.5
    use_teacher_forcing = random.random() < use_teacher_forcing_prob
    
    # Get a batch of trianing data
    features, target_seqs = g.generate_synthetic_training_data(batch_size, max_length=max_length, continue_prob=0.99, device=device, swap_times=0)
    features = features.to(device)
    target_seqs = target_seqs.to(device)
    input_seqs = torch.Tensor(features[:, :, 0]).to(torch.int64)
    coordinates = torch.Tensor(features[:, :, 1:])

    # Initialize the encoder hidden state and cell state with zeros
    hn_enc = encoder.initHidden(input_seqs.shape[0], device=features.device)
    cn_enc = encoder.initHidden(input_seqs.shape[0], device=features.device)
    
    # Initialize encoder outputs tensor
    last_n_states = 2 if bidirectional else 1
    _hidden_size = hidden_size * 2 if bidirectional else hidden_size
    encoder_hidden_states = torch.zeros((batch_size, max_length, _hidden_size*num_layers)).to(device)
    encoder_outputs = torch.zeros((batch_size, max_length, _hidden_size)).to(device)
    
    # Set gradients of all model parameters to zero
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    #attn_optimizer.zero_grad()

    # Initialize loss
    loss = 0
    
    ####################
    #     ENCODING     #
    ####################

    # Iterate over the sequence words and run every word through the encoder
    for i in range(input_seqs.shape[1]):
        # Run the i-th word of the input sequence through the encoder.
        # As a result we will get the prediction (output), the hidden state (hn) and the cell state (cn).
        # The hidden state and cell state will be used as inputs in the next round
        output, (hn_enc, cn_enc) = encoder(
            input_seqs[:, i].unsqueeze(dim=1),
            coordinates[:, i],
            positions[:, i:i+1],
            (hn_enc, cn_enc)
        )
        # Save encoder outputs and states for current word
        encoder_outputs[:, i:i+1, :] = output
        encoder_hidden_states[:, i, :] = seq2seq_lstm.concat_hidden_states(hn)

    ####################
    #     DECODING     #
    ####################
    
    accuracy = 0.0

    # The first words that we be presented to the model is the '<start>' token
    prediction = target_seqs[:, 0]
    
    # The initial hidden state of the decoder is the final hidden state of the encoder
    hn_dec, cn_dec = hn_enc, cn_enc
    
    # Iterate over words of target sequence and run words through the decoder.
    # This will produce a prediction for the next word in the sequence
    for i in range(1, target_seqs.size(1)):
        # Run word i through decoder and get word i+1 and the new hidden state as outputs
        if use_teacher_forcing:
            output, (hn_dec, cn_dec), attention = decoder(
                x=prediction.unsqueeze(dim=1),
                coordinates=coordinates[:, i-1],
                annotations=encoder_hidden_states,
                position=positions[:, i:i+1],
                hidden=(hn_dec, cn_dec)
            )
        else:
            output, (hn_dec, cn_dec), attention = decoder(
                x=prediction.unsqueeze(dim=1),
                coordinates=coordinates[:, i-1],
                annotations=encoder_hidden_states,
                position=positions[:, i:i+1],
                hidden=(hn_dec, cn_dec)
            )

            # Get the predicted classes of the model
            topv, topi = output.topk(1)
            prediction = topi.squeeze()    
        loss += criterion(output.squeeze(), target_seqs[:, i])
        accuracy += float((prediction == target_seqs[:, i]).sum() / (target_seqs.size(0)*target_seqs.size(1)))
    
    history.append(loss.item())
    accuracies.append(accuracy)
    
    print_every = 10
    if not epoch % print_every:
        _accuracy = sum(accuracies[-print_every:]) / print_every
        print(f"LOSS after epoch {epoch}", loss.item() / (target_seqs.size(1)), "ACCURACY", _accuracy)

    # Compute gradient
    loss.backward()
    accuracy = 0.0

    # Update weights of encoder and decoder
    encoder_optimizer.step()
    decoder_optimizer.step()
    #attn_optimizer.step()

LOSS after epoch 0 2.9097002029418944 ACCURACY 0.0037500000558793544
LOSS after epoch 10 2.826010894775391 ACCURACY 0.027500000409781934
LOSS after epoch 20 2.7654678344726564 ACCURACY 0.05250000115483999
LOSS after epoch 30 2.6238496780395506 ACCURACY 0.06875000102445483
LOSS after epoch 40 2.6758474349975585 ACCURACY 0.06750000100582838
LOSS after epoch 50 2.6237009048461912 ACCURACY 0.027500000409781934
LOSS after epoch 60 2.601553535461426 ACCURACY 0.05125000076368451
LOSS after epoch 70 2.4524723052978517 ACCURACY 0.06625000098720193
LOSS after epoch 80 2.543402671813965 ACCURACY 0.05625000083819032
LOSS after epoch 90 2.545836067199707 ACCURACY 0.04625000068917871
LOSS after epoch 100 2.5187891006469725 ACCURACY 0.052500000782310964
LOSS after epoch 110 2.491815948486328 ACCURACY 0.05625000083819032
LOSS after epoch 120 2.487213706970215 ACCURACY 0.06625000098720193
LOSS after epoch 130 2.4760505676269533 ACCURACY 0.07125000106170773
LOSS after epoch 140 2.323046875 ACCURACY 0.07

#### Save model history

In [None]:
import pickle
from datetime import datetime

model_data = {
    "history": history,
    "lr": lr,
    "embedding_dim": embedding_dim,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "max_length": max_length
}

now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d_%H-%M-%S")

torch.save({
    'epoch': epoch,
    'encoder_model_state_dict': encoder.state_dict(),
    'decoder_model_state_dict': decoder.state_dict(),
    'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
    'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
    'loss': loss,
    "history": history,
    "lr": lr,
    "embedding_dim": embedding_dim,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "max_length": max_length,
    "num_layers": num_layers,
    "bidirectional": bidirectional,
}, "model_" + date_time + ".pt")


with open("training_" + date_time + '.pkl', 'wb') as f:
    pickle.dump(model_data, f)

## Make predictions

We run our input sequences through the model and get output seuences. Then we decode the output sequences with the Vocabulary class and get our final latex code.

In [None]:
def predict(input_seqs, coordinates, target_seqs):
    vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
    vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

    predictions = torch.zeros(target_seqs.shape)
    attention_matrix = torch.zeros((input_seqs.shape[0], input_seqs.shape[1], input_seqs.shape[1]))

    with torch.no_grad():
        # Initialize the encoder hidden state and cell state with zeros
        hn = encoder.initHidden(input_seqs.shape[0], device=features.device)
        cn = encoder.initHidden(input_seqs.shape[0], device=features.device)
        
        _hidden_size = hidden_size * 2 if bidirectional else hidden_size
        last_n_states = 2 if bidirectional else 1
        encoder_hidden_states = torch.zeros((input_seqs.shape[0], max_length, _hidden_size)).to(device)
        encoder_outputs = torch.zeros((input_seqs.shape[0], max_length, _hidden_size)).to(device)

        # Iterate over the sequence words and run every word through the encoder
        for i in range(input_seqs.size(1)):
            output, (hn, cn) = encoder(
                input_seqs[:, i].unsqueeze(dim=1),
                coordinates[:, i],
                positions[:, i:i+1],
                (hn, cn)
            )
            encoder_outputs[:, i:i+1, :] = output
            encoder_hidden_states[:, i:i+1, :] = seq2seq_lstm.concat_hidden_states(hn[-last_n_states:]).unsqueeze(dim=1)

        # Predict tokens of the target sequence by running the hidden state through
        # the decoder
        for i in range(0, target_seqs.size(1)):
            output, (hn, cn), attention = decoder(
                target_seqs[:, i].unsqueeze(dim=1),
                coordinates[:, i],
                encoder_hidden_states,
                positions[:, i:i+1],
                (hn, cn)
            )
            # Select the indices of the most likely tokens
            predicted_char = torch.argmax(output, dim=2)
            predictions[:, i] = torch.argmax(output, dim=2).squeeze()
            attention_matrix[:, :, i:i+1] = attention
        
        return predictions, attention_matrix

In [None]:
prediction, attention_matrix = predict(input_seqs[0:1], coordinates[0:1], target_seqs[0:1])
prediction.shape, attention_matrix.shape

In [None]:
plt.imshow(attention_matrix[0])

In [None]:
in_swapped = g.random_swap(input_seqs[0], i=2).unsqueeze(dim=0)
coords_swapped = g.random_swap(coordinates[0], i=2).unsqueeze(dim=0)
prediction_swapped = predict(in_swapped, coords_swapped, target_seqs[0:1])

In [None]:
input_seqs[0:1] == in_swapped

In [None]:
prediction == prediction_swapped

In [None]:
# Pick random sequence and its prediction from the model
import random

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

predictions, attention_matrix = predict(input_seqs, coordinates, target_seqs)

i = random.randint(0, predictions.size(0))
print("MODEL INPUT", vocab_in.decode_sequence(input_seqs[i].cpu().numpy()))
print("MODEL OUTPUT", vocab_out.decode_sequence(predictions[i].cpu().numpy()))
print("TARGET OUTPUT", vocab_out.decode_sequence(target_seqs[i][1:].cpu().numpy()))

In [None]:
prediction = vocab_out.decode_sequence(predictions[i].cpu().numpy())
prediction = list(filter(lambda x: x != '<end>', prediction))
prediction = "".join(prediction)
print("MODEL OUTPUT", prediction)

In [None]:
sns.heatmap(attention_matrix[0])