In [1]:
import torch
import seqgen.seq_gen as g
import random
import matplotlib.pyplot as plt
import seaborn as sns
from seqgen.model import seq2seq_lstm
from seqgen.vocabulary import *
from seqgen.model import transformer
from seqgen.datasets.sequences import *

torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

In [2]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

Device cpu


In [3]:
lr=1e-3
num_layers=3
embedding_dim=32
batch_size=64
max_length=50
heads=8
dropout=0

In [4]:
vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

dataset = SyntheticSequenceDataset(vocab_in, vocab_out, max_length, batch_size, continue_prob=0.99, device="cpu")

In [5]:
input_seqs, coordinates, target_seqs = dataset[0]
coordinate_encoding = seq2seq_lstm.get_coordinate_encoding(coordinates, d=embedding_dim, max_length=max_length)
input_seqs.shape, coordinates.shape, target_seqs.shape, coordinate_encoding.shape

  input_seqs = torch.tensor(features[:, :, 0]).to(torch.int64)
  coordinates = torch.tensor(features[:, :, 1:])


(torch.Size([64, 51]),
 torch.Size([64, 51, 4]),
 torch.Size([64, 51]),
 torch.Size([64, 51, 32]))

In [6]:
print(input_seqs[0, :-1])
print(target_seqs[0, :-1])
print(target_seqs[0, 1:])

tensor([ 0, 16, 12,  4, 16,  8,  8, 15,  8,  8,  3, 11, 14,  4,  8, 10, 12, 16,
        12,  6,  7,  5, 11, 13,  3,  6,  5, 10, 11, 11,  6,  4,  5, 15, 11,  6,
        11, 14,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1])
tensor([ 0,  3, 10,  4, 11, 13,  8, 17,  8,  8,  3, 11, 11, 11,  8, 17, 18, 12,
        18,  6,  7,  5,  6,  5, 12, 18,  5,  6, 11,  8,  6, 10, 14, 12, 11,  4,
         4, 14,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1])
tensor([ 3, 10,  4, 11, 13,  8, 17,  8,  8,  3, 11, 11, 11,  8, 17, 18, 12, 18,
         6,  7,  5,  6,  5, 12, 18,  5,  6, 11,  8,  6, 10, 14, 12, 11,  4,  4,
        14,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1])


In [7]:
def permutate_tokens(input_seq):
    # Get the first index where tensor has an SOS or EOS token
    sos_idx = list(input_seq).index(0)
    eos_idx = list(input_seq).index(1)
    # permutate all elements that are not SOS or EOS
    idx_permuted = torch.cat([torch.arange(0, sos_idx+1), (torch.randperm(eos_idx - sos_idx - 1) + sos_idx+1), torch.arange(eos_idx, max_length+1)])
    return idx_permuted

# The Transformer

In [8]:
load_from_checkpoint = False
checkpoint_file = "model_2023-01-15_09-17-53.pt"

# Transformer model
model = transformer.Transformer(
    src_vocab_size=len(vocab_in),
    trg_vocab_size=len(vocab_out),
    embedding_dim=embedding_dim,
    num_layers=num_layers,
    heads=heads,
    dropout=dropout,
    src_pad_idx=1e10,
    trg_pad_idx=1e10,
    device=device
).to(device)

# Initialize optimizer for encoder and decoder
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [9]:
# Run the feature sequences through the model
output = model(input_seqs, target_seqs, coordinate_encoding)

In [10]:
# Get the predicted classes of the model
topv, topi = output.topk(1, dim=2)
output.shape, topi.shape, topv.shape

(torch.Size([64, 51, 25]), torch.Size([64, 51, 1]), torch.Size([64, 51, 1]))

In [11]:
loss = 0.0
for i in range(max_length):
    loss += criterion(output[:, i, :], target_seqs[:, i])
loss.item() / max_length

3.6018032836914062

# Training

In [None]:
history = []
accuracies = []

for epoch in range(50000):    
    # Get a batch of training data
    input_seqs, coordinates, target_seqs = dataset[0]
    positions_coords = seq2seq_lstm.get_coordinate_encoding(coordinates, max_length=max_length, d=embedding_dim, device=device)
    positions_targets = seq2seq_lstm.get_position_encoding(max_length, embedding_dim, device=device).repeat(batch_size, 1, 1)
    
    # Set gradients of all model parameters to zero
    optimizer.zero_grad()

    # Initialize loss
    loss = 0
    accuracy = 0.0

    #####################
    #    TRANSFORMER    #
    #####################
    
    # Run the input sequences through the model
    output = model(input_seqs[:, :-1], target_seqs[:, :-1], positions_coords[:, :-1])
    
    # Iterate over sequence positions to compute the loss
    for i in range(max_length-1):
        # Get the predicted classes of the model
        topv, topi = output[:, i, :].topk(1)
        loss += criterion(output[:, i, :], target_seqs[:, i+1])
        accuracy += float((topi.squeeze() == target_seqs[:, i+1]).sum() / (target_seqs.size(0)*(target_seqs.size(1)-2)))
    
    history.append(loss.item())
    accuracies.append(accuracy)
    
    print_every = 10
    if not epoch % print_every:
        _accuracy = sum(accuracies[-print_every:]) / print_every
        lr = scheduler.get_last_lr()[0]
        print(f"LOSS after epoch {epoch}", loss.item() / (target_seqs.size(1)), "LR", lr, "ACCURACY", _accuracy)

    # Compute gradient
    loss.backward()
    accuracy = 0.0

    # Update weights of encoder and decoder
    optimizer.step()

LOSS after epoch 0 3.3296122831456803 LR 0.001 ACCURACY 0.003730867357808165
LOSS after epoch 10 2.6482301599839153 LR 0.001 ACCURACY 0.17551020530227107
LOSS after epoch 20 2.262874827665441 LR 0.001 ACCURACY 0.2623724506585859
LOSS after epoch 30 2.156153510598575 LR 0.001 ACCURACY 0.24626913435931783
LOSS after epoch 40 2.1638811896829044 LR 0.001 ACCURACY 0.2733737262664363
LOSS after epoch 50 1.933803633147595 LR 0.001 ACCURACY 0.26594387901131994
LOSS after epoch 60 2.082317875880821 LR 0.001 ACCURACY 0.267442603391828
LOSS after epoch 70 2.086745617436428 LR 0.001 ACCURACY 0.256664542274666
LOSS after epoch 80 1.9946737850413603 LR 0.001 ACCURACY 0.26575255258067043
LOSS after epoch 90 2.0696295943914675 LR 0.001 ACCURACY 0.2607142875553109
LOSS after epoch 100 1.9815061980602788 LR 0.001 ACCURACY 0.28431122513720763
LOSS after epoch 110 1.9188772463331036 LR 0.001 ACCURACY 0.28737244994263167
LOSS after epoch 120 2.127329059675628 LR 0.001 ACCURACY 0.2806760215375107
LOSS after

#### Save model history

In [None]:
import pickle
from datetime import datetime

model_data = {
    "history": history,
    "lr": lr,
    "num_layers": num_layers,
    "embedding_dim": embedding_dim,
    "batch_size": batch_size,
    "max_length": max_length,
    "heads": heads,
    "dropout": dropout,
}

now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d_%H-%M-%S")

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    "history": history,
    "lr": lr,
    "embedding_dim": embedding_dim,
    "batch_size": batch_size,
    "max_length": max_length,
    "num_layers": num_layers,
    "heads": heads,
    "dropout": dropout,
}, "transformer_" + date_time + ".pt")


with open("training_" + date_time + '.pkl', 'wb') as f:
    pickle.dump(model_data, f)

## Make predictions

We run our input sequences through the model and get output seuences. Then we decode the output sequences with the Vocabulary class and get our final latex code.

In [None]:
def predict(input_seqs, coordinates, target_seqs):
    vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
    vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

    with torch.no_grad():
        coordinate_encoding = seq2seq_lstm.get_coordinate_encoding(coordinates, d=embedding_dim, max_length=max_length)
        output = model(input_seqs, target_seqs, coordinate_encoding)
        # Get the predicted classes of the model
        topv, topi = output.topk(1, dim=2)
        
        return topi.squeeze()
    
def predict_sequentially(input_seqs, coordinates):
    prediction = torch.zeros((input_seqs.size(0), input_seqs.size(1)-1)).to(torch.int64)
    for i in range(max_length-1):
        output = predict(input_seqs, coordinates, prediction)
        prediction[:, i] = output[:, i]
    return prediction

In [None]:
prediction = predict_sequentially(input_seqs, coordinates)
prediction.shape

In [None]:
# Pick random sequence and its prediction from the model
import random

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

predictions = predict(input_seqs, coordinates, target_seqs)

i = random.randint(0, predictions.size(0)-1)
print("MODEL INPUT", vocab_in.decode_sequence(input_seqs[i, 1:].cpu().numpy()))
print("MODEL OUTPUT", vocab_out.decode_sequence(predictions[i, :-1].cpu().numpy()))
print("TARGET OUTPUT", vocab_out.decode_sequence(target_seqs[i, 1:].cpu().numpy()))

In [None]:
prediction = vocab_out.decode_sequence(predictions[i].cpu().numpy())
prediction = list(filter(lambda x: x != '<end>', prediction))
prediction = "".join(prediction)
print("MODEL OUTPUT", prediction)

In [None]:
predict_sequentially(input_seqs[0:3], coordinates[0:3])

In [None]:
target_seqs[0:3, 1:]

## Prediction for permutated sequences

In [None]:
def generate_permutated_batch(input_seq, coordinates):
    seqs = torch.zeros((5, input_seq.size(0))).to(torch.int64)
    coords = torch.zeros((5, coordinates.size(0), coordinates.size(1)))
    for i in range(5):
        idx_permutated = permutate_tokens(input_seq)
        seqs[i, :] = input_seq[idx_permutated]
        coords[i, :] = coordinates[idx_permutated]
    return seqs, coords

In [None]:
input_permutated, coords_permutated = generate_permutated_batch(input_seqs[0], coordinates[0])
input_permutated

In [None]:
predict_sequentially(input_permutated, coords_permutated)

In [None]:
target_seqs[0, 1:]