In [1]:
import os
import sys
sys.path.append('../..//Automatic-Circuit-Discovery/')
sys.path.append('..')
import re

import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

from utils.prune_utils import get_3_caches, split_layers_and_heads
from ACDCPPExperiment import ACDCPPExperiment

from acdc.hybridretrieval.utils import (
    get_all_hybrid_retrieval_things,
    get_gpt2_small
)

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


In [2]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
all_kbicr_items = get_all_hybrid_retrieval_things(num_examples=20, device=device, metric_name='logit_diff')
clean_dataset = all_kbicr_items.validation_data
corrupted_dataset = all_kbicr_items.validation_patch_data
wrong_labels = all_kbicr_items.validation_wrong_labels
correct_labels = all_kbicr_items.validation_labels

tl_model = all_kbicr_items.tl_model
validation_metric = all_kbicr_items.validation_metric
validation_data = all_kbicr_items.validation_data
validation_labels = all_kbicr_items.validation_labels
validation_patch_data = all_kbicr_items.validation_patch_data
test_metrics = all_kbicr_items.test_metrics
test_data = all_kbicr_items.test_data
test_labels = all_kbicr_items.test_labels
test_patch_data = all_kbicr_items.test_patch_data

Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda
Clean Data Datasets:
torch.Size([20, 26])

Corrupted Data Datasets:
torch.Size([20, 26])

Clean Labels shape:
torch.Size([20, 1])

Clean Wrong Labels shape:
torch.Size([20, 1])
Shape of validation_data: torch.Size([10, 26])
Shape of validation_patch_data: torch.Size([10, 26])
Shape of test_data: torch.Size([10, 26])
Shape of test_patch_data: torch.Size([10, 26])


In [4]:
logit_diff = all_kbicr_items.test_metrics['logit_diff']
logit_diff

functools.partial(<function logit_diff_metric at 0x7f0cae7dfeb0>, correct_labels=tensor([[13256],
        [11006],
        [19727],
        [18861],
        [13787],
        [31160],
        [13256],
        [11006],
        [19727],
        [18861]]), wrong_labels=tensor([[17940],
        [15039],
        [14053],
        [31632],
        [45001],
        [45355],
        [17940],
        [15039],
        [14053],
        [31632]]))

In [5]:
import torch
from functools import partial

# Example definition of logit_diff_metric
def avg_logit_diff(correct_logits, corrupt_logits):
    # Compute the logit difference metric
    logit_diff = correct_logits - corrupt_logits
    return logit_diff.mean()

# Ensure the labels are of type int64
correct_labels = correct_labels.long()
wrong_labels = wrong_labels.long()

# Move the labels to the device
correct_labels = correct_labels.to(device)
wrong_labels = wrong_labels.to(device)

print(correct_labels.shape, wrong_labels.shape)

# Ensure your model and dataset are on the same device
model.to(device)
clean_dataset = clean_dataset.to(device)
corrupted_dataset = corrupted_dataset.to(device)

with torch.no_grad():
    clean_logits = model(clean_dataset).to(device)
    corrupt_logits = model(corrupted_dataset).to(device)

# Compute the logit difference metric
clean_logit_diff = avg_logit_diff(clean_logits, corrupt_logits).item()

print(clean_logit_diff)

torch.Size([10, 1]) torch.Size([10, 1])
Moving model to device:  cuda
0.05820911005139351


In [6]:
def abs_kbicr_metric(logits):
    return -abs(test_metrics['logit_diff'](logits))

In [7]:
## Calc Node attributions

clean_cache, corrupted_cache, clean_grad_cache = get_3_caches(tl_model, test_data, test_patch_data, test_metrics['logit_diff'], mode="node")
# compute first-order Taylor approximation for each node to get the attribution
clean_head_act = clean_cache.stack_head_results()
corr_head_act = corrupted_cache.stack_head_results()
clean_grad_act = clean_grad_cache.stack_head_results()

# compute attributions of each node
node_attr = (clean_head_act - corr_head_act) * clean_grad_act
# separate layers and heads, sum over d_model (to complete the dot product), batch, and seq
node_attr = split_layers_and_heads(node_attr, tl_model).sum((2, 3, 4)).flatten().abs()

In [8]:
## Calc edge attributions

threshold_dummy = 0.1 # Does not make a difference when only running edge based attribution patching, as all attributions are saved in the result dict anyways
acdcpp_exp = ACDCPPExperiment(tl_model,
                              test_data,
                              test_patch_data,
                              test_metrics['logit_diff'],
                              abs_kbicr_metric,
                              threshold_dummy,
                              verbose=False,
                              attr_absolute_val=True,
                              save_graphs_after=0,
                              pruning_mode="edge",
                              no_pruned_nodes_attr=1
                             )
tlacdc_exp = acdcpp_exp.setup_exp(threshold=threshold_dummy)
attr_results = acdcpp_exp.run_acdcpp(exp=tlacdc_exp, threshold=threshold_dummy)



ln_final.hook_normalized
ln_final.hook_scale
blocks.11.hook_resid_post
blocks.11.hook_mlp_out
blocks.11.mlp.hook_post
blocks.11.mlp.hook_pre
blocks.11.ln2.hook_normalized
blocks.11.ln2.hook_scale
blocks.11.hook_mlp_in
blocks.11.hook_resid_mid
blocks.11.hook_attn_out
blocks.11.attn.hook_result
blocks.11.attn.hook_z
blocks.11.attn.hook_pattern
blocks.11.attn.hook_attn_scores
blocks.11.attn.hook_v
blocks.11.attn.hook_k
blocks.11.attn.hook_q
blocks.11.ln1.hook_normalized
blocks.11.ln1.hook_scale
blocks.11.hook_v_input
blocks.11.hook_k_input
blocks.11.hook_q_input
blocks.11.hook_resid_pre
blocks.10.hook_resid_post
blocks.10.hook_mlp_out
blocks.10.mlp.hook_post
blocks.10.mlp.hook_pre
blocks.10.ln2.hook_normalized
blocks.10.ln2.hook_scale
blocks.10.hook_mlp_in
blocks.10.hook_resid_mid
blocks.10.hook_attn_out
blocks.10.attn.hook_result
blocks.10.attn.hook_z
blocks.10.attn.hook_pattern
blocks.10.attn.hook_attn_scores
blocks.10.attn.hook_v
blocks.10.attn.hook_k
blocks.10.attn.hook_q
blocks.10.ln

OutOfMemoryError: CUDA out of memory. Tried to allocate 148.00 MiB. GPU 0 has a total capacty of 7.79 GiB of which 99.38 MiB is free. Including non-PyTorch memory, this process has 7.68 GiB memory in use. Of the allocated memory 6.98 GiB is allocated by PyTorch, and 527.63 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

: 

In [None]:
attr_vals = np.array(list(attr_results[1].values()))
thresholds_edge = np.linspace(min(attr_vals), 0.2, 100)
thresholds_node = np.linspace(min(attr_vals), max(attr_vals), 100)
num_edges_above_thresh = np.array(
    [sum(attr_vals > t) for t in thresholds_edge]
)
num_nodes_above_thresh = np.array(
    [sum(node_attr > t) for t in thresholds_node]
)



fig, ax = plt.subplots(2,1, figsize=(6,8))

ax[0].plot(thresholds_edge, num_edges_above_thresh)
ax[1].plot(thresholds_node, num_nodes_above_thresh)
ax[1].set_xlabel("threshold")
ax[0].set_ylabel("number of remaining edges")
ax[1].set_ylabel("number of remaining nodes")
ax[0].set_title("Number of remaining edges/nodes after ACDC++ only \n(Docstring task, ADCD threshold: 0.067)")