In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from dataset.npz_dataset import NPZSequencesDataset
from torch.utils.data import DataLoader
from models.embedding import *
from models.xlstm import XTransformer


torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

In [None]:
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

In [None]:
n_epochs=50
lr=1e-4
num_layers=4
factor=2
embedding_dim=64
batch_size=1024
max_length=20
heads=4
dropout=0.1

In [None]:
sequences_en = np.load("data/small_vocab_en.npz")["data"]
sequences_fr = np.load("data/small_vocab_fr.npz")["data"]
vocab_size_en = sequences_en.max()
vocab_size_fr = sequences_fr.max()

In [None]:
dataset_train = NPZSequencesDataset("data/small_vocab_en.npz", "data/small_vocab_fr.npz", split='train', max_length=max_length)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, drop_last=True)

dataset_val = NPZSequencesDataset("data/small_vocab_en.npz", "data/small_vocab_fr.npz", split='val', max_length=max_length)
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True, drop_last=True)

dataset_test = NPZSequencesDataset("data/small_vocab_en.npz", "data/small_vocab_fr.npz", split='test', max_length=max_length)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, drop_last=True)

input_seqs, target_seqs = next(iter(dataloader_train))
input_seqs = input_seqs.to(device)
target_seqs = target_seqs.to(torch.long).to(device)
input_seqs.shape, target_seqs.shape, dataset_train.vocab_in_size, dataset_train.vocab_out_size

In [None]:
load_from_checkpoint = False
checkpoint_file = "transformer_temp2.pt"

# Transformer model
model = XTransformer(
    embedding_type=EmbeddingType.POS_LEARNED,
    src_vocab_size=dataset_train.vocab_in_size,
    tgt_vocab_size=dataset_train.vocab_out_size,
    config_layers='m',
    embedding_dim=embedding_dim,
    max_length=max_length-1,
    num_layers=num_layers,
    factor=factor,
    device=device
).to(device)

# Initialize optimizer for encoder and decoder
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

# Loss function
criterion = torch.nn.NLLLoss(ignore_index=2)

# Load model weights from checkpoint
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Run the feature sequences through the model
output = model(input_seqs[:, :-1], target_seqs[:, :-1])

In [None]:
# Get the predicted classes of the model
topv, topi = output.topk(1, dim=2)
output.shape, topi.shape, topv.shape

In [None]:
loss = 0.0
for i in range(max_length-1):
    _loss = criterion(output[:, i, :], target_seqs[:, i])
    if not _loss.isnan():
        loss = loss + _loss
loss.item() / max_length

In [None]:
history = []
accuracies = []
print_every = 1

for epoch in range(n_epochs):
    ##############################
    #    TRANSFORMER TRAINING    #
    ############################## 
    
    # Get a batch of training data
    for b, (input_seqs, target_seqs) in enumerate(dataloader_train):
        # Set gradients of all model parameters to zero
        optimizer.zero_grad()

        # Initialize loss
        loss = torch.tensor(0.0).to(device)
        accuracy = 0.0
    
        input_seqs = input_seqs.to(device)
        target_seqs = target_seqs.to(torch.long).to(device)
        
        # Run the input sequences through the model
        output = model(input_seqs[:, :-1], target_seqs[:, :-1])

        # Iterate over sequence positions to compute the loss
        for i in range(max_length-1):
            # Get the predicted classes of the model
            topv, topi = output[:, i, :].topk(1)
            _loss = criterion(output[:, i, :], target_seqs[:, i+1])
            if not _loss.isnan():
                loss += _loss
                mask = target_seqs[:, i+1] != 2
                accuracy += float((topi.squeeze()[mask] == target_seqs[mask, i+1]).sum() / (target_seqs[mask].size(0)*(target_seqs[mask].size(1)-2)))

        history.append(loss.item())
        accuracies.append(accuracy)

        if not epoch % print_every:
            _accuracy = sum(accuracies[-print_every:]) / print_every
            lr = scheduler.get_last_lr()[0]
            print(f"LOSS after epoch {epoch} Batch [{b+1}/{len(dataloader_train)}]", loss.item() / (target_seqs.size(1)), "LR", lr, "ACCURACY", _accuracy)

        ######################
        #   WEIGHTS UPDATE   #
        ######################

        # Compute gradient
        loss.backward()
        accuracy = 0.0

        # Update weights of encoder and decoder
        optimizer.step()

    # Adjust the learning rate
    scheduler.step()

In [None]:
model.eval()  # Set the model to evaluation mode
batch_accuracies = []

with torch.no_grad():
    for input_seqs, target_seqs in dataloader_test:
        # Move batch data to the device
        input_seqs = input_seqs.to(device)
        target_seqs = target_seqs[:, 1:].to(device)

        # Forward pass
        outputs = model(input_seqs[:, :-1])

        # Compute the predicted classes
        topv, topi = output.topk(1)

        # Iterate over sequence positions to compute the loss
        accuracy = 0.0
        for i in range(max_length-1):
            # Get the predicted classes of the model
            topv, topi = output[:, i, :].topk(1)
            mask = target_seqs[:, i] != 2
            accuracy += float((topi.squeeze()[mask] == target_seqs[mask, i]).sum() / (target_seqs[mask].size(0)*(target_seqs[mask].size(1)-2)))
        batch_accuracies.append(accuracy)
        print("ACC", accuracy)

# Compute the accuracy
mean_accuracy = np.array(batch_accuracies).mean()

# Print the accuracy
print(f"Accuracy on the test dataset: {mean_accuracy:.4f}")