## Random

In [None]:

from functools import partial
from typing import Optional, List, Union, Literal, Tuple
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils
from transformers import AutoTokenizer 
from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline,get_circuit_logits
from eap.attribute import attribute 
from eap.attribute import tokenize_plus
from eap.metrics import logit_diff, direct_logit
import re

def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    return clean, corrupted, labels

class EAPDataset(Dataset):
    def __init__(self, filepath):
        self.df = pd.read_csv(filepath)

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        return row['clean'], row['corrupted'], row['label']
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)
    
def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)

    logits = logits[idx, input_length - 1]
    return logits


def kl_divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, loss=True):
    logits = get_logit_positions(logits, input_length)
    clean_logits = get_logit_positions(clean_logits, input_length)

    probs = torch.softmax(logits, dim=-1)
    clean_probs = torch.softmax(clean_logits, dim=-1)

    results = F.kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
    return results.mean() if mean else results

from transformers import AutoConfig, AutoModelForCausalLM

model_name = 'pythia-1.4b-deduped'
base_model_name = "EleutherAI/pythia-1.4b-deduped"
config = AutoConfig.from_pretrained(base_model_name)
hf_model = AutoModelForCausalLM.from_config(config)

model = HookedTransformer.from_pretrained(model_name, device='cuda',hf_model=hf_model)

model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

tokenizer = model.tokenizer

ds = EAPDataset('/2_arithmetic_operations_100/stable_analysis/Split_data/Logical_Operations_Split_4.csv')
dataloader = ds.to_dataloader(64)
# Instantiate a graph with a model
g = Graph.from_model(model)

# Attribute using the model, graph, clean / corrupted data and labels, as well as a metric
attribute(model, g, dataloader, partial(kl_divergence, loss=True, mean=True), method='EAP-IG', ig_steps=5)

# Print total number of edges in the graph
total_edges = len(g.edges)
print(f"Total number of edges: {total_edges}")

# Calculate 5% of the edges
five_percent_edges = int(total_edges * 0.05)
print(f"5% of the edges: {five_percent_edges}")
g.apply_topn(five_percent_edges , absolute=True)
g.prune_dead_nodes()
g.to_json('/2_arithmetic_operations_100/stable_analysis/Split_circuits/random/graph_simplemath_1.4b_split4.json')

import torch
from functools import partial

# Function to calculate faithfulness and percentage of model performance
def calculate_faithfulness(model, g, dataloader, metric_fn):
    # Evaluate baseline (full model performance)
    baseline_performance = evaluate_baseline(model, dataloader, metric_fn).mean().item()

    # Evaluate the discovered circuit's performance
    circuit_performance = evaluate_graph(model, g, dataloader, metric_fn).mean().item()

    # Calculate the absolute difference (faithfulness)
    faithfulness = abs(baseline_performance - circuit_performance)

    # Calculate the percentage of model performance achieved by the circuit
    percentage_performance = (1 - faithfulness / baseline_performance) * 100

    print(f"Baseline performance: {baseline_performance}")
    print(f"Circuit performance: {circuit_performance}")
    print(f"Faithfulness: {faithfulness}")
    print(f"Percentage of model performance achieved by the circuit: {percentage_performance:.2f}%")

    return faithfulness, percentage_performance

# Define the KL divergence metric (from your code)
metric_fn = partial(kl_divergence, loss=False, mean=False)

# Calculate faithfulness and percentage performance
faithfulness, percentage_performance = calculate_faithfulness(model, g, dataloader, metric_fn)


from functools import partial

def exact_match_accuracy(logits, corrupted_logits, input_lengths, labels):

    batch_size = logits.size(0)
    device = logits.device

    # Get the positions of the last tokens in each sequence
    positions = input_lengths - 1  # [batch_size]

    # Gather the logits at these positions
    last_logits = logits[torch.arange(batch_size), positions, :]  # [batch_size, vocab_size]

    # Get the predicted tokens
    predicted_tokens = last_logits.argmax(dim=-1)  # [batch_size]

    # Convert predicted tokens to strings
    predicted_strings = [model.to_string(token.item()).strip() for token in predicted_tokens]

    # Convert labels to strings
    labels_strings = []
    for i in range(batch_size):
        lab = labels[i]
        if isinstance(lab, torch.Tensor):
            lab = lab.item()
        labels_strings.append(str(lab).strip())

    # Compute correctness
    correct = []
    for pred_str, label_str in zip(predicted_strings, labels_strings):
        if pred_str == label_str:
            correct.append(1.0)
        else:
            correct.append(0.0)

    return torch.tensor(correct, device=device)

# Evaluate the baseline model
baseline_accuracy = evaluate_baseline(model, dataloader, exact_match_accuracy, quiet=True).mean().item()

# Evaluate the graph model
graph_accuracy = evaluate_graph(model, g, dataloader, exact_match_accuracy, quiet=True).mean().item()

print(f"Baseline model accuracy: {baseline_accuracy:.4f}; Graph model accuracy: {graph_accuracy:.4f}")

## Pretrained

In [None]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from eap.attribute import attribute
from eap.metrics import logit_diff
import re
from functools import partial

# Collate function for DataLoader
def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    return list(clean), list(corrupted), labels

# Dataset class for loading data
class EAPDataset(Dataset):
    def __init__(self, filepath):
        self.df = pd.read_csv(filepath)

    def __len__(self):
        return len(self.df)

    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        return row['clean'], row['corrupted'], row['label']

    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)

# Helper function to extract logits at specific positions
def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)
    return logits[idx, input_length - 1]

# KL divergence loss function
def kl_divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, loss=True):
    logits = get_logit_positions(logits, input_length)
    clean_logits = get_logit_positions(clean_logits, input_length)
    probs = torch.softmax(logits, dim=-1)
    clean_probs = torch.softmax(clean_logits, dim=-1)
    results = F.kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
    return results.mean() if mean else results

# Exact match accuracy metric
def exact_match_accuracy(logits, corrupted_logits, input_lengths, labels):
    batch_size = logits.size(0)
    device = logits.device

    # Get the positions of the last tokens in each sequence
    positions = input_lengths - 1  # [batch_size]

    # Gather the logits at these positions
    last_logits = logits[torch.arange(batch_size), positions, :]  # [batch_size, vocab_size]

    # Get the predicted tokens
    predicted_tokens = last_logits.argmax(dim=-1)  # [batch_size]

    # Convert predicted tokens to strings
    predicted_strings = [model.to_string(token.item()).strip() for token in predicted_tokens]

    # Convert labels to strings
    labels_strings = []
    for i in range(batch_size):
        lab = labels[i]
        if isinstance(lab, torch.Tensor):
            lab = lab.item()
        labels_strings.append(str(lab).strip())

    # Compute correctness
    correct = []
    for pred_str, label_str in zip(predicted_strings, labels_strings):
        if pred_str == label_str:
            correct.append(1.0)
        else:
            correct.append(0.0)

    return torch.tensor(correct, device=device)

# Function to calculate faithfulness and percentage performance
def calculate_faithfulness(model, g, dataloader, metric_fn):
    baseline_performance = evaluate_baseline(model, dataloader, metric_fn).mean().item()
    circuit_performance = evaluate_graph(model, g, dataloader, metric_fn).mean().item()
    faithfulness = abs(baseline_performance - circuit_performance)
    percentage_performance = (1 - faithfulness / baseline_performance) * 100
    print(f"Percentage of model performance achieved by the circuit: {percentage_performance:.2f}%")
    return faithfulness, percentage_performance

split_data_path = '/2_arithmetic_operations_100/stable_analysis/Split_data'
circuit_save_path = '/2_arithmetic_operations_100/stable_analysis/Split_circuits/pretrained'

model_name = 'pythia-1.4b-deduped'
model = HookedTransformer.from_pretrained(model_name, device='cuda')
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

tokenizer = model.tokenizer
metric_fn = partial(kl_divergence, loss=False, mean=False)

n_splits = 5
batch_size = 64

# Process each split
for i in range(n_splits):
    print(f"Processing split {i} ...")
    dataset_path = os.path.join(split_data_path, f"Logical_Operations_Split_{i}.csv")
    circuit_path = os.path.join(circuit_save_path, f"graph_simplemath_1.4b_split{i}.json")

    ds = EAPDataset(dataset_path)
    dataloader = ds.to_dataloader(batch_size=batch_size)
    g = Graph.from_model(model)
    attribute(model, g, dataloader, partial(kl_divergence, loss=True, mean=True), method='EAP-IG', ig_steps=5)

    total_edges = len(g.edges)
    print(f"Total number of edges: {total_edges}")

    five_percent_edges = int(total_edges * 0.05)
    print(f"5% of the edges: {five_percent_edges}")
    g.apply_topn(five_percent_edges, absolute=True)
    g.prune_dead_nodes()

    g.to_json(circuit_path)
    print(f"Saved circuit to {circuit_path}")

    # Calculate faithfulness and percentage performance for the final graph
    faithfulness, percentage_performance = calculate_faithfulness(model, g, dataloader, metric_fn)
    baseline_accuracy = evaluate_baseline(model, dataloader, exact_match_accuracy, quiet=True).mean().item()
    graph_accuracy = evaluate_graph(model, g, dataloader, exact_match_accuracy, quiet=True).mean().item()
    print(f"Baseline model accuracy: {baseline_accuracy:.4f}; Graph model accuracy: {graph_accuracy:.4f}\n")


## Finetuned

In [None]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" 
from functools import partial
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformer_lens import HookedTransformer
from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from eap.attribute import attribute
from eap.metrics import logit_diff
import re
from safetensors.torch import load_file

# Collate function for DataLoader
def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    return clean, corrupted, labels

# Dataset class for loading data
class EAPDataset(Dataset):
    def __init__(self, filepath):
        self.df = pd.read_csv(filepath)

    def __len__(self):
        return len(self.df)

    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        return row['clean'], row['corrupted'], row['label']

    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)

# Helper function to extract logits at specific positions
def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)
    logits = logits[idx, input_length - 1]
    return logits

# KL divergence loss function
def kl_divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, loss=True):
    logits = get_logit_positions(logits, input_length)
    clean_logits = get_logit_positions(clean_logits, input_length)
    probs = torch.softmax(logits, dim=-1)
    clean_probs = torch.softmax(clean_logits, dim=-1)
    results = F.kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
    return results.mean() if mean else results

# Exact match accuracy metric
def exact_match_accuracy(logits, corrupted_logits, input_lengths, labels):
    batch_size = logits.size(0)
    device = logits.device
    # Get the positions of the last tokens in each sequence
    positions = input_lengths - 1
    last_logits = logits[torch.arange(batch_size), positions, :]
    predicted_tokens = last_logits.argmax(dim=-1)
    predicted_strings = [model.to_string(token.item()).strip() for token in predicted_tokens]
    labels_strings = []
    for i in range(batch_size):
        lab = labels[i]
        if isinstance(lab, torch.Tensor):
            lab = lab.item()
        labels_strings.append(str(lab).strip())
    correct = [1.0 if pred == lab else 0.0 for pred, lab in zip(predicted_strings, labels_strings)]
    return torch.tensor(correct, device=device)

# Function to calculate faithfulness and percentage performance
def calculate_faithfulness(model, g, dataloader, metric_fn):
    baseline_performance = evaluate_baseline(model, dataloader, metric_fn).mean().item()
    circuit_performance = evaluate_graph(model, g, dataloader, metric_fn).mean().item()
    faithfulness = abs(baseline_performance - circuit_performance)
    percentage_performance = (1 - faithfulness / baseline_performance) * 100
    print(f"Percentage of model performance achieved by the circuit: {percentage_performance:.2f}%")
    return faithfulness, percentage_performance

# ----------------------------
# Finetuned 模型加载及 LoRA 权重合并
# ----------------------------
model_name = 'pythia-1.4b-deduped'
base_model_path = "EleutherAI/pythia-1.4b-deduped"
lora_weights_path = "/2_arithmetic_operations_100/finetune_pythia_100/Peft/lora_results/r32a64/checkpoint-250/adapter_model.safetensors"

# 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(base_model_path)
# 加载 LoRA 权重
lora_weights = load_file(lora_weights_path)

scaling = 2  # 标准 LoRA 的缩放因子
scaling_extra = 2   # 增强 LoRA 的缩放因子

def merge_lora_weights(model, lora_weights, scaling=2.0, scaling_extra=2.0):
    for name, param in model.named_parameters():
        layer_name = name.rsplit(".", 1)[0] 
        if "bias" in name:
            if name in lora_weights and lora_weights[name].shape == param.data.shape:
                param.data += lora_weights[name].to(param.device)
                print(f"Applied LoRA bias to: {name}")
            else:
                print(f"Skipped LoRA bias update: {name}")
            continue

        lora_A_key = f"{layer_name}.lora_A"
        lora_B_key = f"{layer_name}.lora_B"
        lora_A_extra_key = f"{layer_name}.lora_A_extra"
        lora_B_extra_key = f"{layer_name}.lora_B_extra"

        delta_weight = None

        if lora_A_key in lora_weights and lora_B_key in lora_weights:
            lora_A = lora_weights[lora_A_key].to(param.device)
            lora_B = lora_weights[lora_B_key].to(param.device)
            delta_weight = torch.matmul(lora_B, lora_A) * scaling
            print(f"Applied standard LoRA to: {layer_name}")

        if lora_A_extra_key in lora_weights and lora_B_extra_key in lora_weights:
            lora_A_extra = lora_weights[lora_A_extra_key].to(param.device)
            lora_B_extra = lora_weights[lora_B_extra_key].to(param.device)
            extra_delta = torch.matmul(lora_B_extra, lora_A_extra) * scaling_extra
            delta_weight = delta_weight + extra_delta if delta_weight is not None else extra_delta
            print(f"Applied extra LoRA to: {layer_name}")

        if delta_weight is not None and delta_weight.shape == param.data.shape:
            param.data += delta_weight
            print(f"Updated weight: {name}")
        elif delta_weight is not None:
            print(f"Shape mismatch: {layer_name} - ΔW {delta_weight.shape}, param {param.shape}")
        else:
            print(f"No LoRA update for: {name}")
    return model

hf_model = merge_lora_weights(base_model, lora_weights)

model = HookedTransformer.from_pretrained(model_name, device='cuda', hf_model=hf_model)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
# Prepare tokenizer and metric
tokenizer = model.tokenizer
metric_fn = partial(kl_divergence, loss=False, mean=False)

split_data_path = '/2_arithmetic_operations_100/stable_analysis/Split_data'
circuit_save_path = '/2_arithmetic_operations_100/stable_analysis/Split_circuits/finetuned'
os.makedirs(circuit_save_path, exist_ok=True)

n_splits = 5
batch_size = 64

for i in range(n_splits):
    print(f"\nProcessing split {i} ...")
    dataset_path = os.path.join(split_data_path, f"Logical_Operations_Split_{i}.csv")
    circuit_path = os.path.join(circuit_save_path, f"graph_simplemath_1.4b_split{i}.json")

    # Load dataset
    ds = EAPDataset(dataset_path)
    dataloader = ds.to_dataloader(batch_size=batch_size)

    # Instantiate a graph with the model
    g = Graph.from_model(model)

    # Attribute using the model, graph, and data
    attribute(model, g, dataloader, partial(kl_divergence, loss=True, mean=True), method='EAP-IG', ig_steps=5)

    # Print total number of edges in the graph
    total_edges = len(g.edges)
    print(f"Total number of edges: {total_edges}")

    # Calculate 5% of the edges and apply pruning
    five_percent_edges = int(total_edges * 0.05)
    print(f"5% of the edges: {five_percent_edges}")
    g.apply_topn(five_percent_edges, absolute=True)
    g.prune_dead_nodes()

    # Save the graph (circuit)
    g.to_json(circuit_path)
    print(f"Saved circuit to {circuit_path}")

    # Evaluate faithfulness and accuracy
    faithfulness, percentage_performance = calculate_faithfulness(model, g, dataloader, metric_fn)
    baseline_accuracy = evaluate_baseline(model, dataloader, exact_match_accuracy, quiet=True).mean().item()
    graph_accuracy = evaluate_graph(model, g, dataloader, exact_match_accuracy, quiet=True).mean().item()
    print(f"Baseline model accuracy: {baseline_accuracy:.4f}; Graph model accuracy: {graph_accuracy:.4f}\n")


## Robust

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt


def load_circuits(circuit_dir, n_splits):
    circuits = []
    for i in range(n_splits):
        file_path = os.path.join(circuit_dir, f"graph_simplemath_1.4b_split{i}.json")
        with open(file_path, 'r') as f:
            data = json.load(f)
            edges = data.get("edges", [])
            edge_set = set(tuple(edge) for edge in edges)
            circuits.append(edge_set)
    return circuits


def jaccard_similarity(set1, set2):
    union = set1.union(set2)
    if not union:
        return 0.0
    intersection = set1.intersection(set2)
    return len(intersection) / len(union)

def compute_pairwise_jaccard(circuits):
    similarities = []
    n = len(circuits)
    for i in range(n):
        for j in range(i+1, n):
            sim = jaccard_similarity(circuits[i], circuits[j])
            similarities.append(sim)
    return similarities

n_splits = 5

fine_tuned_dir = "/2_arithmetic_operations_100/stable_analysis/Split_circuits/f1"
pre_trained_dir = "/2_arithmetic_operations_100/stable_analysis/Split_circuits/p1"
random_dir = "/2_arithmetic_operations_100/stable_analysis/Split_circuits/r1"


circuits_finetuned = load_circuits(fine_tuned_dir, n_splits)
circuits_pretrained = load_circuits(pre_trained_dir, n_splits)
circuits_random = load_circuits(random_dir, n_splits)


jacc_finetuned = compute_pairwise_jaccard(circuits_finetuned)
jacc_pretrained = compute_pairwise_jaccard(circuits_pretrained)
jacc_random = compute_pairwise_jaccard(circuits_random)


colors = ["#104670", "#E47159", "#DFA085"]
markers = ["o", "s", "D"]
labels_list = ["Fine-tuned", "Pre-trained", "Random"]

plt.figure(figsize=(8, 6))
model_indices = [0, 1, 2]
all_jacc = [jacc_finetuned, jacc_pretrained, jacc_random]

for idx, similarities in zip(model_indices, all_jacc):

    x_vals = np.random.normal(loc=idx, scale=0.04, size=len(similarities))
    plt.scatter(x_vals, similarities, color=colors[idx], marker=markers[idx],
                label=labels_list[idx], alpha=0.7)
    mean_sim = np.mean(similarities)
    plt.hlines(mean_sim, idx - 0.2, idx + 0.2, colors=colors[idx], linestyles='dashed', linewidth=2)
    print(f"{labels_list[idx]}: mean Jaccard similarity = {mean_sim:.4f}")

plt.xticks(model_indices, labels_list, fontsize=14, fontweight='bold')
plt.xlabel("Model", fontsize=16, fontweight='bold')
plt.ylabel("Robustness Score", fontsize=16, fontweight='bold')
plt.title("Robustness Analysis (Add/Sub)", fontsize=18, fontweight='bold')


plt.ylim(0.5, 1)
plt.yticks(fontsize=14, fontweight='bold')
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.legend(fontsize=12, frameon=True)
plt.tight_layout()
plt.show()
plt.savefig("robust_rebuttal.pdf")