In [17]:
import math
import torch
from torch import nn

In [18]:
embed_dim = 4
num_heads = 4
dropout = 0.0

assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
head_dim = embed_dim // num_heads  # Dimensionality per head

# A single large linear layer for Q, K, V projections.
# This is more efficient than separate layers and matches PyTorch's implementation.
# Input: (Seq_Len, Batch_Size, Embed_Dim)
# Output: (Seq_Len, Batch_Size, 3 * Embed_Dim)
in_proj = nn.Linear(embed_dim, embed_dim * 3)

# Final output projection layer: projects concatenated heads back to embed_dim
out_proj = nn.Linear(embed_dim, embed_dim)

# Dropout layer for regularizing attention weights
attn_dropout = nn.Dropout(dropout)

# Scaling factor for dot-product attention (see Vaswani et al.)
scale = math.sqrt(head_dim)

In [19]:
"""
Forward pass for Multi-Head Attention.

Args:
        query: (L, B, E) - Target sequence length, Batch size, Embed_dim
        key:   (S, B, E) - Source sequence length, Batch size, Embed_dim
        value: (S, B, E) - Source sequence length, Batch size, Embed_dim
        attn_mask: Optional mask to prevent attention to certain positions (e.g., future tokens)
        key_padding_mask: Optional mask to ignore padding tokens in the key
        need_weights: If True, also return average attention weights

Returns:
        attn_output: (L, B, E) - Output of the attention layer
        attn_weights: (B, L, S) or None - Average attention weights over heads (if requested)
"""
# Example: 3 tokens in a sequence, batch size 1, embedding dim 4
seq_len = 3
batch_size = 1
embed_dim = 4
attn_mask = None
key_padding_mask = None
need_weights = True
torch.manual_seed(42)

# (Seq_Len, Batch, Embed_Dim)
query = torch.tensor(
    [
        [[1.0, 0.0, 0.0, 0.0]],
        [[0.0, 1.0, 0.0, 0.0]],
        [[0.0, 0.0, 1.0, 0.0]],
    ]
)  # shape: (3, 1, 4)

key = torch.tensor(
    [
        [[1.0, 0.0, 0.0, 0.0]],
        [[0.0, 1.0, 0.0, 0.0]],
        [[0.0, 0.0, 1.0, 0.0]],
    ]
)  # shape: (3, 1, 4)

value = torch.tensor(
    [
        [[0.1, 0.2, 0.3, 0.4]],
        [[0.5, 0.6, 0.7, 0.8]],
        [[0.9, 1.0, 1.1, 1.2]],
    ]
)  # shape: (3, 1, 4)



# Unpack input shapes for clarity
seq_len_q, batch_size, _ = query.shape  # L, B, E
seq_len_kv = key.shape[0]  # S

# 1. Combined Linear Projection for Q, K, V
# If query, key, and value are the same tensor (self-attention), we can
# project them together for efficiency.
if torch.equal(query, key) and torch.equal(key, value):  # self-attention
    # in_proj returns (L, B, 3*E); chunk into Q, K, V along the last dim
    q, k, v = in_proj(query).chunk(3, dim=-1)
else:  # For cross-attention, project Q, K, V separately using the same weights
    w_q, w_k, w_v = in_proj.weight.chunk(3, dim=0)
    b_q, b_k, b_v = in_proj.bias.chunk(3, dim=0)
    q = nn.functional.linear(query, w_q, b_q)
    k = nn.functional.linear(key, w_k, b_k)
    v = nn.functional.linear(value, w_v, b_v)

# 2. Reshape for Multi-Head Computation
# We want to split the embedding into multiple heads for parallel attention.
# New shape: (Batch, Num_Heads, Seq_Len, Head_Dim)
q = q.view(seq_len_q, batch_size, num_heads, head_dim).permute(
    1, 2, 0, 3
)  # (B, H, L, D)
k = k.view(seq_len_kv, batch_size, num_heads, head_dim).permute(
    1, 2, 0, 3
)  # (B, H, S, D)
v = v.view(seq_len_kv, batch_size, num_heads, head_dim).permute(
    1, 2, 0, 3
)  # (B, H, S, D)

# 3. Scaled Dot-Product Attention
# Compute attention scores: (B, H, L, D) x (B, H, D, S) -> (B, H, L, S)
# Each query vector attends to all key vectors.
scores = torch.matmul(q, k.transpose(-2, -1)) / scale

# Optionally add an attention mask (e.g., for causal or padding masking)
if attn_mask is not None:
    # attn_mask should be broadcastable to (B, H, L, S)
    scores = scores + attn_mask

# Optionally mask out padding tokens in the key
if key_padding_mask is not None:
    # key_padding_mask: (B, S) -> (B, 1, 1, S) for broadcasting
    scores = scores.masked_fill(
        key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
    )

# Softmax over the last dimension (S: source sequence length)
attn_weights = torch.nn.functional.softmax(scores, dim=-1)
attn_weights = attn_dropout(attn_weights)  # Regularization

# Weighted sum of value vectors, using attention weights
# (B, H, L, S) x (B, H, S, D) -> (B, H, L, D)
context = torch.matmul(attn_weights, v)

# 4. Concatenate Heads and Project
# Rearrange and merge heads: (B, H, L, D) -> (L, B, H*D=E)
context = (
    context.permute(2, 0, 1, 3)  # (L, B, H, D)
    .contiguous()
    .view(seq_len_q, batch_size, embed_dim)
)

# Final output projection: (L, B, E) -> (L, B, E)
attn_output = out_proj(context)

# Return output and (optionally) average attention weights over heads
if need_weights:
    # Average over heads: (B, H, L, S) -> (B, L, S)
	print("1...")
	print(attn_output, attn_weights.mean(dim=1))
	print("2...")
else:
    # Return None for the weights, but still inside a tuple (for API compatibility)
	print("3...")
	print(attn_output, None)
	print("4...")

1...
tensor([[[-0.2367,  0.6443,  0.3596,  0.1696]],

        [[-0.2383,  0.6387,  0.3506,  0.1805]],

        [[-0.2353,  0.6491,  0.3515,  0.1831]]], grad_fn=<ViewBackward0>) tensor([[[0.3536, 0.3213, 0.3251],
         [0.3619, 0.3334, 0.3047],
         [0.3408, 0.3264, 0.3328]]], grad_fn=<MeanBackward1>)
2...
