In [1]:
import torch
from data import CharactersDataset
from model import Transformer
from config import TRANSFORMER_CONFIG

In [2]:
batch_size=32
test_dataset = CharactersDataset("test")
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False
)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
model_path = "checkpoints/cosmic-plasma-3/latest_ckpt.pt"
checkpoint = torch.load(model_path, map_location=device)
checkpoint_inverse_vocabulary = checkpoint["config"]["inverse_vocabulary"]

Using device: cuda


In [5]:

# Load the model
model = Transformer(TRANSFORMER_CONFIG)
model.load_state_dict(checkpoint["model"])
model.to(device)
print("Model loaded from", model_path)

Model loaded from checkpoints/cosmic-plasma-3/latest_ckpt.pt


In [22]:
model.eval()
total, correct = 0, 0
wrong_cases = []
with torch.no_grad():
    for batch_index, batch in enumerate(test_dataloader):
        src = batch["src"].to(device)
        tgt = batch["tgt"].to(device)
        tgt_shifted = batch["tgt_shifted"].to(device)
        logits = model(src, tgt_shifted)
        predictions = torch.argmax(logits, dim=-1)
        # A sequence is correct if all the tokens are correct
        # shape (batch_size, sequence_length)
        correct += torch.sum(torch.all(torch.eq(predictions, tgt), dim=1)).item()
        total += len(src)

        # Find the wrong cases
        wrongs = ~torch.all(torch.eq(predictions, tgt), dim=1)
        wrong_indices = torch.where(wrongs)[0]
        # wrong case will contain a list of dictionaries with the following
        # keys: src, tgt, predictions
        for wrong_index in wrong_indices:
            wrong_cases.append(
                {
                    "src": src[wrong_index],
                    "tgt": tgt[wrong_index],
                    "predictions": predictions[wrong_index],
                }
            )

# Sequence length analysis

In [26]:
for wrong_case in wrong_cases:
    src = "".join([checkpoint_inverse_vocabulary[token.item()] for token in wrong_case["src"]])
    tgt = "".join([checkpoint_inverse_vocabulary[token.item()] for token in wrong_case["tgt"]])
    predictions = "".join([checkpoint_inverse_vocabulary[token.item()] for token in wrong_case["predictions"]])
    print(f"src: {src}")
    print(f"tgt: {tgt}")
    print(f"predictions: {predictions}")
    print()

src: <s>CCAA<e>PADPADPADPADPADPADPADPADPADPAD
tgt: AACC<e>PADPADPADPADPADPADPADPADPADPADPAD
predictions: AACCCPADPADPADPADPADPADPADPADPADPADPAD

src: <s>ABAABACAACAAAA<e>
tgt: AAAAAAAAAABBCC<e>PAD
predictions: AAAAAAAAABBBCC<e>PAD

