In [1]:
import torch
from transformers import logging
logging.set_verbosity_error()
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from connectivity.effective import get_circuit_feature
from evaluation.faithfulness import faithfulness as faithfulness_fn

from data.buffer import unpack_batch

from utils.ablation_fns import zero_ablation, mean_ablation, id_ablation
from utils.savior import save_circuit
from utils.plotting import plot_faithfulness
from utils.metric_fns import metric_fn_logit, metric_fn_KL, metric_fn_statistical_distance, metric_fn_acc, metric_fn_MRR
from utils.experiments_setup import load_model_and_modules, load_saes, get_architectural_graph

In [3]:
class single_input_buffer:
    def __init__(self, model, batch_size, device, ctx_len=None, perm=None):
        self.model = model
        self.batch_size = batch_size
        self.device = device
        self.ctx_len = ctx_len
        self.data = {
            "clean": ["When Mary and John went to the store, John gave a glass to"],
            "good": [[" Mary"]],
            "corr": ["When Mary and John went to the store, Paul gave a glass to"],
            "bad": [[" John"]],
        }
        self.done = False

    def __iter__(self):
        return self

    def __next__(self):
        if self.done:
            raise StopIteration
        self.done = True
        tk = self.model.tokenizer
        clean_tokens = tk(self.data["clean"], return_tensors='pt', padding=True, return_attention_mask=False, return_token_type_ids=False)['input_ids'].to(self.device)
        trg_idx = torch.zeros(clean_tokens.size(0), device=clean_tokens.device).long() - 1
        trg = []
        for i, good in enumerate(self.data["good"]):
            trg.append(tk(good, return_tensors='pt', return_attention_mask=False, return_token_type_ids=False)['input_ids'].to(self.device)[:, -1])
        corr_tokens = tk(self.data["corr"], return_tensors='pt', padding=True, return_attention_mask=False, return_token_type_ids=False)['input_ids'].to(self.device)
        corr_trg = []
        for i, bad in enumerate(self.data["bad"]):
            corr_trg.append(tk(bad, return_tensors='pt', return_attention_mask=False, return_token_type_ids=False)['input_ids'].to(self.device)[:, -1])

        return {
            "clean": clean_tokens,
            "trg_idx": trg_idx,
            "trg": trg,
            "corr": corr_tokens,
            "corr_trg": corr_trg,
        }


In [4]:
model, name2mod = load_model_and_modules(device=DEVICE, resid=False)
architectural_graph = get_architectural_graph(model, name2mod)
print(architectural_graph)

Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [5]:
dictionaries = load_saes(model, name2mod, device=DEVICE)

buffer = single_input_buffer(model, 1, DEVICE, ctx_len=None, perm=None)


In [6]:
for batch in tqdm(buffer):
    tokens, trg_idx, trg, corr, corr_trg = unpack_batch(batch)

    get_circuit_feature(
            clean=tokens,
            patch=corr,
            model=model,
            architectural_graph=architectural_graph,
            name2mod=name2mod,
            dictionaries=dictionaries,
            metric_fn=metric_fn_logit,
            metric_kwargs={"trg_idx": trg_idx, "trg_pos": trg, "trg_neg": corr_trg},
            ablation_fn=zero_ablation,
            edge_threshold=0.01,
        )

{'embed': [], 'attn_0': ['embed'], 'mlp_0': ['embed'], 'y': ['mlp_4', 'attn_4', 'attn_3', 'attn_1', 'mlp_2', 'mlp_5', 'attn_0', 'embed', 'mlp_0', 'attn_2', 'attn_5', 'mlp_1', 'mlp_3'], 'attn_1': ['attn_0', 'embed', 'mlp_0'], 'mlp_1': ['attn_0', 'embed', 'mlp_0'], 'attn_2': ['attn_1', 'attn_0', 'embed', 'mlp_0', 'mlp_1'], 'mlp_2': ['attn_1', 'attn_0', 'embed', 'mlp_0', 'mlp_1'], 'attn_3': ['attn_1', 'attn_0', 'embed', 'mlp_0', 'attn_2', 'mlp_1', 'mlp_2'], 'mlp_3': ['attn_1', 'attn_0', 'embed', 'mlp_0', 'attn_2', 'mlp_1', 'mlp_2'], 'attn_4': ['attn_3', 'attn_1', 'mlp_2', 'attn_0', 'embed', 'mlp_0', 'attn_2', 'mlp_1', 'mlp_3'], 'mlp_4': ['attn_3', 'attn_1', 'mlp_2', 'attn_0', 'embed', 'mlp_0', 'attn_2', 'mlp_1', 'mlp_3'], 'attn_5': ['mlp_4', 'attn_4', 'attn_3', 'attn_1', 'mlp_2', 'attn_0', 'embed', 'mlp_0', 'attn_2', 'mlp_1', 'mlp_3'], 'mlp_5': ['mlp_4', 'attn_4', 'attn_3', 'attn_1', 'mlp_2', 'attn_0', 'embed', 'mlp_0', 'attn_2', 'mlp_1', 'mlp_3']}


In [None]:
batch = next(buffer)
tokens, trg_idx, trg, corr, corr_trg = unpack_batch(batch)
circuit = get_circuit_feature(
    clean=tokens,
    patch=corr,
    model=model,
    architectural_graph=architectural_graph,
    name2mod=name2mod,
    dictionaries=dictionaries,
    metric_fn=metric_fn_logit,
    metric_kwargs={"trg_idx": trg_idx, "trg_pos": trg, "trg_neg": corr_trg},
    ablation_fn=zero_ablation,
    edge_threshold=0.01,
)

In [None]:
thresholds = torch.logspace(0.1, 0, 10, 10).tolist()

faithfulness = faithfulness_fn(
    model,
    submodules=submodules,
    sae_dict=dictionaries,
    name_dict=name_dict,
    clean=tokens,
    circuit=(tot_nodes, tot_edges),
    thresholds=thresholds,
    metric_fn=metric_fn_dict,
    metric_fn_kwargs={"trg_idx": trg_idx, "trg_pos": trg, "trg_neg": corr_trg},
    ablation_fn=ablation_fn,
    patch=corr,
    node_ablation=node_ablation,
)