In [101]:
import os
print(os.environ.get("HOSTNAME"))

import time

gpu38-008


In [102]:
import sys
# sys.path.append("./NetworkStructures/") # /!\ Comment out if "." is not home directory in goethe's cluster.

import torch
from transformers import logging
logging.set_verbosity_error()
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [103]:
from connectivity.effective import get_circuit_feature

from evaluation.faithfulness import faithfulness as faithfulness_fn

from data.buffer import unpack_batch

from utils.ablation_fns import zero_ablation, mean_ablation, id_ablation
from utils.savior import save_circuit
from utils.plotting import plot_faithfulness
from utils.metric_fns import metric_fn_logit, metric_fn_KL, metric_fn_statistical_distance, metric_fn_acc, metric_fn_MRR
from utils.experiments_setup import load_model_and_modules, load_saes, get_architectural_graph

import math

In [104]:
class single_input_buffer:
    def __init__(self, model, batch_size, device, ctx_len=None, perm=None):
        self.model = model
        self.batch_size = batch_size
        self.device = device
        self.ctx_len = ctx_len
        self.data = {
            "clean": ["When Mary and John went to the store, John gave a glass to"],
            "good": [[" Mary"]],
            "corr": ["When Mary and John went to the store, Paul gave a glass to"],
            "bad": [[" John"]],
        }
        self.done = False

    def __iter__(self):
        return self

    def __next__(self):
        if self.done:
            raise StopIteration
        self.done = True
        tk = self.model.tokenizer
        clean_tokens = tk(self.data["clean"], return_tensors='pt', padding=True, return_attention_mask=False, return_token_type_ids=False)['input_ids'].to(self.device)
        trg_idx = torch.zeros(clean_tokens.size(0), device=clean_tokens.device).long() - 1
        trg = []
        for i, good in enumerate(self.data["good"]):
            trg.append(tk(good, return_tensors='pt', return_attention_mask=False, return_token_type_ids=False)['input_ids'].to(self.device)[:, -1])
        corr_tokens = tk(self.data["corr"], return_tensors='pt', padding=True, return_attention_mask=False, return_token_type_ids=False)['input_ids'].to(self.device)
        corr_trg = []
        for i, bad in enumerate(self.data["bad"]):
            corr_trg.append(tk(bad, return_tensors='pt', return_attention_mask=False, return_token_type_ids=False)['input_ids'].to(self.device)[:, -1])

        return {
            "clean": clean_tokens,
            "trg_idx": trg_idx,
            "trg": trg,
            "corr": corr_tokens,
            "corr_trg": corr_trg,
        }


In [105]:
use_attn_mlp = True
use_resid = False
start_at_layer = 2
model, name2mod = load_model_and_modules(device=DEVICE, resid=use_resid, attn=use_attn_mlp, mlp=use_attn_mlp, start_at_layer=start_at_layer)
architectural_graph = get_architectural_graph(model, name2mod)

dictionaries = load_saes(model, name2mod, device=DEVICE)
print(architectural_graph)

Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer
{'y': ['mlp_5', 'attn_3', 'resid_2', 'attn_4', 'mlp_3', 'attn_5', 'mlp_4'], 'resid_2': [], 'attn_3': ['resid_2'], 'mlp_3': ['resid_2'], 'attn_4': ['mlp_3', 'resid_2', 'attn_3'], 'mlp_4': ['mlp_3', 'resid_2', 'attn_3'], 'attn_5': ['attn_3', 'resid_2', 'attn_4', 'mlp_3', 'mlp_4'], 'mlp_5': ['attn_3', 'resid_2', 'attn_4', 'mlp_3', 'mlp_4']}


In [106]:
edge_threshold = 1e-4
buffer = single_input_buffer(model, 1, DEVICE, ctx_len=None, perm=None)
steps = 10

edge_circuit = False

batch = next(buffer)
tokens, trg_idx, trg, corr, corr_trg = unpack_batch(batch)

clean = tokens
patch = corr

metric_fn = metric_fn_logit
metric_fn_dict = {
    'logit': metric_fn_logit,
    'KL': metric_fn_KL,
    'Statistical Distance': metric_fn_statistical_distance,
    # 'acc': metric_fn_acc,
    # 'MRR': metric_fn_MRR,
}

metric_kwargs = {"trg_idx": trg_idx, "trg_pos": trg, "trg_neg": corr_trg}

default_ablation = 'id'

if patch is not None:
    ablation_fn = id_ablation
if patch is None and ablation_fn is None:
    if default_ablation == 'mean':
        ablation_fn = mean_ablation
    elif default_ablation == 'zero':
        ablation_fn = zero_ablation
    elif default_ablation == 'id':
        ablation_fn = id_ablation
    else:
        raise ValueError(f"Unknown default ablation function : {default_ablation}")
    

In [107]:
edges = get_circuit_feature(
    clean=clean,
    patch=patch,
    model=model,
    architectural_graph=architectural_graph,
    name2mod=name2mod,
    dictionaries=dictionaries,
    metric_fn=metric_fn,
    metric_kwargs=metric_kwargs,
    ablation_fn=ablation_fn,
    threshold=edge_threshold,
    steps=steps,
    edge_circuit=edge_circuit,
)

In [108]:
nb_eval_thresholds = 20

thresholds = torch.logspace(math.log10(edge_threshold), 0.1, nb_eval_thresholds, 10).tolist() + [999.] # the higher the threshold, the more edges are removed. -1 is to enforce full ablation.

results = faithfulness_fn(
    model,
    name2mod,
    dictionaries,
    clean,
    edges,
    architectural_graph,
    thresholds,
    metric_fn_dict,
    metric_kwargs,
    patch,
    ablation_fn,
    default_ablation=default_ablation,
    node_ablation=(not edge_circuit),
)

In [109]:
plot_faithfulness(results, save_path=None)

logit [2.4891357421875, 2.48046875, 2.4659423828125, 2.4931640625, 2.450439453125, 2.4312744140625, 2.393310546875, 2.357177734375, 2.244873046875, 2.3397216796875, 2.3839111328125, 2.01513671875, 2.007080078125, 2.1650390625, 2.3438720703125, 2.2637939453125, 2.0863037109375, 2.194580078125, 2.1763916015625, 2.1763916015625, 2.1763916015625]
KL [0.0003356989473104477, 0.00034572184085845947, 0.0007738028652966022, 0.0010038330219686031, 0.0014394547324627638, 0.0017834627069532871, 0.0029769744724035263, 0.005828524474054575, 0.010149048641324043, 0.017404276877641678, 0.019076023250818253, 0.029309868812561035, 0.03361905738711357, 0.030940942466259003, 0.03102860599756241, 0.04362988844513893, 0.04387107864022255, 0.04991466552019119, 0.04767295718193054, 0.04767295718193054, 0.04767295718193054]
Statistical Distance [0.009500189684331417, 0.008843712508678436, 0.015384561382234097, 0.017870567739009857, 0.02204575017094612, 0.02265026792883873, 0.03066226653754711, 0.04475709050893

In [110]:
print(results)

{'complete': {'logit': tensor(2.4818, device='cuda:0'), 'KL': tensor([0.], device='cuda:0'), 'Statistical Distance': tensor([0.], device='cuda:0')}, 'empty': {'logit': tensor(2.1764, device='cuda:0'), 'KL': tensor([0.0477], device='cuda:0'), 'Statistical Distance': tensor([0.1168], device='cuda:0')}, 9.999999747378752e-05: {'n_nodes': 2917, 'n_edges': 2986511, 'avg_deg': 2047.6592389441207, 'density': tensor(0.0002), 'faithfulness': {'logit': 2.4891357421875, 'KL': 0.0003356989473104477, 'Statistical Distance': 0.009500189684331417, 'faithfulness_logit': 1.0239808559417725, 'faithfulness_KL': 0.9929583072662354, 'faithfulness_Statistical Distance': 0.9186869859695435}}, 0.0001643574796617031: {'n_nodes': 2692, 'n_edges': 2553615, 'avg_deg': 1897.1879643387815, 'density': tensor(0.0001), 'faithfulness': {'logit': 2.48046875, 'KL': 0.00034572184085845947, 'Statistical Distance': 0.008843712508678436, 'faithfulness_logit': 0.9956035017967224, 'faithfulness_KL': 0.9927480220794678, 'faithf