In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from nnsight import LanguageModel
import torch as t
from dictionary_loading_utils import load_gemma_transcoders_and_submodules
from activation_utils import SparseAct
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "google/gemma-2-2b"
device = 'cpu'
dtype = t.bfloat16

model = LanguageModel(model_name, attn_implementation="eager", torch_dtype=dtype, device_map=device, dispatch=True)

submodules, dictionaries = load_gemma_transcoders_and_submodules(model, dtype=dtype, device=device)

Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.42s/it]
Fetching 26 files: 100%|██████████| 26/26 [00:00<00:00, 37.42it/s]
Loading Gemma transcoders:  96%|█████████▋| 26/27 [00:00<?, ?it/s]


In [3]:
# FOR NOW, LOAD JUST THE COMPLETE GRAPH
# IN THE FUTURE WE WILL HOPEFULLY LOAD A LIST OF GRAPHS PRUNED AT DIFFERENT LEVELS  
from pathlib import Path

graphs_dir = Path('../attribution_graphs')
graphs_name = 'example_graph.pt'
graphs_path = graphs_dir / graphs_name

graphs = t.load(graphs_path, weights_only=False)

# IN THE FUTURE ALSO LOAD THE THRESHOLD VALUES CORRESPONDING TO THE GRAPHS
thresholds = t.logspace(-4, 0, 1).tolist()

In [4]:
# WILL BE REMOVED IF GRAPHS ARE SAVED IN CORRECT FORMAT 
d_transcoder = 16384

def prepare_graph(graph):
    nodes = {}

    n_layers = graph['cfg'].n_layers
    n_pos = len(graph['input_tokens'])
    for layer_idx in range(n_layers):
        nodes[submodules[layer_idx]] = SparseAct(
            act=t.zeros((n_pos, d_transcoder), dtype=t.bool),
            resc=t.ones((n_pos, 1), dtype=t.bool)
        )

    all_active_features = graph['active_features']
    selected_feature_indices = graph['selected_features']
    selected_features = all_active_features[selected_feature_indices]
    for feature in selected_features:
        layer_idx, pos, feat_id = (
            feature[0].item(),
            feature[1].item(),
            feature[2].item(),
        )
        nodes[submodules[layer_idx]].act[pos, feat_id] = True

    return nodes

graphs = [prepare_graph(graphs)]

In [5]:
def run_with_ablations(
    inputs,
    model: LanguageModel,
    submodules,
    dictionaries,
    nodes: dict,
    metric_fn,
    complement=False,
):
    with model.trace(inputs), t.no_grad():
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            nodes_to_keep = nodes[submodule]

            x = submodule[0].get_activation()
            f = dictionary.encode(x)
            res = submodule[1].get_activation() - dictionary.decode(f)

            if complement:
                nodes_to_ablate = nodes_to_keep
            else:
                nodes_to_ablate = ~nodes_to_keep

            f[..., nodes_to_ablate.act] = 0.0

            res_multiplier = (1.0 - nodes_to_ablate.resc.float().to(res.dtype))
            res = res * res_multiplier

            submodule[1].set_activation(dictionary.decode(f) + res)

        metric = metric_fn(model.output).save()

    return metric.value

In [6]:
prompt = "The doctors that the assistant follows"
answer = " go"

In [7]:
def get_fcs(
    prompt,
    answer,
    model,
    submodules,
    dictionaries,
    thresholds,
):
    inputs = [prompt]
    answer_idx = t.tensor(
        [model.tokenizer(answer).input_ids[-1]],
        dtype=t.long,
        device=device
    )

    def metric_fn(model_output):
        logits = model_output.logits.squeeze()[-1]
        answer_logit = logits[answer_idx]
        mean_top_10 = t.topk(logits, 10).values.mean()
        return answer_logit - mean_top_10

    out = {}

    with t.no_grad():
        with model.trace(inputs):
            metric = metric_fn(model.output).save()
        fm = metric.value.item()
        out["fm"] = fm

        fempty_nodes = {
            submod : SparseAct(
                act=t.zeros(dictionaries[submod].d_transcoder, dtype=t.bool),
                resc=t.zeros(1, dtype=t.bool)
            ).to(device)
            for submod in submodules
        }
        fempty = run_with_ablations(
            inputs,
            model,
            submodules,
            dictionaries,
            nodes=fempty_nodes,
            metric_fn=metric_fn,
        ).item()
        out["fempty"] = fempty

        for idx in tqdm(range(len(thresholds)), desc=f"Thresholds"):
            threshold = thresholds[idx]
            nodes = graphs[idx]

            out[threshold] = {}

            n_nodes = sum(
                [
                    n.act.sum() + n.resc.sum() for n in nodes.values()
                ]
            ).item()
            out[threshold]['n_nodes'] = n_nodes

            out[threshold]["fc"] = run_with_ablations(
                inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
            ).item()

            out[threshold]["fccomp"] = run_with_ablations(
                inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
                complement=True
            ).item()

            out[threshold]["faithfulness"] = (
                out[threshold]["fc"] - out["fempty"]
            ) / (out["fm"] - out["fempty"])
            
            out[threshold]["completeness"] = (
                out[threshold]["fccomp"] - out["fempty"]
            ) / (out["fm"] - out["fempty"])

    return out

In [8]:
out = get_fcs(
    prompt,
    answer,
    model,
    submodules,
    dictionaries,
    thresholds,
)

You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Thresholds:   0%|          | 0/1 [00:00<?, ?it/s]


IndexError: Above exception when execution Node: 'getitem_0' in Graph: '2397136607392'