In [37]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [38]:
import torch
import math
import torch.nn as nn

In [39]:
class InputEmbedding(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, 
                                      d_model)
        self.scale = math.sqrt(d_model)  # Precompute scaling factor

    def forward(self, x):
        return self.embedding(x) * self.scale

In [40]:
class PositionEncoding(nn.Module):
    def __init__(self,d_embed_model,max_seq_length,drop_out = 0.1):
        super().__init__()
        self.embed_dim = d_embed_model
        self.max_seq = max_seq_length
        self.dropout = nn.Dropout(drop_out)
        pe = torch.zeros(max_seq_length,d_embed_model)
        for pos in range(max_seq_length):
            for i in range(0,d_embed_model,2):
                pe[pos,i] = math.sin(pos / (10000 ** ((2 * (i)) / self.embed_dim)))
                if i + 1 < d_embed_model:  
                    pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / self.embed_dim)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe',pe)
    def forward(self,x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)

In [41]:
class TransformerInputEmbedding(nn.Module):
    def __init__(self,d_model,vocab_size,max_seq_len,dropout = 0.1):
        super().__init__()
        self.input_embedding = InputEmbedding(d_model,vocab_size)
        self.position_encoding = PositionEncoding(d_model,max_seq_len,dropout)
    def forward(self,x):
        x = self.input_embedding(x)
        x = self.position_encoding(x)
        return x

In [42]:
class MultiheadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        assert d_model % num_heads == 0, \
            "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Linear projections for Query, Key, Value
        self.W_query = nn.Linear(d_model, d_model, bias=False)
        self.W_key = nn.Linear(d_model, d_model, bias=False)
        self.W_value = nn.Linear(d_model, d_model, bias=False)
        self.out_prj = nn.Linear(d_model, d_model, bias=False)

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        batch_size, seq_len, d_in = query.shape
        
        # Linear projections
        queries = self.W_query(query)  # Shape: [batch_size, seq_len, d_model]
        keys = self.W_key(key)
        values = self.W_value(value)
        
        # Reshape for multi-head: [batch_size, seq_len, num_heads, head_dim]
        queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention scores: [batch_size, num_heads, seq_len, seq_len]
        attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Attention weights
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Compute context vector: [batch_size, num_heads, seq_len, head_dim]
        context_vec = torch.matmul(attn_weights, values)
        
        # Combine heads: [batch_size, seq_len, d_model]
        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Final linear projection
        output = self.out_prj(context_vec)
        return output


In [43]:
class LayerNormalization(nn.Module):
    def __init__(self,emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self,x):
        mean = x.mean(dim = -1,keepdim = True)
        var = x.var(dim = -1,keepdim = True,unbiased = False)
        norm_x = (x-mean)/torch.sqrt(var+self.eps)
        return self.scale*norm_x+self.shift

In [44]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715 * torch.pow(x, 3))
        ))

In [45]:
class FeedForward(nn.Module):
    def __init__(self,d_model,dropout):
        super().__init__()
        self.d_model = d_model
        self.layers = nn.Sequential(
            nn.Linear(self.d_model,4*self.d_model),
            GELU(),
            nn.Linear(4*self.d_model,self.d_model)
        )
    def forward(self,x):
        return self.layers(x)

In [46]:
class ResidualConnection(nn.Module):
    def __init__(self,d_model,dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(d_model)
    def forward(self,x,sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [47]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        self.encoder_attention = MultiheadAttention(d_model, num_heads, dropout)
        self.residual1 = ResidualConnection(d_model, dropout)
        self.feed_forward = FeedForward(d_model, dropout)
        self.residual2 = ResidualConnection(d_model, dropout)
        self.layer_norm1 = LayerNormalization(d_model)
        self.layer_norm2 = LayerNormalization(d_model)

    def forward(self, x):
        # Pass query, key, and value explicitly for self-attention
        x = self.residual1(x, lambda x: self.encoder_attention(x, x, x))  # Pass query, key, value
        x = self.layer_norm1(x)

        x = self.residual2(x, lambda x: self.feed_forward(x))  # Use lambda
        x = self.layer_norm2(x)
        return x


In [48]:
class TransformerEncoder(nn.Module):
    def __init__(self,d_model,num_heads,dropout,num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderBlock(d_model,num_heads,dropout)
            for _ in range(num_layers)
        ])
        self.norm = LayerNormalization(d_model)
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [49]:
class MaskedMultiheadAttention(nn.Module):
    def __init__(self,d_model,num_heads,dropout,context_length):
        super().__init__()
        assert d_model % num_heads == 0, \
            "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

         # Linear projections for Query, Key, Value
        self.W_query = nn.Linear(d_model, d_model, bias=False)
        self.W_key = nn.Linear(d_model, d_model, bias=False)
        self.W_value = nn.Linear(d_model, d_model, bias=False)
        self.out_prj = nn.Linear(d_model, d_model, bias=False)

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self,x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_model)
        context_vec = self.out_prj(context_vec) # optional projection

        return context_vec

In [50]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout, context_length):
        super().__init__()
        self.masked_attention = MaskedMultiheadAttention(d_model, num_heads, dropout, context_length)
        self.encoder_attention = MultiheadAttention(d_model, num_heads, dropout)
        self.feed_forward = FeedForward(d_model, dropout)
        self.residual1 = ResidualConnection(d_model, dropout)
        self.residual2 = ResidualConnection(d_model, dropout)
        self.residual3 = ResidualConnection(d_model, dropout)
        self.layernorm1 = LayerNormalization(d_model)
        self.layernorm2 = LayerNormalization(d_model)
        self.layernorm3 = LayerNormalization(d_model)

    def forward(self, x, encoder_output):
        # Masked Self Attention
        x = self.residual1(x, lambda x: self.masked_attention(x))  # Use lambda
        x = self.layernorm1(x)

        # Encoder-Decoder Attention
        x = self.residual2(x, lambda x: self.encoder_attention(x, encoder_output, encoder_output))  # Pass query, key, value
        x = self.layernorm2(x)

        # Feed Forward
        x = self.residual3(x, lambda x: self.feed_forward(x))  # Use lambda
        x = self.layernorm3(x)
        return x


In [51]:
class TransformerDecoder(nn.Module):
    def __init__(self,d_model,num_heads,dropout,context_length,num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderBlock(d_model,num_heads,dropout,context_length)
            for _ in range(num_layers)
        ])
        self.norm = LayerNormalization(d_model)
    def forward(self,x,encoder_output):
        for layer in self.layers:
            x = layer(x,encoder_output)
        return self.norm(x)

In [52]:
class ProjectionLayer(nn.Module):
    def __init__(self,d_model,vocab_size):
        super().__init__()
        self.proj = nn.Linear(d_model,vocab_size)
    def forward(self,x):
        return self.proj(x)

In [53]:
class Transformer(nn.Module):
    def __init__(self,src_vocab_size,tgt_vocab_size,d_model,num_heads,
                 dropout,num_encoder_layers,num_decoder_layers,max_seq_len):
        super().__init__()
        self.src_embedding = TransformerInputEmbedding(d_model,src_vocab_size,
                                                      max_seq_len,dropout)
        self.tgt_embedding = TransformerInputEmbedding(d_model,tgt_vocab_size,
                                                      max_seq_len,dropout)
        self.encoder = TransformerEncoder(d_model,num_heads,dropout,
                                          num_encoder_layers)
        self.decoder = TransformerDecoder(d_model,num_heads,dropout,
                                          max_seq_len,
                                          num_decoder_layers)
        self.fc_out = ProjectionLayer(d_model,tgt_vocab_size)
    def forward(self,src,tgt,apply_softmax = False):
        # Embedding + Positional Encoding for source
        src = self.src_embedding(src)
        # Embedding + Positional Encoding for target
        tgt = self.tgt_embedding(tgt)

        encoder_output = self.encoder(src)
        decoder_output = self.decoder(tgt,encoder_output)
        output = self.fc_out(decoder_output)
        if apply_softmax:
            output = nn.functional.softmax(output, dim=-1)
        return output

In [59]:
src_vocab_size = 11
target_vocab_size = 11
num_layers = 6
seq_length= 12
src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1], 
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1], 
                       [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])
model = Transformer(src_vocab_size = src_vocab_size,
                    tgt_vocab_size = target_vocab_size,
                    d_model = 512, num_heads = 8,
                    dropout = 0.1, 
                    num_encoder_layers = num_layers,
                    num_decoder_layers = num_layers,
                    max_seq_len = seq_length)
output = model(src, target)
print(output.shape) 
print(model)

torch.Size([2, 12, 11])
Transformer(
  (src_embedding): TransformerInputEmbedding(
    (input_embedding): InputEmbedding(
      (embedding): Embedding(11, 512)
    )
    (position_encoding): PositionEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (tgt_embedding): TransformerInputEmbedding(
    (input_embedding): InputEmbedding(
      (embedding): Embedding(11, 512)
    )
    (position_encoding): PositionEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (encoder_attention): MultiheadAttention(
          (W_query): Linear(in_features=512, out_features=512, bias=False)
          (W_key): Linear(in_features=512, out_features=512, bias=False)
          (W_value): Linear(in_features=512, out_features=512, bias=False)
          (out_prj): Linear(in_features=512, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
       