# Coding attention mechanisms

We will implement 4 different attention mechanisms.
 - simplified self-attention
 - self-attention
 - causal attention
 - multi head attention

## Simplified self-attention

In [65]:
import torch

In [66]:
#Dummy input embeddings for 6 tokens, each with an embedding size of 3
inputs = torch.tensor(
    [[0.42, 0.15, 0.89],
    [0.78, 0.33, 0.21],
    [0.12, 0.44, 0.67],
    [0.56, 0.91, 0.73],
    [0.34, 0.29, 0.85],
    [0.63, 0.11, 0.49]
    ]
)

Attention scores

In [67]:
query = inputs[1] # Second word
attention_scores_2 = torch.empty(inputs.shape[0]) # 6 scores
for i, x_i in enumerate(inputs):
    attention_scores_2[i] = torch.dot(x_i, query) #Dot product
attention_scores_2 #This vector says how similar each word is to the second word
#Second item is the greatest in the list, because is respect to himself. 

tensor([0.5640, 0.7614, 0.3795, 0.8904, 0.5394, 0.6306])

In [68]:
attention_weights_2 = torch.softmax(attention_scores_2, dim=0) #Now sum to 1
attention_weights_2

tensor([0.1543, 0.1880, 0.1283, 0.2139, 0.1506, 0.1649])

In [69]:
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attention_weights_2[i] * x_i
context_vec_2

tensor([0.5017, 0.3981, 0.6277])

This is the context vector for only the item 2 (query).

Now, we would like to have all context vectors

In [70]:
attention_scores = inputs @ inputs.T  # Shape (6, 6)
print(attention_scores.shape)
attention_weights = torch.softmax(attention_scores, dim=-1)  # Shape (6,
print(attention_weights.shape)
context_vectors = attention_weights @ inputs
context_vectors

torch.Size([6, 6])
torch.Size([6, 6])


tensor([[0.4657, 0.3874, 0.6732],
        [0.5017, 0.3981, 0.6277],
        [0.4606, 0.4091, 0.6722],
        [0.4790, 0.4538, 0.6663],
        [0.4641, 0.3983, 0.6740],
        [0.4861, 0.3826, 0.6464]])

So, the steps are:
 - **Attention Scores**: Raw similarity between words (dot product)
 - **Attention Weights**: Normalized scores that sum to 1 (softmax)  
 - **Context Vectors**: Weighted combination of all words (final representation)

## Attention with trainable parameters

Now, let's compute attention weights with trainable parameters.
We will introduce three matrices:
 - W_q: query
 - W_k: key
 - W_v: value

In [71]:
x_2 = inputs[1]
d_in = inputs.shape[1] #Input of embedding size (3)
d_out = 2 #Output of embedding size (2)

In [72]:
W_q = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_k = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_v = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [73]:
query_2 = x_2 @ W_q
key_2 = x_2 @ W_k
value_2 = x_2 @ W_v

In [74]:
query_2, key_2, value_2

(tensor([0.6035, 0.8222]), tensor([0.7056, 0.5192]), tensor([0.5968, 0.9680]))

In [75]:
keys = inputs @ W_k
values = inputs @ W_v

Okay, let's get the attention scores

In [76]:
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([0.8740, 0.8528, 0.8435, 1.5476, 0.9330, 0.7289])

An then the attention weights. We will normalize by sqrt(len(keys)). This tries to improve the training performance by avoiding small gradients

In [77]:
attn_weights_2 = torch.softmax(attn_scores_2 / keys.shape[-1] ** 0.5, dim=-1)
attn_weights_2

tensor([0.1534, 0.1511, 0.1501, 0.2470, 0.1599, 0.1384])

Let's get the context vector for the second word.

In [78]:
context_vec_2 = attn_weights_2 @ values
context_vec_2

tensor([0.8122, 1.0227])

Query key and value are borrowed from the domain of information retrieval and databases, where similar concepts are used to store, search and retrieve information.

Query is analogous to a search query ina database. It represents the current item.

Key is like a database key used for indexing and searching. Each item in the input sequence has an associated key.

Value is the value in a key-value pair in a database. Represents the actual content of the input items.



## Compact self-attention

We will use a class to summary the concept of self-attention

In [79]:
from torch import nn
class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        context_vector = attn_weights @ values
        return context_vector

In [80]:
sa_v1 = SelfAttentionV1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.6221, 0.5223],
        [0.6230, 0.5217],
        [0.6220, 0.5222],
        [0.6279, 0.5256],
        [0.6225, 0.5226],
        [0.6216, 0.5214]], grad_fn=<MmBackward0>)


Let's do the same but using linear layers. These use more optimized initial values.

In [81]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out)
        self.W_key = nn.Linear(d_in, d_out)
        self.W_value = nn.Linear(d_in, d_out)
    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim = -1)
        context_vector = attn_weights @ values
        return context_vector
        

In [82]:
sa_v2 = SelfAttentionV2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.4356, -0.6395],
        [-0.4389, -0.6404],
        [-0.4352, -0.6389],
        [-0.4367, -0.6394],
        [-0.4354, -0.6393],
        [-0.4375, -0.6401]], grad_fn=<MmBackward0>)


## Hiding future words with causal attention

For many LLM tasks (as GPT), you will want the self-attention mechanism to consider only the tokens that appear prior the current position.

This is causal/masked attention.

In [83]:
queries = sa_v2.W_query(inputs)
key = sa_v2.W_key(inputs)
value = sa_v2.W_value(inputs)

attn_scores = queries @ keys.T
attn_scores #Squared matrix

tensor([[0.5268, 0.3769, 0.4977, 0.8049, 0.5535, 0.3767],
        [0.2446, 0.1040, 0.2256, 0.3075, 0.2525, 0.1425],
        [0.5180, 0.3810, 0.4902, 0.8012, 0.5450, 0.3752],
        [0.4012, 0.2659, 0.3774, 0.5933, 0.4203, 0.2772],
        [0.5255, 0.3799, 0.4968, 0.8066, 0.5525, 0.3776],
        [0.3706, 0.2235, 0.3469, 0.5275, 0.3868, 0.2460]],
       grad_fn=<MmBackward0>)

Let's replace the upper triangle with -inf.

In [84]:
context_length = attn_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.5268,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.2446, 0.1040,   -inf,   -inf,   -inf,   -inf],
        [0.5180, 0.3810, 0.4902,   -inf,   -inf,   -inf],
        [0.4012, 0.2659, 0.3774, 0.5933,   -inf,   -inf],
        [0.5255, 0.3799, 0.4968, 0.8066, 0.5525,   -inf],
        [0.3706, 0.2235, 0.3469, 0.5275, 0.3868, 0.2460]],
       grad_fn=<MaskedFillBackward0>)

Now, if we use softmax, rows will sum 1. 

In [85]:
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim = 1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5248, 0.4752, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3462, 0.3143, 0.3395, 0.0000, 0.0000, 0.0000],
        [0.2477, 0.2251, 0.2435, 0.2837, 0.0000, 0.0000],
        [0.1953, 0.1762, 0.1913, 0.2382, 0.1990, 0.0000],
        [0.1687, 0.1520, 0.1659, 0.1884, 0.1706, 0.1544]],
       grad_fn=<SoftmaxBackward0>)

In [86]:
context_vec = attn_weights @ values
context_vec

tensor([[0.9036, 0.9912],
        [0.7578, 0.9802],
        [0.7242, 0.8928],
        [0.8062, 1.0259],
        [0.8215, 1.0179],
        [0.7979, 0.9951]], grad_fn=<MmBackward0>)

Making additional attention weights with dropout

In [87]:
from torch.nn import Dropout
dropout = Dropout(0.3)
dropout(attention_weights) #Drop some random cells (~30%).

tensor([[0.0000, 0.0000, 0.0000, 0.2857, 0.2641, 0.2108],
        [0.2204, 0.2686, 0.0000, 0.0000, 0.2151, 0.2356],
        [0.2491, 0.1785, 0.2356, 0.0000, 0.2555, 0.1920],
        [0.0000, 0.1898, 0.2029, 0.0000, 0.0000, 0.1753],
        [0.2666, 0.1781, 0.2172, 0.3042, 0.2611, 0.0000],
        [0.2523, 0.2313, 0.1935, 0.0000, 0.0000, 0.0000]])

In [88]:
class CausalAttention(nn.Module):
    # Implements causal (masked) self-attention for a single attention head.
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        # Linear layers to project input to queries, keys, and values
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # Register a causal mask to prevent attending to future tokens
        # This is a buffer, which means it's part of the model but doesn't need gradients
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        ) #It saves memory and is faster.

    def forward(self, x):
        # x: (batch_size, num_tokens, d_in)
        b, num_tokens, d_in = x.shape
        queries = self.W_query(x)  # (batch_size, num_tokens, d_out)
        keys = self.W_key(x)       # (batch_size, num_tokens, d_out)
        values = self.W_value(x)   # (batch_size, num_tokens, d_out)

        # Compute attention scores
        attn_scores = queries @ keys.transpose(1, 2)  # (batch_size, num_tokens, num_tokens)
        # Apply causal mask (upper triangle set to -inf)
        attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf)
        #It used the registered buffer to mask the future tokens.
        # Softmax over the last dimension (tokens), then apply dropout
        attn_weights = torch.softmax(attention_scores / keys.shape[-1] ** 0.5, dim = 1)
        attn_weights = self.dropout(attn_weights)
        
        # Weighted sum of values
        context_vec = attn_weights @ values  # (batch_size, num_tokens, d_out)
        return context_vec


In [89]:
batch = torch.stack((inputs, inputs), dim = 0)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
context_vecs.shape

torch.Size([2, 6, 2])

## Multi head attention

In [90]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
            for _ in range(num_heads)
        ]) #We create a list of causal attention heads.
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim = -1) #We concatenate the outputs of all heads.

In [94]:
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads = 2)
context_vecs = mha(batch)
context_vecs.shape

torch.Size([2, 6, 4])

## Implementing multi-head attention with weight splits

Let's implement a more efficient multi-head attention. We can improve the last implementation by processing the heads in parallel. One way to achieve this is by computing the outputs for all attention heads simultaneously via matrix multiplication. 

In [None]:
# This class implements a more efficient version of multi-head self-attention.
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out,
        context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        # Linear layer to project input to queries
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        # Linear layer to project input to keys
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        # Linear layer to project input to values
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # Output projection after concatenating all heads
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        # Register a causal mask to prevent attending to future tokens
        #Again, this is a buffer, which means it's part of the model but doesn't need gradients.
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        )
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # Project input to keys, queries, and values
        #The output tensor shape is (batch, num_tokens, d_out)
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        # Reshape for multi-head: (batch, tokens, heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Move heads dimension forward: (batch, heads, tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute attention scores: (batch, heads, tokens, tokens)
        attn_scores = queries @ keys.transpose(2, 3)
        # Create boolean mask for causal attention
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        # Apply mask: set future positions to -inf
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # Softmax over last dimension to get attention weights
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        # Apply dropout to attention weights
        attn_weights = self.dropout(attn_weights)

        # Weighted sum of values
        context_vec = (attn_weights @ values).transpose(1, 2)
        # Concatenate all heads
        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )
        # Final output projection
        context_vec = self.out_proj(context_vec)
        return context_vec

In [93]:
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads = 2)
context_vecs = mha(batch)
display(context_vecs)
print(context_vecs.shape)

tensor([[[-0.0988,  0.3279],
         [-0.2344,  0.4464],
         [-0.2199,  0.4497],
         [-0.2460,  0.4768],
         [-0.2205,  0.4535],
         [-0.2223,  0.4507]],

        [[-0.0988,  0.3279],
         [-0.2344,  0.4464],
         [-0.2199,  0.4497],
         [-0.2460,  0.4768],
         [-0.2205,  0.4535],
         [-0.2223,  0.4507]]], grad_fn=<ViewBackward0>)

torch.Size([2, 6, 2])
