In [1]:
import os
import sys
import re

import acdc
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


# Model Setup

In [2]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Dataset Setup

In [3]:
from ioi_dataset import IOIDataset, format_prompt, make_table
N = 20
clean_dataset = IOIDataset(
    prompt_type='mixed',
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=device
)
corr_dataset = clean_dataset.gen_flipped_prompts('ABC->XYZ, BAB->XYZ')

make_table(
  colnames = ["IOI prompt", "IOI subj", "IOI indirect obj", "ABC prompt"],
  cols = [
    map(format_prompt, clean_dataset.sentences),
    model.to_string(clean_dataset.s_tokenIDs).split(),
    model.to_string(clean_dataset.io_tokenIDs).split(),
    map(format_prompt, clean_dataset.sentences),
  ],
  title = "Sentences from IOI vs ABC distribution",
)

# Metric Setup

In [4]:
def ave_logit_diff(
    logits: Float[Tensor, 'batch seq d_vocab'],
    ioi_dataset: IOIDataset,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    # Get logits for indirect objects
    io_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.io_tokenIDs]
    s_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.s_tokenIDs]
    # Get logits for subject
    logit_diff = io_logits - s_logits
    return logit_diff if per_prompt else logit_diff.mean()

with t.no_grad():
    clean_logits = model(clean_dataset.toks)
    corrupt_logits = model(corr_dataset.toks)
    clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset).item()
    corrupt_logit_diff = ave_logit_diff(corrupt_logits, corr_dataset).item()

def ioi_metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    ioi_dataset: IOIDataset = clean_dataset
 ):
    patched_logit_diff = ave_logit_diff(logits, ioi_dataset)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, clean_dataset)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, corr_dataset)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

Clean direction: 2.6958870887756348, Corrupt direction: 1.5138511657714844
Clean metric: 1.0, Corrupt metric: 0.0


# Experiment Setup

In [5]:
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType

# Set up experiment
exp = TLACDCExperiment(
    model=model,
    threshold=0.00001,
    ds=clean_dataset.toks,
    ref_ds=corr_dataset.toks,
    metric=ioi_metric,
    zero_ablation=True,
    hook_verbose=False
)

# Set up computational graph
exp.model.reset_hooks()
exp.setup_model_hooks(
    add_sender_hooks=True,
    add_receiver_hooks=True,
    doing_acdc_runs=False,
)



ln_final.hook_normalized
ln_final.hook_scale
blocks.11.hook_resid_post
blocks.11.hook_mlp_out
blocks.11.mlp.hook_post
blocks.11.mlp.hook_pre
blocks.11.ln2.hook_normalized
blocks.11.ln2.hook_scale
blocks.11.hook_mlp_in
blocks.11.hook_resid_mid
blocks.11.hook_attn_out
blocks.11.attn.hook_result
blocks.11.attn.hook_z
blocks.11.attn.hook_pattern
blocks.11.attn.hook_attn_scores
blocks.11.attn.hook_v
blocks.11.attn.hook_k
blocks.11.attn.hook_q
blocks.11.ln1.hook_normalized
blocks.11.ln1.hook_scale
blocks.11.hook_v_input
blocks.11.hook_k_input
blocks.11.hook_q_input
blocks.11.hook_resid_pre
blocks.10.hook_resid_post
blocks.10.hook_mlp_out
blocks.10.mlp.hook_post
blocks.10.mlp.hook_pre
blocks.10.ln2.hook_normalized
blocks.10.ln2.hook_scale
blocks.10.hook_mlp_in
blocks.10.hook_resid_mid
blocks.10.hook_attn_out
blocks.10.attn.hook_result
blocks.10.attn.hook_z
blocks.10.attn.hook_pattern
blocks.10.attn.hook_attn_scores
blocks.10.attn.hook_v
blocks.10.attn.hook_k
blocks.10.attn.hook_q
blocks.10.ln

# Helper Methods

In [7]:
def remove_redundant_node(exp, node, safe=True, allow_fails=True):
        if safe:
            for parent_name in exp.corr.edges[node.name][node.index]:
                for parent_index in exp.corr.edges[node.name][node.index][parent_name]:
                    if exp.corr.edges[node.name][node.index][parent_name][parent_index].present:
                        raise Exception(f"You should not be removing a node that is still used by another node {node} {(parent_name, parent_index)}")

        bfs = [node]
        bfs_idx = 0

        while bfs_idx < len(bfs):
            cur_node = bfs[bfs_idx]
            bfs_idx += 1

            children = exp.corr.graph[cur_node.name][cur_node.index].children

            for child_node in children:
                if not cur_node.index in exp.corr.edges[child_node.name][child_node.index][cur_node.name]:
                    continue

                if exp.corr.edges[child_node.name][child_node.index][cur_node.name][cur_node.index].edge_type.value == EdgeType.PLACEHOLDER.value:
                    # TODO be a bit more permissive, this can include all things when we have dropped an attention head...
                    continue

                try:
                    exp.corr.remove_edge(
                        child_node.name, child_node.index, cur_node.name, cur_node.index
                    )
                except KeyError as e:
                    print("Got an error", e)
                    if allow_fails:
                        continue
                    else:
                        raise e

                remove_this = True
                for parent_of_child_name in exp.corr.edges[child_node.name][child_node.index]:
                    for parent_of_child_index in exp.corr.edges[child_node.name][child_node.index][parent_of_child_name]:
                        if exp.corr.edges[child_node.name][child_node.index][parent_of_child_name][parent_of_child_index].present:
                            remove_this = False
                            break
                    if not remove_this:
                        break

                if remove_this and child_node not in bfs:
                    bfs.append(child_node)

def remove_node(exp, node):
    '''
        Method that removes node from model. Assumes children point towards
        the end of the residual stream and parents point towards the beginning.

        exp: A TLACDCExperiment object with a reverse top sorted graph
        node: A TLACDCInterpNode describing the node to remove
        root: Initally the first node in the graph
    '''
    #Removing all edges pointing to the node
    for p_name in exp.corr.edges[node.name][node.index]:
        print(p_name)
        for p_idx in exp.corr.edges[node.name][node.index][p_name]:
            edge = exp.corr.edges[node.name][node.index][p_name][p_idx]
            #if edge.edge_type != EdgeType.PLACEHOLDER:
            #    edge.present = False
            edge.present = False

    # Removing all outgoing edges from the node using BFS
    remove_redundant_node(exp, node, safe=False)

def find_attn_node(exp, layer, head):
    return exp.corr.graph[f'blocks.{layer}.attn.hook_result'][TorchIndex([None, None, head])]

def split_layers_and_heads(act: Tensor, model: HookedTransformer) -> Tensor:
    return einops.rearrange(act, '(layer head) batch seq d_model -> layer head batch seq d_model',
                            layer=model.cfg.n_layers,
                            head=model.cfg.n_heads)

hook_filter = lambda name: name.endswith("ln1.hook_normalized") or name.endswith("attn.hook_result")
def get_3_caches(model, clean_input, corrupted_input, metric):
    # cache the activations and gradients of the clean inputs
    model.reset_hooks()
    clean_cache = {}

    def forward_cache_hook(act, hook):
        clean_cache[hook.name] = act.detach()

    model.add_hook(hook_filter, forward_cache_hook, "fwd")

    clean_grad_cache = {}

    def backward_cache_hook(act, hook):
        clean_grad_cache[hook.name] = act.detach()

    model.add_hook(hook_filter, backward_cache_hook, "bwd")

    value = metric(model(clean_input))
    value.backward()

    # cache the activations of the corrupted inputs
    model.reset_hooks()
    corrupted_cache = {}

    def forward_cache_hook(act, hook):
        corrupted_cache[hook.name] = act.detach()

    model.add_hook(hook_filter, forward_cache_hook, "fwd")
    model(corrupted_input)
    model.reset_hooks()

    clean_cache = ActivationCache(clean_cache, model)
    corrupted_cache = ActivationCache(corrupted_cache, model)
    clean_grad_cache = ActivationCache(clean_grad_cache, model)
    return clean_cache, corrupted_cache, clean_grad_cache

def acdc_nodes(model: HookedTransformer,
              clean_input: Tensor,
              corrupted_input: Tensor,
              metric: Callable[[Tensor], Tensor],
              threshold: float,
              exp: TLACDCExperiment,
              attr_absolute_val: bool = False) -> Tuple[
                  HookedTransformer, Bool[Tensor, 'n_layer n_heads']]:
    '''
    Runs attribution-patching-based ACDC on the model, using the given metric and data.
    Returns the pruned model, and which heads were pruned.

    Arguments:
        model: the model to prune
        clean_input: the input to the model that contains should elicit the behavior we're looking for
        corrupted_input: the input to the model that should elicit random behavior
        metric: the metric to use to compare the model's performance on the clean and corrupted inputs
        threshold: the threshold below which to prune
        create_model: a function that returns a new model of the same type as the input model
        attr_absolute_val: whether to take the absolute value of the attribution before thresholding
    '''
    # get the 2 fwd and 1 bwd caches; cache "normalized" and "result" of attn layers
    clean_cache, corrupted_cache, clean_grad_cache = get_3_caches(model, clean_input, corrupted_input, metric)

    # compute first-order Taylor approximation for each node to get the attribution
    clean_head_act = clean_cache.stack_head_results()
    corr_head_act = corrupted_cache.stack_head_results()
    clean_grad_act = clean_grad_cache.stack_head_results()

    # compute attributions of each node
    node_attr = (clean_head_act - corr_head_act) * clean_grad_act
    # separate layers and heads, sum over d_model (to complete the dot product), batch, and seq
    node_attr = split_layers_and_heads(node_attr, model).sum((2, 3, 4))

    if attr_absolute_val:
        node_attr = node_attr.abs()
    del clean_cache
    del clean_head_act
    del corrupted_cache
    del corr_head_act
    del clean_grad_cache
    del clean_grad_act
    t.cuda.empty_cache()
    # prune all nodes whose attribution is below the threshold
    should_prune = node_attr < threshold
    pruned_nodes_attr = {}
    for layer, head in itertools.product(range(model.cfg.n_layers), range(model.cfg.n_heads)):
        if should_prune[layer, head]:
            print(f'PRUNING L{layer}H{head} with attribution {node_attr[layer, head]}')
            # Find the corresponding node in computation graph
            node = find_attn_node(exp, layer, head)
            print(f'\tFound node {node.name}')
            # Prune node
            remove_node(exp, node)
            print(f'\tRemoved node {node.name}')
            pruned_nodes_attr[(layer, head)] = node_attr[layer, head]
    return pruned_nodes_attr

# Run Experiment

In [8]:
THRESHOLD = 0.5
pruned_nodes_attr = acdc_nodes(
    model=exp.model,
    clean_input=clean_dataset.toks,
    corrupted_input=corr_dataset.toks,
    metric=ioi_metric,
    threshold=THRESHOLD,
    exp=exp,
    attr_absolute_val=True,
) 

PRUNING L0H0 with attribution 0.014933452010154724
	Found node blocks.0.attn.hook_result
blocks.0.attn.hook_q
blocks.0.attn.hook_k
blocks.0.attn.hook_v
	Removed node blocks.0.attn.hook_result
PRUNING L0H1 with attribution 0.004214229062199593
	Found node blocks.0.attn.hook_result
blocks.0.attn.hook_q
blocks.0.attn.hook_k
blocks.0.attn.hook_v
	Removed node blocks.0.attn.hook_result
PRUNING L0H2 with attribution 0.003962965682148933
	Found node blocks.0.attn.hook_result
blocks.0.attn.hook_q
blocks.0.attn.hook_k
blocks.0.attn.hook_v
	Removed node blocks.0.attn.hook_result
PRUNING L0H3 with attribution 0.0938766747713089
	Found node blocks.0.attn.hook_result
blocks.0.attn.hook_q
blocks.0.attn.hook_k
blocks.0.attn.hook_v
	Removed node blocks.0.attn.hook_result
PRUNING L0H4 with attribution 0.02162221074104309
	Found node blocks.0.attn.hook_result
blocks.0.attn.hook_q
blocks.0.attn.hook_k
blocks.0.attn.hook_v
	Removed node blocks.0.attn.hook_result
PRUNING L0H5 with attribution 0.01294391229

# Show resulting graph