In [1]:
import json
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd
from itertools import product, chain

from transformers import AutoModelForCausalLM, AutoTokenizer
from pyvene import (
    ConstantSourceIntervention,
    LocalistRepresentationIntervention,
    IntervenableConfig,
    RepresentationConfig,
    IntervenableModel,
    VanillaIntervention,
)

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Llama-3.1 8B
MODEL = "meta-llama/Llama-3.1-8b"
HF_TOKEN = "..."
JSON_FILE = "translated_prompts_en.json"
M_NOISE   = 10      # number of noise seeds

tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL, device_map="auto", torch_dtype=torch.float16, token=HF_TOKEN)
model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

### Factual Recall

Let’s set up the model and test it on the fact we want to causal trace: “A cat sat on the mat and he”.

In [3]:
def get_gold_joint_original(
    model,
    tokenizer,
    prompt: str,
    device: str,
    gold_ids: list[int]
) -> float:
    """
    Return the joint probability of the gold subtoken sequence `gold_ids`
    under `intervenable` (with corruption+restoration), on the prompt.
    """
    gold_joint = 1.0
    curr_prompt = prompt

    for idx, gold_id in enumerate(gold_ids):
        # 1) run the intervened model to get next-token distribution
        inputs = tokenizer(curr_prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            out = model(**inputs)
        logits = out.logits[0, -1]
        probs  = F.softmax(logits, dim=-1)

        # 2) pick out the probability of the gold subtoken
        p_gold = probs[gold_id].item()
        gold_joint *= p_gold

        # 3) append the decoded gold subtoken to the prompt
        subtok = tokenizer.decode([gold_id], clean_up_tokenization_spaces=False)
        curr_prompt += subtok

    return gold_joint

def get_gold_joint_intervened(
    intervenable,
    tokenizer,
    prompt: str,
    device: str,
    sources,
    unit_locations,
    gold_ids: list[int]
) -> float:
    """
    Return the joint probability of the gold subtoken sequence `gold_ids`
    under `intervenable` (with corruption+restoration), on the prompt.
    """
    gold_joint = 1.0
    curr_prompt = prompt

    for idx, gold_id in enumerate(gold_ids):
        # 1) run the intervened model to get next-token distribution
        inputs = tokenizer(curr_prompt, return_tensors="pt").to(device)
        _, out = intervenable(
            inputs,
            sources,
            unit_locations=unit_locations
        )
        logits = out.logits[0, -1]
        probs  = F.softmax(logits, dim=-1)

        # 2) pick out the probability of the gold subtoken
        p_gold = probs[gold_id].item()
        gold_joint *= p_gold

        # 3) append the decoded gold subtoken to the prompt
        subtok = tokenizer.decode([gold_id], clean_up_tokenization_spaces=False)
        curr_prompt += subtok

    return gold_joint

def get_restore_positions(few_shot: str, target: str):
    """
    Returns the list of token indices in few_shot that correspond to
    every subtoken of `target`, but only when `target` appears in
    the *first* or *last* line of the prompt.
    """
    # 1) Split into lines and record their char‐offset spans
    lines = few_shot.split("\n")
#     print(lines)
    char_starts = []
    cum = 0
    for line in lines:
        char_starts.append(cum)
        cum += len(line) + 1  # +1 for the '\n'
    
    first_span = (char_starts[0], char_starts[0] + len(lines[0]))
    last_span  = (char_starts[-1], char_starts[-1] + len(lines[-1]))
    # ([-1] because last element after split is the empty test prefix line)

    # 2) Tokenize with offsets
    enc = tokenizer(
        few_shot,
        return_offsets_mapping=True,
        add_special_tokens=False
    )
    offsets = enc["offset_mapping"]  # list of (start_char, end_char)
    token_ids = enc["input_ids"]

    # 3) Find all token indices whose span lies fully within first or last line
    candidates = []
    for i, (st, en) in enumerate(offsets):
        in_first = first_span[0] <= st and en <= first_span[1]
        in_last  = last_span[0]  <= st and en <= last_span[1]
        if not (in_first or in_last):
            continue
        candidates.append(i)

    # 4) Within those, find sub‐ranges matching the target's token IDs
    target_ids = tokenizer.encode(target, add_special_tokens=False)
    L = len(target_ids)

    pos_restore = []
    for i in candidates:
        window = token_ids[i:i+L]
        if window == target_ids:
            pos_restore.extend(range(i+1, i+L+1))

    return sorted(set(pos_restore))

def get_custom_labels(inputs, pos_restore):
    # Prepare positions to intervene
    inputs = tokenizer(few_shot, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"][0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    # Input prompt tokens
    temp = []
    for i in input_ids:
        t = tokenizer.decode(i, clean_up_tokenization_spaces=False)
        temp.append(t)
    # Save for visualization
    custom_labels = list(temp)
    # mark restored tokens
    for idx in pos_restore:
        custom_labels[idx] += "*"

### Intervention

In [4]:
# 1) Define the noise intervention class
class NoiseIntervention(ConstantSourceIntervention, LocalistRepresentationIntervention):
    def __init__(self, embed_dim=None, seed: int = 1, **kwargs):
        super().__init__()
        # pull the provided size
        embed_dim = embed_dim or kwargs.get("latent_dim")
        if embed_dim is None:
            raise ValueError(f"No latent_dim in kwargs: {list(kwargs)}")
        self.interchange_dim = embed_dim

        rs = np.random.RandomState(seed)
        prng = lambda *shape: rs.randn(*shape)
        self.noise = torch.from_numpy(prng(1, 1, embed_dim)).to(device)
        self.noise_level = 0.13462981581687927

    def forward(self, base, source=None, subspaces=None):
        base[..., :self.interchange_dim] += self.noise * self.noise_level
        return base
    
# 2) Build a corrupt config for Llama-3.1 8B
def corrupted_config(model_type, layer, seed: int = 1):
    return IntervenableConfig(
        model_type=model_type,
        representations=[
            RepresentationConfig(
                layer,
                "block_input"
            ),
        ],
        intervention_types=NoiseIntervention,
        intervention_additional_kwargs=[ {"seed": seed} ]
    )

def restore_corrupted_with_interval_config(
    intervened_layer: int,
    restore_layer: int,
    stream: str = "mlp",
    window: int = 3,
    num_layers: int = 33,
    seed: int = 1
):
    # compute restore interval around restore_layer
    half = window // 2
    start = max(0, restore_layer - half)
    end   = min(num_layers, restore_layer + half + 1)
    
    reps = []
    types = []
    kwargs_list = [ {"seed": seed} ]
    
    # First: corrupt the embedding of token position later
    reps.append(
      RepresentationConfig(
        0,
        "block_input"
      )
    )
    types.append(NoiseIntervention)
    
    # Then: restore hidden state at each layer in [start, end)
    for L in range(start, end):
        reps.append(
          RepresentationConfig(
            L,
            stream 
          )
        )
        types.append(VanillaIntervention)
        kwargs_list.append({})
    
    return IntervenableConfig(
      model_type=type(model),
      representations=reps,
      intervention_types=types,
      intervention_additional_kwargs=kwargs_list
    )

In [10]:
JSON_FILE = "translated_prompts_en.json"

In [None]:
# 5) Create intervenable wrapper for a given layer
layer = 0

df = pd.read_json(JSON_FILE)
streams = ["block_output", "attention_output", "mlp_activation"]

for entry in df.itertuples():
    pid, prompt, restore_words, gold = entry.prompt_id, entry.prompt, entry.words_restore, entry.gold
    # parse language & tense
    tense, lang = pid.split("_")
    lang = lang[:2]

    # compute clean baseline once
    gold_ids = tokenizer.encode(f" {gold}", add_special_tokens=False)
    p_clean  = get_gold_joint_original(model, tokenizer, prompt, device, gold_ids)

    # compute restore positions once
    pos_restore = sorted(set(chain.from_iterable(
        get_restore_positions(prompt, w) for w in restore_words
    )))
    
    base = tokenizer(prompt, return_tensors="pt").to(device)
    
    for seed in range(M_NOISE):
        rows = []
        print(f"Working on seed {seed}...")
        # (re)set up noise intervention with new random seed
        config = corrupted_config(type(model), layer=0, seed=seed)
        intervenable  = IntervenableModel(config, model)

        # corrupt only baseline
        p_corrupt = get_gold_joint_intervened(
            intervenable, tokenizer, prompt, device,
            sources=None,
            unit_locations={"base":[[pos_restore]]},
            gold_ids=gold_ids
        )
        
        for stream in streams:
            print(f"Working on stream {stream}...")
            for restore_layer in range(33):
                for pos in range(base.input_ids.size(1)):
                    cfg = restore_corrupted_with_interval_config(
                        intervened_layer=0,      # we only corrupt at input
                        restore_layer=restore_layer,
                        stream=stream,
                        window=3,
                        num_layers=model.config.num_hidden_layers,
                        seed=seed
                    )
                    interv = IntervenableModel(cfg, model)

                    # First prediction
                    n_restores = len(cfg.representations) - 1
                    sources = [None] + [base]*n_restores
                    unit_locations = {
                        "sources->base": (
                            [None] + [[[pos]]]*n_restores,
                            [[pos_restore]] + [[[pos]]]*n_restores,
                        )
                    }

                    p_restored = get_gold_joint_intervened(
                        intervenable=interv,
                        tokenizer=tokenizer,
                        prompt=prompt,
                        device=device,
                        sources=sources,
                        unit_locations=unit_locations,
                        gold_ids=gold_ids,
                    )
                    
                    rows.append({
                        "language":        lang,
                        "tense":           tense,
                        "stream":          stream,
                        "prompt_id":       pid,
                        "pos":             pos,
                        "noise_seed":      seed,
                        "restore_layer":   restore_layer,
                        "gold":            gold,
                        "p_clean":         p_clean,
                        "p_corrupt":       p_corrupt,
                        "p_restored":      p_restored,
                        "delta_corrupt":   p_clean - p_corrupt,
                        "delta_restored":  p_restored - p_corrupt
                    })

        pd.DataFrame(rows).to_csv(f"{lang}_{seed}_{pid}.csv", index=False)
        print(f"Saved")

Working on seed 0...
Working on stream block_output...
Working on stream attention_output...
Working on stream mlp_activation...
Saved
Working on seed 1...
Working on stream block_output...
Working on stream attention_output...
Working on stream mlp_activation...
Saved
Working on seed 2...
Working on stream block_output...
Working on stream attention_output...
Working on stream mlp_activation...
Saved
Working on seed 3...
Working on stream block_output...
Working on stream attention_output...
Working on stream mlp_activation...
Saved
Working on seed 4...
Working on stream block_output...
Working on stream attention_output...
Working on stream mlp_activation...
Saved
Working on seed 5...
Working on stream block_output...
Working on stream attention_output...
Working on stream mlp_activation...
Saved
Working on seed 6...
Working on stream block_output...
Working on stream attention_output...
Working on stream mlp_activation...
