In [1]:
from pathlib import Path
from tokenizer.BPE import Tokenizer
import numpy as np
from typing import Optional
import yaml

import torch
import torch.nn as nn

from model import Transformer

from torch.utils.data import DataLoader
from utils.dataset import NeuralTranslationDataset

# DEBUG
from utils.vis import plot_two_runs

from matplotlib import pyplot as plt

In [2]:
# load tokenizer
tokenizer = Tokenizer(compute_vocab=False, 
                      max_vocab_size=37_005,
                      corpus_source='wmt',
                      vocab_dest_file=Path('./data/dest/wmt_37k_tokens.yaml'))

# device
device = torch.device('cuda')

# model
transformer = Transformer(pre_trained=True, yaml_path='./logs/1748304942.yaml')

# test set
data_test = NeuralTranslationDataset(subset='test') 

# dataloaders
# -train
test_loader = DataLoader(data_test, batch_size=2,
                         shuffle=False, num_workers=4, pin_memory=True)

for batch in test_loader:
    break

🔄 Loading weights from: /eagle/projects/argonne_tpc/siebenschuh/attention_from_scratch/data/checkpoints/1748304942_epoch20.pth
✅ Model loaded from: ./logs/1748304942.yaml
✅ Weights loaded from: /eagle/projects/argonne_tpc/siebenschuh/attention_from_scratch/data/checkpoints/1748304942_epoch20.pth


In [3]:
"".join(tokenizer.decode(batch['src_ids'].detach().cpu()[0,:].tolist()))

'<BOS>urspruenglich war die schulhofsanierung sogar schon in den jahren geplant doch hohe unplanmaeige ausgaben brachten eine verschiebung<EOS><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>'

In [4]:
"".join(tokenizer.decode(batch['tgt_ids'].detach().cpu()[0,:].tolist()))

'<BOS>the school yard renovation was originally planned back in however high unplanned expenses meant that the work had to be pushed back<EOS><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>'

In [5]:
pred = transformer.generate(src_ids=batch['src_ids'].to(torch.device('cuda')), 
                            L_max=64, 
                            eos_idx=tokenizer.token_vocab['<EOS>'],
                            temperature=0.2)

In [6]:
transformer(**batch)

tensor([[[ 0.0112, -0.2074,  0.1559,  ...,  0.0985,  0.0603, -0.1851],
         [ 0.0125, -0.2080,  0.1565,  ...,  0.0996,  0.0598, -0.1854],
         [ 0.0161, -0.2101,  0.1581,  ...,  0.1013,  0.0570, -0.1859],
         ...,
         [ 0.0247, -0.2057,  0.1445,  ...,  0.0953,  0.0553, -0.1962],
         [ 0.0244, -0.2054,  0.1444,  ...,  0.0951,  0.0551, -0.1967],
         [ 0.0239, -0.2056,  0.1446,  ...,  0.0949,  0.0556, -0.1976]],

        [[-0.0635, -0.1616,  0.1301,  ...,  0.0974,  0.0014, -0.2508],
         [-0.0616, -0.1619,  0.1303,  ...,  0.0986,  0.0007, -0.2506],
         [-0.0555, -0.1627,  0.1313,  ...,  0.0955, -0.0025, -0.2492],
         ...,
         [-0.0500, -0.1651,  0.1187,  ...,  0.0949, -0.0084, -0.2497],
         [-0.0503, -0.1651,  0.1186,  ...,  0.0946, -0.0084, -0.2501],
         [-0.0507, -0.1653,  0.1193,  ...,  0.0946, -0.0080, -0.2508]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [7]:
transformer.tgt_mask.detach().cpu().numpy()

array([[[[False,  True,  True, ...,  True,  True,  True],
         [False, False,  True, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         ...,
         [False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True]]],


       [[[False,  True,  True, ...,  True,  True,  True],
         [False, False,  True, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         ...,
         [False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True]]]])

In [8]:
transformer.X_after.detach().cpu().numpy()

array([[    1,     1,    96, 21788, 29894,   195, 30469,  9877,  1451,
        18117,  1041, 32726,   181,  8012,   125,  5932,  4107,    74,
        32719,   181, 25402,   428,  6239,    24,   315,    99,  2534,
         5100,   151,   178,  2050,  9993,  8012,     2,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0],
       [    1,     1, 24816,  2749,  2527,   220,   178, 12678, 29747,
         7276,     8,  6220,   173,  7591,  2422, 30722, 18571,    61,
            2,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     

In [None]:
batch['tgt_ids'][0]