In [1]:
import datasets
from __future__ import annotations
import numpy as np

import builtins
import json
import torch
from torch.utils.data import Dataset, DataLoader
from new_transformer import Transformer
import torch.nn as nn
from torch.optim import AdamW
from tqdm import tqdm
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Vocab:
    def __init__(self, content: str) -> None:
        self.content = content
        self.chars = sorted(list(set(self.content)))
        self.chars+=["[UNK]", "[PAD]", "[SOS]", "[EOS]"]
        self.stoi = {char:i for i, char in enumerate(self.chars)}
        self.itos = {i:char for i, char in enumerate(self.chars)}
        stoi_json = json.dumps(self.stoi, indent=4)
        itos_json = json.dumps(self.itos, indent=4)
        with open("stoi.json", "w") as outfile:
            outfile.write(stoi_json)
        with open("itos.json", "w") as outfile:
            outfile.write(itos_json)

    def __len__(self) -> int:
        return len(self.stoi)

    def __getitem__(self, idx: str | int) -> int | str:
        match type(idx):
            case builtins.str: return self.stoi[idx]
            case builtins.int: return self.itos[idx]

    def encode(self, chars: str) -> list[int]:
        encoded = []
        while chars != "":
            if chars[0] == "[" and len(chars) > 5 and chars[4] == "]" and chars[:5] != "[...]":
                encoded.append(self.stoi[chars[:5]])
                chars = chars[5:]
            elif chars[0] == "[" and len(chars) == 5 and chars[4] == "]" and chars[:] != "[...]":
                encoded.append(self.stoi[chars])
                chars = ""
            else:
                encoded.append(self.stoi[chars[0]])
                chars = chars[1:]
        return encoded
    
    def decode(self, idxs: list[int]) -> str:
        return ''.join(self[idx] for idx in idxs)

In [3]:
dataset = datasets.load_dataset("opus_books", "en-fr")

In [4]:
en_content = ''.join([dataset["train"][i]["translation"]["en"] for i in range(dataset.num_rows["train"])])
fr_content = ''.join([dataset["train"][i]["translation"]["fr"] for i in range(dataset.num_rows["train"])])
global_content = en_content + fr_content
en_vocab = Vocab(en_content)
fr_vocab = Vocab(fr_content)
global_vocab = Vocab(global_content)

In [5]:
en_vocab.encode("[UNK][PAD][SOS][EOS]")

[143, 144, 145, 146]

In [6]:
class BilingualDataset(Dataset):
    def __init__(self, dataset, src_vocab: Vocab, tgt_vocab: Vocab, seq_len) -> None:
        self.dataset = dataset
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.seq_len = seq_len
        
        self.sos_idx = torch.tensor(self.tgt_vocab.encode("[SOS]"), dtype = torch.int64)
        self.eos_idx = torch.tensor(self.tgt_vocab.encode("[EOS]"), dtype = torch.int64)
        self.pad_idx = torch.tensor(self.tgt_vocab.encode("[PAD]"), dtype = torch.int64)
        self.unk_idx = torch.tensor(self.tgt_vocab.encode("[UNK]"), dtype = torch.int64)
        
    def _causal_mask(self, seq_len: int) -> torch.Tensor:
        mask = torch.ones(1, seq_len, seq_len, dtype=torch.bool)
        mask = torch.tril(mask, diagonal=0)
        return mask
        
    def __len__(self) -> int:
        return len(self.dataset)
    
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        src = torch.tensor(self.src_vocab.encode(self.dataset[idx]["en"]), dtype = torch.int64)
        tgt = torch.tensor(self.tgt_vocab.encode(self.dataset[idx]["fr"]), dtype = torch.int64)
        if self.seq_len - len(src) - 2 < 0 or self.seq_len - len(tgt) - 1 < 0:
            src = src[:self.seq_len - 2]
            tgt = tgt[:self.seq_len - 1]
        enc_num_pad = self.seq_len - len(src) - 2
        dec_num_pad = self.seq_len - len(tgt) - 1
        input_src = torch.cat([self.sos_idx, src, self.eos_idx, self.pad_idx.repeat(enc_num_pad)])
        input_tgt = torch.cat([self.sos_idx, tgt, self.pad_idx.repeat(dec_num_pad)])
        input_label = torch.cat([tgt,self.eos_idx, self.pad_idx.repeat(dec_num_pad)])
        return (
            input_src, 
            input_tgt, 
            input_label, 
            (input_src!=self.pad_idx).unsqueeze(0).unsqueeze(0).int() == 1,
            (input_tgt!=self.pad_idx).unsqueeze(0).unsqueeze(0).int() & self._causal_mask(self.seq_len) == 1,
            )

In [7]:
dataset = [x for x in dataset["train"]["translation"] if len(x["en"]) < 200 and len(x["fr"]) < 200]

In [8]:
len(dataset)

105312

In [9]:
dataset

[{'en': 'The Wanderer', 'fr': 'Le grand Meaulnes'},
 {'en': 'Alain-Fournier', 'fr': 'Alain-Fournier'},
 {'en': 'First Part', 'fr': 'PREMIÈRE PARTIE'},
 {'en': 'I', 'fr': 'CHAPITRE PREMIER'},
 {'en': 'THE BOARDER', 'fr': 'LE PENSIONNAIRE'},
 {'en': 'He arrived at our home on a Sunday of November, 189-.',
  'fr': 'Il arriva chez nous un dimanche de novembre 189-…'},
 {'en': "I still say 'our home,' although the house no longer belongs to us.",
  'fr': 'Je continue à dire « chez nous », bien que la maison ne nous appartienne plus.'},
 {'en': 'We left that part of the country nearly fifteen years ago and shall certainly never go back to it.',
  'fr': 'Nous avons quitté le pays depuis bientôt quinze ans et nous n’y reviendrons certainement jamais.'},
 {'en': "We were living in the building of the Higher Elementary Classes at Sainte-Agathe's School.",
  'fr': 'Nous habitions les bâtiments du Cours Supérieur de Sainte-Agathe.'},
 {'en': "My father, whom I used to call M. Seurel as did other p

In [10]:
train_set = dataset[:int(len(dataset)*0.8)]
val_set = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_set = dataset[int(len(dataset)*0.9):]

In [11]:
train_set[:10]

[{'en': 'The Wanderer', 'fr': 'Le grand Meaulnes'},
 {'en': 'Alain-Fournier', 'fr': 'Alain-Fournier'},
 {'en': 'First Part', 'fr': 'PREMIÈRE PARTIE'},
 {'en': 'I', 'fr': 'CHAPITRE PREMIER'},
 {'en': 'THE BOARDER', 'fr': 'LE PENSIONNAIRE'},
 {'en': 'He arrived at our home on a Sunday of November, 189-.',
  'fr': 'Il arriva chez nous un dimanche de novembre 189-…'},
 {'en': "I still say 'our home,' although the house no longer belongs to us.",
  'fr': 'Je continue à dire « chez nous », bien que la maison ne nous appartienne plus.'},
 {'en': 'We left that part of the country nearly fifteen years ago and shall certainly never go back to it.',
  'fr': 'Nous avons quitté le pays depuis bientôt quinze ans et nous n’y reviendrons certainement jamais.'},
 {'en': "We were living in the building of the Higher Elementary Classes at Sainte-Agathe's School.",
  'fr': 'Nous habitions les bâtiments du Cours Supérieur de Sainte-Agathe.'},
 {'en': "My father, whom I used to call M. Seurel as did other p

In [12]:
max_seq_len = 0
sum = 0
for i in range(len(train_set)):
    src = train_set[i]["en"]
    tgt = train_set[i]["fr"]
    max_seq_len = max(max_seq_len, len(en_vocab.encode(src)), len(fr_vocab.encode(tgt)))
    sum+=len(en_vocab.encode(src))
mean = sum/len(train_set)

In [13]:
max_seq_len, mean

(199, 83.84421180073355)

In [14]:
dataset_train = BilingualDataset(train_set, global_vocab, global_vocab, 200)
val_set = BilingualDataset(val_set, global_vocab, global_vocab, 200)
test_set = BilingualDataset(test_set, global_vocab, global_vocab, 200)

In [15]:
len(dataset_train), len(val_set), len(test_set)

(84249, 10531, 10532)

In [16]:
input_src, input_tgt, input_label, src_mask, tgt_mask = dataset_train[0]

In [17]:
input_tgt.shape, input_src.shape, input_label.shape, src_mask.shape, tgt_mask.shape

(torch.Size([200]),
 torch.Size([200]),
 torch.Size([200]),
 torch.Size([1, 1, 200]),
 torch.Size([1, 200, 200]))

In [18]:
input_tgt

tensor([161,  43,  66,   0,  68,  79,  62,  75,  65,   0,  44,  66,  62,  82,
         73,  75,  66,  80, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 1

In [19]:
input_label

tensor([ 43,  66,   0,  68,  79,  62,  75,  65,   0,  44,  66,  62,  82,  73,
         75,  66,  80, 162, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160, 1

In [20]:
src_mask[0,0]

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [21]:
tgt_mask.shape

torch.Size([1, 200, 200])

In [22]:
tgt_mask[0,0]

tensor([ True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [23]:
device = torch.device("cuda")

In [24]:
len(fr_vocab)

154

In [25]:
model = Transformer(
    vocab_size=len(global_vocab),
    n_head=8,
    embed_size=496,
    context_length=200,
    dropout=0.1,
    num_layers=6,
    device=device,
)

In [26]:
print(f"Footprint      {f'{(model.num_parameters + model.num_buffers) * 32 * 1.25e-10:.2f} GB':>12}")

Footprint           0.15 GB


In [27]:
batch_size = 32
num_epochs = 10
lr = 3e-4

In [28]:
model = model.to(device)

In [29]:
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, drop_last=True)

In [30]:
len(fr_vocab.stoi)

154

In [31]:
global_vocab.encode("[PAD]")[0]

160

In [32]:
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=len(train_loader)*num_epochs, pct_start=0.1)

In [33]:
min_valid_loss = np.inf
for _ in range(num_epochs):
    model.train()
    with tqdm(enumerate(train_loader)) as pbar:
        for idx, (src, tgt, label, src_mask, tgt_mask) in pbar:
            src = src.to(device)
            tgt = tgt.to(device)
            label = label.to(device)
            src_mask = src_mask.to(device)
            tgt_mask = tgt_mask.to(device)
            output = model(src, tgt, src_mask, tgt_mask)
            B, T, C = output.shape
            if idx%1000 == 0:
                print(global_vocab.decode(output.argmax(dim=-1)[0].tolist()), global_vocab.decode(label[0].tolist()))
            loss = F.cross_entropy(output.view(B * T, C), label.view(B * T), ignore_index=global_vocab.encode("[PAD]")[0])
            acc = (output.argmax(dim=-1) == label).float().mean()
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            pbar.set_description(f"Epoch {_} | Loss {loss.item():.3f} | Acc {acc.item():.3f}")
            # break
    valid_loss = 0
    model.eval()
    with tqdm(val_loader) as pbar:
        with torch.no_grad():
            for src, tgt, label, src_mask, tgt_mask in pbar:
                src = src.to(device)
                tgt = tgt.to(device)
                label = label.to(device)
                src_mask = src_mask.to(device)
                tgt_mask = tgt_mask.to(device)
                
                output = model(src, tgt, src_mask, tgt_mask)
                loss = F.cross_entropy(output.view(B * T, C), label.view(B * T), ignore_index=global_vocab.encode("[PAD]")[0])
                valid_loss += loss.item()*src.shape[0]
                pbar.set_description(f"Epoch {_} | Loss {loss.item():.3f}")
    print(f'Epoch {_+1}Validation Loss: {valid_loss / len(val_set)}')
    if min_valid_loss > valid_loss:
        print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), 'saved_model.pth')

0it [00:00, ?it/s]

Ï+—Âî5ι?АÏâiι?:5ûÏε?ûοO?$êАÎt«Ï?r&κκκκκκκÏκκÏκκκκκκκκκκκκCCκκκCκκκκκκκκ&&κκCCκκκCκκκκκκ&κC&κCκ&κκκκ&κκκκCκκ&κκκ&κκκCκκκκκκκoκCκκκκκκκ&&κκ&κκκκκκκκκκκCκ&κκκκκκκ&κCκCκκκκκκκκκκκκ&CκûκC&κ&κ&ÏκκCκκκ&κκκκκ Est-ce parce qu’il est imprimé ?[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][P

Epoch 0 | Loss 4.448 | Acc 0.065: : 5it [00:02,  2.37it/s]

Epoch 0 | Loss 2.179 | Acc 0.155: : 1000it [07:06,  2.33it/s]

Je na    te  doesarts dai  utde mo due j' da tis deaittite  daur dei d  dous doun[EOS]d  d      i[EOS]                                                                                                           Je descendis prévenir Gaston de ce que je venais d'arranger pour lui et pour moi. Il accepta.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]


Epoch 0 | Loss 1.602 | Acc 0.235: : 2000it [14:15,  2.33it/s]

Il ae ltiluac uempantiln lartiee[EOS]                                                                                                                                                                        Il salua avec respect et partit.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PA

Epoch 0 | Loss 1.472 | Acc 0.248: : 2632it [18:46,  2.34it/s]
Epoch 0 | Loss 1.490: 100%|██████████| 329/329 [00:42<00:00,  7.68it/s]


Epoch 1Validation Loss: 1.549833245386528
Validation Loss Decreased(inf--->16321.293907) 	 Saving The Model


0it [00:00, ?it/s]

Mais elle ntait pejertér[EOS]  t l  sl  - t   l    [EOS] t  t   t  t t t  stl [EOS][EOS]  st       l ttt  - -    tt  s         st lt  ttt-  l tt t t -       s   tt   - t   tt  t   - s    t st - stttst [EOS] t  t s -   [EOS][EOS] Mais elle était déserte.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][P

Epoch 1 | Loss 1.363 | Acc 0.272: : 1000it [07:09,  2.33it/s]

Lean aaraa dur out deabpos de  capiss daiiléaités dui lensrntin de coeleotiirue de cNautilus_  eor as dtedemeentsde   c ca cain de capitaine Nemo dt de ars dor los pulesvetions.drrsonnesles [EOS][EOS][EOS][EOS]  [EOS] [EOS][EOS] J'en parle surtout d'après les cartes manuscrites que contenait la bibliothèque du _Nautilus_, cartes évidemment dues à la main du capitaine Nemo et levées sur ses observations personnelles.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]


Epoch 1 | Loss 1.286 | Acc 0.228: : 2000it [14:18,  2.33it/s]

-- Cansieur,noe pnt paunt  répondit ln memeutique.[EOS][EOS][EOS][EOS]-.[EOS]-[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS].[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS].[EOS][EOS][EOS][EOS]-[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]-[EOS]-[EOS][EOS][EOS][EOS][EOS]-[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS].[EOS]-[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]-[EOS]-[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS] -- Monsieur n’y est point, répondit un domestique.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][

Epoch 1 | Loss 1.254 | Acc 0.296: : 2632it [18:50,  2.33it/s]
Epoch 1 | Loss 1.329: 100%|██████████| 329/329 [00:42<00:00,  7.69it/s]


Epoch 2Validation Loss: 1.3264510075304088
Validation Loss Decreased(16321.293907--->13968.855560) 	 Saving The Model


0it [00:00, ?it/s]

- ui, jon, jéprit-il,.,,,,,..,,.,.,..,.,.,.,.,.,,, ,... ,.,..,,,,,,..,,,,,,.,, .,.,, ,...,.,,.,,, ..,,..,., ,.. ,,.,,,,..,,,,....,,,,....,.,...,,,, ,,,, .,,,.,,,.,,.,,.,,, ,,, ,.., ,.,,,.,.,,,,,.,.,.. «Oui, moi! reprit-il.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PA

Epoch 2 | Loss 1.235 | Acc 0.288: : 1000it [07:09,  2.33it/s]

M Mais ll ne soersa a luec lnnoreeée ![EOS]  u  eo o    ou      o s s  u  uo  o o euo  e    oe o    u so   oe  ss   o  u  s      u o  uo    s a aeoo  ou s u     os o    o   o  u o u  a  uo       uo  uu a  – Mais il me chassera avec ignominie ![EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][P

Epoch 2 | Loss 1.131 | Acc 0.270: : 2000it [14:18,  2.33it/s]

Mais il neitit dens le phnr dùleure.[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS] Puis il sortit dans la cour obscure.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][P

Epoch 2 | Loss 1.154 | Acc 0.270: : 2632it [18:49,  2.33it/s]
Epoch 2 | Loss 1.261: 100%|██████████| 329/329 [00:42<00:00,  7.68it/s]


Epoch 3Validation Loss: 1.2302547129184678
Validation Loss Decreased(13968.855560--->12955.812382) 	 Saving The Model


0it [00:00, ?it/s]

-eacae svsmon,neau  eossieurs, seatée svsmen!neuu [EOS].ts.....t....s.s.t...t.....s......s..t...........s.sss....-..s..s...-..s..ss.ss.ss..s.......t.s.......s..ssts.s......s...s...s.......-...t.......t.s. L'épée au fourreau, messieurs! l'épée au fourreau![EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]

Epoch 3 | Loss 1.089 | Acc 0.272: : 1000it [07:09,  2.33it/s]

Nllons, nepomrheztou,[EOS][EOS][EOS][EOS][EOS][EOS][EOS]»[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]»[EOS][EOS][EOS][EOS][EOS]»[EOS]»[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]»[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]»[EOS][EOS][EOS][EOS][EOS]»[EOS][EOS][EOS][EOS][EOS]»[EOS][EOS]»[EOS][EOS][EOS][EOS][EOS]»»»[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]»[EOS][EOS]»[EOS][EOS][EOS][EOS]»[EOS][EOS]»»[EOS][EOS][EOS][EOS][EOS]»[EOS][EOS]»[EOS][EOS]»[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]»[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS]»[EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS] Allons, recouche-toi.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]

Epoch 3 | Loss 1.079 | Acc 0.323: : 2000it [14:18,  2.33it/s]

Epamis dothoina t de  pirr  et Porthos learrênhait de somps àn pemps àu' ques mauns de monrqacheset aelni de léfispéir [EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS][EOS] Aramis mâchonnait des vers, et Porthos s'arrachait de temps en temps quelques poils de moustache en signe de désespoir.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][

Epoch 3 | Loss 1.112 | Acc 0.274: : 2632it [18:50,  2.33it/s]
Epoch 3 | Loss 1.104: 100%|██████████| 329/329 [00:42<00:00,  7.69it/s]


Epoch 4Validation Loss: 1.182872290081951
Validation Loss Decreased(12955.812382--->12456.828087) 	 Saving The Model


0it [00:00, ?it/s]

Auent au  caussons  ils soéfiquèient doutours dotre pmmirabion, euend ious aoivrînions d lravers les clrteaux dù erte,des ponoets de laur cieiduuetrque.[EOS][EOS]..t...[EOS].....[EOS]..[EOS]...[EOS]t[EOS][EOS]..[EOS]...t..tt.[EOS] ..t.ttt.[EOS] Quant aux poissons, ils provoquaient toujours notre admiration, quand nous surprenions à travers les panneaux ouverts les secrets de leur vie aquatique.[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]


Epoch 4 | Loss 1.025 | Acc 0.342: : 1000it [07:09,  2.33it/s]

DontellDolDlul.[EOS][EOS]l[EOS].[EOS].[EOS]....[EOS][EOS][EOS][EOS][EOS].l[EOS]...[EOS][EOS].....[EOS]l.[EOS].[EOS]....t...[EOS][EOS]...l..[EOS]..h....[EOS].[EOS][EOS][EOS].[EOS].l......o[EOS]...ss....l[EOS][EOS]l.[EOS].[EOS][EOS][EOS][EOS].[EOS][EOS]......[EOS][EOS]h...[EOS].s.[EOS]..l.[EOS]..[EOS]...[EOS].[EOS][EOS][EOS]...[EOS].t.[EOS]..[EOS].[EOS][EOS]...[EOS]..h[EOS]...l..[EOS]...[EOS].lh[EOS][EOS][EOS]...l.l[EOS]l....l Daniel De Foë[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PA

Epoch 4 | Loss 1.127 | Acc 0.290: : 2000it [14:18,  2.33it/s]

Pr, qourquoi noulait-el donc lveir de tafint?[EOS]..........-..........--.......-.....................-............-..........-.--.................................................-.....-..-............... Or, pourquoi voulait-il donc avoir cet agent?[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]

Epoch 4 | Loss 1.047 | Acc 0.263: : 2216it [15:51,  2.33it/s]


KeyboardInterrupt: 

In [36]:
def gen(model: nn.Module, sentence: str, max_len: int, vocab: Vocab, device: torch.device):
    model.eval()
    sos_token = torch.tensor(vocab.encode("[SOS]"), dtype=torch.int64).to(device)
    eos_token = torch.tensor(vocab.encode("[EOS]"), dtype=torch.int64).to(device)
    pad_token = torch.tensor(vocab.encode("[PAD]"), dtype=torch.int64).to(device)
    
    src_input = torch.cat([sos_token, torch.tensor(vocab.encode(sentence), dtype=torch.int64).to(device), eos_token, pad_token.repeat(max_len - len(vocab.encode(sentence)) - 2)])
    src_mask = (src_input != pad_token).unsqueeze(0).int() == 1
    
    tgt_input = sos_token
    while tgt_input[-1] != eos_token and len(tgt_input) < max_len:
        tgt_mask = dataset_train._causal_mask(tgt_input.shape[0]) == 1
        src_input, tgt_input, src_mask, tgt_mask = src_input.to(device), tgt_input.to(device), src_mask.to(device), tgt_mask.to(device)
        print(src_input.shape, tgt_input.shape, src_mask.shape, tgt_mask.shape)
        logits = model(src_input.unsqueeze(0), tgt_input.unsqueeze(0), src_mask.unsqueeze(0), tgt_mask.unsqueeze(0))
        pred = F.softmax(logits, dim=-1)
        print(pred.shape)
        print(pred[:, -1, :].argmax(dim=-1))
        # next_token = torch.multinomial(pred[:,-1,:], num_samples=1)
        tgt_input = torch.cat([tgt_input, pred[:,-1,:].argmax(dim = -1).to(device)])
        print(tgt)
    print(vocab.decode(tgt_input.tolist()))
gen(model, "I am a Student", 200, global_vocab, device)

torch.Size([200]) torch.Size([1]) torch.Size([1, 200]) torch.Size([1, 1, 1])
torch.Size([1, 1, 163])
tensor([41], device='cuda:0')
tensor([[161,  43,  66,  ..., 160, 160, 160],
        [161,  46,  75,  ..., 160, 160, 160],
        [161,  51,  69,  ..., 160, 160, 160],
        ...,
        [161,  36,  80,  ..., 160, 160, 160],
        [161,  93,   0,  ..., 160, 160, 160],
        [161,  43,  66,  ..., 160, 160, 160]], device='cuda:0')
torch.Size([200]) torch.Size([2]) torch.Size([1, 200]) torch.Size([1, 2, 2])
torch.Size([1, 2, 163])
tensor([66], device='cuda:0')
tensor([[161,  43,  66,  ..., 160, 160, 160],
        [161,  46,  75,  ..., 160, 160, 160],
        [161,  51,  69,  ..., 160, 160, 160],
        ...,
        [161,  36,  80,  ..., 160, 160, 160],
        [161,  93,   0,  ..., 160, 160, 160],
        [161,  43,  66,  ..., 160, 160, 160]], device='cuda:0')
torch.Size([200]) torch.Size([3]) torch.Size([1, 200]) torch.Size([1, 3, 3])
torch.Size([1, 3, 163])
tensor([0], device='cud

In [None]:
model = model.to(device)
model = torch.load("saved_model.pth")
torch.onnx.export(
    model.cpu(),
    torch.zeros((1, 200,), dtype=torch.long),
    'harry_potter.onnx',
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size', 1: 'context'}, 'output': {0: 'batch_size', 1: 'context'}},
)