In [4]:
import os
import sys
sys.path.append('Automatic-Circuit-Discovery/')
sys.path.append('utils/')
import re

import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

# ACDCpp helpers
# from utils.prune_utils import get_nodes, acdc_nodes
# from utils.graphics_utils import show


device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cpu


# Model Setup

In [5]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Dataset Setup

In [6]:
from acdc.greaterthan.utils import get_year_data
N=25

dataset = get_year_data(N, model)
dataset[1]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cpu


['The demonstrations lasted from the year 1145 to 11',
 'The assaults lasted from the year 1684 to 16',
 'The affair lasted from the year 1222 to 12',
 'The stature lasted from the year 1784 to 17',
 'The effort lasted from the year 1630 to 16',
 'The experiments lasted from the year 1477 to 14',
 'The employment lasted from the year 1286 to 12',
 'The hostility lasted from the year 1645 to 16',
 'The collaboration lasted from the year 1557 to 15',
 'The competition lasted from the year 1454 to 14',
 'The plan lasted from the year 1651 to 16',
 'The negotiation lasted from the year 1247 to 12',
 'The endeavor lasted from the year 1739 to 17',
 'The expansion lasted from the year 1733 to 17',
 'The relationship lasted from the year 1184 to 11',
 'The challenge lasted from the year 1136 to 11',
 'The evaluation lasted from the year 1780 to 17',
 'The tests lasted from the year 1429 to 14',
 'The slump lasted from the year 1214 to 12',
 'The existence lasted from the year 1181 to 11',
 'T

# Metric Setup

In [7]:
from acdc.greaterthan.utils import greaterthan_metric

In [8]:
with t.no_grad():
    sample_prompt_tokens = dataset[0][0].unsqueeze(0)
    logits = model(sample_prompt_tokens)
    print(greaterthan_metric(logits, sample_prompt_tokens))

tensor(-0.7149)


# ACDC++

In [11]:
from time import time
THRESHOLDS = [0.005, 0.01, 0.015, 0.02, 0.025]
run_name = 'greaterthan_thresh_run'
pruned_nodes_per_thresh = {}
num_forward_passes_per_thresh = {}
heads_per_thresh = {}
os.makedirs(f'./greaterthan_ims/{run_name}', exist_ok=True)
for threshold in THRESHOLDS:
    start_thresh_time = time()
    # Set up model
    # Set up experiment
    exp = TLACDCExperiment(
        model=model,
        threshold=threshold,
        run_name=run_name,
        ds=dataset[0],
        ref_ds=None,
        metric=greaterthan_metric,
        zero_ablation=True,
        hook_verbose=False
    )
    print('Setting up graph')
    # Set up computational graph
    exp.model.reset_hooks()
    exp.setup_model_hooks(
        add_sender_hooks=True,
        add_receiver_hooks=True,
        doing_acdc_runs=False,
    )
    exp_time = time()
    print(f'Time to set up exp: {exp_time - start_thresh_time}')
    for _ in range(10):
        pruned_nodes_attr = acdc_nodes(
            model=exp.model,
            clean_input=dataset[0],
            corrupted_input=dataset[0],
            metric=greaterthan_metric,
            threshold=threshold,
            exp=exp,
            attr_absolute_val=True,
        ) 
        t.cuda.empty_cache()
    acdcpp_time = time()
    print(f'ACDC++ time: {acdcpp_time - exp_time}')
    heads_per_thresh[threshold] = [get_nodes(exp.corr)]
    pruned_nodes_per_thresh[threshold] = pruned_nodes_attr
    show(exp.corr, fname=f'ims/{run_name}/thresh{threshold}_before_acdc.png')
    



ln_final.hook_normalized
ln_final.hook_scale
blocks.11.hook_resid_post
blocks.11.hook_mlp_out
blocks.11.mlp.hook_post
blocks.11.mlp.hook_pre
blocks.11.ln2.hook_normalized
blocks.11.ln2.hook_scale
blocks.11.hook_mlp_in
blocks.11.hook_resid_mid
blocks.11.hook_attn_out
blocks.11.attn.hook_result
blocks.11.attn.hook_z
blocks.11.attn.hook_pattern
blocks.11.attn.hook_attn_scores
blocks.11.attn.hook_v
blocks.11.attn.hook_k
blocks.11.attn.hook_q
blocks.11.ln1.hook_normalized
blocks.11.ln1.hook_scale
blocks.11.hook_v_input
blocks.11.hook_k_input
blocks.11.hook_q_input
blocks.11.hook_resid_pre
blocks.10.hook_resid_post
blocks.10.hook_mlp_out
blocks.10.mlp.hook_post
blocks.10.mlp.hook_pre
blocks.10.ln2.hook_normalized
blocks.10.ln2.hook_scale
blocks.10.hook_mlp_in
blocks.10.hook_resid_mid
blocks.10.hook_attn_out
blocks.10.attn.hook_result
blocks.10.attn.hook_z
blocks.10.attn.hook_pattern
blocks.10.attn.hook_attn_scores
blocks.10.attn.hook_v
blocks.10.attn.hook_k
blocks.10.attn.hook_q
blocks.10.ln

TypeError: greaterthan_metric() missing 1 required positional argument: 'tokens'

In [12]:
# Next: add ref ds *from 1900*