In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from nnsight import LanguageModel
from activation_utils import SparseAct
import torch as t
import plotly.graph_objects as go
from loading_utils import load_examples
from dictionary_learning import AutoEncoder
from dictionary_learning.dictionary import IdentityDict
from dictionary_loading_utils import load_saes_and_submodules
from ablation import run_with_ablations
from attribution import Submodule
from scipy import interpolate
import math
from tqdm import tqdm
from statistics import stdev

In [2]:
device = 'cuda:0'
model_name = "EleutherAI/pythia-70m-deduped"
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=device, dispatch=True)
# model = LanguageModel(model_name, attn_implementation="eager",
#                       torch_dtype=t.bfloat16, device_map=device, dispatch=True)

start_layer = 2 # explain the model starting here



cuda:0


In [3]:
if "gemma-2" in model_name:
    model = LanguageModel(model_name, device_map=device, dispatch=True,
                            torch_dtype=t.bfloat16, attn_implementation="eager")
    embed_submod = Submodule(
        name = "embed",
        submodule=model.model.embed_tokens,
    )
    model_layers = model.model.layers
    out_submod = model.lm_head
else:
    model = LanguageModel(model_name, device_map=device, dispatch=True)
    embed_submod = Submodule(
        name="embed",
        submodule=model.gpt_neox.embed_in
    )
    model_layers = model.gpt_neox.layers
    out_submod = model.embed_out

submodules = []
if start_layer < 0: submodules.append(embed_submod)
for i in range(start_layer, len(model_layers)):
    if "gemma-2" in model_name:
        submodules.extend([
            Submodule(submodule=model.model.layers[i].self_attn.o_proj, use_input=True, name=f"attn_{i}"),
            Submodule(submodule=model.model.layers[i].post_feedforward_layernorm, name=f"mlp_{i}"),
            Submodule(submodule=model.model.layers[i], is_tuple=True, name=f"resid_{i}")
        ])
    else:
        submodules.extend([
            Submodule(submodule=model.gpt_neox.layers[i].attention, name=f"attn_{i}", is_tuple=True),
            Submodule(submodule=model.gpt_neox.layers[i].mlp, name=f"mlp_{i}"),
            Submodule(submodule=model.gpt_neox.layers[i], name=f"resid_{i}", is_tuple=True)
        ])

cuda:0


In [4]:
dict_id = 1
if dict_id != 'id':
    if "gemma-2" in model_name:
        include_embed = False
        dict_size = 16384
        activation_dim = 2304
        dtype = t.bfloat16
    else:
        include_embed = True
        dict_size = 32768
        activation_dim = 512
        dtype = t.float32
    _, feat_dicts = load_saes_and_submodules(
        model, 
        model_name, 
        dtype=dtype,
        device=device,
        include_embed=include_embed
    )

neuron_dicts = {
    submod : IdentityDict(activation_dim).to(device) for submod in submodules
}
for nd in neuron_dicts:
    if "gemma-2" in model_name:
        if nd.name.startswith("att"):
            neuron_dicts[nd] = IdentityDict(2048).to(device)

  state_dict = t.load(path)


In [None]:
# load dictionaries
dict_id = 10

activation_dim = 512
expansion_factor = 64
dict_size = expansion_factor * activation_dim

feat_dicts = {}
feat_dicts[model.gpt_neox.embed_in] = AutoEncoder.from_pretrained(
    f'../dictionaries/pythia-70m-deduped/embed/{dict_id}_{dict_size}/ae.pt', device=device
)
for i in range(len(model.gpt_neox.layers)):
    feat_dicts[model.gpt_neox.layers[i].attention] = AutoEncoder.from_pretrained(
        f'../dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dict_size}/ae.pt', device=device
    )
    feat_dicts[model.gpt_neox.layers[i].mlp] = AutoEncoder.from_pretrained(
        f'../dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dict_size}/ae.pt', device=device
    )
    feat_dicts[model.gpt_neox.layers[i]] = AutoEncoder.from_pretrained(
        f'../dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dict_size}/ae.pt', device=device
    )

neuron_dicts = {
    submod : IdentityDict(activation_dim).to(device) for submod in submodules
}


In [5]:
# use mean ablation
ablation_fn = lambda x: x.mean(dim=0).expand_as(x)

In [14]:
# get m(C) for the circuit obtained by thresholding nodes with the given threshold
def get_fcs(
        dataset,
        model,
        submodules,
        dictionaries,
        ablation_fn,
        thresholds,
        length,
        handle_errors = 'default', # also 'remove' or 'resid_only'
        use_neurons = False,
        random = False,
):
    # load data 
    if "gemma-2" in model.config._name_or_path:
        dfeat, dneuron = (1, 'id')
        node_threshold = 0.75
        model_layers = model.model.layers
    else:
        dfeat, dneuron = (10, 'id')
        node_threshold = 0.1
        model_layers = model.gpt_neox.layers
    edge_threshold = node_threshold / 10
    if not use_neurons:
        circuit = t.load(f'../circuits/{dataset}_train_dict{dfeat}_node{node_threshold}_edge{edge_threshold}_n100_aggnone.pt')['nodes']
    else:
        circuit = t.load(f'../circuits/{dataset}_train_dict{dneuron}_node{node_threshold}_edge{edge_threshold}_n100_aggnone.pt')['nodes']
    examples = load_examples(f'/share/projects/dictionary_circuits/data/phenomena/{dataset}_test.json', 40, model, length=length)
    batch_size = 10
    num_examples = len(examples)
    n_batches = math.ceil(num_examples / batch_size)
    batches = [
        examples[batch*batch_size:(batch+1)*batch_size] for batch in range(n_batches)
    ]
    out = {}

    for batch in tqdm(batches):
        clean_inputs = t.cat([e['clean_prefix'] for e in examples], dim=0).to('cuda:0')
        clean_answer_idxs = t.tensor([e['clean_answer'] for e in examples], dtype=t.long, device='cuda:0')
        patch_inputs = t.cat([e['patch_prefix'] for e in examples], dim=0).to('cuda:0')
        patch_answer_idxs = t.tensor([e['patch_answer'] for e in examples], dtype=t.long, device='cuda:0')
        def metric_fn(model):
            return (
                - t.gather(out_submod.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) + \
                t.gather(out_submod.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
            )
        
        with t.no_grad():
            # get F(M)
            with model.trace(clean_inputs):
                metric = metric_fn(model).save()
            fm = metric.value

            if 'fm' not in out:
                out['fm'] = fm
            else:
                out['fm'] = t.cat((out['fm'], fm))

            # get m(∅)
            fempty = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                submodules,
                dictionaries,
                nodes = {
                    submod.submodule : SparseAct(
                        act=t.zeros(dict_size if not use_neurons else \
                            (2048 if "gemma-2" in model.config._name_or_path and submod.name.startswith("att") else activation_dim), 
                            dtype=t.bool),
                        resc=t.zeros(1, dtype=t.bool)).to(device)
                        for submod in submodules
                },
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
            )
            if 'fempty' not in out:
                out['fempty'] = fempty
            else:
                out['fempty'] = t.cat((out['fempty'], fempty))

            for threshold in thresholds:
                if threshold not in out:
                    out[threshold] = {}
                nodes = {
                    submod.submodule : circuit[submod.name].abs() > threshold for submod in submodules
                }

                if handle_errors == 'remove':
                    for k in nodes: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)
                elif handle_errors == 'resid_only':
                    for k in nodes:
                        if k not in model_layers: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)

                n_nodes = sum([n.act.sum() + n.resc.sum() for n in nodes.values()]).item()
                if random:
                    total_nodes = sum([n.act.numel() + n.resc.numel() for n in nodes.values()])
                    p = n_nodes / total_nodes
                    for k in nodes:
                        nodes[k].act = t.bernoulli(t.ones_like(nodes[k].act, dtype=t.float) * p).to(device).to(dtype=t.bool)
                        nodes[k].resc = t.ones_like(nodes[k].resc, dtype=t.bool).to(device)
                    out[threshold]['n_nodes'] = sum([n.act.sum() + n.resc.sum() for n in nodes.values()]).item()
                else:
                    out[threshold]['n_nodes'] = n_nodes
                

                
                fc = run_with_ablations(
                    clean_inputs,
                    patch_inputs,
                    model,
                    submodules,
                    dictionaries,
                    nodes=nodes,
                    metric_fn=metric_fn,
                    ablation_fn=ablation_fn,
                )
                if 'fc' not in out[threshold]:
                    out[threshold]['fc'] = fc
                else:
                    out[threshold]['fc'] = t.cat((out[threshold]['fc'], fc))

                fccomp = run_with_ablations(
                    clean_inputs,
                    patch_inputs,
                    model,
                    submodules,
                    dictionaries,
                    nodes=nodes,
                    metric_fn=metric_fn,
                    ablation_fn=ablation_fn,
                    complement=True
                )
                if 'fccomp' not in out[threshold]:
                    out[threshold]['fccomp'] = fccomp
                else:
                    out[threshold]['fccomp'] = t.cat((out[threshold]['fccomp'], fccomp))

    out['fempty'] = out['fempty'].mean().item()
    out['fm'] = out['fm'].mean().item()
    for threshold in thresholds:
        out[threshold]['fc'] = out[threshold]['fc'].mean().item()
        out[threshold]['fccomp'] = out[threshold]['fccomp'].mean().item()
        out[threshold]['faithfulness'] = (out[threshold]['fc'] - out['fempty']) / (out['fm'] - out['fempty'])
        out[threshold]['completeness'] = (out[threshold]['fccomp'] - out['fempty']) / (out['fm'] - out['fempty'])

    return out


In [16]:
# dataset : number of tokens in inputs from dataset
datasets = {
    'rc' : 6,
    # 'nounpp' : 5,
    # 'simple' : 2,
    # 'within_rc' : 5
}
thresholds = t.logspace(-4, 0, 15, 10).tolist()
if "gemma-2" in model_name:
    for structure in datasets:
        datasets[structure] += 1    # for BOS token
    thresholds = t.logspace(-4, 1.5, 15, 10).tolist()

outs = {
    'features' : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            feat_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            length=length,
        ) for dataset, length in datasets.items()
    },
    'features_wo_errs' : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            feat_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            length=length,
            handle_errors='remove',
        ) for dataset, length in datasets.items()
    },
    'features_wo_some_errs' : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            feat_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            length=length,
            handle_errors='resid_only',
        ) for dataset, length in datasets.items()
    },
    'neurons' : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            neuron_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            length=length,
            use_neurons=True,
        ) for dataset, length in datasets.items()
    },
}

  circuit = t.load(f'../circuits/{dataset}_train_dict10_node{node_threshold}_edge{edge_threshold}_n100_aggnone.pt')['nodes']


FileNotFoundError: [Errno 2] No such file or directory: '../circuits/rc_train_dict10_node0.75_edge0.075_n100_aggnone.pt'

In [39]:
# plot faithfulness results
fig = go.Figure()

colors = {
    'features' : 'blue',
    'features_wo_errs' : 'red',
    'features_wo_some_errs' : 'green',
    'neurons' : 'purple',
    # 'random_features' : 'black'
}

for setting, subouts in outs.items():

    x_min = max([min(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) + 1
    x_max = min([max(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) - 1
    fs = {
        dataset : interpolate.interp1d([subouts[dataset][t]['n_nodes'] for t in thresholds], [subouts[dataset][t]['faithfulness'] for t in thresholds])
        for dataset in datasets
    }
    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()

    for dataset in datasets:
        fig.add_trace(go.Scatter(
            x = [subouts[dataset][t]['n_nodes'] for t in thresholds],
            y = [subouts[dataset][t]['faithfulness'] for t in thresholds],
            mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False
        ))

    fig.add_trace(go.Scatter(
        x=xs,
        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],
        mode='lines', line=dict(color=colors[setting]), name=setting
    ))

fig.update_xaxes(range=(0, 1700))
fig.update_yaxes(range=(0, 1.1))

fig.update_layout(
    xaxis_title='Nodes',
    yaxis_title='Faithfulness',
    width=800,
    height=375,
    # set white background color
    plot_bgcolor='rgba(0,0,0,0)',
    # add grey gridlines
    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),
    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),

)

# fig.show()
fig.write_image('faithfulness_pythia.pdf')

In [32]:
# plot completeness results
fig = go.Figure()

colors = {
    'features' : 'blue',
    'features_wo_errs' : 'red',
    'features_wo_some_errs' : 'green',
    'neurons' : 'purple'
}

for setting, subouts in outs.items():

    x_min = max([min(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) + 1
    x_max = min([max(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) - 1
    fs = {
        dataset : interpolate.interp1d([subouts[dataset][t]['n_nodes'] for t in thresholds], [subouts[dataset][t]['completeness'] for t in thresholds])
        for dataset in datasets
    }
    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()
    print(x_min, x_max)

    for dataset in datasets:
        fig.add_trace(go.Scatter(
            x = [subouts[dataset][t]['n_nodes'] for t in thresholds],
            y = [subouts[dataset][t]['completeness'] for t in thresholds],
            mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False
        ))
    fig.add_trace(go.Scatter(
        x=xs,
        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],
        mode='lines', line=dict(color=colors[setting]), name=setting
    ))

fig.update_xaxes(range=(0,300))
fig.update_yaxes(range=(-.15, 1))

fig.update_layout(
    xaxis_title='Nodes',
    yaxis_title='Faithfulness',
    width=800,
    height=375,
    # set white background color
    plot_bgcolor='rgba(0,0,0,0)',
    # add grey gridlines
    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),
    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),
)
# fig.show()
fig.write_image('completeness_pythia.pdf')

377 728223
321 726847
377 727311
1 2114715
