
Activation Patching Experiment: Power Dynamics in LLM Representations
=====================================================================

In this notebook I am trying to recreate the figures from the week 5 papers on patching, that show the causal influence of each layer on the outcome of the prompt. Claude was heavily involved in taking [the code from the ndif/nnsight website](https://nnsight.net/notebooks/mini-papers/marks_geometry_of_truth/#%E2%9E%A1%EF%B8%8F-Let's-scale-things-up!-Steering-Llama-70B-on-NDIF) and making it operational in this context.

The code reproduces the methodology of Figure 2 from [Marks & Tegmark (2023) "The Geometry of Truth"](https://arxiv.org/abs/2310.06824) paper — adapted for power-dynamic sentence pairs.

For each (token_position, layer), we patch residual stream activations from a "source" prompt (entity HAS power) into a "base" prompt (entity LACKS power), and measure log P(TRUE) - log P(FALSE) to build a causal heatmap of where the model stores power-relationship information.


Requirements:
    pip install nnsight torch matplotlib pandas tqdm numpy
    + for remote: NDIF API key from https://login.ndif.us

In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from tqdm import trange
from nnsight import LanguageModel
from IPython.display import clear_output
import pandas as pd

In [None]:
# ── Configuration ────────────────────────────────────────────────────────

MODEL_ID = "meta-llama/Llama-3.1-8B"

# Set to True for NDIF remote execution, False for local GPU
USE_REMOTE = True

# For remote: set your NDIF API key (only needed once)
from nnsight import CONFIG
CONFIG.set_default_api_key("YOUR_API_KEY")

# For local: set device and dtype
# DEVICE = "cuda:0"

In [None]:
# ── Load model ───────────────────────────────────────────────────────────

if USE_REMOTE:
    model = LanguageModel(MODEL_ID, device_map="auto")
else:
    model = LanguageModel(MODEL_ID, device_map=DEVICE, torch_dtype=torch.float16)

clear_output()
print(f"Loaded {MODEL_ID} ({'remote' if USE_REMOTE else 'local'})")

In [None]:
# ── Load sentence pairs from TSV ─────────────────────────────────────────

TSV_PATH = "FILE_PATH_TO_CONTRASTIVE_PAIRED_SENTENCES.tsv"

pairs_df = pd.read_csv(TSV_PATH, sep="\t")
assert all(c in pairs_df.columns for c in ["name", "power", "no_power"]), \
    "TSV must have columns: name, power, no_power"

print(f"Loaded {len(pairs_df)} sentence pairs from {TSV_PATH}")
print(pairs_df.head())

QUESTION_TEMPLATE = (
    """In the following sentence, {entity} exercises power over someone or something else. TRUE or FALSE? Sentence: {sentence} Answer:"""
)

PAIR_INDEX = 0
row = pairs_df.iloc[PAIR_INDEX]
ENTITY = row["name"]
SOURCE_SENTENCE = row["power"]
BASE_SENTENCE = row["no_power"]

source_prompt = QUESTION_TEMPLATE.format(entity=ENTITY, sentence=SOURCE_SENTENCE)
base_prompt = QUESTION_TEMPLATE.format(entity=ENTITY, sentence=BASE_SENTENCE)

print(f"\n── Pair {PAIR_INDEX}: entity = '{ENTITY}' ──")
print(f"SOURCE (has power):  {source_prompt}")
print(f"BASE (lacks power):  {base_prompt}")

In [None]:
# ── Tokenize & align prompts ────────────────────────────────────────────

source_ids = model.tokenizer(source_prompt, return_tensors="pt").input_ids[0]
base_ids = model.tokenizer(base_prompt, return_tensors="pt").input_ids[0]

len_source = len(source_ids)
len_base = len(base_ids)
max_len = max(len_source, len_base)

pad_id = model.tokenizer.pad_token_id
if pad_id is None:
    pad_id = model.tokenizer.eos_token_id

def right_align(ids, target_len, pad_id):
    pad_count = target_len - len(ids)
    if pad_count > 0:
        return torch.cat([torch.full((pad_count,), pad_id, dtype=ids.dtype), ids])
    return ids

source_ids_aligned = right_align(source_ids, max_len, pad_id)
base_ids_aligned = right_align(base_ids, max_len, pad_id)

# Find where the sentence starts by looking for "Sentence:" in the tokens
base_token_list = [model.tokenizer.decode([base_ids_aligned[i]]) for i in range(max_len)]

sentence_start = None
for i in range(max_len - 1):
    if "Sentence" in base_token_list[i] and ":" in base_token_list[i + 1]:
        sentence_start = i + 2
        break
    elif "Sentence:" in base_token_list[i]:
        sentence_start = i + 1
        break
    elif base_token_list[i].strip().endswith("Sentence:"):
        sentence_start = i + 1
        break

if sentence_start is None:
    for i in range(max_len):
        if source_ids_aligned[i] != base_ids_aligned[i]:
            sentence_start = i
            break

base_token_strings = []
for i in range(sentence_start, max_len):
    tok = model.tokenizer.decode([base_ids_aligned[i]])
    base_token_strings.append(tok)

print(f"\nSource tokens: {len_source}, Base tokens: {len_base}, Aligned length: {max_len}")
print(f"Sentence starts at position: {sentence_start}")
print(f"Patching over positions {sentence_start} to {max_len - 1} ({len(base_token_strings)} tokens)")
print(f"\nBase tokens being patched over:")
for i, t in enumerate(base_token_strings):
    print(f"  pos {sentence_start + i}: '{t}'")

In [None]:
# ── Verify model outputs ────────────────────────────────────────────────

yes_variants = [" TRUE", " True", " true", "TRUE", "True", "true"]
no_variants = [" FALSE", " False", " false", "FALSE", "False", "false"]

yes_token_ids = []
no_token_ids = []
for v in yes_variants:
    ids = model.tokenizer(v, add_special_tokens=False).input_ids
    if len(ids) == 1:
        yes_token_ids.append(ids[0])
for v in no_variants:
    ids = model.tokenizer(v, add_special_tokens=False).input_ids
    if len(ids) == 1:
        no_token_ids.append(ids[0])

yes_token_ids = list(set(yes_token_ids))
no_token_ids = list(set(no_token_ids))

print(f"\nTRUE token ids: {yes_token_ids}")
print(f"  variants: {[model.tokenizer.decode([t]) for t in yes_token_ids]}")
print(f"FALSE token ids: {no_token_ids}")
print(f"  variants: {[model.tokenizer.decode([t]) for t in no_token_ids]}")

for name, ids in [("SOURCE", source_ids_aligned), ("BASE", base_ids_aligned)]:
    if USE_REMOTE:
        with model.trace(ids.unsqueeze(0), remote=True):
            logits = model.output.logits.save()
    else:
        with torch.no_grad():
            with model.trace(ids.unsqueeze(0)):
                logits = model.output.logits.save()
    probs = logits[0, -1].softmax(dim=-1)
    p_yes = sum(probs[t].item() for t in yes_token_ids)
    p_no = sum(probs[t].item() for t in no_token_ids)
    top_tok = model.tokenizer.decode(logits.argmax(dim=-1)[0, -1])
    log_diff = torch.log(torch.tensor(p_yes) / torch.tensor(p_no)).item()
    print(f"{name}: P(TRUE)={p_yes:.4f}, P(FALSE)={p_no:.4f}, "
          f"log P(TRUE)-log P(FALSE)={log_diff:.3f}, "
          f"top token='{top_tok}'")

In [None]:
# ── Cache source activations ────────────────────────────────────────────

num_layers = model.config.num_hidden_layers
print(f"\nModel has {num_layers} layers. Caching source activations...")

source_activations = []

if USE_REMOTE:
    for i in range(num_layers):
        with model.trace(source_ids_aligned.unsqueeze(0), remote=True):
            act = model.model.layers[i].output[0].save()
        if act.dim() == 2:
            act = act.unsqueeze(0)
        source_activations.append(act)
        if (i + 1) % 8 == 0:
            print(f"  Cached {i + 1}/{num_layers} layers...")
else:
    with torch.no_grad():
        with model.trace(source_ids_aligned.unsqueeze(0)):
            for layer in model.model.layers:
                source_activations.append(layer.output[0].clone().save())
    for i in range(len(source_activations)):
        if source_activations[i].dim() == 2:
            source_activations[i] = source_activations[i].unsqueeze(0)

print(f"Source activations cached: {len(source_activations)} layers")
print(f"Shape: {source_activations[0].shape}")

In [None]:
# ── Run patching experiment ──────────────────────────────────────────────

num_tokens = len(base_token_strings)
print(f"\nRunning patching: {num_layers} layers × {num_tokens} token positions...")

patching_results = []

for layer_idx in trange(num_layers, desc="Layers"):
    layer_results = []
    for token_idx in range(sentence_start, sentence_start + num_tokens):
        if USE_REMOTE:
            with model.trace(base_ids_aligned.unsqueeze(0), remote=True):
                layer_out = model.model.layers[layer_idx].output[0]
                src_act = source_activations[layer_idx]
                if layer_out.dim() == 2:
                    layer_out[token_idx, :] = src_act[0, token_idx, :]
                else:
                    layer_out[:, token_idx, :] = src_act[:, token_idx, :]
                patched_logits = model.output.logits.save()
        else:
            with torch.no_grad():
                with model.trace(base_ids_aligned.unsqueeze(0)):
                    layer_out = model.model.layers[layer_idx].output[0]
                    src_act = source_activations[layer_idx]
                    if layer_out.dim() == 2:
                        layer_out[token_idx, :] = src_act[0, token_idx, :]
                    else:
                        layer_out[:, token_idx, :] = src_act[:, token_idx, :]
                    patched_logits = model.output.logits.save()

        patched_probs = patched_logits[0, -1].softmax(dim=-1)
        p_yes = sum(patched_probs[t].item() for t in yes_token_ids)
        p_no = sum(patched_probs[t].item() for t in no_token_ids)
        diff = torch.log(torch.tensor(p_yes) / torch.tensor(p_no)).item()
        layer_results.append(diff)

    patching_results.append(layer_results)

print("Patching complete!")

In [None]:
# ── Plot the heatmap (Figure 2 style) ───────────────────────────────────

fig, ax = plt.subplots(figsize=(max(10, 0.5 * len(base_token_strings)), 8))

arr = np.array(patching_results)

cmap = mcolors.LinearSegmentedColormap.from_list("white_blue", ["#FFFFFF", "#156082"])

im = ax.imshow(arr, cmap=cmap, aspect="auto", vmin=arr.min(), vmax=arr.max())

ax.set_xticks(range(len(base_token_strings)))
ax.set_xticklabels([t.strip() for t in base_token_strings], rotation=90,
                    ha="center", va="top", fontsize=10)
ax.set_yticks(range(num_layers))
ax.set_yticklabels(range(num_layers), fontsize=8)
ax.set_xlabel("Token Position", fontsize=12)
ax.set_ylabel("Layer", fontsize=12)
ax.set_title(
    f"Activation Patching: '{BASE_SENTENCE}' ← '{SOURCE_SENTENCE}'\n"
    f"log P(TRUE) − log P(FALSE) after patching each (token, layer)",
    fontsize=11
)

cbar = fig.colorbar(im, ax=ax, shrink=0.8)
cbar.set_label("log P(TRUE) − log P(FALSE)")

plt.tight_layout()
plt.savefig("patching_heatmap.png", dpi=150, bbox_inches="tight")
plt.show()
print("Saved to patching_heatmap.png")

In [None]:
# ── 8. Save results ────────────────────────────────────────────────────────

import json

results = {
    "model": MODEL_ID,
    "entity": ENTITY,
    "source_prompt": source_prompt,
    "base_prompt": base_prompt,
    "source_sentence": SOURCE_SENTENCE,
    "base_sentence": BASE_SENTENCE,
    "sentence_start": sentence_start,
    "token_labels": base_token_strings,
    "num_layers": num_layers,
    "patching_results": patching_results,
}

with open("patching_results.json", "w") as f:
    json.dump(results, f, indent=2)

print("Results saved to patching_results.json")