## Transformer from Scratch

In [2]:
import torch
import torch.nn as nn
import math

### Prep 1.: Embedding

In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    def forward(self, x):
        # Multiplicação com math.sqrt(d_model) pois Positional Encoder tem valores iniciais entre -1 e 1 devido a sin e cos.
        # Essa multiplicação escala os valores da inicialização do nn.Embedding para próximo da escala do Positional Encoder.
        # Inicialização do nn.Embedding é normal com média 0 e standard deviation embedding_dim ** -0.5
        return self.embedding(x) * math.sqrt(self.d_model)

### Prep 2.: Positional Encoding

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        pe = torch.zeros((max_seq_length, d_model))
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

### Model 1.: Attention Mechanism Layer (MultiHeadAttention Layer) + Feed Forward Layer

In [5]:
import torch.nn.functional as F

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model, bias=False)
        self.key_linear = nn.Linear(d_model, d_model, bias=False)
        self.value_linear = nn.Linear(d_model, d_model, bias=False)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        # Entra x com shape (batch_size, seq_length, d_model)
        seq_length = x.size(1)
        x = x.reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        # Sai y com shape (batch_size, num_heads, seq_length, head_dim)
        return x.permute(0, 2, 1, 3)
    
    def compute_attention(self, query, key, value, mask=None):
        # Shape de query, key, value (batch_size, num_heads, seq_length, head_dim)
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1) # dim -1 significa softmax computada ao longo da dimensão head_dim, que é um chunk de embedding.
        return torch.matmul(attention_weights, value)
    
    def combine_heads(self, x, batch_size):
        seq_length = x.size(2)
        x = x.permute(0, 2, 1, 3).contiguous()
        return x.reshape(batch_size, seq_length, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query = self.split_heads(self.query_linear(query), batch_size)
        key = self.split_heads(self.key_linear(key), batch_size)
        value = self.split_heads(self.value_linear(value), batch_size)
        
        attention_weights = self.compute_attention(query, key, value, mask)
        output = self.combine_heads(attention_weights, batch_size)
        return self.output_linear(output)

In [7]:
class FeedForwardSubLayer(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

### Model 2.: Transformer Encoder

In [8]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        attn_output = self.self_attn(x, x, x, mask=src_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [9]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()
        self.embedding = InputEmbeddings(vocab_size=vocab_size, d_model=d_model)
        self.positional_encoder = PositionalEncoding(d_model=d_model, max_seq_length=max_seq_length)
        self.encoder_blocks = nn.ModuleList(
            [
                EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(n_layers) # noqa E501
            ]
        )

    def forward(self, x, src_mask):
        x = self.embedding(x)
        x = self.positional_encoder(x)
        for layer in self.encoder_blocks:
            x = layer(x, src_mask)
        return x

In [10]:
class ClassifierHead(nn.Module):
    def __init__(self, d_model, num_classes):
        super().__init__()
        self.classifier_nn = nn.Linear(d_model, num_classes)

    def forward(self, x):
        logits = self.classifier_nn(x)
        return F.softmax(logits)

In [11]:
vocab_size = 256
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.2
seq_length = 256
num_classes = 2

transformer_encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, seq_length)
classifier = ClassifierHead(d_model, num_classes)

### Model 3.: Transformer Decoder

In [12]:
torch.ones(1, seq_length, seq_length)

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [13]:
torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)

tensor([[[0., 1., 1.,  ..., 1., 1., 1.],
         [0., 0., 1.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [14]:
tgt_mask = (1 - torch.triu(
  torch.ones(1, seq_length, seq_length), diagonal=1)
).bool()

In [15]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, tgt_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, tgt_mask)
        x = self.fc(x)
        # Log de probabilidades para computação mais rápida e maior estabilidade com probas perto de 0.
        # Mapeia [0, 1] para (-inf, 0]
        return F.log_softmax(x, dim=-1)

In [17]:
max_seq_length = 256
batch_size = 1
vocab_size = 1000
torch.randint(low=0, high=vocab_size, size=(batch_size, max_seq_length))

tensor([[ 22, 777, 568, 184, 616, 880,  42, 991, 208, 305, 118, 906,  84, 667,
         430, 836, 281, 645, 465, 830, 586, 511,  30, 717, 543, 619, 362, 336,
         255, 567,  31, 744, 563, 181, 839, 557, 457, 191, 556, 584, 777, 965,
          14, 137, 818, 118, 381, 491, 702,  90, 329, 188, 637,  93, 996,  59,
           6, 636,  63, 721, 160, 402, 737, 304, 677, 294, 656, 291, 594, 507,
         873, 823, 630, 938, 972, 955, 265, 762, 259, 736, 971, 596, 651, 132,
         640, 381, 981, 825,  92, 164, 563, 361, 925,  31, 334, 169, 316, 239,
         346, 948, 297, 643, 217, 422, 743,  45, 434, 812, 153, 262, 864, 900,
         677, 329, 927, 381, 927, 331, 416, 418, 837, 214, 898, 255, 227, 829,
         445, 573, 511, 414, 716, 867,   5, 431, 539, 727, 671, 678,  72, 548,
          67, 266, 317, 507, 776, 113, 566,  27, 869, 749,  67,  89, 995, 788,
         383, 160, 862, 426, 668, 906, 831, 435, 838, 992,  49, 525, 822, 509,
         191, 813, 809, 662, 533, 310, 135, 276, 787

In [18]:
max_seq_length = 256
batch_size = 2
vocab_size = 1000
input_tokens = torch.randint(low=0, high=vocab_size, size=(batch_size, max_seq_length))

transformer_decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)   
output = transformer_decoder(input_tokens, tgt_mask)

In [27]:
print(input_tokens)

tensor([[958, 611, 587, 976, 133, 193, 186, 617, 401, 119, 290,  58, 179, 635,
         421,  90, 253, 819, 869, 303, 375, 250, 687, 979, 596, 745, 829, 304,
         167, 131, 303, 189, 362, 272, 922, 524, 586,  76,   2, 427, 229, 893,
         637, 250, 658, 925, 880, 383, 536, 884, 798, 145, 735, 841, 419, 407,
         854, 728, 361, 763, 383, 740, 257, 397, 686, 393, 174, 876, 742, 641,
         444, 383, 470, 110,  16,  48, 608, 698, 622, 774, 199, 833, 422, 696,
         499, 959, 887, 909, 567, 525, 100, 507, 833, 927, 635, 907, 630, 719,
          15, 755, 285, 606, 538,  75, 464, 746, 231, 840, 213, 662, 296, 260,
         229, 884, 433, 795, 365, 558, 279, 668, 815, 292, 498, 834, 590, 243,
         858, 217, 473, 783, 399, 473, 213, 122, 951, 514, 940, 419,  56,  77,
         524, 392, 932, 704, 477, 228, 588, 351, 292, 771, 776, 753, 440, 463,
         168,  49, 975, 599, 594, 995, 924, 737, 325, 829, 258, 714, 572, 400,
           7, 892, 747, 249, 210, 212, 352, 746, 560

In [20]:
# torch.randn(size=(2, 256, 8, 64)).view((2, 256, 512)) # for debugging reshapes

### Model 4.: Encoder-decoder Transformer

In [21]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y, tgt_mask, cross_mask):
        self_attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        cross_attn_output = self.cross_attn(x, y, y, cross_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, y, tgt_mask, cross_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, y, tgt_mask, cross_mask)
        x = self.fc(x)
        # Log de probabilidades para computação mais rápida e maior estabilidade com probas perto de 0.
        # Mapeia [0, 1] para (-inf, 0]
        return F.log_softmax(x, dim=-1)

In [23]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
        self.decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)

    def forward(self, x, src_mask, tgt_mask, cross_mask):
        encoder_output = self.encoder(x, src_mask)
        decoder_output = self.decoder(x, encoder_output, tgt_mask, cross_mask)
        return decoder_output

In [25]:
def generate_padding_mask(sequence, pad_token=0):
    # Mask out padding tokens (assumes pad_token is 0)
    return (sequence != pad_token).unsqueeze(1).unsqueeze(2)

# sequence = torch.tensor([2, 6, 30, 120, 0, 0, 0, 0])
src_mask = generate_padding_mask(input_tokens)
cross_mask = src_mask

In [26]:
transformer = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
outputs = transformer(input_tokens, src_mask, tgt_mask, cross_mask)
print(outputs)
print(outputs.shape)

tensor([[[-7.2952, -7.7172, -6.4120,  ..., -8.0577, -6.7412, -6.9631],
         [-8.0978, -5.6856, -6.9766,  ..., -7.3987, -5.9954, -7.1549],
         [-6.8227, -6.9482, -6.8970,  ..., -6.4872, -6.5135, -6.8972],
         ...,
         [-7.2849, -7.2086, -6.8851,  ..., -8.2023, -7.0849, -7.5703],
         [-6.5633, -6.9625, -7.1137,  ..., -7.2809, -6.3619, -6.8622],
         [-7.9815, -6.7311, -6.9914,  ..., -7.6957, -5.9318, -7.9336]],

        [[-7.9445, -7.5784, -7.1699,  ..., -7.4563, -7.4804, -7.3523],
         [-8.1715, -6.8377, -6.5618,  ..., -6.8760, -6.6384, -5.8102],
         [-8.4128, -7.2745, -7.4224,  ..., -7.3137, -6.4476, -6.8861],
         ...,
         [-6.9692, -6.8992, -7.4177,  ..., -7.3548, -7.0166, -6.5395],
         [-6.2935, -7.0444, -7.1079,  ..., -7.5451, -6.3904, -7.3940],
         [-7.0938, -7.0121, -7.0166,  ..., -7.9626, -6.6922, -8.1125]]],
       grad_fn=<LogSoftmaxBackward0>)
torch.Size([2, 256, 1000])


### Tokenizer

In [None]:
### 259 tokens possíveis:
# ASCII + <UNK> + <SOS> + <EOS>

### Training Loop