<a href="https://www.kaggle.com/code/drapes/multi-head-attention-mechanism-step-by-step?scriptVersionId=237910146" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import torch
import torch.nn as nn

# Self Attention

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, din, dout, qkv_bias=False):
        super().__init__()
        self.W_query=nn.Linear(din, dout, bias=qkv_bias)
        self.W_key=nn.Linear(din, dout, bias=qkv_bias)
        self.W_value=nn.Linear(din, dout, bias=qkv_bias)
    def forward(self, x):
        queries=self.W_query(x)
        key=self.W_key(x)
        value=self.W_value(x)

        attn_scores=queries @ key.T

        attn_wt=torch.softmax(attn_scores/key.shape[-1]**0.5, dim=-1)

        context_vectors=attn_wt @ value
        return context_vectors, attn_scores, attn_wt, key
        

In [3]:
torch.manual_seed(789)

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your   (x^1)
     [0.55, 0.87, 0.66], # journey (x^2)
     [0.57, 0.85, 0.64], # starts  (x^3)
     [0.22, 0.58, 0.33], # with    (x^4)
     [0.77, 0.25, 0.10], # one     (x^5)
     [0.05, 0.80, 0.55]] # step    (x^6)
)
din=3
dout=3
self_att=SelfAttention(din, dout)
context_vectors, attn_scores, attn_wt, key=self_att(inputs)


## Causal attention

In [4]:
context_ln=attn_scores.shape[0]
torch.ones(context_ln,context_ln)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

In [5]:
mask_simple=torch.tril(torch.ones(context_ln,context_ln))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

### Method 1:Make the upper-traingle value of attn weights 0

In [6]:
attn_wt

tensor([[0.1607, 0.1599, 0.1613, 0.1679, 0.1971, 0.1531],
        [0.1606, 0.1579, 0.1590, 0.1721, 0.1907, 0.1597],
        [0.1606, 0.1580, 0.1591, 0.1720, 0.1907, 0.1595],
        [0.1637, 0.1618, 0.1623, 0.1704, 0.1767, 0.1650],
        [0.1631, 0.1629, 0.1638, 0.1676, 0.1836, 0.1590],
        [0.1631, 0.1602, 0.1607, 0.1722, 0.1778, 0.1660]],
       grad_fn=<SoftmaxBackward0>)

In [7]:
masked_attn_wt=attn_wt * mask_simple

In [8]:
masked_attn_wt

tensor([[0.1607, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1606, 0.1579, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1606, 0.1580, 0.1591, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1618, 0.1623, 0.1704, 0.0000, 0.0000],
        [0.1631, 0.1629, 0.1638, 0.1676, 0.1836, 0.0000],
        [0.1631, 0.1602, 0.1607, 0.1722, 0.1778, 0.1660]],
       grad_fn=<MulBackward0>)

### Problem:Sum of each row not zero

### Solution: Divide by sum of each row

In [9]:
row_sums=masked_attn_wt.sum(dim=1, keepdim=True)
masked_attn_wt=masked_attn_wt/row_sums

In [10]:
masked_attn_wt

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5043, 0.4957, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3362, 0.3307, 0.3330, 0.0000, 0.0000, 0.0000],
        [0.2487, 0.2458, 0.2465, 0.2589, 0.0000, 0.0000],
        [0.1939, 0.1937, 0.1947, 0.1993, 0.2183, 0.0000],
        [0.1631, 0.1602, 0.1607, 0.1722, 0.1778, 0.1660]],
       grad_fn=<DivBackward0>)

### While this is nice but we could take advantage of a mathematical properties of softmax by making the upper triangle of *attention scores* -inf causing softmax to make them 0 

In [11]:
mask=torch.triu(torch.ones(context_ln,context_ln), diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [12]:
masked_attn_scores=attn_scores.masked_fill(mask.bool(), -torch.inf)

In [13]:
masked_attn_scores

tensor([[-0.2382,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3270, -0.3568,    -inf,    -inf,    -inf,    -inf],
        [-0.3218, -0.3506, -0.3384,    -inf,    -inf,    -inf],
        [-0.1822, -0.2026, -0.1975, -0.1128,    -inf,    -inf],
        [-0.1385, -0.1401, -0.1312, -0.0915,  0.0668,    -inf],
        [-0.2439, -0.2754, -0.2699, -0.1499, -0.0948, -0.2134]],
       grad_fn=<MaskedFillBackward0>)

In [14]:
masked_attn_wt_soft=torch.softmax(masked_attn_scores/key.shape[-1]**0.5, dim=-1)

In [15]:
masked_attn_wt

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5043, 0.4957, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3362, 0.3307, 0.3330, 0.0000, 0.0000, 0.0000],
        [0.2487, 0.2458, 0.2465, 0.2589, 0.0000, 0.0000],
        [0.1939, 0.1937, 0.1947, 0.1993, 0.2183, 0.0000],
        [0.1631, 0.1602, 0.1607, 0.1722, 0.1778, 0.1660]],
       grad_fn=<DivBackward0>)

In [16]:
masked_attn_wt_soft

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5043, 0.4957, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3362, 0.3307, 0.3330, 0.0000, 0.0000, 0.0000],
        [0.2487, 0.2458, 0.2465, 0.2589, 0.0000, 0.0000],
        [0.1939, 0.1937, 0.1947, 0.1993, 0.2183, 0.0000],
        [0.1631, 0.1602, 0.1607, 0.1722, 0.1778, 0.1660]],
       grad_fn=<SoftmaxBackward0>)

### We get the same output with less effort


In [17]:
class CausalSelfAttention(nn.Module):
    def __init__(self, din, dout, dropout, context_ln, qkv_bias=False):
        super().__init__()
        self.W_query=nn.Linear(din, dout, bias=qkv_bias)
        self.W_key=nn.Linear(din, dout, bias=qkv_bias)
        self.W_value=nn.Linear(din, dout, bias=qkv_bias)
        self.dropout=nn.Dropout(dropout)
        
    def forward(self, x):
        b, no_tokens, din=x.shape 
        queries=self.W_query(x)
        key=self.W_key(x)
        value=self.W_value(x)

        attn_scores=queries @ key.transpose(1,2)
        mask = torch.triu(torch.ones(no_tokens, no_tokens), diagonal=1)
        # Expand mask to match batch dimension
        mask = mask.unsqueeze(0).expand(b, -1, -1)
        masked_attn_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)

        attn_wt=torch.softmax(masked_attn_scores/key.shape[-1]**0.5, dim=-1)
        print("Attention Weight before dropout:", attn_wt)
        attn_wt=self.dropout(attn_wt)
        print("Attention Weight after dropout:", attn_wt)
        context_vectors=attn_wt @ value
        return context_vectors, masked_attn_scores, attn_wt, key
        

In [18]:
torch.manual_seed(789)

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your   (x^1)
     [0.55, 0.87, 0.66], # journey (x^2)
     [0.57, 0.85, 0.64], # starts  (x^3)
     [0.22, 0.58, 0.33], # with    (x^4)
     [0.77, 0.25, 0.10], # one     (x^5)
     [0.05, 0.80, 0.55]] # step    (x^6)
)

batch =torch.stack((inputs, inputs), dim=0)
print(batch.shape)
din=3
dout=3
dropout=0.0
context_ln=6
cas_self_att=CausalSelfAttention(din, dout, context_ln=context_ln, dropout=dropout)
cas_context_vectors, cas_attn_scores, cas_attn_wt, key=cas_self_att(batch)


torch.Size([2, 6, 3])
Attention Weight before dropout: tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5043, 0.4957, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3362, 0.3307, 0.3330, 0.0000, 0.0000, 0.0000],
         [0.2487, 0.2458, 0.2465, 0.2589, 0.0000, 0.0000],
         [0.1939, 0.1937, 0.1947, 0.1993, 0.2183, 0.0000],
         [0.1631, 0.1602, 0.1607, 0.1722, 0.1778, 0.1660]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5043, 0.4957, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3362, 0.3307, 0.3330, 0.0000, 0.0000, 0.0000],
         [0.2487, 0.2458, 0.2465, 0.2589, 0.0000, 0.0000],
         [0.1939, 0.1937, 0.1947, 0.1993, 0.2183, 0.0000],
         [0.1631, 0.1602, 0.1607, 0.1722, 0.1778, 0.1660]]],
       grad_fn=<SoftmaxBackward0>)
Attention Weight after dropout: tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5043, 0.4957, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3362, 0.3307, 0.3330, 0.0000, 0.0000

In [19]:
cas_context_vectors

tensor([[[ 0.3253, -0.5116, -0.1020],
         [ 0.4499, -0.5958, -0.0050],
         [ 0.4909, -0.6204,  0.0269],
         [ 0.4473, -0.5584,  0.0417],
         [ 0.4247, -0.4955,  0.0352],
         [ 0.4166, -0.4996,  0.0483]],

        [[ 0.3253, -0.5116, -0.1020],
         [ 0.4499, -0.5958, -0.0050],
         [ 0.4909, -0.6204,  0.0269],
         [ 0.4473, -0.5584,  0.0417],
         [ 0.4247, -0.4955,  0.0352],
         [ 0.4166, -0.4996,  0.0483]]], grad_fn=<UnsafeViewBackward0>)

# MultiHead Attention

In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self, din, dout, dropout, context_ln, num_heads, qkv_bias=False):
        super().__init__()
        assert (dout % num_heads == 0)
        self.dout = dout
        self.num_heads = num_heads
        self.head_dim = dout // num_heads
        
        self.W_query = nn.Linear(din, dout, bias=qkv_bias)
        self.W_key = nn.Linear(din, dout, bias=qkv_bias)
        self.W_value = nn.Linear(din, dout, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.op_proj = nn.Linear(dout, dout)

        self.register_buffer(
            "mask", 
            torch.triu(torch.ones(context_ln, context_ln), diagonal=1)
        )
        
    def forward(self, x):
        b, num_tokens, _ = x.shape

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Shape: (b, num_tokens, num_heads, head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)

        # Attention scores: (b, num_heads, num_tokens, num_tokens)
        attn_scores = (queries @ keys.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Causal mask
        mask = self.mask[:num_tokens, :num_tokens].to(x.device).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # shape: (1, 1, num_tokens, num_tokens)
        attn_scores = attn_scores.masked_fill(mask, float('-inf'))

        # Apply softmax and dropout
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Attention output: (b, num_heads, num_tokens, head_dim)
        context = (attn_weights @ values).transpose(1, 2)
        context = context.reshape(b, num_tokens, self.dout)

        # Final output projection
        return self.op_proj(context)


In [21]:
torch.manual_seed(789)

input1 = torch.tensor([
    [0.43, 0.15, 0.89, 0.12, 0.44, 0.65, 0.23, 0.90],
    [0.55, 0.87, 0.66, 0.31, 0.48, 0.92, 0.76, 0.11],
    [0.57, 0.85, 0.64, 0.29, 0.30, 0.73, 0.52, 0.19],
    [0.22, 0.58, 0.33, 0.66, 0.91, 0.08, 0.42, 0.77],
    [0.77, 0.25, 0.10, 0.95, 0.14, 0.38, 0.64, 0.33],
    [0.05, 0.80, 0.55, 0.26, 0.70, 0.39, 0.48, 0.85],
    [0.63, 0.12, 0.44, 0.18, 0.27, 0.50, 0.33, 0.71],
    [0.34, 0.48, 0.91, 0.07, 0.36, 0.60, 0.29, 0.88],
    [0.11, 0.90, 0.27, 0.59, 0.46, 0.41, 0.66, 0.22],
    [0.16, 0.73, 0.31, 0.80, 0.53, 0.69, 0.19, 0.37],
    [0.94, 0.20, 0.61, 0.17, 0.34, 0.78, 0.85, 0.04]
])

input2 = torch.tensor([
    [0.12, 0.33, 0.71, 0.45, 0.62, 0.89, 0.24, 0.55],
    [0.67, 0.21, 0.44, 0.58, 0.19, 0.73, 0.90, 0.37],
    [0.81, 0.42, 0.60, 0.35, 0.88, 0.11, 0.76, 0.29],
    [0.93, 0.08, 0.23, 0.77, 0.49, 0.64, 0.18, 0.31],
    [0.39, 0.55, 0.17, 0.26, 0.91, 0.30, 0.63, 0.82],
    [0.24, 0.70, 0.35, 0.14, 0.75, 0.48, 0.59, 0.66],
    [0.53, 0.62, 0.19, 0.88, 0.27, 0.34, 0.41, 0.50],
    [0.77, 0.13, 0.94, 0.61, 0.46, 0.12, 0.69, 0.38],
    [0.30, 0.47, 0.79, 0.16, 0.22, 0.81, 0.33, 0.56],
    [0.45, 0.96, 0.20, 0.52, 0.10, 0.36, 0.67, 0.15],
    [0.84, 0.25, 0.50, 0.73, 0.04, 0.60, 0.92, 0.13]
])

batch =torch.stack((input1, input2), dim=0)
print(batch.shape)
din=8
dout=8
dropout=0.3
context_ln=11
num_heads=2
head_self_att=MultiHeadAttention(din, dout, dropout=dropout, context_ln=context_ln, num_heads=num_heads)
cas_context_vectors=head_self_att(batch)


torch.Size([2, 11, 8])


In [22]:
cas_context_vectors.shape


torch.Size([2, 11, 8])

In [23]:
cas_context_vectors

tensor([[[-0.0845,  0.0366,  0.2011,  0.2282,  0.0772, -0.4271, -0.0105,
           0.1393],
         [-0.1989,  0.3531,  0.2227, -0.1106,  0.0146, -0.0938,  0.0120,
           0.1485],
         [-0.1441,  0.2184,  0.2322,  0.0610,  0.1251, -0.2258,  0.0451,
           0.1346],
         [-0.1848,  0.3734,  0.3464,  0.0428,  0.0641, -0.2774,  0.0615,
           0.1575],
         [-0.1649,  0.3086,  0.2945,  0.0247,  0.1176, -0.1904,  0.0854,
           0.1039],
         [-0.1039,  0.2492,  0.2511,  0.0689,  0.2080, -0.2136,  0.0454,
           0.1535],
         [-0.1304,  0.3206,  0.2623, -0.0222,  0.1250, -0.1705,  0.0439,
           0.1559],
         [-0.1601,  0.3679,  0.3107, -0.0088,  0.0888, -0.1864,  0.0413,
           0.1693],
         [-0.0830,  0.1393,  0.1362,  0.0482,  0.2274, -0.2275,  0.0395,
           0.1291],
         [-0.1409,  0.2701,  0.2271, -0.0081,  0.1453, -0.1890,  0.0718,
           0.1260],
         [-0.0990,  0.1842,  0.1645,  0.0285,  0.1605, -0.1993,  0.006