In [1]:
# autoreload
%load_ext autoreload
%autoreload 2


In [2]:
import json
import torch

def get_results_for_case(case_name, dir = "./interp_results", tuned_lens=False):
    lens_str = "tuned_lens" if tuned_lens else "logit_lens"
    file_path = f"{dir}/{case_name}/{lens_str}/{lens_str}_results.json"
    with open(file_path, "r") as f:
        results = json.load(f)
        labels = torch.tensor(results["labels"])
        logit_lens_results = {k: torch.tensor(v) for k, v in results.items()}
        # delete the "labels" key from logit_lens_results
        del logit_lens_results["labels"]
        return labels, logit_lens_results

In [3]:
from interp_utils.lens.plot_utils import *
from circuits_benchmark.utils.get_cases import get_cases, get_names_of_working_cases

case_name = "20"
case = get_cases(indices=[case_name])[0]
tuned_lens = False

labels, logit_lens_results = get_results_for_case(case_name, tuned_lens=tuned_lens)

In [11]:
from sklearn import metrics
x = logit_lens_results["L1H2(IC)"].detach().cpu().numpy().squeeze()
y = labels.detach().cpu().numpy().squeeze()

metrics.explained_variance_score(y[:, -1], x[:, -1])

0.9997087717056274

In [4]:
plot_explained_variance_combined(
    labels=labels,
    lens_results=logit_lens_results,
    is_categorical=case.is_categorical(),
    nodes_in_circuit= [],
    abs_corr=False,
    tuned_lens=tuned_lens,
    case_name=case_name,
    show=True
)

'./interp_results/20/logit_lens/combined_variance_explained.png'

In [5]:
def metric(y, x):
    return np.mean(abs(x - y))

plot_metric(
    labels=labels,
    lens_results=logit_lens_results,
    metric=metric,
    metric_label="Mean Absolute Error",
    is_categorical=case.is_categorical(),
    nodes_in_circuit= [],
    abs_corr=True,
    tuned_lens=tuned_lens,
    case_name=case_name,
    show=True
)

'./interp_results/20/logit_lens/combined_mean_absolute_error.png'

In [None]:
logit_lens_results.keys()

dict_keys(['embed', 'pos_embed', '0_mlp_out(IC)', '1_mlp_out', 'L0H0', 'L0H1', 'L0H2', 'L0H3', 'L1H0(IC)', 'L1H1', 'L1H2', 'L1H3'])

In [13]:
key = "L1H2(IC)"
plot_pearson(
    key=key,
    in_circuit= "IC" in key,
    lens_results=logit_lens_results,
    labels=labels,
    case_name=case.get_name(),
    show=True,
    tuned_lens=tuned_lens,
    is_categorical=case.is_categorical()
)

'./interp_results//3/logit_lens/L1H2(IC)/pearson.png'