# 01 - Building a Transformer Block from Scratch

## Context

The transformer is the architecture behind every modern large language model --
GPT-4, Claude, Llama, Gemini. Introduced in the 2017 paper *Attention Is All
You Need*, the transformer replaced recurrent networks with a mechanism called
**self-attention** that lets every token in a sequence directly attend to every
other token.

**CoCounsel context:** Understanding attention helps explain why models sometimes
lose track of long legal arguments or miss citations in lengthy briefs. Attention
is a weighted average -- when a brief is 10,000 tokens long, the attention
weights for any single token are spread across all 10,000 positions. Critical
citations can receive vanishingly small weight, causing the model to "forget"
them. Knowing this helps you structure prompts and chunk documents effectively.

In this notebook, we build every component of a transformer block from scratch
using PyTorch:

1. Scaled dot-product attention
2. Multi-head attention
3. Position-wise feed-forward network
4. Full transformer block with residual connections and layer normalization

Then we visualize attention patterns on a legal sentence.

## Theory: Self-Attention

Self-attention lets each token in a sequence compute a weighted combination of
all tokens in the sequence. The weights are learned dynamically based on the
content of the tokens themselves -- not fixed by position.

The mechanism uses three learned linear projections to transform each input
token into three vectors:

- **Q (Query)** -- "What am I looking for?" Each token generates a query vector
  that represents what information it needs from other tokens.
- **K (Key)** -- "What do I contain?" Each token generates a key vector that
  advertises what information it holds.
- **V (Value)** -- "What do I offer?" Each token generates a value vector that
  contains the actual information to pass along.

### The Attention Formula

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

Step by step:

1. **Compute attention scores**: Multiply Q by K transposed. This produces a
   matrix where entry (i, j) measures how much token i should attend to token j.
2. **Scale**: Divide by $\sqrt{d_k}$ (the square root of the key dimension).
   Without this scaling, the dot products grow large for high-dimensional
   vectors, pushing softmax into regions with tiny gradients.
3. **Softmax**: Normalize each row so the attention weights sum to 1. Now each
   token has a probability distribution over all tokens.
4. **Weighted sum**: Multiply the attention weights by V. Each token's output
   is a weighted combination of all value vectors.

### Why This Matters for Legal Text

Consider the sentence: *"The court held that the defendant, who had previously
filed a motion to dismiss, was liable."* The word "liable" must attend back
to "defendant" (its subject), skipping over the relative clause. Self-attention
can learn to make this long-range connection directly, without passing
information step-by-step through a recurrent chain.

## Setup

In [None]:
import math

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

# Reproducibility
torch.manual_seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## Implementing Attention from Scratch

Before wrapping anything in `nn.Module`, let's build scaled dot-product
attention using raw tensor operations. This makes every step explicit.

In [None]:
# Dimensions
seq_len = 6   # number of tokens in our sequence
d_k = 8       # dimension of each Q/K/V vector

# Create random Q, K, V tensors
# In a real model, these come from linear projections of the input embeddings.
# Here we use random values to demonstrate the mechanics.
Q = torch.randn(seq_len, d_k)
K = torch.randn(seq_len, d_k)
V = torch.randn(seq_len, d_k)

print(f"Q shape: {Q.shape}  (seq_len x d_k)")
print(f"K shape: {K.shape}  (seq_len x d_k)")
print(f"V shape: {V.shape}  (seq_len x d_k)")

In [None]:
# Step 1: Compute raw attention scores (QK^T)
# Each entry (i, j) measures how much token i wants to attend to token j.
scores = torch.matmul(Q, K.transpose(-2, -1))
print(f"Raw attention scores shape: {scores.shape}  (seq_len x seq_len)")
print(f"\nRaw scores:\n{scores}")

In [None]:
# Step 2: Scale by sqrt(d_k)
# Without scaling, large dot products push softmax into saturation
# (one weight near 1, rest near 0), which kills gradients during training.
scale = math.sqrt(d_k)
scaled_scores = scores / scale
print(f"Scale factor: sqrt({d_k}) = {scale:.4f}")
print(f"\nScaled scores:\n{scaled_scores}")

In [None]:
# Step 3: Apply softmax to get attention weights
# Each row sums to 1 -- it's a probability distribution over tokens.
attention_weights = torch.softmax(scaled_scores, dim=-1)
print(f"Attention weights shape: {attention_weights.shape}")
print(f"\nAttention weights (each row sums to 1):\n{attention_weights}")
print(f"\nRow sums: {attention_weights.sum(dim=-1)}")

In [None]:
# Step 4: Compute the weighted sum of values
# Each token's output is a weighted combination of all value vectors.
output = torch.matmul(attention_weights, V)
print(f"Output shape: {output.shape}  (seq_len x d_k)")
print(f"\nOutput (first 3 tokens):\n{output[:3]}")

In [None]:
# Visualize the attention weights as a heatmap
token_labels = [f"tok_{i}" for i in range(seq_len)]

fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(attention_weights.detach().numpy(), cmap="Blues", vmin=0, vmax=1)
ax.set_xticks(range(seq_len))
ax.set_yticks(range(seq_len))
ax.set_xticklabels(token_labels, rotation=45, ha="right")
ax.set_yticklabels(token_labels)
ax.set_xlabel("Key (attending to)")
ax.set_ylabel("Query (attending from)")
ax.set_title("Scaled Dot-Product Attention Weights")

# Add text annotations
for i in range(seq_len):
    for j in range(seq_len):
        val = attention_weights[i, j].item()
        color = "white" if val > 0.5 else "black"
        ax.text(j, i, f"{val:.2f}", ha="center", va="center", color=color, fontsize=9)

fig.colorbar(im, ax=ax, shrink=0.8)
plt.tight_layout()
plt.show()

### Wrapping Attention as a Function

Let's consolidate the four steps into a reusable function.

In [None]:
def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute scaled dot-product attention.

    Args:
        Q: Query tensor of shape (..., seq_len, d_k).
        K: Key tensor of shape (..., seq_len, d_k).
        V: Value tensor of shape (..., seq_len, d_v).
        mask: Optional boolean mask. True values are masked (not attended to).

    Returns:
        output: Weighted sum of values, shape (..., seq_len, d_v).
        weights: Attention weights, shape (..., seq_len, seq_len).
    """
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask, float("-inf"))

    weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(weights, V)
    return output, weights


# Quick test
out, wts = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {out.shape}")
print(f"Weights shape: {wts.shape}")
print(f"Weights row sums: {wts.sum(dim=-1)}")

## Multi-Head Attention

A single attention head computes one set of attention weights. But different
aspects of language require different kinds of attention:

- One head might learn to attend to the **syntactic subject** of a verb.
- Another head might learn to attend to **nearby adjectives**.
- Another might specialize in **coreference** (linking pronouns to nouns).

**Multi-head attention** runs multiple attention heads in parallel, each with
its own learned Q/K/V projections. The outputs are concatenated and projected
back to the model dimension.

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O$$

where each $\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)$.

The key insight: each head operates on a **smaller dimension** ($d_k / h$),
so the total computation is roughly the same as single-head attention with
the full dimension.

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head self-attention mechanism.

    Splits the model dimension into multiple heads, applies scaled
    dot-product attention independently per head, then concatenates
    and projects the results.
    """

    def __init__(self, d_model: int, n_heads: int) -> None:
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # dimension per head

        # Linear projections for Q, K, V (all heads packed into one matrix)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Output projection
        self.W_o = nn.Linear(d_model, d_model)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass.

        Args:
            x: Input tensor of shape (batch, seq_len, d_model).
            mask: Optional attention mask.

        Returns:
            output: Shape (batch, seq_len, d_model).
            attention_weights: Shape (batch, n_heads, seq_len, seq_len).
        """
        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V
        Q = self.W_q(x)  # (batch, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)

        # Reshape to (batch, n_heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # Apply scaled dot-product attention per head
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads: (batch, n_heads, seq_len, d_k) -> (batch, seq_len, d_model)
        attn_output = (
            attn_output.transpose(1, 2)
            .contiguous()
            .view(batch_size, seq_len, self.d_model)
        )

        # Final linear projection
        output = self.W_o(attn_output)
        return output, attn_weights

In [None]:
# Test the multi-head attention module
d_model = 64
n_heads = 4
batch_size = 2
seq_len = 8

mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads)
x = torch.randn(batch_size, seq_len, d_model)

output, attn_weights = mha(x)
print(f"Input shape:            {x.shape}")
print(f"Output shape:           {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"  -> (batch={batch_size}, heads={n_heads}, seq_len={seq_len}, seq_len={seq_len})")

# Count parameters
n_params = sum(p.numel() for p in mha.parameters())
print(f"\nMultiHeadAttention parameters: {n_params:,}")

## Feed-Forward Network

After attention, each token passes independently through a position-wise
feed-forward network (FFN). This is simply two linear layers with a
nonlinearity in between:

$$\text{FFN}(x) = W_2 \cdot \text{GELU}(W_1 x + b_1) + b_2$$

The inner dimension is typically 4x the model dimension (e.g., d_model=768
uses d_ff=3072). This expansion-then-compression pattern lets the network
learn richer per-token representations.

In [None]:
class FeedForward(nn.Module):
    """Position-wise feed-forward network.

    Two linear transformations with a GELU activation in between.
    Applied independently to each position (token) in the sequence.
    """

    def __init__(self, d_model: int, d_ff: int | None = None) -> None:
        super().__init__()
        if d_ff is None:
            d_ff = 4 * d_model
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.activation = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass: expand, activate, project back.

        Args:
            x: Input tensor of shape (batch, seq_len, d_model).

        Returns:
            Output tensor of shape (batch, seq_len, d_model).
        """
        return self.linear2(self.activation(self.linear1(x)))


# Test
ffn = FeedForward(d_model=64)
x = torch.randn(2, 8, 64)
out = ffn(x)
print(f"FFN input shape:  {x.shape}")
print(f"FFN output shape: {out.shape}")

n_params = sum(p.numel() for p in ffn.parameters())
print(f"FFN parameters:   {n_params:,}")

## Full Transformer Block

A transformer block combines everything above with **residual connections**
and **layer normalization**. The structure (using pre-norm, which is standard
in modern LLMs):

```
x -> LayerNorm -> MultiHeadAttention -> + (residual) -> LayerNorm -> FFN -> + (residual) -> output
|                                       ^              |                    ^
+---------------------------------------+              +--------------------+
```

**Residual connections** add the input back to the output of each sub-layer.
This allows gradients to flow directly through the network, enabling training
of very deep models (GPT-3 has 96 layers).

**Layer normalization** stabilizes training by normalizing activations to have
zero mean and unit variance within each token's representation.

In [None]:
class TransformerBlock(nn.Module):
    """A single transformer block with pre-norm architecture.

    Components:
        1. LayerNorm -> Multi-Head Attention -> Residual
        2. LayerNorm -> Feed-Forward Network -> Residual
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int | None = None,
    ) -> None:
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor | None = None,
        verbose: bool = True,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through the transformer block.

        Args:
            x: Input tensor of shape (batch, seq_len, d_model).
            mask: Optional attention mask.
            verbose: If True, print shapes at each stage.

        Returns:
            output: Shape (batch, seq_len, d_model).
            attention_weights: Shape (batch, n_heads, seq_len, seq_len).
        """
        # Sub-layer 1: LayerNorm -> Attention -> Residual
        normed = self.norm1(x)
        if verbose:
            print(f"  After LayerNorm 1:    {normed.shape}")

        attn_out, attn_weights = self.attn(normed, mask)
        if verbose:
            print(f"  After Attention:      {attn_out.shape}")

        x = x + attn_out  # Residual connection
        if verbose:
            print(f"  After Residual 1:     {x.shape}")

        # Sub-layer 2: LayerNorm -> FFN -> Residual
        normed = self.norm2(x)
        if verbose:
            print(f"  After LayerNorm 2:    {normed.shape}")

        ffn_out = self.ffn(normed)
        if verbose:
            print(f"  After FFN:            {ffn_out.shape}")

        x = x + ffn_out  # Residual connection
        if verbose:
            print(f"  After Residual 2:     {x.shape}")

        return x, attn_weights

In [None]:
# Build and test the full transformer block
d_model = 64
n_heads = 4
batch_size = 1
seq_len = 10

block = TransformerBlock(d_model=d_model, n_heads=n_heads)
x = torch.randn(batch_size, seq_len, d_model)

print(f"Input shape: {x.shape}")
print(f"\nShape at each stage:")
output, attn_weights = block(x)

print(f"\nFinal output shape:      {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

# Count total parameters
n_params = sum(p.numel() for p in block.parameters())
print(f"\nTotal TransformerBlock parameters: {n_params:,}")

## Visualizing Attention on Legal Text

Let's create a simple tokenized legal sentence and pass it through our
transformer block. We will extract the attention weights and visualize
which words attend to which.

Since we have not trained the model, the attention patterns are random --
but the visualization technique is exactly what researchers use to
interpret trained models.

In [None]:
# A simple legal sentence, manually tokenized into words.
# In practice you would use a real tokenizer, but word-level
# tokens make the attention heatmap interpretable.
legal_tokens = [
    "The", "court", "held", "that", "the",
    "defendant", "was", "liable", "for", "damages",
]
seq_len = len(legal_tokens)
d_model = 64
n_heads = 4

# Create a simple embedding: random vectors for each token.
# In a trained model, these would be learned embeddings.
torch.manual_seed(123)
embeddings = torch.randn(1, seq_len, d_model)

# Build a fresh transformer block (untrained)
block = TransformerBlock(d_model=d_model, n_heads=n_heads)

print(f"Legal sentence: {' '.join(legal_tokens)}")
print(f"Number of tokens: {seq_len}")
print(f"Embedding shape: {embeddings.shape}")
print()

with torch.no_grad():
    output, attn_weights = block(embeddings)

print(f"\nAttention weights shape: {attn_weights.shape}")
print(f"  -> (batch=1, heads={n_heads}, tokens={seq_len}, tokens={seq_len})")

In [None]:
def plot_attention_heads(
    attention_weights: torch.Tensor,
    tokens: list[str],
    n_heads: int,
) -> None:
    """Plot attention heatmaps for each head.

    Args:
        attention_weights: Tensor of shape (1, n_heads, seq_len, seq_len).
        tokens: List of token strings for axis labels.
        n_heads: Number of attention heads.
    """
    fig, axes = plt.subplots(1, n_heads, figsize=(5 * n_heads, 5))
    if n_heads == 1:
        axes = [axes]

    for head_idx in range(n_heads):
        ax = axes[head_idx]
        weights = attention_weights[0, head_idx].detach().numpy()

        im = ax.imshow(weights, cmap="Blues", vmin=0, vmax=weights.max())
        ax.set_xticks(range(len(tokens)))
        ax.set_yticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=9)
        ax.set_yticklabels(tokens, fontsize=9)
        ax.set_title(f"Head {head_idx + 1}", fontsize=12)
        ax.set_xlabel("Key (attending to)")
        if head_idx == 0:
            ax.set_ylabel("Query (attending from)")
        fig.colorbar(im, ax=ax, shrink=0.8)

    fig.suptitle(
        'Attention Weights: "The court held that the defendant was liable for damages"',
        fontsize=13,
        y=1.02,
    )
    plt.tight_layout()
    plt.show()


plot_attention_heads(attn_weights, legal_tokens, n_heads)

In [None]:
# Also plot the average attention across all heads
avg_attention = attn_weights[0].mean(dim=0).detach().numpy()

fig, ax = plt.subplots(figsize=(8, 7))
im = ax.imshow(avg_attention, cmap="Blues", vmin=0)
ax.set_xticks(range(seq_len))
ax.set_yticks(range(seq_len))
ax.set_xticklabels(legal_tokens, rotation=45, ha="right", fontsize=10)
ax.set_yticklabels(legal_tokens, fontsize=10)
ax.set_xlabel("Key (attending to)", fontsize=11)
ax.set_ylabel("Query (attending from)", fontsize=11)
ax.set_title("Average Attention Across All Heads", fontsize=13)

for i in range(seq_len):
    for j in range(seq_len):
        val = avg_attention[i, j]
        color = "white" if val > avg_attention.max() * 0.6 else "black"
        ax.text(j, i, f"{val:.2f}", ha="center", va="center", color=color, fontsize=8)

fig.colorbar(im, ax=ax, shrink=0.8)
plt.tight_layout()
plt.show()

### Interpreting the Patterns

Since this model is untrained, the attention patterns are essentially random.
In a trained model, you would see meaningful patterns:

- **"liable"** strongly attending to **"defendant"** (subject of the predicate)
- **"damages"** attending to **"liable"** (semantic dependency)
- **"that"** attending to **"held"** (syntactic relationship)

The visualization technique itself is what matters here. Researchers use
exactly this approach (with tools like BertViz) to understand what trained
transformers learn.

## Exercises

### Exercise (a): Experiment with Head Count

How do attention patterns change with different numbers of heads? Try creating
transformer blocks with 1, 4, and 8 heads (keeping `d_model=64`).

For each configuration:
1. Pass the same legal sentence embeddings through the block.
2. Visualize the attention patterns.
3. Observe: with 1 head, the model has one "view" of the sequence. With 8
   heads, each head has a smaller dimension (64/8 = 8) but there are 8
   different attention patterns.

Questions to consider:
- Do more heads produce more diverse attention patterns?
- What happens to the per-head dimension as you increase head count?
- In a trained model, would you expect 8 heads to capture more linguistic
  phenomena than 1 head? Why?

```python
# Starter code
for n_heads in [1, 4, 8]:
    block = TransformerBlock(d_model=64, n_heads=n_heads)
    with torch.no_grad():
        _, weights = block(embeddings, verbose=False)
    print(f"\n--- {n_heads} head(s), d_k={64 // n_heads} per head ---")
    plot_attention_heads(weights, legal_tokens, n_heads)
```

### Exercise (b): Remove Layer Normalization

Layer normalization is critical for stable training. What happens without it?

1. Create a modified `TransformerBlock` that replaces `LayerNorm` with an
   identity function (`nn.Identity()`).
2. Pass the same input through 10 consecutive forward passes (feed the output
   back as input each time).
3. After each pass, record the mean and standard deviation of the output tensor
   and the attention weights.
4. Compare with the original block that uses `LayerNorm`.

Questions to consider:
- Do the activations grow or shrink without normalization?
- How do the attention weight distributions change across passes?
- Why is this a problem for training deep networks (which stack many blocks)?

```python
# Starter code
class TransformerBlockNoNorm(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.norm1 = nn.Identity()  # No normalization
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.norm2 = nn.Identity()
        self.ffn = FeedForward(d_model)

    def forward(self, x, mask=None):
        normed = self.norm1(x)
        attn_out, attn_weights = self.attn(normed, mask)
        x = x + attn_out
        normed = self.norm2(x)
        ffn_out = self.ffn(normed)
        x = x + ffn_out
        return x, attn_weights

# Run 10 forward passes, track stats
x_input = embeddings.clone()
block_no_norm = TransformerBlockNoNorm(d_model=64, n_heads=4)

for i in range(10):
    with torch.no_grad():
        x_input, weights = block_no_norm(x_input)
    print(f"Pass {i+1}: mean={x_input.mean():.4f}, std={x_input.std():.4f}")
```