In [9]:
import os
from os.path import join as pj

import pandas as pd

from sae_cooccurrence.normalised_cooc_functions import neat_sae_id
from sae_cooccurrence.utils.saving_loading import set_device
from sae_cooccurrence.utils.set_paths import get_git_root

In [11]:
# Setup
device = set_device()
git_root = get_git_root()

# Configuration
gpt_model_name = "gpt2-small"
gpt_sae_release_short = "res-jb-feature-splitting"
gpt_sae_ids = [
    "blocks.8.hook_resid_pre_768",
    "blocks.8.hook_resid_pre_1536",
    "blocks.8.hook_resid_pre_3072",
    "blocks.8.hook_resid_pre_6144",
    "blocks.8.hook_resid_pre_12288",
    "blocks.8.hook_resid_pre_24576",
]

gemma_width_model_name = "gemma-2-2b"
gemma_width_sae_release_short = "gemma-scope-2b-pt-res-canonical"
gemma_width_sae_ids = [
    "layer_12/width_16k/canonical",
    "layer_12/width_32k/canonical",
    "layer_12/width_65k/canonical",
    "layer_12/width_262k/canonical",
    "layer_12/width_524k/canonical",
    "layer_12/width_1m/canonical",
]

gemma_l0_model_name = "gemma-2-2b"
gemma_l0_sae_release_short = "gemma-scope-2b-pt-res"
gemma_l0_sae_ids = [
    "layer_12/width_16k/average_l0_176",
    "layer_12/width_16k/average_l0_22",
    "layer_12/width_16k/average_l0_41",
    "layer_12/width_16k/average_l0_445",
    "layer_12/width_16k/average_l0_82",
]

n_batches = 10
activation_threshold = 1.5  # You can adjust this threshold
activation_threshold_safe = str(activation_threshold).replace(".", "_")

Using MPS


In [13]:
def load_node_info_df(model_name, sae_release_short, sae_id, activation_threshold_safe):
    base_path = pj(get_git_root(), "results", model_name, sae_release_short)
    sae_id_neat = neat_sae_id(sae_id)
    file_path = pj(
        base_path,
        sae_id_neat,
        "dataframes",
        f"node_info_df_{activation_threshold_safe}.csv",
    )
    if os.path.exists(file_path):
        df = pd.read_csv(file_path)
        return df
    else:
        print(f"Warning: File not found - {file_path}")
        return None


# Load node_info_df for GPT-2
gpt_node_info_dfs = {}
for sae_id in gpt_sae_ids:
    df = load_node_info_df(
        gpt_model_name, gpt_sae_release_short, sae_id, activation_threshold_safe
    )
    if df is not None:
        gpt_node_info_dfs[sae_id] = df

# Load node_info_df for Gemma (width comparison)
gemma_width_node_info_dfs = {}
for sae_id in gemma_width_sae_ids:
    df = load_node_info_df(
        gemma_width_model_name,
        gemma_width_sae_release_short,
        sae_id,
        activation_threshold_safe,
    )
    if df is not None:
        gemma_width_node_info_dfs[sae_id] = df

# Load node_info_df for Gemma (L0 comparison)
gemma_l0_node_info_dfs = {}
for sae_id in gemma_l0_sae_ids:
    df = load_node_info_df(
        gemma_l0_model_name,
        gemma_l0_sae_release_short,
        sae_id,
        activation_threshold_safe,
    )
    if df is not None:
        gemma_l0_node_info_dfs[sae_id] = df

# Print summary of loaded dataframes
print(f"Loaded {len(gpt_node_info_dfs)} node_info_dfs for GPT-2")
print(
    f"Loaded {len(gemma_width_node_info_dfs)} node_info_dfs for Gemma (width comparison)"
)
print(f"Loaded {len(gemma_l0_node_info_dfs)} node_info_dfs for Gemma (L0 comparison)")

Loaded 6 node_info_dfs for GPT-2
Loaded 3 node_info_dfs for Gemma (width comparison)
Loaded 5 node_info_dfs for Gemma (L0 comparison)


In [15]:
# Create output directories
gpt_output_dir = pj(
    git_root,
    "results",
    "size_effects",
    gpt_model_name,
    gpt_sae_release_short,
    f"l0_comparison_{activation_threshold_safe}",
)
gemma_width_output_dir = pj(
    git_root,
    "results",
    "size_effects",
    gemma_width_model_name,
    gemma_width_sae_release_short,
    f"l0_comparison_{activation_threshold_safe}",
)
gemma_l0_output_dir = pj(
    git_root,
    "results",
    "size_effects",
    gemma_l0_model_name,
    gemma_l0_sae_release_short,
    f"l0_comparison_{activation_threshold_safe}",
)

os.makedirs(gpt_output_dir, exist_ok=True)
os.makedirs(gemma_width_output_dir, exist_ok=True)
os.makedirs(gemma_l0_output_dir, exist_ok=True)

TypeError: load_node_info_df() missing 1 required positional argument: 'activation_threshold_safe'