<a href="https://colab.research.google.com/github/ArpitKadam/Attention-Is-All-You-Code/blob/main/LLM-from-Scratch/CHP_06_Multi_Head_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **MULTI-HEAD ATTENTION**

Extending Single Head Attention to Multi-Head Attention

In [10]:
import torch
import torch.nn as nn

class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias):
    super().__init__()
    self.d_out = d_out
    self.context_length = context_length
    self.dropout = nn.Dropout(dropout)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.register_buffer("simple_mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    batch_size, num_tokens, d_in = x.shape

    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = torch.matmul(queries, keys.transpose(-2, -1))

    attn_scores.masked_fill_(
        self.simple_mask.bool()[:num_tokens, :num_tokens],
        -torch.inf
    )

    attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)

    attn_weights = self.dropout(attn_weights)

    context_vec = torch.matmul(attn_weights, values)

    return context_vec

In [11]:
import torch.nn as nn

class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias):
    super().__init__()
    self.heads = nn.ModuleList(
        [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
    )

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)

In [12]:
import torch

input = torch.tensor(
    [[0.72, 0.45, 0.31],   ## Dream
     [0.75, 0.20, 0.55],   ## big
     [0.30, 0.80, 0.40],   ## and
     [0.85, 0.35, 0.60],   ## work
     [0.55, 0.15, 0.75],   ## for
     [0.20, 0.20, 0.85]]   ## it
)

words = ["Dream", "big", "and", "work", "for", "it"]

batch = torch.stack([input, input], dim=0)
print(batch.shape)
d_in = batch.shape[-1]
d_out = 2
context_length = batch.shape[1]
dropout = 0.5
qkv_bias = True
num_heads = 5

torch.Size([2, 6, 3])


In [13]:
mh_attn = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads, qkv_bias)

context_vector = mh_attn(batch)

In [14]:
print("Context Vector:")
print(context_vector)
print("Shape:", context_vector.shape)

Context Vector:
tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00, -1.2077e+00,  1.1254e+00],
         [ 1.4238e-01, -2.7253e-02, -9.7262e-01,  6.7575e-01,  2.5844e-01,
           5.7308e-01,  3.8498e-01, -7.1997e-01, -5.9959e-01,  5.5871e-01],
         [ 3.5423e-01,  3.5294e-02, -1.6565e+00,  1.3933e+00,  1.0789e-01,
           3.2194e-01,  5.6686e-01, -9.1626e-01, -1.0933e+00,  7.1606e-01],
         [ 0.0000e+00,  0.0000e+00, -1.3760e+00,  1.0847e+00,  3.5426e-01,
           8.9170e-01,  2.8736e-01, -3.1735e-01, -8.4044e-01,  5.4209e-01],
         [ 1.8945e-01,  3.4561e-02, -1.1860e+00,  7.8923e-01,  0.0000e+00,
           0.0000e+00,  8.9272e-01, -1.3910e+00, -5.0857e-01,  5.1673e-01],
         [ 1.3407e-01,  3.3340e-02, -8.1284e-01,  6.8262e-01,  0.0000e+00,
           0.0000e+00,  6.0972e-01, -9.1083e-01, -5.6090e-01,  4.4197e-01]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
 

# **IMPLEMENTING MULTI-HEAD ATTENTION WITH WEIGHT SPLITS**

In [34]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias):
    super().__init__()
    assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

    ## Lets just say we take d_out = 6, batch_size = 2, d_out = 6, num_heads = 2, d_in = 6, num_tokens = 3

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads   ## Reduce the projection dim to match the desired output dim

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)    ## Shape: (6, 6)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)      ## Shape: (6, 6)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)    ## Shape: (6, 6)
    self.out_proj = nn.Linear(d_out, d_out)    ## Linear Layer to combine head outputs
    self.dropout = nn.Dropout(dropout)
    self.register_buffer("simple_mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    batch, num_tokens, d_in = x.shape      ## Shape: (2, 3, 6)

    keys = self.W_key(x)                 ## Shape: (2, 3, 6)
    queries = self.W_query(x)            ## Shape: (2, 3, 6)
    values = self.W_value(x)             ## Shape: (2, 3, 6)

    ## We implicitly split the matrix by adding a "num_heads" dimension
    ## Unroll last dim: (batch, num_tokens, d_out) -> (batch, num_tokens, num_heads, head_dim)
    keys = keys.view(batch, num_tokens, self.num_heads, self.head_dim)              ## Shape: (2, 3, 2, 3)
    queries = queries.view(batch, num_tokens, self.num_heads, self.head_dim)        ## Shape: (2, 3, 2, 3)
    values = values.view(batch, num_tokens, self.num_heads, self.head_dim)          ## Shape: (2, 3, 2, 3)

    ## Transpose: (batch, num_tokens, num_heads, head_dim) -> (batch, num_heads, num_tokens, head_dim)
    keys = keys.transpose(1, 2)            ## Shape: (2, 2, 3, 3)    (batch, num_heads, num_tokens, head_dim)
    queries = queries.transpose(1, 2)      ## Shape: (2, 2, 3, 3)    (batch, num_heads, num_tokens, head_dim)
    values = values.transpose(1, 2)        ## Shape: (2, 2, 3, 3)    (batch, num_heads, num_tokens, head_dim)

    ## Calculate attention scores with Causal Mask
    ## (batch, num_heads, num_tokens, head_dim) * (batch, num_heads, head_dim, num_tokens) = (batch, num_heads, num_tokens, num_tokens)
    ## Shape: (2, 2, 3, 3) * (2, 2, 3, 3) = (2, 2, 3, 3)
    attn_scores = torch.matmul(queries, keys.transpose(2, 3))

    attn_scores.masked_fill_(
        self.simple_mask.bool()[:num_tokens, :num_tokens],
        -torch.inf
    )

    attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)    ## Shape: (2, 2, 3, 3)

    attn_weights = self.dropout(attn_weights)    ## Shape: (2, 2, 3, 3)

    ## Calculate context vector
    ## (batch, num_heads, num_tokens, num_tokens) * (batch, num_heads, num_tokens, head_dim) = (batch, num_heads, num_tokens, head_dim)
    ## Shape: (2, 2, 3, 3) * (2, 2, 3, 3) = (2, 2, 3, 3)
    context_vec = torch.matmul(attn_weights, values)    ## Shape: (2, 2, 3, 3)

    ## (batch, num_heads, num_tokens, head_dim) -> (batch, num_tokens, num_heads, head_dim)
    ## Shape: (2, 2, 3, 3) -> ## Shape: (2, 3, 2, 3)
    context_vec = context_vec.transpose(1, 2)    ## Shape: (2, 3, 2, 3)

    ## (batch, num_tokens, num_heads, head_dim) -> (batch, num_tokens, d_out)
    ## Shape: (2, 3, 2, 3) -> (2, 3, 6)
    context_vec = context_vec.contiguous().view(batch, num_tokens, self.d_out)

    context_vec = self.out_proj(context_vec)    ## Shape: (2, 3, 6) * (6, 6) = (2, 3, 6)

    return context_vec

In [38]:
import torch

input = torch.tensor(
    [[0.72, 0.45, 0.31, 0.30, 0.80, 0.40],   ## The
     [0.75, 0.20, 0.55, 0.85, 0.35, 0.60],   ## cat
     [0.85, 0.35, 0.60, 0.20, 0.20, 0.85]],   ## sleeps
)

words = ["The", "cat", "sleeps"]

batch = torch.stack([input, input], dim=0)
batch_size, context_length, d_in = batch.shape
d_out = 6
dropout = 0.5
qkv_bias = True
num_heads = 2

In [40]:
mha = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads, qkv_bias)
context_vector = mha(batch)

In [41]:
print("Context Vector:")
print(context_vector)
print("Shape:", context_vector.shape)

Context Vector:
tensor([[[ 0.1190,  0.0905,  0.4936, -0.1211,  0.3983,  0.4535],
         [ 0.0019,  0.1892,  0.3212, -0.0644,  0.4177,  0.4157],
         [-0.0508,  0.3787, -0.2495, -0.2502,  0.2671,  0.1896]],

        [[ 0.0037,  0.1198,  0.3855, -0.0577,  0.3625,  0.4866],
         [-0.0548,  0.4741, -0.3064, -0.2718,  0.3267,  0.1731],
         [-0.0152,  0.0706,  0.3384, -0.0456,  0.4024,  0.3815]]],
       grad_fn=<ViewBackward0>)
Shape: torch.Size([2, 3, 6])


In [43]:
import torch

input = torch.tensor(
    [[0.72, 0.45, 0.31, 0.30, 0.80, 0.40],   ## The
     [0.75, 0.20, 0.55, 0.85, 0.35, 0.60],   ## cat
     [0.85, 0.35, 0.60, 0.20, 0.20, 0.85]],   ## sleeps
)

words = ["The", "cat", "sleeps"]

batch = torch.stack([input, input, input, input], dim=0)
batch_size, context_length, d_in = batch.shape
d_out = 10
num_heads = 5
dropout = 0.5
qkv_bias = True

In [44]:
mha = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads, qkv_bias)
context_vector = mha(batch)

In [45]:
print("Context Vector:")
print(context_vector)
print("Shape:", context_vector.shape)

Context Vector:
tensor([[[ 2.4097e-01,  1.4449e-01,  4.9991e-01,  3.3799e-01,  3.1057e-01,
          -1.6071e-01, -4.1381e-01, -2.3795e-01, -6.3857e-01,  3.1919e-01],
         [ 4.1329e-01,  1.0487e-01, -1.5155e-01,  9.6190e-02, -4.1447e-01,
          -2.0289e-01, -6.8089e-01, -3.3254e-02,  1.9961e-01,  3.4122e-01],
         [ 4.0597e-01, -1.2724e-01, -2.2650e-01,  1.1021e-01, -2.1889e-01,
           7.1342e-04, -6.4909e-01, -3.0321e-02, -8.3146e-03,  2.7135e-01]],

        [[-4.0709e-01, -4.0353e-01, -2.7830e-01, -4.5873e-01, -5.8114e-01,
          -1.5919e-01, -7.2865e-01, -9.0930e-01, -2.2909e-01,  4.7108e-01],
         [ 4.0491e-01,  2.2273e-01,  1.4551e-01,  4.3604e-01,  2.4016e-01,
           3.8988e-01, -5.3599e-01,  2.0897e-02, -4.0695e-01, -4.8698e-02],
         [-4.8674e-02, -3.9473e-01,  1.3517e-01, -1.8337e-01,  3.0090e-02,
           7.9319e-02, -4.5546e-01, -4.3734e-01, -4.7502e-01,  2.8718e-01]],

        [[ 3.5688e-01, -1.7334e-01, -3.4333e-01,  5.5416e-01, -3.1960e-02,