In [39]:
%load_ext autoreload
%autoreload 2

import logging
import os
import re
from os.path import join as pj

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# import plotly.graph_objs as go
import torch
from sae_lens import SAE, ActivationsStore
from transformer_lens import HookedTransformer

from sae_cooccurrence.normalised_cooc_functions import (
    create_results_dir,
    neat_sae_id,
)
from sae_cooccurrence.pca import (
    analyze_specific_points,
    analyze_specific_points_from_sparse_matrix_path,
    generate_data,
    load_data_from_pickle,
    plot_doubly_clustered_activation_heatmap,
    plot_pca_feature_strength,
    plot_pca_with_active_features,
    plot_pca_with_top_feature,
    plot_subgraph_from_sparse_matrix_path,
    save_data_to_pickle,
)
from sae_cooccurrence.pca_animation import (
    analyze_specific_points_animated_from_sparse_matrix_path,
)
from sae_cooccurrence.utils.saving_loading import load_npz_files, set_device
from sae_cooccurrence.utils.set_paths import get_git_root

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def setup_logging(log_path):
    logging.basicConfig(
        filename=log_path,
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

In [3]:
# Config -------------
torch.set_grad_enabled(False)
device = set_device()
git_root = get_git_root()

Using MPS


In [6]:
save_figs = True


model_name = "gpt2-small"
sae_release_short = "res-jb-feature-splitting"
sae_id = "blocks.8.hook_resid_pre_24576"
n_batches_reconstruction = 1000


activation_threshold = 1.5
subgraph_id = 125
n_batches_generation = 500

In [10]:
fs_splitting_cluster = subgraph_id
pca_prefix = "pca"

np.random.seed(1234)


# n_batches_reconstruction = config['pca']['n_batches_reconstruction']


# Load model
model = HookedTransformer.from_pretrained(model_name, device=device)

# Process the specific subgraph
sae_id_neat = neat_sae_id(sae_id)
results_dir = create_results_dir(
    model_name, sae_release_short, sae_id_neat, n_batches_generation
)
results_path = pj(git_root, results_dir)
activation_threshold_safe = str(activation_threshold).replace(".", "_")

figures_path = pj(git_root, f"figures/{model_name}/{sae_release_short}/{sae_id_neat}")
pca_dir = f"{pca_prefix}_{activation_threshold_safe}_subgraph_{subgraph_id}"
pca_path = pj(figures_path, pca_dir)
if not os.path.exists(pca_path):
    os.makedirs(pca_path)
pickle_file = pj(pca_path, f"pca_data_subgraph_{subgraph_id}.pkl")

# Set up logging
log_path = pj(pca_path, "pca_analysis.log")
setup_logging(log_path)

# Log all settings
logging.info("Script started")
logging.info("Settings:")
logging.info(f"  save_figs: {save_figs}")
logging.info(f"  git_root: {git_root}")
logging.info(f"  sae_id: {sae_id}")
logging.info(f"  activation_threshold: {activation_threshold}")
logging.info(f"  subgraph_id: {subgraph_id}")
logging.info(f"  fs_splitting_cluster: {fs_splitting_cluster}")
logging.info(f"  pca_prefix: {pca_prefix}")
logging.info(f"  model_name: {model_name}")
logging.info(f"  sae_release_short: {sae_release_short}")
logging.info(f"  n_batches_reconstruction: {n_batches_reconstruction}")
logging.info(f"  device: {device}")
logging.info(f"  results_path: {results_path}")
logging.info(f"  pca_path: {pca_path}")

Loaded pretrained model gpt2-small into HookedTransformer


In [11]:
node_df = pd.read_csv(
    pj(results_path, f"dataframes/node_info_df_{activation_threshold_safe}.csv")
)
logging.info(
    f"Loaded node_df from {pj(results_path, f'dataframes/node_info_df_{activation_threshold_safe}.csv')}"
)

overall_feature_activations = load_npz_files(
    results_path, "feature_acts_cooc_activations"
).get(activation_threshold)

# with open(pj(results_path, f"subgraph_objects/activation_{activation_threshold_safe}/subgraph_{subgraph_id}.pkl"), 'rb') as f:
#     subgraph = pickle.load(f)


# Filter for the specific subgraph
fs_splitting_nodes = node_df.query("subgraph_id == @subgraph_id")["node_id"].tolist()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loading npz files:   0%|          | 0/4 [00:00<?, ?it/s]

In [12]:
regen_data = False
if not regen_data:
    raise ValueError("Are you sure you don't want to use existing data?")

In [14]:
# parser = argparse.ArgumentParser(description="PCA analysis script")
# parser.add_argument('--save_pickle', action='store_true', help='Save generated data to pickle')
# parser.add_argument('--load_pickle', action='store_true', help='Load data from pickle instead of regenerating')
# args = parser.parse_args()


if not regen_data and os.path.exists(pickle_file):
    data = load_data_from_pickle(pickle_file)
    results = data["results"]
    pca_df = data["pca_df"]
    pca = data["pca"]
    pca_decoder = data["pca_decoder"]
    pca_decoder_df = data["pca_decoder_df"]
else:
    # Load SAE and set up activation store
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=f"{model_name}-{sae_release_short}", sae_id=sae_id, device=device
    )
    sae.fold_W_dec_norm()

    activation_store = ActivationsStore.from_sae(
        model=model,
        sae=sae,
        streaming=True,
        store_batch_size_prompts=8,
        train_batch_size_tokens=4096,
        n_batches_in_buffer=32,
        device=device,
    )

    data = generate_data(
        model=model,
        sae=sae,
        activation_store=activation_store,
        fs_splitting_nodes=fs_splitting_nodes,
        n_batches_reconstruction=n_batches_reconstruction,
        device=device,
    )

    if regen_data:
        save_data_to_pickle(data, pickle_file)

    results = data["results"]
    pca_df = data["pca_df"]
    pca = data["pca"]
    pca_decoder = data["pca_decoder"]
    pca_decoder_df = data["pca_decoder_df"]



This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info.



  0%|          | 0/1000 [00:00<?, ?it/s]

Total examples found: 978


In [15]:
plot_pca_with_top_feature(
    pca_df, results, fs_splitting_nodes, fs_splitting_cluster, pca_path, save=save_figs
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [16]:
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC1",
    pc_y="PC2",
    save=save_figs,
)
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC1",
    pc_y="PC3",
    save=save_figs,
)
plot_pca_feature_strength(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    pc_x="PC2",
    pc_y="PC3",
    save=save_figs,
)

In [17]:
plot_pca_with_active_features(
    pca_df,
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    pca_path,
    activation_threshold=activation_threshold,
    save=save_figs,
)

In [18]:
plot_doubly_clustered_activation_heatmap(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=save_figs,
)

In [13]:
# plot_feature_activations_combined(
#     get_point_result(results, 2),
#     fs_splitting_nodes,
#     fs_splitting_cluster,
#     activation_threshold,
#     node_df,
#     results_path,
#     pca_path,
#     save_figs=True,
# )


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



/Users/matthew/Documents/Github/PIBBSS/figures/gpt2-small/res-jb-feature-splitting/blocks_8_hook_resid_pre_24576/pca_1_5_subgraph_125/combined_plot_subgraph_2653.png


Number of non-zero features: 38
Number of non-zero feature splitting nodes: 1
Total number of feature splitting nodes: 5
Mean activation of non-zero feature splitting nodes: 0.6166
Mean activation of non-zero non-feature splitting nodes: 3.0819
Median activation of non-zero feature splitting nodes: 0.6166
Median activation of non-zero non-feature splitting nodes: 1.3073
Number of splitting features active above threshold: 0
Number of non-splitting features active above threshold: 17
Sum of activation strengths for splitting features: 0.6166
Sum of activation strengths for non-splitting features: 114.0296


In [19]:
# plot_feature_activations(
#     get_point_result(results, 2),
#     fs_splitting_nodes,
#     fs_splitting_cluster,
#     activation_threshold,
#     node_df,
#     results_path,
#     save_figs=False,
#     pca_path=pca_path,
# )

FileNotFoundError: Subgraph file not found: /Users/matthew/Documents/Github/sae_cooccurrence/results/gpt2-small/res-jb-feature-splitting/blocks_8_hook_resid_pre_24576/n_batches_500/subgraph_objects/activation_1_5/subgraph_2653.pkl

In [20]:
# # Usage example:
# pca_df, _ = perform_pca_on_results(results)
# analyze_representative_points(
#     results=results,
#     fs_splitting_nodes=fs_splitting_nodes,
#     fs_splitting_cluster=fs_splitting_cluster,
#     activation_threshold=activation_threshold,
#     node_df=node_df,
#     results_path=results_path,
#     pca_df=pca_df,
#     save_figs=True,
#     pca_path=pca_path,
# )


Analyzing representative point 1:


FileNotFoundError: Subgraph file not found: /Users/matthew/Documents/Github/sae_cooccurrence/results/gpt2-small/res-jb-feature-splitting/blocks_8_hook_resid_pre_24576/n_batches_500/subgraph_objects/activation_1_5/subgraph_125.pkl

In [16]:
# analyze_representative_points_comp(
#     results,
#     fs_splitting_nodes,
#     activation_threshold,
#     node_df,
#     pca_df,
#     save_figs=True,
#     pca_path=pca_path,
# )


Statistics for representative point 1:
Number of non-zero features: 18
Number of non-zero feature splitting nodes: 3
Total number of feature splitting nodes: 5
Mean activation of non-zero feature splitting nodes: 32.3789
Mean activation of non-zero non-feature splitting nodes: 3.4057
Median activation of non-zero feature splitting nodes: 34.8842
Median activation of non-zero non-feature splitting nodes: 1.3402
Number of splitting features active above threshold: 3
Number of non-splitting features active above threshold: 6
Sum of activation strengths for splitting features: 97.1367
Sum of activation strengths for non-splitting features: 51.0860

Statistics for representative point 2:
Number of non-zero features: 19
Number of non-zero feature splitting nodes: 3
Total number of feature splitting nodes: 5
Mean activation of non-zero feature splitting nodes: 30.3896
Mean activation of non-zero non-feature splitting nodes: 3.0571
Median activation of non-zero feature splitting nodes: 17.105

In [94]:
# After creating the PCA plot and identifying interesting points
interesting_point_ids = [
    300,
    737,
    334,
    655,
    385,
    144,
    348,
]  # Replace with actual IDs of interest
analyze_specific_points(
    results,
    fs_splitting_nodes,
    fs_splitting_cluster,
    activation_threshold,
    node_df,
    results_path,
    pca_df,
    interesting_point_ids,
    save_figs=True,
    pca_path=pca_path,
)


Analyzing point with ID 291:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.




Analyzing point with ID 737:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.




Analyzing point with ID 334:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.




Analyzing point with ID 655:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.




Analyzing point with ID 385:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.




Analyzing point with ID 144:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.




Analyzing point with ID 348:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


Glyph 29256 (\N{CJK UNIFIED IDEOGRAPH-7248}) missing from font(s) DejaVu Sans.


Glyph 29256 (\N{CJK UNIFIED IDEOGRAPH-7248}) missing from font(s) DejaVu Sans.


Glyph 29256 (\N{CJK UNIFIED IDEOGRAPH-7248}) missing from font(s) DejaVu Sans.


Glyph 29256 (\N{CJK UNIFIED IDEOGRAPH-7248}) missing from font(s) DejaVu Sans.



In [25]:
interesting_point_ids = [300, 737, 334, 655, 385, 144, 348]
analyze_specific_points_animated_from_sparse_matrix_path(
    results=results,
    fs_splitting_nodes=fs_splitting_nodes,
    fs_splitting_cluster=fs_splitting_cluster,
    # activation_threshold=activation_threshold,
    node_df=node_df,
    results_path=results_path,
    pca_df=pca_df,
    matrix_filename="sparse_thresholded_matrix_1_5.npz",
    point_ids=interesting_point_ids,
    plot_only_fs_nodes=True,
    save_gif=True,
    gif_path=os.path.join(pca_path, "length.gif"),
)

Point ID: 300<br>. You cant hide pic.twitter.com/|q|XeOlORYUb — Nick (@Nick_
animation.gif
/Users/matthew/Documents/Github/sae_cooccurrence/figures/gpt2-small/res-jb-feature-splitting/blocks_8_hook_resid_pre_24576/pca_1_5_subgraph_125/length.gif/gif_frames
Animation saved as GIF: /Users/matthew/Documents/Github/sae_cooccurrence/figures/gpt2-small/res-jb-feature-splitting/blocks_8_hook_resid_pre_24576/pca_1_5_subgraph_125/length.gif
Individual frames saved in folder: /Users/matthew/Documents/Github/sae_cooccurrence/figures/gpt2-small/res-jb-feature-splitting/blocks_8_hook_resid_pre_24576/pca_1_5_subgraph_125/length.gif/gif_frames


In [26]:
# analyze_user_specified_points_comp(
#     results,
#     fs_splitting_nodes,
#     activation_threshold,
#     node_df,
#     pca_df,
#     interesting_point_ids,
#     save_figs=True,
#     pca_path=pca_path,
# )

In [27]:
# analyze_user_specified_points_comp_subgraph(
#     results,
#     fs_splitting_nodes,
#     fs_splitting_cluster,
#     activation_threshold,
#     node_df,
#     pca_df,
#     interesting_point_ids,
#     results_path,
#     save_figs=True,
#     pca_path=pca_path,
# )

In [28]:
plot_subgraph_from_sparse_matrix_path(
    results_path,
    pca_path,
    fs_splitting_cluster,
    node_df,
    save_figs=True,
    matrix_filename="sparse_thresholded_matrix_1_5.npz",
)


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



(<networkx.classes.graph.Graph at 0x3e2f6f450>,
      node_id  feature_activations
 360    19054                928.0
 361     5748                926.0
 362     1179                777.0
 363      125                655.0
 364      734                750.0)

In [29]:
# After creating the PCA plot and identifying interesting points
interesting_point_ids = [85, 253, 70, 334]  # Replace with actual IDs of interest
analyze_specific_points_from_sparse_matrix_path(
    results=results,
    results_path=results_path,
    fs_splitting_nodes=fs_splitting_nodes,
    fs_splitting_cluster=fs_splitting_cluster,
    activation_threshold=activation_threshold,
    node_df=node_df,
    pca_df=pca_df,
    point_ids=interesting_point_ids,
    save_figs=True,
    pca_path=os.path.join(pca_path, "repeats_char_2"),
)


Analyzing point with ID 85:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



Subgraph figure is None. Likely no latents are within a subgraph.
Subgraph figure is None. Likely no latents are within a subgraph.

Analyzing point with ID 253:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



Subgraph figure is None. Likely no latents are within a subgraph.
Subgraph figure is None. Likely no latents are within a subgraph.

Analyzing point with ID 70:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



Subgraph figure is None. Likely no latents are within a subgraph.

Analyzing point with ID 334:



This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



Subgraph figure is None. Likely no latents are within a subgraph.


In [23]:
pca_df

Unnamed: 0,PC1,PC2,PC3,tokens,context,point_id
0,23.479156,1.318546,-1.092309,us,Main Street USA? Let| us| know in the comment...,0
1,23.367529,1.198425,-1.259207,m,��I��|m| not aware of any of,1
2,23.245977,1.190119,-1.273471,more,"the exploited vulnerability, is| more| widely...",2
3,22.766705,1.156939,-1.329908,z,Programmed by BlackOp|z|FX Labs.,3
4,22.185122,1.034006,-1.561345,48,", 36.6347|48| (Wikimapia",4
...,...,...,...,...,...,...
973,23.174801,1.185114,-1.281878,rather,"appears to arise from,| rather|",973
974,17.466125,3.604700,1.894514,k,<|endoftext|>/|k|] - A[i,974
975,17.466125,3.604700,1.894514,k,<|endoftext|>/|k| time. This is minimized,975
976,23.416416,1.201796,-1.253445,t,<|endoftext|>�|t| work like that.\n,976


In [31]:
def plot_pca_char_in_token(pca_df, pca_path, fs_splitting_cluster, save_figs=False):
    # Function to preprocess context by removing '<|endoftext|>' and counting characters between '|' characters
    def preprocess_and_count_chars(context):
        # Remove '<|endoftext|>' from the context
        cleaned_context = context.replace("<|endoftext|>", "")

        # Split the cleaned context by '|'
        parts = cleaned_context.split("|")

        # Count characters in the middle part if it exists
        if len(parts) >= 3:
            return len(parts[1])
        return 0

    # Apply the preprocessing and counting function
    pca_df["char_count"] = pca_df["context"].apply(preprocess_and_count_chars)

    # Get unique char_count values and sort them
    unique_counts = sorted(pca_df["char_count"].unique())

    # Create a color map for unique values
    color_map = px.colors.qualitative.Plotly
    color_discrete_map = {
        count: color_map[i % len(color_map)] for i, count in enumerate(unique_counts)
    }

    # Create the plot
    fig = go.Figure()

    # Add traces for each unique char_count
    for count in unique_counts:
        df_subset = pca_df[pca_df["char_count"] == count]
        fig.add_trace(
            go.Scatter(
                x=df_subset["PC2"],
                y=df_subset["PC3"],
                mode="markers",
                marker=dict(
                    color=color_discrete_map[count],
                    size=12,
                    line=dict(width=1, color="DarkSlateGrey"),
                ),
                name=f"Count: {count}",
                text=[
                    f"Token: {t}<br>Context: {c}<br>Char Count: {count}"
                    for t, c in zip(df_subset["tokens"], df_subset["context"])
                ],
                hoverinfo="text",
            )
        )

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text=f"PCA Analysis - Cluster {fs_splitting_cluster} (Discrete Context Character Count)",
        xaxis_title="PC2",
        yaxis_title="PC3",
        legend_title="Character Count",
    )

    if save_figs:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"pca_plot_n_char_in_token_{fs_splitting_cluster}.png"
        )
        fig.write_image(png_path, scale=3.0)

        # Save as HTML
        html_path = os.path.join(
            pca_path, f"pca_plot_n_char_in_token_{fs_splitting_cluster}.html"
        )
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [32]:
pca_df_max_examples = pca_df.copy()
pca_df_max_examples["combined"] = (
    pca_df_max_examples["PC2"] + pca_df_max_examples["PC3"]
)
examples = pca_df_max_examples.sort_values("combined", ascending=False).head(10)
examples.rename(columns={"context": "Context"}, inplace=True)
# show only columns for point id and context
examples[["Context"]].to_csv(pj(pca_path, "top_examples_right.csv"), index=False)

In [33]:
pca_df_max_examples = pca_df.copy()
pca_df_max_examples["combined"] = -pca_df_max_examples["PC2"]
examples = pca_df_max_examples.sort_values("combined", ascending=False).head(10)
examples.rename(columns={"context": "Context"}, inplace=True)
# Remove special characters, including \n and \\n, from 'Context' before export
examples["Context"] = examples["Context"].str.replace("\n", "").str.replace("\\n", "")
# show only columns for point id and context
examples[["Context"]].to_csv(pj(pca_path, "top_examples_left.csv"), index=False)
examples

Unnamed: 0,PC1,PC2,PC3,tokens,Context,point_id,combined
65,-31.883965,-44.429882,12.508744,c,.com/e2zNEIdX5|c| — Earl Brown (@cosine55) Apr...,65,44.429882
743,-27.169886,-44.284946,11.883591,VS,.com/sF29nUxL|VS| — FireWorks (@FireWorksBAY) ...,743,44.284946
707,-31.946568,-43.908348,10.686967,zy,.twitter.com/FFwwZLuT|zy| — Kristan T. Harris ...,707,43.908348
98,-30.1394,-43.137024,12.543056,v,twitter.com/9rPy032rG|v| — Kyle Holliman (@the...,98,43.137024
348,-16.514959,-43.130455,14.356832,3,twitter.com/GF0JoXJaq|3| — Nick (@Nick_Falco) ...,348,43.130455
158,-28.061049,-41.902596,12.248115,d,com/KLvQ7Bm9S|d| — Josh Rubin (@jrubin) April 28,158,41.902596
376,-32.934372,-41.835052,10.523829,i,com/sSX4g2zVO|i| — Lea Michele (@msleamichele,376,41.835052
138,-28.204767,-41.548389,9.868689,Ph,twitter.com/OvynGycV|Ph| — SevereStudios (@se,138,41.548389
90,-18.080685,-41.534134,12.893217,9,twitter.com/EB6tlnlPI|9| — Jimmy Carter (@askj...,90,41.534134
435,-29.748013,-40.329414,8.646914,EM,.com/2eJbAhzG|EM| — Jacks (@JackkieMarrie) Dec...,435,40.329414


In [34]:
pca_df_max_examples = pca_df.copy()
pca_df_max_examples["combined"] = (
    -pca_df_max_examples["PC2"] + pca_df_max_examples["PC3"]
)
examples = pca_df_max_examples.sort_values("combined", ascending=False).head(10)
examples.rename(columns={"context": "Context"}, inplace=True)
# show only columns for point id and context
examples[["Context"]].to_csv(pj(pca_path, "top_examples_bottom.csv"), index=False)

In [35]:
# plot_subgraph_static(
#     load_subgraph(results_path, activation_threshold, subgraph_id),
#     node_df,
#     0.0,
#     os.path.join(pca_path, "overall_subgraph"),
#     overall_feature_activations,
#     normalize_globally=False,
#     save_figs=True,
# )

FileNotFoundError: Subgraph file not found: /Users/matthew/Documents/Github/sae_cooccurrence/results/gpt2-small/res-jb-feature-splitting/blocks_8_hook_resid_pre_24576/n_batches_500/subgraph_objects/activation_1_5/subgraph_125.pkl

In [36]:
plot_pca_char_in_token(pca_df, pca_path, fs_splitting_cluster, save_figs=True)

In [37]:
def plot_pca_filtered_context(pca_df, pca_path, fs_splitting_cluster, save_figs=False):
    def process_and_count_chars(context):
        # Remove '<|endoftext|>' from the context
        cleaned_context = context.replace("<|endoftext|>", "")

        # Split the cleaned context by '|'
        parts = cleaned_context.split("|")

        # Check if there's exactly one character between '|' symbols
        if len(parts) == 3 and len(parts[1]) == 1:
            # single_char = parts[1]
            before_part = parts[0]

            # Check for '/watch?' string
            watch_index = before_part.rfind("/watch?")
            if watch_index != -1:
                # Count characters from end of '/watch?' to the single character
                return len(before_part) - (
                    watch_index + 7
                )  # 7 is the length of '/watch?'
            else:
                # Check if there's a '/' before the single character without spaces
                match = re.search(r"/([^/\s]+)$", before_part)
                if match:
                    # Count characters between the last '/' and the single character
                    return len(match.group(1))

        # Return None for cases that don't meet the criteria
        return None

    # Apply the processing and counting function
    pca_df["char_count"] = pca_df["context"].apply(process_and_count_chars)

    # Filter out None values
    pca_df_filtered = pca_df.dropna(subset=["char_count"])

    # Create the plot
    fig = go.Figure()

    # Add trace for all points
    fig.add_trace(
        go.Scatter(
            x=pca_df_filtered["PC2"],
            y=pca_df_filtered["PC3"],
            mode="markers",
            marker=dict(
                color=pca_df_filtered["char_count"],
                colorscale="turbo",
                size=12,
                colorbar=dict(title="Character Count"),
                line=dict(width=1, color="DarkSlateGrey"),
            ),
            text=[
                f"Token: {t}<br>Context: {c}<br>Char Count: {count}"
                for t, c, count in zip(
                    pca_df_filtered["tokens"],
                    pca_df_filtered["context"],
                    pca_df_filtered["char_count"],
                )
            ],
            hoverinfo="text",
        )
    )

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text=f"PCA Analysis - Cluster {fs_splitting_cluster} (Filtered Context Character Count)",
        xaxis_title="PC2",
        yaxis_title="PC3",
    )

    if save_figs:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"pca_plot_filtered_context_char_count_{fs_splitting_cluster}.png"
        )
        fig.write_image(png_path, scale=3.0)

        # Save as HTML
        html_path = os.path.join(
            pca_path,
            f"pca_plot_filtered_context_char_count_{fs_splitting_cluster}.html",
        )
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [40]:
plot_pca_filtered_context(pca_df, pca_path, fs_splitting_cluster, save_figs=True)

In [41]:
# import plotly.graph_objs as go
# import os
# import plotly.express as px
# import re
# import pandas as pd


def plot_pca_filtered_context2(pca_df, pca_path, fs_splitting_cluster, save_figs=False):
    def process_context(context):
        # Split the context by '|'
        parts = context.split("|")

        # Check if there's exactly one character between '|' symbols
        if len(parts) == 3 and len(parts[1]) == 1:
            # single_char = parts[1]
            before_part = parts[0]
            after_part = parts[2]

            # Determine the shape based on what comes immediately after the '|'
            shape = (
                "cross"
                if after_part.startswith(" ") or after_part.startswith("<|endoftext|>")
                else "circle"
            )

            # Check for '/watch?' string
            watch_index = before_part.rfind("/watch?")
            if watch_index != -1:
                # Count characters from end of '/watch?' to the single character
                char_count = len(before_part) - (
                    watch_index + 7
                )  # 7 is the length of '/watch?'
            else:
                # Check if there's a '/' before the single character without spaces
                match = re.search(r"/([^/\s]+)$", before_part)
                if match:
                    # Count characters between the last '/' and the single character
                    char_count = len(match.group(1))
                else:
                    return None, None  # No valid count found

            return char_count, shape

        return None, None  # Criteria not met

    # Apply the processing function
    pca_df[["char_count", "shape"]] = pca_df["context"].apply(
        lambda x: pd.Series(process_context(x))
    )

    # Filter out None values
    pca_df_filtered = pca_df.dropna(subset=["char_count", "shape"])

    # Create the plot
    fig = go.Figure()

    # Add trace for circular markers
    circle_df = pca_df_filtered[pca_df_filtered["shape"] == "circle"]
    fig.add_trace(
        go.Scatter(
            x=circle_df["PC2"],
            y=circle_df["PC3"],
            mode="markers",
            marker=dict(
                color=circle_df["char_count"],
                colorscale="viridis",
                size=12,
                symbol="circle",
                line=dict(width=1, color="DarkSlateGrey"),
            ),
            text=[
                f"Token: {t}<br>Context: {c}<br>Char Count: {count}"
                for t, c, count in zip(
                    circle_df["tokens"], circle_df["context"], circle_df["char_count"]
                )
            ],
            hoverinfo="text",
            name="No space or <|endoftext|> immediately after |",
        )
    )

    # Add trace for cross markers
    cross_df = pca_df_filtered[pca_df_filtered["shape"] == "cross"]
    fig.add_trace(
        go.Scatter(
            x=cross_df["PC2"],
            y=cross_df["PC3"],
            mode="markers",
            marker=dict(
                color=cross_df["char_count"],
                colorscale="turbo",
                size=12,
                symbol="cross",
                line=dict(width=1, color="DarkSlateGrey"),
            ),
            text=[
                f"Token: {t}<br>Context: {c}<br>Char Count: {count}"
                for t, c, count in zip(
                    cross_df["tokens"], cross_df["context"], cross_df["char_count"]
                )
            ],
            hoverinfo="text",
            name="Space or <|endoftext|> immediately after |",
        )
    )

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text=f"PCA Analysis - Cluster {fs_splitting_cluster} (Filtered Context Character Count)",
        xaxis_title="PC2",
        yaxis_title="PC3",
        coloraxis_colorbar=dict(title="Character Count"),
        legend_title="Marker Shape",
    )

    # Update coloraxis to ensure consistent color scale across both traces
    fig.update_traces(marker=dict(coloraxis="coloraxis"))
    fig.update_layout(coloraxis=dict(colorscale="viridis"))

    if save_figs:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"pca_plot_filtered_context_char_count_{fs_splitting_cluster}.png"
        )
        fig.write_image(png_path, scale=3.0)

        # Save as HTML
        html_path = os.path.join(
            pca_path,
            f"pca_plot_filtered_context_char_count_{fs_splitting_cluster}.html",
        )
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [42]:
plot_pca_filtered_context2(pca_df, pca_path, fs_splitting_cluster, save_figs=True)

In [43]:
# import plotly.graph_objects as go


def plot_pca_domain(pca_df, pca_path, fs_splitting_cluster, save_figs=False):
    # Define colors for each category
    color_map = {
        "twitter": "#1DA1F2",  # Twitter blue
        "usat": "#FF0000",  # Red for USA Today
        "youtube": "#00FF00",  # YouTube green
        "other": "#CCCCCC",  # Gray for others
    }

    # Function to determine color
    def get_color(row):
        context = row["context"].lower()
        if "twitter" in context or "t.co" in context:
            return color_map["twitter"]
        elif "usat" in context:
            return color_map["usat"]
        elif "watch?v=" in context:
            return color_map["youtube"]
        else:
            return color_map["other"]

    # Apply the function to get colors
    pca_df["color"] = pca_df.apply(get_color, axis=1)

    # Create the plot
    fig = go.Figure()

    # Add traces for colors (categories)
    for category, color in color_map.items():
        df_category = pca_df[pca_df["color"] == color]
        fig.add_trace(
            go.Scatter(
                x=df_category["PC2"],
                y=df_category["PC3"],
                mode="markers",
                marker=dict(color=color, size=8),
                name=category.capitalize(),
                text=[
                    f"Token: {t}<br>Context: {c}"
                    for t, c in zip(df_category["tokens"], df_category["context"])
                ],
                hoverinfo="text",
            )
        )

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text=f"PCA Analysis - Cluster {fs_splitting_cluster} (Context Categories)",
        xaxis_title="PC2",
        yaxis_title="PC3",
        legend_title_text="Context Category",
    )

    fig.update_traces(
        marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
        selector=dict(mode="markers"),
    )

    if save_figs:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"pca_plot_context_{fs_splitting_cluster}.png"
        )
        fig.write_image(png_path, scale=3.0)

        # Save as HTML
        html_path = os.path.join(
            pca_path, f"pca_plot_context_{fs_splitting_cluster}.html"
        )
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [44]:
plot_pca_domain(pca_df, pca_path, fs_splitting_cluster, save_figs=True)

In [45]:
results.all_feature_acts

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')

In [46]:
pca_df.columns

Index(['PC1', 'PC2', 'PC3', 'tokens', 'context', 'point_id', 'char_count',
       'shape', 'color'],
      dtype='object')

In [47]:
# import plotly.graph_objs as go
# import numpy as np
# import os
# import re
# import pandas as pd


def plot_feature_activation_vs_char_count(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=False,
):
    def process_context(context):
        parts = context.split("|")
        if len(parts) == 3 and len(parts[1]) == 1:
            before_part = parts[0]
            watch_index = before_part.rfind("/watch?")
            if watch_index != -1:
                return len(before_part) - (watch_index + 7)
            else:
                match = re.search(r"/([^/\s]+)$", before_part)
                if match:
                    return len(match.group(1))
        return None

    # Extract feature activations
    feature_activations = results.all_graph_feature_acts.cpu().numpy()

    # Limit the number of examples if there are too many
    n_examples = min(feature_activations.shape[0], max_examples)
    feature_activations = feature_activations[:n_examples]

    # Calculate char_count for each example
    char_counts = pca_df["context"].iloc[:n_examples].apply(process_context)

    # Remove examples with None char_count
    valid_indices = char_counts.notna()
    feature_activations = feature_activations[valid_indices]
    char_counts = char_counts[valid_indices]

    # Sort examples by char_count
    sorted_indices = char_counts.argsort()
    feature_activations = feature_activations[sorted_indices]
    char_counts = char_counts.iloc[sorted_indices]

    # Prepare hover text
    hover_text = []
    for i, char_count in enumerate(char_counts):
        hover_info = f"Char Count: {char_count}"
        hover_text.append(hover_info)

    # Create the heatmap
    fig = go.Figure(
        data=go.Heatmap(
            z=feature_activations.T,
            x=char_counts,
            y=[f"Feature {node}" for node in fs_splitting_nodes],
            colorscale="Viridis",
            colorbar=dict(title="Avg Activation"),
            hoverinfo="text",
            text=hover_text,
        )
    )

    # Update layout
    fig.update_layout(
        title=f"Feature Activation vs Character Count - Cluster {fs_splitting_cluster}",
        xaxis_title="Character Count",
        yaxis_title="Features",
        width=800,
        height=600,
        yaxis=dict(autorange="reversed"),
    )

    # Show the plot
    if save:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"feature_activation_heatmap_nchar_{fs_splitting_cluster}.png"
        )
        fig.write_image(png_path, scale=4.0)

        svg_path = os.path.join(
            pca_path, f"feature_activation_heatmap_nchar_{fs_splitting_cluster}.svg"
        )
        fig.write_image(svg_path)

        # Save as HTML
        html_path = os.path.join(
            pca_path, f"feature_activation_heatmap_nchar_{fs_splitting_cluster}.html"
        )
        fig.write_html(html_path)
        fig.show()
    else:
        fig.show()

In [48]:
plot_feature_activation_vs_char_count(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=True,
)

In [49]:
# import plotly.graph_objs as go
# import numpy as np
# import os
# import re
# import pandas as pd


def plot_feature_activation_stacked_bar(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=False,
):
    def process_context(context):
        parts = context.split("|")
        if len(parts) == 3 and len(parts[1]) == 1:
            before_part = parts[0]
            watch_index = before_part.rfind("/watch?")
            if watch_index != -1:
                return len(before_part) - (watch_index + 7)
            else:
                match = re.search(r"/([^/\s]+)$", before_part)
                if match:
                    return len(match.group(1))
        return None

    # Extract feature activations
    feature_activations = results.all_graph_feature_acts.cpu().numpy()

    # Limit the number of examples if there are too many
    n_examples = min(feature_activations.shape[0], max_examples)
    feature_activations = feature_activations[:n_examples]

    # Calculate char_count for each example
    char_counts = pca_df["context"].iloc[:n_examples].apply(process_context)

    # Remove examples with None char_count
    valid_indices = char_counts.notna()
    feature_activations = feature_activations[valid_indices]
    char_counts = char_counts[valid_indices]

    # Create a DataFrame with char_counts and feature activations
    df = pd.DataFrame(feature_activations, columns=fs_splitting_nodes)
    df["char_count"] = char_counts.values

    # Group by char_count and calculate mean activations
    grouped = df.groupby("char_count").mean().reset_index()
    grouped = grouped.sort_values("char_count")

    # Prepare data for stacked bar chart
    x = grouped["char_count"]
    _ = [grouped[feature] for feature in fs_splitting_nodes]

    # Create stacked bar chart
    fig = go.Figure()

    for i, feature in enumerate(fs_splitting_nodes):
        fig.add_trace(
            go.Bar(
                x=x,
                y=grouped[feature],
                name=f"Feature {feature}",
                hoverinfo="text",
                text=[
                    f"Feature: {feature}<br>Char Count: {count}<br>Activation: {act:.4f}"
                    for count, act in zip(x, grouped[feature])
                ],
            )
        )

    # Update layout
    fig.update_layout(
        title=f"Feature Activation by Character Count - Cluster {fs_splitting_cluster}",
        xaxis_title="Character Count",
        yaxis_title="Feature Activation",
        barmode="stack",
        width=1200,
        height=800,
        legend_title="Features",
        hovermode="closest",
    )

    # Show the plot
    if save:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"feature_activation_stacked_bar_{fs_splitting_cluster}.png"
        )
        fig.write_image(png_path, scale=4.0)

        svg_path = os.path.join(
            pca_path, f"feature_activation_stacked_bar_{fs_splitting_cluster}.svg"
        )
        fig.write_image(svg_path)

        # Save as HTML
        html_path = os.path.join(
            pca_path, f"feature_activation_stacked_bar_{fs_splitting_cluster}.html"
        )
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [50]:
plot_feature_activation_stacked_bar(
    results, fs_splitting_nodes, pca_df, pca_path, fs_splitting_cluster, save=True
)

In [51]:
# import plotly.graph_objs as go
# import numpy as np
# import os
# import re
# import pandas as pd


def plot_feature_activation_area_chart(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=False,
):
    def process_context(context):
        parts = context.split("|")
        if len(parts) == 3 and len(parts[1]) == 1:
            before_part = parts[0]
            watch_index = before_part.rfind("/watch?")
            if watch_index != -1:
                return len(before_part) - (watch_index + 7)
            else:
                match = re.search(r"/([^/\s]+)$", before_part)
                if match:
                    return len(match.group(1))
        return None

    # Extract feature activations
    feature_activations = results.all_graph_feature_acts.cpu().numpy()

    # Limit the number of examples if there are too many
    n_examples = min(feature_activations.shape[0], max_examples)
    feature_activations = feature_activations[:n_examples]

    # Calculate char_count for each example
    char_counts = pca_df["context"].iloc[:n_examples].apply(process_context)

    # Remove examples with None char_count
    valid_indices = char_counts.notna()
    feature_activations = feature_activations[valid_indices]
    char_counts = char_counts[valid_indices]

    # Create a DataFrame with char_counts and feature activations
    df = pd.DataFrame(feature_activations, columns=fs_splitting_nodes)
    df["char_count"] = char_counts.values

    # Group by char_count and calculate mean activations
    grouped = df.groupby("char_count").mean().reset_index()
    grouped = grouped.sort_values("char_count")

    # Create area chart
    fig = go.Figure()

    for feature in fs_splitting_nodes:
        fig.add_trace(
            go.Scatter(
                x=grouped["char_count"],
                y=grouped[feature],
                mode="lines",
                line=dict(width=0.5),
                stackgroup="one",
                name=f"Feature {feature}",
                hoverinfo="text",
                text=[
                    f"Feature: {feature}<br>Char Count: {count}<br>Activation: {act:.4f}"
                    for count, act in zip(grouped["char_count"], grouped[feature])
                ],
            )
        )

    # Update layout
    fig.update_layout(
        title=f"Feature Activation by Character Count - Cluster {fs_splitting_cluster}",
        xaxis_title="Character Count",
        yaxis_title="Cumulative Feature Activation",
        width=1200,
        height=800,
        legend_title="Features",
        hovermode="closest",
        showlegend=True,
    )

    # Adjust y-axis to start at 0
    fig.update_yaxes(rangemode="tozero")

    # Show the plot
    if save:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"feature_activation_area_chart_{fs_splitting_cluster}.png"
        )
        fig.write_image(png_path, scale=4.0)

        svg_path = os.path.join(
            pca_path, f"feature_activation_area_chart_{fs_splitting_cluster}.svg"
        )
        fig.write_image(svg_path)

        # Save as HTML
        html_path = os.path.join(
            pca_path, f"feature_activation_area_chart_{fs_splitting_cluster}.html"
        )
        fig.write_html(html_path)
    else:
        fig.show()

In [63]:
plot_feature_activation_area_chart(
    results, fs_splitting_nodes, pca_df, pca_path, fs_splitting_cluster, save=True
)

In [53]:
# import plotly.graph_objs as go
# import numpy as np
# import os
# import re
# import pandas as pd


def plot_feature_activation_normalized_area_chart(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=False,
):
    def process_context(context):
        parts = context.split("|")
        if len(parts) == 3 and len(parts[1]) == 1:
            before_part = parts[0]
            watch_index = before_part.rfind("/watch?")
            if watch_index != -1:
                return len(before_part) - (watch_index + 7)
            else:
                match = re.search(r"/([^/\s]+)$", before_part)
                if match:
                    return len(match.group(1))
        return None

    # Extract feature activations
    feature_activations = results.all_graph_feature_acts.cpu().numpy()

    # Limit the number of examples if there are too many
    n_examples = min(feature_activations.shape[0], max_examples)
    feature_activations = feature_activations[:n_examples]

    # Calculate char_count for each example
    char_counts = pca_df["context"].iloc[:n_examples].apply(process_context)

    # Remove examples with None char_count
    valid_indices = char_counts.notna()
    feature_activations = feature_activations[valid_indices]
    char_counts = char_counts[valid_indices]

    # Create a DataFrame with char_counts and feature activations
    df = pd.DataFrame(feature_activations, columns=fs_splitting_nodes)
    df["char_count"] = char_counts.values

    # Group by char_count and calculate mean activations
    grouped = df.groupby("char_count").mean().reset_index()
    grouped = grouped.sort_values("char_count")

    # Normalize activations to sum to 1 for each char_count
    activation_columns = grouped.columns.drop("char_count")
    grouped[activation_columns] = grouped[activation_columns].div(
        grouped[activation_columns].sum(axis=1),  # type: ignore
        axis=0,
    )

    # Create area chart
    fig = go.Figure()

    for feature in fs_splitting_nodes:
        fig.add_trace(
            go.Scatter(
                x=grouped["char_count"],
                y=grouped[feature],
                mode="lines",
                line=dict(width=0.5),
                stackgroup="one",
                groupnorm="fraction",
                name=f"Feature {feature}",
                hoverinfo="text",
                text=[
                    f"Feature: {feature}<br>Char Count: {count}<br>Normalized Activation: {act:.4f}"
                    for count, act in zip(grouped["char_count"], grouped[feature])
                ],
            )
        )

    # Update layout
    fig.update_layout(
        title=f"Normalized Feature Activation by Character Count - Cluster {fs_splitting_cluster}",
        xaxis_title="Character Count",
        yaxis_title="Proportion of Feature Activation",
        width=1200,
        height=800,
        legend_title="Features",
        hovermode="closest",
        showlegend=True,
        yaxis=dict(tickformat=".0%"),  # Format y-axis as percentages
    )

    # Show the plot
    if save:
        # Save as PNG
        png_path = os.path.join(
            pca_path,
            f"feature_activation_normalized_area_chart_{fs_splitting_cluster}.png",
        )
        fig.write_image(png_path, scale=4.0)

        svg_path = os.path.join(
            pca_path,
            f"feature_activation_normalized_area_chart_{fs_splitting_cluster}.svg",
        )
        fig.write_image(svg_path)

        # Save as HTML
        html_path = os.path.join(
            pca_path,
            f"feature_activation_normalized_area_chart_{fs_splitting_cluster}.html",
        )
        fig.write_html(html_path)
    else:
        fig.show()

In [54]:
plot_feature_activation_normalized_area_chart(
    results, fs_splitting_nodes, pca_df, pca_path, fs_splitting_cluster, save=True
)

In [44]:
# why has one feature gone missing?

In [45]:
fs_splitting_nodes

[19054, 5748, 1179, 125, 734]

In [56]:
# import pickle

# test_website_path = os.path.join(pca_path, "streamlit_test")
# if not os.path.exists(test_website_path):
#     os.makedirs(test_website_path)

# # Define the path to save the results and pca_df
# save_path = os.path.join(pca_path, test_website_path, "results_and_pca_df.pkl")

# # Save the results and pca_df as a pickle file
# with open(save_path, "wb") as file:
#     pickle.dump({"results": results, "pca_df": pca_df}, file)

# print(f"Results and pca_df saved to {save_path}")


In [57]:
# import plotly.graph_objs as go
# import numpy as np
# import os
# import re
# import pandas as pd


def plot_feature_activation_normalized_line_chart(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=False,
):
    def process_context(context):
        parts = context.split("|")
        if len(parts) == 3 and len(parts[1]) == 1:
            before_part = parts[0]
            watch_index = before_part.rfind("/watch?")
            if watch_index != -1:
                return len(before_part) - (watch_index + 7)
            else:
                match = re.search(r"/([^/\s]+)$", before_part)
                if match:
                    return len(match.group(1))
        return None

    # Extract feature activations
    feature_activations = results.all_graph_feature_acts.cpu().numpy()

    # Limit the number of examples if there are too many
    n_examples = min(feature_activations.shape[0], max_examples)
    feature_activations = feature_activations[:n_examples]

    # Calculate char_count for each example
    char_counts = pca_df["context"].iloc[:n_examples].apply(process_context)

    # Remove examples with None char_count
    valid_indices = char_counts.notna()
    feature_activations = feature_activations[valid_indices]
    char_counts = char_counts[valid_indices]

    # Create a DataFrame with char_counts and feature activations
    df = pd.DataFrame(feature_activations, columns=fs_splitting_nodes)
    df["char_count"] = char_counts.values

    # Group by char_count and calculate mean and std activations
    grouped_mean = df.groupby("char_count").mean().reset_index()
    grouped_std = df.groupby("char_count").std().reset_index()
    grouped = grouped_mean.merge(
        grouped_std, on="char_count", suffixes=("_mean", "_std")
    )
    grouped = grouped.sort_values("char_count")

    # Normalize activations to be between 0 and 1 for each feature
    activation_columns = [col for col in grouped.columns if col.endswith("_mean")]
    for column in activation_columns:
        max_value = grouped[column].max()
        if max_value != 0:
            grouped[column] = grouped[column] / max_value
            grouped[column.replace("_mean", "_std")] = (
                grouped[column.replace("_mean", "_std")] / max_value
            )

    # Create line chart with error bands
    fig = go.Figure()

    for feature in fs_splitting_nodes:
        mean_col = f"{feature}_mean"
        std_col = f"{feature}_std"

        # Add the main line
        fig.add_trace(
            go.Scatter(
                x=grouped["char_count"],
                y=grouped[mean_col],
                mode="lines",
                name=f"Feature {feature}",
                line=dict(
                    color=f"rgba({np.random.randint(0,256)},{np.random.randint(0,256)},{np.random.randint(0,256)},1)"
                ),
                hoverinfo="text",
                text=[
                    f"Feature: {feature}<br>Char Count: {count}<br>Normalized Activation: {act:.4f}"
                    for count, act in zip(grouped["char_count"], grouped[mean_col])
                ],
            )
        )

        # Add the error band
        fig.add_trace(
            go.Scatter(
                x=grouped["char_count"].tolist() + grouped["char_count"].tolist()[::-1],
                y=(grouped[mean_col] + grouped[std_col]).tolist()
                + (grouped[mean_col] - grouped[std_col]).tolist()[::-1],
                fill="toself",
                fillcolor=f"rgba({np.random.randint(0,256)},{np.random.randint(0,256)},{np.random.randint(0,256)},0.2)",
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                name=f"Feature {feature} Error Band",
            )
        )

    # Update layout
    fig.update_layout(
        title=f"Normalized Feature Activation by Character Count - Cluster {fs_splitting_cluster}",
        xaxis_title="Character Count (n_char)",
        yaxis_title="Normalized Feature Activation",
        width=1200,
        height=800,
        legend_title="Features",
        hovermode="closest",
        showlegend=True,
        yaxis=dict(range=[0, 1]),  # Set y-axis range from 0 to 1
    )

    # Show the plot
    if save:
        # Save as PNG
        png_path = os.path.join(
            pca_path,
            f"feature_activation_normalized_line_chart_{fs_splitting_cluster}.png",
        )
        fig.write_image(png_path, scale=4.0)

        # Save as SVG
        svg_path = os.path.join(
            pca_path,
            f"feature_activation_normalized_line_chart_{fs_splitting_cluster}.svg",
        )
        fig.write_image(svg_path)

        # Save as HTML
        html_path = os.path.join(
            pca_path,
            f"feature_activation_normalized_line_chart_{fs_splitting_cluster}.html",
        )
        fig.write_html(html_path)
    else:
        fig.show()
    return fig

In [58]:
plot_feature_activation_normalized_line_chart(
    results, fs_splitting_nodes, pca_df, pca_path, fs_splitting_cluster, save=True
)

In [59]:
# import plotly.graph_objs as go
# import numpy as np
# import os
# import re
# import pandas as pd


def plot_feature_activation_boxplot(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    save=False,
):
    def process_token(context):
        parts = context.split("|")
        if len(parts) == 3 and len(parts[1]) == 1:
            return parts[1]
        return None

    # Extract feature activations
    feature_activations = results.all_graph_feature_acts.cpu().numpy()

    # Process tokens
    tokens = pca_df["context"].apply(process_token)

    # Remove examples with None tokens
    valid_indices = tokens.notna()
    feature_activations = feature_activations[valid_indices]
    tokens = tokens[valid_indices]

    # Create a DataFrame with tokens and feature activations
    df = pd.DataFrame(feature_activations, columns=fs_splitting_nodes)
    df["token"] = tokens.values
    df["case"] = df["token"].apply(
        lambda x: "uppercase" if x.isupper() else "lowercase"
    )

    # Prepare data for box plot
    data = []
    for feature in fs_splitting_nodes:
        for case in ["lowercase", "uppercase"]:
            y = df[df["case"] == case][feature]
            data.append(
                go.Box(
                    y=y,
                    name=f"{feature} ({case})",
                    boxpoints="outliers",
                    # jitter=0.0,
                    # pointpos=-1.8,
                    hoverinfo="y",
                    marker_color="blue" if case == "lowercase" else "red",
                )
            )

    # Create box plot
    fig = go.Figure(data=data)

    # Update layout
    fig.update_layout(
        title=f"Feature Activation by Case - Cluster {fs_splitting_cluster}",
        xaxis_title="Features",
        yaxis_title="Feature Activation",
        width=1200,
        height=800,
        boxmode="group",
        showlegend=False,
        hovermode="closest",
    )

    # Show the plot
    if save:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"feature_activation_boxplot_{fs_splitting_cluster}.png"
        )
        fig.write_image(png_path, scale=4.0)

        # Save as SVG
        svg_path = os.path.join(
            pca_path, f"feature_activation_boxplot_{fs_splitting_cluster}.svg"
        )
        fig.write_image(svg_path)

        # Save as HTML
        html_path = os.path.join(
            pca_path, f"feature_activation_boxplot_{fs_splitting_cluster}.html"
        )
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [60]:
plot_feature_activation_boxplot(
    results, fs_splitting_nodes, pca_df, pca_path, fs_splitting_cluster, save=True
)

In [61]:
# import plotly.graph_objs as go
# import numpy as np
# import os
# import re
# import pandas as pd


def plot_feature_activation_barplot(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    save=False,
):
    def process_token(context):
        parts = context.split("|")
        if len(parts) == 3 and len(parts[1]) == 1:
            return parts[1]
        return None

    # Extract feature activations
    feature_activations = results.all_graph_feature_acts.cpu().numpy()

    # Process tokens
    tokens = pca_df["context"].apply(process_token)

    # Remove examples with None tokens
    valid_indices = tokens.notna()
    feature_activations = feature_activations[valid_indices]
    tokens = tokens[valid_indices]

    # Create a DataFrame with tokens and feature activations
    df = pd.DataFrame(feature_activations, columns=fs_splitting_nodes)
    df["token"] = tokens.values
    df["case"] = df["token"].apply(
        lambda x: "uppercase" if x.isupper() else "lowercase"
    )

    # Calculate the number of times each feature is above 1.5 activity
    above_threshold_counts = {"lowercase": [], "uppercase": []}
    for feature in fs_splitting_nodes:
        above_threshold_counts["lowercase"].append(
            df[(df["case"] == "lowercase") & (df[feature] > 1.5)][feature].count()
        )
        above_threshold_counts["uppercase"].append(
            df[(df["case"] == "uppercase") & (df[feature] > 1.5)][feature].count()
        )

    # Create x-axis labels (feature indices)
    x_labels = [str(i) for i in fs_splitting_nodes]

    # Prepare data for bar plot
    trace_lowercase = go.Bar(
        x=x_labels,
        y=above_threshold_counts["lowercase"],
        name="Lowercase",
        marker_color="blue",
    )

    trace_uppercase = go.Bar(
        x=x_labels,
        y=above_threshold_counts["uppercase"],
        name="Uppercase",
        marker_color="red",
    )

    # Create bar plot
    fig = go.Figure(data=[trace_lowercase, trace_uppercase])

    # Update layout
    fig.update_layout(
        title=f"Number of Times Feature Activation > 1.5 by Case - Cluster {fs_splitting_cluster}",
        xaxis_title="Feature Indices",
        yaxis_title="Count",
        width=1200,
        height=800,
        barmode="group",
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        hovermode="closest",
    )

    # Update x-axis to be categorical
    fig.update_xaxes(type="category", tickmode="linear", tick0=0, dtick=1)

    # Show the plot
    if save:
        # Save as PNG
        png_path = os.path.join(
            pca_path, f"feature_activation_barplot_{fs_splitting_cluster}.png"
        )
        fig.write_image(png_path, scale=4.0)

        # Save as SVG
        svg_path = os.path.join(
            pca_path, f"feature_activation_barplot_{fs_splitting_cluster}.svg"
        )
        fig.write_image(svg_path)

        # Save as HTML
        html_path = os.path.join(
            pca_path, f"feature_activation_barplot_{fs_splitting_cluster}.html"
        )
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [62]:
plot_feature_activation_barplot(
    results, fs_splitting_nodes, pca_df, pca_path, fs_splitting_cluster, save=True
)