## Transformer from Scratch

In [1]:
import math

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.init as init
from tqdm import tqdm

### Prep 1.: Embedding

In [2]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    def forward(self, x):
        # Multiplicação com math.sqrt(d_model) pois Positional Encoder tem valores iniciais entre -1 e 1 devido a sin e cos.
        # Essa multiplicação escala os valores da inicialização do nn.Embedding para próximo da escala do Positional Encoder.
        # Inicialização do nn.Embedding é normal com média 0 e standard deviation embedding_dim ** -0.5
        return self.embedding(x) * math.sqrt(self.d_model)

### Prep 2.: Positional Encoding

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        pe = torch.zeros((max_seq_length, d_model))
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

### Model 1.: Attention Mechanism Layer (MultiHeadAttention Layer) + Feed Forward Layer

In [4]:
import torch.nn.functional as F

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model, bias=False)
        self.key_linear = nn.Linear(d_model, d_model, bias=False)
        self.value_linear = nn.Linear(d_model, d_model, bias=False)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        # Entra x com shape (batch_size, seq_length, d_model)
        seq_length = x.size(1)
        x = x.reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        # Sai y com shape (batch_size, num_heads, seq_length, head_dim)
        return x.permute(0, 2, 1, 3)
    
    def compute_attention(self, query, key, value, mask=None):
        # Shape de query, key, value (batch_size, num_heads, seq_length, head_dim)
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1) # dim -1 significa softmax computada ao longo da dimensão head_dim, que é um chunk de embedding.
        return torch.matmul(attention_weights, value)
    
    def combine_heads(self, x, batch_size):
        seq_length = x.size(2)
        x = x.permute(0, 2, 1, 3).contiguous()
        return x.reshape(batch_size, seq_length, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query = self.split_heads(self.query_linear(query), batch_size)
        key = self.split_heads(self.key_linear(key), batch_size)
        value = self.split_heads(self.value_linear(value), batch_size)
        
        attention_weights = self.compute_attention(query, key, value, mask)
        output = self.combine_heads(attention_weights, batch_size)
        return self.output_linear(output)

In [6]:
class FeedForwardSubLayer(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

### Model 2.: Transformer Encoder

In [7]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        attn_output = self.self_attn(x, x, x, mask=src_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [8]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()
        self.embedding = InputEmbeddings(vocab_size=vocab_size, d_model=d_model)
        self.positional_encoder = PositionalEncoding(d_model=d_model, max_seq_length=max_seq_length)
        self.encoder_blocks = nn.ModuleList(
            [
                EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(n_layers) # noqa E501
            ]
        )

    def forward(self, x, src_mask):
        x = self.embedding(x)
        x = self.positional_encoder(x)
        for layer in self.encoder_blocks:
            x = layer(x, src_mask)
        return x

In [9]:
class ClassifierHead(nn.Module):
    def __init__(self, d_model, num_classes):
        super().__init__()
        self.classifier_nn = nn.Linear(d_model, num_classes)

    def forward(self, x):
        logits = self.classifier_nn(x)
        return F.softmax(logits)

In [10]:
vocab_size = 256
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.2
seq_length = 256
num_classes = 2

transformer_encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, seq_length)
classifier = ClassifierHead(d_model, num_classes)

### Model 3.: Transformer Decoder

In [11]:
torch.ones(1, seq_length, seq_length)

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [12]:
torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)

tensor([[[0., 1., 1.,  ..., 1., 1., 1.],
         [0., 0., 1.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [13]:
tgt_mask = (1 - torch.triu(
  torch.ones(1, seq_length, seq_length), diagonal=1)
).bool()

In [14]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [15]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, tgt_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, tgt_mask)
        logits = self.fc(x)
        # Log de probabilidades para computação mais rápida e maior estabilidade com probas perto de 0.
        # Mapeia [0, 1] para (-inf, 0]
        return logits # F.log_softmax(x, dim=-1)

In [16]:
max_seq_length = 256
batch_size = 1
vocab_size = 1000
torch.randint(low=0, high=vocab_size, size=(batch_size, max_seq_length))

tensor([[359, 162, 723, 140, 969, 208, 259,  89, 650, 765, 962,  85, 291, 573,
         888,  84, 109, 621, 204, 842,  17, 363, 192,  26, 976, 416,  33, 980,
         528, 204,  88, 379, 752, 368,  31, 586, 287, 579, 567, 731, 144, 614,
         326, 277, 222, 295, 259, 236, 646, 104, 373, 530, 166, 674,  47, 748,
         653, 883, 614, 409, 114, 233,   5, 863, 441, 368, 255, 457, 644,  54,
         765, 431,  31, 538, 171, 237, 191, 587, 879, 992, 537, 315, 990,  79,
         338, 275, 992, 675,  64, 342, 908, 652, 235, 734, 867, 543, 345, 202,
         462, 330, 763, 367, 283, 228, 479, 996,  82, 395, 557, 289, 373, 926,
         167, 681, 336,   2, 338, 627, 917,  37, 211,  55, 619, 194, 739, 701,
         267,  49, 779, 426, 162, 762, 346,  79, 505, 360, 852, 358, 375, 547,
         288, 457,  44, 339, 748, 727, 145, 946, 667, 946, 299, 298, 263, 883,
          65, 891, 324, 355, 217, 758, 132, 591, 659, 483, 195,  45, 681, 263,
         531, 707, 529, 668, 556, 898, 190, 424, 928

In [17]:
max_seq_length = 256
batch_size = 2
vocab_size = 1000
input_tokens = torch.randint(low=0, high=vocab_size, size=(batch_size, max_seq_length))

transformer_decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)   
output = transformer_decoder(input_tokens, tgt_mask)

In [18]:
print(input_tokens)

tensor([[ 78,  50,   3, 643, 193, 693, 708, 164, 527, 685, 273, 485, 865, 445,
         449, 207,  17, 785, 786, 931, 451, 281, 580, 488, 362, 499, 870, 890,
         121, 730, 567, 568, 792, 900, 844,  31, 326, 712, 720, 312, 906, 114,
         914,  46, 393, 136, 713, 212, 928, 536, 499, 454, 810, 333, 196, 252,
         950, 907, 301, 480, 236, 271, 558, 167, 243, 386, 182, 744, 741, 282,
         545, 742, 626, 449, 441, 132, 128, 497, 450, 716, 537, 195, 496, 288,
         195, 785, 486, 196, 207,  77, 939, 899, 393, 597, 518,  93, 797,  36,
         601, 286, 368,  50, 169, 570,  54, 359, 370, 705, 107, 140,  96, 188,
          41, 720, 373, 469, 828, 298, 292, 414, 649, 808, 948, 892, 429, 868,
          95, 339, 274,  65, 753, 823, 914, 430, 623, 947, 659, 862, 550, 282,
         544, 849, 763, 122, 543, 511, 260, 728, 582,  68,  20,  74, 706, 467,
         143, 868,   2,  62, 436, 592, 147,   6, 975, 517, 175, 419, 131, 337,
         748, 702, 136, 473, 386, 972, 950, 309,  42

In [19]:
print(output)

tensor([[[ 0.4554, -0.1481, -0.0099,  ..., -0.2201, -0.1905, -0.1033],
         [ 0.3121, -1.0453, -0.5668,  ...,  0.0268, -0.9447,  0.0488],
         [ 0.2442,  0.4744, -0.4564,  ...,  0.4239,  0.6048,  0.1509],
         ...,
         [-0.2265,  0.7925,  0.2404,  ..., -1.1820,  0.3571,  0.0399],
         [ 0.5802,  0.6000,  0.4702,  ..., -1.1512, -0.7085,  0.4183],
         [-0.3865,  0.5531, -0.1136,  ...,  0.8278,  0.4734,  1.0359]],

        [[-0.4036, -0.3295,  0.0918,  ...,  0.0761,  0.1926, -0.0362],
         [-0.6071,  0.3132, -0.1897,  ..., -0.5080,  0.2858,  0.0533],
         [-0.5049,  0.8557,  0.3396,  ...,  0.7998,  0.0855,  0.1365],
         ...,
         [-0.1669, -0.4839,  0.4137,  ...,  0.7798, -0.2559, -0.1794],
         [-0.0708,  0.1759, -0.9030,  ...,  0.2859,  0.3821,  0.0327],
         [-0.1576, -0.1774,  0.2091,  ...,  0.3869,  0.1064, -0.3301]]],
       grad_fn=<ViewBackward0>)


In [20]:
# torch.randn(size=(2, 256, 8, 64)).view((2, 256, 512)) # for debugging reshapes

### Model 4.: Encoder-decoder Transformer

In [21]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y, tgt_mask, cross_mask):
        self_attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        cross_attn_output = self.cross_attn(x, y, y, cross_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, y, tgt_mask, cross_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, y, tgt_mask, cross_mask)
        x = self.fc(x)
        # Log de probabilidades para computação mais rápida e maior estabilidade com probas perto de 0.
        # Mapeia [0, 1] para (-inf, 0]
        return F.log_softmax(x, dim=-1)

In [23]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
        self.decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)

    def forward(self, x, src_mask, tgt_mask, cross_mask):
        encoder_output = self.encoder(x, src_mask)
        decoder_output = self.decoder(x, encoder_output, tgt_mask, cross_mask)
        return decoder_output

In [23]:
def generate_padding_mask(sequence, pad_token=0):
    # Mask out padding tokens (assumes pad_token is 0)
    return (sequence != pad_token).unsqueeze(1).unsqueeze(2)

# sequence = torch.tensor([2, 6, 30, 120, 0, 0, 0, 0])
src_mask = generate_padding_mask(input_tokens)
cross_mask = src_mask

In [26]:
transformer = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
outputs = transformer(input_tokens, src_mask, tgt_mask, cross_mask)
print(outputs)
print(outputs.shape)

tensor([[[-7.2952, -7.7172, -6.4120,  ..., -8.0577, -6.7412, -6.9631],
         [-8.0978, -5.6856, -6.9766,  ..., -7.3987, -5.9954, -7.1549],
         [-6.8227, -6.9482, -6.8970,  ..., -6.4872, -6.5135, -6.8972],
         ...,
         [-7.2849, -7.2086, -6.8851,  ..., -8.2023, -7.0849, -7.5703],
         [-6.5633, -6.9625, -7.1137,  ..., -7.2809, -6.3619, -6.8622],
         [-7.9815, -6.7311, -6.9914,  ..., -7.6957, -5.9318, -7.9336]],

        [[-7.9445, -7.5784, -7.1699,  ..., -7.4563, -7.4804, -7.3523],
         [-8.1715, -6.8377, -6.5618,  ..., -6.8760, -6.6384, -5.8102],
         [-8.4128, -7.2745, -7.4224,  ..., -7.3137, -6.4476, -6.8861],
         ...,
         [-6.9692, -6.8992, -7.4177,  ..., -7.3548, -7.0166, -6.5395],
         [-6.2935, -7.0444, -7.1079,  ..., -7.5451, -6.3904, -7.3940],
         [-7.0938, -7.0121, -7.0166,  ..., -7.9626, -6.6922, -8.1125]]],
       grad_fn=<LogSoftmaxBackward0>)
torch.Size([2, 256, 1000])


### Tokenizer

In [21]:
### 259 tokens possíveis:
# ASCII + <UNK> + <SOS> + <EOS> + <PAD>

class TokenizerChar:
    def __init__(self):
        self.chr_to_idx = {chr(v): v for v in range(1, 257)}
        self.chr_to_idx['<SOS>'] = 257
        self.chr_to_idx['<EOS>'] = 258
        self.chr_to_idx['<PAD>'] = 0
        self.chr_to_idx['<UNK>'] = 259

        self.idx_to_chr = {v: k for k, v in self.chr_to_idx.items()}

        self.vocab_size = len(self.chr_to_idx.keys())

    def encode(self, char):
        if char in self.chr_to_idx.keys():
            return self.chr_to_idx[char]
        else:
            return 259
    
    def decode(self, token_idx):
        return self.idx_to_chr[token_idx]
    
    def sos_token(self):
        return '<SOS>'
    
    def sos_token_idx(self):
        return self.chr_to_idx['<SOS>']

    def eos_token(self):
        return '<EOS>'
    
    def eos_token_idx(self):
        return self.chr_to_idx['<EOS>']
    
    def pad_token(self):
        return '<PAD>'
    
    def pad_token_idx(self):
        return self.chr_to_idx['<PAD>']
    
    def get_vocab_size(self):
        return self.vocab_size


### Dataset

In [22]:
with open('dataset_text/dialogs.txt') as file:
    print(len(file.read().split('\n')))

3725


In [23]:
lista = [1, 2, 3, 4]

In [24]:
lista[:-1]

[1, 2, 3]

In [25]:
lista[1:]

[2, 3, 4]

In [26]:
from torch.utils.data import Dataset


class DatasetDialogs(Dataset):
    def __init__(self, dataset_path, sentence_length):
        self.dataset_path = dataset_path
        self.sentence_length = sentence_length
        self.tokenizer = TokenizerChar()

    def __len__(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return num_of_sentences
    
    def get_shape(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return (num_of_sentences, self.sentence_length)

    def __getitem__(self, line_idx):
        with open(self.dataset_path, 'r') as dataset:
            selected_sentence = dataset.read().split('\n')[line_idx]
            if len(selected_sentence) < self.sentence_length:
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens.append(self.tokenizer.eos_token_idx())
                pad_length = self.sentence_length - len(input_tokens) + 1
                pad_tokens = [self.tokenizer.pad_token_idx()] * pad_length
                input_tokens += pad_tokens
            elif len(selected_sentence) == self.sentence_length:
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens[-1] = self.tokenizer.eos_token_idx()
            elif len(selected_sentence) > self.sentence_length:
                selected_sentence = selected_sentence[:self.sentence_length]
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens[-1] = self.tokenizer.eos_token_idx()
            # print(f'{len(input_tokens)} - {self.sentence_length}') # debug only
            assert len(input_tokens) == self.sentence_length + 1, f"Lista de índices de tokens não possui mesmo tamanho que 'sentence_length'! len(input_tokens): {len(input_tokens)} - self.sentence_length: {self.sentence_length}"
            try:
                x = torch.tensor(input_tokens[:-1])
                y = torch.tensor(input_tokens[1:])
            except RuntimeError as e:
                print(e)
                print(f"Input tokens: {input_tokens}")
                raise e
        return x, y

In [27]:
dataset = DatasetDialogs('dataset_text/dialogs.txt', 200)

In [28]:
chr(104)

'h'

In [29]:
dataset[0]

(tensor([257, 104, 105,  44,  32, 104, 111, 119,  32,  97, 114, 101,  32, 121,
         111, 117,  32, 100, 111, 105, 110, 103,  63,   9, 105,  39, 109,  32,
         102, 105, 110, 101,  46,  32, 104, 111, 119,  32,  97,  98, 111, 117,
         116,  32, 121, 111, 117, 114, 115, 101, 108, 102,  63, 258,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0

### Training Loop

In [30]:
import torch
torch.cuda.is_available()

True

In [31]:
torch.zeros(1).cuda()

tensor([0.], device='cuda:0')

In [32]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
sequence_length = 200
batch_size = 32
dataset_train = DatasetDialogs('dataset_text/dialogs_train.txt', sequence_length)
dataset_test = DatasetDialogs('dataset_text/dialogs_test.txt', sequence_length)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
vocab_size = dataset_train.tokenizer.get_vocab_size()

d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1
max_seq_length = sequence_length
model = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
model.to(device)

tgt_mask = (1 - torch.triu(
  torch.ones(1, sequence_length, sequence_length), diagonal=1)
).bool()

def init_weights(module):
    if isinstance(module, (nn.Linear)):
        init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            init.zeros_(module.bias)
model.apply(init_weights)

optimizer = Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
n_epochs = 50
n_batches = int(dataset_train.__len__() // batch_size)

print("Starting model training...")
for epoch in range(n_epochs):
    print(f"Epoch: {epoch + 1}")
    avg_loss = 0
    model.train()
    for batch_idx, batch in enumerate(tqdm(dataloader_train, total=n_batches)):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        outputs = model(x, tgt_mask.to(device))
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    torch.save(model.state_dict(), f'model_checkpoints/model_checkpoint_{epoch+1}.pth')

    avg_loss /= (batch_idx + 1)
    print(f"Average epoch training loss: {avg_loss}")
    print(f"Last batch training loss: {loss}")

    model.eval()
    avg_loss = 0
    for batch_idx, batch in enumerate(dataloader_test):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        outputs = model(x, tgt_mask.to(device))
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
    
    avg_loss /= (batch_idx + 1)
    print(f"Epoch validation loss: {avg_loss}")
    

cuda
Starting model training...
Epoch: 1


100%|██████████| 107/107 [00:15<00:00,  7.09it/s]


Average epoch training loss: 3.251866808561521
Last batch training loss: 3.145986795425415
Epoch validation loss: 3.067007899284363
Epoch: 2


100%|██████████| 107/107 [00:14<00:00,  7.22it/s]


Average epoch training loss: 3.1171553469149864
Last batch training loss: 3.0974152088165283
Epoch validation loss: 3.0519294500350953
Epoch: 3


100%|██████████| 107/107 [00:15<00:00,  7.09it/s]


Average epoch training loss: 3.0806717471541645
Last batch training loss: 3.0185256004333496
Epoch validation loss: 2.96907114982605
Epoch: 4


100%|██████████| 107/107 [00:16<00:00,  6.58it/s]


Average epoch training loss: 2.8311039643866995
Last batch training loss: 2.5665242671966553
Epoch validation loss: 2.554374623298645
Epoch: 5


100%|██████████| 107/107 [00:16<00:00,  6.48it/s]


Average epoch training loss: 2.482934430380848
Last batch training loss: 2.382699728012085
Epoch validation loss: 2.376488709449768
Epoch: 6


100%|██████████| 107/107 [00:16<00:00,  6.44it/s]


Average epoch training loss: 2.355214990188028
Last batch training loss: 2.2573342323303223
Epoch validation loss: 2.2929923295974732
Epoch: 7


100%|██████████| 107/107 [00:16<00:00,  6.52it/s]


Average epoch training loss: 2.270636551848082
Last batch training loss: 2.3014845848083496
Epoch validation loss: 2.20705885887146
Epoch: 8


100%|██████████| 107/107 [00:16<00:00,  6.37it/s]


Average epoch training loss: 2.195249873901082
Last batch training loss: 2.226848840713501
Epoch validation loss: 2.156653356552124
Epoch: 9


100%|██████████| 107/107 [00:16<00:00,  6.33it/s]


Average epoch training loss: 2.128161200853152
Last batch training loss: 2.1957952976226807
Epoch validation loss: 2.111274528503418
Epoch: 10


100%|██████████| 107/107 [00:16<00:00,  6.39it/s]


Average epoch training loss: 2.0698700118287703
Last batch training loss: 2.0564870834350586
Epoch validation loss: 2.0541015625
Epoch: 11


100%|██████████| 107/107 [00:16<00:00,  6.39it/s]


Average epoch training loss: 2.014673773373399
Last batch training loss: 2.03955078125
Epoch validation loss: 2.0065454125404356
Epoch: 12


100%|██████████| 107/107 [00:16<00:00,  6.39it/s]


Average epoch training loss: 1.9675393672747032
Last batch training loss: 2.0010063648223877
Epoch validation loss: 1.9909847497940063
Epoch: 13


100%|██████████| 107/107 [00:17<00:00,  6.17it/s]


Average epoch training loss: 1.9318829685728127
Last batch training loss: 1.9237276315689087
Epoch validation loss: 1.9282773375511169
Epoch: 14


100%|██████████| 107/107 [00:17<00:00,  6.18it/s]


Average epoch training loss: 1.8895205816375875
Last batch training loss: 1.8973420858383179
Epoch validation loss: 1.9031854629516602
Epoch: 15


100%|██████████| 107/107 [00:17<00:00,  5.96it/s]


Average epoch training loss: 1.8533958107511574
Last batch training loss: 1.820874571800232
Epoch validation loss: 1.8847570300102234
Epoch: 16


100%|██████████| 107/107 [00:17<00:00,  6.11it/s]


Average epoch training loss: 1.8202093661388503
Last batch training loss: 1.8321703672409058
Epoch validation loss: 1.8510371446609497
Epoch: 17


100%|██████████| 107/107 [00:17<00:00,  6.26it/s]


Average epoch training loss: 1.7874014533568765
Last batch training loss: 1.8062423467636108
Epoch validation loss: 1.820282793045044
Epoch: 18


100%|██████████| 107/107 [00:17<00:00,  6.18it/s]


Average epoch training loss: 1.7613285791094058
Last batch training loss: 1.7036370038986206
Epoch validation loss: 1.8241514563560486
Epoch: 19


100%|██████████| 107/107 [00:18<00:00,  5.92it/s]


Average epoch training loss: 1.7362059221089443
Last batch training loss: 1.7391769886016846
Epoch validation loss: 1.806203830242157
Epoch: 20


100%|██████████| 107/107 [00:17<00:00,  6.19it/s]


Average epoch training loss: 1.7072206301109814
Last batch training loss: 1.6881014108657837
Epoch validation loss: 1.7744478464126587
Epoch: 21


100%|██████████| 107/107 [00:17<00:00,  6.22it/s]


Average epoch training loss: 1.6828452560389153
Last batch training loss: 1.758082389831543
Epoch validation loss: 1.7688653826713563
Epoch: 22


100%|██████████| 107/107 [00:17<00:00,  6.09it/s]


Average epoch training loss: 1.6612236009579953
Last batch training loss: 1.6727275848388672
Epoch validation loss: 1.7479785323143004
Epoch: 23


100%|██████████| 107/107 [00:17<00:00,  6.19it/s]


Average epoch training loss: 1.6375378726798797
Last batch training loss: 1.6986958980560303
Epoch validation loss: 1.7425859689712524
Epoch: 24


100%|██████████| 107/107 [00:17<00:00,  6.12it/s]


Average epoch training loss: 1.621959065722528
Last batch training loss: 1.6411521434783936
Epoch validation loss: 1.7306982636451722
Epoch: 25


100%|██████████| 107/107 [00:17<00:00,  6.03it/s]


Average epoch training loss: 1.5988906432535046
Last batch training loss: 1.6227463483810425
Epoch validation loss: 1.7257278919219972
Epoch: 26


100%|██████████| 107/107 [00:17<00:00,  6.10it/s]


Average epoch training loss: 1.5746300989222304
Last batch training loss: 1.5865010023117065
Epoch validation loss: 1.7132383704185485
Epoch: 27


100%|██████████| 107/107 [00:17<00:00,  6.21it/s]


Average epoch training loss: 1.5535674941873996
Last batch training loss: 1.5394823551177979
Epoch validation loss: 1.7039262175559997
Epoch: 28


100%|██████████| 107/107 [00:17<00:00,  5.98it/s]


Average epoch training loss: 1.537041128238785
Last batch training loss: 1.5974630117416382
Epoch validation loss: 1.690328884124756
Epoch: 29


100%|██████████| 107/107 [00:18<00:00,  5.80it/s]


Average epoch training loss: 1.5283951202285624
Last batch training loss: 1.458520531654358
Epoch validation loss: 1.6743595957756043
Epoch: 30


100%|██████████| 107/107 [00:18<00:00,  5.88it/s]


Average epoch training loss: 1.5086469204626352
Last batch training loss: 1.4589226245880127
Epoch validation loss: 1.675321877002716
Epoch: 31


100%|██████████| 107/107 [00:17<00:00,  6.03it/s]


Average epoch training loss: 1.4902234913032746
Last batch training loss: 1.4634512662887573
Epoch validation loss: 1.6883554935455323
Epoch: 32


100%|██████████| 107/107 [00:17<00:00,  5.98it/s]


Average epoch training loss: 1.477960580977324
Last batch training loss: 1.4394700527191162
Epoch validation loss: 1.6527459740638732
Epoch: 33


100%|██████████| 107/107 [00:17<00:00,  6.01it/s]


Average epoch training loss: 1.455547853050945
Last batch training loss: 1.413841962814331
Epoch validation loss: 1.6618195295333862
Epoch: 34


100%|██████████| 107/107 [00:17<00:00,  5.98it/s]


Average epoch training loss: 1.438943726994167
Last batch training loss: 1.388890027999878
Epoch validation loss: 1.6741887092590333
Epoch: 35


100%|██████████| 107/107 [00:18<00:00,  5.92it/s]


Average epoch training loss: 1.4230147109967526
Last batch training loss: 1.403852105140686
Epoch validation loss: 1.652684998512268
Epoch: 36


100%|██████████| 107/107 [00:17<00:00,  6.04it/s]


Average epoch training loss: 1.4066343062391906
Last batch training loss: 1.342697262763977
Epoch validation loss: 1.649178671836853
Epoch: 37


100%|██████████| 107/107 [00:18<00:00,  5.89it/s]


Average epoch training loss: 1.389585601949246
Last batch training loss: 1.4803358316421509
Epoch validation loss: 1.6254518628120422
Epoch: 38


100%|██████████| 107/107 [00:17<00:00,  6.11it/s]


Average epoch training loss: 1.3755330292978019
Last batch training loss: 1.3236342668533325
Epoch validation loss: 1.6303140997886658
Epoch: 39


100%|██████████| 107/107 [00:17<00:00,  5.97it/s]


Average epoch training loss: 1.3635048142103392
Last batch training loss: 1.3215587139129639
Epoch validation loss: 1.6374784469604493
Epoch: 40


100%|██████████| 107/107 [00:17<00:00,  5.98it/s]


Average epoch training loss: 1.3515681819381
Last batch training loss: 1.2559921741485596
Epoch validation loss: 1.65574107170105
Epoch: 41


100%|██████████| 107/107 [00:18<00:00,  5.90it/s]


Average epoch training loss: 1.3339243726195575
Last batch training loss: 1.422647476196289
Epoch validation loss: 1.6396844387054443
Epoch: 42


100%|██████████| 107/107 [00:18<00:00,  5.91it/s]


Average epoch training loss: 1.3211579512212879
Last batch training loss: 1.3035699129104614
Epoch validation loss: 1.6542369842529296
Epoch: 43


100%|██████████| 107/107 [00:17<00:00,  5.95it/s]


Average epoch training loss: 1.305188597919785
Last batch training loss: 1.337019443511963
Epoch validation loss: 1.6395858883857728
Epoch: 44


100%|██████████| 107/107 [00:18<00:00,  5.94it/s]


Average epoch training loss: 1.2915287563733966
Last batch training loss: 1.291664958000183
Epoch validation loss: 1.634751522541046
Epoch: 45


100%|██████████| 107/107 [00:17<00:00,  6.04it/s]


Average epoch training loss: 1.2770134065752832
Last batch training loss: 1.3442552089691162
Epoch validation loss: 1.6456668972969055
Epoch: 46


100%|██████████| 107/107 [00:17<00:00,  5.99it/s]


Average epoch training loss: 1.2654094807455474
Last batch training loss: 1.372725009918213
Epoch validation loss: 1.6311949372291565
Epoch: 47


100%|██████████| 107/107 [00:17<00:00,  6.00it/s]


Average epoch training loss: 1.2568734249221944
Last batch training loss: 1.229806900024414
Epoch validation loss: 1.6474329233169556
Epoch: 48


100%|██████████| 107/107 [00:17<00:00,  5.97it/s]


Average epoch training loss: 1.241687081684576
Last batch training loss: 1.2651147842407227
Epoch validation loss: 1.636780023574829
Epoch: 49


100%|██████████| 107/107 [00:17<00:00,  5.95it/s]


Average epoch training loss: 1.2305150577955157
Last batch training loss: 1.2156800031661987
Epoch validation loss: 1.6359745740890503
Epoch: 50


100%|██████████| 107/107 [00:18<00:00,  5.91it/s]


Average epoch training loss: 1.2168593150432978
Last batch training loss: 1.1891847848892212
Epoch validation loss: 1.6346149086952209


### Text Generation / Model Prediction

In [33]:
def pad_sequence_to_length(sequence, sequence_length, pad_token_idx):
    """
    Completa a sequência com tokens de padding até atingir o comprimento desejado.

    Args:
        sequence (list): Sequência de tokens (índices) a ser completada.
        sequence_length (int): Comprimento desejado da sequência.
        pad_token_idx (int): Índice do token de padding.

    Returns:
        list: Sequência completada com tokens de padding.
    """
    if len(sequence) < sequence_length:
        # Adiciona tokens de padding no final da sequência
        sequence += [pad_token_idx] * (sequence_length - len(sequence))
    elif len(sequence) > sequence_length:
        # Trunca a sequência se for maior que o comprimento desejado
        sequence = sequence[:sequence_length]
    return sequence

In [34]:
import torch

def predict(start_text, model, tokenizer, sequence_length=200, temperature=1.0):
    model.eval()
    device = next(model.parameters()).device

    # Tokeniza e preenche a sequência
    sequence = [tokenizer.sos_token_idx()] + [tokenizer.encode(c) for c in start_text]
    sequence = pad_sequence_to_length(sequence, sequence_length, tokenizer.pad_token_idx())
    input_tokens = torch.tensor(sequence, dtype=torch.long).unsqueeze(0).to(device)

    current_text = start_text

    with torch.no_grad():
        for i in range(len(start_text), sequence_length):
            tgt_mask = torch.tril(torch.ones(sequence_length, sequence_length)).to(device).bool()
            outputs = model(input_tokens, tgt_mask=tgt_mask)
            log_probs = outputs[0, i - 1] / temperature

            predicted_token_idx = torch.distributions.Categorical(logits=log_probs).sample().item()

            if predicted_token_idx == tokenizer.eos_token_idx():
                break

            current_text += tokenizer.decode(predicted_token_idx)
            input_tokens[0, i] = predicted_token_idx

    print('Texto predito:', current_text)
    return current_text


In [102]:
predict('i am doing fine,  ', model=model, tokenizer=dataset.tokenizer, temperature=0.2)

Texto predito: i am doing fine,  but i don't have every day.	i'll bet the best a lot of the bad.


"i am doing fine,  but i don't have every day.\ti'll bet the best a lot of the bad."

In [88]:
predict('hello, are you happy? ', model=model, tokenizer=dataset.tokenizer, temperature=0.2)

Texto predito: hello, are you happy? 	i don't know.


"hello, are you happy? \ti don't know."