In [1]:
import json
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from new_transformer import Transformer
import torch.nn as nn
from torch.optim import AdamW
import tqdm
import torch.nn.functional as F
import numpy as np

In [2]:
class BPE_Tokenizer:
    def _init_(self):
        self.vocab = set()
        self.token_to_index = {}
        self.index_to_token = {}

    @staticmethod
    def get_stats(vocab):
        pairs = {}
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols)-1):
                pair = (symbols[i], symbols[i+1])
                if pair in pairs:
                    pairs[pair] += freq
                else:
                    pairs[pair] = freq
        return pairs

    @staticmethod
    def merge_vocab(pair, v_in):
        v_out = {}
        bigram = ' '.join(pair)
        replacement = ''.join(pair)
        for word in v_in:
            w_out = word.replace(bigram, replacement)
            v_out[w_out] = v_in[word]
        return v_out

    @staticmethod
    def get_vocab(text):
        vocab = {}
        for word in text.split():
            word = ' '.join(list(word)) +  ' </w>'
            if word in vocab:
                vocab[word] += 1
            else:
                vocab[word] = 1
        return vocab

    def train_until(self, text, vocab_size):
        vocab = self.get_vocab(text)
        self.vocab = set(word for word in vocab for word in word.split())
        # Check if the initial vocabulary size is less than the desired size
        while len(self.vocab) < vocab_size:
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.vocab = set(word for word in vocab for word in word.split())

        # Add special tokens
        self.vocab.add('</u>')
        self.vocab.add('[]')

        self.build_index()

    def train(self, text, num_merges):
            vocab = self.get_vocab(text)
            for i in range(num_merges):
                pairs = self.get_stats(vocab)
                if not pairs:
                    break
                best = max(pairs, key=pairs.get)
                vocab = self.merge_vocab(best, vocab)

            self.vocab = set(word for word in vocab for word in word.split())

            # Add special tokens
            self.vocab.add('</u>')

            self.build_index()

    def build_index(self):
        # existing code
        self.token_to_index = {token: index for index, token in enumerate(self.vocab)}
        self.index_to_token = {index: token for token, index in self.token_to_index.items()}

    def tokenize(self, text):
        tokens = []
        for word in text.split():
            subwords = self.get_subwords(word + '</w>')
            tokens.extend(self.token_to_index.get(sw, self.token_to_index['</u>']) for sw in subwords)
        return tokens

    def get_subwords(self, word):
        subwords = []
        while word:
            subword = self.find_longest_subword(word)
            if subword is None:
                subwords.append('</u>')
                break
            subwords.append(subword)
            word = word[len(subword):]
        return subwords

    def find_longest_subword(self, word):
        for i in range(len(word), 0, -1):
            if word[:i] in self.vocab:
                return word[:i]
        return None

    def detokenize(self, token_ids):
        words = []
        current_word = ''
        for token_id in token_ids:
            token = self.index_to_token.get(token_id, '</u>')
            if token == '</w>':
                words.append(current_word)
                current_word = ''
            else:
                current_word += token
        words.append(current_word)
        return ' '.join(words).replace('</w>', ' ')
    
    def add_special_tokens(self, special_tokens):
        for token in special_tokens:
            self.vocab.add(token)
        self.build_index()

    def save_vocab(self, file_path):
        with open(file_path, 'w') as f:
            json.dump(self.token_to_index, f)

    def load_vocab(self, file_path):
        with open(file_path, 'r') as f:
            self.token_to_index = json.load(f)
            self.index_to_token = {int(index): token for token, index in self.token_to_index.items()}
            self.vocab = set(self.token_to_index.keys())
    
    def __len__(self):
        return len(self.token_to_index)

In [3]:
import datasets
import tqdm as tqdm
import pandas as pd
dataset = datasets.load_dataset("opus_books", "en-fr")
en_content = ''.join([dataset["train"][i]["translation"]["en"] for i in range(dataset.num_rows["train"])])
fr_content = ''.join([dataset["train"][i]["translation"]["fr"] for i in range(dataset.num_rows["train"])])
global_content = en_content + fr_content

tokenizer = BPE_Tokenizer()

num_merges =  90
untils = []
total_tokens = []
avg_token_len = []
avg_token_std = []

def clean(text: str) -> str:
    text = text.replace('.', ' . ')
    text = text.replace(',', ' , ')
    text = text.replace('!', ' ! ')
    text = text.replace('?', ' ? ')
    text = text.replace(':', ' : ')
    text = text.replace(';', ' ; ')
    text = text.replace(')', '')
    text = text.replace('(', '')
    text = text.replace('@', '')
    text = text.replace('|', '')
    text = text.replace(']', '')
    text = text.replace('[', '')
    text = text.replace('~', '')
    text = text.replace('^', '')
    text = text.replace('<', '')
    text = text.replace('>', '')
    text = text.replace('&', '')
    text = text.replace('{', '')
    text = text.replace('}', '')
    text = text.replace('+', '')
    # text = text.replace('-', '')
    text = text.replace('tititi', '')
    text = text.replace('orerer', '')
    text = text.replace('errero', '')
    text = text.replace('\u007f', '')
    text = text.replace('_', '')
    text = text.replace('%', '')
    text = text.replace('$', '')
    text = text.replace('\\', '')
    text = text.replace('=', '')
    text = text.replace('#', '')
    text = text.replace(';', '')
    text = text.replace(':', '')
    text = text.encode("ascii", errors="ignore").decode()
    return text


global_content = clean(global_content)
print("text length:", len(global_content))
print("Training...")
tokenizer.train_until(global_content, 1200)
# tokenizer.load_vocab('token_to_index.json')

print("Vocab:", tokenizer.token_to_index)
print("vocab size:", len(tokenizer))


tokenizer.add_special_tokens(['<PAD></w>', '<SOS></w>','<EOS></w>'])


# Saving token_to_index mapping
tokenizer.save_vocab('en-fr.json')

  from .autonotebook import tqdm as notebook_tqdm


text length: 31353581
Training...
Vocab: {'lorin': 0, 'reu': 1, 'oreti': 2, 'renou</w>': 3, 'lone</w>': 4, 'Y': 5, 'aril': 6, 'lut</w>': 7, 'apou': 8, 'nor</w>': 9, 'aloi': 10, 'rentime</w>': 11, 'E': 12, 'poin': 13, 'eron</w>': 14, 'po</w>': 15, 'aloi</w>': 16, 'arerent</w>': 17, 'noron': 18, 'arer': 19, 'sing</w>': 20, 'erev': 21, 'eto</w>': 22, 'ronon': 23, 'poing</w>': 24, 'dily</w>': 25, 'loro': 26, 'eroi</w>': 27, 'orom': 28, 'apori': 29, 'ili</w>': 30, 'x</w>': 31, 'f</w>': 32, 'enti': 33, 'orero': 34, 'esin</w>': 35, 'elor</w>': 36, 'lorer</w>': 37, 'aporous</w>': 38, 'ating</w>': 39, 'oreli': 40, 'areron': 41, 'ronon</w>': 42, 'orec': 43, 'etis</w>': 44, 'anon': 45, 'lin': 46, 'ois</w>': 47, 'ald</w>': 48, 'ti</w>': 49, 'etit</w>': 50, 'arom': 51, 'lilu': 52, 're</w>': 53, 'wily</w>': 54, 'loin': 55, 'rori': 56, 'stilit</w>': 57, 'elous</w>': 58, 'entis</w>': 59, 'erois</w>': 60, 'oun</w>': 61, 'sili': 62, 'rech': 63, 'atil</w>': 64, 'eren': 65, 'estily</w>': 66, 'atil': 67, '

In [4]:
# Tokenizing
token_ids = tokenizer.tokenize("<PAD> <SOS> Je suis jeune <EOS>")
print("Token IDs:", token_ids)
print("num Tokens:", len(token_ids))

# Detokenizing
detokenized_text = tokenizer.detokenize(token_ids)
print("Detokenized Text:", detokenized_text)

Token IDs: [1130, 455, 713, 803, 644, 754, 1166, 198, 38]
num Tokens: 9
Detokenized Text: <PAD> <SOS> Je suis jeune <EOS> 


In [5]:
tokenizer.tokenize("<SOS>")

[455]

In [6]:
class BilingualDataset(Dataset):
    def __init__(self, dataset, tokenizer: BPE_Tokenizer, seq_len) -> None:
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        
        self.sos_idx = torch.tensor(self.tokenizer.tokenize("<SOS>"), dtype = torch.int64)
        self.eos_idx = torch.tensor(self.tokenizer.tokenize("<EOS>"), dtype = torch.int64)
        self.pad_idx = torch.tensor(self.tokenizer.tokenize("<PAD>"), dtype = torch.int64)
        
    def _causal_mask(self, seq_len: int) -> torch.Tensor:
        mask = torch.ones(1, seq_len, seq_len, dtype=torch.bool)
        mask = torch.tril(mask, diagonal=0)
        return mask
        
    def __len__(self) -> int:
        return len(self.dataset)
    
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        src = torch.tensor(self.tokenizer.tokenize(clean(self.dataset[idx]["en"])), dtype = torch.int64)
        tgt = torch.tensor(self.tokenizer.tokenize(clean(self.dataset[idx]["fr"])), dtype = torch.int64)
        if self.seq_len - len(src) - 2 < 0 or self.seq_len - len(tgt) - 1 < 0:
            src = src[:self.seq_len - 2]
            tgt = tgt[:self.seq_len - 1]
        enc_num_pad = self.seq_len - len(src) - 2
        dec_num_pad = self.seq_len - len(tgt) - 1
        input_src = torch.cat([self.sos_idx, src, self.eos_idx, self.pad_idx.repeat(enc_num_pad)])
        input_tgt = torch.cat([self.sos_idx, tgt, self.pad_idx.repeat(dec_num_pad)])
        input_label = torch.cat([tgt,self.eos_idx, self.pad_idx.repeat(dec_num_pad)])
        return (
            input_src, 
            input_tgt, 
            input_label, 
            (input_src!=self.pad_idx).unsqueeze(0).unsqueeze(0).int() == 1,
            (input_tgt!=self.pad_idx).unsqueeze(0).unsqueeze(0).int() & self._causal_mask(self.seq_len) == 1,
            )

In [7]:
dataset = [x for x in dataset["train"]["translation"] if len(tokenizer.tokenize(x["en"])) < 100 and len(tokenizer.tokenize(x["fr"]))  < 100]

In [8]:
train_set = dataset[:int(len(dataset)*0.8)]
val_set = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_set = dataset[int(len(dataset)*0.9):]

In [9]:
# max_seq_len = 0
# sum = 0
# for i in range(len(train_set)):
#     src = train_set[i]["en"]
#     tgt = train_set[i]["fr"]
#     max_seq_len = max(max_seq_len, len(tokenizer.tokenize(src)), len(tokenizer.tokenize(tgt)))
#     sum+=len(tokenizer.tokenize(src))
# mean = sum/len(train_set)
# max_seq_len, mean

In [10]:
dataset_train = BilingualDataset(train_set, tokenizer, 100)
val_set = BilingualDataset(val_set, tokenizer, 100)
test_set = BilingualDataset(test_set, tokenizer, 100)

In [11]:
len(dataset_train), len(val_set), len(test_set)

(83537, 10442, 10443)

In [12]:
input_src, input_tgt, input_label, src_mask, tgt_mask = dataset_train[0]

In [13]:
input_tgt.shape, input_src.shape, input_label.shape, src_mask.shape, tgt_mask.shape

(torch.Size([100]),
 torch.Size([100]),
 torch.Size([100]),
 torch.Size([1, 1, 100]),
 torch.Size([1, 100, 100]))

In [14]:
device = torch.device("cuda")

In [15]:
model = Transformer(
    vocab_size=len(tokenizer),
    n_head=6,
    embed_size=600,
    context_length=100,
    dropout=0.1,
    num_layers=6,
    device=device,
)
model = model.to(device)

In [16]:
batch_size = 32
num_epochs = 10
lr = 3e-4

In [17]:
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, drop_last=True)

In [18]:
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer, T_max=lr)

In [19]:
min_valid_loss = np.inf
for _ in range(num_epochs):
    model.train()
    with tqdm.tqdm(enumerate(train_loader)) as pbar:
        for idx, (src, tgt, label, src_mask, tgt_mask) in pbar:
            src = src.to(device)
            tgt = tgt.to(device)
            label = label.to(device)
            src_mask = src_mask.to(device)
            tgt_mask = tgt_mask.to(device)
            output = model(src, tgt, src_mask, tgt_mask)
            B, T, C = output.shape
            if idx%1000 == 0:
                print(tokenizer.detokenize(src[0].tolist()))
                print(tokenizer.detokenize(output.argmax(dim=-1)[0].tolist()))
                print(tokenizer.detokenize(label[0].tolist()))
            loss = F.cross_entropy(output.view(B * T, C), label.view(B * T), ignore_index=tokenizer.tokenize("[PAD]")[0])
            acc = (output.argmax(dim=-1) == label).float().mean()
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            pbar.set_description(f"Epoch {_} | Loss {loss.item():.3f} | Acc {acc.item():.3f}")
            # break
    valid_loss = 0
    model.eval()
    with tqdm.tqdm(val_loader) as pbar:
        with torch.no_grad():
            for src, tgt, label, src_mask, tgt_mask in pbar:
                src = src.to(device)
                tgt = tgt.to(device)
                label = label.to(device)
                src_mask = src_mask.to(device)
                tgt_mask = tgt_mask.to(device)
                
                output = model(src, tgt, src_mask, tgt_mask)
                loss = F.cross_entropy(output.view(B * T, C), label.view(B * T), ignore_index=tokenizer.tokenize("[PAD]")[0])
                valid_loss += loss.item()*src.shape[0]
                pbar.set_description(f"Epoch {_} | Loss {loss.item():.3f}")
    print(f'Epoch {_+1}Validation Loss: {valid_loss / len(val_set)}')
    if min_valid_loss > valid_loss:
        print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), 'saved_model_tokenizer_3.pth')

0it [00:00, ?it/s]

<SOS> "The smile is very well , " said he , catching instantly the passing expression "but speak too . " <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
8porelenteloin`erestabronoRtilit aly renouare lily ino 2elon lorowine rilatharely eloinrenononering eroiapour `sis sis astis sinore wily entinacerile orrour onorietimrenouastily renourenoustily renourenourenouwte renourenourenourenourenourenourenouagrenourenourenourenourenourenourenourenourenourenourenourenoureli renourenouagagrenourenouagrenourenourenouapore apore tilorenouwrenouagrenourenourenourenourenourenourenourenourenourenoustily 
-- Voil un sourire qui me plat , dit-il , mais cela ne suffit pas parlez . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>

Epoch 0 | Loss 5.258 | Acc 0.535: : 2it [00:00,  2.99it/s]

Epoch 0 | Loss 1.163 | Acc 0.729: : 1001it [04:30,  3.65it/s]

<SOS> I drank for forgetfulness , and when I woke next day I was beside the count . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
JJe nis , mvrer , et je and je ne paperille ais , ceucant ain , je 'tais p'la c.  mme . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
"Je bus pour oublier , et quand je me rveillai le lendemain , j'tais dans le lit du comte . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <P

Epoch 0 | Loss 1.434 | Acc 0.667: : 2001it [09:04,  3.65it/s]

<SOS> All these crumbling masses were covered with an enamel polished by the action of underground fires , and they glistened under the stream of electric light from our beacon . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Ehtes ces coses , 'eres abl, et urertes , 'Acode sde oules 'accon de cur, ouvaivs es , et secemda aiaient , tctrcde cus de 'tueues de  camc. <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Toutes ces masses dsagrges , recouvertes d'un mail poli sous l'action des feux souterrains , resplendissaient au contact des jets lectriques du fanal . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 


Epoch 0 | Loss 1.015 | Acc 0.753: : 2610it [11:51,  3.67it/s]
Epoch 0 | Loss 1.093: 100%|██████████| 326/326 [00:29<00:00, 11.24it/s]


Epoch 1Validation Loss: 1.060110277564241
Validation Loss Decreased(inf--->11069.671518) 	 Saving The Model


Epoch 1 | Loss 1.018 | Acc 0.746: : 1it [00:00,  3.30it/s]

<SOS> Pain is always by the side of joy , the spondee by the dactyl . Master , I must relate to you the history of the Barbeau mansion . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Je memleur de la out jours la aumla Mue , je Pod ,  rs de  Pocre s de <EOS>  est ari re , je est aut que je vous ai sla ette mistoire .  mce . eny on y  . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
La douleur est toujours ct de la joie , le sponde auprs du dactyle . Mon matre , il faut que je vous conte cette histoire du logis Barbeau . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 


Epoch 1 | Loss 0.970 | Acc 0.751: : 1001it [04:34,  3.65it/s]

<SOS> Cyrus Harding and his companions could not understand it . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Cyrus Smith et ne es compagnons ne pardrent ent . es compagdre . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Cyrus Smith et ses compagnons rega

Epoch 1 | Loss 0.782 | Acc 0.805: : 2001it [09:08,  3.65it/s]

<SOS> "My friends , " I said , "we're in a serious predicament , but I'm counting on your courage and energy . " <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-a amis , dit s-je , mmtuation , prand e , et on je suprendr votre amourage , vr motre amegie . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Mes amis , dis-je , la situation est grave , mais je compte sur votre courage et sur votre nergie . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>

Epoch 1 | Loss 1.051 | Acc 0.736: : 2610it [11:54,  3.65it/s]
Epoch 1 | Loss 0.925: 100%|██████████| 326/326 [00:28<00:00, 11.44it/s]


Epoch 2Validation Loss: 0.9462678233820505
Validation Loss Decreased(11069.671518--->9880.928612) 	 Saving The Model


Epoch 2 | Loss 0.806 | Acc 0.795: : 1it [00:00,  3.35it/s]

<SOS> They have always been together , and according to his account he has been a very lonely man with only her as a companion , so that the thought of losing her was really terrible to him . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Cl s ne t poujours peru qucemble , stporehxcstence de ans eude aire , elque t'ltrzer , clusonection e , sclusre , pvait ponn psi t'bler . out rible . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Ils ont toujours vcu ensemble et il a men une existence solitaire quelle seule gayait la perspective de la perdre ne pouvait donc que lui sembler terrible . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 


Epoch 2 | Loss 1.089 | Acc 0.725: : 1001it [04:34,  3.65it/s]

<SOS> It is true , that while I worked , she would idle and I thought to myself , "If you and I were destined to live always together , cousin , we would commence matters on a different footing . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> 
Se vais s tonc vbsge , vipoter ,  si , ien que je ssible , metiaites , je mves s ons , vethprit , aire le , et je vertmicomaleux . vourre . que porraer . 'comrettes . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Je fus donc oblige de supporter aussi bien que possible les plaintes et les lamentations de cet esprit faible , et je fis de mon mieux pour coudre et emballer ses toilettes . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 


Epoch 2 | Loss 0.891 | Acc 0.776: : 2001it [09:08,  3.65it/s]

<SOS> He was making his blood too thick by going to sleep every evening after dinner . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Il n'arsait ssait sdoig de ous fre dinir , erque joir , rs sdner . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Il spaississait le sang sendormir chaque soir aprs le dner . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PA

Epoch 2 | Loss 0.863 | Acc 0.780: : 2610it [11:54,  3.65it/s]
Epoch 2 | Loss 1.011: 100%|██████████| 326/326 [00:28<00:00, 11.44it/s]


Epoch 3Validation Loss: 0.9017884768281625
Validation Loss Decreased(9880.928612--->9416.475275) 	 Saving The Model


Epoch 3 | Loss 0.954 | Acc 0.752: : 1it [00:00,  3.20it/s]

<SOS> The minutes passed very slowly fifteen were counted before the library- door again opened . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Les cemps tasut dgde inart de 'homre , 'inril, ous c'il tendait , vrir , porte . la voubliothque . cin . et r le . ngoal  . int . ar la voule . ogre . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Le temps parut long un quart d'heure s'coula sans qu'on entendt ouvrir la porte de la bibliothque enfin , Mlle Ingram revint par la salle manger . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 


Epoch 3 | Loss 0.773 | Acc 0.798: : 1001it [04:34,  3.65it/s]

<SOS> "Valuable ? " returned Pencroft . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-- Eouuisx ? rpondit Pencroff . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PA

Epoch 3 | Loss 0.891 | Acc 0.775: : 2001it [09:08,  3.65it/s]

<SOS> Our wagonette had topped a rise and in front of us rose the huge expanse of the moor , mottled with gnarled and craggy cairns and tors . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Notre poiture avait pteint la caut de lmo, ans ant nous , ait dre tporetde , et ltsde cled es et tves et de cois tagenous ame . pess les . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Notre voiture avait atteint le haut de la cte devant nous stendait la lande , parseme de pics coniques et de monts-joie en dentelles . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 


Epoch 3 | Loss 0.830 | Acc 0.796: : 2610it [11:55,  3.65it/s]
Epoch 3 | Loss 0.773: 100%|██████████| 326/326 [00:28<00:00, 11.43it/s]


Epoch 4Validation Loss: 0.8718183086740858
Validation Loss Decreased(9416.475275--->9103.526779) 	 Saving The Model


Epoch 4 | Loss 0.716 | Acc 0.816: : 1it [00:00,  3.09it/s]

<SOS> M de Treville approved of the resolution he had adopted , and assured him that if on the morrow he did not appear , he himself would undertake to find him , let him be where he might . <EOS> <PAD> <PAD> <PAD> 
L . de Trville , prova qutuolution , il il avait pris se , et il 'insra que , il il vendemain , av'avait pas enprou , il se 'tait sien slurouver . il i , il arut out . il averait . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
M . de Trville approuva la rsolution qu'il avait prise , et l'assura que , si le lendemain il n'avait pas reparu , il saurait bien le retrouver , lui , partout o il serait . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 


Epoch 4 | Loss 0.784 | Acc 0.799: : 1001it [04:34,  3.64it/s]

<SOS> Then he rose , and paced slowly up and down the room , his chin sunk upon his breast . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Ae sque il il se teva , et se lit en rienter su'ement et miece et fas ssa sumte . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Lorsqu'il se releva , il se mit arpenter lentement la pice en baissant la tte . <EOS> <PAD> <PAD> <PAD>

Epoch 4 | Loss 0.884 | Acc 0.771: : 2001it [09:08,  3.65it/s]

<SOS> Quick , now ! ' <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Allons , cocheus -nous ! <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD

Epoch 4 | Loss 0.724 | Acc 0.810: : 2610it [11:55,  3.65it/s]
Epoch 4 | Loss 0.759: 100%|██████████| 326/326 [00:28<00:00, 11.45it/s]


Epoch 5Validation Loss: 0.8573653238117867
Validation Loss Decreased(9103.526779--->8952.608711) 	 Saving The Model


Epoch 5 | Loss 0.714 | Acc 0.818: : 1it [00:00,  3.29it/s]

<SOS> Next day , the Marquis took Julien to a lonely mansion , at some distance from Paris . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Pe lendemain , Jmarquis , vit sit Julien , ineeau de col . sez longn . Paris . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
Le lendemain , le marquis conduisit Julien un chteau isol assez loign de Paris . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <

Epoch 5 | Loss 0.735 | Acc 0.808: : 1001it [04:34,  3.65it/s]

<SOS> It is a perpetual "I love you . " <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-est une pvous aime . ersuuel . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PA

Epoch 5 | Loss 0.912 | Acc 0.762: : 2001it [09:08,  3.65it/s]

<SOS> "What then , Die ? " he replied , maintaining a marble immobility of feature . <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-- Qt  bien , quareua ! il prit t-il , se tervant mmimporobile de marbre . il x bien ! <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> 
-- Eh bien ! Diana , reprit-il en conservant la mme immobilit de marbre , eh bien ! <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD

Epoch 5 | Loss 0.729 | Acc 0.812: : 2322it [10:36,  3.65it/s]


KeyboardInterrupt: 

In [20]:
device = torch.device("cpu")
model = Transformer(
    vocab_size=len(tokenizer),
    n_head=6,
    embed_size=600,
    context_length=100,
    dropout=0.1,
    num_layers=6,
    device=device,
)
model.load_state_dict(torch.load('saved_model_tokenizer_3.pth'))
model

Transformer(
  (encoder): Encoder(
    (embed): Embedding(1210, 600)
    (pos_encoding): SinusoidEncoding()
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x Block(
        (qkv): Linear(in_features=600, out_features=1800, bias=False)
        (mha): MultiHeadAttention(
          (fc_out): Linear(in_features=600, out_features=600, bias=False)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (fc_dropout): Dropout(p=0.1, inplace=False)
        )
        (ffwd): FeedForward(
          (0): Linear(in_features=600, out_features=2400, bias=False)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2400, out_features=600, bias=False)
          (3): Dropout(p=0.1, inplace=False)
        )
        (norm1): LayerNorm((600,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((600,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (decoder): Decoder(
    (embed): Embedding(1210, 600)
    (pos_encoding)

In [21]:

def gen(model: nn.Module, sentence: str, max_len: int, vocab: BPE_Tokenizer, device: torch.device):
    model.eval()
    sos_token = torch.tensor(vocab.tokenize("<SOS>"), dtype=torch.int64).to(device)
    eos_token = torch.tensor(vocab.tokenize("<EOS>"), dtype=torch.int64).to(device)
    pad_token = torch.tensor(vocab.tokenize("<PAD>"), dtype=torch.int64).to(device)
    print(sos_token, eos_token, pad_token)
    
    src_input = torch.cat([sos_token, torch.tensor(vocab.tokenize(sentence), dtype=torch.int64).to(device), eos_token, pad_token.repeat(max_len - len(vocab.tokenize(sentence)) - 2)])
    src_mask = (src_input != pad_token).unsqueeze(0).int() == 1
    
    tgt_input = sos_token
    while tgt_input[-1] != eos_token and len(tgt_input) < max_len:
        tgt_mask = dataset_train._causal_mask(tgt_input.shape[0]) == 1
        src_input, tgt_input, src_mask, tgt_mask = src_input.to(device), tgt_input.to(device), src_mask.to(device), tgt_mask.to(device)
        print(src_input.unsqueeze(0).shape, tgt_input.unsqueeze(0).shape, src_mask.unsqueeze(0).shape, tgt_mask.unsqueeze(0).shape)
        print(src_input.unsqueeze(0).dtype, tgt_input.unsqueeze(0).dtype, src_mask.unsqueeze(0).dtype, tgt_mask.unsqueeze(0).dtype)

        logits = model(src_input.unsqueeze(0), tgt_input.unsqueeze(0), src_mask.unsqueeze(0), tgt_mask.unsqueeze(0))
        pred = F.softmax(logits, dim=-1)
        print(pred.shape)
        print(pred[:, -1, :].argmax(dim=-1))
        # next_token = torch.multinomial(pred[:,-1,:], num_samples=1)
        tgt_input = torch.cat([tgt_input, pred[:,-1,:].argmax(dim = -1).to(device)])
        print(tgt_input)
    print(vocab.detokenize(tgt_input.tolist()))
gen(model, "I am a student", 100, tokenizer, device)

tensor([455]) tensor([38]) tensor([1130])
torch.Size([1, 100]) torch.Size([1, 1]) torch.Size([1, 1, 100]) torch.Size([1, 1, 1, 1])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([1, 1, 1210])
tensor([713])
tensor([455, 713])
torch.Size([1, 100]) torch.Size([1, 2]) torch.Size([1, 1, 100]) torch.Size([1, 1, 2, 2])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([1, 2, 1210])
tensor([803])
tensor([455, 713, 803])
torch.Size([1, 100]) torch.Size([1, 3]) torch.Size([1, 1, 100]) torch.Size([1, 1, 3, 3])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([1, 3, 1210])
tensor([644])
tensor([455, 713, 803, 644])
torch.Size([1, 100]) torch.Size([1, 4]) torch.Size([1, 1, 100]) torch.Size([1, 1, 4, 4])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([1, 4, 1210])
tensor([754])
tensor([455, 713, 803, 644, 754])
torch.Size([1, 100]) torch.Size([1, 5]) torch.Size([1, 1, 100]) torch.Size([1, 1, 5, 5])
torch.int64 torch.int64 torch.bool torch.bool
torch.Size([1, 5, 121