In [26]:
import json
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from new_transformer import Transformer
import torch.nn as nn
from torch.optim import AdamW
import tqdm
import torch.nn.functional as F
import numpy as np

In [27]:
class BPE_Tokenizer:
    def _init_(self):
        self.vocab = set()
        self.token_to_index = {}
        self.index_to_token = {}

    @staticmethod
    def get_stats(vocab):
        pairs = {}
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols)-1):
                pair = (symbols[i], symbols[i+1])
                if pair in pairs:
                    pairs[pair] += freq
                else:
                    pairs[pair] = freq
        return pairs

    @staticmethod
    def merge_vocab(pair, v_in):
        v_out = {}
        bigram = ' '.join(pair)
        replacement = ''.join(pair)
        for word in v_in:
            w_out = word.replace(bigram, replacement)
            v_out[w_out] = v_in[word]
        return v_out

    @staticmethod
    def get_vocab(text):
        vocab = {}
        for word in text.split():
            word = ' '.join(list(word)) +  ' </w>'
            if word in vocab:
                vocab[word] += 1
            else:
                vocab[word] = 1
        return vocab

    def train_until(self, text, vocab_size):
        vocab = self.get_vocab(text)
        self.vocab = set(word for word in vocab for word in word.split())
        # Check if the initial vocabulary size is less than the desired size
        while len(self.vocab) < vocab_size:
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.vocab = set(word for word in vocab for word in word.split())

        # Add special tokens
        self.vocab.add('</u>')
        self.vocab.add('[]')

        self.build_index()

    def train(self, text, num_merges):
            vocab = self.get_vocab(text)
            for i in range(num_merges):
                pairs = self.get_stats(vocab)
                if not pairs:
                    break
                best = max(pairs, key=pairs.get)
                vocab = self.merge_vocab(best, vocab)

            self.vocab = set(word for word in vocab for word in word.split())

            # Add special tokens
            self.vocab.add('</u>')

            self.build_index()

    def build_index(self):
        # existing code
        self.token_to_index = {token: index for index, token in enumerate(self.vocab)}
        self.index_to_token = {index: token for token, index in self.token_to_index.items()}

    def tokenize(self, text):
        tokens = []
        for word in text.split():
            subwords = self.get_subwords(word + '</w>')
            tokens.extend(self.token_to_index.get(sw, self.token_to_index['</u>']) for sw in subwords)
        return tokens

    def get_subwords(self, word):
        subwords = []
        while word:
            subword = self.find_longest_subword(word)
            if subword is None:
                subwords.append('</u>')
                break
            subwords.append(subword)
            word = word[len(subword):]
        return subwords

    def find_longest_subword(self, word):
        for i in range(len(word), 0, -1):
            if word[:i] in self.vocab:
                return word[:i]
        return None

    def detokenize(self, token_ids):
        words = []
        current_word = ''
        for token_id in token_ids:
            token = self.index_to_token.get(token_id, '</u>')
            if token == '</w>':
                words.append(current_word)
                current_word = ''
            else:
                current_word += token
        words.append(current_word)
        return ' '.join(words).replace('</w>', ' ')
    
    def add_special_tokens(self, special_tokens):
        for token in special_tokens:
            self.vocab.add(token)
        self.build_index()

    def save_vocab(self, file_path):
        with open(file_path, 'w') as f:
            json.dump(self.token_to_index, f)

    def load_vocab(self, file_path):
        with open(file_path, 'r') as f:
            self.token_to_index = json.load(f)
            self.index_to_token = {int(index): token for token, index in self.token_to_index.items()}
            self.vocab = set(self.token_to_index.keys())
    
    def __len__(self):
        return len(self.token_to_index)

In [28]:
import datasets
import tqdm as tqdm
import pandas as pd
dataset = datasets.load_dataset("opus_books", "en-fr")
en_content = ''.join([dataset["train"][i]["translation"]["en"] for i in range(dataset.num_rows["train"])])
fr_content = ''.join([dataset["train"][i]["translation"]["fr"] for i in range(dataset.num_rows["train"])])
global_content = en_content + fr_content

tokenizer = BPE_Tokenizer()

num_merges =  90
untils = []
total_tokens = []
avg_token_len = []
avg_token_std = []

def clean(text: str) -> str:
    text = text.replace('.', ' . ')
    text = text.replace(',', ' , ')
    text = text.replace('!', ' ! ')
    text = text.replace('?', ' ? ')
    text = text.replace(':', ' : ')
    text = text.replace(';', ' ; ')
    text = text.replace(')', '')
    text = text.replace('(', '')
    text = text.replace('@', '')
    text = text.replace('|', '')
    text = text.replace(']', '')
    text = text.replace('[', '')
    text = text.replace('~', '')
    text = text.replace('^', '')
    text = text.replace('<', '')
    text = text.replace('>', '')
    text = text.replace('&', '')
    text = text.replace('{', '')
    text = text.replace('}', '')
    text = text.replace('+', '')
    # text = text.replace('-', '')
    text = text.replace('tititi', '')
    text = text.replace('orerer', '')
    text = text.replace('errero', '')
    text = text.replace('\u007f', '')
    text = text.replace('_', '')
    text = text.replace('%', '')
    text = text.replace('$', '')
    text = text.replace('\\', '')
    text = text.replace('=', '')
    text = text.replace('#', '')
    text = text.replace(';', '')
    text = text.replace(':', '')
    text = text.encode("ascii", errors="ignore").decode()
    return text


global_content = clean(global_content)
print("text length:", len(global_content))
print("Training...")
tokenizer.train_until(global_content, 1200)
# tokenizer.load_vocab('token_to_index.json')

print("Vocab:", tokenizer.token_to_index)
print("vocab size:", len(tokenizer))


tokenizer.add_special_tokens(['<PAD></w>', '<SOS></w>','<EOS></w>'])


# Saving token_to_index mapping
tokenizer.save_vocab('en-fr.json')

text length: 31353581
Training...
Vocab: {'loren': 0, 'sthe</w>': 1, 'silis</w>': 2, 'x</w>': 3, 'ored</w>': 4, 'entiles</w>': 5, 'rento</w>': 6, 'iline</w>': 7, 'renou': 8, 'chis</w>': 9, 'atin': 10, 'ano</w>': 11, 'lilo': 12, 'oreth': 13, 'esun</w>': 14, 'al</w>': 15, 'aroi': 16, 'ly</w>': 17, 'ori</w>': 18, 'pout</w>': 19, 'ereti': 20, 'orel</w>': 21, 'norent</w>': 22, 'aporis</w>': 23, 'poreu': 24, 'arenti': 25, 'poinon</w>': 26, 'loron': 27, 'rine</w>': 28, 'atine</w>': 29, 'areri': 30, 'orine</w>': 31, 'orent</w>': 32, 'arente</w>': 33, 'entil</w>': 34, 'alo': 35, 'one</w>': 36, 'dinous</w>': 37, 'he</w>': 38, 'our</w>': 39, 'relon': 40, ',</w>': 41, 'erec': 42, 'loit</w>': 43, 'areth</w>': 44, 'nono': 45, 'reles</w>': 46, 'eles</w>': 47, 'arech</w>': 48, 'gh': 49, 'w': 50, 'recom': 51, 'sun</w>': 52, 'ail</w>': 53, 'rec': 54, 'ecom': 55, '/': 56, 'ghis</w>': 57, 'esil</w>': 58, 'areti': 59, 'y</w>': 60, 'ne</w>': 61, 'enof</w>': 62, 'nor': 63, 'dit</w>': 64, 'ari</w>': 65, 'orim

In [29]:
# Tokenizing
token_ids = tokenizer.tokenize("<PAD> <SOS> Je suis jeune <EOS>")
print("Token IDs:", token_ids)
print("num Tokens:", len(token_ids))

# Detokenizing
detokenized_text = tokenizer.detokenize(token_ids)
print("Detokenized Text:", detokenized_text)

Token IDs: [10, 645, 179, 315, 1207, 130, 973, 248, 399]
num Tokens: 9
Detokenized Text: <PAD> <SOS> Je suis jeune <EOS> 


In [30]:
tokenizer.tokenize("<SOS>")

[645]

In [31]:
class BilingualDataset(Dataset):
    def __init__(self, dataset, tokenizer: BPE_Tokenizer, seq_len) -> None:
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        
        self.sos_idx = torch.tensor(self.tokenizer.tokenize("<SOS>"), dtype = torch.int64)
        self.eos_idx = torch.tensor(self.tokenizer.tokenize("<EOS>"), dtype = torch.int64)
        self.pad_idx = torch.tensor(self.tokenizer.tokenize("<PAD>"), dtype = torch.int64)
        
    def _causal_mask(self, seq_len: int) -> torch.Tensor:
        mask = torch.ones(1, seq_len, seq_len, dtype=torch.bool)
        mask = torch.tril(mask, diagonal=0)
        return mask
        
    def __len__(self) -> int:
        return len(self.dataset)
    
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        src = torch.tensor(self.tokenizer.tokenize(clean(self.dataset[idx]["en"])), dtype = torch.int64)
        tgt = torch.tensor(self.tokenizer.tokenize(clean(self.dataset[idx]["fr"])), dtype = torch.int64)
        if self.seq_len - len(src) - 2 < 0 or self.seq_len - len(tgt) - 1 < 0:
            src = src[:self.seq_len - 2]
            tgt = tgt[:self.seq_len - 1]
        enc_num_pad = self.seq_len - len(src) - 2
        dec_num_pad = self.seq_len - len(tgt) - 1
        input_src = torch.cat([self.sos_idx, src, self.eos_idx, self.pad_idx.repeat(enc_num_pad)])
        input_tgt = torch.cat([self.sos_idx, tgt, self.pad_idx.repeat(dec_num_pad)])
        input_label = torch.cat([tgt,self.eos_idx, self.pad_idx.repeat(dec_num_pad)])
        return (
            input_src, 
            input_tgt, 
            input_label, 
            (input_src!=self.pad_idx).unsqueeze(0).unsqueeze(0).int() == 1,
            (input_tgt!=self.pad_idx).unsqueeze(0).unsqueeze(0).int() & self._causal_mask(self.seq_len) == 1,
            )

In [32]:
dataset = [x for x in dataset["train"]["translation"] if len(tokenizer.tokenize(x["en"])) < 100 and len(tokenizer.tokenize(x["fr"]))  < 100]

In [33]:
train_set = dataset[:int(len(dataset)*0.8)]
val_set = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_set = dataset[int(len(dataset)*0.9):]

In [34]:
# max_seq_len = 0
# sum = 0
# for i in range(len(train_set)):
#     src = train_set[i]["en"]
#     tgt = train_set[i]["fr"]
#     max_seq_len = max(max_seq_len, len(tokenizer.tokenize(src)), len(tokenizer.tokenize(tgt)))
#     sum+=len(tokenizer.tokenize(src))
# mean = sum/len(train_set)
# max_seq_len, mean

In [35]:
dataset_train = BilingualDataset(train_set, tokenizer, 100)
val_set = BilingualDataset(val_set, tokenizer, 100)
test_set = BilingualDataset(test_set, tokenizer, 100)

In [36]:
len(dataset_train), len(val_set), len(test_set)

(83537, 10442, 10443)

In [37]:
input_src, input_tgt, input_label, src_mask, tgt_mask = dataset_train[0]

In [38]:
input_tgt.shape, input_src.shape, input_label.shape, src_mask.shape, tgt_mask.shape

(torch.Size([100]),
 torch.Size([100]),
 torch.Size([100]),
 torch.Size([1, 1, 100]),
 torch.Size([1, 100, 100]))

In [39]:
device = torch.device("cuda")

In [40]:
model = Transformer(
    vocab_size=len(tokenizer),
    n_head=6,
    embed_size=600,
    context_length=100,
    dropout=0.1,
    num_layers=6,
    device=device,
)
model = model.to(device)

In [41]:
batch_size = 32
num_epochs = 5
lr = 3e-4

In [42]:
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, drop_last=True)

In [43]:
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer, T_max=lr)

In [44]:
def gen(model: nn.Module, sentence: str, max_len: int, vocab: BPE_Tokenizer, device: torch.device):
    model.eval()
    sos_token = torch.tensor(vocab.tokenize("<SOS>"), dtype=torch.int64).to(device)
    eos_token = torch.tensor(vocab.tokenize("<EOS>"), dtype=torch.int64).to(device)
    pad_token = torch.tensor(vocab.tokenize("<PAD>"), dtype=torch.int64).to(device)
    # print(sos_token, eos_token, pad_token)
    
    src_input = torch.cat([sos_token, torch.tensor(vocab.tokenize(sentence), dtype=torch.int64).to(device), eos_token, pad_token.repeat(max_len - len(vocab.tokenize(sentence)) - 2)])
    src_mask = (src_input != pad_token).unsqueeze(0).int() == 1
    
    tgt_input = sos_token
    while tgt_input[-1] != eos_token and len(tgt_input) < max_len:
        tgt_mask = dataset_train._causal_mask(tgt_input.shape[0]) == 1
        src_input, tgt_input, src_mask, tgt_mask = src_input.to(device), tgt_input.to(device), src_mask.to(device), tgt_mask.to(device)
        # print(src_input.unsqueeze(0).shape, tgt_input.unsqueeze(0).shape, src_mask.unsqueeze(0).shape, tgt_mask.unsqueeze(0).shape)
        # print(src_input.unsqueeze(0).dtype, tgt_input.unsqueeze(0).dtype, src_mask.unsqueeze(0).dtype, tgt_mask.unsqueeze(0).dtype)

        logits = model(src_input.unsqueeze(0), tgt_input.unsqueeze(0), src_mask.unsqueeze(0), tgt_mask.unsqueeze(0))
        pred = F.softmax(logits, dim=-1)
        # print(pred.shape)
        # print(pred[:, -1, :].argmax(dim=-1))
        # next_token = torch.multinomial(pred[:,-1,:], num_samples=1)
        tgt_input = torch.cat([tgt_input, pred[:,-1,:].argmax(dim = -1).to(device)])
        # print(tgt_input)
    print(vocab.detokenize(tgt_input.tolist()))
gen(model, "I am a student", 100, tokenizer, device)

<SOS> rimetin rorionorous norinIt cinorinase "arin reche o ast erinorelirime reroorelirime esting ast erinorelirime lois eremares asiliast erinorelirime lois eremares asiliast erinorelirime lois oreces enomerinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares ailinatile come arfpo arin oit ais lored eshais lored 


In [45]:
min_valid_loss = np.inf
for _ in range(num_epochs):
    model.train()
    with tqdm.tqdm(enumerate(train_loader)) as pbar:
        for idx, (src, tgt, label, src_mask, tgt_mask) in pbar:
            src = src.to(device)
            tgt = tgt.to(device)
            label = label.to(device)
            src_mask = src_mask.to(device)
            tgt_mask = tgt_mask.to(device)
            output = model(src, tgt, src_mask, tgt_mask)
            B, T, C = output.shape
            if idx%1000 == 0:
                print(tokenizer.detokenize(src[0].tolist()))
                print(tokenizer.detokenize(output.argmax(dim=-1)[0].tolist()))
                print(tokenizer.detokenize(label[0].tolist()))
                print("test génération: I am a student ->")
                gen(model, "I am a student", 100, tokenizer, device)
                print("\n\n")
            loss = F.cross_entropy(output.view(B * T, C), label.view(B * T), ignore_index=tokenizer.tokenize("[PAD]")[0])
            acc = (output.argmax(dim=-1) == label).float().mean()
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            pbar.set_description(f"Epoch {_} | Loss {loss.item():.3f} | Acc {acc.item():.3f}")
            # break
    valid_loss = 0
    model.eval()
    with tqdm.tqdm(val_loader) as pbar:
        with torch.no_grad():
            for src, tgt, label, src_mask, tgt_mask in pbar:
                src = src.to(device)
                tgt = tgt.to(device)
                label = label.to(device)
                src_mask = src_mask.to(device)
                tgt_mask = tgt_mask.to(device)
                
                output = model(src, tgt, src_mask, tgt_mask)
                loss = F.cross_entropy(output.view(B * T, C), label.view(B * T), ignore_index=tokenizer.tokenize("[PAD]")[0])
                valid_loss += loss.item()*src.shape[0]
                pbar.set_description(f"Epoch {_} | Loss {loss.item():.3f}")
    print(f'Epoch {_+1}Validation Loss: {valid_loss / len(val_set)}')
    if min_valid_loss > valid_loss:
        print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), 'saved_model_tokenizer_3.pth')

0it [00:00, ?it/s]

<SOS> All his life seemed to pass before his eyes . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
rimece res ilorPath pou eth alereshentil dinou enomaret eting sthe ach ron enomthe tin sufenometing shis tin pou rento aret aret aret rento rento aret rento red rento aret aret red aret aret aret aret pou aret aret rento aret aret pou pou esto pou aret rento aret pou rento rento aret aret aret aret rento rento aret aret aret aret aret rento aret aret aret aret rento rento oun aret aret aret aret rento rento aret tin aret aret aret rento aret aret rento rento rento aret aret aret 
Toute sa vie

Epoch 0 | Loss 7.463 | Acc 0.000: : 1it [00:00,  1.33it/s]

<SOS> rimetin rorionorous norinIt cinorinase "arin reche o ast erinorelirime reroorelirime esting ast erinorelirime lois eremares asiliast erinorelirime lois eremares asiliast erinorelirime lois oreces enomerinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares asiliast erinorelilis ares ailinatile come arfpo arin oit ais lored eshais lored 





Epoch 0 | Loss 3.299 | Acc 0.560: : 3it [00:01,  2.71it/s]

Epoch 0 | Loss 1.310 | Acc 0.685: : 1000it [04:28,  3.69it/s]

<SOS> No crewmen . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Nin i meume . m'aui tin. <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <

Epoch 0 | Loss 1.068 | Acc 0.729: : 2000it [08:59,  3.69it/s]

<SOS> It was of interest to me to hear these men , who were spending their lives in fighting against our neighbours , discussing their character and ways . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
I'vendqurand mstrt , es rr, es mommes , qui me ensaient teur pie , prtre , tpix s , et psant utes de pouttre de de dts s . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Je prenais grand intrt couter ces hommes , qui passaient leur vie combattre nos voisins , en discuter le caractre et les mthodes . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
test génération: I am a student ->
<SOS> 

Epoch 0 | Loss 1.005 | Acc 0.752: : 2610it [11:44,  3.70it/s]
Epoch 0 | Loss 0.897: 100%|██████████| 326/326 [00:28<00:00, 11.43it/s]


Epoch 1Validation Loss: 1.0352328117266854
Validation Loss Decreased(inf--->10809.901020) 	 Saving The Model


0it [00:00, ?it/s]

<SOS> "My Lord ! " cried dArtagnan , enlightened by a sudden idea , "my Lord ! Pardon me , monsieur , but you are not--" <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-- Mon d , s'cria d'Artagnan , s i ent , 'un fde , rite , mord , sendon, mais ceur , ais m--que vous touez v. <EOS> . <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-- Milord ! s'cria d'Artagnan illumin d'une ide subite , Milord ! pardon , monsieur mais est-ce que vous seriez . . . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
test génération: I am a stud

Epoch 1 | Loss 1.057 | Acc 0.742: : 1000it [04:30,  3.70it/s]

<SOS> "Discussing the gipsy , I daresay . " <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-- Je le us, 'il peut d, la cuhmienne , <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
--

Epoch 1 | Loss 0.960 | Acc 0.759: : 2000it [09:00,  3.70it/s]

<SOS> The engineer examined this black granite . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
L'ingnieur ebscva cgrand de . r . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD

Epoch 1 | Loss 0.938 | Acc 0.764: : 2610it [11:45,  3.70it/s]
Epoch 1 | Loss 0.880: 100%|██████████| 326/326 [00:28<00:00, 11.45it/s]


Epoch 2Validation Loss: 0.9074117751487757
Validation Loss Decreased(10809.901020--->9475.193756) 	 Saving The Model


0it [00:00, ?it/s]

<SOS> "I fear so . " <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-- Jh bien . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>

Epoch 2 | Loss 0.921 | Acc 0.762: : 1000it [04:30,  3.70it/s]

<SOS> The men had done all that men could do . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Pa mises ers avaient t ait tout ce quon il s avvaient . aire . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <

Epoch 2 | Loss 0.814 | Acc 0.791: : 2000it [09:00,  3.70it/s]

<SOS> "Beat , is he ? " answered Belcher . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-omyez-vous ,  rBelcher . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 

Epoch 2 | Loss 0.631 | Acc 0.834: : 2610it [11:45,  3.70it/s]
Epoch 2 | Loss 0.684: 100%|██████████| 326/326 [00:28<00:00, 11.38it/s]


Epoch 3Validation Loss: 0.847887329981259
Validation Loss Decreased(9475.193756--->8853.639500) 	 Saving The Model


0it [00:00, ?it/s]

<SOS> Then , when the operation was over , we burned every trace of our stay on that islet , which if I could have , I'd have blown up . " <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Cenis , qu'avbration , ait re ine , si mav et t 'ouu , ait te la rav, la tre prosage , r lethb, nous amai rais pait poutant nopnous l'esais du . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Puis , l'opration termine , le feu a dtruit toute trace de notre passage sur cet lot que j'aurais fait sauter , si je l'avais pu . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
test génération: I am a student ->
<SOS> Je suis un co

Epoch 3 | Loss 0.677 | Acc 0.822: : 1000it [04:30,  3.70it/s]

<SOS> Ces rcits , ces rumeurs , prirent bientt un corps et , a force detre confirms cent fois , firent comprendre ce qui en tait . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
That cw, ware nbns , ook , ch banc, sop, whthere coudwly ned , conalfoudwly ded , whntil theiwpluve . em elves . to a meeorit. at. <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
These tales and rumours took substance and shape , and were corroborated and re-corroborated , until they resolved themselves into a definite name . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
test génération: I am a student ->
<SOS> Je suis un cuisine <EOS> 





Epoch 3 | Loss 0.713 | Acc 0.817: : 2000it [09:00,  3.70it/s]

<SOS> As for the mother , he could not tell . . . He gave me long explanations as to the one friend of the family . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Quant aumre , il ne 'ffirma pen de l ne ponnait lonlongues amxplications aume des x foul des i des la famille . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Quant la mre , il naffirmait rien Il me donna de longues explications comme au seul ami de la famille . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 

Epoch 3 | Loss 0.687 | Acc 0.827: : 2610it [11:44,  3.70it/s]
Epoch 3 | Loss 0.829: 100%|██████████| 326/326 [00:28<00:00, 11.41it/s]


Epoch 4Validation Loss: 0.823269912918462
Validation Loss Decreased(8853.639500--->8596.584431) 	 Saving The Model


0it [00:00, ?it/s]

<SOS> And he looked at her fixedly , while in his hand he held two long papers that he slid between his nails . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
En , se ttdra dorxement dpout ce trdil on lonain il eux longuil liers quil laisait srsser sre ses mgles . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Et il la considrait fixement , tout en tenant sa main deux longs papiers quil faisait glisser entre ses ongles . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <

Epoch 4 | Loss 0.664 | Acc 0.825: : 1000it [04:30,  3.70it/s]

<SOS> A definite drowsiness overcame us . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Une fotaine inomme lennv'enpara de nous . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PA

Epoch 4 | Loss 0.651 | Acc 0.831: : 2000it [09:00,  3.70it/s]

<SOS> What sperm whales you're handing us ! <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Quelle sachalots de vhalres ! <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <

Epoch 4 | Loss 0.795 | Acc 0.790: : 2610it [11:45,  3.70it/s]
Epoch 4 | Loss 0.794: 100%|██████████| 326/326 [00:28<00:00, 11.37it/s]


Epoch 5Validation Loss: 0.806047816770649
Validation Loss Decreased(8596.584431--->8416.751303) 	 Saving The Model


In [46]:
device = torch.device("cpu")
model = Transformer(
    vocab_size=len(tokenizer),
    n_head=6,
    embed_size=600,
    context_length=100,
    dropout=0.1,
    num_layers=6,
    device=device,
)
model.load_state_dict(torch.load('saved_model_tokenizer_3.pth'))
model

Transformer(
  (encoder): Encoder(
    (embed): Embedding(1210, 600)
    (pos_encoding): SinusoidEncoding()
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x Block(
        (qkv): Linear(in_features=600, out_features=1800, bias=False)
        (mha): MultiHeadAttention(
          (fc_out): Linear(in_features=600, out_features=600, bias=False)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (fc_dropout): Dropout(p=0.1, inplace=False)
        )
        (ffwd): FeedForward(
          (0): Linear(in_features=600, out_features=2400, bias=False)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2400, out_features=600, bias=False)
          (3): Dropout(p=0.1, inplace=False)
        )
        (norm1): LayerNorm((600,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((600,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (decoder): Decoder(
    (embed): Embedding(1210, 600)
    (pos_encoding)

In [47]:

def gen(model: nn.Module, sentence: str, max_len: int, vocab: BPE_Tokenizer, device: torch.device):
    model.eval()
    sos_token = torch.tensor(vocab.tokenize("<SOS>"), dtype=torch.int64).to(device)
    eos_token = torch.tensor(vocab.tokenize("<EOS>"), dtype=torch.int64).to(device)
    pad_token = torch.tensor(vocab.tokenize("<PAD>"), dtype=torch.int64).to(device)
    print(sos_token, eos_token, pad_token)
    
    src_input = torch.cat([sos_token, torch.tensor(vocab.tokenize(sentence), dtype=torch.int64).to(device), eos_token, pad_token.repeat(max_len - len(vocab.tokenize(sentence)) - 2)])
    src_mask = (src_input != pad_token).unsqueeze(0).int() == 1
    
    tgt_input = sos_token
    while tgt_input[-1] != eos_token and len(tgt_input) < max_len:
        tgt_mask = dataset_train._causal_mask(tgt_input.shape[0]) == 1
        src_input, tgt_input, src_mask, tgt_mask = src_input.to(device), tgt_input.to(device), src_mask.to(device), tgt_mask.to(device)
        print(src_input.unsqueeze(0).shape, tgt_input.unsqueeze(0).shape, src_mask.unsqueeze(0).shape, tgt_mask.unsqueeze(0).shape)
        print(src_input.unsqueeze(0).dtype, tgt_input.unsqueeze(0).dtype, src_mask.unsqueeze(0).dtype, tgt_mask.unsqueeze(0).dtype)

        logits = model(src_input.unsqueeze(0), tgt_input.unsqueeze(0), src_mask.unsqueeze(0), tgt_mask.unsqueeze(0))
        pred = F.softmax(logits, dim=-1)
        print(pred.shape)
        print(pred[:, -1, :].argmax(dim=-1))
        # next_token = torch.multinomial(pred[:,-1,:], num_samples=1)
        tgt_input = torch.cat([tgt_input, pred[:,-1,:].argmax(dim = -1).to(device)])
        print(tgt_input)
    print(vocab.detokenize(tgt_input.tolist()))
gen(model, "I am a student", 100, tokenizer, device)

tensor([645]) tensor([399]) tensor([10])
torch.Size([1, 100]) torch.Size([1, 1]) torch.Size([1, 1, 100]) torch.Size([1, 1, 1, 1])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([1, 1, 1210])
tensor([179])
tensor([645, 179])
torch.Size([1, 100]) torch.Size([1, 2]) torch.Size([1, 1, 100]) torch.Size([1, 1, 2, 2])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([1, 2, 1210])
tensor([315])
tensor([645, 179, 315])
torch.Size([1, 100]) torch.Size([1, 3]) torch.Size([1, 1, 100]) torch.Size([1, 1, 3, 3])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([1, 3, 1210])
tensor([1207])
tensor([ 645,  179,  315, 1207])
torch.Size([1, 100]) torch.Size([1, 4]) torch.Size([1, 1, 100]) torch.Size([1, 1, 4, 4])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([1, 4, 1210])
tensor([130])
tensor([ 645,  179,  315, 1207,  130])
torch.Size([1, 100]) torch.Size([1, 5]) torch.Size([1, 1, 100]) torch.Size([1, 1, 5, 5])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([