# Understanding Attention Mechanism in Deep Learning

In this notebook, we will explore the concept of attention mechanisms in deep learning, particularly focusing on self-attention, scaled dot-product attention, multi-head attention, and masking. The attention mechanism is a fundamental component of modern architectures like the Transformer, which powers models such as BERT and GPT.

We will cover:
- **Query, Key, and Value (QKV) Representation**
- **Attention Formula and Intuition**
- **Scaling and Scaled Dot-Product Attention**
- **Multi-Head Attention**
- **Masking Techniques in Attention**

## 1. Query, Key, and Value (QKV)

The attention mechanism works by computing a weighted sum of values based on the similarity between queries and keys.
- **Query (Q):** The input vector for which we want to compute attention.
- **Key (K):** The reference vectors against which the query is compared.
- **Value (V):** The actual content vectors that are aggregated based on attention scores.

Each token in the input sequence is transformed into Q, K, and V using learnable weight matrices.

In [None]:
import torch
import torch.nn.functional as F

# Define input tensor (batch_size=1, seq_length=3, embedding_dim=4)
input_tensor = torch.randn(1, 3, 4)

# Define weight matrices for Q, K, V
W_q = torch.randn(4, 4)
W_k = torch.randn(4, 4)
W_v = torch.randn(4, 4)

# Compute Q, K, V
Q = input_tensor @ W_q
K = input_tensor @ W_k
V = input_tensor @ W_v

print("Query (Q):", Q)
print("Key (K):", K)
print("Value (V):", V)

## 2. Attention Formula

The core of the attention mechanism is:
\[
   \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V
\]

Where:
- \( QK^T \) computes similarity scores between queries and keys.
- Scaling by \( \sqrt{d_k} \) prevents large dot-product values from leading to small gradients.
- Softmax normalizes the scores into probabilities.
- These probabilities weight the values (V).

In [None]:
# Compute attention scores
dk = Q.shape[-1]
attention_scores = (Q @ K.transpose(-2, -1)) / torch.sqrt(torch.tensor(dk, dtype=torch.float32))
attention_weights = F.softmax(attention_scores, dim=-1)

# Compute attention output
attention_output = attention_weights @ V

print("Attention Weights:", attention_weights)
print("Attention Output:", attention_output)

## 3. Scaled Dot-Product Attention
The scaled dot-product attention is the operation defined above. The division by \( \sqrt{d_k} \) is critical for stabilizing training.

## 4. Multi-Head Attention
Instead of using a single attention function, **multi-head attention** employs multiple attention heads to learn different aspects of the input representation.

Each head has independent weight matrices for Q, K, and V. The outputs from all heads are concatenated and linearly transformed.

In [None]:
# Example of Multi-Head Attention with 2 heads
num_heads = 2
head_dim = dk // num_heads

# Splitting Q, K, V into multiple heads
Q_heads = Q.view(1, 3, num_heads, head_dim).transpose(1, 2)
K_heads = K.view(1, 3, num_heads, head_dim).transpose(1, 2)
V_heads = V.view(1, 3, num_heads, head_dim).transpose(1, 2)

# Compute attention per head
attention_scores_heads = (Q_heads @ K_heads.transpose(-2, -1)) / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
attention_weights_heads = F.softmax(attention_scores_heads, dim=-1)
attention_output_heads = attention_weights_heads @ V_heads

# Concatenate heads
multi_head_output = attention_output_heads.transpose(1, 2).reshape(1, 3, dk)

print("Multi-Head Attention Output:", multi_head_output)

## 5. Masking in Attention

There are two main types of masks used in attention mechanisms:
- **Padding Mask:** Ensures that padding tokens do not influence the attention scores.
- **Lookahead Mask:** Used in autoregressive models (e.g., GPT) to prevent attention to future tokens.


In [None]:
# Example of Masking
mask = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]])  # Upper-triangular mask
masked_attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
masked_attention_weights = F.softmax(masked_attention_scores, dim=-1)

print("Masked Attention Weights:", masked_attention_weights)