# Notes on use



In [37]:
%load_ext autoreload
%autoreload 2

import logging
import os
from os.path import join as pj

import pandas as pd
import numpy as np
import torch
from sae_lens import SAE, ActivationsStore
from transformer_lens import HookedTransformer

from PIBBSS.graph_generation import (
    load_subgraph,
    plot_subgraph_static,
)
from PIBBSS.pca import (
    calculate_pca_decoder,
    create_pca_plots_decoder,
    perform_pca_on_results,
    plot_pca_explanation_and_save,
    plot_pca_feature_strength,
    plot_pca_with_active_features,
    plot_pca_with_top_feature,
    plot_simple_scatter,
    plot_token_pca_and_save,
    plot_doubly_clustered_activation_heatmap, 
    plot_feature_activations_combined, 
    get_point_result,
    plot_feature_activations, 
    analyze_representative_points, 
    analyze_representative_points_comp, 
    analyze_user_specified_points_comp, 
    analyze_user_specified_points_comp_subgraph, 
    analyze_specific_points,
    load_data_from_pickle, 
    save_data_to_pickle,
    generate_data
)
from PIBBSS.utils.saving_loading import load_npz_files, set_device
from PIBBSS.utils.set_paths import get_git_root

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [38]:
def setup_logging(log_path):
    logging.basicConfig(
        filename=log_path,
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )

# Config -------------
torch.set_grad_enabled(False)
device = set_device()
git_root = get_git_root()

# Settings to perform PCA on a particular subgraph

In [40]:
save_figs = True

model_name = "gpt2-small"
sae_release_short = "res-jb-feature-splitting"
sae_id = "blocks.8.hook_resid_pre_24576"
n_batches_reconstruction = 100

# model_name = "gemma-2-2b"
# sae_release_short = "gemma-scope-2b-pt-res-canonical"
# sae_id = "layer_0/width_16k/canonical"
# n_batches_reconstruction = 10

activation_threshold = 1.5
subgraph_id =  2332

fs_splitting_cluster = subgraph_id
pca_prefix = "pca"


In [41]:

np.random.seed(1234)


# Load model
model = HookedTransformer.from_pretrained(model_name, device=device)

# Process the specific subgraph
sae_id_neat = sae_id.replace(".", "_").replace("/", "_")
results_dir = f"results/cooc/{model_name}/{sae_release_short}/{sae_id_neat}"
results_path = pj(git_root, results_dir)
activation_threshold_safe = str(activation_threshold).replace(".", "_")

figures_path = pj(git_root, f"figures/{model_name}/{sae_release_short}/{sae_id_neat}")
pca_dir = f"{pca_prefix}_{activation_threshold_safe}_subgraph_{subgraph_id}"
pca_path = pj(figures_path, pca_dir)
if not os.path.exists(pca_path):
    os.makedirs(pca_path)
pickle_file = pj(pca_path, f'pca_data_subgraph_{subgraph_id}.pkl')

# Set up logging
log_path = pj(pca_path, 'pca_analysis.log')
setup_logging(log_path)

# Log all settings
logging.info(f"Script started")
logging.info(f"Settings:")
logging.info(f"  save_figs: {save_figs}")
logging.info(f"  git_root: {git_root}")
logging.info(f"  sae_id: {sae_id}")
logging.info(f"  activation_threshold: {activation_threshold}")
logging.info(f"  subgraph_id: {subgraph_id}")
logging.info(f"  fs_splitting_cluster: {fs_splitting_cluster}")
logging.info(f"  pca_prefix: {pca_prefix}")
logging.info(f"  model_name: {model_name}")
logging.info(f"  sae_release_short: {sae_release_short}")
logging.info(f"  n_batches_reconstruction: {n_batches_reconstruction}")
logging.info(f"  device: {device}")
logging.info(f"  results_path: {results_path}")
logging.info(f"  pca_path: {pca_path}")



`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884



Loaded pretrained model gpt2-small into HookedTransformer


In [42]:

node_df = pd.read_csv(pj(results_path, f"dataframes/node_info_df_{activation_threshold_safe}.csv"))
logging.info(f"Loaded node_df from {pj(results_path, f'dataframes/node_info_df_{activation_threshold_safe}.csv')}")

overall_feature_activations = load_npz_files(results_path, f'feature_acts_cooc_activations').get(activation_threshold)

# with open(pj(results_path, f"subgraph_objects/activation_{activation_threshold_safe}/subgraph_{subgraph_id}.pkl"), 'rb') as f:
#     subgraph = pickle.load(f)


# Filter for the specific subgraph
fs_splitting_nodes = node_df.query('subgraph_id == @subgraph_id')['node_id'].tolist()


Loading npz files: 100%|██████████| 4/4 [00:00<00:00, 985.68it/s]


In [43]:
regen_data = True
if not regen_data:
    raise ValueError("Are you sure you don't want to use existing data?")

In [44]:
# parser = argparse.ArgumentParser(description="PCA analysis script")
# parser.add_argument('--save_pickle', action='store_true', help='Save generated data to pickle')
# parser.add_argument('--load_pickle', action='store_true', help='Load data from pickle instead of regenerating')
# args = parser.parse_args()


if not regen_data and os.path.exists(pickle_file):
    data = load_data_from_pickle(pickle_file)
    results = data['results']
    pca_df = data['pca_df']
    pca = data['pca']
    pca_decoder = data['pca_decoder']
    pca_decoder_df = data['pca_decoder_df']
else:
    if model_name == "gemma-2-2b":
        sae_release = "gemma-scope-2b-pt-res-canonical"
    else:
        sae_release = f"{model_name}-{sae_release_short}"

    # Load SAE and set up activation store
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=sae_release,
        sae_id=sae_id,
        device=device
    )
    sae.fold_W_dec_norm()
    
    activation_store = ActivationsStore.from_sae(
        model=model,
        sae=sae,
        streaming=True,
        store_batch_size_prompts=8,
        train_batch_size_tokens=4096,
        n_batches_in_buffer=32,
        device=device,
    )

    data = generate_data(model, sae, activation_store, fs_splitting_nodes, n_batches_reconstruction, decoder=False)
    
    if regen_data:
        save_data_to_pickle(data, pickle_file)

    results = data['results']
    pca_df = data['pca_df']
    pca = data['pca']
    pca_decoder = data['pca_decoder']
    pca_decoder_df = data['pca_decoder_df']




  0%|          | 0/100 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors


In [45]:

# # Save pca_df as CSV
# pca_df_filename = f"pca_df_subgraph_{subgraph_id}.csv"
# pca_df.to_csv(pj(pca_path, pca_df_filename), index=False)

plot_token_pca_and_save(pca_df, pca_path, subgraph_id, color_by='token', save=save_figs)

plot_pca_explanation_and_save(pca, pca_path, subgraph_id, save=save_figs)

plot_simple_scatter(results, pca_path, subgraph_id, fs_splitting_nodes, save=save_figs)

if pca_decoder is not None:
    pca_decoder, pca_decoder_df = calculate_pca_decoder(sae, fs_splitting_nodes)

# Save pca_decoder_df as CSV
pca_decoder_df_filename = f"pca_decoder_df_subgraph_{subgraph_id}.csv"
pca_decoder_df.to_csv(pj(pca_path, pca_decoder_df_filename), index=False)

create_pca_plots_decoder(pca_decoder_df, subgraph_id, pca_path, save=save_figs)

print(f"Processing completed for subgraph ID {subgraph_id}")

AttributeError: 'NoneType' object has no attribute 'to_csv'

In [None]:


plot_pca_with_top_feature(pca_df, results, fs_splitting_nodes, fs_splitting_cluster, pca_path, save=save_figs)


In [None]:
plot_pca_feature_strength(pca_df, results, fs_splitting_nodes, fs_splitting_cluster, pca_path, pc_x='PC1', pc_y='PC2', save=save_figs)
plot_pca_feature_strength(pca_df, results, fs_splitting_nodes, fs_splitting_cluster, pca_path, pc_x='PC1', pc_y='PC3', save=save_figs)
plot_pca_feature_strength(pca_df, results, fs_splitting_nodes, fs_splitting_cluster, pca_path, pc_x='PC2', pc_y='PC3', save=save_figs)


In [None]:
plot_pca_with_active_features(pca_df, results, fs_splitting_nodes, fs_splitting_cluster, pca_path, activation_threshold=activation_threshold, save=save_figs)


In [None]:
plot_doubly_clustered_activation_heatmap(results, fs_splitting_nodes, pca_df, pca_path, fs_splitting_cluster, max_examples=1000, save=save_figs)

In [None]:
plot_feature_activations_combined(get_point_result(results, 2), fs_splitting_nodes, fs_splitting_cluster, activation_threshold, node_df, results_path, pca_path, save_figs=True)

In [None]:


plot_feature_activations(get_point_result(results, 2), fs_splitting_nodes, fs_splitting_cluster, activation_threshold, node_df, results_path, save_figs=False, pca_path=pca_path)

In [None]:
# Usage example:
pca_df, _ = perform_pca_on_results(results)
analyze_representative_points(results=results, 
                              fs_splitting_nodes=fs_splitting_nodes, 
                              fs_splitting_cluster=fs_splitting_cluster, 
                              activation_threshold=activation_threshold, 
                              node_df=node_df, 
                              results_path=results_path, 
                              pca_df=pca_df, save_figs=True, pca_path=pca_path)

In [None]:
analyze_representative_points_comp(results, fs_splitting_nodes, activation_threshold, node_df, pca_df, save_figs=True, pca_path=pca_path)

In [None]:
# After creating the PCA plot and identifying interesting points
interesting_point_ids = [0] # Replace with actual IDs of interest
analyze_specific_points(results, fs_splitting_nodes, fs_splitting_cluster, activation_threshold, node_df, results_path, pca_df, interesting_point_ids, save_figs=True, pca_path=pca_path)

In [None]:
analyze_user_specified_points_comp(results, fs_splitting_nodes, activation_threshold, node_df, pca_df, interesting_point_ids, save_figs=True, pca_path=pca_path)

In [None]:
analyze_user_specified_points_comp_subgraph(results, fs_splitting_nodes, fs_splitting_cluster, activation_threshold, node_df, pca_df, interesting_point_ids, results_path, save_figs=True, pca_path=pca_path)

In [None]:
plot_subgraph_static(load_subgraph(results_path, activation_threshold, subgraph_id), node_df, 0.0, os.path.join(pca_path, 'overall_subgraph'), overall_feature_activations, normalize_globally=False, save_figs=True)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os

def plot_pca_weekdays(pca_df, pca_path, fs_splitting_cluster, plot_inner=False, save_figs=False):
    # Define colors for each day and gray for others
    if not plot_inner: 
        color_map = {
            'Monday': '#FF9999',
            'Tuesday': '#66B2FF',
            'Wednesday': '#99FF99',
            'Thursday': '#FFCC99',
            'Friday': '#FF99FF',
            'Saturday': '#99FFFF',
            'Sunday': '#FFFF99',
            'Other': '#CCCCCC'
        }
    else:
        color_map = {
            'Mon': '#FF9999',
            'Tues': '#66B2FF',
            'Wed': '#99FF99',
            'Thurs': '#FFCC99',
            'Fri': '#FF99FF',
            'Sat': '#99FFFF',
            'Sun': '#FFFF99',
            'Other': '#CCCCCC'
        }

    # Function to determine color
    def get_color(token):
        token_lower = token.lower()
        for day in color_map.keys():
            if day.lower() in token_lower:
                return color_map[day]
        return color_map['Other']

    # Apply the function to get colors
    pca_df['color'] = pca_df['tokens'].apply(get_color)

    # Create three figures for different PC combinations
    figs = []
    pc_combinations = [('PC1', 'PC2'), ('PC1', 'PC3'), ('PC2', 'PC3')]

    for pc_x, pc_y in pc_combinations:
        fig = go.Figure()

        # Add traces for colors (days)
        for day in list(color_map.keys()):
            df_day = pca_df[pca_df['color'] == color_map[day]]
            fig.add_trace(
                go.Scatter(
                    x=df_day[pc_x],
                    y=df_day[pc_y],
                    mode='markers',
                    marker=dict(color=color_map[day], size=12, line=dict(width=0)),
                    name=day,
                    text=[f"Token: {t}<br>Context: {c}" for t, c in zip(df_day['tokens'], df_day['context'])],
                    hoverinfo='text'
                )
            )

        # Update layout
        fig.update_layout(
            height=800,
            width=800,
            title_text=f"PCA Analysis - Cluster {fs_splitting_cluster} ({pc_x} vs {pc_y})",
            xaxis_title=pc_x,
            yaxis_title=pc_y,
            legend=dict(
                groupclick="toggleitem",
                tracegroupgap=20
            )
        )

        figs.append(fig)

    outer_suffix = "" if not plot_inner else "_inner"

    if save_figs:
        for i, (pc_x, pc_y) in enumerate(pc_combinations):
            # Save as PNG
            png_path = os.path.join(pca_path, f"pca_plot_weekdays_{fs_splitting_cluster}_{pc_x}_{pc_y}{outer_suffix}.png")
            figs[i].write_image(png_path, scale=3.0)

            # Save as HTML
            html_path = os.path.join(pca_path, f"pca_plot_weekdays_{fs_splitting_cluster}_{pc_x}_{pc_y}{outer_suffix}.html")
            figs[i].write_html(html_path)
    else:
        for fig in figs:
            fig.show()

   

In [None]:
plot_pca_weekdays(pca_df, pca_path, fs_splitting_cluster, save_figs=False)

In [None]:
import plotly.graph_objects as go
import os

def plot_pca_weekdays_3d(pca_df, pca_path, fs_splitting_cluster, plot_inner=False, save_figs=False):
    # Define colors for each day and gray for others
    if not plot_inner: 
        color_map = {
            'Monday': '#FF9999',
            'Tuesday': '#66B2FF',
            'Wednesday': '#99FF99',
            'Thursday': '#FFCC99',
            'Friday': '#FF99FF',
            'Saturday': '#99FFFF',
            'Sunday': '#FFFF99',
            'Other': '#CCCCCC'
        }
    else:
        color_map = {
            'Mon': '#FF9999',
            'Tues': '#66B2FF',
            'Wed': '#99FF99',
            'Thurs': '#FFCC99',
            'Fri': '#FF99FF',
            'Sat': '#99FFFF',
            'Sun': '#FFFF99',
            'Other': '#CCCCCC'
        }

    # Function to determine color
    def get_color(token):
        token_lower = token.lower()
        for day in color_map.keys():
            if day.lower() in token_lower:
                return color_map[day]
        return color_map['Other']

    # Apply the function to get colors
    pca_df['color'] = pca_df['tokens'].apply(get_color)

    # Create a 3D figure
    fig = go.Figure()

    # Add traces for colors (days)
    for day in list(color_map.keys()):
        df_day = pca_df[pca_df['color'] == color_map[day]]
        fig.add_trace(
            go.Scatter3d(
                x=df_day['PC1'],
                y=df_day['PC2'],
                z=df_day['PC3'],
                mode='markers',
                marker=dict(color=color_map[day], size=3, line=dict(width=0)),
                name=day,
                text=[f"Token: {t}<br>Context: {c}" for t, c in zip(df_day['tokens'], df_day['context'])],
                hoverinfo='text'
            )
        )

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text=f"3D PCA Analysis - Cluster {fs_splitting_cluster}",
        scene=dict(
            xaxis_title='PC1',
            yaxis_title='PC2',
            zaxis_title='PC3'
        ),
        legend=dict(
            groupclick="toggleitem",
            tracegroupgap=20
        )
    )

    outer_suffix = "" if not plot_inner else "_inner"

    if save_figs:
        # Save as PNG
        png_path = os.path.join(pca_path, f"pca_plot_weekdays_3d_{fs_splitting_cluster}{outer_suffix}.png")
        fig.write_image(png_path, scale=3.0)

        # Save as HTML
        html_path = os.path.join(pca_path, f"pca_plot_weekdays_3d_{fs_splitting_cluster}{outer_suffix}.html")
        fig.write_html(html_path)
    else:
        fig.show()

In [None]:
plot_pca_weekdays_3d(pca_df, pca_path, fs_splitting_cluster, save_figs=False)

In [None]:
import plotly.graph_objs as go
import os
import plotly.express as px
import re

def plot_pca_filtered_context(pca_df, pca_path, fs_splitting_cluster, save_figs=False):
    def process_and_count_chars(context):
        # Remove '<|endoftext|>' from the context
        cleaned_context = context.replace('<|endoftext|>', '')
        
        # Split the cleaned context by '|'
        parts = cleaned_context.split('|')
        
        # Check if there's exactly one character between '|' symbols
        if len(parts) == 3 and len(parts[1]) == 1:
            single_char = parts[1]
            before_part = parts[0]
            
            # Check for '/watch?' string
            watch_index = before_part.rfind('/watch?')
            if watch_index != -1:
                # Count characters from end of '/watch?' to the single character
                return len(before_part) - (watch_index + 7)  # 7 is the length of '/watch?'
            else:
                # Check if there's a '/' before the single character without spaces
                match = re.search(r'/([^/\s]+)$', before_part)
                if match:
                    # Count characters between the last '/' and the single character
                    return len(match.group(1))
        
        # Return None for cases that don't meet the criteria
        return None

    # Apply the processing and counting function
    pca_df['char_count'] = pca_df['context'].apply(process_and_count_chars)

    # Filter out None values
    pca_df_filtered = pca_df.dropna(subset=['char_count'])

    # Create the plot
    fig = go.Figure()

    # Add trace for all points
    fig.add_trace(
        go.Scatter(
            x=pca_df_filtered['PC2'],
            y=pca_df_filtered['PC3'],
            mode='markers',
            marker=dict(
                color=pca_df_filtered['char_count'],
                colorscale='turbo',
                size=12,
                colorbar=dict(title="Character Count"),
                line=dict(width=1, color='DarkSlateGrey')
            ),
            text=[f"Token: {t}<br>Context: {c}<br>Char Count: {count}" 
                  for t, c, count in zip(pca_df_filtered['tokens'], pca_df_filtered['context'], pca_df_filtered['char_count'])],
            hoverinfo='text'
        )
    )

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text=f"PCA Analysis - Cluster {fs_splitting_cluster} (Filtered Context Character Count)",
        xaxis_title="PC2",
        yaxis_title="PC3",
    )

    if save_figs:
        # Save as PNG
        png_path = os.path.join(pca_path, f"pca_plot_filtered_context_char_count_{fs_splitting_cluster}.png")
        fig.write_image(png_path, scale=3.0)

        # Save as HTML
        html_path = os.path.join(pca_path, f"pca_plot_filtered_context_char_count_{fs_splitting_cluster}.html")
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [None]:
plot_pca_filtered_context(pca_df, pca_path, fs_splitting_cluster, save_figs=True)

In [None]:
import plotly.graph_objs as go
import numpy as np
import os
import re
import pandas as pd

def plot_feature_activation_normalized_area_chart(
    results,
    fs_splitting_nodes,
    pca_df,
    pca_path,
    fs_splitting_cluster,
    max_examples=1000,
    save=False,
):
    def process_context(context):
        parts = context.split('|')
        if len(parts) == 3 and len(parts[1]) == 1:
            before_part = parts[0]
            watch_index = before_part.rfind('/watch?')
            if watch_index != -1:
                return len(before_part) - (watch_index + 7)
            else:
                match = re.search(r'/([^/\s]+)$', before_part)
                if match:
                    return len(match.group(1))
        return None

    # Extract feature activations
    feature_activations = results.all_graph_feature_acts.cpu().numpy()

    # Limit the number of examples if there are too many
    n_examples = min(feature_activations.shape[0], max_examples)
    feature_activations = feature_activations[:n_examples]

    # Calculate char_count for each example
    char_counts = pca_df['context'].iloc[:n_examples].apply(process_context)

    # Remove examples with None char_count
    valid_indices = char_counts.notna()
    feature_activations = feature_activations[valid_indices]
    char_counts = char_counts[valid_indices]

    # Create a DataFrame with char_counts and feature activations
    df = pd.DataFrame(feature_activations, columns=fs_splitting_nodes)
    df['char_count'] = char_counts.values

    # Group by char_count and calculate mean activations
    grouped = df.groupby('char_count').mean().reset_index()
    grouped = grouped.sort_values('char_count')

    # Normalize activations to sum to 1 for each char_count
    activation_columns = grouped.columns.drop('char_count')
    grouped[activation_columns] = grouped[activation_columns].div(grouped[activation_columns].sum(axis=1), axis=0)

    # Create area chart
    fig = go.Figure()

    for feature in fs_splitting_nodes:
        fig.add_trace(go.Scatter(
            x=grouped['char_count'],
            y=grouped[feature],
            mode='lines',
            line=dict(width=0.5),
            stackgroup='one',
            groupnorm='fraction',
            name=f'Feature {feature}',
            hoverinfo='text',
            text=[f"Feature: {feature}<br>Char Count: {count}<br>Normalized Activation: {act:.4f}" 
                  for count, act in zip(grouped['char_count'], grouped[feature])],
        ))

    # Update layout
    fig.update_layout(
        title=f"Normalized Feature Activation by Character Count - Cluster {fs_splitting_cluster}",
        xaxis_title="Character Count",
        yaxis_title="Proportion of Feature Activation",
        width=1200,
        height=800,
        legend_title="Features",
        hovermode='closest',
        showlegend=True,
        yaxis=dict(tickformat='.0%')  # Format y-axis as percentages
    )

    # Show the plot
    if save:
        # Save as PNG
        png_path = os.path.join(pca_path, f"feature_activation_normalized_area_chart_{fs_splitting_cluster}.png")
        fig.write_image(png_path, scale=4.0)

        svg_path = os.path.join(pca_path, f"feature_activation_normalized_area_chart_{fs_splitting_cluster}.svg")
        fig.write_image(svg_path)

        # Save as HTML
        html_path = os.path.join(pca_path, f"feature_activation_normalized_area_chart_{fs_splitting_cluster}.html")
        fig.write_html(html_path)
    else:
        fig.show()
    return fig


In [None]:
plot_feature_activation_normalized_area_chart(results, fs_splitting_nodes, pca_df, pca_path, fs_splitting_cluster,  save=True)

In [None]:
import plotly.graph_objects as go

def plot_pca_domain(pca_df, pca_path, fs_splitting_cluster, save_figs=False):
    # Define colors for each category
    color_map = {
        'twitter': '#1DA1F2',  # Twitter blue
        'usat': '#FF0000',     # Red for USA Today
        'youtube': '#00FF00',  # YouTube red
        'other': '#CCCCCC'     # Gray for others
    }

    # Function to determine color
    def get_color(row):
        context = row['context'].lower()
        if 'twitter' in context or 't.co' in context:
            return color_map['twitter']
        elif 'usat' in context:
            return color_map['usat']
        elif 'watch?v=' in context:
            return color_map['youtube']
        else:
            return color_map['other']

    # Apply the function to get colors
    pca_df['color'] = pca_df.apply(get_color, axis=1)

    # Create the plot
    fig = go.Figure()

    # Add traces for colors (categories)
    for category, color in color_map.items():
        df_category = pca_df[pca_df['color'] == color]
        fig.add_trace(
            go.Scatter(
                x=df_category['PC2'],
                y=df_category['PC3'],
                mode='markers',
                marker=dict(color=color, size=8),
                name=category.capitalize(),
                text=[f"Token: {t}<br>Context: {c}" for t, c in zip(df_category['tokens'], df_category['context'])],
                hoverinfo='text'
            )
        )

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text=f"PCA Analysis - Cluster {fs_splitting_cluster} (Context Categories)",
        xaxis_title="PC2",
        yaxis_title="PC3",
        legend_title_text="Context Category"
    )

    fig.update_traces(marker=dict(size=12,
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

    if save_figs:
        # Save as PNG
        png_path = os.path.join(pca_path, f"pca_plot_context_{fs_splitting_cluster}.png")
        fig.write_image(png_path, scale=3.0)

        # Save as HTML
        html_path = os.path.join(pca_path, f"pca_plot_context_{fs_splitting_cluster}.html")
        fig.write_html(html_path)
    else:
        fig.show()

    return fig

In [None]:
plot_pca_domain(pca_df, pca_path, fs_splitting_cluster, save_figs=True)