## Transformer from Scratch

In [1]:
import math

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.init as init
from tqdm import tqdm

### Prep 1.: Embedding

In [2]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    def forward(self, x):
        # Multiplicação com math.sqrt(d_model) pois Positional Encoder tem valores iniciais entre -1 e 1 devido a sin e cos.
        # Essa multiplicação escala os valores da inicialização do nn.Embedding para próximo da escala do Positional Encoder.
        # Inicialização do nn.Embedding é normal com média 0 e standard deviation embedding_dim ** -0.5
        return self.embedding(x) * math.sqrt(self.d_model)

### Prep 2.: Positional Encoding

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        pe = torch.zeros((max_seq_length, d_model))
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

### Model 1.: Attention Mechanism Layer (MultiHeadAttention Layer) + Feed Forward Layer

In [4]:
import torch.nn.functional as F

**bias=False explanation**

For certain types of layers, such as transformers and convolutional layers, including a bias term is unnecessary and adds unnecessary overhead to the model.

The reason for this is that these layers are typically followed by a normalization layer, such as Batch Normalization or Layer Normalization. These normalization layers center the data at mean=0 (and std=1), effectively removing any bias.

Therefore, it is common practice to omit the bias term in transformers and convolutional layers that are preceded by a normalization layer.

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model, bias=False)
        self.key_linear = nn.Linear(d_model, d_model, bias=False)
        self.value_linear = nn.Linear(d_model, d_model, bias=False)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        # Entra x com shape (batch_size, seq_length, d_model)
        seq_length = x.size(1)
        x = x.reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        # Sai y com shape (batch_size, num_heads, seq_length, head_dim)
        return x.permute(0, 2, 1, 3)
    
    def compute_attention(self, query, key, value, mask=None):
        # Shape de query, key, value (batch_size, num_heads, seq_length, head_dim)
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1) # dim -1 significa softmax computada ao longo da dimensão head_dim, que é um chunk de embedding.
        return torch.matmul(attention_weights, value)
    
    def combine_heads(self, x, batch_size):
        seq_length = x.size(2)
        x = x.permute(0, 2, 1, 3).contiguous()
        return x.reshape(batch_size, seq_length, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query = self.split_heads(self.query_linear(query), batch_size)
        key = self.split_heads(self.key_linear(key), batch_size)
        value = self.split_heads(self.value_linear(value), batch_size)
        
        attention_weights = self.compute_attention(query, key, value, mask)
        output = self.combine_heads(attention_weights, batch_size)
        return self.output_linear(output)

In [6]:
class FeedForwardSubLayer(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

### Model 2.: Transformer Encoder

In [7]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        attn_output = self.self_attn(x, x, x, mask=src_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [8]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()
        self.embedding = InputEmbeddings(vocab_size=vocab_size, d_model=d_model)
        self.positional_encoder = PositionalEncoding(d_model=d_model, max_seq_length=max_seq_length)
        self.encoder_blocks = nn.ModuleList(
            [
                EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(n_layers) # noqa E501
            ]
        )

    def forward(self, x, src_mask):
        x = self.embedding(x)
        x = self.positional_encoder(x)
        for layer in self.encoder_blocks:
            x = layer(x, src_mask)
        return x

In [9]:
class ClassifierHead(nn.Module):
    def __init__(self, d_model, num_classes):
        super().__init__()
        self.classifier_nn = nn.Linear(d_model, num_classes)

    def forward(self, x):
        logits = self.classifier_nn(x)
        return F.softmax(logits)

In [10]:
vocab_size = 256
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.2
seq_length = 256
num_classes = 2

transformer_encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, seq_length)
classifier = ClassifierHead(d_model, num_classes)

### Model 3.: Transformer Decoder

In [11]:
torch.ones(1, seq_length, seq_length)

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [12]:
torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)

tensor([[[0., 1., 1.,  ..., 1., 1., 1.],
         [0., 0., 1.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [13]:
tgt_mask = (1 - torch.triu(
  torch.ones(1, seq_length, seq_length), diagonal=1)
).bool()

In [14]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [15]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, tgt_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, tgt_mask)
        x = self.fc(x)
        # Log de probabilidades para computação mais rápida e maior estabilidade com probas perto de 0.
        # Mapeia [0, 1] para (-inf, 0]
        return F.log_softmax(x, dim=-1)

In [16]:
max_seq_length = 256
batch_size = 1
vocab_size = 1000
torch.randint(low=0, high=vocab_size, size=(batch_size, max_seq_length))

tensor([[457, 812, 385, 732, 773, 206, 769, 487,  73, 416, 620, 975, 457,  78,
         520, 937, 527, 399,  70, 748, 889, 848, 751, 639, 854, 194, 349, 856,
         844, 445, 280, 960, 463, 209, 295, 898, 939, 587, 718, 145, 989, 267,
         336, 619, 852, 244, 960, 150, 805, 105, 525, 132, 372, 743,  17, 806,
         763, 611, 564, 843, 467, 848, 446, 490, 723, 537,  75, 268, 798, 338,
         127, 565,  60, 117, 784, 717, 671, 817,  85, 455, 739, 973, 626, 364,
         409,  51, 529, 233, 451,  40, 893, 658, 723, 679, 570, 482, 972, 385,
         579, 488, 793, 486, 748, 272, 464, 830, 336, 486, 717, 297, 982, 315,
         711,  63, 848, 434, 873, 677, 602, 711,  35, 229, 235, 298, 667, 854,
         205, 307, 954, 709, 409,  37, 759, 603, 790, 645, 564, 760, 386, 302,
         649,  88, 739, 485, 971, 822, 181, 271,  56,  41, 484, 121, 204, 594,
         285, 799, 122, 180, 947, 367, 443,  49, 680, 947, 311, 293, 821, 353,
         910, 516,  30, 976, 578, 951, 893, 630, 152

In [17]:
max_seq_length = 256
batch_size = 2
vocab_size = 1000
input_tokens = torch.randint(low=0, high=vocab_size, size=(batch_size, max_seq_length))

transformer_decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)   
output = transformer_decoder(input_tokens, tgt_mask)

In [18]:
print(input_tokens)

tensor([[536, 662, 181, 601, 254, 640, 329, 121, 629, 260, 467, 115, 594, 491,
         408, 209, 767, 377, 456, 882, 142, 285, 964,  83, 397,  49, 952, 613,
         504, 180, 520, 379, 423, 514, 757, 532, 536, 502,  28, 246, 227, 892,
         546, 599, 337, 421, 465, 642, 221, 824, 736, 993, 998, 654, 129, 671,
         564, 592, 771, 725, 286, 123, 752, 981,  57, 759, 877, 294, 849, 181,
         504,  37, 602, 908, 551,  22, 673, 252, 926, 666, 565, 718,  58, 336,
         277, 128,  16, 512,  36, 834, 249, 214, 447, 635, 730, 554, 899, 771,
         716, 454, 342, 339, 139, 487, 701, 404, 654, 829, 601, 625, 160, 405,
         944, 543, 269, 179,  30, 804,  45, 379, 292, 757, 754, 300, 326, 147,
         628,  50, 274, 639,  63, 903, 992, 497, 930, 633, 752, 313, 472, 920,
         925, 608, 992, 769,  82, 346, 575, 698, 722, 115, 695, 576,  18, 124,
         818, 822, 352, 402, 769, 722, 833, 485, 409, 273, 950,  14, 487,  80,
         710, 975, 445, 849, 970, 433, 288, 849, 232

In [19]:
print(output)

tensor([[[-6.4395, -6.9373, -7.8817,  ..., -7.6871, -6.9264, -6.9265],
         [-7.3706, -6.5637, -8.5538,  ..., -8.2631, -6.3927, -6.7559],
         [-6.8448, -7.1833, -7.7301,  ..., -7.5778, -7.2273, -5.6265],
         ...,
         [-8.1037, -7.1378, -6.7672,  ..., -7.9400, -6.4561, -7.3698],
         [-7.0581, -6.8538, -7.3437,  ..., -6.4154, -7.7382, -6.6302],
         [-7.4199, -6.4425, -8.1437,  ..., -5.8654, -6.9702, -6.9465]],

        [[-6.7747, -7.1154, -8.1558,  ..., -7.1334, -7.1462, -7.1406],
         [-7.0469, -6.4632, -7.1400,  ..., -6.8733, -7.1143, -7.3796],
         [-7.3536, -6.4732, -7.4698,  ..., -7.4088, -6.4714, -7.1459],
         ...,
         [-7.0343, -7.0456, -7.1654,  ..., -7.0362, -7.5145, -6.4881],
         [-7.2127, -6.8117, -6.9803,  ..., -7.0661, -7.0490, -8.3353],
         [-6.1795, -7.5397, -7.5181,  ..., -7.4484, -7.6458, -7.4995]]],
       grad_fn=<LogSoftmaxBackward0>)


In [20]:
# torch.randn(size=(2, 256, 8, 64)).view((2, 256, 512)) # for debugging reshapes

### Model 4.: Encoder-decoder Transformer

In [21]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y, tgt_mask, cross_mask):
        self_attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        cross_attn_output = self.cross_attn(x, y, y, cross_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, y, tgt_mask, cross_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, y, tgt_mask, cross_mask)
        x = self.fc(x)
        # Log de probabilidades para computação mais rápida e maior estabilidade com probas perto de 0.
        # Mapeia [0, 1] para (-inf, 0]
        return F.log_softmax(x, dim=-1)

In [23]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
        self.decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)

    def forward(self, x, src_mask, tgt_mask, cross_mask):
        encoder_output = self.encoder(x, src_mask)
        decoder_output = self.decoder(x, encoder_output, tgt_mask, cross_mask)
        return decoder_output

In [23]:
def generate_padding_mask(sequence, pad_token=0):
    # Mask out padding tokens (assumes pad_token is 0)
    return (sequence != pad_token).unsqueeze(1).unsqueeze(2)

# sequence = torch.tensor([2, 6, 30, 120, 0, 0, 0, 0])
src_mask = generate_padding_mask(input_tokens)
cross_mask = src_mask

In [26]:
transformer = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
outputs = transformer(input_tokens, src_mask, tgt_mask, cross_mask)
print(outputs)
print(outputs.shape)

tensor([[[-7.2952, -7.7172, -6.4120,  ..., -8.0577, -6.7412, -6.9631],
         [-8.0978, -5.6856, -6.9766,  ..., -7.3987, -5.9954, -7.1549],
         [-6.8227, -6.9482, -6.8970,  ..., -6.4872, -6.5135, -6.8972],
         ...,
         [-7.2849, -7.2086, -6.8851,  ..., -8.2023, -7.0849, -7.5703],
         [-6.5633, -6.9625, -7.1137,  ..., -7.2809, -6.3619, -6.8622],
         [-7.9815, -6.7311, -6.9914,  ..., -7.6957, -5.9318, -7.9336]],

        [[-7.9445, -7.5784, -7.1699,  ..., -7.4563, -7.4804, -7.3523],
         [-8.1715, -6.8377, -6.5618,  ..., -6.8760, -6.6384, -5.8102],
         [-8.4128, -7.2745, -7.4224,  ..., -7.3137, -6.4476, -6.8861],
         ...,
         [-6.9692, -6.8992, -7.4177,  ..., -7.3548, -7.0166, -6.5395],
         [-6.2935, -7.0444, -7.1079,  ..., -7.5451, -6.3904, -7.3940],
         [-7.0938, -7.0121, -7.0166,  ..., -7.9626, -6.6922, -8.1125]]],
       grad_fn=<LogSoftmaxBackward0>)
torch.Size([2, 256, 1000])


### Tokenizer

In [20]:
### 259 tokens possíveis:
# ASCII + <UNK> + <SOS> + <EOS> + <PAD>

class TokenizerChar:
    def __init__(self):
        self.chr_to_idx = {chr(v): v for v in range(1, 257)}
        self.chr_to_idx['<SOS>'] = 257
        self.chr_to_idx['<EOS>'] = 258
        self.chr_to_idx['<PAD>'] = 0
        self.chr_to_idx['<UNK>'] = 259

        self.idx_to_chr = {v: k for k, v in self.chr_to_idx.items()}

        self.vocab_size = len(self.chr_to_idx.keys())

    def encode(self, char):
        if char in self.chr_to_idx.keys():
            return self.chr_to_idx[char]
        else:
            return 259
    
    def decode(self, token_idx):
        return self.idx_to_chr[token_idx]
    
    def sos_token(self):
        return '<SOS>'
    
    def sos_token_idx(self):
        return self.chr_to_idx['<SOS>']

    def eos_token(self):
        return '<EOS>'
    
    def eos_token_idx(self):
        return self.chr_to_idx['<EOS>']
    
    def pad_token(self):
        return '<PAD>'
    
    def pad_token_idx(self):
        return self.chr_to_idx['<PAD>']
    
    def get_vocab_size(self):
        return self.vocab_size


### Dataset

In [21]:
with open('dataset_text/dialogs.txt') as file:
    print(len(file.read().split('\n')))

3725


In [22]:
lista = [1, 2, 3, 4]

In [23]:
lista[:-1]

[1, 2, 3]

In [24]:
lista[1:]

[2, 3, 4]

In [25]:
from torch.utils.data import Dataset


class DatasetDialogs(Dataset):
    def __init__(self, dataset_path, sentence_length):
        self.dataset_path = dataset_path
        self.sentence_length = sentence_length
        self.tokenizer = TokenizerChar()

    def __len__(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return num_of_sentences
    
    def get_shape(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return (num_of_sentences, self.sentence_length)

    def __getitem__(self, line_idx):
        with open(self.dataset_path, 'r') as dataset:
            selected_sentence = dataset.read().split('\n')[line_idx]
            if len(selected_sentence) < self.sentence_length:
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens.append(self.tokenizer.eos_token_idx())
                pad_length = self.sentence_length - len(input_tokens) + 1
                pad_tokens = [self.tokenizer.pad_token_idx()] * pad_length
                input_tokens += pad_tokens
            elif len(selected_sentence) == self.sentence_length:
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens[-1] = self.tokenizer.eos_token_idx()
            elif len(selected_sentence) > self.sentence_length:
                selected_sentence = selected_sentence[:self.sentence_length]
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens[-1] = self.tokenizer.eos_token_idx()
            # print(f'{len(input_tokens)} - {self.sentence_length}') # debug only
            assert len(input_tokens) == self.sentence_length + 1, f"Lista de índices de tokens não possui mesmo tamanho que 'sentence_length'! len(input_tokens): {len(input_tokens)} - self.sentence_length: {self.sentence_length}"
            try:
                x = torch.tensor(input_tokens[:-1])
                y = torch.tensor(input_tokens[1:])
            except RuntimeError as e:
                print(e)
                print(f"Input tokens: {input_tokens}")
                raise e
        return x, y

In [26]:
dataset = DatasetDialogs('dataset_text/dialogs.txt', 150)

In [27]:
chr(104)

'h'

In [28]:
dataset[0]

(tensor([257, 104, 105,  44,  32, 104, 111, 119,  32,  97, 114, 101,  32, 121,
         111, 117,  32, 100, 111, 105, 110, 103,  63,   9, 105,  39, 109,  32,
         102, 105, 110, 101,  46,  32, 104, 111, 119,  32,  97,  98, 111, 117,
         116,  32, 121, 111, 117, 114, 115, 101, 108, 102,  63, 258,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0]),
 tensor([104, 105,  44,  32, 104, 111, 119,  32,  97, 114, 101,  32, 121, 111,
         117,  32, 100, 111, 105, 110, 103,  63,   9, 105,  39, 109,  

### Training Loop

In [29]:
import torch
torch.cuda.is_available()

True

In [30]:
torch.zeros(1).cuda()

tensor([0.], device='cuda:0')

In [52]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
sequence_length = 50
batch_size = 128
dataset_train = DatasetDialogs('dataset_text/dialogs_train.txt', sequence_length)
dataset_test = DatasetDialogs('dataset_text/dialogs_test.txt', sequence_length)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
vocab_size = dataset_train.tokenizer.get_vocab_size()

d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1
max_seq_length = sequence_length
model = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
model.to(device)

tgt_mask = (1 - torch.triu(
  torch.ones(1, sequence_length, sequence_length), diagonal=1)
).bool()

def init_weights(module):
    if isinstance(module, (nn.Linear)):
        init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            init.zeros_(module.bias)
model.apply(init_weights)

optimizer = Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
n_epochs = 200
n_batches = int(dataset_train.__len__() // batch_size)

print("Starting model training...")
for epoch in range(n_epochs):
    print(f"Epoch: {epoch + 1}")
    avg_loss = 0
    model.train()
    for batch_idx, batch in enumerate(tqdm(dataloader_train, total=n_batches)):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        outputs = model(x, tgt_mask.to(device))
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
        # print(f"Current training loss: {loss.item()}")
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    torch.save(model.state_dict(), f'model_checkpoints/model_checkpoint_{epoch+1}.pth')

    avg_loss /= (batch_idx + 1)
    print(f"Average epoch training loss: {avg_loss}")
    print(f"Last batch training loss: {loss}")

    model.eval()
    avg_loss = 0
    for batch_idx, batch in enumerate(dataloader_test):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        outputs = model(x, tgt_mask.to(device))
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
    
    avg_loss /= (batch_idx + 1)
    print(f"Epoch validation loss: {avg_loss}")
    

cuda
Starting model training...
Epoch: 1


27it [00:04,  5.86it/s]                        


Average epoch training loss: 3.631200737423367
Last batch training loss: 3.270343780517578
Epoch validation loss: 3.1462790966033936
Epoch: 2


27it [00:04,  6.16it/s]                        


Average epoch training loss: 3.217517040393971
Last batch training loss: 3.2045228481292725
Epoch validation loss: 3.1064573923746743
Epoch: 3


27it [00:04,  6.02it/s]                        


Average epoch training loss: 3.175163295533922
Last batch training loss: 3.1233768463134766
Epoch validation loss: 3.1050404707590737
Epoch: 4


27it [00:04,  6.04it/s]                        


Average epoch training loss: 3.082420304969505
Last batch training loss: 3.025723934173584
Epoch validation loss: 2.9546426932017007
Epoch: 5


27it [00:04,  6.11it/s]                        


Average epoch training loss: 2.984009283560294
Last batch training loss: 2.988713502883911
Epoch validation loss: 2.918907403945923
Epoch: 6


27it [00:04,  6.08it/s]                        


Average epoch training loss: 2.9456180643152305
Last batch training loss: 2.936941623687744
Epoch validation loss: 2.9018630981445312
Epoch: 7


27it [00:04,  6.20it/s]                        


Average epoch training loss: 2.928481667130082
Last batch training loss: 2.8524909019470215
Epoch validation loss: 2.8910324573516846
Epoch: 8


27it [00:04,  6.22it/s]                        


Average epoch training loss: 2.9222613793832286
Last batch training loss: 2.9172606468200684
Epoch validation loss: 2.8777031103769937
Epoch: 9


27it [00:04,  6.25it/s]                        


Average epoch training loss: 2.902684485470807
Last batch training loss: 2.928664207458496
Epoch validation loss: 2.864769617716471
Epoch: 10


27it [00:04,  6.18it/s]                        


Average epoch training loss: 2.850275666625411
Last batch training loss: 2.8455817699432373
Epoch validation loss: 2.7999914487202964
Epoch: 11


27it [00:04,  6.14it/s]                        


Average epoch training loss: 2.769008486359208
Last batch training loss: 2.6587822437286377
Epoch validation loss: 2.6993583838144937
Epoch: 12


27it [00:04,  6.14it/s]                        


Average epoch training loss: 2.6406862205929227
Last batch training loss: 2.498523473739624
Epoch validation loss: 2.498384952545166
Epoch: 13


27it [00:04,  6.06it/s]                        


Average epoch training loss: 2.476541245425189
Last batch training loss: 2.3914785385131836
Epoch validation loss: 2.377758582433065
Epoch: 14


27it [00:04,  6.08it/s]                        


Average epoch training loss: 2.3826850431936757
Last batch training loss: 2.348893642425537
Epoch validation loss: 2.3375622431437173
Epoch: 15


27it [00:04,  6.03it/s]                        


Average epoch training loss: 2.309968347902651
Last batch training loss: 2.271655797958374
Epoch validation loss: 2.2653963565826416
Epoch: 16


27it [00:04,  5.99it/s]                        


Average epoch training loss: 2.2571882053657815
Last batch training loss: 2.281708002090454
Epoch validation loss: 2.2151970863342285
Epoch: 17


27it [00:04,  5.99it/s]                        


Average epoch training loss: 2.2098090295438415
Last batch training loss: 2.162015676498413
Epoch validation loss: 2.177963892618815
Epoch: 18


27it [00:04,  5.94it/s]                        


Average epoch training loss: 2.167961261890553
Last batch training loss: 2.1280324459075928
Epoch validation loss: 2.1476951440175376
Epoch: 19


27it [00:04,  5.94it/s]                        


Average epoch training loss: 2.1311797477580883
Last batch training loss: 2.1202449798583984
Epoch validation loss: 2.1128693421681723
Epoch: 20


27it [00:04,  5.95it/s]                        


Average epoch training loss: 2.101306509088587
Last batch training loss: 2.0806212425231934
Epoch validation loss: 2.053911566734314
Epoch: 21


27it [00:04,  5.99it/s]                        


Average epoch training loss: 2.069378066945959
Last batch training loss: 2.065927743911743
Epoch validation loss: 2.0545667807261148
Epoch: 22


27it [00:04,  5.90it/s]                        


Average epoch training loss: 2.034712645742628
Last batch training loss: 1.9972240924835205
Epoch validation loss: 2.015332341194153
Epoch: 23


27it [00:04,  5.79it/s]                        


Average epoch training loss: 2.0081084260234126
Last batch training loss: 2.0452616214752197
Epoch validation loss: 2.0050973097483316
Epoch: 24


27it [00:05,  5.12it/s]                        


Average epoch training loss: 1.9797711416527077
Last batch training loss: 1.9808369874954224
Epoch validation loss: 2.010214924812317
Epoch: 25


27it [00:04,  5.82it/s]                        


Average epoch training loss: 1.9528468361607305
Last batch training loss: 1.9018155336380005
Epoch validation loss: 1.9482027689615886
Epoch: 26


27it [00:04,  5.80it/s]                        


Average epoch training loss: 1.927673794605114
Last batch training loss: 1.9534002542495728
Epoch validation loss: 1.923122485478719
Epoch: 27


27it [00:04,  5.71it/s]                        


Average epoch training loss: 1.895802546430517
Last batch training loss: 1.8328042030334473
Epoch validation loss: 1.9270611604054768
Epoch: 28


27it [00:04,  5.73it/s]                        


Average epoch training loss: 1.873426565417537
Last batch training loss: 1.9212485551834106
Epoch validation loss: 1.8790709177652996
Epoch: 29


27it [00:04,  5.65it/s]                        


Average epoch training loss: 1.852707478735182
Last batch training loss: 1.818827509880066
Epoch validation loss: 1.8550440470377605
Epoch: 30


27it [00:04,  5.73it/s]                        


Average epoch training loss: 1.83295589906198
Last batch training loss: 1.8228161334991455
Epoch validation loss: 1.8376917441685994
Epoch: 31


27it [00:04,  5.64it/s]                        


Average epoch training loss: 1.8171954464029383
Last batch training loss: 1.7663995027542114
Epoch validation loss: 1.8179478645324707
Epoch: 32


27it [00:04,  5.66it/s]                        


Average epoch training loss: 1.7952421638700697
Last batch training loss: 1.7802155017852783
Epoch validation loss: 1.8215490976969402
Epoch: 33


27it [00:04,  5.42it/s]                        


Average epoch training loss: 1.7775881599496912
Last batch training loss: 1.8237154483795166
Epoch validation loss: 1.7881699403127034
Epoch: 34


27it [00:04,  5.58it/s]                        


Average epoch training loss: 1.759861327983715
Last batch training loss: 1.7586067914962769
Epoch validation loss: 1.7881935437520344
Epoch: 35


27it [00:04,  5.50it/s]                        


Average epoch training loss: 1.737604233953688
Last batch training loss: 1.7373161315917969
Epoch validation loss: 1.782143235206604
Epoch: 36


27it [00:04,  5.57it/s]                        


Average epoch training loss: 1.7194279917964228
Last batch training loss: 1.7526527643203735
Epoch validation loss: 1.7731731335322063
Epoch: 37


27it [00:04,  5.63it/s]                        


Average epoch training loss: 1.7031026195596766
Last batch training loss: 1.6737476587295532
Epoch validation loss: 1.7555227677027385
Epoch: 38


27it [00:04,  5.57it/s]                        


Average epoch training loss: 1.689928999653569
Last batch training loss: 1.6390647888183594
Epoch validation loss: 1.7617863416671753
Epoch: 39


27it [00:04,  5.57it/s]                        


Average epoch training loss: 1.6756687694125705
Last batch training loss: 1.70615816116333
Epoch validation loss: 1.7353893518447876
Epoch: 40


27it [00:04,  5.59it/s]                        


Average epoch training loss: 1.6543612921679463
Last batch training loss: 1.6276583671569824
Epoch validation loss: 1.7260348002115886
Epoch: 41


27it [00:04,  5.56it/s]                        


Average epoch training loss: 1.6444137493769329
Last batch training loss: 1.6910130977630615
Epoch validation loss: 1.7140486637751262
Epoch: 42


27it [00:04,  5.58it/s]                        


Average epoch training loss: 1.6273349479392722
Last batch training loss: 1.563935399055481
Epoch validation loss: 1.7223403453826904
Epoch: 43


27it [00:04,  5.54it/s]                        


Average epoch training loss: 1.6165356238683064
Last batch training loss: 1.6425926685333252
Epoch validation loss: 1.711678147315979
Epoch: 44


27it [00:04,  5.50it/s]                        


Average epoch training loss: 1.6038325671796445
Last batch training loss: 1.5557818412780762
Epoch validation loss: 1.7050451040267944
Epoch: 45


27it [00:04,  5.55it/s]                        


Average epoch training loss: 1.5910468278107819
Last batch training loss: 1.6560511589050293
Epoch validation loss: 1.7146590153376262
Epoch: 46


27it [00:04,  5.60it/s]                        


Average epoch training loss: 1.5790089148062247
Last batch training loss: 1.6223037242889404
Epoch validation loss: 1.688338001569112
Epoch: 47


27it [00:04,  5.49it/s]                        


Average epoch training loss: 1.5682607977478593
Last batch training loss: 1.6014686822891235
Epoch validation loss: 1.6525832414627075
Epoch: 48


27it [00:04,  5.53it/s]                        


Average epoch training loss: 1.553159616611622
Last batch training loss: 1.6111876964569092
Epoch validation loss: 1.64750341574351
Epoch: 49


27it [00:04,  5.50it/s]                        


Average epoch training loss: 1.544550237832246
Last batch training loss: 1.554160475730896
Epoch validation loss: 1.696431557337443
Epoch: 50


27it [00:04,  5.50it/s]                        


Average epoch training loss: 1.5303838738688715
Last batch training loss: 1.5169649124145508
Epoch validation loss: 1.6544324954350789
Epoch: 51


27it [00:04,  5.55it/s]                        


Average epoch training loss: 1.5195457935333252
Last batch training loss: 1.5137351751327515
Epoch validation loss: 1.6491927703221638
Epoch: 52


27it [00:05,  5.31it/s]                        


Average epoch training loss: 1.5105449314470645
Last batch training loss: 1.5221378803253174
Epoch validation loss: 1.6753324270248413
Epoch: 53


27it [00:04,  5.47it/s]                        


Average epoch training loss: 1.4994141569844
Last batch training loss: 1.5175645351409912
Epoch validation loss: 1.638737440109253
Epoch: 54


27it [00:04,  5.46it/s]                        


Average epoch training loss: 1.4879678222868178
Last batch training loss: 1.49513578414917
Epoch validation loss: 1.6585108041763306
Epoch: 55


27it [00:04,  5.45it/s]                        


Average epoch training loss: 1.478119479285346
Last batch training loss: 1.5074702501296997
Epoch validation loss: 1.6279289325078328
Epoch: 56


27it [00:04,  5.44it/s]                        


Average epoch training loss: 1.4677707265924524
Last batch training loss: 1.4769420623779297
Epoch validation loss: 1.6254883607228596
Epoch: 57


27it [00:05,  5.32it/s]                        


Average epoch training loss: 1.4574999544355605
Last batch training loss: 1.4607396125793457
Epoch validation loss: 1.6360883315404255
Epoch: 58


27it [00:04,  5.52it/s]                        


Average epoch training loss: 1.4490914388939187
Last batch training loss: 1.4908528327941895
Epoch validation loss: 1.6372916301091511
Epoch: 59


27it [00:04,  5.52it/s]                        


Average epoch training loss: 1.4388559880080047
Last batch training loss: 1.4644023180007935
Epoch validation loss: 1.620025356610616
Epoch: 60


27it [00:04,  5.44it/s]                        


Average epoch training loss: 1.4305531625394468
Last batch training loss: 1.4532864093780518
Epoch validation loss: 1.610959490140279
Epoch: 61


27it [00:04,  5.49it/s]                        


Average epoch training loss: 1.4227148515206796
Last batch training loss: 1.3885654211044312
Epoch validation loss: 1.6150290568669636
Epoch: 62


27it [00:04,  5.52it/s]                        


Average epoch training loss: 1.413178390926785
Last batch training loss: 1.413625717163086
Epoch validation loss: 1.6164149045944214
Epoch: 63


27it [00:05,  5.39it/s]                        


Average epoch training loss: 1.4018245449772588
Last batch training loss: 1.3981595039367676
Epoch validation loss: 1.6117057800292969
Epoch: 64


27it [00:04,  5.48it/s]                        


Average epoch training loss: 1.3944140098713063
Last batch training loss: 1.3668015003204346
Epoch validation loss: 1.6126211086908977
Epoch: 65


27it [00:04,  5.45it/s]                        


Average epoch training loss: 1.3892123169369168
Last batch training loss: 1.3807634115219116
Epoch validation loss: 1.605349858601888
Epoch: 66


27it [00:04,  5.41it/s]                        


Average epoch training loss: 1.3848288279992562
Last batch training loss: 1.3423805236816406
Epoch validation loss: 1.6323504050572712
Epoch: 67


27it [00:04,  5.43it/s]                        


Average epoch training loss: 1.3733824447349265
Last batch training loss: 1.3335895538330078
Epoch validation loss: 1.6177396774291992
Epoch: 68


27it [00:04,  5.44it/s]                        


Average epoch training loss: 1.3674358041198165
Last batch training loss: 1.3418118953704834
Epoch validation loss: 1.6410496632258098
Epoch: 69


27it [00:04,  5.42it/s]                        


Average epoch training loss: 1.356656131920991
Last batch training loss: 1.3129456043243408
Epoch validation loss: 1.5834746360778809
Epoch: 70


27it [00:04,  5.46it/s]                        


Average epoch training loss: 1.3480860701313726
Last batch training loss: 1.3881807327270508
Epoch validation loss: 1.6246823867162068
Epoch: 71


27it [00:04,  5.42it/s]                        


Average epoch training loss: 1.3405673503875732
Last batch training loss: 1.4086859226226807
Epoch validation loss: 1.5967312256495159
Epoch: 72


27it [00:04,  5.49it/s]                        


Average epoch training loss: 1.3298241827223036
Last batch training loss: 1.3538657426834106
Epoch validation loss: 1.5959595441818237
Epoch: 73


27it [00:04,  5.43it/s]                        


Average epoch training loss: 1.323707898457845
Last batch training loss: 1.3019362688064575
Epoch validation loss: 1.5884768168131511
Epoch: 74


27it [00:04,  5.42it/s]                        


Average epoch training loss: 1.314359916581048
Last batch training loss: 1.3022477626800537
Epoch validation loss: 1.5967770020167034
Epoch: 75


27it [00:04,  5.41it/s]                        


Average epoch training loss: 1.3060209265461675
Last batch training loss: 1.359614610671997
Epoch validation loss: 1.5930280288060505
Epoch: 76


27it [00:05,  5.36it/s]                        


Average epoch training loss: 1.2972608451490049
Last batch training loss: 1.2875182628631592
Epoch validation loss: 1.5771726369857788
Epoch: 77


27it [00:04,  5.41it/s]                        


Average epoch training loss: 1.2951776186625164
Last batch training loss: 1.3188188076019287
Epoch validation loss: 1.6046090523401897
Epoch: 78


27it [00:04,  5.42it/s]                        


Average epoch training loss: 1.2844246670051858
Last batch training loss: 1.2545149326324463
Epoch validation loss: 1.5549076000849407
Epoch: 79


27it [00:05,  5.36it/s]                        


Average epoch training loss: 1.2797812400040802
Last batch training loss: 1.2759819030761719
Epoch validation loss: 1.6080012321472168
Epoch: 80


27it [00:05,  5.36it/s]                        


Average epoch training loss: 1.272619229775888
Last batch training loss: 1.2526865005493164
Epoch validation loss: 1.612975279490153
Epoch: 81


27it [00:04,  5.46it/s]                        


Average epoch training loss: 1.2675558107870597
Last batch training loss: 1.2637858390808105
Epoch validation loss: 1.55881929397583
Epoch: 82


27it [00:05,  5.33it/s]                        


Average epoch training loss: 1.263860238922967
Last batch training loss: 1.2898036241531372
Epoch validation loss: 1.5581261316935222
Epoch: 83


27it [00:04,  5.44it/s]                        


Average epoch training loss: 1.2565745380189683
Last batch training loss: 1.296628475189209
Epoch validation loss: 1.5414034128189087
Epoch: 84


27it [00:04,  5.48it/s]                        


Average epoch training loss: 1.245595106372127
Last batch training loss: 1.249220371246338
Epoch validation loss: 1.5753037134806316
Epoch: 85


27it [00:05,  5.39it/s]                        


Average epoch training loss: 1.2464852686281558
Last batch training loss: 1.2879194021224976
Epoch validation loss: 1.5498493512471516
Epoch: 86


27it [00:05,  5.29it/s]                        


Average epoch training loss: 1.2363834425255105
Last batch training loss: 1.234053134918213
Epoch validation loss: 1.5461691617965698
Epoch: 87


27it [00:04,  5.44it/s]                        


Average epoch training loss: 1.2318586729190968
Last batch training loss: 1.2544015645980835
Epoch validation loss: 1.5648193359375
Epoch: 88


27it [00:04,  5.47it/s]                        


Average epoch training loss: 1.2208391781206485
Last batch training loss: 1.2091922760009766
Epoch validation loss: 1.5951104958852131
Epoch: 89


27it [00:05,  5.38it/s]                        


Average epoch training loss: 1.2167167089603566
Last batch training loss: 1.1496130228042603
Epoch validation loss: 1.5892661809921265
Epoch: 90


27it [00:05,  5.35it/s]                        


Average epoch training loss: 1.208462746055038
Last batch training loss: 1.2461235523223877
Epoch validation loss: 1.591265122095744
Epoch: 91


27it [00:05,  5.38it/s]                        


Average epoch training loss: 1.2054061271526195
Last batch training loss: 1.1908705234527588
Epoch validation loss: 1.575020710627238
Epoch: 92


27it [00:04,  5.42it/s]                        


Average epoch training loss: 1.2009868886735704
Last batch training loss: 1.1815342903137207
Epoch validation loss: 1.5645435253779094
Epoch: 93


27it [00:04,  5.47it/s]                        


Average epoch training loss: 1.1960563483061615
Last batch training loss: 1.1775708198547363
Epoch validation loss: 1.5679397980372112
Epoch: 94


27it [00:05,  5.38it/s]                        


Average epoch training loss: 1.1885741419262357
Last batch training loss: 1.159967064857483
Epoch validation loss: 1.5706898768742878
Epoch: 95


27it [00:05,  5.35it/s]                        


Average epoch training loss: 1.1793474753697712
Last batch training loss: 1.1615978479385376
Epoch validation loss: 1.6003543138504028
Epoch: 96


27it [00:04,  5.46it/s]                        


Average epoch training loss: 1.1711841071093525
Last batch training loss: 1.1668510437011719
Epoch validation loss: 1.5762832164764404
Epoch: 97


27it [00:05,  5.36it/s]                        


Average epoch training loss: 1.1641585517812658
Last batch training loss: 1.162588119506836
Epoch validation loss: 1.5604037841161091
Epoch: 98


27it [00:05,  5.34it/s]                        


Average epoch training loss: 1.1556663115819295
Last batch training loss: 1.1466517448425293
Epoch validation loss: 1.5728081464767456
Epoch: 99


 96%|█████████▌| 25/26 [00:04<00:00,  5.14it/s]


KeyboardInterrupt: 

### Text Generation / Model Prediction

In [53]:
def pad_sequence_to_length(sequence, sequence_length, pad_token_idx):
    """
    Completa a sequência com tokens de padding até atingir o comprimento desejado.

    Args:
        sequence (list): Sequência de tokens (índices) a ser completada.
        sequence_length (int): Comprimento desejado da sequência.
        pad_token_idx (int): Índice do token de padding.

    Returns:
        list: Sequência completada com tokens de padding.
    """
    if len(sequence) < sequence_length:
        # Adiciona tokens de padding no final da sequência
        sequence += [pad_token_idx] * (sequence_length - len(sequence))
    elif len(sequence) > sequence_length:
        # Trunca a sequência se for maior que o comprimento desejado
        sequence = sequence[:sequence_length]
    return sequence

In [73]:
import torch

def predict(start_text, model, tokenizer, sequence_length=50, temperature=1.0):
    model.eval()
    device = next(model.parameters()).device

    # Tokeniza e preenche a sequência
    sequence = [tokenizer.sos_token_idx()] + [tokenizer.encode(c) for c in start_text]
    sequence = pad_sequence_to_length(sequence, sequence_length, tokenizer.pad_token_idx())
    input_tokens = torch.tensor(sequence, dtype=torch.long).unsqueeze(0).to(device)

    current_text = start_text

    with torch.no_grad():
        for i in range(len(start_text), sequence_length):
            tgt_mask = torch.tril(torch.ones(sequence_length, sequence_length)).to(device).bool()
            outputs = model(input_tokens, tgt_mask=tgt_mask)
            log_probs = outputs[0, i - 1] / temperature

            predicted_token_idx = torch.distributions.Categorical(logits=log_probs).sample().item()

            if predicted_token_idx == tokenizer.eos_token_idx():
                break

            current_text += tokenizer.decode(predicted_token_idx)
            input_tokens[0, i] = predicted_token_idx

    print('Texto predito:', current_text)
    return current_text


In [74]:
predict('hi, how are you ', model=model, tokenizer=dataset.tokenizer, temperature=0.05)

Texto predito: hi, how are you  going to do?	i'm going to change 


"hi, how are you  going to do?\ti'm going to change "