In [None]:
import os
import torch
from transformers import AutoTokenizer
import numpy as np
import transformer_lens
torch.set_default_device("cuda")
import seaborn as sns
from functools import partial
import pandas as pd
from matplotlib import pyplot as plt
import json
from tqdm import tqdm
import torch.nn.functional as F
from scipy.stats import pearsonr
from enum import Enum
from src.utils import get_w_vo, rearrange_heads_by_layer, top_k_indices, get_k, load_dataset, get_topm_relation_heads
from src.maps import MAPS

pd.set_option('display.width', 1000)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

In [None]:
model_name = r"gpt2-xl"
model_family_name = "gpt2"
model = transformer_lens.HookedTransformer.from_pretrained_no_processing(model_name, device_map="auto")
for param in model.parameters():
    param.requires_grad = False
model.eval()
state_dict = model.state_dict()
cfg = model.cfg
is_gqa = cfg.n_key_value_heads != None
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
last_device = torch.device(f"cuda:{num_gpus-1}" if torch.cuda.is_available() else "cpu")

Predefined Relations

In [None]:
relation_name = "country_to_capital_wikidata"
k = get_k(model_name, relation_name)
dataset = load_dataset(relation_name)
apply_first_mlp = True
maps = MAPS(model, tokenizer)
relation_scores, suppression_relation_scores = maps.calc_relation_scores(
    dataset,
    apply_first_mlp,
    k)
sns.heatmap(relation_scores.T,vmin=0,vmax=1)
plt.xlabel("Layer")
plt.ylabel("Head")
plt.title(f"Relation scores\n{relation_name}\n{model_name}\nk={k}")

Static vs dynamic relation scores

In [None]:
relation_name = "country_to_capital_wikidata"
k = get_k(model_name, relation_name)
dataset = load_dataset(relation_name)
apply_first_mlp = True
template="This is a document about <X>"
maps = MAPS(model, tokenizer)
relation_scores, suppression_relation_scores = maps.calc_relation_scores(
    dataset,
    apply_first_mlp,
    k)
dynamic_results = maps.calc_dynamic_relation_scores(dataset, template, k)
sns.scatterplot(x=relation_scores.flatten(),y=dynamic_results["wo_context_dynamic_relation_scores"].flatten())
corr,pval = pearsonr(relation_scores.flatten(),dynamic_results["wo_context_dynamic_relation_scores"].flatten())
plt.title(f"Static vs Dynamic relation scores\ntemplate={template}\nPearson corr={corr:.2f}, pval={pval:.1e}\nmodel={model_name}\nw_context=False")
plt.xlabel("Relation score")
plt.ylabel("Dynamic relation score")

Causal Experiment

In [None]:
def sample_m_random_heads(m, heads_to_exclude, cfg):
    heads_to_exclude_indexed = [layer*cfg.n_heads + head for (layer,head) in heads_to_exclude]
    all_heads = [ix for ix in range(cfg.n_layers*cfg.n_heads)]
    available_heads = list(set(all_heads) - set(heads_to_exclude_indexed))
    sampled_heads_indexed = np.random.choice(available_heads, size=m, replace=False)
    sampled_heads = [(ix // cfg.n_heads, ix % cfg.n_heads) for ix in sampled_heads_indexed]
    return sampled_heads

In [None]:
k = 10
relation_name = "name_copying"
template = " John-> John; Donna-> Donna; <X>->"
dataset = load_dataset(relation_name)
apply_first_mlp = True
m_heads=150
maps = MAPS(model, tokenizer)
topm_relation_heads = get_topm_relation_heads(m_heads, maps, dataset, apply_first_mlp, get_k(model_name, relation_name), only_nonzero=True)
accuracies = maps.calc_causal_effects(dataset, template, topm_relation_heads)
control_relation_name = "general_copying_english_500"
control_template = " walk-> walk; cat-> cat; water-> water; <X>->"
control_dataset = load_dataset(control_relation_name)
control_accuracies = maps.calc_causal_effects(control_dataset, control_template, topm_relation_heads)
random_heads = sample_m_random_heads(m_heads, topm_relation_heads, cfg)
random_accuracies = maps.calc_causal_effects(dataset, template, random_heads)
sns.lineplot(accuracies,label="Main task, ablating relation heads")
sns.lineplot(control_accuracies,label="Control task (Copying), ablating relation heads")
sns.lineplot(random_accuracies,label="Main task, ablating random heads")
plt.title(f"MAPS causal experiment\nrelation={relation_name}\ntemplate={template}\n{model_name}")
plt.ylabel("Accuracy")
plt.yticks(np.arange(0,1.05,0.1))
plt.xlabel("# heads ablated")