# Coding attention mechanisms

We will implement 4 different attention mechanisms.
 - simplified self-attention
 - self-attention
 - causal attention
 - multi head attention

## Simplified self-attention

In [130]:
import torch

In [131]:
#Dummy input embeddings for 6 tokens, each with an embedding size of 3
inputs = torch.tensor(
    [[0.42, 0.15, 0.89],
    [0.78, 0.33, 0.21],
    [0.12, 0.44, 0.67],
    [0.56, 0.91, 0.73],
    [0.34, 0.29, 0.85],
    [0.63, 0.11, 0.49]
    ]
)

Attention scores

In [132]:
query = inputs[1] # Second word
attention_scores_2 = torch.empty(inputs.shape[0]) # 6 scores
for i, x_i in enumerate(inputs):
    attention_scores_2[i] = torch.dot(x_i, query) #Dot product
attention_scores_2 #This vector says how similar each word is to the second word
#Second item is the greatest in the list, because is respect to himself. 

tensor([0.5640, 0.7614, 0.3795, 0.8904, 0.5394, 0.6306])

In [133]:
attention_weights_2 = torch.softmax(attention_scores_2, dim=0) #Now sum to 1
attention_weights_2

tensor([0.1543, 0.1880, 0.1283, 0.2139, 0.1506, 0.1649])

In [134]:
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attention_weights_2[i] * x_i
context_vec_2

tensor([0.5017, 0.3981, 0.6277])

This is the context vector for only the item 2 (query).

Now, we would like to have all context vectors

In [135]:
attention_scores = inputs @ inputs.T  # Shape (6, 6)
print(attention_scores.shape)
attention_weights = torch.softmax(attention_scores, dim=-1)  # Shape (6,
print(attention_weights.shape)
context_vectors = attention_weights @ inputs
context_vectors

torch.Size([6, 6])
torch.Size([6, 6])


tensor([[0.4657, 0.3874, 0.6732],
        [0.5017, 0.3981, 0.6277],
        [0.4606, 0.4091, 0.6722],
        [0.4790, 0.4538, 0.6663],
        [0.4641, 0.3983, 0.6740],
        [0.4861, 0.3826, 0.6464]])

So, the steps are:
 - **Attention Scores**: Raw similarity between words (dot product)
 - **Attention Weights**: Normalized scores that sum to 1 (softmax)  
 - **Context Vectors**: Weighted combination of all words (final representation)

## Attention with trainable parameters

Now, let's compute attention weights with trainable parameters.
We will introduce three matrices:
 - W_q: query
 - W_k: key
 - W_v: value

In [136]:
x_2 = inputs[1]
d_in = inputs.shape[1] #Input of embedding size (3)
d_out = 2 #Output of embedding size (2)

In [137]:
W_q = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_k = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_v = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [138]:
query_2 = x_2 @ W_q
key_2 = x_2 @ W_k
value_2 = x_2 @ W_v

In [139]:
query_2, key_2, value_2

(tensor([1.0057, 0.2636]), tensor([0.8479, 1.1501]), tensor([0.9898, 0.8750]))

In [140]:
keys = inputs @ W_k
values = inputs @ W_v

Okay, let's get the attention scores

In [141]:
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([0.9975, 1.1559, 0.6061, 1.3122, 0.9206, 1.0459])

An then the attention weights. We will normalize by sqrt(len(keys)). This tries to improve the training performance by avoiding small gradients

In [142]:
attn_weights_2 = torch.softmax(attn_scores_2 / keys.shape[-1] ** 0.5, dim=-1)
attn_weights_2

tensor([0.1637, 0.1831, 0.1241, 0.2045, 0.1551, 0.1694])

Let's get the context vector for the second word.

In [143]:
context_vec_2 = attn_weights_2 @ values
context_vec_2

tensor([1.1408, 0.8940])

Query key and value are borrowed from the domain of information retrieval and databases, where similar concepts are used to store, search and retrieve information.

Query is analogous to a search query ina database. It represents the current item.

Key is like a database key used for indexing and searching. Each item in the input sequence has an associated key.

Value is the value in a key-value pair in a database. Represents the actual content of the input items.



## Compact self-attention

We will use a class to summary the concept of self-attention

In [144]:
from torch import nn
class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        context_vector = attn_weights @ values
        return context_vector

In [145]:
sa_v1 = SelfAttentionV1(d_in, d_out)
print(sa_v1(inputs))

tensor([[1.0814, 0.6917],
        [1.0830, 0.6922],
        [1.0808, 0.6916],
        [1.0915, 0.7033],
        [1.0821, 0.6927],
        [1.0805, 0.6898]], grad_fn=<MmBackward0>)


Let's do the same but using linear layers. These use more optimized initial values.

In [146]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out)
        self.W_key = nn.Linear(d_in, d_out)
        self.W_value = nn.Linear(d_in, d_out)
    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim = -1)
        context_vector = attn_weights @ values
        return context_vector
        

In [147]:
sa_v2 = SelfAttentionV2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0234,  0.0849],
        [-0.0242,  0.0799],
        [-0.0235,  0.0848],
        [-0.0240,  0.0821],
        [-0.0235,  0.0848],
        [-0.0238,  0.0823]], grad_fn=<MmBackward0>)


## Hiding future words with causal attention

For many LLM tasks (as GPT), you will want the self-attention mechanism to consider only the tokens that appear prior the current position.

This is causal/masked attention.

In [148]:
queries = sa_v2.W_query(inputs)
key = sa_v2.W_key(inputs)
value = sa_v2.W_value(inputs)

attn_scores = queries @ keys.T
attn_scores #Squared matrix

tensor([[ 0.2368,  0.2338,  0.1602,  0.3064,  0.2241,  0.2229],
        [ 0.0069, -0.1251,  0.0575, -0.0078,  0.0246, -0.0760],
        [ 0.5268,  0.4681,  0.3771,  0.6750,  0.5057,  0.4633],
        [ 0.5142,  0.3672,  0.4041,  0.6475,  0.5060,  0.3962],
        [ 0.3489,  0.3237,  0.2443,  0.4488,  0.3331,  0.3154],
        [ 0.0286, -0.0277,  0.0417,  0.0299,  0.0347, -0.0081]],
       grad_fn=<MmBackward0>)

Let's replace the upper triangle with -inf.

In [149]:
context_length = attn_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[ 0.2368,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0069, -0.1251,    -inf,    -inf,    -inf,    -inf],
        [ 0.5268,  0.4681,  0.3771,    -inf,    -inf,    -inf],
        [ 0.5142,  0.3672,  0.4041,  0.6475,    -inf,    -inf],
        [ 0.3489,  0.3237,  0.2443,  0.4488,  0.3331,    -inf],
        [ 0.0286, -0.0277,  0.0417,  0.0299,  0.0347, -0.0081]],
       grad_fn=<MaskedFillBackward0>)

Now, if we use softmax, rows will sum 1. 

In [150]:
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim = 1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5233, 0.4767, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3498, 0.3356, 0.3147, 0.0000, 0.0000, 0.0000],
        [0.2548, 0.2296, 0.2357, 0.2799, 0.0000, 0.0000],
        [0.2011, 0.1975, 0.1867, 0.2158, 0.1988, 0.0000],
        [0.1681, 0.1615, 0.1696, 0.1682, 0.1688, 0.1638]],
       grad_fn=<SoftmaxBackward0>)

In [151]:
context_vec = attn_weights @ values
context_vec

tensor([[1.1219, 0.8538],
        [1.0589, 0.8639],
        [1.0103, 0.7918],
        [1.1774, 0.9110],
        [1.1607, 0.8931],
        [1.1172, 0.8701]], grad_fn=<MmBackward0>)

Making additional attention weights with dropout

In [152]:
from torch.nn import Dropout
dropout = Dropout(0.3)
dropout(attention_weights) #Drop some random cells (~30%).

tensor([[0.2772, 0.1809, 0.2098, 0.0000, 0.2641, 0.0000],
        [0.2204, 0.0000, 0.1833, 0.0000, 0.2151, 0.2356],
        [0.2491, 0.1785, 0.2356, 0.3179, 0.2555, 0.1920],
        [0.2164, 0.1898, 0.2029, 0.4159, 0.0000, 0.0000],
        [0.2666, 0.1781, 0.2172, 0.3042, 0.2611, 0.2014],
        [0.2523, 0.2313, 0.1935, 0.2769, 0.2388, 0.0000]])

In [153]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf)

        attn_weights = torch.softmax(attention_scores / keys.shape[-1] ** 0.5, dim = 1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = attn_weights @ values
        return context_vec


In [155]:
batch = torch.stack((inputs, inputs), dim = 0)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
context_vecs.shape

torch.Size([2, 6, 2])

## Multi head attention

In [156]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)
        ])
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim = -1)

In [157]:
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads = 2)
context_vecs = mha(batch)
context_vecs.shape

torch.Size([2, 6, 4])