In [11]:
import wandb
import os
import pandas as pd

result_dir = "/remote-home/share/research/mechinterp/gpt2-dictionary/results"

In [12]:
df = pd.DataFrame(columns=['l0', 'ev', 'ce', 'layer', 'hook', 'prune'])
df_prune = pd.DataFrame(columns=['feature_pruned', 'layer', 'hook'])
for layer in range(12):
    for hook_suffix, hook_suffix_abbr in [("hook_mlp_out", "M"), ("hook_attn_out", "A")]:
        exp = f"L{layer}{hook_suffix_abbr}-l1-0.00012-lr-0.001-32x"
        with open(os.path.join(result_dir, exp, "train_wandb_id.txt"), "r") as f:
            train_id = f.read().strip()
        run = wandb.Api().run(f"fnlp-mechinterp/gpt2-sae/{train_id}")
        df.loc[exp] = [run.summary['metrics/l0'], run.summary['metrics/explained_variance'], run.summary['metrics/ce_loss_score'], layer, "MLP" if hook_suffix_abbr == "M" else "Attn", "Before Pruning"]
        
        with open(os.path.join(result_dir, exp, "prune_wandb_id.txt"), "r") as f:
            prune_id = f.read().strip()
        run = wandb.Api().run(f"fnlp-mechinterp/gpt2-sae/{prune_id}")
        df.loc[exp + "-prune"] = [run.summary['metrics/l0'], run.summary['metrics/explained_variance'], run.summary['metrics/ce_loss_score'], layer, "MLP" if hook_suffix_abbr == "M" else "Attn", "After Pruning"]
        df_prune.loc[exp] = [run.summary['sparsity/total_pruned_features'], layer, "MLP" if hook_suffix_abbr == "M" else "Attn"]

In [13]:
import plotly.express as px
fig_l0 = px.line(df, x="layer", y="l0", color="prune", title="L0 vs Layer", facet_col="hook")
fig_l0.show()
fig_ev = px.line(df, x="layer", y="ev", color="prune", title="Explained Variance vs Layer", facet_col="hook")
fig_ev.show()
fig_ce = px.line(df, x="layer", y="ce", color="prune", title="CE Score vs Layer", facet_col="hook")
fig_ce.show()
fig_prune = px.line(df_prune, x="layer", y="feature_pruned", title="Total Pruned Features vs Layer", facet_col="hook")
fig_prune.show()