This notebook focuses on attention diagnostics and follows `transformer_visualization.ipynb`.


In [1]:
# Bring in model definitions and dependencies from the architecture notebook
%run ./transformer_architecture.ipynb


  validate(nb)


FFN(
  (fc1): Linear(in_features=512, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=512, bias=True)
)
TransformerEncoder(
  (att): TransformerAttention(
    (q_proj): Linear(in_features=512, out_features=512, bias=True)
    (k_proj): Linear(in_features=512, out_features=512, bias=True)
    (v_proj): Linear(in_features=512, out_features=512, bias=True)
    (output_proj): Linear(in_features=512, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (ffn): FFN(
    (fc1): Linear(in_features=512, out_features=2048, bias=True)
    (fc2): Linear(in_features=2048, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (LayerNorm_att): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (LayerNorm_ffn): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
TransformerDecoder(
  (att): TransformerAttention(
    (q_proj): Linear(in_features=512, out_features=512, bias=True)
    (k_proj): Linear(in_features=512,

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def _set_axis_labels(ax, tokens, axis):
    if tokens is None:
        return
    positions = range(len(tokens))
    if axis == "x":
        ax.set_xticks(list(positions))
        ax.set_xticklabels(tokens, rotation=45, ha="right")
    else:
        ax.set_yticks(list(positions))
        ax.set_yticklabels(tokens)

def plot_attention_heatmaps(att_weights, head_indices=None, query_tokens=None, key_tokens=None, title=None, cmap="magma", max_heads=4, batch_index=0):
    """Visualize per-head attention weights as heatmaps."""
    if not isinstance(att_weights, torch.Tensor):
        att_weights = torch.as_tensor(att_weights)
    att = att_weights.detach().cpu()
    if att.dim() != 4:
        raise ValueError("att_weights must have shape [batch, heads, query_len, key_len]")
    num_heads = att.size(1)
    if head_indices is None:
        head_indices = list(range(min(max_heads, num_heads)))
    if not head_indices:
        raise ValueError("Provide at least one head index to visualize.")
    fig, axes = plt.subplots(1, len(head_indices), figsize=(4.5 * len(head_indices), 4), squeeze=False, layout='constrained')
    axes = axes[0]
    subset = att[batch_index, head_indices]
    vmin = float(subset.min())
    vmax = float(subset.max())
    if np.isclose(vmax - vmin, 0.0):
        vmax = vmin + 1e-6
    for ax, head_idx in zip(axes, head_indices):
        head_map = att[batch_index, head_idx]
        if key_tokens is not None and len(key_tokens) != head_map.size(-1):
            raise ValueError("Number of key_tokens must match key length.")
        if query_tokens is not None and len(query_tokens) != head_map.size(-2):
            raise ValueError("Number of query_tokens must match query length.")
        im = ax.imshow(head_map, cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto")
        ax.set_title(f"Head {head_idx}")
        ax.set_xlabel("Key Positions")
        ax.set_ylabel("Query Positions")
        _set_axis_labels(ax, key_tokens, "x")
        _set_axis_labels(ax, query_tokens, "y")
    fig.suptitle(title or "Attention Heatmaps", fontsize=13)
    cbar = fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
    cbar.set_label("Attention Weight")
    plt.show()

def plot_mean_attention(att_weights, query_tokens=None, key_tokens=None, title=None, cmap="magma", batch_index=0):
    """Plot the mean attention pattern across all heads."""
    if not isinstance(att_weights, torch.Tensor):
        att_weights = torch.as_tensor(att_weights)
    att = att_weights.detach().cpu()
    if att.dim() != 4:
        raise ValueError("att_weights must have shape [batch, heads, query_len, key_len]")
    mean_map = att.mean(dim=1)[batch_index]
    if key_tokens is not None and len(key_tokens) != mean_map.size(-1):
        raise ValueError("Number of key_tokens must match key length.")
    if query_tokens is not None and len(query_tokens) != mean_map.size(-2):
        raise ValueError("Number of query_tokens must match query length.")
    fig, ax = plt.subplots(figsize=(4.5, 4), layout='constrained')
    im = ax.imshow(mean_map, cmap=cmap, aspect="auto")
    ax.set_title(title or "Average Attention")
    ax.set_xlabel("Key Positions")
    ax.set_ylabel("Query Positions")
    _set_axis_labels(ax, key_tokens, "x")
    _set_axis_labels(ax, query_tokens, "y")
    cbar = fig.colorbar(im, fraction=0.046, pad=0.04)
    cbar.set_label("Attention Weight")
    plt.show()


### Attention Pattern Visualization
Shows:
- Self-attention patterns (causal masking)
- Cross-attention patterns
- Effect of padding masks
- How attention weights distribute

In [3]:
def test_decoder_causal_masking():
    torch.manual_seed(42)
    
    # Test parameters
    batch_size = 2
    seq_length = 5
    d_model = 512
    d_ff = 2048
    num_heads = 8
    
    decoder = TransformerDecoder(
        d_model=d_model,
        d_ff=d_ff,
        num_head=num_heads,
        dropout=0.1
    )
    decoder.eval()
    
    decoder_input = torch.randn(batch_size, seq_length, d_model)
    encoder_output = torch.randn(batch_size, seq_length, d_model)
    
    attention_scores = []
    
    def attention_hook(module, input, output):
        if not attention_scores:
            # Apply softmax to get actual attention probabilities
            scores = F.softmax(module.att_matrix, dim=-1)
            attention_scores.append(scores.detach())
    
    decoder.att.register_forward_hook(attention_hook)
    
    with torch.no_grad():
        output = decoder(decoder_input, encoder_output)
    
    att_weights = attention_scores[0]
    
    print("\nAttention Matrix Shape:", att_weights.shape)
    
    # Print attention pattern for first head of first batch
    print("\nAttention Pattern (first head):")
    print(att_weights[0, 0].round(decimals=4))
    
    # Check future tokens (should be 0)
    future_attention = att_weights[:, :, torch.triu_indices(seq_length, seq_length, offset=1)[0], 
                                        torch.triu_indices(seq_length, seq_length, offset=1)[1]]
    
    print("\nFuture Token Analysis:")
    print(f"Mean attention to future tokens: {future_attention.mean():.8f}")
    print(f"Max attention to future tokens: {future_attention.max():.8f}")
    print("Causal masking working:", "Yes" if future_attention.mean() < 1e-7 else "No")
    
    # Check present/past tokens
    present_past = att_weights[:, :, torch.tril_indices(seq_length, seq_length)[0],
                                    torch.tril_indices(seq_length, seq_length)[1]]
    
    print("\nPresent/Past Token Analysis:")
    print(f"Mean attention to present/past tokens: {present_past.mean():.4f}")
    print(f"Has non-zero attention patterns:", "Yes" if present_past.mean() > 0 else "No")
    
    # Verify each position's attention sums to 1
    attention_sums = att_weights.sum(dim=-1)
    print("\nAttention Sum Analysis:")
    print(f"Mean attention sum (should be 1): {attention_sums.mean():.4f}")
    print(f"Max deviation from 1: {(attention_sums - 1).abs().max():.8f}")
    
    return att_weights

attention_weights = test_decoder_causal_masking()


Attention Matrix Shape: torch.Size([2, 8, 5, 5])

Attention Pattern (first head):
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4465, 0.5535, 0.0000, 0.0000, 0.0000],
        [0.3403, 0.3496, 0.3101, 0.0000, 0.0000],
        [0.1965, 0.3485, 0.1174, 0.3377, 0.0000],
        [0.1563, 0.1571, 0.1859, 0.1948, 0.3059]])

Future Token Analysis:
Mean attention to future tokens: 0.00000000
Max attention to future tokens: 0.00000000
Causal masking working: Yes

Present/Past Token Analysis:
Mean attention to present/past tokens: 0.3333
Has non-zero attention patterns: Yes

Attention Sum Analysis:
Mean attention sum (should be 1): 1.0000
Max deviation from 1: 0.00000012


In [4]:
def test_decoder_cross_attention():
    torch.manual_seed(42)
    
    # Test parameters
    batch_size = 2
    decoder_seq_len = 5
    encoder_seq_len = 7  # Different length to make it interesting!
    d_model = 512
    d_ff = 2048
    num_heads = 8
    
    decoder = TransformerDecoder(
        d_model=d_model,
        d_ff=d_ff,
        num_head=num_heads,
        dropout=0.1
    )
    decoder.eval()
    
    # Create input sequences
    decoder_input = torch.randn(batch_size, decoder_seq_len, d_model)
    encoder_output = torch.randn(batch_size, encoder_seq_len, d_model)
    
    # Store attention scores
    cross_attention_scores = []
    
    def attention_hook(module, input, output):
        # We want the second call to att (cross-attention)
        if len(cross_attention_scores) < 2:
            scores = F.softmax(module.att_matrix, dim=-1)
            cross_attention_scores.append(scores.detach())
    
    decoder.att.register_forward_hook(attention_hook)
    
    # Forward pass
    with torch.no_grad():
        output = decoder(decoder_input, encoder_output)
    
    # Get cross-attention weights (second element in list)
    cross_att_weights = cross_attention_scores[1]  # [batch, heads, decoder_seq_len, encoder_seq_len]
    
    print("\nCross-Attention Matrix Shape:", cross_att_weights.shape)
    
    # Print attention pattern for first head of first batch
    print("\nCross-Attention Pattern (first head):")
    print(cross_att_weights[0, 0].round(decimals=4))
    
    # Verify each decoder position attends to all encoder positions
    attention_sums = cross_att_weights.sum(dim=-1)
    zero_attention = (cross_att_weights == 0).all(dim=-1)
    
    print("\nCross-Attention Analysis:")
    print(f"Mean attention weight: {cross_att_weights.mean():.4f}")
    print(f"Min attention weight: {cross_att_weights.min():.4f}")
    print(f"Max attention weight: {cross_att_weights.max():.4f}")
    
    print("\nAttention Coverage:")
    print(f"Each position's attention sums to 1: {torch.allclose(attention_sums, torch.ones_like(attention_sums))}")
    print(f"Every decoder position attends to some encoder position: {not zero_attention.any()}")
    
    # Check attention distribution
    attention_entropy = -(cross_att_weights * torch.log(cross_att_weights + 1e-9)).sum(dim=-1).mean()
    print(f"\nAttention entropy (higher means more uniform attention): {attention_entropy:.4f}")
    
    return cross_att_weights

# Run the test
cross_attention_weights = test_decoder_cross_attention()


Cross-Attention Matrix Shape: torch.Size([2, 8, 5, 7])

Cross-Attention Pattern (first head):
tensor([[0.1308, 0.1502, 0.1380, 0.1131, 0.1987, 0.1117, 0.1576],
        [0.1303, 0.1041, 0.1502, 0.1756, 0.1679, 0.1589, 0.1130],
        [0.0896, 0.2159, 0.1142, 0.1718, 0.1797, 0.0844, 0.1444],
        [0.1250, 0.1650, 0.1607, 0.1053, 0.0868, 0.2349, 0.1223],
        [0.1637, 0.0842, 0.2093, 0.1223, 0.1274, 0.1392, 0.1540]])

Cross-Attention Analysis:
Mean attention weight: 0.1429
Min attention weight: 0.0389
Max attention weight: 0.4142

Attention Coverage:
Each position's attention sums to 1: True
Every decoder position attends to some encoder position: True

Attention entropy (higher means more uniform attention): 1.8917


In [5]:
def test_decoder_cross_attention_with_padding():
    torch.manual_seed(42)
    
    # Test parameters
    batch_size = 2
    decoder_seq_len = 5
    encoder_seq_len = 7
    d_model = 512
    d_ff = 2048
    num_heads = 8
    
    decoder = TransformerDecoder(
        d_model=d_model,
        d_ff=d_ff,
        num_head=num_heads,
        dropout=0.1
    )
    decoder.eval()
    
    # Create input sequences
    decoder_input = torch.randn(batch_size, decoder_seq_len, d_model)
    encoder_output = torch.randn(batch_size, encoder_seq_len, d_model)
    
    # Create padding mask for encoder outputs
    # Mask out last 2 positions (as if they were padding in encoder output)
    padding_mask = torch.ones(batch_size, decoder_seq_len, encoder_seq_len)
    padding_mask[:, :, -2:] = float('-inf')  # Mask positions 5,6
    padding_mask = padding_mask.unsqueeze(1)  # Add head dimension [batch, 1, decoder_seq, encoder_seq]
    
    cross_attention_scores = []
    
    def attention_hook(module, input, output):
        if len(cross_attention_scores) < 2:
            scores = F.softmax(module.att_matrix, dim=-1)
            cross_attention_scores.append(scores.detach())
    
    decoder.att.register_forward_hook(attention_hook)
    
    # Forward pass
    with torch.no_grad():
        output = decoder(decoder_input, encoder_output, padding_mask)
    
    # Get cross-attention weights (second element)
    cross_att_weights = cross_attention_scores[1]
    
    print("\nCross-Attention Matrix Shape:", cross_att_weights.shape)
    
    print("\nCross-Attention Pattern (first head):")
    print("(Last two encoder positions should have zero attention)")
    print(cross_att_weights[0, 0].round(decimals=4))
    
    # Analyze masked positions (last two columns)
    masked_attention = cross_att_weights[:, :, :, -2:]
    unmasked_attention = cross_att_weights[:, :, :, :-2]
    
    print("\nMasking Analysis:")
    print(f"Mean attention to masked positions: {masked_attention.mean():.8f}")
    print(f"Max attention to masked positions: {masked_attention.max():.8f}")
    print(f"Mean attention to unmasked positions: {unmasked_attention.mean():.4f}")
    
    # Verify attention still sums to 1 (only over unmasked positions)
    attention_sums = cross_att_weights.sum(dim=-1)
    
    print("\nAttention Coverage:")
    print(f"Each position's attention sums to 1: {torch.allclose(attention_sums, torch.ones_like(attention_sums), atol=1e-6)}")
    
    # Analyze attention distribution over unmasked positions
    print("\nUnmasked Position Analysis:")
    print(f"Min attention to unmasked positions: {unmasked_attention.min():.4f}")
    print(f"Max attention to unmasked positions: {unmasked_attention.max():.4f}")
    
    return cross_att_weights

# Run the test
cross_attention_weights = test_decoder_cross_attention_with_padding()


Cross-Attention Matrix Shape: torch.Size([2, 8, 5, 7])

Cross-Attention Pattern (first head):
(Last two encoder positions should have zero attention)
tensor([[0.1791, 0.2055, 0.1888, 0.1547, 0.2719, 0.0000, 0.0000],
        [0.1789, 0.1430, 0.2063, 0.2412, 0.2306, 0.0000, 0.0000],
        [0.1162, 0.2800, 0.1480, 0.2228, 0.2330, 0.0000, 0.0000],
        [0.1945, 0.2566, 0.2500, 0.1638, 0.1350, 0.0000, 0.0000],
        [0.2316, 0.1191, 0.2961, 0.1730, 0.1802, 0.0000, 0.0000]])

Masking Analysis:
Mean attention to masked positions: 0.00000000
Max attention to masked positions: 0.00000000
Mean attention to unmasked positions: 0.2000

Attention Coverage:
Each position's attention sums to 1: True

Unmasked Position Analysis:
Min attention to unmasked positions: 0.0458
Max attention to unmasked positions: 0.4875


---

Return to `transformer_architecture.ipynb` or `transformer_testing.ipynb` to revisit core components.
