Performed for a run where the original graphs were run for 100 batchs

In [1]:
%load_ext autoreload
%autoreload 2
import logging
import os
from os.path import join as pj

import numpy as np
import torch
from sae_lens import SAE, ActivationsStore
from transformer_lens import HookedTransformer

from sae_cooccurrence.normalised_cooc_functions import (
    create_results_dir,
    get_sae_release,
    neat_sae_id,
)
from sae_cooccurrence.pca import (
    calculate_pca_decoder,
    create_pca_plots_decoder,
    generate_data,
    load_data_from_pickle,
    plot_doubly_clustered_activation_heatmap,
    plot_pca_explanation_and_save,
    plot_pca_feature_strength,
    plot_pca_with_active_features,
    plot_pca_with_top_feature,
    plot_simple_scatter,
    plot_token_pca_and_save,
    save_data_to_pickle,
)
from sae_cooccurrence.utils.saving_loading import set_device
from sae_cooccurrence.utils.set_paths import get_git_root

  warn(


# Set up logging and paths


In [2]:
def setup_logging(log_path):
    logging.basicConfig(
        filename=log_path,
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )


# Config -------------
torch.set_grad_enabled(False)
device = set_device()
git_root = get_git_root()

Using GPU 0 out of 1 available GPUs


# Settings to perform PCA on a particular subgraph

In [3]:
save_figs = True
n_batches_generation = 100


# model_name = "gpt2-small"
# sae_release_short = "res-jb-feature-splitting"
# sae_id = "blocks.8.hook_resid_pre_24576"
# n_batches_reconstruction = 25

model_name = "gemma-2-2b"
sae_release_short = "gemma-scope-2b-pt-res-canonical"
sae_id = "layer_0/width_16k/canonical"
n_batches_reconstruction = 100
remove_special_tokens = True

activation_threshold = 1.5
subgraph_id = "test"

In [4]:
np.random.seed(1234)

fs_splitting_cluster = 0
pca_prefix = "pca"

# Load model
model = HookedTransformer.from_pretrained(model_name, device=device)

# Process the specific subgraph
sae_id_neat = neat_sae_id(sae_id)
results_dir = create_results_dir(
    model_name, sae_release_short, sae_id_neat, n_batches_generation
)
results_path = pj(git_root, results_dir)
activation_threshold_safe = str(activation_threshold).replace(".", "_")

figures_path = pj(git_root, f"figures/{model_name}/{sae_release_short}/{sae_id_neat}")
pca_dir = f"{pca_prefix}_{activation_threshold_safe}_subgraph_{subgraph_id}"
pca_path = pj(figures_path, pca_dir)
if not os.path.exists(pca_path):
    os.makedirs(pca_path)
pickle_file = pj(pca_path, f"pca_data_subgraph_{subgraph_id}.pkl")

# Set up logging
log_path = pj(pca_path, "pca_analysis.log")
setup_logging(log_path)

# Log all settings
logging.info("Script started")
logging.info("Settings:")
logging.info(f"  save_figs: {save_figs}")
logging.info(f"  git_root: {git_root}")
logging.info(f"  sae_id: {sae_id}")
logging.info(f"  activation_threshold: {activation_threshold}")
logging.info(f"  subgraph_id: {subgraph_id}")
logging.info(f"  fs_splitting_cluster: {fs_splitting_cluster}")
logging.info(f"  pca_prefix: {pca_prefix}")
logging.info(f"  model_name: {model_name}")
logging.info(f"  sae_release_short: {sae_release_short}")
logging.info(f"  n_batches_reconstruction: {n_batches_reconstruction}")
logging.info(f"  device: {device}")
logging.info(f"  results_path: {results_path}")
logging.info(f"  pca_path: {pca_path}")



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [5]:
# node_df = pd.read_csv(
#     pj(results_path, f"dataframes/node_info_df_{activation_threshold_safe}.csv")
# )
# logging.info(
#     f"Loaded node_df from {pj(results_path, f'dataframes/node_info_df_{activation_threshold_safe}.csv')}"
# )

# overall_feature_activations = load_npz_files(
#     results_path, "feature_acts_cooc_activations"
# ).get(activation_threshold)

# with open(pj(results_path, f"subgraph_objects/activation_{activation_threshold_safe}/subgraph_{subgraph_id}.pkl"), 'rb') as f:
#     subgraph = pickle.load(f)


# Filter for the specific subgraph
fs_splitting_nodes = [6449, 8129, 13989, 13623, 10032, 1469]

In [6]:
regen_data = True
if not regen_data:
    raise ValueError("Are you sure you don't want to use existing data?")

In [7]:
if not regen_data and os.path.exists(pickle_file):
    data = load_data_from_pickle(pickle_file)
    results = data["results"]
    pca_df = data["pca_df"]
    pca = data["pca"]
    pca_decoder = data["pca_decoder"]
    pca_decoder_df = data["pca_decoder_df"]
else:
    sae_release = get_sae_release(model_name, sae_release_short)

    # Load SAE and set up activation store
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=sae_release, sae_id=sae_id, device=device
    )
    sae.fold_W_dec_norm()

    activation_store = ActivationsStore.from_sae(
        model=model,
        sae=sae,
        streaming=True,
        store_batch_size_prompts=8,
        train_batch_size_tokens=4096,
        n_batches_in_buffer=32,
        device=device,
    )

    data = generate_data(
        model,
        sae,
        activation_store,
        fs_splitting_nodes,
        n_batches_reconstruction,
        decoder=False,
        remove_special_tokens=remove_special_tokens,
        device=device,
    )

    if regen_data:
        save_data_to_pickle(data, pickle_file)

    results = data["results"]
    pca_df = data["pca_df"]
    pca = data["pca"]
    pca_decoder = data["pca_decoder"]
    pca_decoder_df = data["pca_decoder_df"]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]


Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info.



  0%|          | 0/100 [00:00<?, ?it/s]

Total examples found: 3346


In [8]:
# # Save pca_df as CSV
# pca_df_filename = f"pca_df_subgraph_{subgraph_id}.csv"
# pca_df.to_csv(pj(pca_path, pca_df_filename), index=False)

plot_token_pca_and_save(pca_df, pca_path, subgraph_id, color_by="token", save=save_figs)

plot_pca_explanation_and_save(pca, pca_path, subgraph_id, save=save_figs)

plot_simple_scatter(results, pca_path, subgraph_id, fs_splitting_nodes, save=save_figs)

if pca_decoder is not None:
    pca_decoder, pca_decoder_df = calculate_pca_decoder(sae, fs_splitting_nodes)
    # Save pca_decoder_df as CSV
    pca_decoder_df_filename = f"pca_decoder_df_subgraph_{subgraph_id}.csv"
    pca_decoder_df.to_csv(pj(pca_path, pca_decoder_df_filename), index=False)

    create_pca_plots_decoder(pca_decoder_df, subgraph_id, pca_path, save=save_figs)

print(f"Processing completed for subgraph ID {subgraph_id}")

Processing completed for subgraph ID test


In [9]:
plot_pca_with_top_feature(
    pca_df, results, fs_splitting_nodes, fs_splitting_cluster, pca_path, save=save_figs
)

In [10]:
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC1",
    pc_y="PC2",
    save=save_figs,
)
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC1",
    pc_y="PC3",
    save=save_figs,
)
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC2",
    pc_y="PC3",
    save=save_figs,
)

In [11]:
plot_pca_with_active_features(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    activation_threshold=activation_threshold,
    save=save_figs,
)

In [12]:
plot_doubly_clustered_activation_heatmap(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=save_figs,
)

In [19]:
# analyze_representative_points(
#     results=results,
#     fs_splitting_nodes=fs_splitting_nodes,
#     fs_splitting_cluster=fs_splitting_cluster,
#     activation_threshold=activation_threshold,
#     node_df=node_df,
#     results_path=results_path,
#     pca_df=pca_df,
#     save_figs=True,
#     pca_path=pca_path,
# )


Analyzing representative point 1:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



Number of non-zero features: 39
Number of non-zero feature splitting nodes: 1
Total number of feature splitting nodes: 5
Mean activation of non-zero feature splitting nodes: 0.3027
Mean activation of non-zero non-feature splitting nodes: 3.2734
Median activation of non-zero feature splitting nodes: 0.3027
Median activation of non-zero non-feature splitting nodes: 1.1038
Number of splitting features active above threshold: 0
Number of non-splitting features active above threshold: 15
Sum of activation strengths for splitting features: 0.3027
Sum of activation strengths for non-splitting features: 124.3899

Analyzing representative point 2:


Number of non-zero features: 23
Number of non-zero feature splitting nodes: 1
Total number of feature splitting nodes: 5
Mean activation of non-zero feature splitting nodes: 0.4615
Mean activation of non-zero non-feature splitting nodes: 4.6290
Median activation of non-zero feature splitting nodes: 0.4615
Median activation of non-zero non-feature splitting nodes: 1.4941
Number of splitting features active above threshold: 0
Number of non-splitting features active above threshold: 11
Sum of activation strengths for splitting features: 0.4615
Sum of activation strengths for non-splitting features: 101.8379

Analyzing representative point 3:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



Number of non-zero features: 55
Number of non-zero feature splitting nodes: 1
Total number of feature splitting nodes: 5
Mean activation of non-zero feature splitting nodes: 1.2277
Mean activation of non-zero non-feature splitting nodes: 2.9216
Median activation of non-zero feature splitting nodes: 1.2277
Median activation of non-zero non-feature splitting nodes: 1.2821
Number of splitting features active above threshold: 0
Number of non-splitting features active above threshold: 26
Sum of activation strengths for splitting features: 1.2277
Sum of activation strengths for non-splitting features: 157.7681

Analyzing representative point 4:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



Number of non-zero features: 55
Number of non-zero feature splitting nodes: 1
Total number of feature splitting nodes: 5
Mean activation of non-zero feature splitting nodes: 1.2277
Mean activation of non-zero non-feature splitting nodes: 2.9216
Median activation of non-zero feature splitting nodes: 1.2277
Median activation of non-zero non-feature splitting nodes: 1.2821
Number of splitting features active above threshold: 0
Number of non-splitting features active above threshold: 26
Sum of activation strengths for splitting features: 1.2277
Sum of activation strengths for non-splitting features: 157.7681

Analyzing representative point 5:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


Glyph 29256 (\N{CJK UNIFIED IDEOGRAPH-7248}) missing from font(s) DejaVu Sans.


Glyph 29256 (\N{CJK UNIFIED IDEOGRAPH-7248}) missing from font(s) DejaVu Sans.


Glyph 29256 (\N{CJK UNIFIED IDEOGRAPH-7248}) missing from font(s) DejaVu Sans.


Glyph 29256 (\N{CJK UNIFIED IDEOGRAPH-7248}) missing from font(s) DejaVu Sans.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



Number of non-zero features: 44
Number of non-zero feature splitting nodes: 1
Total number of feature splitting nodes: 5
Mean activation of non-zero feature splitting nodes: 1.6251
Mean activation of non-zero non-feature splitting nodes: 3.3963
Median activation of non-zero feature splitting nodes: 1.6251
Median activation of non-zero non-feature splitting nodes: 1.1159
Number of splitting features active above threshold: 1
Number of non-splitting features active above threshold: 19
Sum of activation strengths for splitting features: 1.6251
Sum of activation strengths for non-splitting features: 146.0398


In [23]:
# # After creating the PCA plot and identifying interesting points
# interesting_point_ids = [0]  # Replace with actual IDs of interest
# analyze_specific_points(
#     results,
#     fs_splitting_nodes,
#     fs_splitting_cluster,
#     activation_threshold,
#     node_df,
#     results_path,
#     pca_df,
#     interesting_point_ids,
#     save_figs=True,
#     pca_path=pca_path,
# )


Analyzing point with ID 0:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



In [35]:
# plot_subgraph_static(
#     subgraph=load_subgraph(results_path, activation_threshold, subgraph_id),
#     node_info_df=node_df,
#     output_path=os.path.join(pca_path, "overall_subgraph"),
#     activation_array=overall_feature_activations,
#     normalize_globally=False,
#     save_figs=True,
# )


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.

