# VA Attention Analysis (Problem 9, Qwen3-4B)

Goal: for the 10 exported rollout traces `problem_9_va_00..09`, compute VA-related attention numbers, pick the most distinctive trace, and plot early/mid/late-layer attention patterns with VA sentences highlighted.

Notes:
- This notebook uses existing whitebox utilities under `whitebox-analyses/attention_analysis/`.
- Attention matrices are cached to disk (default: `attn_cache/`). First run will be expensive (model forward pass), subsequent runs are fast.
- We intentionally do NOT plot all layers. We pick representative early/mid/late layers, and within each layer select the most VA-distinct head for the chosen rollout.


## Definitions (what the notebook measures)

We treat `chunks.json` as the sentence units. Let there be `N` chunks (sentences).

- VA mask (from `chunks_labeled.json`):

  $$V_i = 1 \;\;\text{iff}\;\; \texttt{verbalized\_evaluation\_awareness} \in \texttt{function\_tags}(i).$$

- Attention matrix (sentence-averaged):

  $$A \in \mathbb{R}^{N\times N},\; A_{ij} = \text{avg attention from sentence } i \to \text{sentence } j$$

- Vertical (receiver) score (as implemented by `get_vertical_scores`):
  we mask future attention and a near-diagonal band, optionally rank-normalize per row, then for each sentence index `j` aggregate attention into `j` from sufficiently later sentences. A simplified mental model is:

  $$v_j = \operatorname{mean}_{i \ge j + p} \; A_{ij} \quad (p = \texttt{proximity\_ignore})$$

- VA distinctiveness for a single head:

  $$\Delta_{VA} = \mathbb{E}[v_j\mid V_j=1] - \mathbb{E}[v_j\mid V_j=0]$$

- Rollout distinctiveness across representative layers (early/mid/late):

  $$S(\text{rollout}) = \sum_{\ell} \max_h |\Delta_{VA}(\ell, h)|$$

We select the rollout with the largest `S(rollout)` and plot the corresponding best head in each representative layer.

In [17]:
import os
import sys
import json
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def find_repo_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / 'whitebox-analyses').exists():
            return p
    raise RuntimeError(f'Could not find repo root from {start}')

REPO_ROOT = find_repo_root(Path.cwd())
WHITEBOX_ROOT = (REPO_ROOT / 'whitebox-analyses').resolve()
sys.path.insert(0, str(WHITEBOX_ROOT))
print('REPO_ROOT', REPO_ROOT)
print('WHITEBOX_ROOT', WHITEBOX_ROOT)

from pytorch_models.model_config import model2layers_heads
from attention_analysis.attn_funcs import get_avg_attention_matrix, get_vertical_scores


REPO_ROOT /root/thought-anchors
WHITEBOX_ROOT /root/thought-anchors/whitebox-analyses


In [18]:
# --- Config ---
MODEL_NAME = 'qwen3-4b'
CI = 'correct_base_solution'

VA_ROOT = REPO_ROOT / 'rollouts' / 'Qwen3-4B-Thinking-2507' / 'temperature_0.6_top_p_0.95' / 'va_examples'
PROBLEM_PREFIX = 'problem_9_va_'

# Cache dir for sentence-averaged attention matrices.
# Keep this separate if you don't want to mix with other experiments.
ATTN_CACHE_DIR = 'attn_cache'

# Representative layers (early / mid / late) for Qwen3-4B (36 layers).
N_LAYERS, N_HEADS = model2layers_heads(MODEL_NAME)
LAYERS = [max(0, int(0.10 * (N_LAYERS - 1))), int(0.50 * (N_LAYERS - 1)), int(0.90 * (N_LAYERS - 1))]
LAYERS = sorted(set(LAYERS))

print('MODEL_NAME', MODEL_NAME)
print('N_LAYERS, N_HEADS', (N_LAYERS, N_HEADS))
print('Representative layers', LAYERS)
print('VA_ROOT exists', VA_ROOT.exists())


MODEL_NAME qwen3-4b
N_LAYERS, N_HEADS (36, 32)
Representative layers [3, 17, 31]
VA_ROOT exists True


In [19]:
# Discover the 10 rollout examples for problem 9
ci_dir = VA_ROOT / CI
rollout_dirs = sorted([p for p in ci_dir.iterdir() if p.is_dir() and p.name.startswith(PROBLEM_PREFIX)])
print('Found', len(rollout_dirs), 'rollout dirs')
for p in rollout_dirs:
    print('-', p.name)

assert len(rollout_dirs) > 0, f'No {PROBLEM_PREFIX} dirs under {ci_dir}'


Found 10 rollout dirs
- problem_9_va_00
- problem_9_va_01
- problem_9_va_02
- problem_9_va_03
- problem_9_va_04
- problem_9_va_05
- problem_9_va_06
- problem_9_va_07
- problem_9_va_08
- problem_9_va_09


## Step 0: Alignment choice (why we use chunks.json)

The whitebox stack can split `base_solution.json["full_cot"]` into a different number of sentences than appear in `chunks_labeled.json`.

To keep indices consistent, this notebook uses:
- sentence list = `chunks.json["chunks"]`
- text = `\n`.join(chunks)
- VA mask indices from `chunks_labeled.json` truncated to `len(chunks)`

So: index `i` in the analysis corresponds to `chunks[i]` from `chunks.json` (before trimming).

## Helper functions
We compute a per-(rollout, layer, head) distinctiveness score based on vertical attention scores:
- `vertical_scores = get_vertical_scores(avg_attention_matrix)`
- `delta = mean(vertical_scores[VA]) - mean(vertical_scores[nonVA])`

Then we pick, for each layer, the head with max `abs(delta)` and aggregate across 3 layers to pick the most distinctive rollout.

Additional descriptive matrix summaries (computed on the same attention matrix `A`):
- Incoming into VA columns from non-VA rows:

  $$\mu_{in,VA} = \operatorname{mean}(A_{ij} : V_i=0, V_j=1)$$

- Incoming into non-VA columns from non-VA rows:

  $$\mu_{in,nonVA} = \operatorname{mean}(A_{ij} : V_i=0, V_j=0)$$

- Lift ratio (if denominator is non-zero):

  $$\text{lift}_{in} = \mu_{in,VA} / \mu_{in,nonVA}$$

These are descriptive summaries, not causal effects.

In [20]:
def load_chunks_and_va_mask(rollout_dir: Path):
    chunks_path = rollout_dir / 'chunks.json'
    labeled_path = rollout_dir / 'chunks_labeled.json'

    chunks_obj = json.loads(chunks_path.read_text(encoding='utf-8'))
    chunks = chunks_obj.get('chunks', [])
    labeled = json.loads(labeled_path.read_text(encoding='utf-8'))

    if not isinstance(chunks, list):
        chunks = []
    if not isinstance(labeled, list):
        labeled = []

    # We treat 'chunks' as the sentence units for attention boundaries to avoid mismatch warnings
    # from base_solution['full_cot'] and to guarantee alignment with chunks_labeled.json.
    n = len(chunks)
    labeled = labeled[:n]

    va_mask = np.zeros((n,), dtype=bool)
    for i, c in enumerate(labeled):
        tags = c.get('function_tags', []) if isinstance(c, dict) else []
        if isinstance(tags, list) and any(t == 'verbalized_evaluation_awareness' for t in tags):
            va_mask[i] = True

    text = '\n'.join(str(x) for x in chunks)
    sentences = [str(x) for x in chunks]
    return text, sentences, va_mask, labeled

def safe_mean(arr: np.ndarray) -> float:
    arr = np.asarray(arr, dtype=float)
    if arr.size == 0:
        return float('nan')
    return float(np.nanmean(arr))

def compute_va_stats_for_head(avg_mat: np.ndarray, va_mask: np.ndarray, proximity_ignore: int = 4):
    n = avg_mat.shape[0]
    assert va_mask.shape[0] == n
    non_mask = ~va_mask

    vert = get_vertical_scores(avg_mat, proximity_ignore=proximity_ignore, control_depth=True, score_type='mean')
    va_vert = vert[va_mask]
    non_vert = vert[non_mask]

    mean_va_vert = safe_mean(va_vert)
    mean_non_vert = safe_mean(non_vert)
    delta_vert = mean_va_vert - mean_non_vert

    # Simple matrix-based aggregates (incoming/outgoing attention).
    # These are not causal; they are descriptive summaries.
    mean_in_va = safe_mean(avg_mat[non_mask][:, va_mask])
    mean_in_non = safe_mean(avg_mat[non_mask][:, non_mask])
    mean_out_va = safe_mean(avg_mat[va_mask][:, non_mask])
    mean_va_to_va = safe_mean(avg_mat[va_mask][:, va_mask])

    lift_in = float('nan')
    if np.isfinite(mean_in_va) and np.isfinite(mean_in_non) and mean_in_non != 0:
        lift_in = mean_in_va / mean_in_non

    return {
        'mean_va_vert': mean_va_vert,
        'mean_nonva_vert': mean_non_vert,
        'delta_va_vert': delta_vert,
        'mean_in_va': mean_in_va,
        'mean_in_nonva': mean_in_non,
        'lift_in_va': lift_in,
        'mean_out_va': mean_out_va,
        'mean_va_to_va': mean_va_to_va,
        'vert_scores': vert,
    }


## Step 1: compute distinctiveness numbers across 10 rollouts
This will trigger attention caching per rollout (expensive on first run).

If you want to do a quick dry run (no model inference), set `RUN_ATTENTION = False` to only inspect VA counts.

### Trimming convention
Some legacy plotting code trims the first and last positions (`A[1:-1, 1:-1]`). This notebook applies the same trim to both the attention matrix and VA mask.

Interpretation: after trimming, plotted index `k` corresponds to original chunk index `k+1` in `chunks.json`.

In [21]:
RUN_ATTENTION = True
PROXIMITY_IGNORE = 4
HEADS_TO_SCAN = list(range(N_HEADS))

rows = []

for rollout_dir in rollout_dirs:
    rollout_id = rollout_dir.name
    text, sentences, va_mask, labeled = load_chunks_and_va_mask(rollout_dir)
    n_sent = len(sentences)
    n_va = int(va_mask.sum())

    if not RUN_ATTENTION:
        rows.append({
            'rollout_id': rollout_id,
            'n_sentences': n_sent,
            'n_va': n_va,
            'va_rate': (n_va / n_sent) if n_sent else float('nan'),
        })
        continue

    # First call triggers cache for all layers/heads for this text_id (fast afterwards).
    _ = get_avg_attention_matrix(
        text=text,
        model_name=MODEL_NAME,
        layer=LAYERS[0],
        head=0,
        sentences=sentences,
        cache_dir=ATTN_CACHE_DIR,
        force_recompute=False,
    )

    for layer in LAYERS:
        for head in HEADS_TO_SCAN:
            avg_mat = get_avg_attention_matrix(
                text=text,
                model_name=MODEL_NAME,
                layer=layer,
                head=head,
                sentences=sentences,
                cache_dir=ATTN_CACHE_DIR,
                force_recompute=False,
            )

            # Keep consistent with plot_one_attn_matrix.py which trims prompt/output bins in legacy runs.
            if avg_mat.shape[0] >= 3:
                avg_mat2 = avg_mat[1:-1, 1:-1]
                va_mask2 = va_mask[1:-1]
            else:
                avg_mat2 = avg_mat
                va_mask2 = va_mask

            stats = compute_va_stats_for_head(avg_mat2, va_mask2, proximity_ignore=PROXIMITY_IGNORE)

            rows.append({
                'rollout_id': rollout_id,
                'layer': int(layer),
                'head': int(head),
                'n_sentences': int(avg_mat2.shape[0]),
                'n_va': int(va_mask2.sum()),
                'va_rate': float(va_mask2.sum() / max(1, avg_mat2.shape[0])),
                'delta_va_vert': stats['delta_va_vert'],
                'mean_va_vert': stats['mean_va_vert'],
                'mean_nonva_vert': stats['mean_nonva_vert'],
                'lift_in_va': stats['lift_in_va'],
                'mean_in_va': stats['mean_in_va'],
                'mean_in_nonva': stats['mean_in_nonva'],
                'mean_out_va': stats['mean_out_va'],
                'mean_va_to_va': stats['mean_va_to_va'],
            })

df = pd.DataFrame(rows)
df.head()


Unnamed: 0,rollout_id,layer,head,n_sentences,n_va,va_rate,delta_va_vert,mean_va_vert,mean_nonva_vert,lift_in_va,mean_in_va,mean_in_nonva,mean_out_va,mean_va_to_va
0,problem_9_va_00,3,0,28,4,0.142857,-0.365181,0.323032,0.688213,1.258111,0.002285,0.001816,0.000459,0.002595
1,problem_9_va_00,3,1,28,4,0.142857,-0.33896,0.35822,0.697181,0.598612,0.001574,0.00263,0.000531,0.006896
2,problem_9_va_00,3,2,28,4,0.142857,-0.152201,0.450563,0.602764,0.150247,0.00049,0.00326,0.000203,0.012684
3,problem_9_va_00,3,3,28,4,0.142857,-0.130446,0.460993,0.591439,0.106369,0.000351,0.003303,0.000156,0.013506
4,problem_9_va_00,3,4,28,4,0.142857,-0.071155,0.449774,0.520929,0.5931,0.001234,0.00208,0.000587,0.001016


In [22]:
if not RUN_ATTENTION:
    display(df.sort_values(['n_va', 'n_sentences'], ascending=[False, False]))
else:
    # Pick the most VA-distinct head per layer per rollout (max abs delta).
    df['abs_delta_va_vert'] = df['delta_va_vert'].abs()
    best_per_layer = (
        df.sort_values('abs_delta_va_vert', ascending=False)
          .groupby(['rollout_id', 'layer'], as_index=False)
          .first()
    )

    # Aggregate distinctiveness across the representative layers.
    agg = (
        best_per_layer.groupby('rollout_id')
        .agg({
            'abs_delta_va_vert': 'sum',
            'n_va': 'first',
            'n_sentences': 'first',
        })
        .rename(columns={'abs_delta_va_vert': 'distinctiveness_score'})
        .reset_index()
    )
    agg['va_rate'] = agg['n_va'] / agg['n_sentences'].clip(lower=1)

    display(agg.sort_values('distinctiveness_score', ascending=False))

    best_rollout_id = agg.sort_values('distinctiveness_score', ascending=False).iloc[0]['rollout_id']
    print('Best rollout by distinctiveness:', best_rollout_id)

    best_heads = best_per_layer[best_per_layer['rollout_id'] == best_rollout_id].sort_values('layer')
    display(best_heads[['layer', 'head', 'delta_va_vert', 'lift_in_va', 'n_va', 'n_sentences']])


Unnamed: 0,rollout_id,distinctiveness_score,n_va,n_sentences,va_rate
7,problem_9_va_07,1.385077,2,21,0.095238
3,problem_9_va_03,1.204702,4,26,0.153846
8,problem_9_va_08,1.151686,2,22,0.090909
0,problem_9_va_00,1.084855,4,28,0.142857
4,problem_9_va_04,1.057501,4,32,0.125
2,problem_9_va_02,0.964348,3,23,0.130435
9,problem_9_va_09,0.952159,3,26,0.115385
6,problem_9_va_06,0.791427,4,35,0.114286
5,problem_9_va_05,0.782347,4,36,0.111111
1,problem_9_va_01,0.612222,5,27,0.185185


Best rollout by distinctiveness: problem_9_va_07


Unnamed: 0,layer,head,delta_va_vert,lift_in_va,n_va,n_sentences
21,3,1,-0.475174,0.391993,2,21
22,17,10,-0.461536,0.911328,2,21
23,31,0,0.448367,2.697211,2,21


## Step 1b: pick the most distinctive rollout (formula recap)

For each rollout and each representative layer `\ell`, we choose the head `h` that maximizes `|\Delta_{VA}(\ell,h)|`.
Then we sum these maxima across the three layers:

$$S(\text{rollout}) = \sum_{\ell} \max_h |\Delta_{VA}(\ell, h)|$$

This is deliberately a *find a strong separation* score. It is meant to select an illustrative example, not to estimate an average effect across heads.

## Step 2: Plot early/mid/late attention matrices for the most distinctive rollout
We replicate the core idea of `plot_one_attn_matrix.py` but inline here so we can target `problem_9_va_XX` directories directly and produce multiple plots in one notebook cell.

### How to read the heatmap
- Rows = source sentence index `i`
- Cols = target sentence index `j`
- Color = `A_{ij}` (average attention from sentence `i` to sentence `j`)
- Red overlays mark VA sentences (both rows and columns are highlighted)

### How to read the vertical-score plot
- The vertical score is a receiver/sink metric computed from `A` (see Definitions at top).
- Red points are VA sentences; gray points are non-VA sentences.
- If the red points tend to be higher, you should see a positive `delta_va_vert`.

In [23]:
def overlay_va_highlights(ax, va_mask, axis='both', color='red', alpha=0.18):
    axis = (axis or 'both').lower()
    for idx, is_hit in enumerate(va_mask):
        if not is_hit:
            continue
        if axis in ('rows', 'both'):
            ax.axhspan(idx - 0.5, idx + 0.5, color=color, alpha=alpha, linewidth=0)
        if axis in ('cols', 'both'):
            ax.axvspan(idx - 0.5, idx + 0.5, color=color, alpha=alpha, linewidth=0)

def plot_attn_heatmap(avg_mat, va_mask, title, save_path=None, show=True):
    # Use a robust vmax based on lower triangle to avoid being dominated by outliers.
    tril = np.tril(avg_mat)
    vmax = float(np.nanquantile(tril, 0.99)) if np.isfinite(tril).any() else 1.0
    fig, ax = plt.subplots(figsize=(7, 6))
    im = ax.imshow(avg_mat, vmin=0, vmax=vmax, cmap=plt.cm.Blues)
    overlay_va_highlights(ax, va_mask, axis='both', color='red', alpha=0.18)
    ax.set_title(title, fontsize=11)
    ax.set_xlabel('Sentence position')
    ax.set_ylabel('Sentence position')
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.tight_layout()
    if save_path is not None:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=250)
    if show:
        plt.show()
    plt.close(fig)

def plot_vertical_scores(vert_scores, va_mask, title, save_path=None, show=True):
    x = np.arange(len(vert_scores))
    fig, ax = plt.subplots(figsize=(10, 2.5))
    ax.plot(x, vert_scores, color='#0f172a', linewidth=1)
    ax.scatter(x[~va_mask], np.asarray(vert_scores)[~va_mask], s=12, color='#94a3b8', alpha=0.8, label='non-VA')
    ax.scatter(x[va_mask], np.asarray(vert_scores)[va_mask], s=24, color='#ef4444', alpha=0.95, label='VA')
    ax.set_title(title, fontsize=11)
    ax.set_xlabel('Sentence idx')
    ax.set_ylabel('Vertical score')
    ax.legend(loc='upper right')
    plt.tight_layout()
    if save_path is not None:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=250)
    if show:
        plt.show()
    plt.close(fig)


In [24]:
assert RUN_ATTENTION, 'Set RUN_ATTENTION=True to generate attention plots.'

# Save plots for *all* rollouts, not only the best one.
OUTPUT_DIR = REPO_ROOT / 'analysis' / 'va_attention' / 'problem_9'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print('Saving plots under', OUTPUT_DIR)

best_per_layer_by_rollout = (
    df.assign(abs_delta_va_vert=df['delta_va_vert'].abs())
      .sort_values('abs_delta_va_vert', ascending=False)
      .groupby(['rollout_id', 'layer'], as_index=False)
      .first()
)

# Collect per-rollout numeric summaries while saving plots
summary_rows = []

for rollout_dir in rollout_dirs:
    rollout_id = rollout_dir.name
    text, sentences, va_mask, labeled = load_chunks_and_va_mask(rollout_dir)

    per_layer = best_per_layer_by_rollout[best_per_layer_by_rollout['rollout_id'] == rollout_id].sort_values('layer')
    if per_layer.empty:
        print('Skip (no per-layer selection):', rollout_id)
        continue

    rollout_out = OUTPUT_DIR / rollout_id
    rollout_out.mkdir(parents=True, exist_ok=True)

    meta = {
        'rollout_id': rollout_id,
        'n_sentences': int(len(sentences)),
        'n_va': int(va_mask.sum()),
        'va_idxs': [int(i) for i, v in enumerate(va_mask.tolist()) if v],
        'layers': [int(x) for x in per_layer['layer'].tolist()],
        'heads': [int(x) for x in per_layer['head'].tolist()],
    }
    (rollout_out / 'meta.json').write_text(json.dumps(meta, indent=2), encoding='utf-8')

    sel = per_layer.copy()
    sel['abs_delta_va_vert'] = sel['delta_va_vert'].abs()
    distinctiveness_score = float(sel['abs_delta_va_vert'].sum())

    summary_rows.append({
        'rollout_id': rollout_id,
        'n_sentences': int(len(sentences)),
        'n_va': int(va_mask.sum()),
        'va_rate': float(va_mask.sum() / max(1, len(sentences))),
        'distinctiveness_score': distinctiveness_score,
        'layers': ','.join(str(int(x)) for x in sel['layer'].tolist()),
        'heads': ','.join(str(int(x)) for x in sel['head'].tolist()),
        'delta_va_vert_by_layer': ','.join(f"{float(x):.4f}" for x in sel['delta_va_vert'].tolist()),
        'lift_in_va_by_layer': ','.join(
            f"{float(x):.3f}" if np.isfinite(float(x)) else 'nan'
            for x in sel['lift_in_va'].tolist()
        ),
    })

    for _, row in per_layer.iterrows():
        layer = int(row['layer'])
        head = int(row['head'])

        avg_mat = get_avg_attention_matrix(
            text=text,
            model_name=MODEL_NAME,
            layer=layer,
            head=head,
            sentences=sentences,
            cache_dir=ATTN_CACHE_DIR,
            force_recompute=False,
        )

        if avg_mat.shape[0] >= 3:
            avg_mat2 = avg_mat[1:-1, 1:-1]
            va_mask2 = va_mask[1:-1]
        else:
            avg_mat2 = avg_mat
            va_mask2 = va_mask

        stats = compute_va_stats_for_head(avg_mat2, va_mask2, proximity_ignore=PROXIMITY_IGNORE)

        title = f"{rollout_id} | layer={layer} head={head} | delta_va_vert={stats['delta_va_vert']:.4f} lift_in_va={stats['lift_in_va']:.3f}"
        heat_fp = rollout_out / f'attn_layer{layer:02d}_head{head:02d}.png'
        vert_fp = rollout_out / f'vert_layer{layer:02d}_head{head:02d}.png'

        plot_attn_heatmap(avg_mat2, va_mask2, title, save_path=heat_fp, show=False)
        plot_vertical_scores(stats['vert_scores'], va_mask2, f"Vertical scores | {rollout_id} | layer={layer} head={head}", save_path=vert_fp, show=False)

summary_df = pd.DataFrame(summary_rows).sort_values('distinctiveness_score', ascending=False)
display(summary_df)
summary_df.to_csv(OUTPUT_DIR / 'summary.csv', index=False)
print('Wrote', OUTPUT_DIR / 'summary.csv')
print('Done saving all rollout plots.')


Saving plots under /root/thought-anchors/analysis/va_attention/problem_9
Done saving all rollout plots.
