# Attention Fundamentals: Self, Cross, and Multi-Head

Attention lets models focus on the most relevant context without the recurrence bottleneck. This notebook develops scaled dot-product attention step-by-step and prepares you to assemble full transformer blocks.

## Learning Objectives

- Derive scaled dot-product attention and understand each tensor transformation.
- Visualize attention weights to interpret where models focus.
- Implement self-attention, cross-attention, and multi-head variants.
- Build modular attention blocks with residual connections for use in transformers.

## Attention Pipeline

1. Project inputs into queries (Q), keys (K), and values (V).
2. Compute compatibility scores `QK^T / sqrt(d_k)`.
3. Apply softmax to obtain attention weights.
4. Weight values to produce context vectors.

Scaling by `sqrt(d_k)` keeps gradients stable when dimensionality grows.

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(5)

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.size(-1)
    scores = q @ k.transpose(-2, -1) / d_k ** 0.5
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    output = weights @ v
    return output, weights

q = torch.randn(2, 4, 8)
k = torch.randn(2, 4, 8)
v = torch.randn(2, 4, 8)

context, weights = scaled_dot_product_attention(q, k, v)
print(context.shape, weights.shape)


### Visualizing Attention Weights

Heatmaps help you see which tokens influence each other. This is invaluable when debugging or explaining model decisions.

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
attn = weights[0].detach()
im = ax.imshow(attn, cmap="viridis")
for i in range(attn.size(0)):
    for j in range(attn.size(1)):
        ax.text(j, i, f"{attn[i, j]:.2f}", ha="center", va="center", color="white")
ax.set_xlabel("Key index")
ax.set_ylabel("Query index")
fig.colorbar(im, ax=ax)
plt.title("Self-attention weights (sample 0)")
plt.show()


## Multi-Head Attention

Multiple heads allow the model to attend to different representation subspaces. Each head performs attention independently before recombining.

In [None]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
        self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
        self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
        self.out_proj = torch.nn.Linear(embed_dim, embed_dim)

    def forward(self, x, context=None, mask=None):
        context = x if context is None else context
        bsz = x.size(0)

        def reshape(tensor):
            return tensor.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)

        q = reshape(self.q_proj(x))
        k = reshape(self.k_proj(context))
        v = reshape(self.v_proj(context))

        attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, -1, self.embed_dim)
        return self.out_proj(attn_output), attn_weights

mha = MultiHeadAttention(embed_dim=32, num_heads=4)
out, w = mha(torch.randn(2, 5, 32))
print(out.shape, w.shape)


## Mini Task – Causal Mask

Autoregressive models (language generation) require a causal mask so each position attends only to previous positions. Create a lower-triangular mask and apply it to the scaled dot-product attention.

In [None]:
seq_len = 4
# TODO: create causal mask of shape (1, 1, seq_len, seq_len) and apply attention


In [None]:
seq_len = 4
mask = torch.tril(torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool))
masked_context, masked_weights = scaled_dot_product_attention(
    q.view(2, 1, 4, 8),
    k.view(2, 1, 4, 8),
    v.view(2, 1, 4, 8),
    mask=mask,
)
print(masked_context.shape, masked_weights.shape)


## Cross-Attention

Cross-attention connects decoder queries to encoder outputs. This is the bridge between the sequence modeling notebook and the transformer architecture coming next.

In [None]:
encoder_outputs = torch.randn(2, 7, 32)
decoder_states = torch.randn(2, 5, 32)
cross_out, cross_weights = mha(decoder_states, context=encoder_outputs)
print(cross_out.shape, cross_weights.shape)


## Comprehensive Exercise – Attention Block

Implement an `AttentionBlock` that includes layer normalization, residual connections, optional cross-attention, and a position-wise feed-forward network.

In [None]:
class AttentionBlock(torch.nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, cross_attention=False, dropout=0.1):
        super().__init__()
        # TODO: compose self-attention, optional cross-attention, feed-forward, and residual paths

    def forward(self, x, context=None, mask=None, context_mask=None):
        raise NotImplementedError


In [None]:
class AttentionBlock(torch.nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, cross_attention=False, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads) if cross_attention else None
        self.norm1 = torch.nn.LayerNorm(embed_dim)
        self.norm2 = torch.nn.LayerNorm(embed_dim)
        self.norm3 = torch.nn.LayerNorm(embed_dim)
        self.ff = torch.nn.Sequential(
            torch.nn.Linear(embed_dim, ff_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(ff_dim, embed_dim),
        )
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x, context=None, mask=None, context_mask=None):
        residual = x
        attn_out, _ = self.self_attn(self.norm1(x), mask=mask)
        x = residual + self.dropout(attn_out)

        if self.cross_attn is not None and context is not None:
            residual = x
            cross_out, _ = self.cross_attn(self.norm2(x), context=context, mask=context_mask)
            x = residual + self.dropout(cross_out)

        residual = x
        ff_out = self.ff(self.norm3(x))
        x = residual + self.dropout(ff_out)
        return x

block = AttentionBlock(embed_dim=32, num_heads=4, ff_dim=64, cross_attention=True)
dummy_x = torch.randn(2, 6, 32)
dummy_context = torch.randn(2, 7, 32)
print(block(dummy_x, context=dummy_context).shape)


## Further Reading

- Vaswani et al. (2017) “Attention Is All You Need”
- PyTorch `nn.MultiheadAttention` documentation
- Annotated Transformer blog posts for step-by-step derivations
- “A Primer in BERTology” for attention interpretability techniques