# Understanding Transformers using PyTorch
https://www.geeksforgeeks.org/deep-learning/transformer-using-pytorch/

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

## MultiHeadAttention

In [34]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model # embedding size (e.g. 512)
        self.num_heads = num_heads # number of attention heads (e.g. 8)

        # Ensuring each head gets equal dimensions
        self.d_k = d_model // num_heads # (e.g. 512 // 8 = 64)
        print(f"Number of dimensions each head gets: {self.d_k}")

        # Creating Q(Query), K(Key), V(Value) from the input embeddings.
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Output Layer: This combines all attention heads back into one vector.
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        '''
            This is the core attention computation.
            Compute attention scores by taking the dot product of Q and K, scaling the result and applying softmax to normalise.
            - Measures similarity between Q and K
            - Division by √d_k prevents extremely large values → stabilizes training
            - Mask used for:
                - Padding mask
                - Causal (future-token) masking
            - Apply softmax = Converts scores into probabilities.
            - Softmax example: 
                tensor([-0.8058, -0.9375,  1.2299,  0.2358, -1.0952,  0.0997,  0.8335,  2.3506, -0.3834,  0.1132]) ----> 
                tensor([0.0207, 0.0182, 0.1587, 0.0587, 0.0155, 0.0512, 0.1067, 0.4867, 0.0316, 0.0519])
            - Multiply attn_probs with V
        '''
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # (batch, heads, seq_len, seq_len) seq_len = -1 and -2
        print(attn_scores.shape)

        # Applying mask if needed,
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Converts scores into probabilities.
        attn_probs = torch.softmax(attn_probs, dim=-1)

        # Multiply attn_probs with V
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        '''
            - Input Shape: (batch_size, seq_length, d_model)
            - Transform to: (batch_size, num_heads, seq_length, d_k)
            - ✔ Allows parallel attention across heads.
        '''
        batch_size, seq_length, d_model = x.size() # example: (32, 512, 512)
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # example: (32, 8, 512, 64)
        
    
        

In [35]:
model = MultiHeadAttention(d_model=512, num_heads=8)

Number of dimensions each head gets: 64


In [36]:
S = torch.randn(32, 512, 512)

In [39]:
model.split_heads(S).shape

torch.Size([32, 8, 512, 64])

In [25]:
X = torch.randint(1, 10, (1, 2, 3))
X.shape

torch.Size([1, 2, 3])

In [26]:
X

tensor([[[2, 3, 2],
         [9, 9, 8]]])

In [27]:
X.transpose(1, 2)

tensor([[[2, 9],
         [3, 9],
         [2, 8]]])

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output