In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

# Multi-Head Attention Implementation

## Step 2: Scaled Dot-Product Attention Function

This is the core of self-attention:

**Attention(Q, K, V) = softmax((QK^T) / √d_k) V**

Where:
- **Q, K, V** are queries, keys, values derived from the same source in self-attention
- **Results in values**: the weighted sum for each position and head
- **Softmax ensures** the attention weights sum to 1
- **If masking**: irrelevant positions (like future tokens or padding) get large negative values in logits, so after softmax attention there is 0


In [29]:
class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim # input embedding size
        self.d_model = d_model # output embedding size
        self.num_heads = num_heads # number of attention heads
        self.head_dim = d_model // num_heads # size for the single head
		# computer q,k,v using a single linear layer
        self.qkv_layer = nn.Linear(input_dim, d_model * 3)
		# final layer
        self.linear_layer = nn.Linear(d_model, d_model) 
    def dot_product(self, q, k, v, mask = None):
        d_k = q.size()[-1]	
        # compute the square root of d_k
        scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
        if mask is not None:
            scaled += mask
        attention_mat = F.softmax(scaled, dim = -1)
        values = torch.matmul(attention_mat, v)
        return values, attention_mat
        
    def softmax(self, values):
        exp_values = np.exp(values)
        sum_values = np.sum(exp_values)
        return exp_values / sum_values

    # Every step mimics the original Transformer:
    # Project to QKV,
    # Reshape for multiple heads,
    # Split into Q, K, V,
    # Compute attention,
    # Concatenate heads,
    # Linear output.
    def forward(self, x, mask = None):
        batch_size, sequence_length, input_dim = x.size()
		# Step 1: Project x into concatenated q, k, v for ALL heads at once
        qkv = self.qkv_layer(x)
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, self.head_dim * 3)
		# permute
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim = -1)
		# compute attention
        values, attention_mat = self.dot_product(q, k, v, mask)
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        out = self.linear_layer(values)
        return out, attention_mat


In [30]:
# ======================
# ✅ Test Case
# ======================
batch_size = 2
sequence_length = 5
input_dim = 16
d_model = 32
num_heads = 4

x = torch.randn((batch_size, sequence_length, input_dim))
model = MultiheadAttention(input_dim, d_model, num_heads)
output, attn = model.forward(x)

print("Output attention shape:", output.shape)       # torch.Size([2, 5, 32])
print("Attention matrix shape:", attn.shape)  # torch.Size([2, 4, 5, 5])

Output attention shape: torch.Size([2, 5, 32])
Attention matrix shape: torch.Size([2, 4, 5, 5])
