In [1]:
import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from nnsight import LanguageModel
from activation_utils import SparseAct
import torch as t
import plotly.graph_objects as go
from loading_utils import load_examples
from dictionary_learning import AutoEncoder
from dictionary_learning.dictionary import IdentityDict
from ablation import run_with_ablations
from scipy import interpolate
import math
from statistics import stdev

from circuit import get_circuit

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda:0'
DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=device, dispatch=True)

start_layer = 2 # explain the model starting here

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# load submodules
submodules = []
if start_layer < 0: submodules.append(model.gpt_neox.embed_in)
for i in range(start_layer, len(model.gpt_neox.layers)):
    submodules.extend([
        model.gpt_neox.layers[i].attention,
        model.gpt_neox.layers[i].mlp,
        model.gpt_neox.layers[i]
    ])

submod_names = {
    model.gpt_neox.embed_in : 'embed'
}
for i in range(len(model.gpt_neox.layers)):
    submod_names[model.gpt_neox.layers[i].attention] = f'attn_{i}'
    submod_names[model.gpt_neox.layers[i].mlp] = f'mlp_{i}'
    submod_names[model.gpt_neox.layers[i]] = f'resid_{i}'

In [4]:
# load dictionaries
dict_id = 10

activation_dim = 512
expansion_factor = 64
dict_size = expansion_factor * activation_dim

feat_dicts = {}
ae = AutoEncoder(activation_dim, dict_size).to(device)
ae.load_state_dict(t.load(f"C:/Users/ConnardMcGregoire/Documents/MI_Internship/feature-circuits/dictionary_learning/dictionaires/pythia-70m-deduped/embed/ae.pt", map_location=device))
feat_dicts[model.gpt_neox.embed_in] = ae

d_model = activation_dim
dict_size = d_model * expansion_factor

for layer in range(len(model.gpt_neox.layers)):
    ae = AutoEncoder(d_model, dict_size).to(device)
    ae.load_state_dict(t.load(f"C:/Users/ConnardMcGregoire/Documents/MI_Internship/feature-circuits/dictionary_learning/dictionaires/pythia-70m-deduped/resid_out_layer{layer}/ae.pt", map_location=device))
    feat_dicts[model.gpt_neox.layers[layer]] = ae

    ae = AutoEncoder(d_model, dict_size).to(device)
    ae.load_state_dict(t.load(f"C:/Users/ConnardMcGregoire/Documents/MI_Internship/feature-circuits/dictionary_learning/dictionaires/pythia-70m-deduped/attn_out_layer{layer}/ae.pt", map_location=device))
    feat_dicts[model.gpt_neox.layers[layer].attention] = ae

    ae = AutoEncoder(d_model, dict_size).to(device)
    ae.load_state_dict(t.load(f"C:/Users/ConnardMcGregoire/Documents/MI_Internship/feature-circuits/dictionary_learning/dictionaires/pythia-70m-deduped/mlp_out_layer{layer}/ae.pt", map_location=device))
    feat_dicts[model.gpt_neox.layers[layer].mlp] = ae

neuron_dicts = {
    submod : IdentityDict(activation_dim).to(device) for submod in submodules
}


In [5]:
pythia70m = model

pythia70m_embed = pythia70m.gpt_neox.embed_in

pythia70m_resids= []
pythia70m_attns = []
pythia70m_mlps = []
for layer in range(len(pythia70m.gpt_neox.layers)):
    pythia70m_resids.append(pythia70m.gpt_neox.layers[layer])
    pythia70m_attns.append(pythia70m.gpt_neox.layers[layer].attention)
    pythia70m_mlps.append(pythia70m.gpt_neox.layers[layer].mlp)
    
dictionaries = {}

d_model = 512
dict_size = 32768

path = "C:/Users/ConnardMcGregoire/Documents/MI_Internship/feature-circuits/dictionary_learning/dictionaires/pythia-70m-deduped/"

ae = AutoEncoder(d_model, dict_size).to(DEVICE)
ae.load_state_dict(t.load(path + f"embed/ae.pt", map_location=DEVICE))
dictionaries[pythia70m_embed] = ae


for layer in range(len(pythia70m.gpt_neox.layers)):
    ae = AutoEncoder(d_model, dict_size).to(DEVICE)
    ae.load_state_dict(t.load(path + f"resid_out_layer{layer}/ae.pt", map_location=DEVICE))
    dictionaries[pythia70m_resids[layer]] = ae

    ae = AutoEncoder(d_model, dict_size).to(DEVICE)
    ae.load_state_dict(t.load(path + f"attn_out_layer{layer}/ae.pt", map_location=DEVICE))
    dictionaries[pythia70m_attns[layer]] = ae

    ae = AutoEncoder(d_model, dict_size).to(DEVICE)
    ae.load_state_dict(t.load(path + f"mlp_out_layer{layer}/ae.pt", map_location=DEVICE))
    dictionaries[pythia70m_mlps[layer]] = ae

In [6]:
def metric_fn_v1(model, trg=None):
    """
    default : return the logit
    """
    if trg is None:
        raise ValueError("trg must be provided")
    logits = model.embed_out.output[:,-1,:]
    return logits[0, trg[0]] - logits[0, trg[1]]

In [7]:
clean = "When Mary and John went to the store, John gave a drink to"
patch = "When Mary and John went to the store, Alice gave a drink to"
trg = " Mary"
trg_idx = t.tensor(pythia70m.tokenizer.encode(trg)[0], device=DEVICE)
patch_trg_idx = t.tensor([253], device=DEVICE) # the

In [8]:
circuit = get_circuit(
    clean, patch,
    model,
    pythia70m_embed, pythia70m_attns, pythia70m_mlps, pythia70m_resids, dictionaries,
    metric_fn_v1, {"trg": (trg_idx, patch_trg_idx)}
)

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [9]:
with model.trace(patch):
    logits = model.output.save()

print(logits[0][:, -1, :].shape)
logits = logits[0][:, -1, :].squeeze()
perm = t.argsort(logits, descending=True)
p = t.softmax(logits, dim=-1)
for i in range(10):
    print(p[perm[i].item()].item())
    print(perm[i].item())
    print(pythia70m.tokenizer.decode([perm[i].item()]))

torch.Size([1, 50304])
0.1829976737499237
253
 the
0.07211925834417343
617
 her
0.05171777307987213
731
 them
0.03711021691560745
6393
 Mary
0.03526845946907997
779
 him
0.03157972916960716
16922
 Alice
0.0294966921210289
247
 a
0.011895938776433468
5332
 Jack
0.011458381079137325
616
 their
0.008285464718937874
521
 his


In [10]:
# use mean ablation
ablation_fn = lambda x: x.mean(dim=0).expand_as(x)

In [18]:
# get m(C) for the circuit obtained by thresholding nodes with the given threshold
def get_fcs(
    model,
    circuit,
    submodules,
    dictionaries,
    ablation_fn,
    thresholds,
    handle_errors = 'default', # also 'remove' or 'resid_only'
):
    clean_inputs = clean
    clean_answer_idxs = trg_idx
    patch_inputs = patch
    patch_answer_idxs = patch_trg_idx

    def metric_fn(model):
        return (
            - t.gather(model.embed_out.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) + \
            t.gather(model.embed_out.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
        )
    
    circuit = circuit[0]

    with t.no_grad():
        out = {}

        # get F(M)
        with model.trace(clean_inputs):
            metric = metric_fn(model).save()
        fm = metric.value.mean().item()

        out['fm'] = fm

        # get m(∅)
        fempty = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            dictionaries,
            nodes = {
                submod : SparseAct(
                    act=t.zeros(dict_size, dtype=t.bool), 
                    resc=t.zeros(1, dtype=t.bool)).to(device)
                for submod in submodules
            },
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
        ).mean().item()
        out['fempty'] = fempty

        for threshold in thresholds:
            out[threshold] = {}
            nodes = {
                submod : circuit[submod_names[submod]].abs() > threshold for submod in submodules
            }

            if handle_errors == 'remove':
                for k in nodes: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)
            elif handle_errors == 'resid_only':
                for k in nodes:
                    if k not in model.gpt_neox.layers: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)

            n_nodes = sum([n.act.sum() + n.resc.sum() for n in nodes.values()]).item()
            out[threshold]['n_nodes'] = n_nodes
            
            out[threshold]['fc'] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
            ).mean().item()
            out[threshold]['fccomp'] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
                complement=True
            ).mean().item()
            out[threshold]['faithfulness'] = (out[threshold]['fc'] - fempty) / (fm - fempty)
            out[threshold]['completeness'] = (out[threshold]['fccomp'] - fempty) / (fm - fempty)
    
    return out

In [26]:
#thresholds = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032, 0.064, 0.128, 0.256, 0.512]
thresholds = t.logspace(-4, 0, 15, 10).tolist()
outs = {
    'features' :
        get_fcs(
            model,
            circuit,
            submodules,
            feat_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
        ),
    'features_wo_errs' :
        get_fcs(
            model,
            circuit,
            submodules,
            feat_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            handle_errors='remove'
        ),
    'features_wo_some_errs' :
        get_fcs(
            model,
            circuit,
            submodules,
            feat_dicts,
            ablation_fn=ablation_fn,
            thresholds = thresholds,
            handle_errors='resid_only'
        )
}

In [25]:
# plot faithfulness results
fig = go.Figure()

colors = {
    'features' : 'blue',
    'features_wo_errs' : 'red',
    'features_wo_some_errs' : 'green',
    'neurons' : 'purple',
    # 'random_features' : 'black'
}

for setting, subouts in outs.items():
    x_min = max([min(subouts[t]['n_nodes'] for t in thresholds)]) + 1
    x_max = min([max(subouts[t]['n_nodes'] for t in thresholds)]) - 1
    fs = {
        "ioi" : interpolate.interp1d([subouts[t]['n_nodes'] for t in thresholds], [subouts[t]['faithfulness'] for t in thresholds])
    }
    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()

    fig.add_trace(go.Scatter(
        x = [subouts[t]['n_nodes'] for t in thresholds],
        y = [subouts[t]['faithfulness'] for t in thresholds],
        mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False
    ))

    fig.add_trace(go.Scatter(
        x=xs,
        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],
        mode='lines', line=dict(color=colors[setting]), name=setting
    ))

fig.update_xaxes(range=(0, 1700))
fig.update_yaxes(range=(0, 1.1))

fig.update_layout(
    xaxis_title='Nodes',
    yaxis_title='Faithfulness',
    width=800,
    height=375,
    # set white background color
    plot_bgcolor='rgba(0,0,0,0)',
    # add grey gridlines
    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),
    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),

)

# fig.show()
fig.write_image('faithfulness.pdf')

In [None]:
# plot completeness results
fig = go.Figure()

colors = {
    'features' : 'blue',
    'features_wo_errs' : 'red',
    'features_wo_some_errs' : 'green',
    'neurons' : 'purple'
}

for setting, subouts in outs.items():

    x_min = max([min(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) + 1
    x_max = min([max(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) - 1
    fs = {
        dataset : interpolate.interp1d([subouts[dataset][t]['n_nodes'] for t in thresholds], [subouts[dataset][t]['completeness'] for t in thresholds])
        for dataset in datasets
    }
    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()

    for dataset in datasets:
        fig.add_trace(go.Scatter(
            x = [subouts[dataset][t]['n_nodes'] for t in thresholds],
            y = [subouts[dataset][t]['completeness'] for t in thresholds],
            mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False
        ))
    fig.add_trace(go.Scatter(
        x=xs,
        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],
        mode='lines', line=dict(color=colors[setting]), name=setting
    ))

fig.update_xaxes(range=(0,300))
fig.update_yaxes(range=(-.15, 1))

fig.update_layout(
    xaxis_title='Nodes',
    yaxis_title='Faithfulness',
    width=800,
    height=375,
    # set white background color
    plot_bgcolor='rgba(0,0,0,0)',
    # add grey gridlines
    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),
    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),
)
# fig.show()
fig.write_image('completeness.pdf')