## Transformer From Scratch

### Transformer Decoder

#### Next Token Prediction Training

In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.init as init
from tqdm import tqdm

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        pe = torch.zeros((max_seq_length, d_model))
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "Vetor de embedding precisa ser divisivel pelo número de cabeças da camada de atenção!"
        self.head_dim = d_model // num_heads
        self.d_model, self.num_heads = d_model, num_heads
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, encoder_output=None):
        # Entra Q, K, V com dimensão (batch_size, sequence_length, d_model)
        # Reshape para (batch_size, sequence_length, num_heads, d_model)
        # Reordering para (batch_size, num_heads, sequence_length, d_model)
        if encoder_output is None:
            x = torch.reshape(x, shape=(x.shape[0], x.shape[1], self.num_heads, self.head_dim)) #.contiguous()
            x = x.permute(0, 2, 1, 3)
        else:
            raise NotImplementedError("Modelo ainda não compatível com Encoder.")
        return x

    def compute_attention_scores(self, q_linear_out, k_linear_out, v_linear_out, mask=None):
        qk_dot_product = torch.matmul(q_linear_out, k_linear_out.transpose(2, 3)) / self.head_dim ** 0.5

        if mask is not None:
            qk_dot_product = qk_dot_product.masked_fill(mask == 0, float('-inf'))

        attn_scores = nn.functional.softmax(qk_dot_product, dim=-1)
        attn_weighted_v = torch.matmul(attn_scores, v_linear_out)

        return attn_weighted_v


    def combine_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        return torch.reshape(x, shape=(x.shape[0], x.shape[1], int(x.shape[2] * x.shape[3])))

    def forward(self, x, mask):
        q_linear_out = self.split_heads(self.q(x))
        k_linear_out = self.split_heads(self.k(x))
        v_linear_out = self.split_heads(self.v(x))
        
        attn_weighted_v = self.compute_attention_scores(q_linear_out, k_linear_out, v_linear_out, mask=mask)
        attn_weighted_v = self.combine_heads(attn_weighted_v)
        return self.output_linear(attn_weighted_v)

In [4]:
matrix_1 = torch.rand(1, 8, 512, 10)
matrix_2 = torch.rand(1, 8, 512, 10)
print(torch.matmul(matrix_1, matrix_2.transpose(-1, -2)).shape)

torch.Size([1, 8, 512, 512])


In [5]:
class FeedForwardSubLayer(nn.Module):
    def __init__(self, d_model, hidden_size):
        super().__init__()
        self.ff_1 = nn.Linear(d_model, hidden_size)
        self.ff_2 = nn.Linear(hidden_size, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.ff_2(self.relu(self.ff_1(x)))

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, hidden_size, num_heads, dropout=0.1):
        super().__init__()
        self.feed_forward = FeedForwardSubLayer(d_model, hidden_size)
        self.mha = MultiHeadAttention(d_model, num_heads) # nn.MultiheadAttention()
        self.norm_1 = nn.LayerNorm(d_model)
        self.norm_2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, tgt_mask):
        x = self.norm_1(x + self.dropout(self.mha(x, mask=tgt_mask)))
        x = self.norm_2(x + self.dropout(self.feed_forward(x)))
        return x

In [7]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, max_sequence_length, n_layers, hidden_size, num_heads, dropout=0.1):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model, padding_idx=0)
        self.pe = PositionalEncoding(d_model, max_sequence_length)
        self.layers = nn.ModuleList(
            [DecoderBlock(d_model, hidden_size, num_heads, dropout) for _ in range(n_layers)]
        )
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, x, tgt_mask):
        x = self.embedding(x) # torch.nn.functional.pad(x, pad=(0, tgt_mask.shape[1] - x.shape[1]))
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, tgt_mask)
        out = self.output_layer(x)
        return out

In [8]:
### 259 tokens possíveis:
# ASCII + <UNK> + <SOS> + <EOS> + <PAD>

class TokenizerChar:
    def __init__(self):
        self.chr_to_idx = {chr(v): v for v in range(1, 257)}
        self.chr_to_idx['<SOS>'] = 257
        self.chr_to_idx['<EOS>'] = 258
        self.chr_to_idx['<PAD>'] = 0
        self.chr_to_idx['<UNK>'] = 259

        self.idx_to_chr = {v: k for k, v in self.chr_to_idx.items()}

        self.vocab_size = len(self.chr_to_idx.keys())

    def encode(self, char):
        if char in self.chr_to_idx.keys():
            return self.chr_to_idx[char]
        else:
            return 259
    
    def decode(self, token_idx):
        return self.idx_to_chr[token_idx]
    
    def sos_token(self):
        return '<SOS>'
    
    def sos_token_idx(self):
        return self.chr_to_idx['<SOS>']

    def eos_token(self):
        return '<EOS>'
    
    def eos_token_idx(self):
        return self.chr_to_idx['<EOS>']
    
    def pad_token(self):
        return '<PAD>'
    
    def pad_token_idx(self):
        return self.chr_to_idx['<PAD>']
    
    def get_vocab_size(self):
        return self.vocab_size


In [9]:
from torch.utils.data import Dataset


class DatasetDialogs(Dataset):
    def __init__(self, dataset_path, sentence_length):
        self.dataset_path = dataset_path
        self.sentence_length = sentence_length
        self.tokenizer = TokenizerChar()

    def __len__(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return num_of_sentences
    
    def get_shape(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return (num_of_sentences, self.sentence_length)

    def __getitem__(self, line_idx):
        with open(self.dataset_path, 'r') as dataset:
            selected_sentence = dataset.read().split('\n')[line_idx]
            if len(selected_sentence) < self.sentence_length:
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens.append(self.tokenizer.eos_token_idx())
                pad_length = self.sentence_length - len(input_tokens) + 1
                pad_tokens = [self.tokenizer.pad_token_idx()] * pad_length
                input_tokens += pad_tokens
            elif len(selected_sentence) == self.sentence_length:
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens[-1] = self.tokenizer.eos_token_idx()
            elif len(selected_sentence) > self.sentence_length:
                selected_sentence = selected_sentence[:self.sentence_length]
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens[-1] = self.tokenizer.eos_token_idx()
            # print(f'{len(input_tokens)} - {self.sentence_length}') # debug only
            assert len(input_tokens) == self.sentence_length + 1, f"Lista de índices de tokens não possui mesmo tamanho que 'sentence_length'! len(input_tokens): {len(input_tokens)} - self.sentence_length: {self.sentence_length}"
            try:
                x = torch.tensor(input_tokens[:-1])
                y = torch.tensor(input_tokens[1:])
            except RuntimeError as e:
                print(e)
                print(f"Input tokens: {input_tokens}")
                raise e
        return x, y

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
sequence_length = 100
batch_size = 32
dataset_train = DatasetDialogs('dataset_text/dialogs_train.txt', sequence_length)
dataset_test = DatasetDialogs('dataset_text/dialogs_test.txt', sequence_length)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=False)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
vocab_size = dataset_train.tokenizer.get_vocab_size()

d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1
max_seq_length = sequence_length
# model = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
model = TransformerDecoder(vocab_size, d_model, max_seq_length, num_layers, d_ff, num_heads, dropout=0.1)
model.to(device)

tgt_mask = (1 - torch.triu(
  torch.ones(1, sequence_length, sequence_length), diagonal=1)
).bool()

def init_weights(module):
    if isinstance(module, (nn.Linear)):
        init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            init.zeros_(module.bias)
model.apply(init_weights)

optimizer = Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
n_epochs = 1000
n_batches = int(dataset_train.__len__() // batch_size)

print("Starting model training...")
for epoch in range(n_epochs):
    print(f"Epoch: {epoch + 1}")
    avg_loss = 0
    model.train()
    for batch_idx, batch in enumerate(tqdm(dataloader_train, total=n_batches)):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        outputs = model(x, tgt_mask.to(device))
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    torch.save(model.state_dict(), f'model_checkpoints/model_checkpoint_{epoch+1}.pth')

    avg_loss /= (batch_idx + 1)
    print(f"Average epoch training loss: {avg_loss}")
    print(f"Last batch training loss: {loss}")

    model.eval()
    avg_loss = 0
    for batch_idx, batch in enumerate(dataloader_test):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        outputs = model(x, tgt_mask.to(device))
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
    
    avg_loss /= (batch_idx + 1)
    print(f"Epoch validation loss: {avg_loss}")
    

cuda
Starting model training...
Epoch: 1


100%|██████████| 107/107 [00:08<00:00, 12.15it/s]


Average epoch training loss: 3.2458659711285174
Last batch training loss: 3.150761604309082
Epoch validation loss: 3.0379201650619505
Epoch: 2


100%|██████████| 107/107 [00:09<00:00, 11.45it/s]


Average epoch training loss: 3.088669195353428
Last batch training loss: 3.1086926460266113
Epoch validation loss: 3.025111770629883
Epoch: 3


100%|██████████| 107/107 [00:09<00:00, 11.66it/s]


Average epoch training loss: 3.0608464134073703
Last batch training loss: 3.0924878120422363
Epoch validation loss: 2.996799182891846
Epoch: 4


100%|██████████| 107/107 [00:09<00:00, 11.64it/s]


Average epoch training loss: 3.0424999388578895
Last batch training loss: 3.0851898193359375
Epoch validation loss: 2.973800849914551
Epoch: 5


100%|██████████| 107/107 [00:09<00:00, 11.66it/s]


Average epoch training loss: 3.0222182184736304
Last batch training loss: 3.0729527473449707
Epoch validation loss: 2.970201849937439
Epoch: 6


100%|██████████| 107/107 [00:09<00:00, 11.53it/s]


Average epoch training loss: 2.932748362282726
Last batch training loss: 2.852943181991577
Epoch validation loss: 2.7321972370147707
Epoch: 7


100%|██████████| 107/107 [00:09<00:00, 11.49it/s]


Average epoch training loss: 2.666991068938068
Last batch training loss: 2.6629786491394043
Epoch validation loss: 2.5363375902175904
Epoch: 8


100%|██████████| 107/107 [00:09<00:00, 11.34it/s]


Average epoch training loss: 2.512644720968799
Last batch training loss: 2.5197994709014893
Epoch validation loss: 2.4078354120254515
Epoch: 9


100%|██████████| 107/107 [00:09<00:00, 11.48it/s]


Average epoch training loss: 2.3865823322367445
Last batch training loss: 2.471691370010376
Epoch validation loss: 2.319689416885376
Epoch: 10


100%|██████████| 107/107 [00:09<00:00, 11.21it/s]


Average epoch training loss: 2.3129742680308976
Last batch training loss: 2.3988847732543945
Epoch validation loss: 2.2517327070236206
Epoch: 11


100%|██████████| 107/107 [00:09<00:00, 11.32it/s]


Average epoch training loss: 2.252376291239373
Last batch training loss: 2.3383140563964844
Epoch validation loss: 2.195799446105957
Epoch: 12


100%|██████████| 107/107 [00:09<00:00, 11.29it/s]


Average epoch training loss: 2.199827740125567
Last batch training loss: 2.2986490726470947
Epoch validation loss: 2.14315779209137
Epoch: 13


100%|██████████| 107/107 [00:09<00:00, 11.37it/s]


Average epoch training loss: 2.147576248534372
Last batch training loss: 2.2622218132019043
Epoch validation loss: 2.1139885425567626
Epoch: 14


100%|██████████| 107/107 [00:09<00:00, 11.12it/s]


Average epoch training loss: 2.096043537710315
Last batch training loss: 2.2164206504821777
Epoch validation loss: 2.0595064997673034
Epoch: 15


100%|██████████| 107/107 [00:09<00:00, 11.06it/s]


Average epoch training loss: 2.0494329249747447
Last batch training loss: 2.1771974563598633
Epoch validation loss: 2.0188956260681152
Epoch: 16


100%|██████████| 107/107 [00:09<00:00, 11.32it/s]


Average epoch training loss: 2.00349745906402
Last batch training loss: 2.120504856109619
Epoch validation loss: 1.9882047176361084
Epoch: 17


100%|██████████| 107/107 [00:09<00:00, 11.32it/s]


Average epoch training loss: 1.964719863695519
Last batch training loss: 2.0716373920440674
Epoch validation loss: 1.938533365726471
Epoch: 18


100%|██████████| 107/107 [00:09<00:00, 11.17it/s]


Average epoch training loss: 1.9175561711052869
Last batch training loss: 2.0601110458374023
Epoch validation loss: 1.9062323093414306
Epoch: 19


100%|██████████| 107/107 [00:09<00:00, 11.11it/s]


Average epoch training loss: 1.8793466057732842
Last batch training loss: 2.014859199523926
Epoch validation loss: 1.8777598142623901
Epoch: 20


100%|██████████| 107/107 [00:09<00:00, 11.37it/s]


Average epoch training loss: 1.845056807883432
Last batch training loss: 1.9886656999588013
Epoch validation loss: 1.860313594341278
Epoch: 21


100%|██████████| 107/107 [00:09<00:00, 10.96it/s]


Average epoch training loss: 1.8116143317980187
Last batch training loss: 1.9381994009017944
Epoch validation loss: 1.8316502690315246
Epoch: 22


100%|██████████| 107/107 [00:09<00:00, 11.32it/s]


Average epoch training loss: 1.784659465896749
Last batch training loss: 1.9564028978347778
Epoch validation loss: 1.944465208053589
Epoch: 23


100%|██████████| 107/107 [00:09<00:00, 11.36it/s]


Average epoch training loss: 1.7664202930771302
Last batch training loss: 1.8975753784179688
Epoch validation loss: 1.7936704635620118
Epoch: 24


100%|██████████| 107/107 [00:09<00:00, 11.04it/s]


Average epoch training loss: 1.7267583441511494
Last batch training loss: 1.8565219640731812
Epoch validation loss: 1.7711947441101075
Epoch: 25


100%|██████████| 107/107 [00:09<00:00, 10.93it/s]


Average epoch training loss: 1.7010868598367566
Last batch training loss: 1.847628116607666
Epoch validation loss: 1.7567464351654052
Epoch: 26


100%|██████████| 107/107 [00:09<00:00, 10.99it/s]


Average epoch training loss: 1.6789818155431302
Last batch training loss: 1.8206658363342285
Epoch validation loss: 1.7410956263542174
Epoch: 27


100%|██████████| 107/107 [00:09<00:00, 11.23it/s]


Average epoch training loss: 1.6565852688851757
Last batch training loss: 1.808974027633667
Epoch validation loss: 1.729894745349884
Epoch: 28


100%|██████████| 107/107 [00:09<00:00, 11.19it/s]


Average epoch training loss: 1.6319976410019064
Last batch training loss: 1.7555654048919678
Epoch validation loss: 1.7009750366210938
Epoch: 29


100%|██████████| 107/107 [00:09<00:00, 11.09it/s]


Average epoch training loss: 1.6091624509508364
Last batch training loss: 1.7583696842193604
Epoch validation loss: 1.6934565424919128
Epoch: 30


100%|██████████| 107/107 [00:09<00:00, 11.02it/s]


Average epoch training loss: 1.5874448557880436
Last batch training loss: 1.735521674156189
Epoch validation loss: 1.6804898381233215
Epoch: 31


100%|██████████| 107/107 [00:09<00:00, 11.01it/s]


Average epoch training loss: 1.5665514992776317
Last batch training loss: 1.6902135610580444
Epoch validation loss: 1.6643911480903626
Epoch: 32


100%|██████████| 107/107 [00:09<00:00, 11.24it/s]


Average epoch training loss: 1.5462976304170126
Last batch training loss: 1.6780424118041992
Epoch validation loss: 1.6749812364578247
Epoch: 33


100%|██████████| 107/107 [00:09<00:00, 11.29it/s]


Average epoch training loss: 1.5285231354080628
Last batch training loss: 1.6480828523635864
Epoch validation loss: 1.6590741753578186
Epoch: 34


100%|██████████| 107/107 [00:09<00:00, 11.08it/s]


Average epoch training loss: 1.5112976671379303
Last batch training loss: 1.62749445438385
Epoch validation loss: 1.649367332458496
Epoch: 35


100%|██████████| 107/107 [00:09<00:00, 11.21it/s]


Average epoch training loss: 1.4938189515443605
Last batch training loss: 1.7022149562835693
Epoch validation loss: 1.654877483844757
Epoch: 36


100%|██████████| 107/107 [00:09<00:00, 11.31it/s]


Average epoch training loss: 1.4967309604181307
Last batch training loss: 1.6345971822738647
Epoch validation loss: 1.632383406162262
Epoch: 37


100%|██████████| 107/107 [00:09<00:00, 11.24it/s]


Average epoch training loss: 1.4632136030731915
Last batch training loss: 1.6077629327774048
Epoch validation loss: 1.6326369643211365
Epoch: 38


100%|██████████| 107/107 [00:09<00:00, 11.24it/s]


Average epoch training loss: 1.442929135304745
Last batch training loss: 1.548840880393982
Epoch validation loss: 1.6244669437408448
Epoch: 39


100%|██████████| 107/107 [00:09<00:00, 11.15it/s]


Average epoch training loss: 1.4245066297388522
Last batch training loss: 1.5536494255065918
Epoch validation loss: 1.6215940356254577
Epoch: 40


100%|██████████| 107/107 [00:09<00:00, 10.97it/s]


Average epoch training loss: 1.4106627000826542
Last batch training loss: 1.5440913438796997
Epoch validation loss: 1.6096951007843017
Epoch: 41


100%|██████████| 107/107 [00:09<00:00, 11.11it/s]


Average epoch training loss: 1.3932047057374615
Last batch training loss: 1.5212568044662476
Epoch validation loss: 1.60868399143219
Epoch: 42


100%|██████████| 107/107 [00:09<00:00, 11.27it/s]


Average epoch training loss: 1.37683048649369
Last batch training loss: 1.5147464275360107
Epoch validation loss: 1.6037568926811219
Epoch: 43


100%|██████████| 107/107 [00:09<00:00, 11.03it/s]


Average epoch training loss: 1.3612065448939243
Last batch training loss: 1.5233310461044312
Epoch validation loss: 1.5925783514976501
Epoch: 44


100%|██████████| 107/107 [00:09<00:00, 11.23it/s]


Average epoch training loss: 1.3464526188707797
Last batch training loss: 1.4708175659179688
Epoch validation loss: 1.5972877144813538
Epoch: 45


100%|██████████| 107/107 [00:09<00:00, 11.01it/s]


Average epoch training loss: 1.3358519205423158
Last batch training loss: 1.4717708826065063
Epoch validation loss: 1.5900057315826417
Epoch: 46


100%|██████████| 107/107 [00:09<00:00, 11.13it/s]


Average epoch training loss: 1.3191616797001562
Last batch training loss: 1.4579603672027588
Epoch validation loss: 1.586282217502594
Epoch: 47


100%|██████████| 107/107 [00:09<00:00, 10.92it/s]


Average epoch training loss: 1.309049936098473
Last batch training loss: 1.4663121700286865
Epoch validation loss: 1.5907786011695861
Epoch: 48


100%|██████████| 107/107 [00:09<00:00, 10.88it/s]


Average epoch training loss: 1.296562619298418
Last batch training loss: 1.4515047073364258
Epoch validation loss: 1.5823691010475158
Epoch: 49


100%|██████████| 107/107 [00:09<00:00, 11.04it/s]


Average epoch training loss: 1.2845385993752525
Last batch training loss: 1.449761986732483
Epoch validation loss: 1.578969132900238
Epoch: 50


100%|██████████| 107/107 [00:09<00:00, 11.31it/s]


Average epoch training loss: 1.2715448597881283
Last batch training loss: 1.4007861614227295
Epoch validation loss: 1.5822626590728759
Epoch: 51


100%|██████████| 107/107 [00:09<00:00, 10.97it/s]


Average epoch training loss: 1.2615067685875938
Last batch training loss: 1.3797236680984497
Epoch validation loss: 1.587569797039032
Epoch: 52


100%|██████████| 107/107 [00:09<00:00, 11.17it/s]


Average epoch training loss: 1.2477462319570167
Last batch training loss: 1.3961025476455688
Epoch validation loss: 1.5691036820411681
Epoch: 53


100%|██████████| 107/107 [00:09<00:00, 10.82it/s]


Average epoch training loss: 1.2374013400523463
Last batch training loss: 1.3613771200180054
Epoch validation loss: 1.5745491623878478
Epoch: 54


100%|██████████| 107/107 [00:09<00:00, 11.14it/s]


Average epoch training loss: 1.228640384206148
Last batch training loss: 1.3515383005142212
Epoch validation loss: 1.56990807056427
Epoch: 55


100%|██████████| 107/107 [00:09<00:00, 10.79it/s]


Average epoch training loss: 1.2157401654207818
Last batch training loss: 1.3534702062606812
Epoch validation loss: 1.5710762739181519
Epoch: 56


100%|██████████| 107/107 [00:09<00:00, 11.00it/s]


Average epoch training loss: 1.2058849802641112
Last batch training loss: 1.327820897102356
Epoch validation loss: 1.5699811100959777
Epoch: 57


100%|██████████| 107/107 [00:09<00:00, 11.07it/s]


Average epoch training loss: 1.1964834811531495
Last batch training loss: 1.3077750205993652
Epoch validation loss: 1.5718865156173707
Epoch: 58


100%|██████████| 107/107 [00:09<00:00, 11.05it/s]


Average epoch training loss: 1.1847823256644132
Last batch training loss: 1.3305028676986694
Epoch validation loss: 1.5733693838119507
Epoch: 59


100%|██████████| 107/107 [00:09<00:00, 11.12it/s]


Average epoch training loss: 1.174315342836291
Last batch training loss: 1.325385332107544
Epoch validation loss: 1.5941571116447448
Epoch: 60


100%|██████████| 107/107 [00:09<00:00, 11.00it/s]


Average epoch training loss: 1.16554704447773
Last batch training loss: 1.2949358224868774
Epoch validation loss: 1.5810581922531128
Epoch: 61


100%|██████████| 107/107 [00:09<00:00, 11.15it/s]


Average epoch training loss: 1.1545147366612871
Last batch training loss: 1.3013356924057007
Epoch validation loss: 1.5801969289779663
Epoch: 62


100%|██████████| 107/107 [00:09<00:00, 11.09it/s]


Average epoch training loss: 1.1427246064783256
Last batch training loss: 1.3038700819015503
Epoch validation loss: 1.5873427033424377
Epoch: 63


100%|██████████| 107/107 [00:10<00:00, 10.61it/s]


Average epoch training loss: 1.1345342921319408
Last batch training loss: 1.2674890756607056
Epoch validation loss: 1.5723849654197692
Epoch: 64


100%|██████████| 107/107 [00:09<00:00, 11.45it/s]


Average epoch training loss: 1.1261909069301925
Last batch training loss: 1.294823169708252
Epoch validation loss: 1.5837682485580444
Epoch: 65


100%|██████████| 107/107 [00:32<00:00,  3.31it/s]


Average epoch training loss: 1.115262067763605
Last batch training loss: 1.2534153461456299
Epoch validation loss: 1.5940488219261169
Epoch: 66


100%|██████████| 107/107 [00:08<00:00, 11.89it/s]


Average epoch training loss: 1.1089761268312686
Last batch training loss: 1.2592666149139404
Epoch validation loss: 1.6035747528076172
Epoch: 67


100%|██████████| 107/107 [00:09<00:00, 11.45it/s]


Average epoch training loss: 1.0957032380817093
Last batch training loss: 1.245103120803833
Epoch validation loss: 1.605346941947937
Epoch: 68


100%|██████████| 107/107 [00:08<00:00, 12.12it/s]


Average epoch training loss: 1.0872199212279274
Last batch training loss: 1.22137451171875
Epoch validation loss: 1.615861749649048
Epoch: 69


100%|██████████| 107/107 [00:09<00:00, 11.81it/s]


Average epoch training loss: 1.0785248368699973
Last batch training loss: 1.2311725616455078
Epoch validation loss: 1.5979690194129943
Epoch: 70


100%|██████████| 107/107 [00:09<00:00, 11.66it/s]


Average epoch training loss: 1.068748805567483
Last batch training loss: 1.1996521949768066
Epoch validation loss: 1.6271481394767762
Epoch: 71


100%|██████████| 107/107 [00:09<00:00, 11.66it/s]


Average epoch training loss: 1.0677339016834153
Last batch training loss: 1.1906386613845825
Epoch validation loss: 1.6217171669006347
Epoch: 72


100%|██████████| 107/107 [00:09<00:00, 11.22it/s]


Average epoch training loss: 1.0547696101331265
Last batch training loss: 1.1826151609420776
Epoch validation loss: 1.6223968505859374
Epoch: 73


100%|██████████| 107/107 [00:09<00:00, 11.32it/s]


Average epoch training loss: 1.0432972412243067
Last batch training loss: 1.1508007049560547
Epoch validation loss: 1.636875605583191
Epoch: 74


100%|██████████| 107/107 [00:09<00:00, 11.25it/s]


Average epoch training loss: 1.036450216146273
Last batch training loss: 1.1800448894500732
Epoch validation loss: 1.646957552433014
Epoch: 75


100%|██████████| 107/107 [00:09<00:00, 11.30it/s]


Average epoch training loss: 1.0314067139803806
Last batch training loss: 1.1739141941070557
Epoch validation loss: 1.6343891382217408
Epoch: 76


100%|██████████| 107/107 [00:09<00:00, 11.27it/s]


Average epoch training loss: 1.0173688759313566
Last batch training loss: 1.1600217819213867
Epoch validation loss: 1.6435920476913453
Epoch: 77


100%|██████████| 107/107 [00:09<00:00, 11.27it/s]


Average epoch training loss: 1.0094965551501123
Last batch training loss: 1.1341676712036133
Epoch validation loss: 1.6332300305366516
Epoch: 78


100%|██████████| 107/107 [00:09<00:00, 11.15it/s]


Average epoch training loss: 0.9997802181778667
Last batch training loss: 1.1270945072174072
Epoch validation loss: 1.6557953834533692
Epoch: 79


100%|██████████| 107/107 [00:09<00:00, 11.18it/s]


Average epoch training loss: 0.991533123444174
Last batch training loss: 1.1401190757751465
Epoch validation loss: 1.6463997602462768
Epoch: 80


100%|██████████| 107/107 [00:09<00:00, 11.17it/s]


Average epoch training loss: 0.9850439579687386
Last batch training loss: 1.1340643167495728
Epoch validation loss: 1.6631471514701843
Epoch: 81


100%|██████████| 107/107 [00:09<00:00, 11.16it/s]


Average epoch training loss: 0.9739979310570476
Last batch training loss: 1.1288819313049316
Epoch validation loss: 1.6711835861206055
Epoch: 82


100%|██████████| 107/107 [00:09<00:00, 10.95it/s]


Average epoch training loss: 0.9662019593693386
Last batch training loss: 1.1021065711975098
Epoch validation loss: 1.66290602684021
Epoch: 83


100%|██████████| 107/107 [00:09<00:00, 11.15it/s]


Average epoch training loss: 0.9613927936999598
Last batch training loss: 1.101820468902588
Epoch validation loss: 1.6916552782058716
Epoch: 84


100%|██████████| 107/107 [00:09<00:00, 11.09it/s]


Average epoch training loss: 0.9503331641170466
Last batch training loss: 1.0719692707061768
Epoch validation loss: 1.6885416507720947
Epoch: 85


100%|██████████| 107/107 [00:09<00:00, 10.76it/s]


Average epoch training loss: 0.9413766554582899
Last batch training loss: 1.0649158954620361
Epoch validation loss: 1.697167456150055
Epoch: 86


100%|██████████| 107/107 [00:09<00:00, 11.05it/s]


Average epoch training loss: 0.9333159723014475
Last batch training loss: 1.0621840953826904
Epoch validation loss: 1.7126036524772643
Epoch: 87


100%|██████████| 107/107 [00:09<00:00, 11.06it/s]


Average epoch training loss: 0.9245340389626049
Last batch training loss: 1.0588531494140625
Epoch validation loss: 1.7094740986824035
Epoch: 88


100%|██████████| 107/107 [00:09<00:00, 10.93it/s]


Average epoch training loss: 0.9195045505728677
Last batch training loss: 1.0685832500457764
Epoch validation loss: 1.7194653153419495
Epoch: 89


100%|██████████| 107/107 [00:09<00:00, 11.08it/s]


Average epoch training loss: 0.9098524495820018
Last batch training loss: 1.0496901273727417
Epoch validation loss: 1.712940263748169
Epoch: 90


100%|██████████| 107/107 [00:09<00:00, 10.95it/s]


Average epoch training loss: 0.899928160916979
Last batch training loss: 1.045318841934204
Epoch validation loss: 1.7195160508155822
Epoch: 91


100%|██████████| 107/107 [00:09<00:00, 11.12it/s]


Average epoch training loss: 0.892315376584775
Last batch training loss: 1.047141671180725
Epoch validation loss: 1.7188438773155212
Epoch: 92


100%|██████████| 107/107 [00:09<00:00, 10.85it/s]


Average epoch training loss: 0.8866919713599659
Last batch training loss: 1.0178881883621216
Epoch validation loss: 1.7256455898284913
Epoch: 93


100%|██████████| 107/107 [00:09<00:00, 11.19it/s]


Average epoch training loss: 0.8779502828544545
Last batch training loss: 1.0082974433898926
Epoch validation loss: 1.7364740133285523
Epoch: 94


100%|██████████| 107/107 [00:09<00:00, 11.10it/s]


Average epoch training loss: 0.8693443598034226
Last batch training loss: 1.0093934535980225
Epoch validation loss: 1.749776840209961
Epoch: 95


100%|██████████| 107/107 [00:09<00:00, 10.95it/s]


Average epoch training loss: 0.8668881649168853
Last batch training loss: 0.9993721842765808
Epoch validation loss: 1.7725111722946167
Epoch: 96


100%|██████████| 107/107 [00:09<00:00, 11.16it/s]


Average epoch training loss: 0.8690439479373325
Last batch training loss: 1.0303822755813599
Epoch validation loss: 1.7347704410552978
Epoch: 97


100%|██████████| 107/107 [00:09<00:00, 11.01it/s]


Average epoch training loss: 0.855561314342178
Last batch training loss: 0.9780593514442444
Epoch validation loss: 1.7509577751159668
Epoch: 98


100%|██████████| 107/107 [00:09<00:00, 11.03it/s]


Average epoch training loss: 0.8498220677687743
Last batch training loss: 0.9802083373069763
Epoch validation loss: 1.7680776953697204
Epoch: 99


100%|██████████| 107/107 [00:09<00:00, 11.04it/s]


Average epoch training loss: 0.8390663249470364
Last batch training loss: 0.984528124332428
Epoch validation loss: 1.767014217376709
Epoch: 100


100%|██████████| 107/107 [00:09<00:00, 11.09it/s]


Average epoch training loss: 0.8303277876889594
Last batch training loss: 0.9499963521957397
Epoch validation loss: 1.7654722571372985
Epoch: 101


100%|██████████| 107/107 [00:09<00:00, 11.14it/s]


Average epoch training loss: 0.8242397024252704
Last batch training loss: 0.9591460227966309
Epoch validation loss: 1.7953692078590393
Epoch: 102


100%|██████████| 107/107 [00:09<00:00, 10.91it/s]


Average epoch training loss: 0.8196438556519624
Last batch training loss: 0.944384753704071
Epoch validation loss: 1.81068754196167
Epoch: 103


100%|██████████| 107/107 [00:09<00:00, 11.10it/s]


Average epoch training loss: 0.8141950079213793
Last batch training loss: 0.9567021131515503
Epoch validation loss: 1.8090829849243164
Epoch: 104


100%|██████████| 107/107 [00:09<00:00, 10.90it/s]


Average epoch training loss: 0.8044288219692551
Last batch training loss: 0.9204676151275635
Epoch validation loss: 1.812583899497986
Epoch: 105


 58%|█████▊    | 62/107 [00:05<00:04, 10.45it/s]


KeyboardInterrupt: 

In [11]:
def make_tgt_mask(sequence_length, device):
    tgt_mask = torch.tril(torch.ones(sequence_length, sequence_length, dtype=torch.bool)).to(device)
    return tgt_mask

In [None]:
import torch

def predict_fernando(start_text, model, tokenizer, max_sequence_length=50, temperature=1.0):
    model.eval()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    sequence = [tokenizer.sos_token_idx()] + [tokenizer.encode(c) for c in start_text]

    input_tokens = torch.tensor(sequence, dtype=torch.long).unsqueeze(0).to(device) # acrescenta dimensão batch_size=1
    current_text = start_text

    with torch.no_grad():
        for _ in range(len(start_text), max_sequence_length):
            outputs = model(input_tokens, tgt_mask=make_tgt_mask(input_tokens.shape[1], device))
            log_probs = outputs[0, -1] / temperature

            predicted_token_idx = torch.distributions.Categorical(logits=log_probs).sample().item()

            if predicted_token_idx == tokenizer.eos_token_idx():
                break

            current_text += tokenizer.decode(predicted_token_idx)
            input_tokens = torch.cat((input_tokens, torch.tensor([[predicted_token_idx]]).to(device)), dim=1)

    print('Texto predito:', current_text)
    return current_text


In [12]:
def pad_sequence_to_length(sequence, sequence_length, pad_token_idx):
    """
    Completa a sequência com tokens de padding até atingir o comprimento desejado.

    Args:
        sequence (list): Sequência de tokens (índices) a ser completada.
        sequence_length (int): Comprimento desejado da sequência.
        pad_token_idx (int): Índice do token de padding.

    Returns:
        list: Sequência completada com tokens de padding.
    """
    if len(sequence) < sequence_length:
        # Adiciona tokens de padding no final da sequência
        sequence += [pad_token_idx] * (sequence_length - len(sequence))
    elif len(sequence) > sequence_length:
        # Trunca a sequência se for maior que o comprimento desejado
        sequence = sequence[:sequence_length]
    return sequence

In [None]:
import torch

def predict_assisted(start_text, model, tokenizer, max_sequence_length=50, temperature=1.0):
    model.eval()
    device = next(model.parameters()).device

    # Tokeniza e preenche a sequência
    sequence = [tokenizer.sos_token_idx()] + [tokenizer.encode(c) for c in start_text]
    sequence = pad_sequence_to_length(sequence, sequence_length, tokenizer.pad_token_idx())
    input_tokens = torch.tensor(sequence, dtype=torch.long).unsqueeze(0).to(device)

    current_text = start_text

    with torch.no_grad():
        for i in range(len(start_text), sequence_length):
            tgt_mask = torch.tril(torch.ones(sequence_length, sequence_length)).to(device).bool()
            outputs = model(input_tokens, tgt_mask=tgt_mask)
            log_probs = outputs[0, i - 1] / temperature

            predicted_token_idx = torch.distributions.Categorical(logits=log_probs).sample().item()

            if predicted_token_idx == tokenizer.eos_token_idx():
                break

            current_text += tokenizer.decode(predicted_token_idx)
            input_tokens[0, i] = predicted_token_idx
            # if i == 20:
            #     print(input_tokens)

    print('Texto predito:', current_text)
    return current_text


In [45]:
dataset = DatasetDialogs('dataset_text/dialogs.txt', 100)

In [None]:
predict_fernando("you're kidding.	no, ", model=model, tokenizer=dataset.tokenizer, max_sequence_length=50, temperature=1)

Texto predito: you're kidding.	no, you don't smell radid anymore.


"you're kidding.\tno, you don't smell radid anymore."

In [None]:
predict_fernando("you're kidding.	no, ", model=model, tokenizer=dataset.tokenizer, max_sequence_length=50, temperature=1)

Texto predito: you're kidding.	no, that's a good idea.


"you're kidding.\tno, that's a good idea."

In [None]:
predict_assisted("you're kidding.	no, ", model=model, tokenizer=dataset.tokenizer, temperature=0.2)

Texto predito: you're kidding.	no,  why don't you wash?


"you're kidding.\tno,  why don't you wash?"

In [None]:
predict_assisted("i'm doing well. how about you? ", model=model, tokenizer=dataset.tokenizer, temperature=1)

Texto predito: i'm doing well. how about you?  i'm going to the movies.


"i'm doing well. how about you?  i'm going to the movies."