# Attention Visualization for Encoder-Decoder Transformer

This notebook provides visualization tools for understanding attention patterns in the Encoder-Decoder Transformer model.

Visualizations include:
- Encoder self-attention
- Decoder self-attention (causal)
- Decoder source attention (cross-attention)

Requires: `altair`, `pandas`

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import pandas as pd
import altair as alt

# Enable altair to handle more than 5000 rows
alt.data_transformers.disable_max_rows()

## Load Model and Data

In [None]:
from src.transformer import make_model, greedy_decode
from src.shared.training import Batch

# Create a small model for demonstration
# In practice, load a trained model
V = 11  # Vocabulary size for copy task
model = make_model(V, V, N=2, d_model=128, h=4)
model.eval()

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

## Visualization Helper Functions

In [None]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    """
    Convert a dense attention matrix to a DataFrame for visualization.
    
    Args:
        m: Attention matrix (seq_len x seq_len)
        max_row: Maximum rows to include
        max_col: Maximum columns to include
        row_tokens: List of row token labels
        col_tokens: List of column token labels
    
    Returns:
        DataFrame with row, column, value, and token labels
    """
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                f"{r:03d} {row_tokens[r] if len(row_tokens) > r else '<blank>'}",
                f"{c:03d} {col_tokens[c] if len(col_tokens) > c else '<blank>'}",
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )


def attn_map(attn, layer, head, row_tokens, col_tokens, max_dim=30):
    """
    Create an altair heatmap for attention visualization.
    
    Args:
        attn: Attention tensor (batch, heads, seq, seq)
        layer: Layer index (for title)
        head: Head index to visualize
        row_tokens: Row (query) token labels
        col_tokens: Column (key) token labels
        max_dim: Maximum dimension to display
    
    Returns:
        Altair Chart object
    """
    df = mtx2df(
        attn[0, head].data,
        max_dim,
        max_dim,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="Key")),
            y=alt.Y("row_token", axis=alt.Axis(title="Query")),
            color=alt.Color("value:Q", scale=alt.Scale(scheme="viridis")),
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        .properties(height=400, width=400, title=f"Head {head}")
        .interactive()
    )

In [None]:
def get_encoder_attn(model, layer):
    """Get encoder self-attention weights from a layer."""
    return model.encoder.layers[layer].self_attn.attn


def get_decoder_self_attn(model, layer):
    """Get decoder self-attention weights from a layer."""
    return model.decoder.layers[layer].self_attn.attn


def get_decoder_src_attn(model, layer):
    """Get decoder source (cross) attention weights from a layer."""
    return model.decoder.layers[layer].src_attn.attn


def visualize_layer(model, layer, getter_fn, ntokens, row_tokens, col_tokens):
    """
    Visualize attention patterns for all heads in a layer.
    
    Args:
        model: The transformer model
        layer: Layer index
        getter_fn: Function to get attention weights
        ntokens: Number of tokens
        row_tokens: Query token labels
        col_tokens: Key token labels
    
    Returns:
        Altair Chart with all heads
    """
    attn = getter_fn(model, layer)
    n_heads = attn.shape[1]
    
    charts = [
        attn_map(
            attn,
            layer,
            h,
            row_tokens=row_tokens,
            col_tokens=col_tokens,
            max_dim=ntokens,
        )
        for h in range(n_heads)
    ]
    
    # Arrange heads in a grid (2 columns)
    rows = []
    for i in range(0, len(charts), 2):
        if i + 1 < len(charts):
            rows.append(charts[i] | charts[i + 1])
        else:
            rows.append(charts[i])
    
    result = rows[0]
    for row in rows[1:]:
        result = result & row
    
    return result.properties(title=f"Layer {layer + 1}")

## Run a Forward Pass to Capture Attention

In [None]:
from src.transformer.attention import subsequent_mask

# Create sample input (copy task)
src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
src_mask = torch.ones(1, 1, src.size(1))

# Encode source
memory = model.encode(src, src_mask)

# Create target (for teacher forcing visualization)
tgt = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9]])
tgt_mask = subsequent_mask(tgt.size(1))

# Decode
out = model.decode(memory, src_mask, tgt, tgt_mask)

# Token labels for visualization
src_tokens = [f"src_{i}" for i in range(src.size(1))]
tgt_tokens = [f"tgt_{i}" for i in range(tgt.size(1))]

print(f"Source shape: {src.shape}")
print(f"Target shape: {tgt.shape}")
print(f"Output shape: {out.shape}")

## Encoder Self-Attention

The encoder uses bidirectional self-attention - each position can attend to all other positions.

In [None]:
# Visualize encoder self-attention for layer 0
visualize_layer(
    model,
    layer=0,
    getter_fn=get_encoder_attn,
    ntokens=len(src_tokens),
    row_tokens=src_tokens,
    col_tokens=src_tokens,
)

In [None]:
# Visualize encoder self-attention for layer 1
visualize_layer(
    model,
    layer=1,
    getter_fn=get_encoder_attn,
    ntokens=len(src_tokens),
    row_tokens=src_tokens,
    col_tokens=src_tokens,
)

## Decoder Self-Attention (Causal)

The decoder uses causal (masked) self-attention - each position can only attend to previous positions and itself. This preserves the autoregressive property.

In [None]:
# Visualize decoder self-attention for layer 0
visualize_layer(
    model,
    layer=0,
    getter_fn=get_decoder_self_attn,
    ntokens=len(tgt_tokens),
    row_tokens=tgt_tokens,
    col_tokens=tgt_tokens,
)

In [None]:
# Visualize decoder self-attention for layer 1
visualize_layer(
    model,
    layer=1,
    getter_fn=get_decoder_self_attn,
    ntokens=len(tgt_tokens),
    row_tokens=tgt_tokens,
    col_tokens=tgt_tokens,
)

## Decoder Source Attention (Cross-Attention)

The decoder's cross-attention allows each target position to attend to all source positions. This is how the decoder "reads" the encoded source.

In [None]:
# Visualize decoder source attention for layer 0
visualize_layer(
    model,
    layer=0,
    getter_fn=get_decoder_src_attn,
    ntokens=max(len(src_tokens), len(tgt_tokens)),
    row_tokens=tgt_tokens,
    col_tokens=src_tokens,
)

In [None]:
# Visualize decoder source attention for layer 1
visualize_layer(
    model,
    layer=1,
    getter_fn=get_decoder_src_attn,
    ntokens=max(len(src_tokens), len(tgt_tokens)),
    row_tokens=tgt_tokens,
    col_tokens=src_tokens,
)

## Subsequent Mask Visualization

This shows the causal mask used in decoder self-attention.

In [None]:
def visualize_subsequent_mask(size=20):
    """Visualize the subsequent (causal) mask."""
    mask = subsequent_mask(size)
    
    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "Subsequent Mask": mask[0][x, y].flatten().float(),
                    "Position (Key)": y,
                    "Position (Query)": x,
                }
            )
            for y in range(size)
            for x in range(size)
        ]
    )
    
    return (
        alt.Chart(data)
        .mark_rect()
        .properties(height=400, width=400, title="Causal (Subsequent) Mask")
        .encode(
            alt.X("Position (Key):O"),
            alt.Y("Position (Query):O"),
            alt.Color(
                "Subsequent Mask:Q",
                scale=alt.Scale(scheme="viridis"),
                legend=alt.Legend(title="Can Attend")
            ),
        )
        .interactive()
    )

visualize_subsequent_mask(20)

## Positional Encoding Visualization

In [None]:
from src.transformer.embeddings import PositionalEncoding

def visualize_positional_encoding(d_model=64, max_len=100, dims_to_show=None):
    """Visualize positional encoding patterns."""
    if dims_to_show is None:
        dims_to_show = [0, 1, 2, 3, 4, 5, 6, 7]
    
    pe = PositionalEncoding(d_model, dropout=0)
    y = pe.forward(torch.zeros(1, max_len, d_model))
    
    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "Encoding Value": y[0, :, dim].detach().numpy(),
                    "Dimension": f"dim_{dim}",
                    "Position": list(range(max_len)),
                }
            )
            for dim in dims_to_show
        ]
    )
    
    return (
        alt.Chart(data)
        .mark_line()
        .properties(width=800, height=300, title="Sinusoidal Positional Encodings")
        .encode(
            x=alt.X("Position:Q"),
            y=alt.Y("Encoding Value:Q"),
            color=alt.Color("Dimension:N"),
        )
        .interactive()
    )

visualize_positional_encoding()

## Learning Rate Schedule Visualization

In [None]:
from src.shared.lr_scheduler import rate

def visualize_lr_schedule():
    """Visualize the Noam learning rate schedule."""
    configs = [
        (512, 1, 4000, "d_model=512, warmup=4000"),
        (512, 1, 8000, "d_model=512, warmup=8000"),
        (256, 1, 4000, "d_model=256, warmup=4000"),
    ]
    
    max_steps = 20000
    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "Learning Rate": [rate(step, d_model, factor, warmup) for step in range(max_steps)],
                    "Configuration": config_name,
                    "Step": list(range(max_steps)),
                }
            )
            for d_model, factor, warmup, config_name in configs
        ]
    )
    
    return (
        alt.Chart(data)
        .mark_line()
        .properties(width=600, height=300, title="Noam Learning Rate Schedule")
        .encode(
            x=alt.X("Step:Q"),
            y=alt.Y("Learning Rate:Q"),
            color=alt.Color("Configuration:N"),
        )
        .interactive()
    )

visualize_lr_schedule()

## Label Smoothing Visualization

In [None]:
from src.shared.loss import LabelSmoothing

def visualize_label_smoothing(vocab_size=5, smoothing=0.4):
    """Visualize the label smoothing distribution."""
    crit = LabelSmoothing(vocab_size, padding_idx=0, smoothing=smoothing)
    
    # Dummy predictions
    predict = torch.FloatTensor(
        [
            [0, 0.2, 0.7, 0.1, 0],
            [0, 0.2, 0.7, 0.1, 0],
            [0, 0.2, 0.7, 0.1, 0],
            [0, 0.2, 0.7, 0.1, 0],
            [0, 0.2, 0.7, 0.1, 0],
        ]
    )
    
    # Compute loss to populate true_dist
    crit(predict.log(), torch.LongTensor([2, 1, 0, 3, 3]))
    
    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "Target Distribution": crit.true_dist[x, y].flatten(),
                    "Vocabulary Index": y,
                    "Example": x,
                }
            )
            for y in range(vocab_size)
            for x in range(5)
        ]
    )
    
    return (
        alt.Chart(data)
        .mark_rect()
        .properties(
            height=200,
            width=200,
            title=f"Label Smoothing (smoothing={smoothing})"
        )
        .encode(
            alt.X("Vocabulary Index:O", title="Vocab Index"),
            alt.Y("Example:O", title="Example"),
            alt.Color(
                "Target Distribution:Q",
                scale=alt.Scale(scheme="viridis")
            ),
        )
        .interactive()
    )

visualize_label_smoothing()