In [None]:
import torch
import seqgen.seq_gen as g
import random
import matplotlib.pyplot as plt
import seaborn as sns
from seqgen.model import seq2seq_lstm
from seqgen.vocabulary import *
from seqgen.model import transformer

torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

In [None]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

In [None]:
lr = 1e-3
num_layers=1
embedding_dim=64
batch_size=4
max_length=100
heads=4
dropout=0

In [None]:
features, target_seqs = g.generate_synthetic_training_data(batch_size, max_length=max_length, device=device, continue_prob=0.999, swap_times=0)
input_seqs = torch.tensor(features[:, :, 0]).to(torch.int64)
coordinates = torch.tensor(features[:, :, 1:])
positions_coords = seq2seq_lstm.get_coordinate_encoding(coordinates, max_length=max_length, d=embedding_dim, device=device)
positions_targets = seq2seq_lstm.get_position_encoding(max_length, embedding_dim, device=device).repeat(batch_size, 1, 1)

In [None]:
features.shape, input_seqs.shape, coordinates.shape, target_seqs.shape, positions_coords.shape, positions_targets.shape

# The Transformer

In [None]:
load_from_checkpoint = False
checkpoint_file = "model_2023-01-15_09-17-53.pt"

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

# Transformer model
model = transformer.Transformer(
    src_vocab_size=len(vocab_in),
    trg_vocab_size=len(vocab_out),
    embedding_dim=embedding_dim,
    num_layers=num_layers,
    heads=heads,
    dropout=dropout,
    src_pad_idx=1e10,
    trg_pad_idx=1e10,
    device=device
).to(device)

# Initialize optimizer for encoder and decoder
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Run the feature sequences through the model
output = model(input_seqs, target_seqs)

In [None]:
# Get the predicted classes of the model
topv, topi = output.topk(1, dim=2)
output.shape, topi.shape, topv.shape

In [None]:
loss = 0.0
for i in range(max_length):
    loss += criterion(output[:, i, :], target_seqs[:, i])
loss.item() / max_length

# Training

In [None]:
history = []
accuracies = []

for epoch in range(25):    
    # Get a batch of training data
    features, target_seqs = g.generate_synthetic_training_data(batch_size, max_length=max_length, continue_prob=0.999, device=device, swap_times=max_length)
    features = features.to(device)
    target_seqs = target_seqs.to(device)
    input_seqs = torch.tensor(features[:, :, 0]).to(torch.int64)
    coordinates = torch.tensor(features[:, :, 1:])
    positions_coords = seq2seq_lstm.get_coordinate_encoding(coordinates, max_length=max_length, d=embedding_dim, device=device)
    positions_targets = seq2seq_lstm.get_position_encoding(max_length, embedding_dim, device=device).repeat(batch_size, 1, 1)
    
    # Set gradients of all model parameters to zero
    optimizer.zero_grad()

    # Initialize loss
    loss = 0
    accuracy = 0.0

    #####################
    #    TRANSFORMER    #
    #####################
    
    # Run the input sequences through the model
    output = model(input_seqs, target_seqs)
    
    # Iterate over sequence positions to compute the loss
    for i in range(max_length):
        # Get the predicted classes of the model
        topv, topi = output[:, i, :].topk(1)
        loss += criterion(output[:, i, :], target_seqs[:, i])
        accuracy += float((topi.squeeze() == target_seqs[:, i]).sum() / (target_seqs.size(0)*(target_seqs.size(1))))
    
    history.append(loss.item())
    accuracies.append(accuracy)
    
    print_every = 1
    if not epoch % print_every:
        _accuracy = sum(accuracies[-print_every:]) / print_every
        print(f"LOSS after epoch {epoch}", loss.item() / (target_seqs.size(1)), "ACCURACY", _accuracy)

    # Compute gradient
    loss.backward()
    accuracy = 0.0

    # Update weights of encoder and decoder
    optimizer.step()

#### Save model history

In [None]:
import pickle
from datetime import datetime

model_data = {
    "history": history,
    "lr": lr,
    "num_layers": num_layers,
    "embedding_dim": embedding_dim,
    "batch_size": batch_size,
    "max_length": max_length,
    "heads": heads,
    "dropout": dropout,
}

now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d_%H-%M-%S")

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    "history": history,
    "lr": lr,
    "embedding_dim": embedding_dim,
    "batch_size": batch_size,
    "max_length": max_length,
    "num_layers": num_layers,
    "heads": heads,
    "dropout": dropout,
}, "transformer_" + date_time + ".pt")


with open("training_" + date_time + '.pkl', 'wb') as f:
    pickle.dump(model_data, f)

## Make predictions

We run our input sequences through the model and get output seuences. Then we decode the output sequences with the Vocabulary class and get our final latex code.

In [None]:
def predict(input_seqs, coordinates, target_seqs):
    vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
    vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

    with torch.no_grad():        
        output = model(input_seqs, target_seqs)
        # Get the predicted classes of the model
        topv, topi = output.topk(1, dim=2)
        print(output.shape, topv.shape, topi.shape)
        
        return topi.squeeze()

In [None]:
prediction = predict(input_seqs, coordinates, target_seqs)
prediction.shape

In [None]:
# Pick random sequence and its prediction from the model
import random

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

predictions = predict(input_seqs, coordinates, target_seqs)

i = random.randint(0, predictions.size(0)-1)
print("MODEL INPUT", vocab_in.decode_sequence(input_seqs[i].cpu().numpy()))
print("MODEL OUTPUT", vocab_out.decode_sequence(predictions[i].cpu().numpy()))
print("TARGET OUTPUT", vocab_out.decode_sequence(target_seqs[i].cpu().numpy()))

In [None]:
prediction = vocab_out.decode_sequence(predictions[i].cpu().numpy())
prediction = list(filter(lambda x: x != '<end>', prediction))
prediction = "".join(prediction)
print("MODEL OUTPUT", prediction)

In [None]:
sns.heatmap(attention_matrix[0])