In [10]:
# --- MATS causal interventions + hydra-style analysis (Llama-3.2-3B) ---

import os, math, json, random, gc, pathlib
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Literal

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt

SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)



# Evaluation settings
GEN_MAX_NEW_TOKENS = 64
TEMPERATURE = 0.0
TOP_K = None

# Anchor positions we evaluate hydra/kl at:
EVAL_POS = "last"   # 'last' or 'first_code' (we compute first code token index)



In [11]:
# ---------------- Config ----------------

HF_TOKEN   = os.environ.get("HF_TOKEN")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.float16
MODEL_NAME = "meta-llama/Llama-3.2-3B"
# ---------------- Load model ----------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, token=HF_TOKEN, torch_dtype=DTYPE, low_cpu_mem_usage=True
).to(DEVICE).eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
# Create a dataset pair specific config file that can be passed everywhere

token_list_value = {
    "cpp_top": ["std", "int", ";", "{", "}", "[i", "long", "return", "<<", ">>"],
    "python_top": [":", "None", "def", "print", "==", "len", "range", "str", ",", "):"],
    "science": ["energy", "water", "cells", "system", "body", "blod", "carbon", "molecules", "light", "atoms"],
    "medical": ["cause", "options", "clinical", "patient", "blood", "symptoms", "disease", "diagnosis", "pain", "condition"],
    "finance": ["company", "total", "capital", "share", "tax", "assets", "total", "rate", "ratio", "value"],
    "math": ["total", "many", "number", "per", "cost", "times", "find", "one", "amount", "money"]
}
def tok(s: str):
    return tokenizer(s, return_tensors="pt").to(DEVICE)

def untok(s: int):
    return tokenizer.decode(s)
# [tok(i)['input_ids'][0][1].item() for i in config_dict['cpp-python'].a.token_set]
token_list = {key: [tok(i)['input_ids'][0][1].item() for i in value] for key, value in token_list_value.items()}

print(token_list)
class Language:
    def __init__(self, name: str, token_set: list[str] | None = None):
        self.name = name
        if token_set is None:
            self.token_set = self.load_specific()
        else:
            self.token_set = token_set
        
    def load_tokens(self):
        with open("token_set_id.json", 'r') as f:
            data = json.load(f)
        if self.name == "cpp_top":
            return data["cpp_clean"]
        if self.name == "python_top":
            return data["python_clean"]
        return data[self.name]
    def load_specific(self):
        return token_list[self.name]
    

class DatasetPairConfig:
    def __init__(self, lang_a: str, lang_b: str, file_path: str):
        self.a = Language(lang_a)
        self.b = Language(lang_b)
        self.prompts = self.load_file(file_path)
        
    def load_file(self, file_path):
        with open(file_path, 'r') as f:
            data = json.load(f)
        prompts = []
        for element in data:
            prompts.append((element[self.a.name], element[self.b.name]))
        return prompts
        
        
# CPP_TOKENS = {"def", "import", ":", "pass"}     # Python-ish indicators
# PYTHON_TOKENS = {";", "::", "std", "#"}            # C++-ish indicators; '#' approximates #include

config_dict = {
    "cpp-python": DatasetPairConfig("cpp_top", "python_top", "prompt_set/cpp_python_100.json"),
    # "sci-math": DatasetPairConfig("science", "math", "prompt_set/sci_math_prompts.json"),
    # "medical-finance": DatasetPairConfig("medical", "finance", "prompt_set/fin_med_prompts.json"),
}
    

{'cpp_top': [1872, 396, 26, 90, 92, 1004, 4930, 693, 2501, 2511], 'python_top': [25, 4155, 755, 1374, 419, 2963, 9866, 496, 11, 1680], 'science': [17947, 13284, 37791, 9125, 2664, 2067, 74441, 76, 4238, 66650], 'medical': [1593, 2945, 91899, 23557, 51105, 24738, 67, 8747, 92724, 9233], 'finance': [10348, 5143, 66163, 19930, 18081, 5271, 5143, 7853, 46458, 970], 'math': [5143, 35676, 4174, 716, 16845, 15487, 3990, 606, 6173, 29359]}


In [13]:
# def one_token_ids_scan(cands):
#     og_len = len(cands)
#     ids, keep, remove = [], [], []
#     for s in cands:
#         t = tokenizer.encode(s, add_special_tokens=False)
#         if len(t) == 1:
#             ids.append(t[0]); keep.append(s)
#         else:
#             print(f" ids = {t}, token = {s}")
#             remove.append(s)
            
#     new_len = len(keep)
#     print(f"removed {og_len} - {new_len} = {og_len - new_len} tokens")
#     print("These tokens were removed: ")
#     print(remove)
#     return ids, keep

# ids, tokens = one_token_ids_scan(config.a.token_set)
# ids, tokens = one_token_ids_scan(config.b.token_set)



In [14]:
# ---------------- Utilities ----------------
def tok(s: str):
    return tokenizer(s, return_tensors="pt").to(DEVICE)

def generate_logits(prompt: str, max_new_tokens=0):
    """Return logits over next token for the prompt (no generation if max_new_tokens=0)."""
    with torch.no_grad():
        enc = tok(prompt)
        out = model(**enc)
        logits = out.logits  # [B, T, V]
    return enc, logits

def pick_eval_index(input_ids: torch.Tensor, text: str, mode: str) -> int:
    # mode 'last': last token in the prompt; 'first_code': first token after a code marker
    if mode == "last":
        return input_ids.shape[1] - 1
    if mode == "first_code":
        # Heuristic: find "```" or a language marker in the prompt text
        idx = text.find("```")
        if idx < 0: return input_ids.shape[1] - 1
        # tokenize up to the backticks and use its length - 1
        prefix = text[:idx+3]
        with torch.no_grad():
            ids = tokenizer(prefix, return_tensors="pt").to(DEVICE)["input_ids"]
        return min(ids.shape[1]-1, input_ids.shape[1]-1)
    return input_ids.shape[1] - 1

def kl_divergence(p_logits, q_logits):
    # KL(p || q) at a single position; stable softmax
    p = F.log_softmax(p_logits, dim=-1)
    q = F.log_softmax(q_logits, dim=-1)
    return torch.sum(torch.exp(p) * (p - q), dim=-1)

def l2(x): return torch.norm(x.float(), dim=-1)

def token_prob_sum(logits, token_strs: List[str]) -> float:
    ids = []
    for s in token_strs:
        toks = tokenizer.encode(s, add_special_tokens=False)
        if len(toks) == 1:
            ids.append(toks[0])
    if not ids: return 0.0
    probs = F.softmax(logits, dim=-1)[..., ids]
    return probs.sum(-1).item()

# ---------------- Hooking helpers ----------------
@dataclass(frozen=True)
class LayerSpec:
    kind: Literal["attn","mlp"]
    idx: int

def get_layer_module(m: nn.Module, spec: LayerSpec) -> nn.Module:
    # Llama: model.model.layers[i].self_attn / .mlp
    return getattr(m.model.layers[spec.idx], "self_attn" if spec.kind=="attn" else "mlp")

def _as_tensor(output):
    # Some HF modules may return tuple; we normalize to a tensor for replacement logic
    if isinstance(output, tuple):
        return output[0]
    return output

def _repack_like(original_output, new_tensor):
    # Put new_tensor back into the original structure if needed
    if isinstance(original_output, tuple):
        lst = list(original_output)
        lst[0] = new_tensor
        return tuple(lst)
    return new_tensor

class Capture:
    """Capture module outputs (pre-residual) at all sequence positions."""
    def __init__(self, model: nn.Module, specs: List[LayerSpec]):
        self.handles = []
        self.data: Dict[LayerSpec, torch.Tensor] = {}
        for s in specs:
            mod = get_layer_module(model, s)
            h = mod.register_forward_hook(self._make_hook(s))
            self.handles.append(h)
    def _make_hook(self, spec):
        def hook(module, inp, out):
            self.data[spec] = _as_tensor(out).detach()
        return hook
    def remove(self):
        for h in self.handles: h.remove()
        self.handles = []

class Intervention:
    """Generic intervention hook using a function f(spec, output)->new_output."""
    def __init__(self, model: nn.Module, specs: List[LayerSpec], fn):
        self.handles = []
        self.fn = fn
        for s in specs:
            mod = get_layer_module(model, s)
            h = mod.register_forward_hook(self._make_hook(s))
            self.handles.append(h)
    def _make_hook(self, spec):
        def hook(module, inp, out):
            out_t = _as_tensor(out)
            new_t = self.fn(spec, out_t)
            return _repack_like(out, new_t)
        return hook
    def remove(self):
        for h in self.handles: h.remove()
        self.handles = []

# ---------------- Core experiments ----------------
def run_baseline(prompt: str):
    with torch.no_grad():
        enc = tok(prompt)
        out = model(**enc)
        return enc, out.logits, out.hidden_states if hasattr(out, "hidden_states") else None




def cross_swap_once(config: DatasetPairConfig, idx, spec: LayerSpec, eval_pos="last"):
    """Swap the module output at spec between A and B runs (last-token position)."""
    with torch.no_grad():
        prompt_a = config.prompts[idx][0]
        prompt_b = config.prompts[idx][1]
        enc_a = tok(prompt_a); enc_b = tok(prompt_b)

        # Capture outputs
        cap_a = Capture(model, [spec]); _ = model(**enc_a); cap_a.remove()
        cap_b = Capture(model, [spec]); _ = model(**enc_b); cap_b.remove()
        A = cap_a.data[spec]; B = cap_b.data[spec]  # [1, T, d]

        # Align lengths; use last token only to be safe
        def swap_fn_factory(src_vec):
            def fn(s, out):
                out2 = out.clone()
                out2[:, -1, :] = src_vec[:, -1, :].to(out2.dtype).to(out2.device)
                return out2
            return fn

        pos_a = pick_eval_index(enc_a["input_ids"], prompt_a, eval_pos)
        pos_b = pick_eval_index(enc_b["input_ids"], prompt_b, eval_pos)

        # Baseline logits at pos
        base_a = model(**enc_a).logits[:, pos_a, :]
        base_b = model(**enc_b).logits[:, pos_b, :]

        # A <- B
        hook = Intervention(model, [spec], swap_fn_factory(B))
        out = model(**enc_a); hook.remove()
        swap_a = out.logits[:, pos_a, :]

        # B <- A
        hook = Intervention(model, [spec], swap_fn_factory(A))
        out = model(**enc_b); hook.remove()
        swap_b = out.logits[:, pos_b, :]

        # Language-bias scores
        score_a_before = token_prob_sum(base_a, list(config.a.token_set)) - token_prob_sum(base_a, list(config.b.token_set))
        score_a_after  = token_prob_sum(swap_a, list(config.a.token_set)) - token_prob_sum(swap_a, list(config.b.token_set))
        score_b_before = token_prob_sum(base_b, list(config.b.token_set)) - token_prob_sum(base_b, list(config.a.token_set))
        score_b_after  = token_prob_sum(swap_b, list(config.b.token_set)) - token_prob_sum(swap_b, list(config.a.token_set))
        push_b = token_prob_sum(swap_a, list(config.b.token_set)) - token_prob_sum(base_a, list(config.b.token_set))
        push_a = token_prob_sum(swap_b, list(config.a.token_set)) - token_prob_sum(base_b, list(config.a.token_set))
        kl_a = kl_divergence(base_a, swap_a).item()
        kl_b = kl_divergence(base_b, swap_b).item()

        return {
            "kl_a": kl_a, "kl_b": kl_b,
            # "score_a_before": score_a_before, "score_a_after": score_a_after,
            # "score_b_before": score_b_before, "score_b_after": score_b_after,
            "push_b_in_a": push_b, "push_a_in_b": push_a
        }

def compute_concept_means(samples: List[str], specs: List[LayerSpec], max_samples: int = 200):
    """Mean module output per spec over samples (last token)."""
    means: Dict[LayerSpec, torch.Tensor] = {}
    sums: Dict[LayerSpec, torch.Tensor] = {}
    count = 0
    with torch.no_grad():
        for s in samples[:max_samples]:
            enc = tok(s)
            cap = Capture(model, specs); _ = model(**enc); cap.remove()
            for spec in specs:
                vec = cap.data[spec][:, -1, :]  # [1, d], last token
                if spec not in sums: sums[spec] = vec.clone().to("cpu")
                else: sums[spec] += vec.to("cpu")
            count += 1
    for spec in specs:
        means[spec] = (sums[spec] / count)
    return means


# ---------------- Plot helpers ----------------
def plot_kl_curve(kl_avg: Dict[LayerSpec, float], title="KL after zero-ablation by layer"):
    xs_attn, ys_attn = zip(*sorted([(s.idx, v) for s,v in kl_avg.items() if s.kind=="attn"]))
    xs_mlp,  ys_mlp  = zip(*sorted([(s.idx, v) for s,v in kl_avg.items() if s.kind=="mlp"]))
    plt.figure()
    plt.plot(xs_attn, ys_attn, marker="o", label="Attention")
    plt.plot(xs_mlp, ys_mlp, marker="o", label="MLP")
    plt.xlabel("Layer index"); plt.ylabel("KL(baseline || ablated)")
    plt.title(title); plt.legend(); plt.show()

def plot_hydra(hydra_avg: Dict[LayerSpec, Tuple[float,float]], title="Hydra: Δembed vs Δunembed (KL)"):
    xs, ys, cs = [], [], []
    for s,(demb, kl) in hydra_avg.items():
        xs.append(demb); ys.append(kl); cs.append("attn" if s.kind=="attn" else "mlp")
    plt.figure()
    for kind in ["attn","mlp"]:
        X = [x for x,c in zip(xs,cs) if c==kind]
        Y = [y for y,c in zip(ys,cs) if c==kind]
        plt.scatter(X, Y, label=kind, alpha=0.8)
    plt.xlabel("‖Δembed‖ at intervention"); plt.ylabel("KL at output")
    plt.title(title); plt.legend(); plt.show()

def plot_alpha_curve(results, title="Concept-vector injection"):
    alphas = [a for a,_,_ in results]
    scores = [s for _,s,_ in results]
    kls    = [k for *_,k in results]
    plt.figure(); plt.plot(alphas, scores, marker="o"); plt.xlabel("alpha"); plt.ylabel("Language-bias score"); plt.title(title + " (bias)"); plt.show()
    plt.figure(); plt.plot(alphas, kls, marker="o"); plt.xlabel("alpha"); plt.ylabel("KL"); plt.title(title + " (KL)"); plt.show()

In [15]:
# ========= INTERPRETABLE METRICS =========
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import json, math

# Indicator tokens (single-token only; we filter to those)
# PY_CANDIDATES  = ["def", "import", ":", "pass", "lambda", "print"]
# CPP_CANDIDATES = [";", "::", "std", "using", "cout", "#"]

def one_token_ids(cands):
    og_len = len(cands)
    ids, keep = [], []
    for s in cands:
        t = tokenizer.encode(s, add_special_tokens=False)
        if len(t) == 1:
            ids.append(t[0]); keep.append(s)
    return ids, keep

# ids, tokens = one_token_ids(config.a.token_set)
# print(ids, tokens)
# ids, tokens = one_token_ids(config.b.token_set)
# print(ids, tokens)


def prob_mass_for_ids(logits, ids):
    probs = F.softmax(logits, dim=-1)[..., ids]
    return probs.sum(-1).item()

def kl_divergence(p_logits, q_logits):
    p = F.log_softmax(p_logits, dim=-1)
    q = F.log_softmax(q_logits, dim=-1)
    return torch.sum(torch.exp(p) * (p - q), dim=-1)

def next_token_logits_at_first_code(prompt: str):
    enc = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    pos = pick_eval_index(enc["input_ids"], prompt, mode="last")
    with torch.no_grad():
        out = model(**enc)
    return out.logits[:, pos, :].squeeze(0)  # [V]

def cross_swap_verbose(config: DatasetPairConfig, prompt_a: str, prompt_b: str, spec: LayerSpec):
    # Baselines
    base_a = next_token_logits_at_first_code(prompt_a)
    base_b  = next_token_logits_at_first_code(prompt_b)

    # Capture layer outputs on last prompt position
    enc_a = tokenizer(prompt_a, return_tensors="pt").to(DEVICE)
    enc_b  = tokenizer(prompt_b,  return_tensors="pt").to(DEVICE)
    capA = Capture(model, [spec]); _ = model(**enc_a); capA.remove()
    capB = Capture(model, [spec]); _ = model(**enc_b);  capB.remove()
    A = capA.data[spec][:, -1, :]  # C++ layer output at last prompt token
    B = capB.data[spec][:, -1, :]  # Python layer output

    def swap_fn(vec):
        def fn(s, out):
            out2 = out.clone()
            out2[:, -1, :] = vec.to(out2)
            return out2
        return fn

    # C++ side: insert Python vec
    hook = Intervention(model, [spec], swap_fn(B))
    out_a = model(**enc_a); hook.remove()
    pos_a = pick_eval_index(enc_a["input_ids"], prompt_a, "last")
    swap_a = out_a.logits[:, pos_a, :].squeeze(0)

    # Python side: insert C++ vec
    hook = Intervention(model, [spec], swap_fn(A))
    out_b = model(**enc_b); hook.remove()
    pos_b = pick_eval_index(enc_b["input_ids"], prompt_b, "last")
    swap_b = out_b.logits[:, pos_b, :].squeeze(0)

    # a_token_ids, _  = one_token_ids(config.a.token_set)
    # b_token_ids, _  = one_token_ids(config.b.token_set)

    # Metrics per side (note: bias = P(Py tokens) - P(C++ tokens), same sign for both)
    def summarize(base, new, token_set_a, token_set_b):
        # assume A <- B

        kl = kl_divergence(base.unsqueeze(0), new.unsqueeze(0)).item()
        shift_b = prob_mass_for_ids(new, token_set_b) - prob_mass_for_ids(base, token_set_b)
        return dict(
            KL=kl,
            shift_other=shift_b
        )

    return {
        "spec": (spec.kind, spec.idx),
        config.a.name: summarize(base_a, swap_a, config.a.token_set, config.b.token_set),
        config.b.name: summarize(base_b,  swap_b, config.b.token_set, config.a.token_set),
    }
    


# ========= BATCH & AVERAGING =========
def run_cross_swap_batch(config: DatasetPairConfig, layer_specs):
    """
    paired_prompts: list of (cpp_prompt, py_prompt)
    returns dict keyed by (kind, idx) with averages across pairs
    """
    paired_prompts = config.prompts
    agg = defaultdict(lambda: {config.a.name: defaultdict(list), config.b.name: defaultdict(list)})
    for (a_p, b_p) in paired_prompts:
        for spec in layer_specs:
            res = cross_swap_verbose(config, a_p, b_p, spec)
            kind, idx = res["spec"]
            k = f"{kind}-{idx}"
            for side in [config.a.name,config.b.name]:
                for key, val in res[side].items():
                    if key in ["side"]: continue
                    agg[k][side][key].append(val)

    # average
    # out = {}
    # for k, sides in agg.items():
    #     out[k] = {}
    #     for side, metrics in sides.items():
    #         out[k][side] = {m: float(np.mean(vals)) for m, vals in metrics.items()}
    # return out
    
    # compute mean and std
    out = {}
    for k, sides in agg.items():
        out[k] = {}
        for side, metrics in sides.items():
            out[k][side] = {
                m: {
                    "mean": float(np.mean(vals)),
                    "std": float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0
                }
                for m, vals in metrics.items()
            }
    return out

def run_cross_swap_average(config, layer_specs):
    """
    Runs cross_swap_verbose over all prompt pairs in config.prompts
    and averages KL divergence and shift metrics for each layer spec.

    Returns dict keyed by (kind, idx), with averages for both sides.
    """
    paired_prompts = config.prompts

    # aggregate storage
    agg = defaultdict(lambda: {config.a.name: defaultdict(list),
                               config.b.name: defaultdict(list)})

    # run cross swap for all prompt pairs and layer specs
    for (a_p, b_p) in paired_prompts:
        for spec in layer_specs:
            res = cross_swap_verbose(config, a_p, b_p, spec)
            k = res["spec"]
            for side in [config.a.name, config.b.name]:
                for key, val in res[side].items():
                    if key == "side": 
                        continue
                    agg[k][side][key].append(val)

    # compute averages
    out = {}
    for k, sides in agg.items():
        out[k] = {}
        for side, metrics in sides.items():
            out[k][side] = {m: float(np.mean(vals)) for m, vals in metrics.items()}
    return out


def plot_cross_swap_kl(config:DatasetPairConfig, results, title="Cross-swap KL by layer (C++ & Python sides)", save="crossswap_kl.png"):
    a = config.a.name
    b = config.b.name
    labels = []
    kl_a, kl_b = [], []
    std_a, std_b = [], []

    # collect means and stds
    # for (kind, idx), sides in sorted(results.items(), key=lambda x: (x[0][0], x[0][1])):
    for (kind, idx), sides in results.items():
        labels.append(f"{kind.upper()}-{idx}")
        kl_a.append(sides[a]["KL"]["mean"])
        kl_b.append(sides[b]["KL"]["mean"])
        std_a.append(sides[a]["KL"]["std"])
        std_b.append(sides[b]["KL"]["std"])

    x = np.arange(len(labels))
    w = 0.38
    plt.figure(figsize=(10, 4))

    # bars with error bars (std)
    plt.bar(x - w/2, kl_a, width=w, label="KL on " + a.upper() + " prompt", yerr=std_a, capsize=5)
    plt.bar(x + w/2, kl_b,  width=w, label="KL on " + b.upper() + " prompt", yerr=std_b, capsize=5)

    plt.xticks(x, labels, rotation=45, ha="right")
    plt.ylabel("KL(baseline || swapped)")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    directory = os.path.dirname(save)
    if directory:
        os.makedirs(directory, exist_ok=True)
    plt.savefig(save, dpi=220, bbox_inches="tight")
    plt.show()
    
    print(f"Saved to {save}")

def plot_cross_swap_bias(config: DatasetPairConfig, results, title="Language-bias shift Δ (Py mass − C++ mass)", save="crossswap_bias.png"):
    a = config.a.name
    b = config.b.name
    labels = []
    db_a, db_b = [], []
    std_a, std_b = [], []

    # collect means and stds
    #for (kind, idx), sides in sorted(results.items(), key=lambda x:(x[0][0], x[0][1])):
    for (kind, idx), sides in results.items():
        labels.append(f"{kind.upper()}-{idx}")
        db_a.append(sides[a]["delta_bias"]["mean"])
        db_b.append(sides[b]["delta_bias"]["mean"])
        std_a.append(sides[a]["delta_bias"]["std"])
        std_b.append(sides[b]["delta_bias"]["std"])

    x = np.arange(len(labels)); w = 0.38
    plt.figure(figsize=(10,4))

    # bars with error bars (std)
    plt.bar(x - w/2, db_a, width=w, label="Δ bias on "+ a.upper() + " prompt", yerr=std_a, capsize=5)
    plt.bar(x + w/2, db_b,  width=w, label="Δ bias on " + b.upper() + " prompt", yerr=std_b, capsize=5)

    plt.axhline(0, linestyle="--", linewidth=1)
    plt.xticks(x, labels, rotation=45, ha="right")
    plt.ylabel(f"After − Before  (P[{b} tokens] − P[{a} tokens])")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save, dpi=220, bbox_inches="tight")
    plt.show()
    # print(f"Saved {save}")





In [16]:
import torch
import torch.nn.functional as F
from typing import Dict, Tuple, List

# Assume the following classes and functions are defined from your provided context:
# - model, tokenizer, DEVICE
# - DatasetPairConfig, LayerSpec
# - Capture, Intervention
# - next_token_logits_at_first_code, pick_eval_index


def calculate_token_shifts_for_layer(
    config: DatasetPairConfig,
    prompt_a: str,
    prompt_b: str,
    spec: LayerSpec
) -> Dict[str, torch.Tensor]:
    """
    Calculates the shift in probability for every token in the vocabulary after a causal swap.

    For a given layer (`spec`), this function swaps the activation from `prompt_b` into the
    forward pass of `prompt_a` (and vice-versa) and calculates the difference in the
    final probability distribution compared to the baseline.

    Args:
        config: Configuration object for the dataset pair (e.g., containing names 'cpp', 'python').
        prompt_a: The prompt for the first domain (e.g., the C++ prompt).
        prompt_b: The prompt for the second domain (e.g., the Python prompt).
        spec: The LayerSpec object defining the component and layer index for the intervention.

    Returns:
        A dictionary containing the raw probability shift tensors for each intervention direction.
        - Keys are descriptive strings like 'shift_cpp_<-_python'.
        - Values are tensors of shape [vocab_size] where each element is the change in
          probability for that token ID.
    """
    # 1. Get baseline next-token logits for both original prompts.
    base_a_logits = next_token_logits_at_first_code(prompt_a)
    base_b_logits = next_token_logits_at_first_code(prompt_b)

    # 2. Tokenize prompts and capture the internal activations at the specified layer.
    enc_a = tokenizer(prompt_a, return_tensors="pt").to(DEVICE)
    enc_b = tokenizer(prompt_b, return_tensors="pt").to(DEVICE)

    capA = Capture(model, [spec]); _ = model(**enc_a); capA.remove()
    capB = Capture(model, [spec]); _ = model(**enc_b); capB.remove()

    # Extract the activation at the final token position for each prompt.
    activation_A = capA.data[spec][:, -1, :]
    activation_B = capB.data[spec][:, -1, :]

    # 3. Define the intervention function that will replace the activation.
    def swap_function_factory(vector_to_insert):
        def intervention_fn(spec, current_output):
            intervened_output = current_output.clone()
            # Replace the activation at the last sequence position with the swapped vector.
            intervened_output[:, -1, :] = vector_to_insert.to(current_output.dtype)
            return intervened_output
        return intervention_fn

    # 4. Perform the causal swaps and get the new, intervened logits.
    # Intervention on Prompt A: Run prompt_a's input but swap in prompt_b's activation.
    with torch.no_grad():
        hook = Intervention(model, [spec], swap_function_factory(activation_B))
        out_a_swapped = model(**enc_a); hook.remove()
        pos_a = pick_eval_index(enc_a["input_ids"], prompt_a, "last")
        swap_a_logits = out_a_swapped.logits[:, pos_a, :].squeeze(0)
    
        # Intervention on Prompt B: Run prompt_b's input but swap in prompt_a's activation.
        hook = Intervention(model, [spec], swap_function_factory(activation_A))
        out_b_swapped = model(**enc_b); hook.remove()
        pos_b = pick_eval_index(enc_b["input_ids"], prompt_b, "last")
        swap_b_logits = out_b_swapped.logits[:, pos_b, :].squeeze(0)

    # 5. Calculate the shifts in probability distributions.
    # Shift = P(token | intervened) - P(token | baseline)
    base_a_probs = F.softmax(base_a_logits, dim=-1)
    swap_a_probs = F.softmax(swap_a_logits, dim=-1)
    shift_a_receives_b = swap_a_probs - base_a_probs

    base_b_probs = F.softmax(base_b_logits, dim=-1)
    swap_b_probs = F.softmax(swap_b_logits, dim=-1)
    shift_b_receives_a = swap_b_probs - base_b_probs

    return {
        f"shift_{config.a.name}_<-_{config.b.name}": shift_a_receives_b,
        f"shift_{config.b.name}_<-_{config.a.name}": shift_b_receives_a,
    }

def analyze_all_layers_with_token_shifts(
    config: DatasetPairConfig,
    prompt_a: str,
    prompt_b: str,
    layers_to_probe: List[int],
    components_to_probe: List[str],
    top_k: int = 20
) -> Dict[Tuple[str, int], Dict[str, Dict[str, List[Tuple[str, float]]]]]:
    """
    Iterates through layers to calculate token shifts, returning two separate lists
    for the top_k promoted (positive shift) and demoted (negative shift) tokens.
    """
    full_results = {}
    for comp in components_to_probe:
        for L in layers_to_probe:
            spec = LayerSpec(comp, L)
            print(f"--- Analyzing {spec} ---")

            # Get the raw shift tensors for the current layer
            shift_tensors = calculate_token_shifts_for_layer(config, prompt_a, prompt_b, spec)

            layer_results = {}
            for key, shift_tensor in shift_tensors.items():
                # Get the top k largest positive shifts (promoted tokens)
                top_pos_vals, top_pos_indices = torch.topk(shift_tensor, top_k)

                # Get the top k largest negative shifts (demoted tokens)
                top_neg_vals, top_neg_indices = torch.topk(-shift_tensor, top_k)

                # --- MODIFICATION START ---
                # Create a list for promoted tokens
                promoted_tokens = []
                for val, idx in zip(top_pos_vals, top_pos_indices):
                    token = tokenizer.decode([idx.item()])
                    promoted_tokens.append((token, val.item()))

                # Create a list for demoted tokens
                demoted_tokens = []
                for val, idx in zip(top_neg_vals, top_neg_indices):
                    token = tokenizer.decode([idx.item()])
                    # Negate the value back to its original negative shift
                    demoted_tokens.append((token, -val.item()))

                # Store the two lists in a dictionary
                layer_results[key] = {
                    'promoted': promoted_tokens,
                    'demoted': demoted_tokens
                }
                # --- MODIFICATION END ---

            full_results[(spec.kind, spec.idx)] = layer_results
            print(f"Finished {spec}.\n")

    return full_results

# results = analyze_all_layers_with_token_shifts(config_dict['medical-finance'], 'In medicine, a prediction of future recovery is a', 'In finance, a prediction of future earnings is a', [2], ['mlp', 'attn'], 100)


In [17]:
import torch
from typing import Dict, List, Tuple

def analyze_average_token_shifts(
    config: DatasetPairConfig,
    layers_to_probe: List[int],
    components_to_probe: List[str],
    top_k: int | None = None
) -> Dict[Tuple[str, int], Dict[str, Dict[str, List[Tuple[str, float]]]]]:
    """
    Computes the AVERAGE token shifts by iterating over a list of prompt pairs.

    For each layer, this function accumulates the shift tensors from all prompts in
    `config.prompts`, averages them, and then finds the top-k promoted and demoted tokens.
    """
    if top_k is None:
        top_k = tokenizer.vocab_size
    full_results = {}
    num_prompts = len(config.prompts)
    if num_prompts == 0:
        print("Warning: No prompts found in the config. Returning empty results.")
        return {}

    # Define the keys for the two intervention directions
    key_a_receives_b = f"shift_{config.a.name}_<-_{config.b.name}"
    key_b_receives_a = f"shift_{config.b.name}_<-_{config.a.name}"

    for comp in components_to_probe:
        for L in layers_to_probe:
            spec = LayerSpec(comp, L)
            print(f"--- Analyzing {spec} ---")

            # Initialize accumulators for the shift tensors for this layer
            # We initialize with 0.0, which will broadcast to the tensor shape upon the first addition
            accumulated_shifts = {
                key_a_receives_b: 0.0,
                key_b_receives_a: 0.0
            }

            # 1. Accumulate shifts over all prompt pairs
            for i in range(num_prompts):
                prompt_a = config.prompts[i][0]
                prompt_b = config.prompts[i][1]

                # This helper function gets the shifts for a single prompt pair
                shift_tensors_for_prompt = calculate_token_shifts_for_layer(config, prompt_a, prompt_b, spec)

                accumulated_shifts[key_a_receives_b] += shift_tensors_for_prompt[key_a_receives_b]
                accumulated_shifts[key_b_receives_a] += shift_tensors_for_prompt[key_b_receives_a]

            # 2. Average the accumulated shifts
            average_shift_tensors = {
                key: total_shift / num_prompts
                for key, total_shift in accumulated_shifts.items()
            }

            # 3. Process the averaged tensors to find top-k tokens (same logic as before)
            layer_results = {}
            for key, avg_shift_tensor in average_shift_tensors.items():
                top_pos_vals, top_pos_indices = torch.topk(avg_shift_tensor, top_k)
                top_neg_vals, top_neg_indices = torch.topk(-avg_shift_tensor, top_k)

                promoted_tokens = []
                for val, idx in zip(top_pos_vals, top_pos_indices):
                    promoted_tokens.append((tokenizer.decode([idx.item()]), val.item()))

                demoted_tokens = []
                for val, idx in zip(top_neg_vals, top_neg_indices):
                    demoted_tokens.append((tokenizer.decode([idx.item()]), -val.item()))

                layer_results[key] = {
                    'promoted': promoted_tokens,
                    'demoted': demoted_tokens
                }

            full_results[(spec.kind, spec.idx)] = layer_results
            print(f"Finished {spec} (averaged over {num_prompts} prompts).\n")

    return full_results

In [18]:
from collections import defaultdict
from typing import Dict, List, Tuple

def get_domain_representative_tokens(
    results: Dict[Tuple[str, int], Dict[str, Dict[str, List[Tuple[str, float]]]]],
    intervention_key: str,
    top_n: int = 50
) -> List[str]:
    """
    Aggregates shifts across all layers to find the most representative tokens for a domain.

    Args:
        results: The output dictionary from analyze_average_token_shifts.
        intervention_key: The shift direction that promotes the desired domain's tokens.
                          For example, to get 'science' tokens, use 'shift_science_<-_math'.
        top_n: The number of representative tokens to return.

    Returns:
        A sorted list of the top_n most representative token strings.
    """
    # 1. Use a defaultdict to store the cumulative shift score for each token.
    token_scores = defaultdict(float)

    # 2. Iterate through the results for each layer.
    for layer_data in results.values():
        if intervention_key in layer_data:
            # Get the list of promoted tokens for the current layer
            promoted_list = layer_data[intervention_key]['promoted']

            # 3. Add the shift score to the token's cumulative total.
            for token, shift in promoted_list:
                token_scores[token] += shift

    # 4. Sort the tokens by their final cumulative score in descending order.
    sorted_by_score = sorted(token_scores.items(), key=lambda item: item[1], reverse=True)

    # 5. Extract just the token strings and return the top_n.
    representative_tokens = [token for token, score in sorted_by_score]

    return representative_tokens[:top_n]

# --- EXAMPLE USAGE ---

# Assume 'results' is the output from your analysis function
# set_a = 'cpp_top'
# set_b = 'python_top'
# intervention = f'shift_{set_a}_<-_{set_b}'
# rep_tokens = get_domain_representative_tokens(results, intervention, top_n=50)

# print("Top 50 most representative 'python' tokens:")
# print(rep_tokens)

In [19]:
from typing import Dict, List, Tuple

def find_all_domain_representatives(
    config_dict: Dict[str, 'DatasetPairConfig'],
    layers_to_probe: List[int],
    components_to_probe: List[str],
    top_n: int = 100
) -> Dict[str, List[str]]:
    """
    Analyzes all dataset pairs in a config dictionary to find the most representative tokens for each domain.

    This function orchestrates the entire process:
    1. Loops through each `DatasetPairConfig`.
    2. Runs the average token shift analysis for that pair.
    3. Extracts the top N representative tokens for BOTH domains in the pair.
    4. Returns a single dictionary mapping each domain name to its list of tokens.

    Args:
        config_dict: A dictionary mapping names to DatasetPairConfig objects.
        layers_to_probe: A list of layer indices to analyze.
        components_to_probe: A list of component names (e.g., ['mlp', 'attn']).
        top_n: The number of representative tokens to find for each domain.

    Returns:
        A dictionary where keys are domain names (e.g., 'cpp', 'python', 'science')
        and values are the lists of top_n representative token strings.
    """
    all_domain_tokens = {}

    print(f"Starting analysis for {len(config_dict)} dataset pairs...")
    print(f"Probing layers: {layers_to_probe} for components: {components_to_probe}\n")

    for name, config in config_dict.items():
        print(f"--- Processing dataset pair: '{name}' ({config.a.name} vs {config.b.name}) ---")

        # 1. Run the core analysis, averaging over all prompts for this config.
        #    We use the default top_k=None to analyze the full vocabulary,
        #    ensuring the subsequent aggregation is as accurate as possible.
        results_for_pair = analyze_average_token_shifts(
            config,
            layers_to_probe,
            components_to_probe,
            top_k=None
        )

        # 2. Define the intervention keys to extract tokens for each domain.
        #    To get tokens for domain 'a', we look at what's promoted when 'a' receives 'b'.
        intervention_key_for_a = f"shift_{config.a.name}_<-_{config.b.name}"
        intervention_key_for_b = f"shift_{config.b.name}_<-_{config.a.name}"

        # 3. Extract the representative tokens for the first domain in the pair.
        print(f"Aggregating results for domain: '{config.a.name}'...")
        tokens_for_a = get_domain_representative_tokens(
            results_for_pair,
            intervention_key_for_a,
            top_n=top_n
        )
        all_domain_tokens[config.a.name] = tokens_for_a

        # 4. Extract the representative tokens for the second domain in the pair.
        print(f"Aggregating results for domain: '{config.b.name}'...")
        tokens_for_b = get_domain_representative_tokens(
            results_for_pair,
            intervention_key_for_b,
            top_n=top_n
        )
        all_domain_tokens[config.b.name] = tokens_for_b
        print(f"--- Finished processing '{name}' ---\n")

    print("="*40)
    print("      ANALYSIS COMPLETE      ")
    print("="*40)
    return all_domain_tokens

# --- EXAMPLE USAGE ---

# Assume 'config_dict' is your dictionary of all DatasetPairConfig objects.
# For example:
# config_dict = {
#     'sci-math': science_math_config,
#     'code': cpp_python_config,
#     'lang': english_french_config
# }

# Define which parts of the model to analyze.
# Analyzing mid-to-late layers is often most informative for abstract concepts.
LAYERS_TO_PROBE = range(1, 28)
COMPONENTS_TO_PROBE = ['mlp', 'attn']

# Run the full analysis
representative_tokens_by_domain = find_all_domain_representatives(
    config_dict,
    layers_to_probe=LAYERS_TO_PROBE,
    components_to_probe=COMPONENTS_TO_PROBE,
    top_n=300
)

# # Print the results for inspection
for domain, tokens in representative_tokens_by_domain.items():
    print(f"\n--- Top 20 of 100 Representative Tokens for '{domain.upper()}' ---")
    # Print the first 20 for a quick look
    print(tokens[:20])

with open("token_set.json", 'w') as f:
    json.dump(representative_tokens_by_domain, f)

Starting analysis for 1 dataset pairs...
Probing layers: range(1, 28) for components: ['mlp', 'attn']

--- Processing dataset pair: 'cpp-python' (cpp_top vs python_top) ---
--- Analyzing LayerSpec(kind='mlp', idx=1) ---
Finished LayerSpec(kind='mlp', idx=1) (averaged over 100 prompts).

--- Analyzing LayerSpec(kind='mlp', idx=2) ---
Finished LayerSpec(kind='mlp', idx=2) (averaged over 100 prompts).

--- Analyzing LayerSpec(kind='mlp', idx=3) ---
Finished LayerSpec(kind='mlp', idx=3) (averaged over 100 prompts).

--- Analyzing LayerSpec(kind='mlp', idx=4) ---
Finished LayerSpec(kind='mlp', idx=4) (averaged over 100 prompts).

--- Analyzing LayerSpec(kind='mlp', idx=5) ---
Finished LayerSpec(kind='mlp', idx=5) (averaged over 100 prompts).

--- Analyzing LayerSpec(kind='mlp', idx=6) ---
Finished LayerSpec(kind='mlp', idx=6) (averaged over 100 prompts).

--- Analyzing LayerSpec(kind='mlp', idx=7) ---
Finished LayerSpec(kind='mlp', idx=7) (averaged over 100 prompts).

--- Analyzing LayerSpe

In [None]:
output_dir = "token_sets"
os.makedirs(output_dir, exist_ok=True)
with open("token_sets/tokens.json", 'w') as f:
    json.dump(representative_tokens_by_domain, f)
    
with open("token_sets/tokens.json", "r") as f:
    data = json.load(f)
    
result = {}
for dataset, token_list in data.items():
    token_id_list = []
    for element in token_list:
        id = tokenizer.encode(element)[1]
        # print(tokenizer.decode(id), id)
        token_id_list.append(id)
        
    result[dataset] = token_id_list
    
print(result)
for dataset, list in data.items():
    print(dataset, len(list))
    
with open("token_sets/ids.json", "w") as f:
    json.dump(result, f)

{'cpp_top': [11055, 66, 711, 272, 443, 674, 586, 755, 12958, 34, 898, 538, 1058, 356, 10248, 2, 1179, 475, 1723, 1019, 322, 528, 742, 734, 47316, 720, 47924, 1085, 1012, 738, 220, 10344, 928, 1181, 925, 1799, 396, 8144, 1040, 9842, 333, 2039, 3, 2707, 257, 791, 1416, 31380, 31, 1845, 1527, 128001, 505, 1487, 1674, 422, 1757, 578, 256, 40, 16, 400, 959, 14402, 1872, 55375, 42333, 571, 5688, 2997, 879, 27, 37942, 71, 77, 644, 262, 353, 4077, 87, 1379, 763, 90, 12761, 2580, 79, 5560, 939, 308, 3990, 9, 58, 4724, 767, 2028, 985, 47375, 13325, 2000, 2485, 366, 3788, 22818, 2675, 5040, 69, 1118, 2566, 1686, 1988, 5321, 5618, 7, 320, 6462, 1075, 0, 358, 8872, 4942, 1115, 13798, 1342, 817, 510, 32, 1264, 2020, 369, 1472, 3368, 2465, 2900, 7003, 4110, 1580, 362, 3295, 4, 73, 1701, 10464, 4324, 28121, 14, 7531, 1442, 5830, 4429, 865, 39, 47, 471, 74, 305, 15391, 4194, 16234, 16644, 71742, 1358, 11502, 2033, 3556, 3427, 82, 4815, 9528, 260, 7927, 330, 72, 3350, 10086, 258, 1257, 6403, 2746, 10091