# Quickstart Guide: **Isolating Path Effect for Latent Circuit Discovery**


## Setup

### Models

The code in this repository is compatible with models from transformer-lens which is built on top of PyTorch. You can install it via pip:
```bash
pip install transformer-lens
```


In [1]:
import torch
from transformer_lens import HookedTransformer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HookedTransformer.from_pretrained('gpt2-small', device=device, torch_dtype=torch.float32, center_unembed=True)

The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.


Loaded pretrained model gpt2-small into HookedTransformer


### Data

The current implementations are aimed at finding the circuit responsible for single token predictions. It may be adapted to other cases with appropriate changes but this is not currently supported out of the box.

Furthermore if you want to perform a positional analysis, higlighting the contributions of single residual positions, you will need to provide batches with constant number of tokens. This is required because of the underlying assumption that throughout the batch, the meaning of the token at a certain position is similar. On the other hand, if you want to perform a non-positional analysis, you can use batches with variable number of tokens.

In the below example we will use a small batch of 4 samples from the `mib-bench/ioi` dataset. This dataset is based on the Indirect Object Identification task [(K. Wang et al)](https://arxiv.org/abs/2211.00593), which is a common benchmark for mechanistic interpretability. You can find more information about the dataset [here](https://huggingface.co/datasets/mib-bench/ioi).

In [2]:
from datasets import load_dataset

dataset = load_dataset("mib-bench/ioi", split="test")
batch_size = 4
target_length = 15

prompts, answers = [], []
counterfactual_prompts, counterfactual_answers = [], []

for sample in dataset:
	if model.to_tokens(sample['prompt'], prepend_bos=True).shape[1] == target_length:
		prompts.append(sample['prompt'])
		answers.append(f' {sample['metadata']['indirect_object']}')

		counterfactual_prompts.append(sample['s2_io_flip_counterfactual']['prompt'])
		counterfactual_answers.append(f' {sample['s2_io_flip_counterfactual']['choices'][sample['s2_io_flip_counterfactual']['answerKey']]}')
		if len(prompts) >= batch_size:
			break

print("Example prompt:\n   ``", prompts[0], "``")
print("  Answer: ``", answers[0].replace(' ', '_'), "``")

print("\nCounterfactual prompt:\n   ``", counterfactual_prompts[0], "``")
print("  Counterfactual answer: ``", counterfactual_answers[0].replace(' ', '_'), "``")


Example prompt:
   `` Once Austin and Phil arrived at the ramp, Austin gave a backpack to ``
  Answer: `` _Phil ``

Counterfactual prompt:
   `` Once Austin and Phil arrived at the ramp, Phil gave a backpack to ``
  Counterfactual answer: `` _Austin ``


### Cache

The main reason why we use transformer-lens is that it provides a caching mechanism that allows us to store the activations of the model during a forward pass. This is useful because we will need to access these activations multiple times during the path attribution patching process.


In [3]:
clean_logits, clean_cache = model.run_with_cache(model.to_tokens(prompts, prepend_bos=True))
clean_probabilities = clean_logits.softmax(dim=-1)
correct_token_ids = [model.to_single_token(answers[i]) for i in range(len(answers))]

cf_logits, cf_cache = model.run_with_cache(model.to_tokens(counterfactual_prompts, prepend_bos=True))
cf_probabilities = cf_logits.softmax(dim=-1)
cf_token_ids = [model.to_single_token(counterfactual_answers[i]) for i in range(len(counterfactual_answers))]

### Metrics

The goal of this library is to identify circuits by isolating highly contributional paths. Obviously we need a way to measure the contribution of a path to the final prediction. This is done via metrics.
We provide different metrics out of the box, but you can also define your own metric. The only requirement is that the metric can be expressed as a function of the corrupted final residual stream obtained after path patching.

Here we provide two examples of metrics, a simple percentage of logit difference of the target token and the indirect effect metric [(A. Stolfo et al)](arxiv.org/abs/2305.15054).


In [4]:
from backward_search_approximated.utils.metrics import indirect_effect, compare_token_logit
from functools import partial 

compare_token_logit_metric = partial(compare_token_logit, clean_resid=clean_cache[f'blocks.{model.cfg.n_layers-1}.hook_resid_post'], model=model, target_tokens=correct_token_ids)
indirect_effect_metric = partial(indirect_effect, clean_resid=clean_cache[f'blocks.{model.cfg.n_layers-1}.hook_resid_post'], model=model, clean_targets=correct_token_ids, corrupt_targets=cf_token_ids, verbose=False, set_baseline=False)

### Nodes

Nodes are the building blocks of paths and are the core of our code. A node represents a specific component of the model at a specific layer and position (if applicable).

Currently we support the following nodes:
- `EMBED_Node`: the token embedding at position 0
- `MLP_Node`: the MLP at a specific layer and position
- `ATT_Node`: the attention head or block at a specific layer and position
- `FINAL_Node`: a dummy node representing the final residual stream before the unembedding layer

They expose three main methods:
- `forward`: computes the output of the node when applying a given patching
- `get_expansion_candidates`: returns all the predecessor nodes of the current node
- `calculate_gradient`: computes the gradient of the metric with respect to the input of the node, passing through the path leading to the output


Particularly to perform a search we have to start from the `FINAL_Node`, whose initialization will impact the behavior of the search.


In [5]:
from backward_search_approximated.utils.nodes import FINAL_ApproxNode

root_node = FINAL_ApproxNode(
							model=model,
							layer=model.cfg.n_layers - 1,
							metric=indirect_effect_metric,
							position=target_length - 1, #   <-- if None (non-positional analysis), else the last position of the sequence
							parent=None,
							children=set(),
							msg_cache=dict(clean_cache),
							cf_cache=dict(cf_cache),
							gradient=None,
							patch_type='counterfactual' # 	<--	either 'zero' or 'counterfactual'
							)


## Backward Breadth-First Search (BFS)

In [None]:
from backward_search_approximated.utils.graph_search import IsolatingPathEffect_BW

minimum_contribution_threshold = 2

indirect_effect_paths_BW = IsolatingPathEffect_BW(
	model=model,
	metric=indirect_effect_metric, # You can change this to any metric you want (es. compare_token_logit_metric)
	root=root_node, # The root node of the search, if root.position is None the analysis will be non-positional
	min_contribution=minimum_contribution_threshold, 
	include_negative=True,
	return_all=False,
	batch_heads=True, # If True is faster but less precise (Note: if you have to chose generally batching heads is worse than batching positions)
	batch_positions=True, # If True is faster but less precise (only for positional analysis)
)


(total 1)    Frontier: [(tensor(4921.8726, grad_fn=<MeanBackward0>), [FINAL_ApproxNode(layer=11)])]


100%|██████████| 1/1 [00:07<00:00,  7.91s/it]


(total 4)    Frontier: [(tensor(38.2632, grad_fn=<MeanBackward0>), [ATTN_ApproxNode(layer=9, head=9, position=None, keyvalue_position=None, patch_query=False, patch_key=True, patch_value=True), FINAL_ApproxNode(layer=11)]), (tensor(38.2632, grad_fn=<MeanBackward0>), [ATTN_ApproxNode(layer=9, head=9, position=None, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), FINAL_ApproxNode(layer=11)])]... ]


100%|██████████| 4/4 [00:16<00:00,  4.06s/it]


(total 4)    Frontier: [(tensor(6.8362, grad_fn=<MeanBackward0>), [ATTN_ApproxNode(layer=8, head=6, position=None, keyvalue_position=None, patch_query=False, patch_key=True, patch_value=True), ATTN_ApproxNode(layer=9, head=9, position=None, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), FINAL_ApproxNode(layer=11)]), (tensor(6.8362, grad_fn=<MeanBackward0>), [ATTN_ApproxNode(layer=8, head=6, position=None, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), ATTN_ApproxNode(layer=9, head=9, position=None, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), FINAL_ApproxNode(layer=11)])]... ]


100%|██████████| 4/4 [00:15<00:00,  3.92s/it]


(total 2)    Frontier: [(tensor(2.2764, grad_fn=<MeanBackward0>), [ATTN_ApproxNode(layer=5, head=5, position=None, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), ATTN_ApproxNode(layer=8, head=6, position=None, keyvalue_position=None, patch_query=False, patch_key=True, patch_value=True), ATTN_ApproxNode(layer=9, head=9, position=None, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), FINAL_ApproxNode(layer=11)]), (tensor(2.2764, grad_fn=<MeanBackward0>), [ATTN_ApproxNode(layer=5, head=5, position=None, keyvalue_position=None, patch_query=False, patch_key=True, patch_value=True), ATTN_ApproxNode(layer=8, head=6, position=None, keyvalue_position=None, patch_query=False, patch_key=True, patch_value=True), ATTN_ApproxNode(layer=9, head=9, position=None, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), FINAL_ApproxNode(layer=11)])]


100%|██████████| 2/2 [00:04<00:00,  2.07s/it]


## Path Attribution Patching (PAP)

In [7]:
from backward_search_approximated.utils.graph_search import PathAttributionPatching

minimum_contribution_threshold = 0.1

indirect_effect_paths_PAP = PathAttributionPatching(
	model=model,
	metric=indirect_effect_metric, # You can change this to any metric you want (es. compare_token_logit_metric)
	root=root_node, # The root node of the search, if root.position is None the analysis will be non-positional
	min_contribution=minimum_contribution_threshold, 
	include_negative=True,
	return_all=False,
)

print(f"Found {len(indirect_effect_paths_PAP)} paths")

root_node.position = None
indirect_effect_paths_PAP_nonpositional = PathAttributionPatching(
	model=model,
	metric=indirect_effect_metric, # You can change this to any metric you want (es. compare_token_logit_metric)
	root=root_node, # The root node of the search, if root.position is None the analysis will be non-positional
	min_contribution=minimum_contribution_threshold, 
	include_negative=True,
	return_all=False,
)

print(f"Found {len(indirect_effect_paths_PAP_nonpositional)} paths")


100%|██████████| 1/1 [00:15<00:00, 15.58s/it]
100%|██████████| 39/39 [00:10<00:00,  3.61it/s]
100%|██████████| 73/73 [00:22<00:00,  3.30it/s]
100%|██████████| 33/33 [00:06<00:00,  4.89it/s]
100%|██████████| 8/8 [00:01<00:00,  7.81it/s]


Found 5 paths


100%|██████████| 1/1 [00:00<00:00,  4.41it/s]
100%|██████████| 27/27 [00:01<00:00, 14.50it/s]
100%|██████████| 76/76 [00:04<00:00, 17.13it/s]
100%|██████████| 33/33 [00:01<00:00, 21.48it/s]
100%|██████████| 4/4 [00:00<00:00,  5.64it/s]


Found 5 paths


## Paths

In [10]:
print("Backward Breadt-First Search Paths:")
print(indirect_effect_paths_BW)

print("\nPath Attribution Patching Search Paths:")
print(indirect_effect_paths_PAP)

print("\nPath Attribution Patching Search Paths (non positional):")
print(indirect_effect_paths_PAP_nonpositional)

Backward Breadt-First Search Paths:
[]

Path Attribution Patching Search Paths:
[(tensor(0.8353, grad_fn=<MeanBackward0>), [EMBED_ApproxNode(layer=0, position=10), MLP_ApproxNode(layer=0, position=10), ATTN_ApproxNode(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), ATTN_ApproxNode(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_key=True, patch_value=True), ATTN_ApproxNode(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), FINAL_ApproxNode(layer=11)]), (tensor(0.2592, grad_fn=<MeanBackward0>), [EMBED_ApproxNode(layer=0, position=10), MLP_ApproxNode(layer=0, position=10), ATTN_ApproxNode(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_key=False, patch_value=False), ATTN_ApproxNode(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_key=True, patch_value=True), ATTN_ApproxNode(layer=10, he