<a href="https://colab.research.google.com/github/Priyanshu-Naik/Gen_AI/blob/main/Multihead_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Scaled Dot-Product Attention Function**

Q, K, V are queries, keys, values derived from the same source in self-attention.

It results in values i.e the weighted sum for each position and head.

Softmax ensures the attention weights sum to 1.

If masking, irrelevant positions (like future tokens or padding) get large negative values in logits, so after softmax attention there is 0

**Multi-Head Attention Class**

Every step mimics the original Transformer:

Project to QKV,

Reshape for multiple heads,

Split into Q, K, V,

Compute attention,

Concatenate heads,

Linear output.


In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.size()[-1]

    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)

    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.input_dim = input_dim
        self.head_dim = d_model // num_heads

        # For efficiency, compute Q, K, V for all heads at once with a single linear layer
        self.qkv_layer = nn.Linear(input_dim, 3 * d_model)
         # Final projection, combines all heads' outputs
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, input_dim = x.size()

        # Step 1: Project x into concatenated q, k, v for ALL heads at once
        qkv = self.qkv_layer(x)

        # Step 2: reshape into (batch, seq_len, num_heads, 3 * head_dim)
        qkv = qkv.reshape(batch_size, seq_len, self.num_heads, 3 * self.head_dim)

        # Step 4: Split the last dimension into q, k, v (each get last dimension of head_dim)
        q, k, v = qkv.chunk(3, dim=-1) # Each: (batch, seq_len, num_heads, head_dim)

        # Permute q, k, v to (batch, num_heads, seq_len, head_dim)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # Step 5: Apply scaled dot product attention
        values, attention = scaled_dot_product_attention(q, k, v, mask)

        # Step 6: Merge the heads (permute before reshape)
        values = values.permute(0, 2, 1, 3)
        values = values.reshape(batch_size, seq_len, self.num_heads * self.head_dim)

        # Step 7: Final linear layer
        output = self.linear_layer(values)
        return output

input_dim = 1024
d_model = 512
num_heads = 8
batch_size = 30
seq_len = 5

x = torch.randn((batch_size, seq_len, input_dim))
multihead_attn = MultiHeadAttention(input_dim, d_model, num_heads)
output = multihead_attn.forward(x)
print(output.shape)

torch.Size([30, 5, 512])
