In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from nnsight import LanguageModel
import torch as t
from dictionary_loading_utils import load_saes_and_submodules
from activation_utils import SparseAct
from tqdm import tqdm

In [None]:
def run_with_ablations(
    inputs,
    model: LanguageModel,
    submodules,
    dictionaries,
    nodes: dict,
    metric_fn,
    complement=False,
):
    with model.trace(inputs), t.no_grad():
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            nodes_to_keep = nodes[submodule]

            x = submodule.get_activation()
            f = dictionary.encode(x)
            res = x - dictionary.decode(f)

            if complement:
                nodes_to_ablate = nodes_to_keep
            else:
                nodes_to_ablate = ~nodes_to_keep

            f[..., nodes_to_ablate.act] = 0.0

            res_multiplier = (1.0 - nodes_to_ablate.resc.float().to(res.dtype))
            res = res * res_multiplier

            submodule.set_activation(dictionary.decode(f) + res)

        metric = metric_fn(model.output).save()

    return metric.value

In [None]:
model_name = "google/gemma-2-2b"
device = 'cpu'
dtype = t.bfloat16

model = LanguageModel(model_name, attn_implementation="eager", torch_dtype=dtype, device_map=device, dispatch=True)

submodules, dictionaries = load_saes_and_submodules(model, include_embed=False, dtype=dtype, device=device)

In [None]:
prompt = "The doctors that the assistant follows"
answer = " go"

In [None]:
model.tokenizer(" go")

In [None]:
thresholds = thresholds = t.logspace(-4, 0, 50).tolist()

In [None]:
def get_fcs(
    prompt,
    answer,
    model,
    submodules,
    dictionaries,
    thresholds,
):
    circuit = t.load(
        "../circuits/gemma-2-2b_rc_train_n100_aggnone_nodeall.pt", 
        map_location=t.device('cpu'), 
        weights_only=False
    )
    circuit = circuit['nodes']

    inputs = [prompt]
    answer_idx = t.tensor(
        [model.tokenizer(answer).input_ids[-1]],
        dtype=t.long,
        device=device
    )

    def metric_fn(model_output):
        logits = model_output.logits.squeeze()[-1]
        answer_logit = logits[answer_idx]
        mean_top_10 = t.topk(logits, 10).values.mean()
        return answer_logit - mean_top_10

    out = {}

    with t.no_grad():
        with model.trace(inputs):
            metric = metric_fn(model.output).save()
        fm = metric.value.item()
        out["fm"] = fm

        fempty_nodes = {
            submod : SparseAct(
                act=t.zeros(dictionaries[submod].dict_size, dtype=t.bool),
                resc=t.zeros(1, dtype=t.bool)
            ).to(device)
            for submod in submodules
        }
        fempty = run_with_ablations(
            inputs,
            model,
            submodules,
            dictionaries,
            nodes=fempty_nodes,
            metric_fn=metric_fn,
        ).item()
        out["fempty"] = fempty

        for threshold in tqdm(thresholds, desc=f"Thresholds"):
            nodes = {
                submod : circuit[submod.name].abs() > threshold
                for submod in submodules
            }

            out[threshold] = {}

            n_nodes = sum(
                [
                    n.act.sum() + n.resc.sum() for n in nodes.values()
                ]
            ).item()
            out[threshold]['n_nodes'] = n_nodes

            out[threshold]["fc"] = run_with_ablations(
                inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
            ).item()

            out[threshold]["fccomp"] = run_with_ablations(
                inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
                complement=True
            ).item()

            out[threshold]["faithfulness"] = (
                out[threshold]["fc"] - out["fempty"]
            ) / (out["fm"] - out["fempty"])
            
            out[threshold]["completeness"] = (
                out[threshold]["fccomp"] - out["fempty"]
            ) / (out["fm"] - out["fempty"])

    return out

In [None]:
out = get_fcs(
    prompt,
    answer,
    model,
    submodules,
    dictionaries,
    thresholds,
)