In [1]:
%load_ext autoreload
%autoreload 2

# Preparation

## Define model

In [2]:
from nnsight import LanguageModel
import torch as t

t.cuda.empty_cache()

device = t.device('cuda:0')
dtype = t.float32

model = LanguageModel(
    'EleutherAI/pythia-70m-deduped',
    device_map=device,
    dispatch=True,
    torch_dtype=dtype,
)

  from .autonotebook import tqdm as notebook_tqdm


## Load data

In [4]:
import os
from data_loading_utils import load_examples_rct
from dictionary_loading_utils import load_saes_and_submodules

category = 'concise'
num_examples = 100
train_clean, train_patch = load_examples_rct(os.path.join('prompts', category), 50)
save_base = f'pythia-70m-deduped_prompt_dir_{category}_n{num_examples}'

submods, SAEs = load_saes_and_submodules(
    model,
    separate_by_type=True,
    include_embed=True,
    device=device,
    dtype=dtype,
)

# Experiment

## Clean run

In [5]:
from attribution import get_activations
from metrics import eos_metric
from activation_utils import SparseAct

metric_fn = eos_metric
metric_kwargs = dict()
submods_flat = [
    submod for layer_submods in zip(submods.attns, submods.mlps, submods.resids) for submod in layer_submods
]
clean_acts, clean_metric = get_activations(
    model,
    train_clean,
    submods_flat,
    SAEs,
    metric_fn = eos_metric
)

## Patched run (with explicit behaviour elicitation)

In [6]:
patch_acts, patch_metric = get_activations(
    model,
    train_patch,
    submods_flat,
    SAEs,
    metric_fn = eos_metric
)

In [15]:
from attribution import pe_exact
effect = pe_exact(
    model, 
    train_clean,
    clean_acts,
    patch_acts,
    submods_flat, 
    SAEs,
    metric_fn,
    patch_pos=-1,
)

100%|██████████| 18/18 [36:07<00:00, 120.44s/it]


In [None]:
effects, deltas, grads, total_effect = effect

In [17]:
import pickle
with open('effect_0.pkl', 'wb') as f:
    pickle.dump(effect, f)

In [None]:
embed = submods.embed
attns = submods.attns
mlps = submods.mlps
resids = submods.resids

n_layers = len(resids)
total_effect = 
nodes = {"y": total_effect}
for i in range(n_layers):
    nodes[f"attn_{i}"] = effects[attns[i]]
    nodes[f"mlp_{i}"] = effects[mlps[i]]
    nodes[f"resid_{i}"] = effects[resids[i]]

effects[attns[0]].shape