In [49]:
%load_ext autoreload
%autoreload 2

import logging
import os
from os.path import join as pj

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import scipy.sparse as sparse
import torch
from sae_lens import SAE, ActivationsStore
from transformer_lens import HookedTransformer

from sae_cooccurrence.normalised_cooc_functions import (
    create_results_dir,
)
from sae_cooccurrence.pca import (
    analyze_representative_points,
    analyze_specific_points,
    create_pca_plots_decoder,
    generate_data,
    generate_subgraph_plot_data_sparse,
    get_point_result,
    load_data_from_pickle,
    perform_pca_on_results,
    plot_doubly_clustered_activation_heatmap,
    plot_feature_activations,
    plot_pca_explanation_and_save,
    plot_pca_feature_strength,
    plot_pca_single_feature_strength,
    plot_pca_with_active_features,
    plot_pca_with_top_feature,
    plot_simple_scatter,
    plot_subgraph_static_from_nx,
    plot_token_pca_and_save,
    save_data_to_pickle,
)
from sae_cooccurrence.utils.saving_loading import load_npz_files, set_device
from sae_cooccurrence.utils.set_paths import get_git_root

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def setup_logging(log_path):
    logging.basicConfig(
        filename=log_path,
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

In [3]:
# Config -------------
torch.set_grad_enabled(False)
device = set_device()
git_root = get_git_root()

Using MPS


In [13]:
save_figs = True


model_name = "gpt2-small"
sae_release_short = "res-jb"
sae_id = "blocks.0.hook_resid_pre"
n_batches_reconstruction = 100


activation_threshold = 1.5
subgraph_id = 3240
n_batches_generation = 1000

In [14]:
fs_splitting_cluster = subgraph_id
pca_prefix = "pca"

np.random.seed(1234)


# n_batches_reconstruction = config['pca']['n_batches_reconstruction']


# Load model
model = HookedTransformer.from_pretrained(model_name, device=device)

# Process the specific subgraph
sae_id_neat = sae_id.replace(".", "_")
results_dir = create_results_dir(
    model_name, sae_release_short, sae_id_neat, n_batches_generation
)
results_path = pj(git_root, results_dir)
activation_threshold_safe = str(activation_threshold).replace(".", "_")

figures_path = pj(git_root, f"figures/{model_name}/{sae_release_short}/{sae_id_neat}")
pca_dir = f"{pca_prefix}_{activation_threshold_safe}_subgraph_{subgraph_id}"
pca_path = pj(figures_path, pca_dir)
if not os.path.exists(pca_path):
    os.makedirs(pca_path)
pickle_file = pj(pca_path, f"pca_data_subgraph_{subgraph_id}.pkl")

# Set up logging
log_path = pj(pca_path, "pca_analysis.log")
setup_logging(log_path)

# Log all settings
logging.info("Script started")
logging.info("Settings:")
logging.info(f"  save_figs: {save_figs}")
logging.info(f"  git_root: {git_root}")
logging.info(f"  sae_id: {sae_id}")
logging.info(f"  activation_threshold: {activation_threshold}")
logging.info(f"  subgraph_id: {subgraph_id}")
logging.info(f"  fs_splitting_cluster: {fs_splitting_cluster}")
logging.info(f"  pca_prefix: {pca_prefix}")
logging.info(f"  model_name: {model_name}")
logging.info(f"  sae_release_short: {sae_release_short}")
logging.info(f"  n_batches_reconstruction: {n_batches_reconstruction}")
logging.info(f"  device: {device}")
logging.info(f"  results_path: {results_path}")
logging.info(f"  pca_path: {pca_path}")

Loaded pretrained model gpt2-small into HookedTransformer


In [15]:
node_df = pd.read_csv(
    pj(results_path, f"dataframes/node_info_df_{activation_threshold_safe}.csv")
)
logging.info(
    f"Loaded node_df from {pj(results_path, f'dataframes/node_info_df_{activation_threshold_safe}.csv')}"
)

overall_feature_activations = load_npz_files(
    results_path, "feature_acts_cooc_activations"
).get(activation_threshold)

# with open(pj(results_path, f"subgraph_objects/activation_{activation_threshold_safe}/subgraph_{subgraph_id}.pkl"), 'rb') as f:
#     subgraph = pickle.load(f)


# Filter for the specific subgraph
fs_splitting_nodes = node_df.query("subgraph_id == @subgraph_id")["node_id"].tolist()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loading npz files:   0%|          | 0/4 [00:00<?, ?it/s]

In [36]:
regen_data = False
if not regen_data:
    raise ValueError("Are you sure you don't want to use existing data?")

In [47]:
# parser = argparse.ArgumentParser(description="PCA analysis script")
# parser.add_argument('--save_pickle', action='store_true', help='Save generated data to pickle')
# parser.add_argument('--load_pickle', action='store_true', help='Load data from pickle instead of regenerating')
# args = parser.parse_args()


if not regen_data and os.path.exists(pickle_file):
    data = load_data_from_pickle(pickle_file)
    results = data["results"]
    pca_df = data["pca_df"]
    pca = data["pca"]
    pca_decoder = data["pca_decoder"]
    pca_decoder_df = data["pca_decoder_df"]
else:
    # Load SAE and set up activation store
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=f"{model_name}-{sae_release_short}", sae_id=sae_id, device=device
    )
    sae.fold_W_dec_norm()

    activation_store = ActivationsStore.from_sae(
        model=model,
        sae=sae,
        streaming=True,
        store_batch_size_prompts=8,
        train_batch_size_tokens=4096,
        n_batches_in_buffer=32,
        device=device,
    )

    data = generate_data(
        model,
        sae,
        activation_store,
        fs_splitting_nodes,
        n_batches_reconstruction,
        device=device,
    )

    if regen_data:
        save_data_to_pickle(data, pickle_file)

    results = data["results"]
    pca_df = data["pca_df"]
    pca = data["pca"]
    pca_decoder = data["pca_decoder"]
    pca_decoder_df = data["pca_decoder_df"]

  0%|          | 0/100 [00:00<?, ?it/s]

Total examples found: 165


In [None]:
plot_token_pca_and_save(pca_df, pca_path, subgraph_id, color_by="token", save=save_figs)

plot_pca_explanation_and_save(pca, pca_path, subgraph_id, save=save_figs)

plot_simple_scatter(results, pca_path, subgraph_id, fs_splitting_nodes, save=save_figs)

# pca_decoder, pca_decoder_df = calculate_pca_decoder(sae, fs_splitting_nodes)

# Save pca_decoder_df as CSV
pca_decoder_df_filename = f"pca_decoder_df_subgraph_{subgraph_id}.csv"
pca_decoder_df.to_csv(pj(pca_path, pca_decoder_df_filename), index=False)

create_pca_plots_decoder(pca_decoder_df, subgraph_id, pca_path, save=save_figs)

print(f"Processing completed for subgraph ID {subgraph_id}")

In [None]:
plot_pca_with_top_feature(
    pca_df, results, fs_splitting_nodes, fs_splitting_cluster, pca_path, save=save_figs
)

In [13]:
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC1",
    pc_y="PC2",
    save=save_figs,
)
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC1",
    pc_y="PC3",
    save=save_figs,
)
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC2",
    pc_y="PC3",
    save=save_figs,
)

In [None]:
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC2",
    pc_y="PC3",
    save=save_figs,
)

In [None]:
fs_splitting_nodes

In [None]:
plot_pca_single_feature_strength(
    pca_df,
    results,
    3266,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC2",
    pc_y="PC3",
    save=save_figs,
)

In [None]:
plot_pca_single_feature_strength(
    pca_df,
    results,
    8838,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC2",
    pc_y="PC3",
    save=save_figs,
)

In [None]:
plot_pca_with_active_features(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    activation_threshold=activation_threshold,
    save=save_figs,
)

In [None]:
plot_doubly_clustered_activation_heatmap(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=save_figs,
)

In [None]:
# plot_feature_activations_combined(
#     get_point_result(results, 2),
#     fs_splitting_nodes,
#     fs_splitting_cluster,
#     activation_threshold,
#     node_df,
#     results_path,
#     pca_path,
#     save_figs=True,
# )

In [None]:
plot_feature_activations(
    get_point_result(results, 2),
    fs_splitting_nodes,
    fs_splitting_cluster,
    activation_threshold,
    node_df,
    results_path,
    save_figs=False,
    pca_path=pca_path,
)

In [None]:
# Usage example:
pca_df, _ = perform_pca_on_results(results)
analyze_representative_points(
    results=results,
    fs_splitting_nodes=fs_splitting_nodes,
    fs_splitting_cluster=fs_splitting_cluster,
    activation_threshold=activation_threshold,
    node_df=node_df,
    results_path=results_path,
    pca_df=pca_df,
    save_figs=True,
    pca_path=pca_path,
)

In [None]:
# analyze_representative_points_comp(
#     results,
#     fs_splitting_nodes,
#     activation_threshold,
#     node_df,
#     pca_df,
#     save_figs=True,
#     pca_path=pca_path,
# )

In [None]:
# After creating the PCA plot and identifying interesting points
interesting_point_ids = [54, 357, 178, 930, 1001]  # Replace with actual IDs of interest
analyze_specific_points(
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    activation_threshold,
    node_df,
    results_path,
    pca_df,
    interesting_point_ids,
    save_figs=True,
    pca_path=pca_path,
)

In [None]:
# analyze_user_specified_points_comp(
#     results,
#     fs_splitting_nodes,
#     activation_threshold,
#     node_df,
#     pca_df,
#     interesting_point_ids,
#     save_figs=True,
#     pca_path=pca_path,
# )

In [None]:
# analyze_user_specified_points_comp_subgraph(
#     results,
#     fs_splitting_nodes,
#     fs_splitting_cluster,
#     activation_threshold,
#     node_df,
#     pca_df,
#     interesting_point_ids,
#     results_path,
#     save_figs=True,
#     pca_path=pca_path,
# )

In [None]:
def plot_pca_weekdays(
    pca_df, pca_path, fs_splitting_cluster, plot_inner=False, save_figs=False
):
    # Define colors for each day and gray for others
    if not plot_inner:
        color_map = {
            "Monday": "#FF9999",
            "Tuesday": "#66B2FF",
            "Wednesday": "#99FF99",
            "Thursday": "#FFCC99",
            "Friday": "#FF99FF",
            "Saturday": "#99FFFF",
            "Sunday": "#FFFF99",
            "Other": "#CCCCCC",
        }
    else:
        color_map = {
            "Mon": "#FF9999",
            "Tues": "#66B2FF",
            "Wed": "#99FF99",
            "Thurs": "#FFCC99",
            "Fri": "#FF99FF",
            "Sat": "#99FFFF",
            "Sun": "#FFFF99",
            "Other": "#CCCCCC",
        }

    # Function to determine color and marker shape
    def get_color_and_marker(token):
        token_lower = token.lower()
        for day in color_map.keys():
            if day.lower() in token_lower:
                return color_map[day], "cross" if " " in token else "circle"
        return color_map["Other"], "circle"

    # Apply the function to get colors and markers
    pca_df["color"], pca_df["marker"] = zip(
        *pca_df["tokens"].apply(get_color_and_marker)
    )

    # Create the plot
    fig = go.Figure()

    # Add traces for colors (days)
    for day in list(color_map.keys()):
        df_day = pca_df[pca_df["color"] == color_map[day]]
        fig.add_trace(
            go.Scatter(
                x=df_day["PC2"],
                y=df_day["PC3"],
                mode="markers",
                marker=dict(color=color_map[day], size=8),
                name=day,
                legendgroup="days",
                legendgrouptitle_text="Day of Week",
                text=[
                    f"Token: {t}<br>Context: {c}"
                    for t, c in zip(df_day["tokens"], df_day["context"])
                ],
                hoverinfo="text",
            )
        )

    # Add traces for shapes (with/without space)
    for marker, label in [("circle", "No Space"), ("cross", "With Space")]:
        df_marker = pca_df[pca_df["marker"] == marker]
        fig.add_trace(
            go.Scatter(
                x=df_marker["PC2"],
                y=df_marker["PC3"],
                mode="markers",
                marker=dict(symbol=marker, size=8, color="rgba(0,0,0,0)"),
                name=label,
                legendgroup="shapes",
                legendgrouptitle_text="Token Type",
                text=[
                    f"Token: {t}<br>Context: {c}"
                    for t, c in zip(df_marker["tokens"], df_marker["context"])
                ],
                hoverinfo="text",
            )
        )

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text=f"PCA Analysis - Cluster {fs_splitting_cluster} (Weekdays)",
        xaxis_title="PC2",
        yaxis_title="PC3",
        legend=dict(groupclick="toggleitem", tracegroupgap=20),
    )

    fig.update_traces(
        marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
        selector=dict(mode="markers"),
    )

    outer_suffix = "" if not plot_inner else "_inner"

    if save_figs:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"pca_plot_weekdays_{fs_splitting_cluster}{outer_suffix}.png"
        )
        fig.write_image(png_path, scale=3.0)

        # Save as HTML
        html_path = os.path.join(
            pca_path, f"pca_plot_weekdays_{fs_splitting_cluster}{outer_suffix}.html"
        )
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [None]:
plot_pca_weekdays(pca_df, pca_path, fs_splitting_cluster, save_figs=True)

In [None]:
plot_pca_weekdays(
    pca_df, pca_path, fs_splitting_cluster, plot_inner=True, save_figs=True
)

In [50]:
subgraph_id = fs_splitting_cluster
sparse_thresholded_matrix = sparse.load_npz(
    os.path.join(
        results_path, "thresholded_matrices", "sparse_thresholded_matrix_1_5.npz"
    ),
)
subgraph, subgraph_df = generate_subgraph_plot_data_sparse(
    sparse_thresholded_matrix=sparse_thresholded_matrix,
    node_df=node_df,
    subgraph_id=subgraph_id,
)
plot_subgraph_static_from_nx(
    subgraph=subgraph,
    output_path=pj(pca_path, "subgraph_static"),
    subgraph_df=subgraph_df,
    node_info_df=node_df,
    save_figs=True,
    show_plot=True,
)


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.

