In [1]:
# llama attention mechanism

In [11]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA available:", torch.cuda.is_available())
    print("Number of GPUs:", torch.cuda.device_count())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CUDA available: True
Number of GPUs: 1


In [4]:
# --- Configuration --- # 

hidden_size = 128
# Dimension of the embedding vector for each token.
# Example: every token (like "dog") is represented by a vector of 128 numbers.

num_attention_heads = 16
# How many attention heads we use in parallel.
# The hidden_size (128) is split across these heads.
# Each head sees only part of the vector → 128 / 16 = 8 dimensions per head.

num_key_value_heads = 4
# Special trick: Grouped-Query Attention (GQA).
# Instead of creating 16 different Key/Value heads,
# we only create 4 K/V heads (each of size 8) and let multiple Q heads share them.
# → This saves memory and computation while keeping good performance.

head_dim = hidden_size // num_attention_heads
# Size of each head’s Q, K, V vector.
# With hidden_size=128 and 16 heads, each head works in 8 dimensions.

max_position_embeddings = 256
# Maximum sequence length (number of tokens) the model can process at once.
# If a sentence has more than 256 tokens, it must be truncated or split.

rope_theta = 10000.0
# Base frequency for Rotary Position Embeddings (RoPE).
# It controls how positional information is encoded.
# Larger theta = slower change in frequency = smoother positional encoding.

rms_norm_eps = 1e-5
# Tiny constant added inside RMSNorm to avoid division by zero.
# Ensures stability in training and inference.

attention_bias = False
# Whether to add a bias term to the linear layers that produce Q, K, V.
# Usually kept False for efficiency.

attention_dropout = 0.0
# Dropout probability applied to attention weights (to prevent overfitting).
# Often set to 0.0 during inference (disabled).

use_qk_norm = True
# Whether to normalize Q and K vectors (L2 norm) before computing attention scores.
# This keeps dot products more stable and avoids extreme attention weights

In [5]:
# --- Sample input setup ---

batch_size = 2  
# Number of independent sequences (context windows) processed in parallel.
# Example: 2 separate sentences.

sequence_length = 10  
# Number of tokens in each sequence (the length of the context window).

hidden_states = torch.randn(batch_size, sequence_length, hidden_size)  
# Random embeddings for each token in each sequence.
# Shape = (batch_size, sequence_length, hidden_size)
#        = (2, 10, 128)
# Meaning:
# - 2 sequences
# - Each sequence has 10 tokens
# - Each token is represented by a 128-dimensional vector

# --- Position IDs creation ---

position_ids = torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1)
# position_ids = torch.arange(0, sequence_length)  
# → Shape: (10,)
# → [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# position_ids = position_ids.unsqueeze(0)  
# Add a new dimension at the front
# → Shape: (1, 10)
# → [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]

# position_ids = position_ids.repeat(batch_size, 1)  
# Repeat the row for each sequence in the batch
# → Shape: (2, 10)
# → [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
#    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]

# Intuition:
# - Each token in each sequence needs a position ID.
# - Both sequences start at position 0, because they are independent windows.

# Create a causal attention mask
# Goal: make sure each token can only see itself and tokens before it (no looking into the future)

attention_mask = torch.triu(torch.ones(sequence_length, sequence_length) * -torch.inf, diagonal=1)
# Step 1: Make a square matrix (seq x seq).
# -∞ above the diagonal = future tokens (blocked)
# 0 on and below diagonal = current/past tokens (allowed)

attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq, seq)
# Step 2: Add two extra dimensions so the mask matches attention shapes.
# Now we have 4D: [1, 1, seq, seq]

attention_mask = attention_mask.expand(batch_size, 1, -1, -1)  # Shape: (batch, 1, seq, seq)
# Step 3: Copy the mask for each sequence in the batch.
# The "1" in heads dimension means the same mask is shared across all attention heads.
print("Configuration:")
print(f"  hidden_size: {hidden_size}")
print(f"  num_attention_heads: {num_attention_heads}")
print(f"  num_key_value_heads: {num_key_value_heads}")
print(f"  head_dim: {head_dim}")

print("\nSample Input Shapes:")
print(f"  hidden_states: {hidden_states.shape}")
print(f"  position_ids: {position_ids.shape}")
print(f"  attention_mask: {attention_mask.shape}")

Configuration:
  hidden_size: 128
  num_attention_heads: 16
  num_key_value_heads: 4
  head_dim: 8

Sample Input Shapes:
  hidden_states: torch.Size([2, 10, 128])
  position_ids: torch.Size([2, 10])
  attention_mask: torch.Size([2, 1, 10, 10])


In [6]:
# ## Q, K, V Projections
#
# The first step of attention: project the input hidden states into
# Query (Q), Key (K), and Value (V) spaces using linear layers.
#
# - Q = "what am I looking for?" (the current token’s query)
# - K = "what can I offer?" (the key of each token in the sequence)
# - V = "what information do I carry?" (the value of each token)
#
# In Llama (and many modern transformers) uses GQA = Grouped-Query Attention:
# - There are more Q heads (16 here) than K/V heads (4 here).
# - Multiple Q heads share the same K/V heads.
# - This reduces memory/computation without losing much performance.


# --- Define linear projection layers ---
# Each Linear layer creates a weight matrix W and (optionally) a bias vector b.
# PyTorch stores W with shape (out_features, in_features).
# During forward, the computation is:  output = input @ W.T + b

# Q projection: from hidden_size=128 → num_attention_heads * head_dim = 16*8=128
# So shape of Wq is (128, 128). Each token gets projected into 16 Q-heads (each of size 8).
q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)

# K projection: from hidden_size=128 → num_key_value_heads * head_dim = 4*8=32
# So shape of Wk is (32, 128). Each token gets projected into 4 K-heads (each of size 8).
k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)

# V projection: same as K (4 heads of size 8).
v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)

# O projection: after attention, we concatenate the 16 Q-head outputs back into
# a single vector (size 128). o_proj maps it back into hidden_size=128.
o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)


# --- Apply projections to the hidden states ---
# hidden_states has shape [batch_size, seq_len, hidden_size]
# Each linear layer is applied independently to every token in the batch.

# Q projection: output shape [batch_size, seq_len, num_attention_heads*head_dim] = [B, S, 128]
query_states = q_proj(hidden_states)

# K projection: output shape [batch_size, seq_len, num_key_value_heads*head_dim] = [B, S, 32]
key_states = k_proj(hidden_states)

# V projection: output shape [batch_size, seq_len, num_key_value_heads*head_dim] = [B, S, 32]
value_states = v_proj(hidden_states)

# Reshape queries, keys, values into [B, num_heads, S, head_dim] for multi-head attention
# --------------------------------------------------------------
# Before view: 
#   query_states shape = [B, S, hidden_size] = [1, 3, 4]
#   (for each token we just have 4 numbers, flattened)
#
# Step 1 (view): cut hidden_size=4 into (num_heads=2, head_dim=2)
#   query_states.view(1, 3, 2, 2) → [1, 3, 2, 2]
#
#   For token 0: [q0, q1 | q2, q3]   → head0=[q0,q1], head1=[q2,q3]
#   For token 1: [q4, q5 | q6, q7]   → head0=[q4,q5], head1=[q6,q7]
#   For token 2: [q8, q9 | q10,q11]  → head0=[q8,q9], head1=[q10,q11]
#
# Step 2 (transpose): put heads dimension before sequence length
#   query_states.transpose(1, 2) → [1, 2, 3, 2]
#
#   Now we can think like:
#     for b in batch:        # here b=0
#       for h in heads:      # h=0..1
#         query_states[b,h,:,:] = all tokens for this head
#
#   Example:
#     query_states[0,0,:,:] = [[q0,q1], [q4,q5], [q8,q9]]  # head 0 across all tokens
#     query_states[0,1,:,:] = [[q2,q3], [q6,q7], [q10,q11]]# head 1 across all tokens
#
# Keys and values go through the same reshape+transpose,
# but with num_key_value_heads=4 in Llama (instead of 16 for queries).
# So shapes end up:
#   Q: [B, num_heads, S, head_dim]
#   K: [B, num_kv_heads, S, head_dim]
#   V: [B, num_kv_heads, S, head_dim]
query_states = query_states.view(batch_size, sequence_length, num_attention_heads, head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1, 2)

print("Projected Shapes:")
print(f"  query_states: {query_states.shape}") # (batch_size, num_attention_heads, sequence_length, head_dim)
print(f"  key_states: {key_states.shape}")     # (batch_size, num_key_value_heads, sequence_length, head_dim)
print(f"  value_states: {value_states.shape}")   # (batch_size, num_key_value_heads, sequence_length, head_dim)

num_key_value_groups = num_attention_heads // num_key_value_heads
print(f"\nNum Key/Value Groups (Q heads per K/V head): {num_key_value_groups}")

Projected Shapes:
  query_states: torch.Size([2, 16, 10, 8])
  key_states: torch.Size([2, 4, 10, 8])
  value_states: torch.Size([2, 4, 10, 8])

Num Key/Value Groups (Q heads per K/V head): 4


In [None]:
# ---------------------------------------------------------------------------
# Rotary Positional Embeddings (RoPE)
# ---------------------------------------------------------------------------
#
# Transformers need to know the order of tokens in a sequence.
# Instead of adding positional vectors (absolute embeddings),
# RoPE injects positional information by ROTATING the Q and K vectors
# before the dot product.
#
# Key idea:
# - Each pair of dimensions in Q and K is treated like coordinates in 2D.
# - We rotate them by an angle that depends on the token's position.
# - Different dimension pairs rotate at different speeds (frequencies).
#
# Why?
# - This encodes *relative position* information directly in Q and K.
# - It improves performance, especially for long sequences.
#
# Implementation notes:
# - Angles are computed from positions × frequencies.
# - Cosine and sine represent the real/imag parts of the rotation (Euler’s formula).
# - The result is stored as complex numbers (cos + i·sin).
# - Later, these rotations are applied to Q and K in `apply_rotary_emb`.

def simple_rope_calculation(dim, max_seq_len, base=10000.0, device=None):
    """
    Simplified calculation of Rotary Positional Embedding (RoPE) frequencies.

    Goal:
    - Instead of adding position vectors, RoPE rotates Q and K vectors by an angle
      that depends on the token’s position.
    - These angles are created here and later applied to Q and K.
    """

    # Step 1: Create "inverse frequencies"
    # Each pair of dimensions (0-1, 2-3, ...) gets a different rotation speed.
    # Smaller indices → higher frequency, larger indices → lower frequency.
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))

    # Step 2: Positions of tokens (0,1,2,...,max_seq_len-1)
    t = torch.arange(max_seq_len, device=device).type_as(inv_freq)

    # Step 3: Outer product between positions and frequencies
    # freqs[p, f] = position * frequency
    # This gives us the rotation angle for each (position, frequency pair).
    freqs = new_func(inv_freq, t)  # shape: [max_seq_len, dim/2]

    # Step 4: Duplicate freqs to cover the full head dimension
    # (because each pair of dims uses the same angle: one for cos, one for sin)
    emb = torch.cat((freqs, freqs), dim=-1)  # shape: [max_seq_len, dim]

    # Step 5: Compute cosine and sine
    # cos = real part of rotation, sin = imaginary part
    # Based on Euler’s formula: e^(iθ) = cos(θ) + i·sin(θ)
    freqs_cos = emb.cos()  # shape [max_seq_len, dim]
    freqs_sin = emb.sin()  # shape [max_seq_len, dim]

    # Step 6: Combine into a complex number (cos + i·sin)
    # This is the "rotation operator" we will apply to Q and K.
    freqs_cis = torch.complex(freqs_cos, freqs_sin)  # shape: [max_seq_len, dim]

    return freqs_cis


def new_func(inv_freq, t):
    # Outer product: multiplies every position with every frequency
    # Example: if positions=[0,1,2] and freqs=[f0,f1],
    # result = [[0*f0, 0*f1],
    #           [1*f0, 1*f1],
    #           [2*f0, 2*f1]]
    return torch.outer(t, inv_freq)


In [None]:
def apply_rotary_emb_torch(
    xq: torch.Tensor,       # Queries: shape [batch, num_heads, seq_len, head_dim]
    xk: torch.Tensor,       # Keys:    shape [batch, num_heads, seq_len, head_dim]
    freqs_cis: torch.Tensor # Precomputed complex rotations: shape [max_seq_len, head_dim]
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply Rotary Positional Embeddings (RoPE) to Q and K.

    Idea:
    - Each pair of dimensions in Q and K is seen as a 2D vector.
    - We rotate that 2D vector by an angle that depends on the token's position.
    - Angles come from freqs_cis (cos + i·sin, precomputed).
    """

    # 1. Move rotation frequencies to the same device (CPU/GPU) as Q
    freqs_cis = freqs_cis.to(xq.device)

    # 2. Pick the rotation angles for the actual token positions
    #    Example: position_ids = [[0,1,2,...]] → selects rows 0,1,2... from freqs_cis
    #    Result: [batch, seq_len, head_dim] (complex)
    freqs_cis = freqs_cis[position_ids]

    # 3. Add a "heads" axis for broadcasting
    #    We want the same rotation applied to all heads of the same token
    #    Shape becomes: [batch, 1, seq_len, head_dim] (complex)
    freqs_cis = freqs_cis[:, None, :, :]

    # ---------------- Prepare Q and K as complex numbers ----------------

    # 4. Reshape last dim so that pairs of values become complex numbers
    #    Example: [q0, q1] → q0 + i*q1
    #    Shape: [B, H, S, head_dim] → [B, H, S, head_dim//2, 2] → complex → [B, H, S, head_dim//2]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # ---------------- Prepare freqs_cis ----------------

    # 5. Keep only the first half of freqs_cis (unique frequencies)
    #    Because each complex number already represents a pair of dims
    #    Shape: [B, 1, S, head_dim//2] (complex)
    freqs_cis_broadcast = freqs_cis[..., :xq_.shape[-1]]

    # ---------------- Apply the rotation ----------------

    # 6. Rotate Q and K using element-wise complex multiplication
    #    Multiplying by (cos + i·sin) applies the rotation
    #    Broadcast: the '1' in freqs_cis_broadcast shares the same rotation across all heads
    rotated_xq = xq_ * freqs_cis_broadcast
    rotated_xk = xk_ * freqs_cis_broadcast

    # ---------------- Convert back to real ----------------

    # 7. Convert complex back to real pairs, then flatten
    #    Shape: [B, H, S, head_dim//2] complex
    #        -> view_as_real -> [B, H, S, head_dim//2, 2] real
    #        -> flatten last two dims -> [B, H, S, head_dim] real
    xq_out = torch.view_as_real(rotated_xq).flatten(3)
    xk_out = torch.view_as_real(rotated_xk).flatten(3)

    # 8. Cast back to the same dtype as input (e.g., float16 for efficiency)
    return xq_out.type_as(xq), xk_out.type_as(xk)


In [None]:
# --------------------------------------------------------------
# Precompute RoPE frequencies
# --------------------------------------------------------------
# RoPE works on the head_dim (per head), not the full hidden_size.
# We generate rotation frequencies (cos + i·sin) for all possible positions
# up to max_position_embeddings.
freqs_cis = simple_rope_calculation(
    head_dim,                  # dimension per head (e.g., 8)
    max_position_embeddings,   # maximum sequence length (e.g., 256)
    base=rope_theta,           # base frequency (default 10000.0)
    device=hidden_states.device
)
print(f"Calculated freqs_cis shape: {freqs_cis.shape}")  
# Expected: (max_position_embeddings, head_dim)


# --------------------------------------------------------------
# Apply RoPE to Q and K
# --------------------------------------------------------------
# Important: RoPE is applied BEFORE repeating K/V for grouped-query attention.
# This rotates Q and K vectors by position-dependent angles,
# injecting relative position information directly into them.
query_states_rope, key_states_rope = apply_rotary_emb_torch(
    query_states,   # shape [B, num_heads, S, head_dim]
    key_states,     # shape [B, num_kv_heads, S, head_dim]
    freqs_cis       # precomputed rotations
)

# --------------------------------------------------------------
# Step 3: Check shapes after applying RoPE
# --------------------------------------------------------------
print("\nShapes after RoPE:")
print(f"  query_states_rope: {query_states_rope.shape}")  # [B, num_heads, S, head_dim]
print(f"  key_states_rope:   {key_states_rope.shape}")    # [B, num_kv_heads, S, head_dim]


In [None]:
# --------------------------------------------------------------
# Step 3: Optional Q/K Normalization (Simple L2 Norm)
# --------------------------------------------------------------
# Goal:
# - Make sure Q and K vectors have a stable scale (not too big, not too small).
# - This helps when we compute attention scores = Q @ K^T,
#   so the softmax is stable (not exploding or vanishing).
#
# Idea:
# - For each vector along the last dimension (head_dim),
#   compute its size (RMS = sqrt(mean of squares)).
# - Divide the vector by this size so its magnitude is ~1.
# - Add a tiny epsilon (eps) to avoid division by zero.
# --------------------------------------------------------------

class SimpleL2Norm(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps  # small constant for numerical stability

    def forward(self, x):
        # x shape: [B, num_heads, seq_len, head_dim]

        # 1. Compute mean of squares along the last dimension (head_dim)
        #    keepdim=True → so result can broadcast back to x
        norm = x.pow(2).mean(-1, keepdim=True)

        # 2. Add epsilon and take reciprocal square root
        #    torch.rsqrt(y) = 1 / sqrt(y)
        scale = torch.rsqrt(norm + self.eps)

        # 3. Multiply x by this scale → normalize each vector
        return x * scale


# --------------------------------------------------------------
# Apply normalization (if enabled)
# --------------------------------------------------------------
if use_qk_norm:
    qk_norm = SimpleL2Norm()

    # Normalize queries and keys after RoPE
    query_states_final = qk_norm(query_states_rope)
    key_states_final   = qk_norm(key_states_rope)

    print("\nApplied QK Norm")  # confirms normalization was used
else:
    # Skip normalization: just use RoPE outputs
    query_states_final = query_states_rope
    key_states_final   = key_states_rope
    print("\nSkipped QK Norm")


# --------------------------------------------------------------
# Check final shapes before attention scores
# --------------------------------------------------------------
print("\nShapes before attention score calculation:")
print(f"  query_states_final: {query_states_final.shape}")  # [B, num_heads, S, head_dim]
print(f"  key_states_final:   {key_states_final.shape}")    # [B, num_kv_heads, S, head_dim]


In [None]:
# --------------------------------------------------------------
# Step 4: Grouped-Query Attention (GQA) - Key/Value Repeating
# --------------------------------------------------------------
# Context:
# - We already have Q, K, V after projection.
# - But in GQA, we have MORE Q heads than K/V heads.
#   Example: Q = 16 heads, K/V = 4 heads.
# - Solution: "repeat" each K/V head enough times so every Q head
#   has a matching K and V.
# - After repeating:
#     Q: [B, 16, S, head_dim]
#     K: [B, 16, S, head_dim]  <-- repeated from 4 heads
#     V: [B, 16, S, head_dim]  <-- repeated from 4 heads
# --------------------------------------------------------------

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    Repeat Key/Value heads for GQA.

    Args:
        hidden_states: tensor of shape [B, num_kv_heads, S, head_dim]
        n_rep: how many times each K/V head should be repeated
               (num_attention_heads // num_kv_heads)

    Returns:
        repeated tensor of shape [B, num_attention_heads, S, head_dim]
    """
    batch, num_kv_heads, slen, head_dim = hidden_states.shape

    if n_rep == 1:
        # No repeating needed (KV heads == Q heads)
        return hidden_states

    # Step 1: Insert a new axis after kv_heads for the repetitions
    # Shape becomes: [B, num_kv_heads, 1, S, head_dim]
    hidden_states = hidden_states[:, :, None, :, :]

    # Step 2: Expand along the new axis to repeat each KV head n_rep times
    # Shape: [B, num_kv_heads, n_rep, S, head_dim]
    hidden_states = hidden_states.expand(batch, num_kv_heads, n_rep, slen, head_dim)

    # Step 3: Merge kv_heads * n_rep into one dimension
    # Final shape: [B, num_kv_heads * n_rep, S, head_dim]
    return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)


# --------------------------------------------------------------
# Apply repeating to K and V
# --------------------------------------------------------------

# Repeat K: normalized and RoPE-applied
key_states_repeated = repeat_kv(key_states_final, num_key_value_groups)

# Repeat V: use original V (no RoPE, no norm applied to values)
value_states_repeated = repeat_kv(value_states, num_key_value_groups)

print("\nShapes after repeating K/V for GQA:")
print(f"  key_states_repeated:   {key_states_repeated.shape}")   # should match Q heads
print(f"  value_states_repeated: {value_states_repeated.shape}") # should match Q heads

In [None]:
# --------------------------------------------------------------
# Step 5: Scaled Dot-Product Attention
# --------------------------------------------------------------
# Context recap:
# - We now have:
#     Q = query_states_final        # [B, num_heads, S, head_dim]
#     K = key_states_repeated       # [B, num_heads, S, head_dim]  (after GQA repeat)
#     V = value_states_repeated     # [B, num_heads, S, head_dim]  (after GQA repeat)
# - We also have a causal attention_mask: [B, 1, S, S]
# - Goal: weights = softmax((Q @ K^T) / sqrt(d)) ; output = weights @ V
# --------------------------------------------------------------

# 1) Raw attention scores = Q @ K^T
#    For each batch and head, compare every query token (S) against every key token (S).
#    Shapes:
#      Q: [B, H, S, D]
#      K^T: transpose last two dims → [B, H, D, S]
#      result: [B, H, S, S]  (score of each query token vs each key token)
attn_weights = torch.matmul(query_states_final, key_states_repeated.transpose(2, 3))

# 2) Scale by 1 / sqrt(head_dim)
#    Reason: keep magnitudes stable so softmax doesn’t become too peaky or too flat.
scaling_factor = 1.0 / math.sqrt(head_dim)
attn_weights = attn_weights * scaling_factor

# 3) Apply causal mask (no looking ahead)
#    Mask shape is [B, 1, S, S] and broadcasts across heads.
#    We slice mask’s key-length dimension to match K/V sequence length if needed.
if attention_mask is not None:
    print(f"\nApplying attention mask with shape: {attention_mask.shape}")
    causal_mask = attention_mask[:, :, :, :key_states_repeated.shape[-2]]  # ensure key-dim match
    # Add large negative values (-inf) to forbidden positions so softmax → 0 there.
    attn_weights = attn_weights + causal_mask
else:
    print("\nNo attention mask applied.")

# 4) Softmax over keys dimension
#    Turn scores into probabilities along the last dim (which indexes keys).
#    Result stays [B, H, S, S].
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)

# 5) (Optional) Dropout on attention weights (used in training; skipped here for inference)
# attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=self.training)

# 6) Weighted sum of Values → attention output
#    Multiply probabilities [B, H, S, S] by V [B, H, S, D] along the key/token axis.
#    Each query token gets a weighted combination of all value tokens → [B, H, S, D].
attn_output = torch.matmul(attn_weights, value_states_repeated)

print("\nAttention Calculation Shapes:")
print(f"  attn_weights (after mask+softmax): {attn_weights.shape}")  # [B, H, S, S]
print(f"  attn_output: {attn_output.shape}")                         # [B, H, S, D]

In [None]:
# --------------------------------------------------------------
# Step 6: Reshape and Output Projection
# --------------------------------------------------------------
# Context recap:
# - After Step 5, attn_output = [B, num_heads, S, head_dim]
#   Each head produced its own [S, head_dim] output.
# - Next steps:
#   1) Concatenate all heads together (merge num_heads * head_dim).
#   2) Pass through a final linear layer (o_proj) to map back
#      to hidden_size, so the rest of the Transformer sees the
#      same shape it started with.
# --------------------------------------------------------------

# 1) Move heads dimension after sequence
#    From [B, H, S, D] → [B, S, H, D]
attn_output = attn_output.transpose(1, 2).contiguous()

# 2) Flatten heads into one big vector per token
#    Shape: [B, S, H*D] = [B, S, hidden_size]
attn_output = attn_output.view(batch_size, sequence_length, hidden_size)

# 3) Final linear projection back to hidden_size
#    This mixes information across heads and ensures output
#    has the same dimension as the model embedding (hidden_size).
final_attn_output = o_proj(attn_output)

print("\nFinal Output Shapes:")
print(f"  attn_output (reshaped): {attn_output.shape}")       # [B, S, hidden_size]
print(f"  final_attn_output: {final_attn_output.shape}")      # [B, S, hidden_size]