# Chapter - 3: Attention Mechanisms

In this chapter, we focus on building Attention Mechanisms which are at the heart of modern LLMs. They equip our models with the ability to understand relationships between words and allow for more contextual accuracy.

## 3.3: Attending to different parts of the input with self attention

Firstly we calculate the attention scores, which is just the dot product of the embeddings of query token with all other tokens

In [43]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

attention_scores_2 = torch.empty(inputs.shape[0])
for i in range(6):
    attention_scores_2[i] = torch.dot(inputs[1], inputs[i])

print(attention_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


Next, we calculate the attention weights which are just normalised attention scores so that their sum = 1.

We'll se two ways to do that:
1. x[i]/x.sum
2. e^x[i]/e^x.sum (SOFTMAX)

In [44]:
attention_weights_2_linear = attention_scores_2/attention_scores_2.sum()
attention_weights_2_softmax = torch.softmax(attention_scores_2, dim = 0, dtype=float)
print(attention_weights_2_linear, attention_weights_2_softmax)

tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656]) tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581], dtype=torch.float64)


Now, finally we take the weighted sum of all input tokens to generate the context vector for inputs[1].

In [45]:
context_vector_2 = torch.zeros(inputs.shape[1])
for i in range(inputs.shape[0]):
    context_vector_2 += inputs[i]*attention_weights_2_softmax[i]

print(context_vector_2)

tensor([0.4419, 0.6515, 0.5683])


### Generating the context vector for entire input

In [46]:
attention_score = inputs @ inputs.T
attentions_weights = torch.softmax(attention_score, dim=1)

context_vector = attentions_weights @ inputs
print(context_vector)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## 3.4: Attention Mechanism with trainable weights

This works by having three weight matrices, Wk, Wq and Wv. Query, Key and Value. For calculating context vector for inputs[2], its query vector is scalarly multiplied by the key vector of others to get the attention scores. Normalising it yields the attention weights, and the weighted sum of value vectors results in the context vector.

In [47]:
d_in = inputs.shape[1]
d_out = 2 # generally these are same

torch.manual_seed(123)
query_mat = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
key_mat = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
value_mat = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)

query_2 = inputs[1] @ query_mat
keys = inputs @ key_mat
values = inputs @ value_mat
attention_scores_2 = query_2 @ keys.T

attentions_weights = torch.softmax(attention_scores_2/d_out**0.5, dim = 0) # scalable with dimension of embeddings
context_vector_2 = attentions_weights @ values
print(context_vector_2)

tensor([0.3061, 0.8210])


Now, to generate context vectors for all of the input tokens, we will create a attention class, using the nn.Module class as Parent which gives us easy functionality when we will be required to optimize weights using gradients.

In [48]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Now we switch nn.Parameters for nn.Linear. It is beneficial as it has a better weight initialisation scheme, leading to more appropriate initial weights. This helps in training the model easily 

In [49]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # the bias parameters sets bias to 0, so it behaves purely as a matrix multiplication
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


## 3.5: Hiding duture words with casual attention

Since our model learns by next word prediction, it would literally be cheating if the context vector for the current word had data from the future tokens. Thus, we use masks to get rid of future attention weights.

In [50]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_score = queries @ keys.T
attn_weights = torch.softmax(attn_score/d_out**0.5, dim=1)

context_length = attn_weights.shape[0]
casual_mask = torch.tril(torch.ones(context_length, context_length))
masked_weights = casual_mask*attn_weights
print(masked_weights)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


We have to normalise this again, When we apply a mask and then renormalize the attention weights, it might initially appear that information from future tokens (which we intend to mask) could still influence the current token because their values are part of the softmax calculation However, the key insight is that when we renormalize the attention weights after masking, what we're essentially doing is recalculating the softmax over a smaller subset (since masked positions don't contribute to the softmax value).

In [51]:
normalised_masked_weights = masked_weights / masked_weights.sum(dim=1, keepdim=True)
print(normalised_masked_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


It is better to fill the future attentionw weights with negative infinity, instead of zero. This allows us to use the softmax function for normalisation rather than linear

In [52]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked_weights = attn_score.masked_fill(mask.bool(), -torch.inf)
print(masked_weights)

normalised_masked_weights = torch.softmax(masked_weights/d_out**0.5, dim = 1)
print(normalised_masked_weights)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


Dropping out weights to reduce overfitting issue, and multiplies remaining by the reciprocal of dropout rate to deal with scaling

In [53]:
DropOut = torch.nn.Dropout(0.5)
normalised_masked_weights = DropOut(normalised_masked_weights)
print(normalised_masked_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.4638, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.3968, 0.0000, 0.0000, 0.0000],
        [0.3869, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


Now we consolidate everything to a class that can even handle data generated by the DataLoader Classes

In [54]:
class CausalAttentionV1(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W_query = torch.nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_keys = torch.nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_values = torch.nn.Linear(d_in, d_out, bias = qkv_bias)
        self.DropOut = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape

        attn_score = self.W_query(x) @ self.W_keys(x).transpose(1,2)
        attn_score.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf )
        attn_weights = torch.softmax(attn_score / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.DropOut(attn_weights)

        context_vec = attn_weights @ self.W_values(x)
        return context_vec

torch.manual_seed(123)
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) # stacked inputs to mimc epoch

context_length = batch.shape[1]
ca = CausalAttentionV1(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

torch.Size([2, 6, 3])
tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


## 3.6: Multi-Head Attention

Utilising several CausalAttention Classes.

Our first implementation will be a wrapper for multiple heads, in which the final context vector will be a concatenation of the outputs of multiple heads.

In [55]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([CausalAttentionV1(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


Now, we try to implement this more efficiently

In [56]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # d_out per head

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out) // this d_out = num_heads*head_dim
        queries = self.W_query(x)
        values = self.W_value(x)

        # Unfurling the last dimensions into 2 to reveal the seperated heads
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        # This now segregates heads into tokens not otherwise
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim) Switching dim 1 and 2 so that we can reduce dimension size properly
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) #tensor.contiguous() has to do with memory allocation
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
