In [28]:
import os
from os.path import join as pj

import pandas as pd
import plotly.graph_objects as go

from sae_cooccurrence.utils.set_paths import get_git_root

In [5]:
# Load the CSV files for all folders that start with layer_12_width_16k_average
base_path = pj(get_git_root(), "results", "gemma-2-2b", "gemma-scope-2b-pt-res")
folders = [
    f for f in os.listdir(base_path) if f.startswith("layer_12_width_16k_average")
]

node_info_dfs = []
for folder in folders:
    file_path = pj(base_path, folder, "dataframes", "node_info_df_1_5.csv")
    if os.path.exists(file_path):
        df = pd.read_csv(file_path)
        df["l0_value"] = int(
            folder.split("average_l0_")[1]
        )  # Extract number after 'average_l0_'
        node_info_dfs.append(df)

# Combine all dataframes
node_info_df = pd.concat(node_info_dfs, ignore_index=True)

In [6]:
node_info_df

Unnamed: 0,node_id,activity_threshold,subgraph_id,subgraph_size,feature_activations,top_10_tokens,neuronpedia_link,density,max_avg_degree_ratio,avg_clustering,diameter,single_node_score,hub_spoke_score,strongly_connected_score,linear_score,quicklist_link,l0_value
0,0,1.5,0,1,3860.0,"['```', '\ufeff/**', 'стви', ""'];?>"", 'Kariera...",,0.000000,0.000000,0.000000,0,1.0,0.000000,0.00000,0.000000,,445
1,1,1.5,1,1,1051.0,"['OnInit', ' وتسجيلات', ' ModelExpression', ' ...",,0.000000,0.000000,0.000000,0,1.0,0.000000,0.00000,0.000000,,445
2,8384,1.5,2,17,430.0,"['Datuak', ' hashtag', ' Савезне', '!#', ' #',...",,0.345588,2.712766,0.739101,3,0.0,0.541574,0.90638,0.552046,,445
3,2,1.5,2,17,889.0,"['PhysRev', 'anchor', ' anchor', 'jaan', 'crea...",,0.345588,2.712766,0.739101,3,0.0,0.541574,0.90638,0.552046,,445
4,3582,1.5,2,17,176.0,"[' AssemblyProduct', '曖昧さ回避', 'Personensuche',...",,0.345588,2.712766,0.739101,3,0.0,0.541574,0.90638,0.552046,,445
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
81915,16379,1.5,13787,1,1.0,"['expandindo', ' مشين', ' дописавши', ' bezeic...",,0.000000,0.000000,0.000000,0,1.0,0.000000,0.00000,0.000000,,22
81916,16380,1.5,13788,1,715.0,"[' a', ' an', ' eventual', ' resulting', ' man...",,0.000000,0.000000,0.000000,0,1.0,0.000000,0.00000,0.000000,,22
81917,16381,1.5,13789,1,566.0,"[' into', ' onto', 'Datuak', ' to', ' INTO', '...",,0.000000,0.000000,0.000000,0,1.0,0.000000,0.00000,0.000000,,22
81918,16382,1.5,13790,1,63.0,"['HtmlAttribute', ' Incidentally', '顺便', ' &__...",,0.000000,0.000000,0.000000,0,1.0,0.000000,0.00000,0.000000,,22


In [15]:
# Calculate mean subgraph size for all subgraphs
mean_subgraph_size = (
    node_info_df.groupby("l0_value")["subgraph_size"].mean().reset_index()
)

# Calculate mean subgraph size excluding subgraphs of size 1
mean_subgraph_size_excluding_ones = (
    node_info_df[node_info_df["subgraph_size"] > 1]
    .groupby("l0_value")["subgraph_size"]
    .mean()
    .reset_index()
)

In [19]:
# Create the plot using Plotly
fig = go.Figure()

# Plot mean subgraph size including all subgraphs
fig.add_trace(
    go.Scatter(
        x=mean_subgraph_size["l0_value"],
        y=mean_subgraph_size["subgraph_size"],
        mode="markers+lines",
        name="Including size 1 subgraphs",
        marker=dict(size=10, color="blue"),
        line=dict(color="blue"),
    )
)

# Plot mean subgraph size excluding subgraphs of size 1
fig.add_trace(
    go.Scatter(
        x=mean_subgraph_size_excluding_ones["l0_value"],
        y=mean_subgraph_size_excluding_ones["subgraph_size"],
        mode="markers+lines",
        name="Excluding size 1 subgraphs",
        marker=dict(size=10, color="red"),
        line=dict(color="red"),
    )
)

fig.update_layout(
    title="Mean Subgraph Size vs L0 Value",
    xaxis_title="L0 Value",
    yaxis_title="Mean Subgraph Size",
    height=800,
    width=800,
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)

# Adjust x-axis to show all L0 values
fig.update_xaxes(tickmode="array", tickvals=mean_subgraph_size["l0_value"])

# Save the plot as HTML and PNG
output_dir = pj(
    get_git_root(),
    "results",
    "gemma-2-2b",
    "gemma-scope-2b-pt-res",
    "mean_subgraph_size_analysis",
)
os.makedirs(output_dir, exist_ok=True)
fig.write_html(pj(output_dir, "mean_subgraph_size_vs_l0_value.html"))
fig.write_image(pj(output_dir, "mean_subgraph_size_vs_l0_value.png"), scale=4.0)

# Display the plot
fig.show()

In [18]:
# Calculate the fraction of subgraphs that are size 1
fraction_size_1 = (
    node_info_df.groupby("l0_value")
    .apply(lambda x: (x["subgraph_size"] == 1).mean())
    .reset_index()
)
fraction_size_1.columns = ["l0_value", "fraction_size_1"]

# Create a new figure for the fraction of size 1 subgraphs
fig_fraction = go.Figure()

fig_fraction.add_trace(
    go.Scatter(
        x=fraction_size_1["l0_value"],
        y=fraction_size_1["fraction_size_1"],
        mode="markers+lines",
        name="Fraction of size 1 subgraphs",
        marker=dict(size=10, color="green"),
        line=dict(color="green"),
    )
)

fig_fraction.update_layout(
    title="Fraction of Subgraphs of Size 1 vs L0 Value",
    xaxis_title="L0 Value",
    yaxis_title="Fraction of Subgraphs of Size 1",
    height=800,
    width=800,
    yaxis=dict(range=[0, 1]),  # Set y-axis range from 0 to 1
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)

# Adjust x-axis to show all L0 values
fig_fraction.update_xaxes(tickmode="array", tickvals=fraction_size_1["l0_value"])

# Save the plot as HTML and PNG
fig_fraction.write_html(pj(output_dir, "fraction_size_1_subgraphs_vs_l0_value.html"))
fig_fraction.write_image(
    pj(output_dir, "fraction_size_1_subgraphs_vs_l0_value.png"), scale=4.0
)

# Display the plot
fig_fraction.show()





In [20]:
# Calculate the total number of feature activations in graphs of size 1 for each width
total_activations_size_1 = (
    node_info_df[node_info_df["subgraph_size"] == 1]
    .groupby("l0_value")["activity_threshold"]
    .count()
    .reset_index(name="total_activations_size_1")
)

# Calculate the total number of feature activations for each width
total_activations = (
    node_info_df.groupby("l0_value")["activity_threshold"]
    .count()
    .reset_index(name="total_activations")
)

# Merge the two dataframes
activations_df = pd.merge(total_activations_size_1, total_activations, on="l0_value")

# Calculate the fraction of activations in graphs of size 1
activations_df["fraction_activations_size_1"] = (
    activations_df["total_activations_size_1"] / activations_df["total_activations"]
)

# Create a new figure for the fraction of activations in size 1 subgraphs
fig_activations = go.Figure()

fig_activations.add_trace(
    go.Scatter(
        x=activations_df["l0_value"],
        y=activations_df["fraction_activations_size_1"],
        mode="markers+lines",
        name="Fraction of activations in size 1 subgraphs",
        marker=dict(size=10, color="purple"),
        line=dict(color="purple"),
    )
)

fig_activations.update_layout(
    title="Fraction of Feature Activations in Size 1 Subgraphs vs L0 Value",
    xaxis_title="L0 Value",
    yaxis_title="Fraction of Feature Activations in Size 1 Subgraphs",
    height=800,
    width=800,
    yaxis=dict(range=[0, 1]),  # Set y-axis range from 0 to 1
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)

# Adjust x-axis to show all L0 values
fig_activations.update_xaxes(tickmode="array", tickvals=activations_df["l0_value"])

# Save the plot as HTML and PNG
fig_activations.write_html(
    pj(output_dir, "fraction_activations_size_1_subgraphs_vs_l0_value.html")
)
fig_activations.write_image(
    pj(output_dir, "fraction_activations_size_1_subgraphs_vs_l0_value.png"), scale=4.0
)

# Display the plot
fig_activations.show()

In [23]:
# Load the CSV files for all folders that start with layer_12_width_ of the form layer_12_width_16k_canonical
base_path = pj(
    get_git_root(), "results", "gemma-2-2b", "gemma-scope-2b-pt-res-canonical"
)
folders = [f for f in os.listdir(base_path) if f.startswith("layer_12_width_")]

node_info_dfs = []
for folder in folders:
    file_path = pj(base_path, folder, "dataframes", "node_info_df_1_5.csv")
    if os.path.exists(file_path):
        df = pd.read_csv(file_path)
        df["width"] = int(
            folder.split("width_")[1].split("k")[0]
        )  # Extract number before 'k' after 'width_'
        node_info_dfs.append(df)

# Combine all dataframes
node_info_df_widths = pd.concat(node_info_dfs, ignore_index=True)

In [24]:
node_info_df_widths

Unnamed: 0,node_id,activity_threshold,subgraph_id,subgraph_size,feature_activations,top_10_tokens,neuronpedia_link,density,max_avg_degree_ratio,avg_clustering,diameter,single_node_score,hub_spoke_score,strongly_connected_score,linear_score,quicklist_link,width
0,0,1.5,0,1,19.0,"[' TextStyle', ' MessageBoxIcon', 'lihood', ' ...",https://neuronpedia.org/gemma-2-2b/12-gemmasco...,0.0,0.0,0.0,0,1.0,0.000000,0.0,0.000000,https://neuronpedia.org/quick-list/?name=tempo...,65
1,1,1.5,1,1,176.0,"[' ', ' Christian', ' the', ' C', 'bibitem', '...",https://neuronpedia.org/gemma-2-2b/12-gemmasco...,0.0,0.0,0.0,0,1.0,0.000000,0.0,0.000000,https://neuronpedia.org/quick-list/?name=tempo...,65
2,2,1.5,2,1,3514.0,"[' is', ' not', ' a', ' also', ' are', ' was',...",https://neuronpedia.org/gemma-2-2b/12-gemmasco...,0.0,0.0,0.0,0,1.0,0.000000,0.0,0.000000,https://neuronpedia.org/quick-list/?name=tempo...,65
3,3,1.5,3,1,139.0,"[' actual', 'actual', ' ACTUAL', ' Actual', 'A...",https://neuronpedia.org/gemma-2-2b/12-gemmasco...,0.0,0.0,0.0,0,1.0,0.000000,0.0,0.000000,https://neuronpedia.org/quick-list/?name=tempo...,65
4,62465,1.5,4,7,7.0,"['normal', ' normal', 'Normal', ' Normal', ' N...",https://neuronpedia.org/gemma-2-2b/12-gemmasco...,1.0,1.0,1.0,1,0.0,0.499701,1.0,0.500299,https://neuronpedia.org/quick-list/?name=tempo...,65
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
114683,32763,1.5,31265,1,2.0,"['Skocz', 'Hentet', ' barter', ' fossa', 'Cros...",https://neuronpedia.org/gemma-2-2b/12-gemmasco...,0.0,0.0,0.0,0,1.0,0.000000,0.0,0.000000,https://neuronpedia.org/quick-list/?name=tempo...,32
114684,32764,1.5,31266,1,158.0,"['BrowserModule', ' autorytatywna', ' BrowserM...",https://neuronpedia.org/gemma-2-2b/12-gemmasco...,0.0,0.0,0.0,0,1.0,0.000000,0.0,0.000000,https://neuronpedia.org/quick-list/?name=tempo...,32
114685,32765,1.5,31267,1,262.0,"[' Roskov', 'spesies', ' braccio', 'edicated',...",https://neuronpedia.org/gemma-2-2b/12-gemmasco...,0.0,0.0,0.0,0,1.0,0.000000,0.0,0.000000,https://neuronpedia.org/quick-list/?name=tempo...,32
114686,32766,1.5,31268,1,774.0,"['Datuak', ' Wikiseite', 'nsis', ' ADV', ' sat...",https://neuronpedia.org/gemma-2-2b/12-gemmasco...,0.0,0.0,0.0,0,1.0,0.000000,0.0,0.000000,https://neuronpedia.org/quick-list/?name=tempo...,32


In [30]:
# Calculate mean subgraph size for each width
# Calculate mean subgraph size for each width
mean_subgraph_size = (
    node_info_df_widths.groupby("width")["subgraph_size"].mean().reset_index()
)

# Calculate mean subgraph size excluding subgraphs of size 1
mean_subgraph_size_excluding_ones = (
    node_info_df_widths[node_info_df_widths["subgraph_size"] > 1]
    .groupby("width")["subgraph_size"]
    .mean()
    .reset_index()
)

# Create the plot using Plotly
fig = go.Figure()

# Plot mean subgraph size including all subgraphs
fig.add_trace(
    go.Scatter(
        x=mean_subgraph_size["width"],
        y=mean_subgraph_size["subgraph_size"],
        mode="markers+lines",
        name="Including size 1 subgraphs",
        marker=dict(size=10, color="blue"),
        line=dict(color="blue"),
    )
)

# Plot mean subgraph size excluding subgraphs of size 1
fig.add_trace(
    go.Scatter(
        x=mean_subgraph_size_excluding_ones["width"],
        y=mean_subgraph_size_excluding_ones["subgraph_size"],
        mode="markers+lines",
        name="Excluding size 1 subgraphs",
        marker=dict(size=10, color="red"),
        line=dict(color="red"),
    )
)

fig.update_layout(
    title="Mean Subgraph Size vs Width",
    xaxis_title="Width",
    yaxis_title="Mean Subgraph Size",
    height=800,
    width=800,
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)

# Adjust x-axis to show all widths
fig.update_xaxes(tickmode="array", tickvals=mean_subgraph_size["width"])

# Save the plot as HTML and PNG
output_dir = pj(
    get_git_root(),
    "results",
    "gemma-2-2b",
    "gemma-scope-2b-pt-res-canonical",
    "mean_subgraph_size_analysis",
)
os.makedirs(output_dir, exist_ok=True)
fig.write_html(pj(output_dir, "mean_subgraph_size_vs_width.html"))
fig.write_image(pj(output_dir, "mean_subgraph_size_vs_width.png"), scale=4.0)

# Display the plot
fig.show()

In [31]:
# Calculate the fraction of subgraphs that are size 1
fraction_size_1 = (
    node_info_df_widths.groupby("width")
    .apply(lambda x: (x["subgraph_size"] == 1).mean())
    .reset_index()
)
fraction_size_1.columns = ["width", "fraction_size_1"]

# Create a new figure for the fraction of size 1 subgraphs
fig_fraction = go.Figure()

fig_fraction.add_trace(
    go.Scatter(
        x=fraction_size_1["width"],
        y=fraction_size_1["fraction_size_1"],
        mode="markers+lines",
        name="Fraction of size 1 subgraphs",
        marker=dict(size=10, color="green"),
        line=dict(color="green"),
    )
)

fig_fraction.update_layout(
    title="Fraction of Subgraphs of Size 1 vs Width",
    xaxis_title="Width",
    yaxis_title="Fraction of Subgraphs of Size 1",
    height=800,
    width=800,
    yaxis=dict(range=[0, 1]),  # Set y-axis range from 0 to 1
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)

# Adjust x-axis to show all width values
fig_fraction.update_xaxes(tickmode="array", tickvals=fraction_size_1["width"])

# Save the plot as HTML and PNG
fig_fraction.write_html(pj(output_dir, "fraction_size_1_subgraphs_vs_width.html"))
fig_fraction.write_image(
    pj(output_dir, "fraction_size_1_subgraphs_vs_width.png"), scale=4.0
)

# Display the plot
fig_fraction.show()



