Performed for a run where the original graphs were run for 100 batchs

In [None]:
%load_ext autoreload
%autoreload 2
import logging
import os
from os.path import join as pj

import numpy as np
import pandas as pd
import torch
from sae_lens import SAE, ActivationsStore
from transformer_lens import HookedTransformer

from sae_cooccurrence.normalised_cooc_functions import (
    create_results_dir,
    get_sae_release,
    neat_sae_id,
)
from sae_cooccurrence.pca import (
    analyze_specific_points_from_thresholded,
    calculate_pca_decoder,
    create_pca_plots_decoder,
    generate_data,
    load_data_from_pickle,
    plot_doubly_clustered_activation_heatmap,
    plot_feature_activation_counts,
    plot_feature_activation_trends_representative_points,
    plot_pca_explanation_and_save,
    plot_pca_feature_strength,
    plot_pca_with_active_features,
    plot_pca_with_top_feature,
    plot_simple_scatter,
    plot_token_pca_and_save,
    save_data_to_pickle,
)
from sae_cooccurrence.pca_animation import (
    analyze_specific_points_animated_from_thresholded,
)
from sae_cooccurrence.utils.saving_loading import load_npz_files, set_device
from sae_cooccurrence.utils.set_paths import get_git_root

# Set up logging and paths


In [None]:
def setup_logging(log_path):
    logging.basicConfig(
        filename=log_path,
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )


# Config -------------
torch.set_grad_enabled(False)
device = set_device()
git_root = get_git_root()

# Settings to perform PCA on a particular subgraph

In [81]:
save_figs = True

# model_name = "gpt2-small"
# sae_release_short = "res-jb-feature-splitting"
# sae_id = "blocks.8.hook_resid_pre_24576"
# n_batches_reconstruction = 25

model_name = "gemma-2-2b"
sae_release_short = "gemma-scope-2b-pt-res-canonical"
sae_id = "layer_12/width_16k/canonical"
n_batches_reconstruction = 500
remove_special_tokens = True

activation_threshold = 1.5
subgraph_id = 4334
n_batches_generation = 100

In [None]:
np.random.seed(1234)

fs_splitting_cluster = subgraph_id
pca_prefix = "pca"

# Load model
model = HookedTransformer.from_pretrained(model_name, device=device)

# Process the specific subgraph
sae_id_neat = neat_sae_id(sae_id)
results_dir = create_results_dir(
    model_name, sae_release_short, sae_id_neat, n_batches_generation
)
results_path = pj(git_root, results_dir)
activation_threshold_safe = str(activation_threshold).replace(".", "_")

thresholded_matrix = np.load(
    pj(
        results_path,
        f"thresholded_matrices/thresholded_matrix_{activation_threshold_safe}.npz",
    )
)["arr_0"]

figures_path = pj(git_root, f"figures/{model_name}/{sae_release_short}/{sae_id_neat}")
pca_dir = f"{pca_prefix}_{activation_threshold_safe}_subgraph_{subgraph_id}"
pca_path = pj(figures_path, pca_dir)
if not os.path.exists(pca_path):
    os.makedirs(pca_path)
pickle_file = pj(pca_path, f"pca_data_subgraph_{subgraph_id}.pkl")

# Set up logging
log_path = pj(pca_path, "pca_analysis.log")
setup_logging(log_path)

# Log all settings
logging.info("Script started")
logging.info("Settings:")
logging.info(f"  save_figs: {save_figs}")
logging.info(f"  git_root: {git_root}")
logging.info(f"  sae_id: {sae_id}")
logging.info(f"  activation_threshold: {activation_threshold}")
logging.info(f"  subgraph_id: {subgraph_id}")
logging.info(f"  fs_splitting_cluster: {fs_splitting_cluster}")
logging.info(f"  pca_prefix: {pca_prefix}")
logging.info(f"  model_name: {model_name}")
logging.info(f"  sae_release_short: {sae_release_short}")
logging.info(f"  n_batches_reconstruction: {n_batches_reconstruction}")
logging.info(f"  device: {device}")
logging.info(f"  results_path: {results_path}")
logging.info(f"  pca_path: {pca_path}")

In [None]:
node_df = pd.read_csv(
    pj(results_path, f"dataframes/node_info_df_{activation_threshold_safe}.csv")
)
logging.info(
    f"Loaded node_df from {pj(results_path, f'dataframes/node_info_df_{activation_threshold_safe}.csv')}"
)

overall_feature_activations = load_npz_files(
    results_path, "feature_acts_cooc_activations"
).get(activation_threshold)

# with open(pj(results_path, f"subgraph_objects/activation_{activation_threshold_safe}/subgraph_{subgraph_id}.pkl"), 'rb') as f:
#     subgraph = pickle.load(f)


# Filter for the specific subgraph
fs_splitting_nodes = node_df.query("subgraph_id == @subgraph_id")["node_id"].tolist()

In [None]:
regen_data = False
if not regen_data:
    raise ValueError("Are you sure you don't want to use existing data?")

In [85]:
if not regen_data and os.path.exists(pickle_file):
    data = load_data_from_pickle(pickle_file)
    results = data["results"]
    pca_df = data["pca_df"]
    pca = data["pca"]
    pca_decoder = data["pca_decoder"]
    pca_decoder_df = data["pca_decoder_df"]
else:
    sae_release = get_sae_release(model_name, sae_release_short)

    # Load SAE and set up activation store
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=sae_release, sae_id=sae_id, device=device
    )
    sae.fold_W_dec_norm()

    activation_store = ActivationsStore.from_sae(
        model=model,
        sae=sae,
        streaming=True,
        store_batch_size_prompts=8,
        train_batch_size_tokens=4096,
        n_batches_in_buffer=32,
        device=device,
    )

    data = generate_data(
        model,
        sae,
        activation_store,
        fs_splitting_nodes,
        n_batches_reconstruction,
        decoder=False,
        remove_special_tokens=remove_special_tokens,
        device=device,
    )

    if regen_data:
        save_data_to_pickle(data, pickle_file)

    results = data["results"]
    pca_df = data["pca_df"]
    pca = data["pca"]
    pca_decoder = data["pca_decoder"]
    pca_decoder_df = data["pca_decoder_df"]

In [None]:
# # Save pca_df as CSV
# pca_df_filename = f"pca_df_subgraph_{subgraph_id}.csv"
# pca_df.to_csv(pj(pca_path, pca_df_filename), index=False)

plot_token_pca_and_save(pca_df, pca_path, subgraph_id, color_by="token", save=save_figs)

plot_pca_explanation_and_save(pca, pca_path, subgraph_id, save=save_figs)

plot_simple_scatter(results, pca_path, subgraph_id, fs_splitting_nodes, save=save_figs)

if pca_decoder is not None:
    pca_decoder, pca_decoder_df = calculate_pca_decoder(sae, fs_splitting_nodes)
    # Save pca_decoder_df as CSV
    pca_decoder_df_filename = f"pca_decoder_df_subgraph_{subgraph_id}.csv"
    pca_decoder_df.to_csv(pj(pca_path, pca_decoder_df_filename), index=False)

    create_pca_plots_decoder(pca_decoder_df, subgraph_id, pca_path, save=save_figs)

print(f"Processing completed for subgraph ID {subgraph_id}")

In [87]:
plot_pca_with_top_feature(
    pca_df, results, fs_splitting_nodes, fs_splitting_cluster, pca_path, save=save_figs
)

In [88]:
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC1",
    pc_y="PC2",
    save=save_figs,
)
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC1",
    pc_y="PC3",
    save=save_figs,
)
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC2",
    pc_y="PC3",
    save=save_figs,
)

In [89]:
plot_pca_with_active_features(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    activation_threshold=activation_threshold,
    save=save_figs,
)

In [90]:
plot_doubly_clustered_activation_heatmap(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=save_figs,
)

In [None]:
# just straight line
interesting_point_ids_name_to_thing = [
    15554,
    20044,
    8096,
    9367,
    5629,
    15233,
    10617,
    14254,
    6310,
    9660,
    4074,
]  # Replace with actual IDs of interest
analyze_specific_points_from_thresholded(
    results=results,
    thresholded_matrix=thresholded_matrix,
    fs_splitting_nodes=fs_splitting_nodes,
    fs_splitting_cluster=fs_splitting_cluster,
    activation_threshold=activation_threshold,
    node_df=node_df,
    pca_df=pca_df,
    point_ids=interesting_point_ids_name_to_thing,
    subdir="straight_line_name_to_thing",
    save_figs=True,
    pca_path=pca_path,
)

In [None]:
analyze_specific_points_animated_from_thresholded(
    results=results,
    thresholded_matrix=thresholded_matrix,
    fs_splitting_nodes=fs_splitting_nodes,
    fs_splitting_cluster=fs_splitting_cluster,
    # activation_threshold=activation_threshold,
    node_df=node_df,
    pca_df=pca_df,
    point_ids=interesting_point_ids_name_to_thing,
    # results_path=results_path,
    plot_only_fs_nodes=True,
    save_gif=True,
    gif_path=pca_path,
    gif_filename="straight_line_name_to_thing.gif",
    frame_folder_name="straight_gif_frames_name_to_thing",
)

In [None]:
plot_feature_activation_trends_representative_points(
    results,
    fs_splitting_nodes,
    interesting_point_ids_name_to_thing,
    pca_df,
    save_figs=True,
    pca_path=pca_path,
    subdir="straight_line_name_to_thing",
    filename="feature_activation_trends_name_to_thing",
)

In [None]:
# just straight line
interesting_point_ids_name_to_generic_person = [
    15554,
    20105,
    17434,
    6837,
    11308,
    19926,
    5875,
    15795,
    14216,
    # 17671,
    20406,
    17631,
    17110,
    11443,
    17133,
    17330,
    4567,
    4895,
]  # Replace with actual IDs of interest
analyze_specific_points_from_thresholded(
    results=results,
    thresholded_matrix=thresholded_matrix,
    fs_splitting_nodes=fs_splitting_nodes,
    fs_splitting_cluster=fs_splitting_cluster,
    activation_threshold=activation_threshold,
    node_df=node_df,
    pca_df=pca_df,
    point_ids=interesting_point_ids_name_to_generic_person,
    subdir="straight_line_name_to_generic_person",
    save_figs=True,
    pca_path=pca_path,
)

In [None]:
analyze_specific_points_animated_from_thresholded(
    results=results,
    thresholded_matrix=thresholded_matrix,
    fs_splitting_nodes=fs_splitting_nodes,
    fs_splitting_cluster=fs_splitting_cluster,
    # activation_threshold=activation_threshold,
    node_df=node_df,
    pca_df=pca_df,
    point_ids=interesting_point_ids_name_to_generic_person,
    # results_path=results_path,
    plot_only_fs_nodes=True,
    save_gif=True,
    gif_path=pca_path,
    gif_filename="straight_line_name_to_generic_person.gif",
    frame_folder_name="straight_gif_frames_name_to_generic_person",
)

In [None]:
plot_feature_activation_trends_representative_points(
    results,
    fs_splitting_nodes,
    interesting_point_ids_name_to_generic_person,
    pca_df,
    save_figs=True,
    pca_path=pca_path,
    subdir="straight_line_name_to_generic_person",
    filename="feature_activation_trends_name_to_generic_person",
)

In [111]:
pca_df_max_examples = pca_df.copy()
pca_df_max_examples["Context"] = pca_df_max_examples["context"].apply(lambda x: repr(x))
pca_df_max_examples["combined"] = (
    pca_df_max_examples["PC2"] + pca_df_max_examples["PC3"]
)
examples = pca_df_max_examples.sort_values("combined", ascending=False).head(10)
examples.rename(columns={"context": "Context"}, inplace=True)
# show only columns for point id and context
examples[["Context"]].to_csv(pj(pca_path, "top_examples_top_right.csv"), index=False)

In [107]:
examples

Unnamed: 0,PC1,PC2,PC3,tokens,Context,point_id,active_feature,top_activation,top_features,combined
4583,23.958549,7.801569,15.605243,',the buyers changed their position in reliance...,4583,4572,35.897873,"4572, 9754",23.406813
17337,20.419363,7.034844,15.721357,',2002. It is apparent from plaintiffs|'| papers...,17337,4572,34.139557,"4572, 9754",22.756201
17308,21.14625,7.231211,15.448147,',scour the purchase agreement looking for the ...,17308,4572,34.269295,"4572, 9754",22.679358
17330,20.692991,7.188835,15.105166,',and those records are more than adequate to m...,17330,4572,33.712051,"4572, 9754",22.294001
1070,25.607754,8.240004,13.391506,',Regarding Jury Deliberations\n\nWe next consi...,1070,4572,34.466255,"4572, 9754",21.63151
9976,17.885668,6.633146,14.860841,’,"playgrounds, were joined by an increasing awa...",9976,4572,32.005165,"4572, 9754",21.493986
17354,23.038532,6.061219,15.394769,',alone lay persons on a jury. To prove plainti...,17354,4572,34.520355,"4572, 9754",21.455988
3276,23.302017,7.641175,13.647782,',unconstitutional. The only issue presented in ...,3276,4572,33.419617,"4572, 9754",21.288958
1077,19.350863,7.044068,14.21327,',"these rules of law.""\nAll of the plaintiffs|'...",1077,4572,32.176796,"4572, 9754",21.257338
4584,20.210949,7.266536,13.953338,',rescind the earnest money agreement because o...,4584,4572,32.389084,"4572, 9754",21.219873


In [112]:
pca_df_max_examples = pca_df.copy()
pca_df_max_examples["Context"] = pca_df_max_examples["context"].apply(lambda x: repr(x))
pca_df_max_examples["combined"] = -pca_df_max_examples["PC2"]
examples = pca_df_max_examples.sort_values("combined", ascending=False).head(10)
examples.rename(columns={"context": "Context"}, inplace=True)
# show only columns for point id and context
examples[["Context"]].to_csv(pj(pca_path, "top_examples_left.csv"), index=False)

In [113]:
pca_df_max_examples = pca_df.copy()
pca_df_max_examples["Context"] = pca_df_max_examples["context"].apply(lambda x: repr(x))
pca_df_max_examples["combined"] = (
    pca_df_max_examples["PC2"] - pca_df_max_examples["PC3"]
)
examples = pca_df_max_examples.sort_values("combined", ascending=False).head(10)
examples.rename(columns={"context": "Context"}, inplace=True)
# show only columns for point id and context
examples[["Context"]].to_csv(pj(pca_path, "top_examples_bottom_right.csv"), index=False)

In [117]:
plot_feature_activation_counts(
    results,
    fs_splitting_nodes,
    interesting_point_ids_name_to_generic_person,
    pca_df,
    save_figs=True,
    pca_path=pca_path,
    activation_threshold=10,
)

In [120]:
plot_feature_activation_counts(
    results,
    fs_splitting_nodes,
    interesting_point_ids_name_to_thing,
    pca_df,
    save_figs=True,
    pca_path=pca_path,
    activation_threshold=10,
)