🧠 What is Multi-Head vs Single-Head?  

Single-Head	One set of weight matrices (W_q, W_k, W_v) is used for all attention.  

Multi-Head	Multiple independent sets of (W_q, W_k, W_v), then results are concatenated. Each “head” learns to focus on different patterns.

If you apply attention over the whole space using a single projection, then all information must be compressed into a single set of attention scores.

Each subspace gets its own Q, K, V projection.

Each head independently attends over the sequence using only its subspace.

You combine the outputs (via concat + linear layer) to get richer representations.

Heads dont interact to each other during their attention computation

but during softmax on scores we concat all heads output and in final project this using W_o to final output here they all combine

Each head uses smaller matrices,

All heads are computed in parallel (very GPU-friendly).

In CNNs:

Receptive field grows slowly — token sees neighbors only after many layers.

In attention:

Every token attends to all tokens in the same layer.

This gives a global receptive field in a single step.

Token at position 1 can immediately focus on position 99 if needed.



In [3]:
import numpy as np

def softmax(x, axis=-1):
    exps = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exps / np.sum(exps, axis=axis, keepdims=True)

def dsoftmax(softmax_out, grad_out):
    s = softmax_out
    return s * (grad_out - np.sum(grad_out * s, axis=-1, keepdims=True))

class MultiHeadSelfAttention:
    def __init__(self, embed_dim, num_heads):
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" # Each head gets a portion of total embedding
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads #Each head works on a subspace of dimension head_dim.

        # Parameters
        self.W_q = np.random.randn(embed_dim, embed_dim) * 0.1  # Exactly Same as Single Attention
        self.W_k = np.random.randn(embed_dim, embed_dim) * 0.1
        self.W_v = np.random.randn(embed_dim, embed_dim) * 0.1
        self.W_o = np.random.randn(embed_dim, embed_dim) * 0.1

        # Gradients
        self.dW_q = np.zeros_like(self.W_q)    # Also Same as Single Attention
        self.dW_k = np.zeros_like(self.W_k)
        self.dW_v = np.zeros_like(self.W_v)
        self.dW_o = np.zeros_like(self.W_o)

    def forward(self, x):
        self.x = x
        T, D = x.shape # T = seq_len ,D = embed_dim

        self.Q = x @ self.W_q  # Same as Single Attention
        self.K = x @ self.W_k
        self.V = x @ self.W_v

        self.Q_ = self.Q.reshape(T, self.num_heads, self.head_dim) # shape (T,h,d) where h is num_heads and d : head dim .T : Target seq len or no of queries
        self.K_ = self.K.reshape(T, self.num_heads, self.head_dim) # shape (s,h,d) where s is source seq length (no of keys)
        self.V_ = self.V.reshape(T, self.num_heads, self.head_dim) # value vector for each source token per head

        # How similar keys are to token queries
        self.scores = np.einsum('thd,shd->ths', self.Q_, self.K_) / np.sqrt(self.head_dim)    # Einsum : used to write operations using indices ;  doing dot product here for query and key vectors for all heads
        self.weights = softmax(self.scores, axis=2)  # shape of scores : (T,num_heads,T)attention from every token to other token
        #For each token t and head h, compute attention distribution over all tokens s (including itself).
        self.attn_output = np.einsum('ths,shd->thd', self.weights, self.V_)
        self.concat_output = self.attn_output.reshape(T, D)
        self.output = self.concat_output @ self.W_o

        return self.output

    def backward(self, d_out):
        T, D = self.x.shape

        # dOutput wrt W_o and concat_output
        self.dW_o += self.concat_output.T @ d_out
        d_concat = d_out @ self.W_o.T  # (T, D)

        # Back to per-head output
        d_attn_output = d_concat.reshape(T, self.num_heads, self.head_dim)

        # Attention weights gradients
        d_weights = np.einsum('thd,shd->ths', d_attn_output, self.V_)
        dV_ = np.einsum('ths,thd->shd', self.weights, d_attn_output)  # corrected einsum

        # Backprop through softmax
        dscores = dsoftmax(self.weights, d_weights) / np.sqrt(self.head_dim)

        dQ_ = np.einsum('ths,shd->thd', dscores, self.K_)
        dK_ = np.einsum('ths,thd->shd', dscores, self.Q_)

        # Reshape back to (T, D)
        dQ = dQ_.reshape(T, D)
        dK = dK_.reshape(T, D)
        dV = dV_.reshape(T, D)

        self.dW_q += self.x.T @ dQ
        self.dW_k += self.x.T @ dK
        self.dW_v += self.x.T @ dV

        dx_q = dQ @ self.W_q.T
        dx_k = dK @ self.W_k.T
        dx_v = dV @ self.W_v.T

        dx = dx_q + dx_k + dx_v
        return dx

    def zero_grad(self):
        self.dW_q.fill(0)
        self.dW_k.fill(0)
        self.dW_v.fill(0)
        self.dW_o.fill(0)
