Scaled Dot product atention

In [130]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [131]:
seq_len = 4
batch_size = 1
input_dim = 512
d_model = 512
x = torch.randn((batch_size,seq_len,input_dim))

In [132]:
x.shape

torch.Size([1, 4, 512])

In [133]:
qkv_layer = nn.Linear(input_dim,3*d_model)

In [134]:
qkv = qkv_layer(x)

In [135]:
num_heads = 8
head_dim = d_model//num_heads
qkv = qkv.reshape(batch_size,seq_len,num_heads,3*head_dim)

In [136]:
qkv = qkv.permute(0,2,1,3)
qkv.shape

torch.Size([1, 8, 4, 192])

In [137]:
q, k,v = qkv.chunk(3,dim=-1)

In [138]:
d_k = q.size()[-1]
scaled = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

In [139]:
mask = torch.full(scaled.size(),float('-inf'))
mask = torch.triu(mask,diagonal=1)
mask

tensor([[[[0., -inf, -inf, -inf],
          [0., 0., -inf, -inf],
          [0., 0., 0., -inf],
          [0., 0., 0., 0.]],

         [[0., -inf, -inf, -inf],
          [0., 0., -inf, -inf],
          [0., 0., 0., -inf],
          [0., 0., 0., 0.]],

         [[0., -inf, -inf, -inf],
          [0., 0., -inf, -inf],
          [0., 0., 0., -inf],
          [0., 0., 0., 0.]],

         [[0., -inf, -inf, -inf],
          [0., 0., -inf, -inf],
          [0., 0., 0., -inf],
          [0., 0., 0., 0.]],

         [[0., -inf, -inf, -inf],
          [0., 0., -inf, -inf],
          [0., 0., 0., -inf],
          [0., 0., 0., 0.]],

         [[0., -inf, -inf, -inf],
          [0., 0., -inf, -inf],
          [0., 0., 0., -inf],
          [0., 0., 0., 0.]],

         [[0., -inf, -inf, -inf],
          [0., 0., -inf, -inf],
          [0., 0., 0., -inf],
          [0., 0., 0., 0.]],

         [[0., -inf, -inf, -inf],
          [0., 0., -inf, -inf],
          [0., 0., 0., -inf],
          [0., 0., 0.,

In [140]:
attention = F.softmax(scaled+mask,dim=-1)

In [141]:
value = torch.matmul(attention,v)
value.shape

torch.Size([1, 8, 4, 64])

<h1>Multi Head Attention</h1>

In [142]:
def scaled_dot_product(q,k,v,mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(d_k)
    if mask is not None:
        scaled+=mask
    attention = F.softmax(scaled,dim=-1)
    values = torch.matmul(attention,v)
    return values,attention
    

class multiheadAttention(nn.Module):

    def __init__(self,d_model,num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dims = d_model//num_heads
        self.qkv_layer = nn.Linear(d_model,3*d_model)
        self.liner_layer = nn.Linear(d_model,d_model)

    
    def forward(self,x,mask=None):
        batch_size,seq_len,_=x.size()    
        qkv = self.qkv_layer(x)
        qkv = qkv.reshape(batch_size,seq_len,self.num_heads,3*self.head_dims)
        qkv = qkv.permute(0,2,1,3)
        q,k,v = qkv.chunk(3,dim=-1)
        values,attention = scaled_dot_product(q,k,v,mask)
        values = values.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        out = self.liner_layer(values)
        return out
        


In [143]:
batch_size=30
seq_len=5
x = torch.randn((batch_size,seq_len,input_dim))
model = multiheadAttention(512,8)
out = model.forward(x)

In [144]:
out.shape

torch.Size([30, 5, 512])

<h1>Positonal encodidng</h1>>

In [145]:
import torch 
import torch.nn as nn

max_seq_len=10
d_model = 512

In [146]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_seq_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

<h1>Pointwise feedforward network</h1>

In [147]:
class PointwiseFeedForward(nn.Module):
    def __init__(self,d_in,dff):
        super().__init__()
        self.din = d_in
        self.dff = dff
        self.linear1 = nn.Linear(d_in,dff)
        self.linear2 = nn.Linear(dff,d_in)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x


<h1>Encoder Layer</h1>

In [148]:
class encoderLayer(nn.Module):
    def __init__(self, d_model,num_heads,d_ff,):
        super().__init__()
        self.attention = multiheadAttention(d_model,num_heads)
        self.feedForward = PointwiseFeedForward(d_model,d_ff)
        self.layerNorm1 = nn.LayerNorm(d_model)
        self.layerNorm2 = nn.LayerNorm(d_model)

    def forward(self,x,mask=None):
        attention_out = self.attention(x)
        x = self.layerNorm1(x+attention_out)
        feedforward = self.feedForward(x)
        x = self.layerNorm2(feedforward+x)
        return x

 

<h1>Complete encoder with stacking

In [149]:
class Encoder(nn.Module):
    def __init__(self, num_layer,d_model,num_heads,d_ff,input_vocab_size,max_seq_len,dropout):
        super().__init__()
        self.d_model = d_model
        self.num_layer = num_layer
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.input_vocab_size = input_vocab_size
        self.max_seq_len = max_seq_len
        self.dropout = dropout
        self.embedding = nn.Embedding(self.input_vocab_size,self.d_model)
        self.pos = PositionalEncoding(self.d_model,self.max_seq_len,self.dropout)
        self.layers = nn.ModuleList([encoderLayer(self.d_model,self.num_heads,self.d_ff) for _ in range(self.num_layer)])
        self.norm = nn.LayerNorm(self.d_model)

    def forward(self,x):
        x = self.embedding(x)*math.sqrt(self.d_model)
        x = self.pos(x)
        for layer in self.layers:
            x = layer(x)

        return self.norm(x)


<h1>Decoder Stack

In [150]:
class CrossMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dims = d_model // num_heads

        # --- Key Change Here ---
        # Layer for Query (from decoder)
        self.w_q = nn.Linear(d_model, d_model)
        # Layers for Key and Value (from encoder)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # Final output layer
        self.linear_layer = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, y, x, mask=None):
        # y is the query source (from decoder), x is the key/value source (from encoder)
        batch_size = y.size(0)
        seq_len_q = y.size(1)
        seq_len_kv = x.size(1)

        # 1. Project Query, Key, Value
        # Note: q comes from y, k and v come from x
        q = self.w_q(y)
        k = self.w_k(x)
        v = self.w_v(x)

        # 2. Split into heads
        q = q.reshape(batch_size, seq_len_q, self.num_heads, self.head_dims).permute(0, 2, 1, 3)
        k = k.reshape(batch_size, seq_len_kv, self.num_heads, self.head_dims).permute(0, 2, 1, 3)
        v = v.reshape(batch_size, seq_len_kv, self.num_heads, self.head_dims).permute(0, 2, 1, 3)

        # 3. Apply attention
        values, attention = scaled_dot_product(q, k, v, mask)

        # 4. Combine heads and pass through final linear layer
        values = values.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        out = self.linear_layer(values)
        out = self.dropout(out)

        return out
        

In [151]:
class DecoderLayer(nn.Module):
    """A single layer of the Transformer Decoder."""
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super().__init__()
        self.self_attention = multiheadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.cross_attention = CrossMultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(p=drop_prob)

        self.ffn = PointwiseFeedForward(d_in=d_model,dff=ffn_hidden)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, self_attention_mask, cross_attention_mask):
        # x is from encoder, y is from previous decoder layer
        residual_y = y
        y = self.self_attention(y, mask=self_attention_mask)
        y = self.dropout1(y)
        y = self.norm1(y + residual_y)

        residual_y = y
        y = self.cross_attention(x, y, mask=cross_attention_mask)
        y = self.dropout2(y)
        y = self.norm2(y + residual_y)

        residual_y = y
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.norm3(y + residual_y)
        return y


In [152]:
class Decoder(nn.Module):
    """The full Decoder stack."""
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers, vocab_size, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len, drop_prob)
        self.layers = nn.ModuleList([DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])
        self.d_model = d_model
        
    def forward(self, x, y, self_attention_mask, cross_attention_mask):
        y = self.embedding(y) * math.sqrt(self.d_model)
        y = self.positional_encoding(y)
        for layer in self.layers:
            y = layer(x, y, self_attention_mask, cross_attention_mask)
        return y

In [163]:
class Transformer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads,  num_layers, max_len,
                 src_vocab_size, tgt_vocab_size, src_pad_idx, tgt_pad_idx, device,drop_prob=0.1,):
        super().__init__()
        self.encoder = Encoder(d_model=d_model,num_heads=num_heads,num_layer=num_layers,input_vocab_size=src_vocab_size,dropout=drop_prob,max_seq_len=max_len,d_ff=ffn_hidden)
        self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, tgt_vocab_size, max_len)
        self.linear = nn.Linear(d_model, tgt_vocab_size)
        self.device = device
        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx

    def make_src_mask(self, src):
        # src shape: (batch_size, src_len)
        # mask shape: (batch_size, 1, 1, src_len)
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_tgt_mask(self, tgt):
        # tgt shape: (batch_size, tgt_len)
        tgt_len = tgt.shape[1]
        # Padding mask shape: (batch_size, 1, 1, tgt_len)
        tgt_pad_mask = (tgt != self.tgt_pad_idx).unsqueeze(1).unsqueeze(2)
        # Look-ahead mask shape: (1, 1, tgt_len, tgt_len)
        tgt_lookahead_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=self.device)).bool()
        # Combined mask
        tgt_mask = tgt_pad_mask & tgt_lookahead_mask
        return tgt_mask.to(self.device)

    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        
        enc_src = self.encoder(src)
        dec_output = self.decoder(enc_src, tgt, tgt_mask, src_mask)
        
        output = self.linear(dec_output)
        return output

In [None]:
def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [155]:
get_device()

device(type='cuda')

In [162]:
if __name__ == '__main__':
    device = 'cpu'
    
    # Model parameters
    d_model = 512
    num_heads = 8
    num_layers = 8
    d_ff = 2048
    max_len = 100
    dropout = 0.1
    src_vocab_size = 10000
    tgt_vocab_size = 10000
    src_pad_idx = 0
    tgt_pad_idx = 0

    # Instantiate the model
    model = Transformer(
        d_model=d_model,
        ffn_hidden=d_ff,
        num_heads=num_heads,
        num_layers=num_layers,
        max_len=max_len,
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        src_pad_idx=src_pad_idx,
        tgt_pad_idx=tgt_pad_idx,
        device=device
    ).to(device)

    # Create dummy data
    batch_size = 32
    src_len = 60
    tgt_len = 60
    
    src = torch.randint(1, src_vocab_size, (batch_size, src_len)).to(device)
    tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len)).to(device)
    
    # Add some padding
    src[0, -15:] = src_pad_idx
    tgt[0, -10:] = tgt_pad_idx

    # Forward pass
    output = model(src, tgt)

    print("Shape of final Transformer output:", output.shape)
    assert output.shape == (batch_size, tgt_len, tgt_vocab_size)
    print("\nSuccessfully built and tested the complete Transformer model!")

Shape of final Transformer output: torch.Size([32, 60, 10000])

Successfully built and tested the complete Transformer model!
