In [44]:
from pathlib import Path

import pandas as pd
import torch
from sklearn.decomposition import PCA
from tqdm.autonotebook import tqdm
from transformer_lens import HookedTransformer

from matryoshka_cooc.subgraph_analysis import (
    analyze_subgraph,
    load_matryoshka_sae,
    load_node_info,
    plot_pca_by_feature_activation,
    plot_token_pca,
    set_device,
)

In [None]:
root = Path.cwd().parent.parent
root

In [None]:
# Configuration
config = {
    "sae_path": "checkpoints/gpt2-small_blocks.8.hook_resid_pre_24576_global-matryoshka-topk_32_0.0003_final.pt",
    "cooc_dir": "graph_cooc_n_batches_500_layer_8/matryoshka",
    "n_batches": 10,
    "threshold": 8.0,
    "subgraph_id": 516,
    "layer": 8,
    "remove_special_tokens": True,
    "max_examples": 5000,
    "output_dir": "matryoshka_pca_results",
}

torch.set_grad_enabled(False)

# Run analysis
analysis_results = analyze_subgraph(
    sae_path=root / config["sae_path"],
    cooc_dir=root / config["cooc_dir"],
    subgraph_id=config["subgraph_id"],
    layer=config["layer"],
    activation_threshold=config["threshold"],
    n_batches=config["n_batches"],
    remove_special_tokens=config["remove_special_tokens"],
    max_examples=config["max_examples"],
    output_dir=config["output_dir"],
)

# Extract components for later use
results = analysis_results["results"]
pca_df = analysis_results["pca_df"]
pca = analysis_results["pca"]
# model = analysis_results["model"]
# subgraph_nodes = analysis_results["subgraph_nodes"]

# Store important data in dictionary for easy access
analysis_data = {
    "pca_df": pca_df,
    "tokens": results["all_fired_tokens"],
    "contexts": results["all_token_dfs"].context.values,
    "reconstructions": results["all_reconstructions"],
    "feature_acts": results["all_graph_feature_acts"],
    "max_feature_info": results["all_max_feature_info"],
    "examples_found": results["all_examples_found"],
    # "subgraph_nodes": subgraph_nodes,
    "pca_explained_variance": pca.explained_variance_ratio_
    if pca is not None
    else None,
}

print(f"Analysis complete! Found {analysis_data['examples_found']} examples")
if pca is not None:
    print("\nPCA Explained Variance Ratios:")
    for i, ratio in enumerate(analysis_data["pca_explained_variance"]):
        print(f"PC{i + 1}: {ratio:.3f}")

In [None]:
all_token_df = analysis_results["results"].all_token_dfs
all_token_df.head()

In [None]:
all_graph_feature_acts = analysis_results["results"].all_graph_feature_acts
all_graph_feature_acts

In [None]:
pca_results = analysis_results["pca_df"]
pca_results

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from textblob import TextBlob


# Calculate sentiment scores for context
def get_sentiment_scores(text):
    blob = TextBlob(text)
    return blob.sentiment.polarity


# Calculate sentiment scores for all contexts
sentiment_scores = all_token_df["context"].apply(get_sentiment_scores)

# Convert the tensor to numpy for easier correlation analysis
graph_features_np = all_graph_feature_acts.cpu().numpy()

# Create a DataFrame with the graph features
feature_names = [f"feature_{i}" for i in range(graph_features_np.shape[1])]
graph_features_df = pd.DataFrame(graph_features_np, columns=feature_names)

# Add sentiment scores to the features DataFrame
graph_features_df["context_sentiment"] = sentiment_scores.values

# Calculate correlations
correlation_df = graph_features_df.corr()

# Create correlation heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(
    correlation_df[["context_sentiment"]].sort_values(
        "context_sentiment", ascending=False
    )[1:6],
    annot=True,
    cmap="coolwarm",
    center=0,
)
plt.title("Correlation between Graph Features and Context Sentiment")
plt.tight_layout()
plt.show()

# Print top 5 most correlated features
print("\nTop 5 features most correlated with sentiment:")
correlations = correlation_df["context_sentiment"].sort_values(ascending=False)
print(correlations[1:6])  # Skip first as it's the sentiment correlation with itself

In [None]:
all_token_df.head()

In [None]:
from itertools import combinations

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# 1. Create feature pairs and calculate correlations
n_features = graph_features_np.shape[1]
interaction_matrix = np.zeros((n_features, n_features))
pair_correlations = []

for i, j in combinations(range(n_features), 2):
    interaction_term = graph_features_np[:, i] * graph_features_np[:, j]
    correlation = np.corrcoef(interaction_term, sentiment_scores)[0, 1]

    # Store in matrix for heatmap
    interaction_matrix[i, j] = correlation
    interaction_matrix[j, i] = correlation

    # Store in list for ranking
    pair_correlations.append((f"feature_{i}*feature_{j}", correlation))

# Sort pairs by absolute correlation
pair_correlations.sort(key=lambda x: abs(x[1]), reverse=True)

# Plot top 15 feature pairs
plt.figure(figsize=(12, 6))
pairs, corrs = zip(*pair_correlations[:15])
plt.bar(range(len(corrs)), [abs(c) for c in corrs])
plt.xticks(range(len(corrs)), pairs, rotation=45, ha="right")
plt.title("Top 15 Feature Pair Correlations with Sentiment (Absolute Value)")
plt.ylabel("|Correlation|")
plt.tight_layout()
plt.show()

# Print actual correlation values (including sign)
print("\nTop 15 feature pairs and their correlations with sentiment:")
for pair, corr in pair_correlations[:15]:
    print(f"{pair}: {corr:.3f}")

# Plot interaction matrix heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(
    interaction_matrix,
    xticklabels=[f"F{i}" for i in range(n_features)],
    yticklabels=[f"F{i}" for i in range(n_features)],
    cmap="coolwarm",
    center=0,
    annot=False,
)
plt.title("Feature Pair Correlation Strength with Sentiment")
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Create a 2x2 subplot
fig, axs = plt.subplots(2, 2, figsize=(15, 15))
fig.suptitle("PCA Component 2 vs 3 with Different Colorings", fontsize=16)

# Convert PCA results to numpy if needed
if isinstance(pca_results, pd.DataFrame):
    pca_array = pca_results.values
else:
    pca_array = pca_results

# a) Color by sentiment
sc0 = axs[0, 0].scatter(
    pca_array[:, 1], pca_array[:, 2], c=sentiment_scores, cmap="RdYlBu"
)
axs[0, 0].set_title("Colored by Sentiment")
axs[0, 0].set_xlabel("PC2")
axs[0, 0].set_ylabel("PC3")
divider = make_axes_locatable(axs[0, 0])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(sc0, cax=cax)

# b) Color by feature 0 activation
sc1 = axs[0, 1].scatter(
    pca_array[:, 1], pca_array[:, 2], c=graph_features_np[:, 0], cmap="viridis"
)
axs[0, 1].set_title("Colored by Feature 0 Activation")
axs[0, 1].set_xlabel("PC2")
axs[0, 1].set_ylabel("PC3")
divider = make_axes_locatable(axs[0, 1])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(sc1, cax=cax)

# c) Color by feature 1 activation
sc2 = axs[1, 0].scatter(
    pca_array[:, 1], pca_array[:, 2], c=graph_features_np[:, 1], cmap="viridis"
)
axs[1, 0].set_title("Colored by Feature 1 Activation")
axs[1, 0].set_xlabel("PC2")
axs[1, 0].set_ylabel("PC3")
divider = make_axes_locatable(axs[1, 0])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(sc2, cax=cax)

# d) Color by product of features 1 and 3
feature_product = graph_features_np[:, 1] * graph_features_np[:, 3]
sc3 = axs[1, 1].scatter(
    pca_array[:, 1], pca_array[:, 2], c=feature_product, cmap="viridis"
)
axs[1, 1].set_title("Colored by Feature 1 × Feature 3")
axs[1, 1].set_xlabel("PC2")
axs[1, 1].set_ylabel("PC3")
divider = make_axes_locatable(axs[1, 1])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(sc3, cax=cax)

plt.tight_layout()
plt.show()

# Print correlations for reference
print("Correlations with sentiment:")
print(f"Feature 0: {np.corrcoef(graph_features_np[:, 0], sentiment_scores)[0, 1]:.3f}")
print(f"Feature 1: {np.corrcoef(graph_features_np[:, 1], sentiment_scores)[0, 1]:.3f}")
print(
    f"Feature 1 × Feature 3: {np.corrcoef(feature_product, sentiment_scores)[0, 1]:.3f}"
)