In [9]:
from IPython import get_ipython
ipython = get_ipython()
if ipython is not None:
    ipython.magic("%load_ext autoreload")
    ipython.magic("%autoreload 2")

import os
import sys
sys.path.append('../Automatic-Circuit-Discovery/')
sys.path.append('..')
import torch
import re

import acdc
from utils.prune_utils import get_3_caches, split_layers_and_heads
from acdc.greaterthan.utils import get_all_greaterthan_things
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools
from functools import partial

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Device: cpu


  ipython.magic("%load_ext autoreload")
  ipython.magic("%autoreload 2")


# Model Setup

In [5]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Metric & Dataset Setup

In [12]:
# Make clean dataset and reference dataset
N = 25

things = get_all_greaterthan_things(
    num_examples=N, metric_name="greaterthan", device=device
)
greaterthan_metric = things.validation_metric
clean_ds = things.validation_data # clean data x_i
corr_ds = things.validation_patch_data # corrupted data x_i'

print("\nClean dataset samples")
for i in range(5):
    print(model.tokenizer.decode(clean_ds[i]))

print("\nReference dataset samples")
for i in range(5):
    print(model.tokenizer.decode(corr_ds[i]))

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cpu

Clean dataset samples
The demonstrations lasted from the year 1267 to 12
The assaults lasted from the year 1644 to 16
The affair lasted from the year 1268 to 12
The stature lasted from the year 1653 to 16
The effort lasted from the year 1318 to 13

Reference dataset samples
The demonstrations lasted from the year 1201 to 12
The assaults lasted from the year 1601 to 16
The affair lasted from the year 1201 to 12
The stature lasted from the year 1601 to 16
The effort lasted from the year 1301 to 13


# Run Experiment

In [13]:
# get the 2 fwd and 1 bwd caches; cache "normalized" and "result" of attn layers
clean_cache, corrupted_cache, clean_grad_cache = get_3_caches(
    model, 
    clean_ds,
    corr_ds,
    metric=greaterthan_metric,
    mode = "edge",
)

In [6]:
clean_head_act = split_layers_and_heads(clean_cache.stack_head_results(), model=model)
corr_head_act = split_layers_and_heads(corrupted_cache.stack_head_results(), model=model)

In [7]:
stacked_grad_act = torch.zeros(
    3, # QKV
    model.cfg.n_layers,
    model.cfg.n_heads,
    clean_head_act.shape[-3], # Batch
    clean_head_act.shape[-2], # Seq
    clean_head_act.shape[-1], # D
)

for letter_idx, letter in enumerate("qkv"):
    for layer_idx in range(model.cfg.n_layers):
        stacked_grad_act[letter_idx, layer_idx] = einops.rearrange(clean_grad_cache[f"blocks.{layer_idx}.hook_{letter}_input"], "batch seq n_heads d -> n_heads batch seq d")

In [10]:
results = {}

for upstream_layer_idx in range(model.cfg.n_layers):
    for upstream_head_idx in range(model.cfg.n_heads):
        for downstream_letter_idx, downstream_letter in enumerate("qkv"):
            for downstream_layer_idx in range(upstream_layer_idx+1, model.cfg.n_layers):
                for downstream_head_idx in range(model.cfg.n_heads):
                    results[
                        (
                            upstream_layer_idx,
                            upstream_head_idx,
                            downstream_letter,
                            downstream_layer_idx,
                            downstream_head_idx,
                        )
                    ] = (stacked_grad_act[downstream_letter_idx, downstream_layer_idx, downstream_head_idx].cpu() * (clean_head_act[upstream_layer_idx, upstream_head_idx] - corr_head_act[upstream_layer_idx, upstream_head_idx]).cpu()).sum()

In [11]:
sorted_results = sorted(results.items(), key=lambda x: x[1].abs(), reverse=True)

In [15]:
print("Top 10 most important edges:")
for i in range(10):
    print(
        f"{sorted_results[i][0][0]}:{sorted_results[i][0][1]} -> {sorted_results[i][0][3]}:{sorted_results[i][0][4]}",
    )

Top 10 most important edges:
9:9 -> 11:10
10:7 -> 11:10
5:5 -> 8:6
9:9 -> 10:7
5:5 -> 6:9
9:6 -> 11:10
4:11 -> 6:9
9:6 -> 10:7
5:5 -> 7:9
10:10 -> 11:10
