# Finding the Add/Sub Circuit Using EAP(-IG)

First, we import various packages.

In [None]:

from functools import partial
from typing import Optional, List, Union, Literal, Tuple
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils
from transformers import AutoTokenizer 
from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline,get_circuit_logits
from eap.attribute import attribute 
from eap.attribute import tokenize_plus
from eap.metrics import logit_diff, direct_logit
import re

In [None]:
torch.cuda.empty_cache()

## Dataset and Metrics

This package expects data to come from a dataloader. Each item consists of clean and corrupted paired inputs (strings), as well as a label (encoded as a token id). For convenience, we've included a dataset in that form as a CSV (more to come with the full code of the paper).

A metric takes in the model's (possibly corrupted) logits, clean logits, input lengths, and labels. It computes a metric value for each batch item; this can either be used as is, or turned into a loss (lower is better), or meaned.

In [None]:
def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    return clean, corrupted, labels

class EAPDataset(Dataset):
    def __init__(self, filepath):
        self.df = pd.read_csv(filepath)

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        return row['clean'], row['corrupted'], row['label']
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)
    
def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)

    logits = logits[idx, input_length - 1]
    return logits


def kl_divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, loss=True):
    logits = get_logit_positions(logits, input_length)
    clean_logits = get_logit_positions(clean_logits, input_length)

    probs = torch.softmax(logits, dim=-1)
    clean_probs = torch.softmax(clean_logits, dim=-1)

    results = F.kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
    return results.mean() if mean else results


## Performing EAP-IG

First, we load the model, data, and metric.

In [None]:
model_name = 'pythia-1.4b-deduped'

from safetensors.torch import load_file
from transformers import AutoModelForCausalLM

base_model_path = "EleutherAI/pythia-1.4b-deduped"
lora_weights_path = "/add_sub_mul_div/finetune_pythia_steps/PEFT/add_sub/lora_results/r32a64/checkpoint-500/adapter_model.safetensors"


base_model = AutoModelForCausalLM.from_pretrained(base_model_path)
lora_weights = load_file(lora_weights_path)

scaling = 2  
scaling_extra = 2  


def merge_lora_weights(model, lora_weights, scaling=2.0, scaling_extra=2.0):
    for name, param in model.named_parameters():
        layer_name = name.rsplit(".", 1)[0]  

        if "bias" in name:
            if name in lora_weights and lora_weights[name].shape == param.data.shape:
                param.data += lora_weights[name].to(param.device)
                print(f"Applied LoRA bias to: {name}")
            else:
                print(f"Skipped LoRA bias update: {name}")
            continue

        lora_A_key = f"{layer_name}.lora_A"
        lora_B_key = f"{layer_name}.lora_B"
        lora_A_extra_key = f"{layer_name}.lora_A_extra"
        lora_B_extra_key = f"{layer_name}.lora_B_extra"

        delta_weight = None

        if lora_A_key in lora_weights and lora_B_key in lora_weights:
            lora_A = lora_weights[lora_A_key].to(param.device)
            lora_B = lora_weights[lora_B_key].to(param.device)
            delta_weight = torch.matmul(lora_B, lora_A) * scaling  
            print(f"Applied standard LoRA to: {layer_name}")

        if lora_A_extra_key in lora_weights and lora_B_extra_key in lora_weights:
            lora_A_extra = lora_weights[lora_A_extra_key].to(param.device)
            lora_B_extra = lora_weights[lora_B_extra_key].to(param.device)
            extra_delta = torch.matmul(lora_B_extra, lora_A_extra) * scaling_extra 
            delta_weight = delta_weight + extra_delta if delta_weight is not None else extra_delta
            print(f"Applied extra LoRA to: {layer_name}")

        if delta_weight is not None and delta_weight.shape == param.data.shape:
            param.data += delta_weight
            print(f"Updated weight: {name}")
        elif delta_weight is not None:
            print(f"Shape mismatch: {layer_name} - ΔW {delta_weight.shape}, param {param.shape}")
        else:
            print(f"No LoRA update for: {name}")

    return model


hf_model = merge_lora_weights(base_model, lora_weights)

model = HookedTransformer.from_pretrained(model_name, device='cuda',hf_model=hf_model)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

tokenizer = model.tokenizer
ds = EAPDataset('/add_sub_mul_div/4_arithmetic_data/add_sub/add_sub_circuit.csv')
dataloader = ds.to_dataloader(64)

Then, we perform EAP! We instantiate an unscored graph from the model, and use the attribute method to score it. This requires a model, graph, dataloader, and loss. We set `method='EAP-IG'`, and set the number of iterations via `ig_steps`.

In [None]:
# Instantiate a graph with a model
g = Graph.from_model(model)

# Attribute using the model, graph, clean / corrupted data and labels, as well as a metric
attribute(model, g, dataloader, partial(kl_divergence, loss=True, mean=True), method='EAP-IG', ig_steps=5)

In [None]:
g.count_included_nodes(), g.count_included_edges()

In [None]:
# Print total number of edges in the graph
total_edges = len(g.edges)
print(f"Total number of edges: {total_edges}")

# Calculate 5% of the edges
five_percent_edges = int(total_edges * 0.05)
print(f"5% of the edges: {five_percent_edges}")

We can now apply greedy search to the scored graph to find a circuit! We prune dead nodes, and export the circuit.

In [None]:
g.apply_topn(five_percent_edges , absolute=True)
g.prune_dead_nodes()
g.to_json('/add_sub_mul_div/graph_results_steps/graph_add_sub_1.4b_r32_epoch4_initial.json')

In [None]:
g.count_included_nodes(), g.count_included_edges()

In [None]:
import torch
from functools import partial

# Function to calculate faithfulness and percentage of model performance
def calculate_faithfulness(model, g, dataloader, metric_fn):
    # Evaluate baseline (full model performance)
    baseline_performance = evaluate_baseline(model, dataloader, metric_fn).mean().item()

    # Evaluate the discovered circuit's performance
    circuit_performance = evaluate_graph(model, g, dataloader, metric_fn).mean().item()

    # Calculate the absolute difference (faithfulness)
    faithfulness = abs(baseline_performance - circuit_performance)

    # Calculate the percentage of model performance achieved by the circuit
    percentage_performance = (1 - faithfulness / baseline_performance) * 100

    print(f"Baseline performance: {baseline_performance}")
    print(f"Circuit performance: {circuit_performance}")
    print(f"Faithfulness: {faithfulness}")
    print(f"Percentage of model performance achieved by the circuit: {percentage_performance:.2f}%")

    return faithfulness, percentage_performance

# Define the KL divergence metric (from your code)
metric_fn = partial(kl_divergence, loss=False, mean=False)

# Calculate faithfulness and percentage performance
faithfulness, percentage_performance = calculate_faithfulness(model, g, dataloader, metric_fn)


In [None]:
from functools import partial

def exact_match_accuracy(logits, corrupted_logits, input_lengths, labels):

    batch_size = logits.size(0)
    device = logits.device

    # Get the positions of the last tokens in each sequence
    positions = input_lengths - 1  # [batch_size]

    # Gather the logits at these positions
    last_logits = logits[torch.arange(batch_size), positions, :]  # [batch_size, vocab_size]

    # Get the predicted tokens
    predicted_tokens = last_logits.argmax(dim=-1)  # [batch_size]

    # Convert predicted tokens to strings
    predicted_strings = [model.to_string(token.item()).strip() for token in predicted_tokens]

    # Convert labels to strings
    labels_strings = []
    for i in range(batch_size):
        lab = labels[i]
        if isinstance(lab, torch.Tensor):
            lab = lab.item()
        labels_strings.append(str(lab).strip())

    # Compute correctness
    correct = []
    for pred_str, label_str in zip(predicted_strings, labels_strings):
        if pred_str == label_str:
            correct.append(1.0)
        else:
            correct.append(0.0)

    return torch.tensor(correct, device=device)

# Evaluate the baseline model
baseline_accuracy = evaluate_baseline(model, dataloader, exact_match_accuracy, quiet=True).mean().item()

# Evaluate the graph model
graph_accuracy = evaluate_graph(model, g, dataloader, exact_match_accuracy, quiet=True).mean().item()

print(f"Baseline model accuracy: {baseline_accuracy:.4f}; Graph model accuracy: {graph_accuracy:.4f}")
