In [None]:
import torch
import seqgen.seq_gen as g
import random
import matplotlib.pyplot as plt
import seaborn as sns
from seqgen.model import rnn
from seqgen.vocabulary import *
from seqgen.model import transformer, embedding
from seqgen.datasets.sequences import *
from seqgen.datasets.realdata import RealSequencesDataset

torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

In [None]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

In [None]:
use_real_dataset=False
lr=1e-2
num_layers=3
embedding_dim=128
batch_size=256
max_length=50
heads=8
dropout=0

In [None]:
vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt", vocab_file="vocab_in.pkl")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt", vocab_file="vocab_out.pkl")

if use_real_dataset:
    dataset = RealSequencesDataset(filename="data/train/label.txt", vocab_in=vocab_in, vocab_out=vocab_out, max_length=max_length-1, batch_size=batch_size, device=device)
else:
    dataset = SyntheticSequenceDataset(vocab_in, vocab_out, max_length, batch_size, continue_prob=0.95, additional_eos=True, device=device)

In [None]:
input_seqs, coordinates, target_seqs = dataset[0]
input_seqs.shape, coordinates.shape, target_seqs.shape

In [None]:
print(input_seqs[0, :-1])
print(target_seqs[0, :-1])
print(target_seqs[0, 1:])

In [None]:
def permutate_tokens(input_seq):
    # Get the first index where tensor has an SOS or EOS token
    sos_idx = list(input_seq).index(0)
    eos_idx = list(input_seq).index(1)
    # permutate all elements that are not SOS or EOS
    idx_permuted = torch.cat([torch.arange(0, sos_idx+1), (torch.randperm(eos_idx - sos_idx - 1) + sos_idx+1), torch.arange(eos_idx, max_length+1)])
    return idx_permuted

In [None]:
len(vocab_in), len(vocab_out), torch.max(input_seqs[:, :-1]), torch.max(target_seqs[:, :-1])

# The Transformer

In [None]:
load_from_checkpoint = False
checkpoint_file = "transformer_temp2.pt"

# Transformer model
model = transformer.PyTorchTransformerDecoder(
    encoder_embedding_type=embedding.EmbeddingType.COORDS_DIRECT,
    src_vocab_size=len(vocab_in),
    trg_vocab_size=len(vocab_out),
    embedding_dim=embedding_dim,
    num_layers=num_layers,
    heads=heads,
    dropout=dropout,
    src_pad_idx=2,
    trg_pad_idx=2,
    device=device
).to(device)

# Initialize optimizer for encoder and decoder
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

# Loss function
criterion = torch.nn.NLLLoss(ignore_index=2)

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Run the feature sequences through the model
output = model(input_seqs[:, :-1], target_seqs[:, :-1], coordinates[:, :-1])

In [None]:
# Get the predicted classes of the model
topv, topi = output.topk(1, dim=2)
output.shape, topi.shape, topv.shape

In [None]:
loss = 0.0
for i in range(max_length):
    _loss = criterion(output[:, i, :], target_seqs[:, i])
    if not _loss.isnan():
        loss += _loss
loss.item() / max_length

In [None]:
len(vocab_in), len(vocab_out), torch.max(input_seqs[:, :-1]), torch.max(target_seqs[:, :-1])

# Training

In [None]:
history = []
accuracies = []

for epoch in range(10000):
    # Set gradients of all model parameters to zero
    optimizer.zero_grad()

    # Initialize loss
    loss = torch.tensor(0.0).to(device)
    accuracy = 0.0

    ##############################
    #    TRANSFORMER TRAINING    #
    ############################## 
    
    # Get a batch of training data
    input_seqs, coordinates, target_seqs = dataset[0]
    
    # Run the input sequences through the model
    output = model(input_seqs[:, :-1], target_seqs[:, :-1], coordinates[:, :-1])
    
    # Iterate over sequence positions to compute the loss
    for i in range(max_length-1):
        # Get the predicted classes of the model
        topv, topi = output[:, i, :].topk(1)
        _loss = criterion(output[:, i, :], target_seqs[:, i+1])
        if not _loss.isnan():
            loss += _loss
            mask = target_seqs[:, i+1] != 2
            accuracy += float((topi.squeeze()[mask] == target_seqs[mask, i+1]).sum() / (target_seqs[mask].size(0)*(target_seqs[mask].size(1)-2)))
    
    history.append(loss.item())
    accuracies.append(accuracy)
    
    print_every = 10
    if not epoch % print_every:
        _accuracy = sum(accuracies[-print_every:]) / print_every
        lr = scheduler.get_last_lr()[0]
        print(f"LOSS after epoch {epoch}", loss.item() / (target_seqs.size(1)), "LR", lr, "ACCURACY", _accuracy)

    ######################
    #   WEIGHTS UPDATE   #
    ######################
    
    # Compute gradient
    loss.backward()
    accuracy = 0.0

    # Update weights of encoder and decoder
    optimizer.step()

#### Save model history

In [None]:
import pickle
from datetime import datetime

model_data = {
    "history": history,
    "accuracy": accuracies,
    "lr": lr,
    "num_layers": num_layers,
    "embedding_dim": embedding_dim,
    "batch_size": batch_size,
    "max_length": max_length,
    "heads": heads,
    "dropout": dropout,
}

now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d_%H-%M-%S")

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    "embedding_dim": embedding_dim,
    "batch_size": batch_size,
    "max_length": max_length,
    "num_layers": num_layers,
    "heads": heads,
    "dropout": dropout,
}, "transformer_" + date_time + ".pt")


with open("training_" + date_time + '.pkl', 'wb') as f:
    pickle.dump(model_data, f)

## Make predictions

We run our input sequences through the model and get output seuences. Then we decode the output sequences with the Vocabulary class and get our final latex code.

In [None]:
def predict(input_seqs, coordinates, target_seqs):
    vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
    vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

    with torch.no_grad():
        output = model(input_seqs.to(device), target_seqs.to(device), coordinates.to(device))
        # Get the predicted classes of the model
        topv, topi = output.topk(1, dim=2)
        
        return topi.squeeze()
    
def predict_sequentially(input_seqs, coordinates):
    prediction = torch.zeros((input_seqs.size(0), input_seqs.size(1)-1)).to(torch.int64)
    for i in range(max_length-1):
        output = predict(input_seqs, coordinates, prediction)
        prediction[:, i] = output[:, i]
    return prediction

In [None]:
prediction = predict_sequentially(input_seqs, coordinates)
prediction.shape

In [None]:
# Pick random sequence and its prediction from the model
import random

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

predictions = predict(input_seqs, coordinates, target_seqs)

i = random.randint(0, predictions.size(0)-1)
print("MODEL INPUT", vocab_in.decode_sequence(input_seqs[i, 1:].cpu().numpy()))
print("MODEL OUTPUT", vocab_out.decode_sequence(predictions[i, :-1].cpu().numpy()))
print("TARGET OUTPUT", vocab_out.decode_sequence(target_seqs[i, 1:].cpu().numpy()))

In [None]:
prediction = vocab_out.decode_sequence(predictions[i].cpu().numpy())
prediction = list(filter(lambda x: x != '<end>', prediction))
prediction = "".join(prediction)
print("MODEL OUTPUT", prediction)

In [None]:
predict_sequentially(input_seqs[0:3], coordinates[0:3])

In [None]:
target_seqs[0:3, 1:]

## Prediction for permutated sequences

In [None]:
def generate_permutated_batch(input_seq, coordinates):
    seqs = torch.zeros((5, input_seq.size(0))).to(torch.int64)
    coords = torch.zeros((5, coordinates.size(0), coordinates.size(1)))
    for i in range(5):
        idx_permutated = permutate_tokens(input_seq)
        seqs[i, :] = input_seq[idx_permutated]
        coords[i, :] = coordinates[idx_permutated]
    return seqs, coords

In [None]:
input_permutated, coords_permutated = generate_permutated_batch(input_seqs[0], coordinates[0])
input_permutated

In [None]:
predict_sequentially(input_permutated, coords_permutated)

In [None]:
target_seqs[0, 1:]