# Set configs


In [1]:
from pathlib import Path
import ast
import pandas as pd
import numpy as np

## inspect the voverlapping neuron indices

In [2]:
ROOT = Path("/Users/jliu/workspace/RAG/")
neuron_path = ROOT / "results" / "token_freq"

In [3]:
def compare_idx(neuron_path,vec,model,neuron_num):
    """Compare whether the 2 index lists are the same."""
    boost_df = pd.read_csv(neuron_path/"boost"/vec/"EleutherAI"/f"pythia-{model}-deduped"/f"500_{neuron_num}.csv")
    suppress_df = pd.read_csv(neuron_path/"suppress"/vec/"EleutherAI"/f"pythia-{model}-deduped"/f"500_{neuron_num}.csv")
    boost_idx = set(ast.literal_eval(boost_df["top_neurons"].to_list()))
    suppress_idx = set(ast.literal_eval(suppress_df["top_neurons"].to_list()))
    common_idx = boost_idx & suppress_idx
    return True if len(common_idx) > 0 else False

In [22]:
def compare_idx(neuron_path, vec, model, neuron_num):
    """ Compare whether neuron indices overlap between boost and suppress conditions."""
    try:
        # Construct file paths
        boost_file = neuron_path / "boost" / vec / "EleutherAI" / f"pythia-{model}-deduped" / f"500_{neuron_num}.csv"
        suppress_file = neuron_path / "suppress" / vec / "EleutherAI" / f"pythia-{model}-deduped" / f"500_{neuron_num}.csv"
        # Check if files exist
        if not boost_file.exists() or not suppress_file.exists():
            print(f"Warning: One or both files do not exist: {boost_file}, {suppress_file}")
            return False
        # Read CSV files
        boost_df = pd.read_csv(boost_file)
        suppress_df = pd.read_csv(suppress_file)
 
        # loop over different steps
        n = 0
        while n < boost_df.shape[0]:
            boost_neurons = set(ast.literal_eval(boost_df["top_neurons"].tolist()[n]))
            suppress_neurons = set(ast.literal_eval(suppress_df["top_neurons"].tolist()[n]))
            if len(boost_neurons & suppress_neurons) > 0:
                #print(suppress_df["top_neurons"].tolist()[n])
                print(suppress_df["step"].tolist()[n])
                print((boost_neurons & suppress_neurons))
                
            n += 1
    except Exception as e:
        print(f"Error in compare_idx: {e}")
        return False

In [23]:
neuron_name = [10]
models = ["70m","410m"]
vecs = ["longtail","mean"]
for vec in vecs:
    for model in models:
        for neuron_num in neuron_name:
            print(f"{vec}:{model}:{neuron_num}:{compare_idx(neuron_path,vec,model,neuron_num)}")

90000
{'5.311', '5.986', '5.404'}
0
{'5.884', '5.281', '5.381'}
40000
{'5.1622', '5.311', '5.1212', '5.213'}
1
{'5.381', '5.1661', '5.1681', '5.544', '5.1258'}
91000
{'5.302'}
41000
{'5.1622', '5.311', '5.1631', '5.213'}
2
{'5.884', '5.1661', '5.1202', '5.553', '5.1957'}
92000
{'5.311', '5.638', '5.1571', '5.404'}
42000
{'5.1622', '5.311'}
4
{'5.544', '5.1258', '5.553', '5.1007'}
93000
{'5.251', '5.638'}
43000
{'5.367', '5.1622', '5.213', '5.311', '5.1631'}
8
{'5.544', '5.884', '5.1202', '5.770'}
94000
{'5.1571', '5.404'}
44000
{'5.1622', '5.311', '5.838', '5.213'}
95000
{'5.1420', '5.1342', '5.251', '5.1571', '5.311', '5.838'}
16
{'5.1957', '5.1258', '5.68'}
45000
{'5.311'}
32
{'5.1138', '5.1298', '5.1585', '5.1739', '5.68', '5.347'}
96000
{'5.1420', '5.404', '5.302'}
46000
{'5.1622', '5.311', '5.213'}
97000
{'5.1420', '5.1571'}
64
{'5.1486', '5.1298', '5.1366'}
47000
{'5.1622', '5.311', '5.367'}
98000
{'5.302'}
128
{'5.1486', '5.1227', '5.121'}
48000
{'5.1175', '5.311', '5.213'}
9900

## inspect full ablation exp: .feather

In [2]:
ROOT = Path("/Users/jliu/workspace/RAG/")
ablation_path = ROOT / "results" / "ablations"
neuron_path = ROOT / "results" / "token_freq"

In [3]:
vector = "longtail"
model = "70m"
ckpt = "10000"
feather_path = ablation_path/vector /"EleutherAI"/f"pythia-{model}-deduped"/ckpt / "500" / "k10.feather"
data = pd.read_feather(feather_path)

In [15]:
len(set(data["component_name"]))

2048

In [11]:
data[data["activation"]<0]["activation"].max()

-1.5936679e-07

In [12]:
sel_df = data.head(100)
sel_df.to_csv("sample.csv")

In [31]:
effect = "suppress"

final_df = pd.read_feather(feather_path)
final_df["abs_delta_loss_post_ablation"] = np.abs(final_df["loss_post_ablation"] - final_df["loss"])
final_df["abs_delta_loss_post_ablation_with_frozen_unigram"] = np.abs(
    final_df["loss_post_ablation_with_frozen_unigram"] - final_df["loss"]
)
final_df["delta_loss_post_ablation"] = final_df["loss_post_ablation"] - final_df["loss"]
final_df["delta_loss_post_ablation_with_frozen_unigram"] = final_df["loss_post_ablation_with_frozen_unigram"] - final_df["loss"]

if "kl_divergence_before" in final_df.columns:
    final_df["kl_from_unigram_diff"] = final_df["kl_divergence_after"] - final_df["kl_divergence_before"]
    final_df["abs_kl_from_unigram_diff"] = np.abs(final_df["kl_divergence_after"] - final_df["kl_divergence_before"])


In [30]:
final_df.shape[0]

5144576

In [37]:

final_df1 = final_df[final_df["kl_from_unigram_diff"] < 0]

final_df2 = final_df[final_df["kl_from_unigram_diff"] > 0]


final_df3 = final_df[final_df["kl_from_unigram_diff"] == 0]

In [33]:
final_df1.shape[0]

2011341

In [38]:
final_df2.shape[0] + final_df1.shape[0] + final_df3.shape[0]

5144576

In [None]:
final_df1 ["kl_from_unigram_diff"]

In [4]:

def select_top_token_frequency_neurons(feather_path: Path, top_n: int, step:int,effect:str) -> pd.DataFrame:
    if not feather_path.is_file():
        return

    final_df = pd.read_feather(feather_path)
    final_df["abs_delta_loss_post_ablation"] = np.abs(final_df["loss_post_ablation"] - final_df["loss"])
    final_df["abs_delta_loss_post_ablation_with_frozen_unigram"] = np.abs(
        final_df["loss_post_ablation_with_frozen_unigram"] - final_df["loss"]
    )
    final_df["delta_loss_post_ablation"] = final_df["loss_post_ablation"] - final_df["loss"]
    final_df["delta_loss_post_ablation_with_frozen_unigram"] = final_df["loss_post_ablation_with_frozen_unigram"] - final_df["loss"]

    if "kl_divergence_before" in final_df.columns:
        final_df["kl_from_unigram_diff"] = final_df["kl_divergence_after"] - final_df["kl_divergence_before"]
        final_df["abs_kl_from_unigram_diff"] = np.abs(final_df["kl_divergence_after"] - final_df["kl_divergence_before"])

    if effect == "suppress":
        # filter the neurons that push towards the unigram freq
        final_df = final_df[final_df["kl_from_unigram_diff"] < 0]
    if effect == "boost":
        # filter the neurons that push away from the unigram freq
        final_df = final_df[final_df["kl_from_unigram_diff"] > 0]

    # Calculate the mediation effect
    final_df["mediation_effect"] = (
        1 - final_df["abs_delta_loss_post_ablation_with_frozen_unigram"] / final_df["abs_delta_loss_post_ablation"]
    )
    # group by neuron idx
    final_df = final_df.groupby("component_name").mean(numeric_only=True).reset_index()
    ranked_neurons = final_df.sort_values(by=["mediation_effect", "abs_kl_from_unigram_diff"], ascending=[False, False])

    # Select top N neurons, preserving the original sorting
    header_dict = {
        "component_name":"top_neurons",
        "mediation_effect":"med_effect",
        "kl_from_unigram_diff":"kl_diff",
        "delta_loss_post_ablation": "delta_loss_post",
        "delta_loss_post_ablation_with_frozen_unigram": "delta_loss_post_frozen"
        }
    df_lst = []
    for sel_header,_ in header_dict.items():
        df_lst.append([ranked_neurons[sel_header].head(top_n).tolist()])

    stat_df = pd.DataFrame(df_lst).T
    stat_df.columns = header_dict.values()
    stat_df.insert(0,"step",step)
    return stat_df

In [75]:
ranked_neurons

Unnamed: 0,component_name,index,batch,pos,token_id,entropy,top_logit,pred,loss,top_logp,...,kl_divergence_after,kl_divergence_after_frozen_unigram,longtail_threshold,num_longtail_tokens,diff_loss_frozen,abs_delta_loss_post_ablation,abs_delta_loss_post_ablation_with_frozen_unigram,kl_from_unigram_diff,abs_kl_from_unigram_diff,mediation_effect
1869,5.838,122772.892691,479.069984,130.976672,4927.985226,3.779051,16.797192,2775.627527,3.853251,-1.435874,...,5.291354,9.990311e+09,9.445668e-07,2240.0,8.436818e-07,2.297556e-06,1.866601e-06,-0.000007,0.000007,0.198883
1176,5.213,115657.675335,451.301024,124.613081,3612.337273,3.964005,16.413382,1818.835303,3.969293,-1.508335,...,5.096713,9.990311e+09,9.445668e-07,2240.0,8.659589e-06,1.468293e-05,1.339358e-05,-0.000007,0.000007,0.193907
1413,5.427,116241.998522,453.575758,126.604582,4607.801183,3.847726,16.566679,2072.930525,3.880392,-1.452252,...,5.063638,9.990311e+09,9.445668e-07,2240.0,-5.171027e-07,2.282566e-06,1.953657e-06,-0.000005,0.000005,0.185331
1285,5.311,112062.910883,437.267350,122.469243,4696.208991,3.908638,16.342705,1879.518139,3.908251,-1.449447,...,4.882395,9.990311e+09,9.445668e-07,2240.0,-7.406170e-06,1.092419e-05,9.695507e-06,-0.000014,0.000014,0.172399
1065,5.1957,114920.743448,448.434483,121.515862,3521.129655,4.909103,15.232484,3560.506207,5.039191,-1.902696,...,4.659582,9.990311e+09,9.445668e-07,2240.0,9.197318e-06,2.025649e-05,1.952193e-05,-0.000006,0.000006,0.105653
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
158,5.114,111679.030651,435.736398,130.512644,4839.815326,3.597455,16.835875,2168.600000,3.636356,-1.324581,...,5.150748,9.990311e+09,9.445667e-07,2240.0,8.017324e-07,1.636691e-06,1.996828e-06,-0.000001,0.000001,
1339,5.360,117205.374175,457.325165,130.131951,4973.139491,3.961027,16.375248,2179.238454,3.959995,-1.481460,...,4.921102,9.990311e+09,9.445667e-07,2240.0,-4.772515e-07,9.874302e-07,9.949229e-07,-0.000001,0.000001,
1341,5.362,114706.296528,447.574306,127.274306,4398.476389,3.979105,16.363232,2440.754167,3.938786,-1.478477,...,4.969233,9.990311e+09,9.445668e-07,2240.0,3.469671e-07,1.166591e-06,1.153758e-06,-0.000001,0.000001,
924,5.183,115021.935702,448.776650,135.113367,4368.744501,4.428389,15.811915,2988.164129,4.175528,-1.683964,...,4.774828,9.990311e+09,9.445668e-07,2240.0,-7.939974e-07,8.300968e-07,8.171875e-07,-0.000001,0.000001,


In [74]:
ranked_neurons["diff_loss_frozen"]

1869    8.436818e-07
1176    8.659589e-06
1413   -5.171027e-07
1285   -7.406170e-06
1065    9.197318e-06
            ...     
158     8.017324e-07
1339   -4.772515e-07
1341    3.469671e-07
924    -7.939974e-07
1704   -3.162685e-07
Name: diff_loss_frozen, Length: 2048, dtype: float64

In [70]:
kl_diff

[-7.130773610697361e-06,
 -7.038608600851148e-06,
 -4.76986951980507e-06,
 -1.4116523743723519e-05,
 -5.648876140185166e-06,
 -3.520868267514743e-05,
 -3.3175299449794693e-06,
 -4.951208666170714e-06,
 -1.767237154126633e-05,
 -3.821765403699828e-06]

In [48]:
data.head(10)

Unnamed: 0,index,str_tokens,unique_token,context,batch,pos,label,token_id,entropy,top_logit,...,loss_post_ablation_with_frozen_unigram,entropy_post_ablation,entropy_post_ablation_with_frozen_unigram,kl_divergence_before,kl_divergence_after,kl_divergence_after_frozen_unigram,ablation_mode,longtail_threshold,num_longtail_tokens,diff_loss_frozen
0,8707,’,’/3,<|endoftext|> support Amazon|’|s,34,3,34/3,457,0.916765,16.608845,...,0.091855,0.91676,0.91676,4.79005,4.790049,9990311000.0,longtail,9.445668e-07,2240.0,-6.556511e-07
1,8708,s,s/4,<|endoftext|> support Amazon’|s| arrival,34,4,34/4,84,7.373738,12.153364,...,9.126356,7.373729,7.37373,3.420141,3.420139,9990311000.0,longtail,9.445668e-07,2240.0,-1.907349e-06
2,8709,arrival,arrival/5,"<|endoftext|> support Amazon’s| arrival|,",34,5,34/5,13024,4.47116,14.695408,...,2.422079,4.471157,4.471158,2.991667,2.991667,9990311000.0,longtail,9.445668e-07,2240.0,-2.384186e-07
3,8710,",",",/6","support Amazon’s arrival|,| several",34,6,34/6,13,5.7082,14.091806,...,7.784938,5.708184,5.708185,3.667704,3.667702,9990311000.0,longtail,9.445668e-07,2240.0,-2.384186e-06
4,8711,several,several/7,"Amazon’s arrival,| several| surveys",34,7,34/7,2067,6.886786,12.835464,...,8.572117,6.886886,6.886879,3.310757,3.310778,9990311000.0,longtail,9.445668e-07,2240.0,2.193451e-05
5,8712,surveys,surveys/8,"’s arrival, several| surveys| have",34,8,34/8,17276,4.310981,16.292671,...,1.798479,4.310988,4.310988,3.779734,3.779734,9990311000.0,longtail,9.445668e-07,2240.0,8.34465e-07
6,8713,have,have/9,"s arrival, several surveys| have| found",34,9,34/9,452,4.383304,17.071587,...,2.388429,4.383301,4.383301,5.107384,5.107384,9990311000.0,longtail,9.445668e-07,2240.0,-2.384186e-07
7,8714,found,found/10,"arrival, several surveys have| found|.",34,10,34/10,1119,4.619333,15.078238,...,3.841269,4.61937,4.619363,3.19087,3.190871,9990311000.0,longtail,9.445668e-07,2240.0,3.33786e-06
8,8715,.,./11,", several surveys have found|.| Business",34,11,34/11,15,4.389691,14.009556,...,10.034325,4.389685,4.389686,3.516019,3.516018,9990311000.0,longtail,9.445668e-07,2240.0,-9.536743e-07
9,8716,Business,Business/12,several surveys have found.| Business| organi...,34,12,34/12,10518,4.128812,15.715575,...,6.508544,4.128804,4.128805,3.807681,3.80768,9990311000.0,longtail,9.445668e-07,2240.0,-9.536743e-07


In [69]:
len(set(final_df["diff_loss_frozen"]))

2048

In [32]:
def get_mean(med_lst):
    # convert the input string into a list
    med_lst = ast.literal_eval(med_lst)
    return sum(med_lst)/len(med_lst)

In [34]:
def compute_mean(neuron_num,model,vec,neuron_type):
    neuron_file = neuron_path / neuron_type/vec/"EleutherAI"/f"pythia-{model}-deduped" / f"500_{neuron_num}.csv"
    data = pd.read_csv(neuron_file)
    data["mean_med"] = data["med_effect"].apply(get_mean)
    return set(data["mean_med"])

In [37]:
neuron_name = [10,50,500]
models = ["70m","410m"]
vecs = ["longtail","mean"]
neuron_types = ["boost","suppress"]
for vec in vecs:
    for neuron_type in neuron_types:
        for model in models:
            for neuron_num in neuron_name:
                print(f"{vec}:{neuron_type}:{model}:{neuron_num}:{compute_mean(neuron_num,model)}")

longtail:boost:70m:10:{1.0}
longtail:boost:70m:50:{1.0}
longtail:boost:70m:500:{1.0}
longtail:boost:410m:10:{1.0}
longtail:boost:410m:50:{1.0}
longtail:boost:410m:500:{1.0}
longtail:suppress:70m:10:{1.0}
longtail:suppress:70m:50:{1.0}
longtail:suppress:70m:500:{1.0}
longtail:suppress:410m:10:{1.0}
longtail:suppress:410m:50:{1.0}
longtail:suppress:410m:500:{1.0}
mean:boost:70m:10:{1.0}
mean:boost:70m:50:{1.0}
mean:boost:70m:500:{1.0}
mean:boost:410m:10:{1.0}
mean:boost:410m:50:{1.0}
mean:boost:410m:500:{1.0}
mean:suppress:70m:10:{1.0}
mean:suppress:70m:50:{1.0}
mean:suppress:70m:500:{1.0}
mean:suppress:410m:10:{1.0}
mean:suppress:410m:50:{1.0}
mean:suppress:410m:500:{1.0}


In [24]:
data.head(10)

Unnamed: 0.1,Unnamed: 0,step,top_neurons,med_effect,kl_diff
0,0,90000,"['5.110', '5.96', '5.838', '5.1622', '5.587', ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.001056671142578125, 0.00030994415283203125,..."
1,0,0,"['5.45', '5.1781', '5.1585', '5.1585', '5.1585...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[4.935264587402344e-05, 3.790855407714844e-05,..."
2,0,40000,"['5.189', '5.771', '5.1099', '5.1018', '5.67',...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0013713836669921875, 0.0013532638549804688,..."
3,0,1,"['5.1585', '5.2037', '5.175', '5.1585', '5.166...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[3.6716461181640625e-05, 3.647804260253906e-05..."
4,0,91000,"['5.1622', '5.1622', '5.1622', '5.1622', '5.16...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.00021791458129882812, 0.00017547607421875, ..."
5,0,41000,"['5.1622', '5.1273', '5.1191', '5.1273', '5.11...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.00026988983154296875, 9.822845458984375e-05..."
6,0,2,"['5.2026', '5.1585', '5.1957', '5.945', '5.158...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[3.9577484130859375e-05, 3.886222839355469e-05..."
7,0,92000,"['5.214', '5.1132', '5.1954', '5.1978', '5.769...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.00029277801513671875, 0.00017547607421875, ..."
8,0,42000,"['5.518', '5.1622', '5.1622', '5.1622', '5.147...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.00026702880859375, 0.0001735687255859375, 0..."
9,0,4,"['5.1585', '5.1880', '5.1585', '5.1585', '5.15...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[4.267692565917969e-05, 3.743171691894531e-05,..."


In [28]:
data

Unnamed: 0.1,Unnamed: 0,step,top_neurons,med_effect,kl_diff
0,0,90000,"['5.110', '5.96', '5.838', '5.1622', '5.587', ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.001056671142578125, 0.00030994415283203125,..."
1,0,0,"['5.45', '5.1781', '5.1585', '5.1585', '5.1585...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[4.935264587402344e-05, 3.790855407714844e-05,..."
2,0,40000,"['5.189', '5.771', '5.1099', '5.1018', '5.67',...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0013713836669921875, 0.0013532638549804688,..."
3,0,1,"['5.1585', '5.2037', '5.175', '5.1585', '5.166...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[3.6716461181640625e-05, 3.647804260253906e-05..."
4,0,91000,"['5.1622', '5.1622', '5.1622', '5.1622', '5.16...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.00021791458129882812, 0.00017547607421875, ..."
...,...,...,...,...,...
149,0,39000,"['5.759', '5.1954', '5.1954', '5.1306', '5.147...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0001163482666015625, 8.821487426757812e-05,..."
150,0,140000,"['5.1174', '5.769', '5.372', '5.587', '5.703',...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.00018262863159179688, 0.000171661376953125,..."
151,0,141000,"['5.1758', '5.587', '5.769', '5.703', '5.587',...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0002574920654296875, 0.0001983642578125, 0...."
152,0,142000,"['5.1263', '5.1529', '5.457', '5.587', '5.50',...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0061244964599609375, 0.0006308555603027344,..."


In [29]:
data["mean_med"] = data["med_effect"].apply(get_mean)

In [31]:
set(data["mean_med"])

{1.0}