In [44]:
import torch
import seqgen.seq_gen as g
import random
import matplotlib.pyplot as plt
import seaborn as sns
from seqgen.model import rnn
from seqgen.vocabulary import *
from seqgen.model import transformer, embedding
from seqgen.datasets.sequences import *
from seqgen.datasets.realdata import RealSequencesDataset

torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

Device cuda


In [46]:
use_real_dataset=False
lr=1e-5
num_layers=6
embedding_dim=128
batch_size=128
max_length=50
heads=8
dropout=0

In [47]:
vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt", vocab_file="vocab_in.pkl")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt", vocab_file="vocab_out.pkl")

if use_real_dataset:
    dataset = RealSequencesDataset(filename="data/train/label.txt", vocab_in=vocab_in, vocab_out=vocab_out, max_length=max_length-1, batch_size=batch_size, device=device)
else:
    dataset = SyntheticSequenceDataset(vocab_in, vocab_out, max_length, batch_size, continue_prob=0.95, additional_eos=True, device=device)

In [82]:
input_seqs, coordinates, target_seqs = dataset[0]
input_seqs.shape, coordinates.shape, target_seqs.shape

(torch.Size([128, 51]), torch.Size([128, 51, 4]), torch.Size([128, 51]))

In [83]:
print(input_seqs[0, :-1])
print(target_seqs[0, :-1])
print(target_seqs[0, 1:])

tensor([  0,  47,  31,  17,  36, 104,  91, 110,  16,   9,   7, 109,  52,  86,
         81, 104,  45, 112,  47,  74,  88,   4,  31, 104,   1,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2], device='cuda:0')
tensor([ 0, 59, 63, 85, 20, 74, 20, 97, 48, 20, 87, 23, 70, 20, 78, 20, 59, 53,
        23, 98, 65,  3, 96, 23, 78, 23, 75, 78, 20, 16, 66, 20, 83, 86, 65,  1,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
       device='cuda:0')
tensor([59, 63, 85, 20, 74, 20, 97, 48, 20, 87, 23, 70, 20, 78, 20, 59, 53, 23,
        98, 65,  3, 96, 23, 78, 23, 75, 78, 20, 16, 66, 20, 83, 86, 65,  1,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
       device='cuda:0')


In [7]:
def permutate_tokens(input_seq):
    # Get the first index where tensor has an SOS or EOS token
    sos_idx = list(input_seq).index(0)
    eos_idx = list(input_seq).index(1)
    # permutate all elements that are not SOS or EOS
    idx_permuted = torch.cat([torch.arange(0, sos_idx+1), (torch.randperm(eos_idx - sos_idx - 1) + sos_idx+1), torch.arange(eos_idx, max_length+1)])
    return idx_permuted

In [8]:
len(vocab_in), len(vocab_out), torch.max(input_seqs[:, :-1]), torch.max(target_seqs[:, :-1])

(113, 180, tensor(112, device='cuda:0'), tensor(98, device='cuda:0'))

# The Transformer

In [9]:
load_from_checkpoint = False
checkpoint_file = "transformer_temp2.pt"

# Transformer model
model = transformer.PyTorchTransformer(
    encoder_embedding_type=embedding.EmbeddingType.COORDS_DIRECT,
    src_vocab_size=len(vocab_in),
    trg_vocab_size=len(vocab_out),
    embedding_dim=embedding_dim,
    num_layers=num_layers,
    heads=heads,
    dropout=dropout,
    src_pad_idx=2,
    trg_pad_idx=2,
    device=device
).to(device)

# Initialize optimizer for encoder and decoder
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

# Loss function
criterion = torch.nn.NLLLoss(ignore_index=2)

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [10]:
# Run the feature sequences through the model
output = model(input_seqs[:, :-1], target_seqs[:, :-1], coordinates[:, :-1])

In [11]:
# Get the predicted classes of the model
topv, topi = output.topk(1, dim=2)
output.shape, topi.shape, topv.shape

(torch.Size([128, 50, 180]),
 torch.Size([128, 50, 1]),
 torch.Size([128, 50, 1]))

In [12]:
loss = 0.0
for i in range(max_length):
    _loss = criterion(output[:, i, :], target_seqs[:, i])
    if not _loss.isnan():
        loss += _loss
loss.item() / max_length

5.302947998046875

In [13]:
len(vocab_in), len(vocab_out), torch.max(input_seqs[:, :-1]), torch.max(target_seqs[:, :-1])

(113, 180, tensor(112, device='cuda:0'), tensor(98, device='cuda:0'))

# Training

In [None]:
history = []
accuracies = []

for epoch in range(10000):
    # Set gradients of all model parameters to zero
    optimizer.zero_grad()

    # Initialize loss
    loss = torch.tensor(0.0).to(device)
    accuracy = 0.0

    ##############################
    #    TRANSFORMER TRAINING    #
    ############################## 
    
    # Get a batch of training data
    input_seqs, coordinates, target_seqs = dataset[0]
    
    # Run the input sequences through the model
    output = model(input_seqs[:, :-1], target_seqs[:, :-1], coordinates[:, :-1])
    
    # Iterate over sequence positions to compute the loss
    for i in range(max_length-1):
        # Get the predicted classes of the model
        topv, topi = output[:, i, :].topk(1)
        _loss = criterion(output[:, i, :], target_seqs[:, i+1])
        if not _loss.isnan():
            loss += _loss
            mask = target_seqs[:, i+1] != 2
            accuracy += float((topi.squeeze()[mask] == target_seqs[mask, i+1]).sum() / (target_seqs[mask].size(0)*(target_seqs[mask].size(1)-2)))
    
    history.append(loss.item())
    accuracies.append(accuracy)
    
    print_every = 100
    if not epoch % print_every:
        _accuracy = sum(accuracies[-print_every:]) / print_every
        lr = scheduler.get_last_lr()[0]
        print(f"LOSS after epoch {epoch}", loss.item() / (target_seqs.size(1)), "LR", lr, "ACCURACY", _accuracy)

    ######################
    #   WEIGHTS UPDATE   #
    ######################
    
    # Compute gradient
    loss.backward()
    accuracy = 0.0

    # Update weights of encoder and decoder
    optimizer.step()

LOSS after epoch 0 2.95256072399663 LR 1e-05 ACCURACY 0.001987399298232049
LOSS after epoch 100 2.9609802844477633 LR 1e-05 ACCURACY 0.19256766066071576
LOSS after epoch 200 2.9139323515050553 LR 1e-05 ACCURACY 0.1921432506706333
LOSS after epoch 300 2.9232258516199447 LR 1e-05 ACCURACY 0.19410451809002552
LOSS after epoch 400 2.9300839293236827 LR 1e-05 ACCURACY 0.19282348835258745
LOSS after epoch 500 2.936972524605545 LR 1e-05 ACCURACY 0.19322228737641126
LOSS after epoch 600 2.947541779162837 LR 1e-05 ACCURACY 0.193400424978463
LOSS after epoch 700 2.9065234614353552 LR 1e-05 ACCURACY 0.1936121210947749
LOSS after epoch 800 2.9167974135454964 LR 1e-05 ACCURACY 0.19291146828676575
LOSS after epoch 900 2.9228045893650427 LR 1e-05 ACCURACY 0.19194898122630547
LOSS after epoch 1000 2.896685431985294 LR 1e-05 ACCURACY 0.1933470781723736
LOSS after epoch 1100 2.9403088139552698 LR 1e-05 ACCURACY 0.19183330473548266
LOSS after epoch 1200 2.8698446236404718 LR 1e-05 ACCURACY 0.193313890320

#### Save model history

In [15]:
import pickle
from datetime import datetime

model_data = {
    'model_type': type(model),
    "history": history,
    "accuracy": accuracies,
    "lr": lr,
    "num_layers": num_layers,
    "embedding_dim": embedding_dim,
    "batch_size": batch_size,
    "max_length": max_length,
    "heads": heads,
    "dropout": dropout,
}

now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d_%H-%M-%S")

torch.save({
    'model_type': type(model),
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    "embedding_dim": embedding_dim,
    "batch_size": batch_size,
    "max_length": max_length,
    "num_layers": num_layers,
    "heads": heads,
    "dropout": dropout,
}, "transformer_" + date_time + ".pt")


with open("training_" + date_time + '.pkl', 'wb') as f:
    pickle.dump(model_data, f)

## Make predictions

We run our input sequences through the model and get output seuences. Then we decode the output sequences with the Vocabulary class and get our final latex code.

In [16]:
def predict(input_seqs, coordinates, target_seqs):
    vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
    vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

    with torch.no_grad():
        output = model(input_seqs.to(device), target_seqs.to(device), coordinates.to(device))
        # Get the predicted classes of the model
        topv, topi = output.topk(1, dim=2)
        
        return topi.squeeze()
    
def predict_sequentially(input_seqs, coordinates):
    prediction = torch.zeros((input_seqs.size(0), input_seqs.size(1)-1)).to(torch.int64)
    for i in range(max_length-1):
        output = predict(input_seqs, coordinates, prediction)
        prediction[:, i] = output[:, i]
    return prediction

In [17]:
prediction = predict_sequentially(input_seqs, coordinates)
prediction.shape

torch.Size([128, 50])

In [18]:
# Pick random sequence and its prediction from the model
import random

vocab_in = Vocabulary(vocab_filename="seqgen/vocab_in.txt")
vocab_out = Vocabulary(vocab_filename="seqgen/vocab_out.txt")

predictions = predict(input_seqs, coordinates, target_seqs)

i = random.randint(0, predictions.size(0)-1)
print("MODEL INPUT", vocab_in.decode_sequence(input_seqs[i, 1:].cpu().numpy()))
print("MODEL OUTPUT", vocab_out.decode_sequence(predictions[i, :-1].cpu().numpy()))
print("TARGET OUTPUT", vocab_out.decode_sequence(target_seqs[i, 1:].cpu().numpy()))

MODEL INPUT ['<unk>', '<unk>', '<unk>', 'q', 'b', '<unk>', '<unk>', 'U', 'H', 't', 'v', 'T', 'i', 'w', 'I', 'a', 'd', 'H', 'q', 'R', '<end>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>']
MODEL OUTPUT ['/', '^', '_', '\\cdot', '^', '/', '^', '^', '^', '/', '^', 'q', '^', '\\cdot', '^', '/', '^', '/', '^', '^', '/', '^', '/', '^', '^', '\\cdot', '_', '/', '^', '^', '\\cdot', '_', '/', '^', '^', '^', '^', '^', '^', '^', '^', '^', '^', '^', '^', '^', '^', '^', '^', '^']
TARGET OUTPUT ['I', 'R', '_', 'v', '_', 'H', 'q', 'T', '_', '\\cdot', '^', 't', '^', '\\cdot', '^', 'U', '^', 'q', '/', '^', 'w', '_', 'H', '/', '_', 'i', '^', '-', 'a', '^', 'd', '_', 'b', '<end>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>

In [19]:
prediction = vocab_out.decode_sequence(predictions[i].cpu().numpy())
prediction = list(filter(lambda x: x != '<end>', prediction))
prediction = "".join(prediction)
print("MODEL OUTPUT", prediction)

MODEL OUTPUT /^_\cdot^/^^^/^q^\cdot^/^/^^/^/^^\cdot_/^^\cdot_/^^^^^^^^^^^^^^^^^^


In [20]:
predict_sequentially(input_seqs[0:3], coordinates[0:3])

tensor([[60, 60, 60, 60, 20, 60, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  0],
        [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  0],
        [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  0]])

In [21]:
target_seqs[0:3, 1:]

tensor([[93,  5, 88, 62, 23, 95, 78, 20, 53, 81, 20, 88, 10, 20, 60, 60, 20,  8,
         23, 74, 72, 71, 65, 33, 80,  8, 20, 59, 49, 92,  1,  2,  2,  2,  2,  2,
          2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [75, 56, 20, 58, 57, 20,  5, 55, 81, 20, 83, 48, 91, 95, 91, 23, 86, 23,
         76, 48, 23, 88, 52, 20, 69, 65, 92, 63, 20, 52, 23,  8, 20, 96, 96, 90,
         23, 16, 23, 75,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [58, 23, 11, 23, 93, 20, 90, 87, 23, 57, 65, 20, 32, 20, 94, 92, 23, 60,
         88, 69, 83, 23, 53, 94, 20, 79, 54, 20, 63, 20, 66, 94, 85, 20, 96, 20,
         54, 20, 59, 87, 71, 58, 23, 70, 80,  5, 23, 33,  1,  2]],
       device='cuda:0')

## Prediction for permutated sequences

In [22]:
def generate_permutated_batch(input_seq, coordinates):
    seqs = torch.zeros((5, input_seq.size(0))).to(torch.int64)
    coords = torch.zeros((5, coordinates.size(0), coordinates.size(1)))
    for i in range(5):
        idx_permutated = permutate_tokens(input_seq)
        seqs[i, :] = input_seq[idx_permutated]
        coords[i, :] = coordinates[idx_permutated]
    return seqs, coords

In [23]:
input_permutated, coords_permutated = generate_permutated_batch(input_seqs[0], coordinates[0])
input_permutated

tensor([[  0, 108,  44,  47, 107,  36, 104,  34,  38,  93,  91,  70,   6,  44,
          34,  31,  85,  13,  30,  93,  99,  79,   2,  53,   1,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2],
        [  0,  34, 108, 104,  44,  85,  30,  13,  93,  44,  34,  31,  93,  91,
         107,  47,  38,  99,   6,  70,   2,  53,  79,  36,   1,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2],
        [  0,  30,  47,   6,  70,  93,  36,  13,  34,  91,  79,  38,  93,  34,
         108,   2,  44, 104, 107,  99,  53,  85,  31,  44,   1,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2],
        [  0,  93,  99,  34,  30,  79,  34,  13, 104,  70,  44,  38, 107,  93,
           2,  31,  53,  85,  44,   6,  91,  

In [24]:
predict_sequentially(input_permutated, coords_permutated)

tensor([[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  0],
        [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  0],
        [20, 20, 20, 20, 20, 60, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  0],
        [60, 20, 60, 60, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
         20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  0],
        [20, 60, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
       

In [25]:
target_seqs[0, 1:]

tensor([93,  5, 88, 62, 23, 95, 78, 20, 53, 81, 20, 88, 10, 20, 60, 60, 20,  8,
        23, 74, 72, 71, 65, 33, 80,  8, 20, 59, 49, 92,  1,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
       device='cuda:0')