In [None]:
%load_ext autoreload
%autoreload 2
import multiprocessing
multiprocessing.set_start_method("spawn")
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))

In [None]:
import argparse
import numpy as np
import torch
import os
from collections import Counter
from tqdm.auto import tqdm
import sys
from datetime import datetime
import sys
import os
from utils import (
    same_recons, get_topk_acc, get_layer_names,
    evaluate_performance, Logger, prepare_encoders, get_layer_names_new
)
from intervention_utils import intervension
from IterativePatching import initialize_model, prepare_inputs
import glob
def load_dataset(operation, base_path=None, eval="TRAIN"):
    if base_path is None:
        path = os.path.join("data", "Arco", "CircuitFinding")
    else:
        path = os.path.join(base_path, "data", "Arco", "CircuitFinding")
    abs_path = os.path.abspath(path)

    file_path = glob.glob(f"{abs_path}/CD_{eval}_{operation}_*.npy")[0]
    data = np.load(file_path, allow_pickle=True)
    return data

In [None]:
import json
with open("../../circuit_config_2.json", "r") as f:
    configs = json.load(f)

In [None]:
def select_topk_dataset(ranks, datapoints, top1=0, top2=0, top3=0, seed=42):
    np.random.seed(seed)

    ranks = np.array(ranks)
    datapoints = np.array(datapoints, dtype=object)

    idx_top1 = np.where(ranks == 1)[0]
    idx_top2 = np.where(ranks == 2)[0]
    idx_top3 = np.where(ranks == 3)[0]

    if len(idx_top1) < top1 or len(idx_top2) < top2 or len(idx_top3) < top3:
        raise ValueError("Not enough samples in one or more top-k categories.")

    selected_idx_1 = np.random.choice(idx_top1, top1, replace=False)
    selected_idx_2 = np.random.choice(idx_top2, top2, replace=False)
    selected_idx_3 = np.random.choice(idx_top3, top3, replace=False)

    all_selected_idx = np.concatenate([selected_idx_1, selected_idx_2, selected_idx_3])
    subset = datapoints[all_selected_idx]
    return all_selected_idx.tolist(), subset.tolist()

In [None]:
# ITERATVIE
intervention_class, model = initialize_model(base_path="../..")
full_range = set(range(0, 117))  # 0 to 117 inclusive
EVAL = "TRAIN"

counter = 0
for config in configs[:]:
    OPERATION = config["operation"]
    PATCH_TYPE = config["patch_type"] # mean or resample
    CTR_TOKEN = config["CTR_token"]
    CIRCUIT = config["circuit"]
    CTR = config["CTR"]

    MAX_SAMPLES = 500
    TYPE_SUBSET = "correct"
    if CTR:
        continue

    layer_names = get_layer_names_new()
    datapoints = load_dataset(OPERATION, base_path="../..", eval=EVAL)
    datapoints = np.array(datapoints, dtype=object)
    if PATCH_TYPE == "mean":
        mean_patches = np.load("../../data/Arco/MeanPatching/mean_patching_10000.npy", allow_pickle=True).item()
    elif PATCH_TYPE == "resample":
        n = 100 if EVAL == "TRAIN" else 400 if OPERATION != "log" else 209
        data_name = f"../../data/Arco/ResamplePatching/cached_values_{n}_{OPERATION}_{CTR_TOKEN}_{EVAL}.npz"
        if CTR:
            print("Using CTR ONLY")
            mean_patches = np.load(data_name, allow_pickle=True)["resample_patch_CTR"]
        else:
            mean_patches = np.load(data_name, allow_pickle=True)["resample_patch_mean"]
    else:
        raise NotImplementedError
    
    Xs, ys, equations, eqs_untill_target = prepare_inputs(intervention_class, 
                                                          datapoints, 
                                                          200, 
                                                          max_samples=MAX_SAMPLES)
    patch_type = PATCH_TYPE

    
    # Baseline
    encoders = prepare_encoders(layer_names, [], mean_patches, patch_type, model, Xs, ys)
    ranks, counter, model_logit_score = evaluate_performance(model,
                                                         intervention_class,
                                                         encoders,
                                                         equations,
                                                         eqs_untill_target,
                                                         OPERATION,
                                                         return_activation=True)
    model_top1 = get_topk_acc(counter, 1)
    model_top2 = get_topk_acc(counter, 2)
    model_top3 = get_topk_acc(counter, 3)
    print(f"model:              Top-1: {model_top1:.3f}, Top-2: {model_top2:.3f}, Top-3: {model_top3:.3f}, logit score: {model_logit_score.mean():.3f}")
    
    
    
    # performance of model with circuit only
    excluded = full_range - set(CIRCUIT)

    encoders = prepare_encoders(layer_names, excluded, mean_patches, patch_type, model, Xs, ys)
    ranks, counter, circuit_logit_score = evaluate_performance(model,
                                                intervention_class,
                                                encoders,
                                                equations,
                                                eqs_untill_target,
                                                OPERATION,
                                                return_activation=True)

    circuit_top1 = get_topk_acc(counter, 1)
    circuit_top2 = get_topk_acc(counter, 2)
    circuit_top3 = get_topk_acc(counter, 3)
    print(f"Circuit:            Top-1: {circuit_top1:.3f}, Top-2: {circuit_top2:.3f}, Top-3: {circuit_top3:.3f}, logit score: {circuit_logit_score.mean():.3f}")
    
    encoders = prepare_encoders(layer_names, full_range, mean_patches, patch_type, model, Xs, ys)
    ranks, counter, MC_logit_score = evaluate_performance(model,
                                                         intervention_class,
                                                         encoders,
                                                         equations,
                                                         eqs_untill_target,
                                                         OPERATION,
                                                         return_activation=True)
    
    MC_top1 = get_topk_acc(counter, 1)
    MC_top2 = get_topk_acc(counter, 2)
    MC_top3 = get_topk_acc(counter, 3)
    print(f"model compliment:   Top-1: {MC_top1:.3f}, Top-2: {MC_top2:.3f}, Top-3: {MC_top3:.3f}, logit score: {MC_logit_score.mean():.3f}")
    
    # performance of model with circuit only
    excluded = full_range - set(CIRCUIT)
    encoders = prepare_encoders(layer_names, CIRCUIT, mean_patches, patch_type, model, Xs, ys)
    ranks, counter, CC_logit_score = evaluate_performance(model,
                                                intervention_class,
                                                encoders,
                                                equations,
                                                eqs_untill_target,
                                                OPERATION,
                                                return_activation=True)
    
    CC_top1 = get_topk_acc(counter, 1)
    CC_top2 = get_topk_acc(counter, 2)
    CC_top3 = get_topk_acc(counter, 3)
    print(f"Circuit compliment: Top-1: {CC_top1:.3f}, Top-2: {CC_top2:.3f}, Top-3: {CC_top3:.3f}, logit score: {CC_logit_score.mean():.3f}")
    print(f"{OPERATION} {PATCH_TYPE.capitalize()[0]}{config['Evaluation_type'].capitalize()[0]} & {len(CIRCUIT)} & {circuit_top1:.3f} & {circuit_top2:.3f} & {circuit_top3:.3f} & {circuit_logit_score.mean():.3f} && {CC_top3:.3f} & {CC_logit_score.mean():.3f} & -  & - ")
print(counter)