## Multi-Head Attention Breakdown

In [371]:
import torch
import torch.nn as nn

### Creating Batches

In [372]:
# Create batches by stacking 2 input vectors together
inputs = torch.nn.Embedding(6,3)
inputs = inputs.weight.data
batches = torch.stack((inputs, inputs), dim=0)


In [373]:
print(inputs)

tensor([[-0.3406,  2.0035, -0.4027],
        [ 0.4818,  0.3113,  0.0386],
        [-0.7591, -0.7128, -0.5769],
        [ 0.7443, -0.1105, -0.0138],
        [ 0.7773,  0.2322, -1.0294],
        [ 0.7965,  0.8256,  0.8181]])


In [374]:
batches

tensor([[[-0.3406,  2.0035, -0.4027],
         [ 0.4818,  0.3113,  0.0386],
         [-0.7591, -0.7128, -0.5769],
         [ 0.7443, -0.1105, -0.0138],
         [ 0.7773,  0.2322, -1.0294],
         [ 0.7965,  0.8256,  0.8181]],

        [[-0.3406,  2.0035, -0.4027],
         [ 0.4818,  0.3113,  0.0386],
         [-0.7591, -0.7128, -0.5769],
         [ 0.7443, -0.1105, -0.0138],
         [ 0.7773,  0.2322, -1.0294],
         [ 0.7965,  0.8256,  0.8181]]])

In [375]:
#Explain view methods and transpositions in the multiheadattention class!

In [376]:
## Multihead Attention Class for reference 

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec


In [377]:
#Instantiate a reference class
mha = MultiHeadAttention(d_in=3, d_out= 2, context_length= 6, dropout= 0, num_heads= 2)

In [406]:
#Find context vectors from sample ID vector batches
mha_out = mha(batches)
mha_out
print(mha_out.shape)

torch.Size([2, 6, 2])


In [389]:
batch_size, context_length, d_in = batches.shape
b, num_tokens, d_in = batches.shape
d_out = 2
#Shape of the bactch. Thus b = 2, num = 4, d_in = 3
#Decide number of attention heads...
num_heads = 2
qkv_bias = False
dropout = 0
#Calculate head dimensions
head_dim = d_out // num_heads
print(d_in)
print(batches.shape)

3
torch.Size([2, 6, 3])


In [390]:
# Creat query, key, value matrices of old vector length (2) by new vector length (3)

W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
dropout = nn.Dropout(dropout)

mask = torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)

# Weight matrices are 3 x 2
print("key weights:", W_key)
# Mask matrix is 6 x 6 
print(mask)
print(mask.shape)


key weights: Linear(in_features=3, out_features=2, bias=False)
tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
torch.Size([6, 6])


In [381]:
#Creates 3 matrices of correct dimensions by multiplying random weight matrix (3 x 2) by input matrix (2 x 6 x 3))
keys = W_key(batches) # Shape: (b, num_tokens, d_out)
queries = W_query(batches)
values = W_value(batches)
print(W_key, "X", batches.shape)

print("resulting tensor:", keys.shape)
print("Queries:", keys)

# Returns matrices of 2 (b) x 6 (number tokens) x 2 (d_out) = 2 x 6 x 2

Linear(in_features=3, out_features=2, bias=False) X torch.Size([2, 6, 3])
resulting tensor: torch.Size([2, 6, 2])
Queries: tensor([[[ 0.1446,  0.8069],
         [-0.2327, -0.1224],
         [ 0.0879, -0.0064],
         [-0.4420, -0.4205],
         [-0.9495, -0.5462],
         [ 0.0216,  0.0837]],

        [[ 0.1446,  0.8069],
         [-0.2327, -0.1224],
         [ 0.0879, -0.0064],
         [-0.4420, -0.4205],
         [-0.9495, -0.5462],
         [ 0.0216,  0.0837]]], grad_fn=<UnsafeViewBackward0>)


In [392]:
#View method converts to dimensions to 2 x 6 x 2 x 1

keys = keys.view(b, num_tokens, num_heads, head_dim) 
values = values.view(b, num_tokens, num_heads, head_dim)
queries = queries.view(b, num_tokens, num_heads, head_dim)
print(keys.shape)
print("Keys:", keys)

torch.Size([2, 6, 2, 1])
Keys: tensor([[[[ 0.1446],
          [ 0.8069]],

         [[-0.2327],
          [-0.1224]],

         [[ 0.0879],
          [-0.0064]],

         [[-0.4420],
          [-0.4205]],

         [[-0.9495],
          [-0.5462]],

         [[ 0.0216],
          [ 0.0837]]],


        [[[ 0.1446],
          [ 0.8069]],

         [[-0.2327],
          [-0.1224]],

         [[ 0.0879],
          [-0.0064]],

         [[-0.4420],
          [-0.4205]],

         [[-0.9495],
          [-0.5462]],

         [[ 0.0216],
          [ 0.0837]]]], grad_fn=<ViewBackward0>)


In [393]:
#transpose function transposes matrices on 2nd and 3rd dimensions so 2 x 6 x 2 x 1 becomes 2 x 2 x 6 x 1 

keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
print("Keys:", keys)
print(keys.shape)

Keys: tensor([[[[ 0.1446],
          [-0.2327],
          [ 0.0879],
          [-0.4420],
          [-0.9495],
          [ 0.0216]],

         [[ 0.8069],
          [-0.1224],
          [-0.0064],
          [-0.4205],
          [-0.5462],
          [ 0.0837]]],


        [[[ 0.1446],
          [-0.2327],
          [ 0.0879],
          [-0.4420],
          [-0.9495],
          [ 0.0216]],

         [[ 0.8069],
          [-0.1224],
          [-0.0064],
          [-0.4205],
          [-0.5462],
          [ 0.0837]]]], grad_fn=<TransposeBackward0>)
torch.Size([2, 2, 6, 1])


In [None]:
#multiply queries (2x2x6x1) by keys transposed on 3rd and 4th dimension(2x2x1x6) 
attn_scores = queries @ keys.transpose(2, 3)
mask_bool = mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)
#Multiplication returns a 2x2x6x6
print(attn_scores)

tensor([[[[-0.0527,    -inf,    -inf,    -inf,    -inf,    -inf],
          [-0.0153,  0.0246,    -inf,    -inf,    -inf,    -inf],
          [ 0.0197, -0.0316,  0.0120,    -inf,    -inf,    -inf],
          [-0.0101,  0.0163, -0.0062,  0.0309,    -inf,    -inf],
          [-0.0391,  0.0629, -0.0238,  0.1196,  0.2569,    -inf],
          [-0.0185,  0.0297, -0.0112,  0.0564,  0.1212, -0.0028]],

         [[-0.2159,    -inf,    -inf,    -inf,    -inf,    -inf],
          [-0.1774,  0.0269,    -inf,    -inf,    -inf,    -inf],
          [ 0.2699, -0.0409, -0.0021,    -inf,    -inf,    -inf],
          [-0.1970,  0.0299,  0.0016,  0.1027,    -inf,    -inf],
          [-0.3347,  0.0508,  0.0026,  0.1744,  0.2266,    -inf],
          [-0.2773,  0.0421,  0.0022,  0.1445,  0.1877, -0.0288]]],


        [[[-0.0527,    -inf,    -inf,    -inf,    -inf,    -inf],
          [-0.0153,  0.0246,    -inf,    -inf,    -inf,    -inf],
          [ 0.0197, -0.0316,  0.0120,    -inf,    -inf,    -inf],
    

In [394]:
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = dropout(attn_weights)
print(attn_weights)
print (attn_weights.shape)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4900, 0.5100, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3399, 0.3229, 0.3373, 0.0000, 0.0000, 0.0000],
          [0.2455, 0.2521, 0.2465, 0.2558, 0.0000, 0.0000],
          [0.1773, 0.1964, 0.1801, 0.2078, 0.2384, 0.0000],
          [0.1587, 0.1666, 0.1599, 0.1711, 0.1825, 0.1612]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4491, 0.5509, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4009, 0.2938, 0.3054, 0.0000, 0.0000, 0.0000],
          [0.2073, 0.2601, 0.2528, 0.2797, 0.0000, 0.0000],
          [0.1372, 0.2017, 0.1923, 0.2283, 0.2405, 0.0000],
          [0.1235, 0.1700, 0.1633, 0.1883, 0.1966, 0.1583]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4900, 0.5100, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3399, 0.3229, 0.3373, 0.0000, 0.0000, 0.0000],
          [0.2455, 0.2521, 0.2465, 0.2558, 0.0000, 0.0000],
          [0.1773, 0.1964, 0.1801,

In [403]:
# Multiply attention weights (2x2x6x6) with values transposed on the 2nd and 3rd dimensions (2x6x2x1)

print(attn_weights.shape)
print(values.transpose(1,2).shape)
print("equals:")
context_vec = (attn_weights @ values).transpose(1, 2)
print(context_vec.shape)

torch.Size([2, 2, 6, 6])
torch.Size([2, 6, 2, 1])
equals:
torch.Size([2, 6, 2, 1])


In [405]:
#Contiguous converts 2x6x2x1 into 2x6x2 
context_vec = context_vec.contiguous().view(
         b, num_tokens, d_out
)
print(context_vec.shape)
print(context_vec)

torch.Size([2, 6, 2])
tensor([[[-0.6197, -0.6902],
         [-0.3749, -0.4036],
         [-0.1447, -0.2254],
         [-0.1095, -0.1181],
         [-0.0953, -0.1280],
         [-0.1446, -0.1698]],

        [[-0.6197, -0.6902],
         [-0.3749, -0.4036],
         [-0.1447, -0.2254],
         [-0.1095, -0.1181],
         [-0.0953, -0.1280],
         [-0.1446, -0.1698]]], grad_fn=<ViewBackward0>)
