## Transformer from Scratch

In [2]:
import torch
import torch.nn as nn
import math

### Prep 1.: Embedding

In [26]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

### Prep 2.: Positional Encoding

In [35]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        pe = torch.zeros((max_seq_length, d_model))
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

### Model 1.: Attention Mechanism Layer (MultiHeadAttention Layer)

In [36]:
import torch.nn.functional as F

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model, bias=False)
        self.key_linear = nn.Linear(d_model, d_model, bias=False)
        self.value_linear = nn.Linear(d_model, d_model, bias=False)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        # Entra x com shape (batch_size, seq_length, d_model)
        seq_length = x.size(1)
        x = x.reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        # Sai y com shape (batch_size, num_heads, seq_length, head_dim)
        return x.permute(0, 2, 1, 3)
    
    def compute_attention(self, query, key, value, mask=None):
        # Shape de query, key, value (batch_size, num_heads, seq_length, head_dim)
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1) # dim -1 significa softmax computada ao longo da dimensão head_dim, que é um chunk de embedding.
        return torch.matmul(attention_weights, value)
    
    def combine_heads(self, x, batch_size):
        seq_length = x.size(1)
        x = x.permute(0, 2, 1, 3).contiguous()
        return x.reshape(batch_size, seq_length, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query = self.split_heads(self.query_linear(query), batch_size)
        key = self.split_heads(self.key_linear(key), batch_size)
        value = self.split_heads(self.value_linear(value), batch_size)
        
        attention_weights = self.compute_attention(query, key, value, mask)
        output = self.combine_heads(attention_weights, batch_size)
        return self.output_linear(output)

### Model 2.: Transformer Encoder

In [38]:
class FeedForwardSubLayer(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [39]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        attn_output = self.self_attn(x, x, x, mask=src_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [43]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()
        self.embedding = InputEmbeddings(vocab_size=vocab_size, d_model=d_model)
        self.positional_encoder = PositionalEncoding(d_model=d_model, max_seq_length=max_seq_length)
        self.encoder_blocks = nn.ModuleList(
            [
                EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(n_layers) # noqa E501
            ]
        )

    def forward(self, x, src_mask):
        x = self.embedding(x)
        x = self.positional_encoder(x)
        for layer in self.encoder_blocks:
            x = layer(x, src_mask)
        return x

In [44]:
class ClassifierHead(nn.Module):
    def __init__(self, d_model, num_classes):
        super().__init__()
        self.classifier_nn = nn.Linear(d_model, num_classes)

    def forward(self, x):
        logits = self.classifier_nn(x)
        return F.softmax(logits)

In [45]:
vocab_size = 256
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.2
seq_length = 256
num_classes = 2

transformer_encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, seq_length)
classifier = ClassifierHead(d_model, num_classes)

### Model 3.: Transformer Decoder

In [None]:
tgt_mask = (1 - torch.triu(
  torch.ones(1, seq_length, seq_length), diagonal=1)
).bool()