## Transformer from Scratch

In [29]:
import math

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.init as init
from tqdm import tqdm

### Prep 1.: Embedding

In [2]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    def forward(self, x):
        # Multiplicação com math.sqrt(d_model) pois Positional Encoder tem valores iniciais entre -1 e 1 devido a sin e cos.
        # Essa multiplicação escala os valores da inicialização do nn.Embedding para próximo da escala do Positional Encoder.
        # Inicialização do nn.Embedding é normal com média 0 e standard deviation embedding_dim ** -0.5
        return self.embedding(x) * math.sqrt(self.d_model)

### Prep 2.: Positional Encoding

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        pe = torch.zeros((max_seq_length, d_model))
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

### Model 1.: Attention Mechanism Layer (MultiHeadAttention Layer) + Feed Forward Layer

In [4]:
import torch.nn.functional as F

**bias=False explanation**

For certain types of layers, such as transformers and convolutional layers, including a bias term is unnecessary and adds unnecessary overhead to the model.

The reason for this is that these layers are typically followed by a normalization layer, such as Batch Normalization or Layer Normalization. These normalization layers center the data at mean=0 (and std=1), effectively removing any bias.

Therefore, it is common practice to omit the bias term in transformers and convolutional layers that are preceded by a normalization layer.

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model, bias=False)
        self.key_linear = nn.Linear(d_model, d_model, bias=False)
        self.value_linear = nn.Linear(d_model, d_model, bias=False)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        # Entra x com shape (batch_size, seq_length, d_model)
        seq_length = x.size(1)
        x = x.reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        # Sai y com shape (batch_size, num_heads, seq_length, head_dim)
        return x.permute(0, 2, 1, 3)
    
    def compute_attention(self, query, key, value, mask=None):
        # Shape de query, key, value (batch_size, num_heads, seq_length, head_dim)
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1) # dim -1 significa softmax computada ao longo da dimensão head_dim, que é um chunk de embedding.
        return torch.matmul(attention_weights, value)
    
    def combine_heads(self, x, batch_size):
        seq_length = x.size(2)
        x = x.permute(0, 2, 1, 3).contiguous()
        return x.reshape(batch_size, seq_length, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query = self.split_heads(self.query_linear(query), batch_size)
        key = self.split_heads(self.key_linear(key), batch_size)
        value = self.split_heads(self.value_linear(value), batch_size)
        
        attention_weights = self.compute_attention(query, key, value, mask)
        output = self.combine_heads(attention_weights, batch_size)
        return self.output_linear(output)

In [6]:
class FeedForwardSubLayer(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

### Model 2.: Transformer Encoder

In [7]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        attn_output = self.self_attn(x, x, x, mask=src_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [8]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()
        self.embedding = InputEmbeddings(vocab_size=vocab_size, d_model=d_model)
        self.positional_encoder = PositionalEncoding(d_model=d_model, max_seq_length=max_seq_length)
        self.encoder_blocks = nn.ModuleList(
            [
                EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(n_layers) # noqa E501
            ]
        )

    def forward(self, x, src_mask):
        x = self.embedding(x)
        x = self.positional_encoder(x)
        for layer in self.encoder_blocks:
            x = layer(x, src_mask)
        return x

In [9]:
class ClassifierHead(nn.Module):
    def __init__(self, d_model, num_classes):
        super().__init__()
        self.classifier_nn = nn.Linear(d_model, num_classes)

    def forward(self, x):
        logits = self.classifier_nn(x)
        return F.softmax(logits)

In [10]:
vocab_size = 256
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.2
seq_length = 256
num_classes = 2

transformer_encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, seq_length)
classifier = ClassifierHead(d_model, num_classes)

### Model 3.: Transformer Decoder

In [11]:
torch.ones(1, seq_length, seq_length)

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [12]:
torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)

tensor([[[0., 1., 1.,  ..., 1., 1., 1.],
         [0., 0., 1.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [13]:
tgt_mask = (1 - torch.triu(
  torch.ones(1, seq_length, seq_length), diagonal=1)
).bool()

In [14]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [15]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, tgt_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, tgt_mask)
        x = self.fc(x)
        # Log de probabilidades para computação mais rápida e maior estabilidade com probas perto de 0.
        # Mapeia [0, 1] para (-inf, 0]
        return F.log_softmax(x, dim=-1)

In [17]:
max_seq_length = 256
batch_size = 1
vocab_size = 1000
torch.randint(low=0, high=vocab_size, size=(batch_size, max_seq_length))

tensor([[ 22, 777, 568, 184, 616, 880,  42, 991, 208, 305, 118, 906,  84, 667,
         430, 836, 281, 645, 465, 830, 586, 511,  30, 717, 543, 619, 362, 336,
         255, 567,  31, 744, 563, 181, 839, 557, 457, 191, 556, 584, 777, 965,
          14, 137, 818, 118, 381, 491, 702,  90, 329, 188, 637,  93, 996,  59,
           6, 636,  63, 721, 160, 402, 737, 304, 677, 294, 656, 291, 594, 507,
         873, 823, 630, 938, 972, 955, 265, 762, 259, 736, 971, 596, 651, 132,
         640, 381, 981, 825,  92, 164, 563, 361, 925,  31, 334, 169, 316, 239,
         346, 948, 297, 643, 217, 422, 743,  45, 434, 812, 153, 262, 864, 900,
         677, 329, 927, 381, 927, 331, 416, 418, 837, 214, 898, 255, 227, 829,
         445, 573, 511, 414, 716, 867,   5, 431, 539, 727, 671, 678,  72, 548,
          67, 266, 317, 507, 776, 113, 566,  27, 869, 749,  67,  89, 995, 788,
         383, 160, 862, 426, 668, 906, 831, 435, 838, 992,  49, 525, 822, 509,
         191, 813, 809, 662, 533, 310, 135, 276, 787

In [16]:
max_seq_length = 256
batch_size = 2
vocab_size = 1000
input_tokens = torch.randint(low=0, high=vocab_size, size=(batch_size, max_seq_length))

transformer_decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)   
output = transformer_decoder(input_tokens, tgt_mask)

In [17]:
print(input_tokens)

tensor([[ 57, 145, 314,  78, 183, 113, 407, 586, 535, 885, 925, 587, 816,  37,
          55, 252,  66, 988, 593, 906, 248, 802, 770, 274, 645, 344, 320, 297,
         471, 294, 371, 241, 347, 304, 181,  50, 636,  22, 489, 439, 534, 588,
          26, 190, 113, 788, 138,  73, 686, 332, 215, 204, 586, 198, 443, 395,
         999, 216, 793, 212, 881, 363, 592,  90, 370,  33,  51, 638, 483, 558,
         729, 349, 886, 729,  34, 883, 652, 656, 170, 521, 276, 407, 695, 803,
          12, 893, 641, 612, 315, 483, 400, 611, 555, 293, 858, 158, 923, 402,
         614, 860, 409,  37, 564, 998,  22, 337,  60, 146, 728, 210, 684, 342,
         227, 722, 147, 990,  95, 418, 666, 917, 644, 440, 611, 803, 626, 298,
         990, 534, 937, 691,  16, 659, 798,  46, 199, 301, 503, 525, 418, 209,
         985, 981, 292, 457, 605,  72, 953, 105, 937, 685, 472, 343, 659, 150,
         306,   7, 350, 460, 349, 602, 314, 822, 474,  57, 295, 274, 142, 213,
         719, 272, 287, 497, 436, 992,  90, 989, 272

In [18]:
print(output)

tensor([[[-6.7642, -6.8077, -7.8105,  ..., -7.3003, -7.4120, -7.1858],
         [-7.0911, -7.6265, -7.7117,  ..., -6.6925, -6.7127, -7.6584],
         [-7.0995, -7.4291, -6.8647,  ..., -7.2756, -7.0006, -7.9697],
         ...,
         [-7.1360, -6.5465, -7.3965,  ..., -6.7170, -7.5379, -7.3314],
         [-6.5048, -7.2928, -7.6898,  ..., -6.2444, -7.6415, -7.2333],
         [-7.2406, -7.5268, -7.3894,  ..., -6.7142, -7.4443, -8.0571]],

        [[-5.9346, -6.9170, -6.5690,  ..., -7.1437, -6.1088, -7.4729],
         [-6.5090, -6.5014, -6.1054,  ..., -7.6661, -7.3899, -8.0944],
         [-8.3925, -7.2747, -6.2694,  ..., -7.6565, -7.4415, -7.6147],
         ...,
         [-7.3333, -7.1162, -7.4921,  ..., -7.1177, -7.4886, -7.1390],
         [-6.5008, -6.4315, -6.2267,  ..., -6.7472, -7.8710, -7.2734],
         [-6.8413, -6.8020, -7.1416,  ..., -6.7609, -7.6003, -6.8107]]],
       grad_fn=<LogSoftmaxBackward0>)


In [20]:
# torch.randn(size=(2, 256, 8, 64)).view((2, 256, 512)) # for debugging reshapes

### Model 4.: Encoder-decoder Transformer

In [21]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y, tgt_mask, cross_mask):
        self_attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        cross_attn_output = self.cross_attn(x, y, y, cross_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, y, tgt_mask, cross_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, y, tgt_mask, cross_mask)
        x = self.fc(x)
        # Log de probabilidades para computação mais rápida e maior estabilidade com probas perto de 0.
        # Mapeia [0, 1] para (-inf, 0]
        return F.log_softmax(x, dim=-1)

In [23]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
        self.decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)

    def forward(self, x, src_mask, tgt_mask, cross_mask):
        encoder_output = self.encoder(x, src_mask)
        decoder_output = self.decoder(x, encoder_output, tgt_mask, cross_mask)
        return decoder_output

In [25]:
def generate_padding_mask(sequence, pad_token=0):
    # Mask out padding tokens (assumes pad_token is 0)
    return (sequence != pad_token).unsqueeze(1).unsqueeze(2)

# sequence = torch.tensor([2, 6, 30, 120, 0, 0, 0, 0])
src_mask = generate_padding_mask(input_tokens)
cross_mask = src_mask

In [26]:
transformer = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
outputs = transformer(input_tokens, src_mask, tgt_mask, cross_mask)
print(outputs)
print(outputs.shape)

tensor([[[-7.2952, -7.7172, -6.4120,  ..., -8.0577, -6.7412, -6.9631],
         [-8.0978, -5.6856, -6.9766,  ..., -7.3987, -5.9954, -7.1549],
         [-6.8227, -6.9482, -6.8970,  ..., -6.4872, -6.5135, -6.8972],
         ...,
         [-7.2849, -7.2086, -6.8851,  ..., -8.2023, -7.0849, -7.5703],
         [-6.5633, -6.9625, -7.1137,  ..., -7.2809, -6.3619, -6.8622],
         [-7.9815, -6.7311, -6.9914,  ..., -7.6957, -5.9318, -7.9336]],

        [[-7.9445, -7.5784, -7.1699,  ..., -7.4563, -7.4804, -7.3523],
         [-8.1715, -6.8377, -6.5618,  ..., -6.8760, -6.6384, -5.8102],
         [-8.4128, -7.2745, -7.4224,  ..., -7.3137, -6.4476, -6.8861],
         ...,
         [-6.9692, -6.8992, -7.4177,  ..., -7.3548, -7.0166, -6.5395],
         [-6.2935, -7.0444, -7.1079,  ..., -7.5451, -6.3904, -7.3940],
         [-7.0938, -7.0121, -7.0166,  ..., -7.9626, -6.6922, -8.1125]]],
       grad_fn=<LogSoftmaxBackward0>)
torch.Size([2, 256, 1000])


### Tokenizer

In [63]:
### 259 tokens possíveis:
# ASCII + <UNK> + <SOS> + <EOS> + <PAD>

class TokenizerChar:
    def __init__(self):
        self.chr_to_idx = {chr(v): v for v in range(1, 257)}
        self.chr_to_idx['<SOS>'] = 257
        self.chr_to_idx['<EOS>'] = 258
        self.chr_to_idx['<PAD>'] = 0
        self.chr_to_idx['<UNK>'] = 259

        self.idx_to_chr = {v: k for k, v in self.chr_to_idx.items()}

        self.vocab_size = len(self.chr_to_idx.keys())

    def encode(self, char):
        if char in self.chr_to_idx.keys():
            return self.chr_to_idx[char]
        else:
            return 259
    
    def decode(self, token_idx):
        return self.idx_to_chr[token_idx]
    
    def sos_token(self):
        return '<SOS>'
    
    def sos_token_idx(self):
        return self.chr_to_idx['<SOS>']

    def eos_token(self):
        return '<EOS>'
    
    def eos_token_idx(self):
        return self.chr_to_idx['<EOS>']
    
    def pad_token(self):
        return '<PAD>'
    
    def pad_token_idx(self):
        return self.chr_to_idx['<PAD>']
    
    def get_vocab_size(self):
        return self.vocab_size


### Dataset

In [64]:
with open('dataset_text/dialogs.txt') as file:
    print(len(file.read().split('\n')))

3725


In [65]:
lista = [1, 2, 3, 4]

In [66]:
lista[:-1]

[1, 2, 3]

In [67]:
lista[1:]

[2, 3, 4]

In [86]:
from torch.utils.data import Dataset


class DatasetDialogs(Dataset):
    def __init__(self, dataset_path, sentence_length):
        self.dataset_path = dataset_path
        self.sentence_length = sentence_length
        self.tokenizer = TokenizerChar()

    def __len__(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return num_of_sentences
    
    def get_shape(self):
        with open(self.dataset_path, 'r') as dataset:
            num_of_sentences = len(dataset.read().split('\n'))
        return (num_of_sentences, self.sentence_length)

    def __getitem__(self, line_idx):
        with open(self.dataset_path, 'r') as dataset:
            selected_sentence = dataset.read().split('\n')[line_idx]
            if len(selected_sentence) < self.sentence_length:
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                pad_length = self.sentence_length - len(input_tokens)
                pad_tokens = [self.tokenizer.pad_token_idx()] * pad_length
                input_tokens += pad_tokens
                input_tokens.append(self.tokenizer.eos_token_idx())
            else:
                selected_sentence = selected_sentence[:self.sentence_length - 1]
                input_tokens = [self.tokenizer.sos_token_idx()] + [self.tokenizer.encode(char) for char in selected_sentence]
                input_tokens[self.sentence_length - 1] = self.tokenizer.eos_token_idx()
            try:
                x = torch.tensor(input_tokens[:-1])
                y = torch.tensor(input_tokens[1:])
            except RuntimeError as e:
                print(e)
                print(f"Input tokens: {input_tokens}")
                raise e
        return x, y

In [87]:
dataset = DatasetDialogs('dataset_text/dialogs.txt', 256)

In [88]:
chr(104)

'h'

In [89]:
dataset[0]

(tensor([257, 104, 105,  44,  32, 104, 111, 119,  32,  97, 114, 101,  32, 121,
         111, 117,  32, 100, 111, 105, 110, 103,  63,   9, 105,  39, 109,  32,
         102, 105, 110, 101,  46,  32, 104, 111, 119,  32,  97,  98, 111, 117,
         116,  32, 121, 111, 117, 114, 115, 101, 108, 102,  63,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0

### Training Loop

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sequence_length = 150
batch_size = 32
dataset_train = DatasetDialogs('dataset_text/dialogs_train.txt', sequence_length)
dataset_test = DatasetDialogs('dataset_text/dialogs_test.txt', sequence_length)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
vocab_size = dataset_train.tokenizer.get_vocab_size()

d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1
max_seq_length = sequence_length
model = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)

tgt_mask = (1 - torch.triu(
  torch.ones(1, sequence_length, sequence_length), diagonal=1)
).bool()

def init_weights(module):
    if isinstance(module, (nn.Linear)):
        init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            init.zeros_(module.bias)
model.apply(init_weights)

optimizer = Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
n_epochs = 2
n_batches = int(dataset_train.__len__() // batch_size)

print("Starting model training...")
for epoch in range(n_epochs):
    print(f"Epoch: {epoch + 1}")
    avg_loss = 0
    model.train()
    for batch_idx, batch in enumerate(tqdm(dataloader_train, total=n_batches)):
        x, y = batch
        outputs = model(x, tgt_mask)
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
        # print(f"Current training loss: {loss.item()}")
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    torch.save(model.state_dict(), f'model_checkpoints/model_checkpoint_{epoch+1}.pth')

    avg_loss /= (batch_idx + 1)
    print(f"Average epoch training loss: {avg_loss}")
    print(f"Last batch training loss: {loss}")

    model.eval()
    avg_loss = 0
    for batch_idx, batch in enumerate(dataloader_test):
        x, y = batch
        outputs = model(x, tgt_mask)
        loss = loss_fn(outputs.view(-1, vocab_size), y.view(-1))
        avg_loss += loss.item()
    
    avg_loss /= (batch_idx + 1)
    print(f"Epoch validation loss: {avg_loss}")
    

Starting model training...
Epoch: 1


  7%|▋         | 8/107 [00:15<03:09,  1.92s/it]

### Text Generation / Model Prediction

In [None]:
def predict(start_text):
    model.eval()
    tokenizer = TokenizerChar()
    input_tokens = torch.tensor([tokenizer.chr_to_idx[token] for token in start_text])

    current_text = start_text
    for _ in range(sequence_length - len(start_text)):
        outputs = model(input_tokens, tgt_mask)
        outputs_list = outputs.view(-1, vocab_size)
        predicted_token_idx = torch.multinomial(outputs_list, num_samples=1)
        current_text += tokenizer.decode(predicted_token_idx)

    final_text = current_text
    print('Texto preditado: ', final_text)
    return final_text

In [None]:
predict('hi, how are you doing?')