# Silly implementation of a Transformer
Educational purpose implementation of a transformer

In [17]:
import torch
import torch.nn as nn
import math
import copy

### Attention Head
This class represents an attention head as described in the paper Attention Is All You Need (just like the rest of the code)

In [2]:
class Head(nn.Module):
    def __init__(self, model_dim, head_size):
        super().__init__()
        self.k = nn.Linear(model_dim, head_size)
        self.q = nn.Linear(model_dim, head_size)
        self.v = nn.Linear(model_dim, head_size)

        self.head_size = head_size

    def forward(self, q, k, v, mask=None):
        # Project inputs to queries, keys, and values
        q = self.q(q)
        k = self.k(k)  
        v = self.v(v)
        
        # MatMul Q and K transpose
        scores = torch.matmul(q, k.transpose(-2, -1))
        
        # Scale the scores
        scores = scores / math.sqrt(self.head_size)
        
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(0)
            
            if mask.dtype == torch.float and torch.isinf(mask).any():
                scores = scores + mask 
            else:
                scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax to get attention weights
        attention_weights = torch.softmax(scores, dim=-1) 
        
        # Apply attention weights to values
        output = torch.matmul(attention_weights, v)
        
        return output, attention_weights

### Multi-Head Attention

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim, num_heads):
        super().__init__()
        assert model_dim % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_size = model_dim // num_heads
        
        self.heads = nn.ModuleList([
            Head(model_dim, self.head_size) for _ in range(num_heads)
        ])
        
        self.output_linear = nn.Linear(model_dim, model_dim)
    
    def forward(self, query, key, value, mask=None):
        head_outputs = []
        attention_weights = []
        
        for head in self.heads:
            head_output, attn_weights = head(query, key, value, mask)
            head_outputs.append(head_output)
            attention_weights.append(attn_weights)
        
        concat_output = torch.cat(head_outputs, dim=-1)
        
        output = self.output_linear(concat_output)
        
        return output, attention_weights


### Feed Forward

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

### Positional Encoding

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

### LayerNorm

In [11]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

### Sub Layer Connection (add + norm)

In [15]:
class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super().__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

### Encoder Layer

In [14]:
class EncoderLayer(nn.Module):
    def __init__(self, model_dim, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(model_dim, num_heads)
        self.feed_forward = PositionwiseFeedForward(model_dim, d_ff, dropout)
        self.sublayer = nn.ModuleList([SublayerConnection(model_dim, dropout) for _ in range(2)])
        self.size = model_dim
        
    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)[0])
        x = self.sublayer[1](x, self.feed_forward)
        return x

### Decoder Layer

In [16]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.sublayer = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(3)])
        self.size = d_model
        
    def forward(self, x, memory, src_mask=None, tgt_mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)[0])
        
        x = self.sublayer[1](x, lambda x: self.cross_attn(x, memory, memory, src_mask)[0])
        
        x = self.sublayer[2](x, self.feed_forward)
        
        return x

### Encoder

In [18]:
class Encoder(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

### Decoder

In [19]:
class Decoder(nn.Module):
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, memory, src_mask=None, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

### Transformer

In [28]:
class Transformer(nn.Module):
    """
    Complete Transformer model
    """
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
                 d_ff=2048, num_layers=6, dropout=0.1, max_seq_length=5000):
        super(Transformer, self).__init__()
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length, dropout)
        
        # Building Encoder and Decoder layers
        encoder_layer = EncoderLayer(d_model, num_heads, d_ff, dropout)
        decoder_layer = DecoderLayer(d_model, num_heads, d_ff, dropout)
        
        # Building full Encoder and Decoder
        self.encoder = Encoder(encoder_layer, num_layers)
        self.decoder = Decoder(decoder_layer, num_layers)
        
        # Final output layer
        self.generator = nn.Linear(d_model, tgt_vocab_size)
        
        # Initialize parameters
        self._init_parameters()
        
        self.d_model = d_model
        
    def _init_parameters(self):
        """
        Initialize parameters with Xavier uniform
        """
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def encode(self, src, src_mask=None):
        """
        Encode the source sequence
        """
        # Scale embeddings by sqrt(d_model)
        src_embedded = self.src_embedding(src) * math.sqrt(self.d_model)
        src_embedded = self.positional_encoding(src_embedded)
        return self.encoder(src_embedded, src_mask)
    
    def decode(self, tgt, memory, src_mask=None, tgt_mask=None):
        """
        Decode given the encoded source and target sequence
        """
        # Scale embeddings by sqrt(d_model)
        tgt_embedded = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_embedded = self.positional_encoding(tgt_embedded)
        return self.decoder(tgt_embedded, memory, src_mask, tgt_mask)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Take in and process source and target sequences.
        
        Args:
            src: source sequence [batch_size, src_seq_len]
            tgt: target sequence [batch_size, tgt_seq_len]
            src_mask: mask for source sequence [batch_size, 1, src_seq_len] or broadcastable
            tgt_mask: mask for target sequence [batch_size, tgt_seq_len, tgt_seq_len] or broadcastable
        
        Returns:
            output probabilities [batch_size, tgt_seq_len, tgt_vocab_size]
        """
        # Encode the source
        memory = self.encode(src, src_mask)
        
        # Decode with the encoded source and target
        decoder_output = self.decode(tgt, memory, src_mask, tgt_mask)
        
        # Generate the output
        output = self.generator(decoder_output)
        
        return output
    
    def generate_square_subsequent_mask(self, sz):
        """
        Generate a square causal mask for the decoder.
        The masked positions are filled with float('-inf').
        Unmasked positions are filled with float(0.0).
        """
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        # Add batch dimension to be compatible with other masks
        return mask.unsqueeze(0)
    
    def create_pad_mask(self, matrix, pad_idx):
        """
        Create a mask to hide padding tokens.
        Returns a mask of shape [batch_size, 1, seq_len] where 1 indicates a valid token
        and 0 indicates a padding token.
        """
        # Create a mask that is 1 for non-padding tokens and 0 for padding tokens
        # Shape: [batch_size, seq_len]
        mask = (matrix != pad_idx).float()
        
        # Add a dimension to make it compatible with attention scores
        # Shape: [batch_size, 1, seq_len]
        return mask.unsqueeze(1)

In [31]:
def train_transformer(model, src_vocab_size=10000, tgt_vocab_size=10000):
    batch_size = 32
    max_len = 100
    
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=0.0001, 
        betas=(0.9, 0.98), 
        eps=1e-9
    )
    
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    src = torch.randint(1, src_vocab_size, (batch_size, max_len))
    tgt = torch.randint(1, tgt_vocab_size, (batch_size, max_len))
    
    src_mask = model.create_pad_mask(src, 0)
    
    tgt_padding_mask = model.create_pad_mask(tgt[:, :-1], 0)
    
    tgt_look_ahead_mask = model.generate_square_subsequent_mask(tgt[:, :-1].size(1))
    
    if tgt_padding_mask.device != tgt_look_ahead_mask.device:
        tgt_look_ahead_mask = tgt_look_ahead_mask.to(tgt_padding_mask.device)
    
    if tgt_padding_mask is not None:
        expanded_padding_mask = tgt_padding_mask.expand(-1, tgt[:, :-1].size(1), -1)
        inf_padding_mask = expanded_padding_mask.float().masked_fill(expanded_padding_mask == 0, float('-inf'))
        tgt_mask = torch.minimum(inf_padding_mask, tgt_look_ahead_mask)
    else:
        tgt_mask = tgt_look_ahead_mask
    
    output = model(src, tgt[:, :-1], src_mask, tgt_mask)
    
    loss = criterion(
        output.contiguous().view(-1, tgt_vocab_size), 
        tgt[:, 1:].contiguous().view(-1)
    )
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return model, loss


def inference_example(model, src, max_len=100, start_symbol=2):
    model.eval()
    
    src_mask = model.create_pad_mask(src, 0)
    encoder_output = model.encode(src, src_mask)
    
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src).long()
    
    for i in range(max_len - 1):
        tgt_mask = model.generate_square_subsequent_mask(ys.size(1)).to(src.device)
        
        out = model.decode(ys, encoder_output, src_mask, tgt_mask)
        prob = model.generator(out[:, -1])
        
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src).fill_(next_word)], dim=1)
        
        if next_word == 3: 
            break
    
    return ys


def training_loop(model, train_dataloader, valid_dataloader, n_epochs, src_vocab_size, tgt_vocab_size):
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=0.0001, 
        betas=(0.9, 0.98), 
        eps=1e-9
    )
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.1, 
        patience=5
    )
    
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    best_valid_loss = float('inf')
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        
        for batch_idx, (src, tgt) in enumerate(train_dataloader):
            src_mask = model.create_pad_mask(src, 0)
            
            tgt_padding_mask = model.create_pad_mask(tgt[:, :-1], 0)
            tgt_look_ahead_mask = model.generate_square_subsequent_mask(tgt[:, :-1].size(1)).to(src.device)
            
            expanded_padding_mask = tgt_padding_mask.expand(-1, tgt[:, :-1].size(1), -1)
            inf_padding_mask = expanded_padding_mask.float().masked_fill(expanded_padding_mask == 0, float('-inf'))
            tgt_mask = torch.minimum(inf_padding_mask, tgt_look_ahead_mask)
            
            optimizer.zero_grad()
            output = model(src, tgt[:, :-1], src_mask, tgt_mask)
            
            loss = criterion(
                output.contiguous().view(-1, tgt_vocab_size), 
                tgt[:, 1:].contiguous().view(-1)
            )
            
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            
            if (batch_idx + 1) % 100 == 0:
                print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Train loss: {loss.item():.4f}')
        
        model.eval()
        valid_loss = 0
        
        with torch.no_grad():
            for src, tgt in valid_dataloader:
                src_mask = model.create_pad_mask(src, 0)
                tgt_padding_mask = model.create_pad_mask(tgt[:, :-1], 0)
                tgt_look_ahead_mask = model.generate_square_subsequent_mask(tgt[:, :-1].size(1)).to(src.device)
                
                expanded_padding_mask = tgt_padding_mask.expand(-1, tgt[:, :-1].size(1), -1)
                inf_padding_mask = expanded_padding_mask.float().masked_fill(expanded_padding_mask == 0, float('-inf'))
                tgt_mask = torch.minimum(inf_padding_mask, tgt_look_ahead_mask)
                
                output = model(src, tgt[:, :-1], src_mask, tgt_mask)
                
                loss = criterion(
                    output.contiguous().view(-1, tgt_vocab_size), 
                    tgt[:, 1:].contiguous().view(-1)
                )
                
                valid_loss += loss.item()
        
        train_loss /= len(train_dataloader)
        valid_loss /= len(valid_dataloader)
        
        scheduler.step(valid_loss)
        
        print(f'Epoch: {epoch+1}, Train loss: {train_loss:.4f}, Valid loss: {valid_loss:.4f}')
        
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'best_transformer_model.pt')
            print(f'Best model saved with validation loss: {best_valid_loss:.4f}')
    
    return model