# Self Attention
Attention in LLMs is a mechanism that allows the model to focus on different parts of the input sequence when processing each token. In the context of transformers, it computes a weighted average of all tokens for each token, letting the model "attend" to relevant parts.

In [9]:
import torch
import torch.nn.functional as F

d = 4 # dimension

x = torch.rand(10, d)  # (tokens, embedding_dim) # 10 tokens - input sequence
print('Sample tokens:\n', x)

# Linear projections for queries, keys, values
# These weight matrices learn to transform input embeddings into Q, K, V representations
W_q = torch.rand(d, d)  # Query projection - "what am I looking for?"
W_k = torch.rand(d, d)  # Key projection - "what do I contain?"
W_v = torch.rand(d, d)  # Value projection - "what information do I actually provide?"

Q = x @ W_q   # Queries: each token asks "what should I attend to?"
K = x @ W_k   # Keys: each token says "here's what I represent"
V = x @ W_v   # Values: each token says "here's the info I contribute"

# Compute attention scores - how much should each token attend to every other token?
scores = Q @ K.T / (d ** 0.5)  # (10, 10), scaled dot product attention
# Apply softmax to get attention weights - normalize scores into probabilities

weights = F.softmax(scores, dim=-1)  # (10, 10)
# Each row sums to 1, representing attention distribution for that token

# Weighted sum of values - final attended representation
attended = weights @ V  # (10, d)
# For each token: mix all value vectors weighted by attention scores
# This is the core of self-attention: tokens can look at and incorporate 
# information from all other tokens in the sequence

print("\nAttention output:\n", attended)

Sample tokens:
 tensor([[0.2429, 0.1176, 0.3660, 0.8112],
        [0.3871, 0.6056, 0.0759, 0.2050],
        [0.4745, 0.5922, 0.1595, 0.4395],
        [0.9812, 0.0948, 0.9929, 0.4543],
        [0.3246, 0.2551, 0.8802, 0.0150],
        [0.8081, 0.8946, 0.3738, 0.1147],
        [0.5037, 0.5514, 0.4478, 0.7581],
        [0.4618, 0.4430, 0.6137, 0.6494],
        [0.6258, 0.6795, 0.7773, 0.3714],
        [0.1914, 0.7553, 0.4314, 0.8364]])

Attention output:
 tensor([[1.2994, 0.7965, 0.8056, 0.8360],
        [1.2843, 0.7925, 0.8000, 0.8249],
        [1.2966, 0.8001, 0.8077, 0.8337],
        [1.3276, 0.8143, 0.8229, 0.8607],
        [1.3016, 0.7962, 0.8054, 0.8418],
        [1.3099, 0.8110, 0.8177, 0.8451],
        [1.3166, 0.8113, 0.8191, 0.8491],
        [1.3166, 0.8094, 0.8177, 0.8501],
        [1.3235, 0.8153, 0.8231, 0.8565],
        [1.3165, 0.8114, 0.8192, 0.8485]])


# Cross Attention
Cross-attention is used in models like the Transformer decoder, where the model attends to a different sequence (e.g., encoder outputs) rather than the same sequence. This allows the decoder to focus on relevant parts of the input sequence while generating output.

# Multi-Head Attention
Multi-Head Attention is an extension of the attention mechanism used in transformers. Instead of calculating attention once, it does it multiple times in parallel, with each version called a “head.” Each head learns to focus on different parts or aspects of the input.

Multi-head attention means running attention multiple times in parallel, each time with different “views” of the input, then combining the results. This makes transformers much more powerful at modeling complex relationships in language.



- The input is projected into several smaller spaces (one for each head).
- Each head performs its own self-attention calculation independently.
- This means each head can capture different relationships or features in the data.
- The outputs from all heads are concatenated (joined together) and then transformed one more time with a linear layer.
- Multiple heads allow the model to learn different kinds of relationships at the same time, improving its understanding of the context.




In [3]:
import torch
import torch.nn as nn

# Multi-Head Attention module with clear explanations
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        # Initialize PyTorch's built-in MultiheadAttention layer
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
    
    def forward(self, x):
        # x shape: (sequence_len, batch, embedding_dim)
        print(f"Input tensor shape: {x.shape} (sequence_len, batch, embedding_dim)")
        # Self-attention: query, key, value are all x
        attn_output, attn_weights = self.attention(x, x, x)
        print(f"Attention output shape: {attn_output.shape} (sequence_len, batch, embedding_dim)")
        print(f"Attention weights shape: {attn_weights.shape} (batch, sequence_len, sequence_len)")
        return attn_output

# Example usage
x = torch.rand(10, 32, 512)  # 10 tokens, batch size 32, embedding dim 512
print("Random input tensor x created with shape:", x.shape)
mha = MultiHeadAttention(embed_dim=512, num_heads=8)
print("MultiHeadAttention module initialized with 8 heads.")
output = mha(x)
print("Final output tensor shape:", output.shape)

Random input tensor x created with shape: torch.Size([10, 32, 512])
MultiHeadAttention module initialized with 8 heads.
Input tensor shape: torch.Size([10, 32, 512]) (sequence_len, batch, embedding_dim)
Attention output shape: torch.Size([10, 32, 512]) (sequence_len, batch, embedding_dim)
Attention weights shape: torch.Size([32, 10, 10]) (batch, sequence_len, sequence_len)
Final output tensor shape: torch.Size([10, 32, 512])
