# Week 3: Attention, Please!

Attention mechanisms,  are a cornerstone of modern natural language processing and form the heart of GPT (Generative Pre-trained Transformer) models. 

We are building on top of [the paper introducing Transformer architecture by Vaswani et al](https://arxiv.org/abs/1706.03762). 

#### What is Attention?
Attention in neural networks allows the model to focus on specific parts of the input when producing an output. In the context of language models like GPT, attention enables the model to consider relevant words or phrases when predicting the next word, regardless of their position in the sentence.

GPT models use a specific type of attention called "self-attention" within their transformer blocks. Attention fits into the GPT architecture through:

1. Input Embedding: Words are converted into vectors.
2. Positional Encoding: Position information is added to these vectors.
3. Transformer Blocks: Multiple layers, each containing:
    - Multi-Head Attention: Allows the model to attend to different aspects of the input simultaneously.
    - Feed-Forward Neural Network: Processes the attention output.
4. Output Layer: Produces probabilities for the next word.

In this notebook, we'll implement and visualize these attention mechanisms, providing you with a deep understanding of how GPT models process and generate text.

Let's begin our journey into the world of attention!

## 0. Query, Key, Value

In [None]:
import math
import base64
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

from IPython.display import Image, display

In [None]:
from helpers.check_todo import check_implementation
from helpers.show_mermaid import mm

In [None]:
# Create the src directory if it doesn't exist
import os
os.makedirs('src', exist_ok=True)

In [None]:
# TODO: Define your example sentence
sentence = "Think of humanity on the path towards more unfathomable complexity"

Let's use a simplfied representation of this sentence where each element is either "0" or "1"with 3 dimensions. We could think of these 3 dimensions as representing very basic linguistic features. For example:
1. First dimension could represent "action" or "concreteness"
2. Second dimension: Could represent "relation" or "abstractness"
3. Third dimension: Could represent "complexity" or "plurality".

Thus, a word "think" can be represented as [1, 0, 0] since it is an action verb, concrete, singular, while "of" as [0, 1, 0] since it is a relational word, abstract, singular. 

In [None]:
word_to_vec = {
    "Think": [1, 0, 0],
    "of": [0, 1, 0],
    "humanity": [1, 0, 1],
    "on": [0, 1, 1],
    "the": [1, 1, 1],
    "path": [1, 0, 0],
    "towards": [1, 1, 0],
    "more": [0, 0, 1],
    "unfathomable": [1, 0, 1],
    "complexity": [0, 1, 1]
}

In [None]:
for word, vector in word_to_vec.items():
    print(f"{word}: {vector}")

In [None]:
words = sentence.split()
query = torch.tensor([word_to_vec["path"]], dtype=torch.float)  # We're querying with the specific word
key = torch.tensor([word_to_vec[word] for word in words], dtype=torch.float)
value = torch.tensor([[i] for i in range(len(words))], dtype=torch.float)

In [None]:
print("Query shape:", query.shape)
print("Key shape:", key.shape)
print("Value shape:", value.shape)

The attention scores are computed by taking the dot product of the query vector with each key vector (word in the sentence in this case). The dot product is a measure of similarity. So in this simple representation:
- Words with the same vector as "path" ([1, 0, 0]) will have a score of 1.
- Words with completely different vectors will have a score of 0.
- Words with some overlap will have a score between 0 and 1.

In [None]:
scores = query @ key.T
print("\nAttention Scores:", scores)

The softmax function turns these scores into a probability distribution. This step:
1. Ensures all weights are between 0 and 1.
2. Makes the weights sum to 1.
3. Enhances the differences between scores, making high scores even higher and low scores lower.

In [None]:
attention_weights = F.softmax(scores, dim=-1)
print("Attention Weights:", attention_weights)

In [None]:
# Visualize attention weights
# Think what words contribute most to the query word's representation
query_word = "path"

plt.figure(figsize=(12, 2))
plt.imshow(F.softmax(query @ key.T, dim=-1).detach(), cmap='viridis', aspect='auto')
plt.title(f"Attention Weights for Query: '{query_word}'")
plt.xticks(range(len(words)), words, rotation=45, ha='right')
plt.colorbar()
plt.tight_layout()
plt.show()

We multiply each value (which in this case is just the position of the word) by its corresponding attention weight and sum the results. In math terms this is matrix multiplication. The result represents a "weighted average" position in the sentence, where positions of words with higher attention weights contribute more to the final sum.

In [None]:
# Show the step-by-step calculation for weighted average
print("\nStep-by-step calculation for a weghted sum:")
for i, (weight, val) in enumerate(zip(attention_weights[0], value)):
    print(f"Word {i}: {weight.item():.2f} * {val.item()} = {weight.item()*val.item():.2f}")

In [None]:
output = attention_weights @ value
print("Weighted sum also referred to as a context vector:", output)

In [None]:
mm("""
graph TD
    A[Input Sequence] --> B[Query Q]
    A --> C[Key K]
    A --> D[Value V]
    B --> E[Compute Attention Scores]
    C --> E
    E --> F[Apply Softmax]
    F --> G[Weighted Sum]
    D --> G
    G --> H[Output]
""")

## 1. Dot-Product Attention

In [None]:
def dot_product_attention(query, key, value):
    # TODO: Implement dot-product attention    
    # Hint: Compute attention scores, apply softmax, and compute weighted sum
    
    pass

In [None]:
try:
    check_implementation(dot_product_attention)
except NotImplementedError as e:
    print(e)

In [None]:
output = dot_product_attention(query, key, value)
print("Dot-Product Attention Output:", output)

## 2. Scaled Dot-Product Attention

In scaled attention, we introduce a scaling factor of a square root of the dimnesion of key vectors. You might ask why scaling is needed... There are at least 3 reasons:
1. As the dimension of the key vectors increases, the dot products tend to grow larger in magnitude.
2. Large dot products push the softmax function into regions where it has extremely small gradients.
3. This can lead to very peaked (near-one-hot) distributions after softmax, or to vanishing gradients.

In [None]:
def scaled_dot_product_attention(query, key, value):
    # TODO: Implement scaled dot-product attention    
    # Hint: Similar to dot-product attention, but scale the scores
    
    pass

In [None]:
try:
    check_implementation(scaled_dot_product_attention)
except NotImplementedError as e:
    print(e)

In [None]:
output_scaled = scaled_dot_product_attention(query, key, value)
print("Scaled Dot-Product Attention Output:", output_scaled)

## 3. Self-Attention

In self-attention, instead of using the input directly as query, key, and value, we create projections of the input. These projections allow the model to transform the input into different representation spaces (and different dimensions). 

Vector projection uses the dot product because the dot product provides a measure of how much one vector aligns with another, which is exactly what projection aims to capture. 

In [None]:
x = torch.tensor([word_to_vec[word] for word in words], dtype=torch.float)

We'd now review the concept of linear transfomation and `nn.Linear`. When we apply the linear layer, it performs the following operation for each sequence element:
output = x @ weight.T + bias

In [None]:
# Let's create a simple example
input_dim = x.shape[1]
out_dim = 3
seq_length = 2
batch_size = 1

# Create an nn.Linear layer
linear_layer = nn.Linear(input_dim, out_dim)

# Print the weight matrix and bias
print("Weight matrix shape:", linear_layer.weight.shape)
print("Weight matrix:")
print(linear_layer.weight.data)
print("\nBias shape:", linear_layer.bias.shape)
print("Bias:")
print(linear_layer.bias.data)

In [None]:
print("\nInput shape:", x.shape)
print("Input:")
print(x)

In [None]:
# Apply the linear transformation
output = linear_layer(x)

print("\nOutput shape:", output.shape)
print("Output:")
print(output)

In [None]:
# Verification with the dot product
manual_output = torch.matmul(x, linear_layer.weight.t()) + linear_layer.bias
print("\nManual calculation output:")
print(manual_output)
print("\nDoes manual calculation match nn.Linear output?", torch.allclose(output, manual_output))

The key difference between this and regular scaled dot-product attention is the use of learned projections (Q, K, V) instead of using the input directly. This self-attention mechanism allows each position in the sequence to attend to all positions, including itself.
The `nn.Linear` layers are crucial here:
1. They introduce learnable parameters, allowing the model to adapt how it projects the input for attention computation.
2. They enable the model to transform the input in ways that are most useful for the task at hand, which is learned during training.

In [None]:
class SimpleSelfAttention(nn.Module):
    def __init__(self, input_dim, out_dim):
        super(SimpleSelfAttention, self).__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim

        # Create linear projections for Q, K, V
        self.q_proj = nn.Linear(input_dim, out_dim)
        self.k_proj = nn.Linear(input_dim, out_dim)
        self.v_proj = nn.Linear(input_dim, out_dim)
        
        self.fc_out = nn.Linear(out_dim, input_dim) # -> (batch_size, seq_length, input_dim)

    def forward(self, x):
        #  Implement SimpleSelfAttention class forward()
        # Hint: Use self.q_proj, self.k_proj, and self.v_proj to create projections from the input
        # Apply the scaled dot-product attention mechanism using these projections
        # Use self.fc_out for the final projection

        pass

In [None]:
try:
    check_implementation(SimpleSelfAttention)
except NotImplementedError as e:
    print(e)
# even if you correctly implement forward function for this nn.Module you'll still get a warning
# do you know why?
# if not, go forward and check the next class implementation

In [None]:
self_attention = SimpleSelfAttention(input_dim, out_dim)
output, attention_weights = self_attention(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

In [None]:
# Visualize attention weights
plt.figure(figsize=(10, 8))
plt.imshow(attention_weights.detach().numpy(), cmap='viridis')  # Convert to numpy if needed
plt.title("Self-Attention Weights")
plt.xlabel("Key")
plt.ylabel("Query")
plt.xticks(range(len(words)), words, rotation=45, ha='right')
plt.yticks(range(len(words)), words)
plt.colorbar()
plt.tight_layout()
plt.show()

## 4. Causal Attention with Masking

With GPT models, the autoregressive property is important to maintain. It refers to a characteristic of certain models, where the prediction of the current step or value depends only on the information from the previous steps or values.

Causal attention, also known as masked attention, is a specialized form of self-attention.

It restricts a model to only consider previous and current inputs in a sequence when processing any given token

From the math perspective, `torch.triu` creates an upper triangular matrix from an input tensor with teh main diagonal and above. It's often used to create masks for causal attention in transformer models.

In [None]:
matrix = torch.ones(5, 5)
print("Original matrix:")
print(matrix)

In [None]:
upper_triangular = torch.triu(matrix)
print("\nUpper triangular matrix:")
print(upper_triangular)

In [None]:
# Offset the diagonal
offset_triangular = torch.triu(matrix, diagonal=1)
print("\nUpper triangular matrix with offset 1:")
print(offset_triangular)

In [None]:
# Make masked values be ignored also after Softmax
neg_inf = upper_triangular.masked_fill(upper_triangular == 0, float('-inf'))
neg_inf

Any `-inf` value in scores will result in a zero probability after applying the Softmax function, meaning those positions are ignored. This effectively ensures that each token can only attend to itself and the tokens before it in the sequence.

In [None]:
# Creating a boolean mask
mask = torch.triu(matrix, diagonal=1).bool()
print("\nBoolean mask for upper triangular with offset 1:")
print(mask)

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, input_dim, out_dim):
        super(CausalSelfAttention, self).__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim

        self.q_proj = nn.Linear(input_dim, out_dim)
        self.k_proj = nn.Linear(input_dim, out_dim)
        self.v_proj = nn.Linear(input_dim, out_dim)
        
        self.fc_out = nn.Linear(out_dim, input_dim)

    def forward(self, x):

        batch_size, seq_length, _ = x.size()

        # TODO: Implement CausalSelfAttention class forward pass
        # Hint: 1. re-use SimpleSelfAttention
        # 2. create causal mask using torch.triu over x 
        # 3. update mask zero values to -inf before Softmax

        pass

In [None]:
try:
    check_implementation(CausalSelfAttention)
except NotImplementedError as e:
    print(e)
# you will see here the similar warning as in SimpleSelfAttention
# where is the 'TODO' placeholder then?
# and you can move on once your forward pass is ready

In [None]:
x = x.unsqueeze(0)  # Add batch dimension
x.shape

In [None]:
causal_attention = CausalSelfAttention(input_dim, out_dim)
output, attention_weights = causal_attention(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

In [None]:
# Visualize attention weights
plt.figure(figsize=(10, 8))
plt.imshow(attention_weights[0].detach().numpy(), cmap='viridis')  # Convert to numpy if needed
plt.title("Causal Self-Attention Weights")
plt.xlabel("Key")
plt.ylabel("Query")
plt.xticks(range(len(words)), words, rotation=45, ha='right')
plt.yticks(range(len(words)), words)
plt.colorbar()
plt.tight_layout()
plt.show()

In [None]:
# Print transformed vectors for each word
print("\nTransformed word vectors:")
for i, word in enumerate(words):
    print(f"{word}: {output[0][i].detach().numpy()}")

#### Question to you

We now see how an autoregressive setting works with triangular mask. And that's the approach in decoder attention block. Thinking about encoder block and its setting of all tokens communicating with each other, what should we update in the current implementation of decoder to allow for this communication as in encoders? 

## 5. Multi-Head Attention

In [None]:
batch_size, seq_length, d_model = x.shape
num_heads = 3  # We'll use 3 heads since d_model is 3
head_dim = d_model // num_heads  # This will be 1 in this case

print(f"Input shape: {x.shape}")

In [None]:
# Linear projections
W_q = nn.Linear(d_model, d_model)
W_k = nn.Linear(d_model, d_model)
W_v = nn.Linear(d_model, d_model)

# Apply projections
Q = W_q(x)
K = W_k(x)
V = W_v(x)

print(f"\nShape after linear projection:")
print(f"Q shape: {Q.shape}")

#### Question to you

It won't be in our scope but have you thought of a "proper" weight initialization? If not but you want to think of it, check `xavier_uniform_`.

In [None]:
# Transpose to get dimensions [batch_size, num_heads, seq_len, head_dim]
Q = Q.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
K = K.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
V = V.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)

print(f"\nShape after reshaping for multi-head:")
print(f"Q shape: {Q.shape}")

In [None]:
# Perform attention for each head
scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5)
attention_weights = torch.softmax(scores, dim=-1)
head_outputs = torch.matmul(attention_weights, V)

print(f"\nShape of each head's output:")
print(f"head_outputs shape: {head_outputs.shape}")

In [None]:
# Transpose and concatenate heads [batch_size, seq_len, d_model]
concat_heads = head_outputs.transpose(1, 2).contiguous().view(batch_size, seq_length, d_model)

print(f"\nShape after concatenating heads:")
print(f"concat_heads shape: {concat_heads.shape}")

In [None]:
# Project final output
W_o = nn.Linear(d_model, d_model)
output = W_o(concat_heads)

print(f"\nFinal output shape:")
print(f"output shape: {output.shape}")

In [None]:
# Print attention weights for the first head
print(f"\nAttention weights for the first head:")
print(attention_weights[0, 0])

In [None]:
%%writefile src/multiattention.py
import torch
import torch.nn as nn
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model # the output dimensionality
        self.num_heads = num_heads # the number of attention heads
        self.head_dim = d_model // num_heads # dimensionality of each head output

        # Making sure that each head will process a portion of the output
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"

        # Initialize weights for Q, K, V projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        # Scaling factor
        #self.scale = self.head_dim ** -0.5
        self.scale = 1 / math.sqrt(self.head_dim)

    def forward(self, query, key, value, mask=None):

        batch_size = query.size(0)

        # TODO: Implement a multi-head attention class forward pass
        
        # Hint: 1. Apply linear projections to self.W_q, self.W_k and self.W_v
        # 2. Use previously created CausalSelfAttention 
        # 3. Update attention scores using self.scale instead of math.sqrt(self.out_dim)
        # 4. Concatenate heads with final output projection

        pass

In [None]:
from src.multiattention import MultiHeadAttention

try:
    check_implementation(MultiHeadAttention)
except NotImplementedError as e:
    print(e)
# you can move on once your forward pass is ready

In [None]:
mha = MultiHeadAttention(d_model, num_heads)
output, attention_weights = mha(x, x, x)

In [None]:
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

## 5. Attention Mechanism Visualization

In [None]:
def visualize_attention(attention_weights, words=None, title="Attention Weights"):
    """
    Visualize attention weights as a heatmap.
    
    Parameters:
    - attention_weights: torch.Tensor or numpy array of shape (seq_length, seq_length)
    - words: list of words corresponding to the sequence (optional)
    - title: title for the plot
    """
    if isinstance(attention_weights, torch.Tensor):
        attention_weights = attention_weights.detach().cpu().numpy()
    
    fig, ax = plt.subplots(figsize=(10, 8))   
    sns.heatmap(attention_weights, 
                annot=True, 
                cmap='viridis', 
                ax=ax, 
                cbar=True,
                fmt='.2f')
    
    ax.set_title(title)
    if words:
        ax.set_xticks(range(len(words)))
        ax.set_yticks(range(len(words)))
        ax.set_xticklabels(words, rotation=45, ha='right')
        ax.set_yticklabels(words, rotation=0)
        ax.set_xlabel('Key')
        ax.set_ylabel('Query')
    else:
        ax.set_xlabel('Sequence Position (Key)')
        ax.set_ylabel('Sequence Position (Query)')
    
    # Adjust layout and display
    plt.tight_layout()
    plt.show()

In [None]:
# TODO: Experiment with the input
# Not checking this but highly recommending

sentence = "Think of humanity on the path towards more unfathomable complexity"
words = sentence.split()

# Create input tensor
word_to_vec = {
    "Think": [1, 0, 0], 
    "of": [0, 1, 0], 
    "humanity": [1, 0, 1],
    "on": [0, 1, 1], 
    "the": [1, 1, 1], 
    "path": [1, 0, 0],
    "towards": [1, 1, 0], 
    "more": [0, 0, 1], 
    "unfathomable": [1, 0, 1],
    "complexity": [0, 1, 1]
}

x = torch.tensor([word_to_vec[word] for word in words], dtype=torch.float).unsqueeze(0)

d_model = 3  # Dimension of our word vectors
num_heads = 3  # We'll use 3 heads for simplicity, but you can increase this
mha = MultiHeadAttention(d_model, num_heads)

# Get attention weights
_, attention_weights = mha(x, x, x)

# Visualize
for i in range(num_heads):
    visualize_attention(attention_weights[0, i], words, f"Attention Weights - Head {i+1}")

#### Congratulations! You've implemented various attention mechanisms. These will be crucial components in building your LLM in the coming weeks.