# Belief Collapse v2 (Long Prompts)

This version stresses the model with **long, rich prompts** (stories, recipes, jokes, code) to see collapse dynamics on realistic sequences. We keep the valid metrics (entropy, logit lens, head alignment, convergence) and add top-k token readouts to avoid single-token myopia.


In [None]:
# Imports
import math
import json
from pathlib import Path
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from transformers import GPT2LMHeadModel, GPT2Tokenizer

NOTEBOOK_DIR = Path.cwd()
FIG_DIR = NOTEBOOK_DIR / "figs_collapse_v2"
FIG_DIR.mkdir(exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
# Load GPT-2
model_name = "gpt2"
print(f"Loading {model_name}...")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name, output_attentions=True)
model.to(device)
model.eval()

n_layers = model.config.n_layer
n_heads = model.config.n_head
d_model = model.config.n_embd
d_head = d_model // n_heads
print(f"Model: {n_layers} layers, {n_heads} heads, d_model={d_model}, d_head={d_head}")


In [None]:
# Inference helper
def run_inference(text: str) -> Dict:
    input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
    tokens = [tokenizer.decode([tid]) for tid in input_ids[0]]
    with torch.no_grad():
        outputs = model(input_ids, output_attentions=True, output_hidden_states=True)
    return {
        "tokens": tokens,
        "input_ids": input_ids,
        "logits": outputs.logits.cpu(),
        "attentions": [a.cpu() for a in outputs.attentions],
        "hidden_states": [h.cpu() for h in outputs.hidden_states],
        "n_layers": len(outputs.attentions),
        "n_heads": outputs.attentions[0].size(1),
        "seq_len": input_ids.size(1),
    }

print("run_inference() ready")


In [None]:
# Metrics

def compute_attention_entropy(attn: torch.Tensor) -> float:
    p = attn.clamp(min=1e-10)
    entropy = -(p * p.log()).sum(dim=-1)
    return entropy.mean().item()

def attention_entropy_per_layer(result: Dict, query_pos: int = -1) -> List[float]:
    ent = []
    for attn in result["attentions"]:
        q_attn = attn[0, :, query_pos, :]
        ent.append(compute_attention_entropy(q_attn))
    return ent

def logit_lens(result: Dict, pos: int = -1, top_k: int = 5) -> Tuple[List[float], List[List[str]]]:
    lm_head = model.lm_head
    ln_f = model.transformer.ln_f
    entropies = []
    tops = []
    for hidden in result["hidden_states"]:
        h = hidden[0, pos, :]
        h_norm = ln_f(h.to(device))
        logits = lm_head(h_norm)
        probs = F.softmax(logits, dim=-1)
        p = probs.clamp(min=1e-10)
        entropies.append((-(p * p.log()).sum()).item())
        top_ids = torch.topk(probs, k=top_k).indices.tolist()
        tops.append([tokenizer.decode([tid]) for tid in top_ids])
    return entropies, tops

def head_context_vectors(result: Dict, layer_idx: int, query_pos: int = -1) -> np.ndarray:
    attn = result["attentions"][layer_idx][0].numpy()
    hidden_in = result["hidden_states"][layer_idx][0].numpy()
    ctxs = []
    for h in range(attn.shape[0]):
        w = attn[h, query_pos, :]
        ctxs.append(w @ hidden_in)
    return np.stack(ctxs, axis=0)

def head_alignment(result: Dict, layer_idx: int, query_pos: int = -1) -> float:
    ctxs = head_context_vectors(result, layer_idx, query_pos)
    norms = np.linalg.norm(ctxs, axis=1, keepdims=True) + 1e-10
    ctxs = ctxs / norms
    sim = ctxs @ ctxs.T
    upper = sim[np.triu_indices(sim.shape[0], k=1)]
    return float(upper.mean())

def head_alignment_per_layer(result: Dict, query_pos: int = -1) -> List[float]:
    return [head_alignment(result, l, query_pos) for l in range(result["n_layers"])]

def activation_change_across_positions(result: Dict, layer_idx: int) -> np.ndarray:
    hidden = result["hidden_states"][layer_idx][0].numpy()
    if hidden.shape[0] < 2:
        return np.array([])
    norms = np.linalg.norm(hidden, axis=1, keepdims=True) + 1e-10
    hnorm = hidden / norms
    diffs = []
    for i in range(hidden.shape[0] - 1):
        sim = np.dot(hnorm[i], hnorm[i + 1])
        diffs.append(1 - sim)
    return np.array(diffs)

def mean_position_change_per_layer(result: Dict) -> List[float]:
    out = []
    for l in range(len(result["hidden_states"])):
        c = activation_change_across_positions(result, l)
        out.append(float(c.mean()) if len(c) else 0.0)
    return out

def layer_to_layer_similarity(result: Dict, pos: int = -1) -> List[float]:
    hs = result["hidden_states"]
    sims = []
    for i in range(len(hs) - 1):
        h1 = hs[i][0, pos, :].numpy()
        h2 = hs[i + 1][0, pos, :].numpy()
        n1 = np.linalg.norm(h1) + 1e-10
        n2 = np.linalg.norm(h2) + 1e-10
        sims.append(float(np.dot(h1, h2) / (n1 * n2)))
    return sims

def layer_similarity_to_final(result: Dict, pos: int = -1) -> List[float]:
    hs = result["hidden_states"]
    final = hs[-1][0, pos, :].numpy()
    n_final = np.linalg.norm(final) + 1e-10
    sims = []
    for h in hs:
        v = h[0, pos, :].numpy()
        n_v = np.linalg.norm(v) + 1e-10
        sims.append(float(np.dot(v, final) / (n_v * n_final)))
    return sims

print("Metric functions ready")


## Long, content-rich prompts
- Story
- Recipe
- Joke
- Technical code
- Legal-ish clause
- News/analysis
- Speculative scenario


In [None]:
prompts = {
    "story": (
        "On the fourth night of the winter storm, the lighthouse keeper "
        "noticed the beam flicker. He grabbed his tools, descended the narrow "
        "stairs, and discovered a loose wire sparking near the fuel line. With "
        "waves smashing the rocks and a cargo ship approaching, he had one "
        "chance to repair the circuit before the coast went dark."
    ),
    "recipe": (
        "To make a crusty sourdough loaf, feed your starter the night before, "
        "then mix 500g bread flour, 350g water, 100g active starter, and 10g salt. "
        "Rest 30 minutes, stretch and fold four times every 30 minutes, bulk "
        "ferment until doubled, shape, proof overnight in the fridge, bake at "
        "250C with steam for 20 minutes, then 230C dry for 20-25 minutes until "
        "deeply browned."
    ),
    "joke": (
        "A data scientist walks into a bakery and asks for a pie chart. The baker "
        "hands over a blueberry tart and says, 'Careful, the confidence interval "
        "is deliciously narrow today.'"
    ),
    "code": (
        "Write a Python function that parses a CSV file of orders, groups them "
        "by customer, computes total spend, and returns the top five customers "
        "by revenue. The function should handle missing values, malformed rows, "
        "and should stream the file to avoid loading it all into memory."
    ),
    "legal": (
        "This agreement indemnifies the consultant against all claims arising "
        "from negligent implementation of the clientâ€™s specifications, except "
        "where gross misconduct is proven by clear and convincing evidence."
    ),
    "news": (
        "Analysts expect the central bank to pause rate hikes after inflation "
        "fell for the third consecutive month, but warn that energy volatility "
        "could force a surprise move before year-end."
    ),
    "speculative": (
        "In the distant future, autonomous probes exchange compressed knowledge "
        "packets near the heliopause, negotiating bandwidth and trust scores "
        "before relaying discoveries back to their origin worlds."
    ),
}

print(f"Loaded {len(prompts)} prompts (long-form)")


In [None]:
# Run analysis
all_results = {}

for name, prompt in prompts.items():
    print(f"Processing: {name}...")
    result = run_inference(prompt)
    attn_entropy = attention_entropy_per_layer(result)
    logit_entropy_vals, top_tokens_by_layer = logit_lens(result, top_k=5)
    alignment = head_alignment_per_layer(result)
    pos_change = mean_position_change_per_layer(result)
    layer_sim = layer_to_layer_similarity(result)
    sim_final = layer_similarity_to_final(result)
    
    # Final output
    final_logits = result["logits"][0, -1, :]
    probs = F.softmax(final_logits, dim=-1)
    final_entropy = -(probs * probs.clamp(min=1e-10).log()).sum().item()
    topk_final = torch.topk(probs, k=5).indices.tolist()
    topk_tokens = [tokenizer.decode([tid]) for tid in topk_final]
    
    all_results[name] = {
        "prompt": prompt,
        "tokens": result["tokens"],
        "top_tokens_by_layer": top_tokens_by_layer,
        "topk_final": topk_tokens,
        "final_entropy": final_entropy,
        "attn_entropy": attn_entropy,
        "logit_entropy": logit_entropy_vals,
        "head_alignment": alignment,
        "position_change": pos_change,
        "layer_to_layer_sim": layer_sim,
        "sim_to_final": sim_final,
    }

print("Done.")


In [None]:
# Summary table
print("="*120)
print(f"{'Prompt':<16} {'Final top-5 tokens':<60} {'Ent':>6} {'Align(L11)':>10} {'SimFinal(L0)':>12}")
print("="*120)

for name, r in all_results.items():
    prompt_short = name[:16]
    top5 = ", ".join([t.strip() for t in r["topk_final"]])[:58]
    ent = r["final_entropy"]
    align = r["head_alignment"][-1]
    sim0 = r["sim_to_final"][0]
    print(f"{prompt_short:<16} {top5:<60} {ent:>6.2f} {align:>10.3f} {sim0:>12.3f}")

print("="*120)


## Visualization: All metrics for the story prompt


In [None]:
target = "story"
r = all_results[target]

fig = plt.figure(figsize=(16, 12))
gs = GridSpec(3, 2, figure=fig)

layers = list(range(n_layers))

# Attention entropy
ax1 = fig.add_subplot(gs[0, 0])
ax1.plot(layers, r["attn_entropy"], 'o-', color='crimson', linewidth=2)
ax1.set_xlabel("Layer")
ax1.set_ylabel("Entropy (nats)")
ax1.set_title("Attention Entropy")
ax1.grid(True, alpha=0.3)

# Logit entropy
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(range(len(r["logit_entropy"])), r["logit_entropy"], 's-', color='darkorange', linewidth=2)
ax2.set_xlabel("Layer (0=emb)")
ax2.set_ylabel("Output Entropy")
ax2.set_title("Logit Lens: Output Sharpness")
ax2.grid(True, alpha=0.3)

# Head alignment
ax3 = fig.add_subplot(gs[1, 0])
ax3.plot(layers, r["head_alignment"], '^-', color='forestgreen', linewidth=2)
ax3.set_xlabel("Layer")
ax3.set_ylabel("Mean Pairwise Cosine Sim")
ax3.set_title("Head Alignment")
ax3.grid(True, alpha=0.3)

# Convergence to final
ax4 = fig.add_subplot(gs[1, 1])
ax4.plot(range(len(r["sim_to_final"])), r["sim_to_final"], 'D-', color='purple', linewidth=2)
ax4.set_xlabel("Layer (0=emb)")
ax4.set_ylabel("Cosine Sim to Final")
ax4.set_title("Convergence to Final")
ax4.grid(True, alpha=0.3)

# Layer-to-layer similarity
ax5 = fig.add_subplot(gs[2, 0])
ax5.plot(range(len(r["layer_to_layer_sim"])), r["layer_to_layer_sim"], 'p-', color='teal', linewidth=2)
ax5.set_xlabel("Transition (L_i -> L_{i+1})")
ax5.set_ylabel("Cosine Sim")
ax5.set_title("Layer-to-Layer Stability")
ax5.grid(True, alpha=0.3)

# Top tokens by layer
ax6 = fig.add_subplot(gs[2, 1])
ax6.axis('off')
rows = []
for i, toks in enumerate(r["top_tokens_by_layer"]):
    rows.append(f"L{i:02d}: {', '.join([t.strip() for t in toks[:3]])}")
ax6.text(0.05, 0.95, "Top-3 tokens per layer:\n" + "\n".join(rows), va='top', fontsize=9, family='monospace')
ax6.set_title(f"Final top-5: {', '.join(r['topk_final'])}")

plt.suptitle(f"Collapse Metrics (story): '{r['prompt'][:60]}...'", fontsize=14)
plt.tight_layout()
plt.savefig(FIG_DIR / "01_story_metrics.png", dpi=150)
plt.show()


## Visualization: Heatmaps across long prompts


In [None]:
names = list(all_results.keys())

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

alignment_matrix = np.array([all_results[n]["head_alignment"] for n in names])
ax1 = axes[0]
im1 = ax1.imshow(alignment_matrix, aspect='auto', cmap='viridis')
ax1.set_yticks(range(len(names)))
ax1.set_yticklabels([n[:12] for n in names], fontsize=8)
ax1.set_xlabel("Layer")
ax1.set_title("Head Alignment (long prompts)")
plt.colorbar(im1, ax=ax1)

logit_matrix = np.array([all_results[n]["logit_entropy"] for n in names])
ax2 = axes[1]
im2 = ax2.imshow(logit_matrix, aspect='auto', cmap='magma')
ax2.set_yticks(range(len(names)))
ax2.set_yticklabels([n[:12] for n in names], fontsize=8)
ax2.set_xlabel("Layer (0=emb)")
ax2.set_title("Output Entropy (logit lens)")
plt.colorbar(im2, ax=ax2)

plt.suptitle("Heatmaps: Alignment and Output Entropy (Long Prompts)", fontsize=14)
plt.tight_layout()
plt.savefig(FIG_DIR / "02_heatmaps_long.png", dpi=150)
plt.show()


In [None]:
# Save results
summary = {
    "model": model_name,
    "n_layers": n_layers,
    "n_heads": n_heads,
    "prompts": {
        name: {
            "prompt": r["prompt"],
            "tokens": r["tokens"],
            "top_tokens_by_layer": r["top_tokens_by_layer"],
            "topk_final": r["topk_final"],
            "final_entropy": float(r["final_entropy"]),
            "attn_entropy": [float(x) for x in r["attn_entropy"]],
            "logit_entropy": [float(x) for x in r["logit_entropy"]],
            "head_alignment": [float(x) for x in r["head_alignment"]],
            "position_change": [float(x) for x in r["position_change"]],
            "layer_to_layer_sim": [float(x) for x in r["layer_to_layer_sim"]],
            "sim_to_final": [float(x) for x in r["sim_to_final"]],
        }
        for name, r in all_results.items()
    }
}

with open(FIG_DIR / "collapse_results_v2.json", "w") as f:
    json.dump(summary, f, indent=2)

print(f"Saved to {FIG_DIR / 'collapse_results_v2.json'}")


## Summary
- Prompts are now long-form (stories, recipes, jokes, code, legal, news, speculative)
- Added top-k tokens (per layer and final) to avoid single-token focus
- Metrics kept: attention entropy, logit lens entropy, head alignment, position change, convergence
- Figures:
  - `01_story_metrics.png`: All metrics for story prompt
  - `02_heatmaps_long.png`: Heatmaps of alignment and logit entropy across prompts
- Results saved to `figs_collapse_v2/collapse_results_v2.json`

Key expectation: longer prompts should show richer dynamics and more informative top-k continuations.
