In [10]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
from torchvision import transforms

import math
import random
import matplotlib.pyplot as plt

MULTI HEAD SELF ATTENTION
 - self-attention computes relationships between elements in a sequence — like how patch A attends to patch B
 - instead of just one set of attention weights, MHSA splits it into multiple “heads”, each learning different relationships
 - they all look at the same input, but each sees different patterns
 - Each head looks at all patches, but learns to focus on different visual cues:

        | Head 1 | → center of object
        | Head 2 | → edges
        | Head 3 | → object boundaries
        | Head 4 | → spatial layout
        | ... |
        | Head 12| → fine-grained textures

 - That’s the multi-head magic: parallel perspectives.

In [13]:

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads

        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"

        self.query = nn.Linear(dim, dim)   # has learnable weights and biases that transform the input into the query space
        self.key   = nn.Linear(dim, dim)   # has learnable weights and biases that transform the input into the key space
        self.value = nn.Linear(dim, dim)   # has learnable weights and biases that transform the input into the value space
        self.out   = nn.Linear(dim, dim)   # has learnable weights and biases similar to the other linear layer

    def forward(self, x):
        batch_size = x.size(0)

        # Linear projections
        Q = self.query(x)  # (batch_size, seq_length, dim)
        K = self.key(x)    # (batch_size, seq_length, dim)
        V = self.value(x)  # (batch_size, seq_length, dim)

        # Split into multiple heads
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)

        # Scaled dot-product attention
        scores  = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (batch_size, num_heads, seq_length, seq_length)
        attn    = torch.softmax(scores, dim=-1)  # (batch_size, num_heads, seq_length, seq_length)
        context = torch.matmul(attn, V)          # (batch_size, num_heads, seq_length, head_dim)

        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.dim)  # (batch_size, seq_length, dim)

        # Final linear layer
        out = self.out(context)  # (batch_size, seq_length, dim)

        return out

In [14]:
# Example usage
# seq_length is the number of tokens (or patches) in each sample.
# dim is the embedding dimension of each token (or patch)

x = torch.randn(1, 64, 128)  # Example input tensor with shape (batch_size, seq_length, dim)
multi_head_attn = MultiHeadAttention(dim=128, num_heads=8)
output = multi_head_attn(x)
print(output.shape)  # Should be (1, 64, 128)

torch.Size([1, 64, 128])
