# Multi-Head Attention
**with weight split**

In [53]:
import torch
import torch.nn as nn

## Input Embedding

In [54]:
D_IN = 4
D_OUT = 4 # final output dimension and not per head
CONTEXT_SIZE = 6
NUM_TOKENS = 6
NUM_BATCH = 1 # for simplicity for now
NUM_HEADS = 2

In [55]:
input_embeddings = torch.randn(NUM_BATCH, NUM_TOKENS, D_IN)

print(f"Input Embeddings:\n{input_embeddings}\nShape: {input_embeddings.shape}\n")

Input Embeddings:
tensor([[[-1.4208,  0.5254,  0.0145, -1.0520],
         [-1.3243, -0.2664,  2.0347, -0.2241],
         [-0.6319,  0.8668,  0.3893, -0.0230],
         [ 0.7506,  0.0992, -0.2964,  0.4428],
         [ 0.5634, -0.8032, -0.3909,  0.4640],
         [ 1.2728,  0.6727, -0.2302, -0.7652]]])
Shape: torch.Size([1, 6, 4])



## Attention Weights

In [56]:
Wq = nn.Linear(D_IN, D_OUT, bias=False)
Wk = nn.Linear(D_IN, D_OUT, bias=False)
Wv = nn.Linear(D_IN, D_OUT, bias=False)

print(f"Wq:\n{Wq.weight}\nShape: {Wq.weight.shape}\n")
print(f"Wk:\n{Wk.weight}\nShape: {Wk.weight.shape}\n")
print(f"Wv:\n{Wv.weight}\nShape: {Wv.weight.shape}\n")

Wq:
Parameter containing:
tensor([[ 0.1118, -0.4581, -0.2810, -0.0504],
        [ 0.4673,  0.0952, -0.1767,  0.3812],
        [-0.1484, -0.0739,  0.4602,  0.3879],
        [ 0.2999,  0.3496, -0.2486, -0.0946]], requires_grad=True)
Shape: torch.Size([4, 4])

Wk:
Parameter containing:
tensor([[-0.3001, -0.4434, -0.0820, -0.1382],
        [ 0.4897, -0.3804,  0.1761,  0.3176],
        [ 0.4293, -0.1651, -0.0434,  0.4178],
        [ 0.2808,  0.0722,  0.2693,  0.4605]], requires_grad=True)
Shape: torch.Size([4, 4])

Wv:
Parameter containing:
tensor([[-0.1242, -0.0797,  0.1834, -0.4971],
        [-0.4361, -0.2864,  0.2599,  0.0864],
        [ 0.3415, -0.1870,  0.0523,  0.2140],
        [ 0.0108,  0.3178, -0.4126,  0.1115]], requires_grad=True)
Shape: torch.Size([4, 4])



## Q,K,V Tensors

In [57]:
Q = Wq(input_embeddings)
K = Wk(input_embeddings)
V = Wv(input_embeddings)

print(f"Queries:\n{Q}\nShape: {Q.shape}\n")
print(f"Keys:\n{K}\nShape: {K.shape}\n")
print(f"Values:\n{V}\nShape: {V.shape}\n")

Queries:
tensor([[[-0.3506, -1.0175, -0.2295, -0.1465],
         [-0.5866, -1.0891,  1.0656, -0.9749],
         [-0.5760, -0.2903,  0.1999,  0.0190],
         [ 0.0995,  0.5813, -0.0833,  0.2916],
         [ 0.5175,  0.4328, -0.0241, -0.0586],
         [-0.0626,  0.4078, -0.6414,  0.7465]]], grad_fn=<UnsafeViewBackward0>)
Shape: torch.Size([1, 6, 4])

Keys:
tensor([[[ 0.3377, -1.2272, -1.1369, -0.8416],
         [ 0.3797, -0.2600, -0.7065,  0.0536],
         [-0.2234, -0.5780, -0.4410, -0.0206],
         [-0.3062,  0.4182,  0.5037,  0.3420],
         [ 0.1550,  0.6600,  0.5854,  0.2086],
         [-0.5557,  0.0838,  0.1256, -0.0084]]], grad_fn=<UnsafeViewBackward0>)
Shape: torch.Size([1, 6, 4])

Values:
tensor([[[ 0.6602,  0.3820, -0.8079,  0.0283],
         [ 0.6702,  1.1632, -0.3440, -0.9635],
         [ 0.0922,  0.1265, -0.3625,  0.1055],
         [-0.3756, -0.3945,  0.3171,  0.2113],
         [-0.3083, -0.0772,  0.4215, -0.0361],
         [ 0.1264, -0.8737,  0.1331,  0.2373]]], gra

## Head Dimension

In [58]:
HEAD_DIM = D_OUT // NUM_HEADS

print(f"Head Dimension: {HEAD_DIM}\n")

Head Dimension: 2



## Reshape Q,K,V

In [59]:
Q = Q.view(NUM_BATCH, NUM_TOKENS, NUM_HEADS, HEAD_DIM)
K = K.view(NUM_BATCH, NUM_TOKENS, NUM_HEADS, HEAD_DIM)
V = V.view(NUM_BATCH, NUM_TOKENS, NUM_HEADS, HEAD_DIM)

print(f"Reshaped Queries:\n{Q}\nShape: {Q.shape}\n")
print(f"Reshaped Keys:\n{K}\nShape: {K.shape}\n")
print(f"Reshaped Values:\n{V}\nShape: {V.shape}\n")

Reshaped Queries:
tensor([[[[-0.3506, -1.0175],
          [-0.2295, -0.1465]],

         [[-0.5866, -1.0891],
          [ 1.0656, -0.9749]],

         [[-0.5760, -0.2903],
          [ 0.1999,  0.0190]],

         [[ 0.0995,  0.5813],
          [-0.0833,  0.2916]],

         [[ 0.5175,  0.4328],
          [-0.0241, -0.0586]],

         [[-0.0626,  0.4078],
          [-0.6414,  0.7465]]]], grad_fn=<ViewBackward0>)
Shape: torch.Size([1, 6, 2, 2])

Reshaped Keys:
tensor([[[[ 0.3377, -1.2272],
          [-1.1369, -0.8416]],

         [[ 0.3797, -0.2600],
          [-0.7065,  0.0536]],

         [[-0.2234, -0.5780],
          [-0.4410, -0.0206]],

         [[-0.3062,  0.4182],
          [ 0.5037,  0.3420]],

         [[ 0.1550,  0.6600],
          [ 0.5854,  0.2086]],

         [[-0.5557,  0.0838],
          [ 0.1256, -0.0084]]]], grad_fn=<ViewBackward0>)
Shape: torch.Size([1, 6, 2, 2])

Reshaped Values:
tensor([[[[ 0.6602,  0.3820],
          [-0.8079,  0.0283]],

         [[ 0.6702,  1.163

## Regroup Q,K,V into per Attention Head

In [60]:
Q = Q.transpose(1, 2)  # (NUM_BATCH, NUM_HEADS, NUM_TOKENS, HEAD_DIM)
K = K.transpose(1, 2)  # (NUM_BATCH, NUM_HEADS, NUM_TOKENS, HEAD_DIM)
V = V.transpose(1, 2)  # (NUM_BATCH, NUM_HEADS, NUM_TOKENS, HEAD_DIM)

print(f"Regrouped Queries:\n{Q}\nShape: {Q.shape}\n")
print(f"Regrouped Keys:\n{K}\nShape: {K.shape}\n")
print(f"Regrouped Values:\n{V}\nShape: {V.shape}\n")

Regrouped Queries:
tensor([[[[-0.3506, -1.0175],
          [-0.5866, -1.0891],
          [-0.5760, -0.2903],
          [ 0.0995,  0.5813],
          [ 0.5175,  0.4328],
          [-0.0626,  0.4078]],

         [[-0.2295, -0.1465],
          [ 1.0656, -0.9749],
          [ 0.1999,  0.0190],
          [-0.0833,  0.2916],
          [-0.0241, -0.0586],
          [-0.6414,  0.7465]]]], grad_fn=<TransposeBackward0>)
Shape: torch.Size([1, 2, 6, 2])

Regrouped Keys:
tensor([[[[ 0.3377, -1.2272],
          [ 0.3797, -0.2600],
          [-0.2234, -0.5780],
          [-0.3062,  0.4182],
          [ 0.1550,  0.6600],
          [-0.5557,  0.0838]],

         [[-1.1369, -0.8416],
          [-0.7065,  0.0536],
          [-0.4410, -0.0206],
          [ 0.5037,  0.3420],
          [ 0.5854,  0.2086],
          [ 0.1256, -0.0084]]]], grad_fn=<TransposeBackward0>)
Shape: torch.Size([1, 2, 6, 2])

Regrouped Values:
tensor([[[[ 0.6602,  0.3820],
          [ 0.6702,  1.1632],
          [ 0.0922,  0.1265],
 

## Attention Score

In [61]:
attention_scores = Q @ K.transpose(-2, -1)  # (B, H, T, Dhead) * (B, H, Dhead, T) -> (B, H, T, T)

print(f"Attention Scores:\n{attention_scores}\nShape: {attention_scores.shape}\n")

Attention Scores:
tensor([[[[ 1.1303,  0.1314,  0.6665, -0.3182, -0.7259,  0.1096],
          [ 1.1384,  0.0604,  0.7605, -0.2759, -0.8097,  0.2347],
          [ 0.1617, -0.1433,  0.2965,  0.0550, -0.2809,  0.2958],
          [-0.6798, -0.1134, -0.3582,  0.2127,  0.3991, -0.0065],
          [-0.3563,  0.0840, -0.3658,  0.0226,  0.3658, -0.2513],
          [-0.5216, -0.1298, -0.2217,  0.1897,  0.2594,  0.0690]],

         [[ 0.3842,  0.1543,  0.1042, -0.1657, -0.1649, -0.0276],
          [-0.3910, -0.8052, -0.4498,  0.2033,  0.4204,  0.1420],
          [-0.2432, -0.1402, -0.0885,  0.1072,  0.1210,  0.0249],
          [-0.1507,  0.0745,  0.0307,  0.0578,  0.0120, -0.0129],
          [ 0.0767,  0.0139,  0.0118, -0.0322, -0.0263, -0.0025],
          [ 0.1009,  0.4932,  0.2674, -0.0677, -0.2197, -0.0868]]]],
       grad_fn=<UnsafeViewBackward0>)
Shape: torch.Size([1, 2, 6, 6])



## Negative Infinity Masking

In [62]:
mask = torch.triu(torch.ones(CONTEXT_SIZE, CONTEXT_SIZE), diagonal=1)

print(f"Mask:\n{mask}\nShape: {mask.shape}\n")

Mask:
tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
Shape: torch.Size([6, 6])



In [63]:
neg_inf_masked_attention_scores = attention_scores.masked_fill(mask.bool(), -torch.inf)

print(f"Masked Attention Scores:\n{neg_inf_masked_attention_scores}\nShape: {neg_inf_masked_attention_scores.shape}\n")

Masked Attention Scores:
tensor([[[[ 1.1303,    -inf,    -inf,    -inf,    -inf,    -inf],
          [ 1.1384,  0.0604,    -inf,    -inf,    -inf,    -inf],
          [ 0.1617, -0.1433,  0.2965,    -inf,    -inf,    -inf],
          [-0.6798, -0.1134, -0.3582,  0.2127,    -inf,    -inf],
          [-0.3563,  0.0840, -0.3658,  0.0226,  0.3658,    -inf],
          [-0.5216, -0.1298, -0.2217,  0.1897,  0.2594,  0.0690]],

         [[ 0.3842,    -inf,    -inf,    -inf,    -inf,    -inf],
          [-0.3910, -0.8052,    -inf,    -inf,    -inf,    -inf],
          [-0.2432, -0.1402, -0.0885,    -inf,    -inf,    -inf],
          [-0.1507,  0.0745,  0.0307,  0.0578,    -inf,    -inf],
          [ 0.0767,  0.0139,  0.0118, -0.0322, -0.0263,    -inf],
          [ 0.1009,  0.4932,  0.2674, -0.0677, -0.2197, -0.0868]]]],
       grad_fn=<MaskedFillBackward0>)
Shape: torch.Size([1, 2, 6, 6])



## Attention Weights

In [64]:
attention_weights = torch.softmax(neg_inf_masked_attention_scores / (HEAD_DIM ** 0.5), dim=-1)

print(f"Attention Weights:\n{attention_weights}\nShape: {attention_weights.shape}\n")

Attention Weights:
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.6818, 0.3182, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3441, 0.2774, 0.3785, 0.0000, 0.0000, 0.0000],
          [0.1777, 0.2652, 0.2231, 0.3340, 0.0000, 0.0000],
          [0.1579, 0.2156, 0.1569, 0.2064, 0.2632, 0.0000],
          [0.1181, 0.1559, 0.1460, 0.1954, 0.2052, 0.1794]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5727, 0.4273, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3134, 0.3370, 0.3496, 0.0000, 0.0000, 0.0000],
          [0.2238, 0.2624, 0.2544, 0.2593, 0.0000, 0.0000],
          [0.2098, 0.2006, 0.2004, 0.1942, 0.1950, 0.0000],
          [0.1665, 0.2198, 0.1873, 0.1478, 0.1327, 0.1458]]]],
       grad_fn=<SoftmaxBackward0>)
Shape: torch.Size([1, 2, 6, 6])



## Dropout

In [65]:
dropout = nn.Dropout(0.2)
attention_weights = dropout(attention_weights)

print(f"Attention Weights after Dropout:\n{attention_weights}\nShape: {attention_weights.shape}\n")

Attention Weights after Dropout:
tensor([[[[1.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.8523, 0.3977, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4301, 0.3467, 0.4732, 0.0000, 0.0000, 0.0000],
          [0.2221, 0.3315, 0.2788, 0.4175, 0.0000, 0.0000],
          [0.1974, 0.2695, 0.1961, 0.0000, 0.3289, 0.0000],
          [0.1477, 0.1948, 0.1826, 0.2442, 0.2565, 0.2242]],

         [[1.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.7159, 0.5341, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3917, 0.0000, 0.4370, 0.0000, 0.0000, 0.0000],
          [0.2798, 0.0000, 0.3180, 0.3242, 0.0000, 0.0000],
          [0.0000, 0.2508, 0.0000, 0.2428, 0.0000, 0.0000],
          [0.2082, 0.2747, 0.2342, 0.0000, 0.1659, 0.1823]]]],
       grad_fn=<MulBackward0>)
Shape: torch.Size([1, 2, 6, 6])



## Context Vectors

In [66]:
context_vectors = attention_weights @ V  # (B, H, T, T) * (B, H, T, Dhead) -> (B, H, T, Dhead)

print(f"Context Vectors:\n{context_vectors}\nShape: {context_vectors.shape}\n")

Context Vectors:
tensor([[[[ 0.8253,  0.4775],
          [ 0.8293,  0.7882],
          [ 0.5600,  0.6275],
          [ 0.2378,  0.3411],
          [ 0.2276,  0.3883],
          [ 0.1024, -0.0059]],

         [[-1.0098,  0.0353],
          [-0.7621, -0.4944],
          [-0.4748,  0.0572],
          [-0.2385,  0.1100],
          [-0.0093, -0.1904],
          [-0.2533, -0.1968]]]], grad_fn=<UnsafeViewBackward0>)
Shape: torch.Size([1, 2, 6, 2])



## Regroup Context Vectors into per Token

In [67]:
context_vectors = context_vectors.transpose(1, 2) # (B, T, H, Dhead)

print(f"Regrouped Context Vectors:\n{context_vectors}\nShape: {context_vectors.shape}\n")

Regrouped Context Vectors:
tensor([[[[ 0.8253,  0.4775],
          [-1.0098,  0.0353]],

         [[ 0.8293,  0.7882],
          [-0.7621, -0.4944]],

         [[ 0.5600,  0.6275],
          [-0.4748,  0.0572]],

         [[ 0.2378,  0.3411],
          [-0.2385,  0.1100]],

         [[ 0.2276,  0.3883],
          [-0.0093, -0.1904]],

         [[ 0.1024, -0.0059],
          [-0.2533, -0.1968]]]], grad_fn=<TransposeBackward0>)
Shape: torch.Size([1, 6, 2, 2])



## Reshape Context Vectors to Match final D_out

In [68]:
context_vectors = context_vectors.contiguous().view(NUM_BATCH, NUM_TOKENS, D_OUT)

print(f"Reshaped Context Vectors:\n{context_vectors}\nShape: {context_vectors.shape}\n")

Reshaped Context Vectors:
tensor([[[ 0.8253,  0.4775, -1.0098,  0.0353],
         [ 0.8293,  0.7882, -0.7621, -0.4944],
         [ 0.5600,  0.6275, -0.4748,  0.0572],
         [ 0.2378,  0.3411, -0.2385,  0.1100],
         [ 0.2276,  0.3883, -0.0093, -0.1904],
         [ 0.1024, -0.0059, -0.2533, -0.1968]]], grad_fn=<ViewBackward0>)
Shape: torch.Size([1, 6, 4])



Why contiguous?
* We need to ensure the memory is contiguous before using view().
* When a tensor is initialized for the first time, the memory is already contiguous so when view() applied at that time, no need of this.

reshape() does not have this condition, then why view() instead of reshape()?
* view() ensures zero copy while reshape() does not.

## Mult-Head Attention Class

In [69]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_size, num_heads, dropout_prob):
        super(MultiHeadAttention, self).__init__()

        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.Wq = nn.Linear(d_in, d_out, bias=False)
        self.Wk = nn.Linear(d_in, d_out, bias=False)
        self.Wv = nn.Linear(d_in, d_out, bias=False)

        self.dropout = nn.Dropout(dropout_prob)

        self.register_buffer("mask", torch.triu(torch.ones(context_size, context_size), diagonal=1))

        self.out_proj = nn.Linear(d_out, d_out)

    def forward(self, input_embeddings):
        B, T, d_in = input_embeddings.shape

        Q = self.Wq(input_embeddings) # (B, T, D_out)
        K = self.Wk(input_embeddings) # (B, T, D_out)
        V = self.Wv(input_embeddings) # (B, T, D_out)

        Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, T, Dhead)
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, T, Dhead)
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, T, Dhead)

        attention_scores = Q @ K.transpose(-2, -1) # (B, H, T, T)
        attention_scores = attention_scores.masked_fill(self.mask.bool(), -torch.inf)
        attention_weights = torch.softmax(attention_scores / (self.head_dim ** 0.5), dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vectors = attention_weights @ V # (B, H, T, Dhead)
        context_vectors = context_vectors.transpose(1, 2).contiguous().view(B, T, self.d_out) # (B, T, D_out)

        output = self.out_proj(context_vectors) # (B, T, D_out)

        return output


In [70]:
input_embeddings = torch.randn(2, NUM_TOKENS, D_IN)

print(f"Input Embeddings:\n{input_embeddings}\nShape: {input_embeddings.shape}\n")

mha = MultiHeadAttention(D_IN, D_OUT, CONTEXT_SIZE, NUM_HEADS, dropout_prob=0.2)
output = mha(input_embeddings)

print(f"Multi-Head Attention Output:\n{output}\nShape: {output.shape}\n")

Input Embeddings:
tensor([[[ 6.0151e-01,  5.6525e-01, -5.7093e-01, -5.5915e-01],
         [-1.5905e+00,  1.0057e+00, -1.1979e+00,  1.6426e+00],
         [-1.0055e-01, -8.8790e-01,  1.8286e-01, -6.6268e-01],
         [-5.9494e-01,  1.7040e-01, -1.0473e+00, -7.6299e-01],
         [-4.8385e-01,  1.2971e-01, -1.0164e+00,  1.0861e+00],
         [ 1.7435e+00,  2.9948e-01,  5.3035e-01, -7.9038e-01]],

        [[ 2.8088e-01, -1.1303e+00,  1.0223e-01,  7.1892e-01],
         [ 2.3506e-04, -1.6781e+00, -1.5704e+00,  6.3134e-01],
         [ 1.7100e+00, -1.9451e+00,  3.0664e-01,  3.4074e-01],
         [-1.2654e+00, -6.8019e-01, -4.2150e-01,  1.0234e-01],
         [ 1.8939e-01,  1.0149e+00, -6.2363e-01, -1.0587e-01],
         [-1.4688e+00,  1.3885e+00,  5.7455e-01, -2.1958e-02]]])
Shape: torch.Size([2, 6, 4])

Multi-Head Attention Output:
tensor([[[-0.4930, -0.2375, -0.4020, -0.1955],
         [-0.5924, -0.0874, -0.5475, -0.7912],
         [-0.4578, -0.1824, -0.4833, -0.5951],
         [-0.4172, -0.

The last output projection is not necessary but is observed in major LLM architectures. Since the projection weight matrix has size dout x dout, the final output has the same shape of of context vector (B x T x Dout * Dout x Dout -> B x T x Dout).