In [None]:
import os
import sys
sys.path.append('../Automatic-Circuit-Discovery/')
sys.path.append('..')
import re

import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
from acdc.greaterthan.utils import get_all_greaterthan_things

from ACDCPPExperiment import ACDCPPExperiment

import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

# Model Setup

In [None]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

# Dataset Setup

In [None]:
# Make clean dataset and reference dataset
N = 25

things = get_all_greaterthan_things(
    num_examples=N, metric_name="greaterthan", device=device
)
greaterthan_metric = things.validation_metric
toks_int_values = things.validation_data # clean data x_i
toks_int_values_other = things.validation_patch_data # corrupted data x_i'

print("\nClean dataset samples")
for i in range(5):
    print(model.tokenizer.decode(toks_int_values[i]))

print("\nReference dataset samples")
for i in range(5):
    print(model.tokenizer.decode(toks_int_values_other[i]))

# Run Experiment

In [None]:
THRESHOLDS = [0.01585]
RUN_NAME = 'greaterthan_first_pass'
acdcpp_exp = ACDCPPExperiment(model,
                              toks_int_values,
                              toks_int_values_other,
                              greaterthan_metric,
                              greaterthan_metric,
                              THRESHOLDS,
                              run_name=RUN_NAME,
                              verbose=False,
                              attr_absolute_val=True,
                              save_graphs_after=0.07,
                             )
pruned_heads, num_passes, pruned_attrs = acdcpp_exp.run()

# Save Data

In [None]:
import json

for thresh in pruned_heads.keys():
   pruned_heads[thresh][0] = list(pruned_heads[thresh][0])
   pruned_heads[thresh][1] = list(pruned_heads[thresh][1])

cleaned_attrs = {}
for thresh in pruned_attrs.keys():
    cleaned_attrs[thresh] = {}
    for (layer, head), attr in pruned_attrs[thresh].items():
        cleaned_attrs[thresh][f'L{layer}H{head}'] = attr
        
with open(f'{RUN_NAME}_pruned_heads.json', 'w') as f:
    json.dump(pruned_heads, f)
with open(f'{RUN_NAME}_num_passes.json', 'w') as f:
    json.dump(num_passes, f)
with open(f'{RUN_NAME}_pruned_attrs.json', 'w') as f:
    json.dump(cleaned_attrs, f)