In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import nnsight
from nnsight import LanguageModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = LanguageModel("gpt2", device_map="auto")


In [3]:
path = "data/gpt2-small_nurse_man_20251110_215002.csv"
corpus_path = "data/WikiText.txt"
df = pd.read_csv(path, comment="#")
df.head()

Unnamed: 0,rank,layer,head,nie,abs_nie
0,1,10,9,-2.486696,2.486696
1,2,9,7,-1.148109,1.148109
2,3,11,8,-0.590734,0.590734
3,4,9,5,0.564491,0.564491
4,5,9,2,-0.285976,0.285976


In [4]:
def read_topk_heads(cma_csv: str, top_k: int):
    """
    Read CMA CSV and return top-K (layer, head, nie) tuples.

    Priority:
      - If 'rank' exists, sort ascending by rank (1 = best)
      - Else sort by 'abs_nie' (desc) if present, otherwise by 'nie' (desc)
    """
    df = pd.read_csv(cma_csv, comment="#")

    if "rank" in df.columns:
        df = df.sort_values("rank", ascending=True)
    elif "abs_nie" in df.columns:
        df = df.sort_values("abs_nie", ascending=False)
    elif "nie" in df.columns:
        df["abs_nie"] = df["nie"].abs()
        df = df.sort_values("abs_nie", ascending=False)
    else:
        raise ValueError("CMA CSV must have 'rank' or 'nie' / 'abs_nie' column.")

    df = df.dropna(subset=["layer", "head"])
    if "nie" in df.columns:
        df = df.dropna(subset=["nie"])

    if df.empty:
        raise ValueError(f"CMA CSV '{cma_csv}' has no valid rows after dropping NaNs.")

    sel = df.head(top_k)[["layer", "head", "nie"]].values.tolist()
    return [(int(L), int(h), float(nie)) for (L, h, nie) in sel]


top_k = 5
top_heads = read_topk_heads(path, top_k)

In [5]:
prompt_base = "The nurse said that"
prompt_cf = "The man said that"


In [6]:
def get_head_mask(head_indices, hidden_dim, head_dim):
    x = torch.zeros(hidden_dim)              # 必须是 tensor
    l = torch.tensor(head_indices)           # head indices 
    idx = l[:, None] * head_dim + torch.arange(head_dim)
    idx = idx.reshape(-1)
    
    x[idx] = 1
    return x.to(bool)

In [7]:
top_heads = read_topk_heads(path, top_k)
top_heads = sorted(top_heads, key=lambda x: (x[0], x[1]))   
top_heads

[(9, 2, -0.2859764099121094),
 (9, 5, 0.5644912719726562),
 (9, 7, -1.1481094360351562),
 (10, 9, -2.486696243286133),
 (11, 8, -0.5907344818115234)]

In [8]:
def run_head_of_exp(model, example, csv_path, top_k):
    base_examples = [e["base"]for e in example]
    cf_examples = [e["counterfactual"]for e in example]
    
    attn_dim = model.config.n_embd // model.config.n_head
    results = []
    top_heads = read_topk_heads(csv_path, top_k)
    # sort the heads based on layer and head index
    top_heads = sorted(top_heads, key=lambda x: (x[0], x[1]))   
    # merge to [{layer: [heads]}]
    heads_by_layer = {}
    for head in top_heads:
        L, h, nie = head
        heads_by_layer.setdefault(L, []).append(h)
        
    with model.edit() as edited:
        for idx, (L, heads) in enumerate(heads_by_layer.items()):
            head_mask = get_head_mask(heads, model.config.n_embd, attn_dim)
            edited.transformer.h[L].attn.c_proj.input[:, :, head_mask] = 0.0
    with model.trace(base_examples) as t1:
        base_logits = model.output['logits'][:,-1,:].detach().cpu().save()
    with model.trace(cf_examples) as t2:
        cf_logits = model.output['logits'][:,-1,:].detach().cpu().save()   
    with edited.trace(base_examples) as et1:
        edited_base_logits = model.output['logits'][:,-1,:].detach().cpu().save()
    with edited.trace(cf_examples) as et2:
        edited_cf_logits = model.output['logits'][:,-1,:].detach().cpu().save()
        
    return (edited, base_logits, cf_logits, edited_base_logits, edited_cf_logits)
    
    
# result = run_head_of_exp(model, prompt_base, prompt_cf, path, top_k)
# edited_model, base_logits, cf_logits, edited_base_logits, edited_cf_logits = result

In [9]:
def eval_ppl(model, text):
    """Evaluate perplexity of the given text using the model."""
    inputs = model.tokenizer(text, return_tensors="pt").to(model.device)
    with model.trace(inputs):
        logits = model.output['logits'].save()
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = inputs['input_ids'][..., 1:].contiguous()
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1))
    ppl = torch.exp(loss.mean()).item()
    return ppl


In [10]:
def tokenize_corpus(
    model,
    path,
    max_corpus_tokens
):
    if not path or not os.path.exists(path):
        return None
    with open(path, "r", encoding="utf-8") as f:
        text = f.read()
    tokens = model.tokenizer.encode(text, return_tensors="pt")
    if tokens.numel() <= 1:
        return None
    if max_corpus_tokens is not None and tokens.size(1) > max_corpus_tokens:
        tokens = tokens[:, :max_corpus_tokens]
    return tokens



In [11]:

def compute_corpus_perplexity(
    model,
    tokens,
    block_size: int = 512,
) -> float:
    if tokens is None or tokens.numel() <= 1:
        return float("nan")
    total_log_prob = 0.0
    total_tokens = 0
    device = model.device
    for start in range(0, tokens.size(1) - 1, block_size):
        end = min(start + block_size + 1, tokens.size(1))
        slice_tokens = tokens[:, start:end]
        with model.trace(slice_tokens) as t:
            logits = model.output['logits'].save()
        log_probs = torch.nn.functional.log_softmax(logits[:, :-1], dim=-1)
        target = slice_tokens[:, 1:]
        slice_tokens = slice_tokens.to(device)
        target = slice_tokens[:, 1:].to(device)

        gathered = torch.gather(log_probs, 2, target.unsqueeze(-1)).squeeze(-1) 
        total_log_prob += gathered.sum().item()
        total_tokens += target.numel()
    if total_tokens == 0:
        return float("nan")
    avg_log_prob = total_log_prob / total_tokens
    return float(np.exp(-avg_log_prob))



In [110]:
corpus_tokens = tokenize_corpus(model, corpus_path, max_corpus_tokens=4096)
baseline_corpus_ppl = (compute_corpus_perplexity(model, corpus_tokens) if corpus_tokens is not None else float("nan"))
edited_corpus_ppl = (compute_corpus_perplexity(edited_model, corpus_tokens) if corpus_tokens is not None else float("nan"))

In [111]:
baseline_corpus_ppl, edited_corpus_ppl

(20.934903423804844, 22.173661220600465)

In [20]:
logs = []
for k in tqdm([1, 5, 10, 20, 50, 100]):
    log = {
        "k": k,
        "NIE_before": 0,
        "NIE_after": 0,
        "ppl_before": 0,
        "ppl_after": 0,
    }
    result = run_head_of_exp(model, prompt_base, prompt_cf, path, k+1)
    edited_model, base_logits, cf_logits, edited_base_logits, edited_cf_logits = result
    id_she = model.tokenizer(" she", add_special_tokens=False)["input_ids"][0]
    id_he = model.tokenizer(" he", add_special_tokens=False)["input_ids"][0]
    bias_clean = (base_logits[0, id_she] - base_logits[0,id_he]).item()
    bias_gated = (edited_base_logits[0,id_she] - edited_base_logits[0,id_he]).item()
    bias_cf_replaced = (edited_cf_logits[0, id_she] - edited_cf_logits[0, id_he]).item()
    nie_before = bias_cf_replaced - bias_clean
    nie_after = bias_cf_replaced - bias_gated
    log["NIE_before"] = nie_before
    log["NIE_after"] = nie_after
    corpus_tokens = tokenize_corpus(model, corpus_path, max_corpus_tokens=4096)
    baseline_corpus_ppl = (compute_corpus_perplexity(model, corpus_tokens) if corpus_tokens is not None else float("nan"))
    edited_corpus_ppl = (compute_corpus_perplexity(edited_model, corpus_tokens) if corpus_tokens is not None else float("nan"))
    log["ppl_before"] = baseline_corpus_ppl
    log["ppl_after"] = edited_corpus_ppl
    remaining_pct = (
            float(abs(bias_gated) / (abs(bias_clean) + 1e-9))
            if abs(bias_clean) > 1e-9
            else float("nan")
        )
    log["remaining_pct"] = remaining_pct
    logs.append(log)
    
    
    

100%|██████████| 6/6 [00:08<00:00,  1.40s/it]


In [13]:
from prompts_winogender import get_prompt_examples
examples = get_prompt_examples("test")


In [28]:
rs = run_head_of_exp(model, examples, path, top_k)

In [27]:
logs = []
for k in tqdm(range(10)):
    log = {
        "k": k,
        "NIE_before": 0,
        "NIE_after": 0,
        "ppl_before": 0,
        "ppl_after": 0,
    }
    result = run_head_of_exp(model, examples, path, k)
    edited_model, base_logits, cf_logits, edited_base_logits, edited_cf_logits = result
    id_she = model.tokenizer(" she", add_special_tokens=False)["input_ids"][0]
    id_he = model.tokenizer(" he", add_special_tokens=False)["input_ids"][0]
    bias_clean = (base_logits[:, id_she] - base_logits[:, id_he])
    bias_edited = (edited_base_logits[:, id_she] - edited_base_logits[:, id_he])
    bias_cf_replaced = (edited_cf_logits[:, id_she] - edited_cf_logits[:, id_he])
    log['bias_clean_raw'] = bias_clean
    log['bias_edited_raw'] = bias_edited
    log['bias_cf_replaced_raw'] = bias_cf_replaced
    
    bias_clean = bias_clean.mean().item()
    bias_edited = bias_edited.mean().item()
    bias_cf_replaced = bias_cf_replaced.mean().item()
    corpus_tokens = tokenize_corpus(model, corpus_path, max_corpus_tokens=4096)
    baseline_corpus_ppl = (compute_corpus_perplexity(model, corpus_tokens) if corpus_tokens is not None else float("nan"))
    edited_corpus_ppl = (compute_corpus_perplexity(edited_model, corpus_tokens) if corpus_tokens is not None else float("nan"))
    log["ppl_before"] = baseline_corpus_ppl
    log["ppl_after"] = edited_corpus_ppl
    log['d_ppl'] = abs(edited_corpus_ppl - baseline_corpus_ppl)
    remaining_pct = (
            float(abs(bias_edited) / (abs(bias_clean) + 1e-9))
            if abs(bias_clean) > 1e-9
            else float("nan")
        )
    log["remaining_pct"] = remaining_pct
    logs.append(log)
    


  0%|          | 0/10 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 10/10 [00:13<00:00,  1.36s/it]


In [28]:
logs    

[{'k': 0,
  'NIE_before': 0,
  'NIE_after': 0,
  'ppl_before': 20.934903423804844,
  'ppl_after': 20.934903423804844,
  'bias_clean_raw': tensor([3.1397, 0.8581, 4.7102, 1.6012, 0.3988]),
  'bias_edited_raw': tensor([3.1397, 0.8581, 4.7102, 1.6012, 0.3988]),
  'bias_cf_replaced_raw': tensor([-2.4131, -1.2378, -4.3987, -0.4387, -2.0944]),
  'd_ppl': 0.0,
  'remaining_pct': 0.9999999995330601},
 {'k': 1,
  'NIE_before': 0,
  'NIE_after': 0,
  'ppl_before': 20.934903423804844,
  'ppl_after': 20.95264363222214,
  'bias_clean_raw': tensor([3.1397, 0.8581, 4.7102, 1.6012, 0.3988]),
  'bias_edited_raw': tensor([ 0.7468, -0.1749,  2.8602,  1.2758,  0.3267]),
  'bias_cf_replaced_raw': tensor([-1.5976, -1.1484, -2.8746, -0.1965, -2.0159]),
  'd_ppl': 0.01774020841729751,
  'remaining_pct': 0.4701681683243988},
 {'k': 2,
  'NIE_before': 0,
  'NIE_after': 0,
  'ppl_before': 20.934903423804844,
  'ppl_after': 21.1700966252794,
  'bias_clean_raw': tensor([3.1397, 0.8581, 4.7102, 1.6012, 0.3988]),
  

In [14]:
result1 = run_head_of_exp(model,examples, path, 3)
edited_model, base_logits, cf_logits, edited_base_logits, edited_cf_logits = result1


In [15]:
# ppl on examples
val_example = get_prompt_examples("val")
for v in val_example:
    edited_ppl = eval_ppl(edited_model, v['base'])
    base_ppl = eval_ppl(model, v['base'])
    print(f"Base PPL: {base_ppl}, Edited PPL: {edited_ppl}, Diff: {edited_ppl - base_ppl}")

Base PPL: 61.7552490234375, Edited PPL: 69.33033752441406, Diff: 7.5750885009765625
Base PPL: 122.14643859863281, Edited PPL: 136.76490783691406, Diff: 14.61846923828125
Base PPL: 88.93116760253906, Edited PPL: 86.7463150024414, Diff: -2.1848526000976562


In [32]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def compute_ppl_sentence_by_sentence(model, sentences):

    model.eval()

    total_nll = 0.0
    total_tokens = 0


    for s in tqdm(sentences):
        with model.trace(s):
            output = model.output.save()
        logits = output['logits']
        inputs = model.tokenizer(s, return_tensors="pt").to(model.device)
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = inputs['input_ids'][..., 1:].contiguous()
        loss_fct = nn.CrossEntropyLoss(reduction='none')
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1))
        nll = loss.sum().item()
        num_tokens = shift_labels.numel()
        total_nll += nll
        total_tokens += num_tokens
    ppl = torch.exp(torch.tensor(total_nll / total_tokens)).item()
    return ppl


In [17]:
train_examples = get_prompt_examples("train")
train_b = [e['base'] for e in train_examples]
train_cf = [e['counterfactual'] for e in train_examples]
val_examples = get_prompt_examples("val")
val_b = [e['base'] for e in val_examples]
val_cf = [e['counterfactual'] for e in val_examples]
sentences = train_b + train_cf + val_b + val_cf
model_name = "gpt2"
output = compute_ppl_sentence_by_sentence(model, sentences)
output_ed = compute_ppl_sentence_by_sentence(edited_model, sentences)
    