In [None]:
from tools.cc_utils import load_latent_df


df_cc = load_latent_df()
df_k = load_latent_df("gemma-2-2b-L13-k100-lr1e-04-local-shuffling-CCLoss")
df_sae = load_latent_df(
    "gemma-2-2b-L13-mu5.2e-02-lr1e-04-2x100M-local-shuffling-SAELoss"
)
df_sae

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Define threshold values for the sweep
thresholds = np.arange(0.1, 0.7, 0.1)
counts_cc = []
counts_k = []
counts_sae = []

# Perform the sweep for all three dataframes
for thres in thresholds:
    # Count for df_cc
    count_cc = len(
        df_cc.query(
            f"-{thres} < beta_ratio_error < {thres} and -{thres} < beta_ratio_reconstruction < {thres}"
        )
    )
    counts_cc.append(count_cc)

    # Count for df_k
    count_k = len(
        df_k.query(
            f"-{thres} < beta_ratio_error < {thres} and -{thres} < beta_ratio_reconstruction < {thres}"
        )
    )
    counts_k.append(count_k)

    # Count for df_sae
    count_sae = len(
        df_sae.query(
            f"-{thres} < beta_ratio_error < {thres} and -{thres} < beta_ratio_reconstruction < {thres}"
        )
    )
    counts_sae.append(count_sae)

    print(
        f"Threshold {thres:.1f}: {count_cc} features (CC), {count_k} features (K), {count_sae} features (SAE)"
    )

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(
    thresholds, counts_cc, marker="o", linestyle="-", linewidth=2, label="CrossCoder"
)
plt.plot(thresholds, counts_k, marker="s", linestyle="--", linewidth=2, label="K=100")
plt.plot(thresholds, counts_sae, marker="^", linestyle="-.", linewidth=2, label="SAE")
plt.xlabel("Threshold Value", fontsize=12)
plt.ylabel("Number of Features", fontsize=12)
plt.title("Number of Features vs. Threshold Value", fontsize=14)
plt.grid(True, alpha=0.3)
plt.xticks(thresholds)
plt.legend()
plt.tight_layout()
plt.show()