SAE Training and Evaluation

Train Sparse Autoencoder (SAE) on Evo2 embeddings and evaluate performance for
distinguishing true positive vs false positive structural variants.

In [None]:
import os
import time
import json
import math
import pathlib
import datetime
import gc
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import precision_score, recall_score, f1_score, average_precision_score
from sklearn.model_selection import StratifiedKFold, train_test_split, cross_val_score
from sklearn.linear_model import LogisticRegression
from scipy.stats import chi2_contingency
from statsmodels.stats.multitest import multipletests
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Global settings
DRIVE_ROOT = "../data/models"
EPOCHS = 400
BATCH_SIZE = 32
LR = 3e-4
λ_L1 = 0.05
β_KL = 0.5
ρ_TARGET = 0.02  # Desired active-feature probability
λ_ORTH = 0.01
INPUT_DIM = 4096  # Evo-2 layer-26 embedding dimension
FEATURE_DIM = 4096  # Selected based on interpretability
K_ACTIVE = 64  # Selected based on interpretability
device = "cuda" if torch.cuda.is_available() else "cpu"

# Class weights for TP/FP imbalance (TP: 88%, FP: 12%)
CLASS_WEIGHTS = torch.tensor([1.0, 7.17]).to(device)  # Inverse frequency: FP weight = TP_count/FP_count ≈ 7.17

print(f"Device: {device}")
print(f"Model will be saved to: {DRIVE_ROOT}")

In [None]:
# SAE Model Definition

class BatchTopKSAE(torch.nn.Module):
    """Sparse Autoencoder with Top-K activation"""

    def __init__(self, input_dim, feature_dim, k_active):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = feature_dim
        self.k_active = k_active

        # Encoder: input -> features
        self.encoder = torch.nn.Linear(input_dim, feature_dim, bias=True)
        # Decoder: features -> input (no bias)
        self.decoder = torch.nn.Linear(feature_dim, input_dim, bias=False)

        # Initialize weights
        torch.nn.init.kaiming_uniform_(self.encoder.weight)
        torch.nn.init.kaiming_uniform_(self.decoder.weight)

    def forward(self, x):
        # Encode
        features = self.encoder(x)

        # Top-K sparsification
        batch_size = x.size(0)
        k = min(self.k_active, self.feature_dim)

        # Get top-k values and indices
        topk_values, topk_indices = torch.topk(features, k, dim=-1)

        # Create sparse feature tensor
        sparse_features = torch.zeros_like(features)
        sparse_features.scatter_(-1, topk_indices, topk_values)

        # Decode
        reconstructed = self.decoder(sparse_features)

        return {
            'reconstructed': reconstructed,
            'sparse_features': sparse_features,
            'dense_features': features
        }

    def get_feature_activations(self, x):
        """Get active feature indices for analysis"""
        with torch.no_grad():
            features = self.encoder(x)
            k = min(self.k_active, self.feature_dim)
            _, topk_indices = torch.topk(features, k, dim=-1)
            return topk_indices, features

In [None]:
# Load and whiten training embeddings
print("Loading training embeddings...")
train_pkg = torch.load("../data/processed/sae_sv_embeddings.pt", map_location="cpu")
train_emb = train_pkg["embeddings"].float().to(device)  # [N_train, 4096]
train_sv_info = train_pkg["sv_info"]

# Whiten embeddings
mu, σ = train_emb.mean(0, keepdim=True), train_emb.std(0, keepdim=True) + 1e-6
train_emb_w = (train_emb - mu) / σ
train_loader = DataLoader(TensorDataset(train_emb_w), batch_size=BATCH_SIZE, shuffle=True)
N_train = train_emb_w.size(0)
print(f"Loaded {N_train:,} training embeddings (standardized)")

# Create labels tensor (TP=1, FP=0)
train_labels = torch.tensor([1 if sv['truvari_class'] == 'TP' else 0 for sv in train_sv_info], dtype=torch.long).to(device)

In [None]:

# Train SAE
tag = f"{INPUT_DIM}to{FEATURE_DIM}_k{K_ACTIVE}"
out_dir = pathlib.Path(DRIVE_ROOT) / tag
ckpt_path = out_dir / "sae.pt"

# Check if model already exists
if ckpt_path.exists():
    print(f"Found existing SAE model at {ckpt_path}, skipping training")
    # Load existing model and metadata
    sae = BatchTopKSAE(INPUT_DIM, FEATURE_DIM, K_ACTIVE).to(device).float()
    pkg = torch.load(ckpt_path, map_location=device)
    sae.load_state_dict(pkg["model_state_dict"])
    loss_curve = pkg["loss_curve"]
    meta_path = out_dir / "meta.json"
    if meta_path.exists():
        meta = json.load(open(meta_path))
        atoms_used = meta["atoms_used"]
        sparsity = meta["sparsity_ratio"]
        train_secs = meta["train_minutes"] * 60
    else:
        # Compute atoms used and sparsity if metadata is missing
        sae.eval()
        with torch.no_grad():
            acts, _ = sae.get_feature_activations(train_emb_w)
        atoms_used = int(torch.unique(acts).numel())
        sparsity = K_ACTIVE / FEATURE_DIM
        train_secs = 0  # Unknown without metadata
    print(f"   atoms-used: {atoms_used}/{FEATURE_DIM}  (sparsity ≈ {K_ACTIVE}/{FEATURE_DIM} = {sparsity:.4f})")
else:
    print(f"Training SAE {INPUT_DIM} → {FEATURE_DIM} (k={K_ACTIVE})")
    t0 = time.time()

    sae = BatchTopKSAE(INPUT_DIM, FEATURE_DIM, K_ACTIVE).to(device).float()
    opt = torch.optim.Adam(sae.parameters(), lr=LR)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

    loss_curve = []
    sae.train()
    for ep in range(EPOCHS):
        running_loss = 0.0
        for i, (x,) in enumerate(train_loader):
            x = x.to(device)
            batch_labels = train_labels[i * BATCH_SIZE: (i + 1) * BATCH_SIZE].to(device)
            out = sae(x)
            s_raw = out["sparse_features"]
            s = F.relu(s_raw)  # Enforce non-negativity

            # Losses
            recon = F.mse_loss(out["reconstructed"], x)
            l1 = λ_L1 * s.mean()
            act_mask = (s > 0).float()
            p_hat = act_mask.mean(0)
            kl = β_KL * (ρ_TARGET * torch.log(ρ_TARGET / p_hat.clamp(1e-4, 1-1e-4)) +
                         (1 - ρ_TARGET) * torch.log((1 - ρ_TARGET) / (1 - p_hat.clamp(1e-4, 1-1e-4)))).mean()
            W = sae.decoder.weight
            orth = λ_ORTH * ((W @ W.T - torch.eye(W.size(0), device=W.device))**2).mean()

            # Class-weighted reconstruction loss
            weights = CLASS_WEIGHTS[batch_labels].view(-1, 1)
            weighted_recon = (weights * F.mse_loss(out["reconstructed"], x, reduction='none').mean(dim=1)).mean()

            loss = weighted_recon + l1 + kl + orth
            opt.zero_grad()
            loss.backward()
            opt.step()
            running_loss += loss.item()

            # Weight normalization
            with torch.no_grad():
                W.div_(W.norm(dim=1, keepdim=True) + 1e-8)

        sched.step()
        loss_curve.append(running_loss / len(train_loader))
        if ep % 100 == 0 or ep == EPOCHS - 1:
            print(f"  epoch {ep:3d}  loss={loss_curve[-1]:.4f}")

    train_secs = time.time() - t0
    print(f"Training finished in {train_secs/60:.1f} min")

    # Compute sparsity and usage stats
    sae.eval()
    with torch.no_grad():
        acts, _ = sae.get_feature_activations(train_emb_w)  # [N_train, k]
    atoms_used = int(torch.unique(acts).numel())
    sparsity = K_ACTIVE / FEATURE_DIM
    print(f"   atoms-used: {atoms_used}/{FEATURE_DIM}  (sparsity ≈ {K_ACTIVE}/{FEATURE_DIM} = {sparsity:.4f})")

    # Save model and config
    out_dir.mkdir(parents=True, exist_ok=True)
    torch.save({
        "model_state_dict": sae.state_dict(),
        "config": dict(input_dim=INPUT_DIM, feature_dim=FEATURE_DIM, k=K_ACTIVE, μ=mu.cpu(), σ=σ.cpu()),
        "loss_curve": loss_curve
    }, ckpt_path)

    meta = dict(
        feature_dim=FEATURE_DIM,
        k_active=K_ACTIVE,
        epochs=EPOCHS,
        lr=LR,
        l1=λ_L1,
        kl_beta=β_KL,
        kl_rho=ρ_TARGET,
        orth=λ_ORTH,
        atoms_used=atoms_used,
        sparsity_ratio=sparsity,
        train_minutes=round(train_secs/60, 2),
        n_samples=N_train
    )
    json.dump(meta, open(out_dir / "meta.json", "w"), indent=2)
    print(f"Saved → {ckpt_path}")

# Clean up to free memory
torch.cuda.empty_cache()
gc.collect()

In [None]:
# Load testing embeddings for evaluation
print("Loading testing embeddings...")
test_pkg = torch.load("../data/processed/sae_sv_embeddings.pt", map_location="cpu")
test_emb = test_pkg["embeddings"].float().to(device)  # [N_test, 4096]
test_sv_info = test_pkg["sv_info"]

# Debug: Verify test_sv_info
print("Debugging test_sv_info:")
print(f"Total test samples: {len(test_sv_info)}")
truvari_classes = [sv['truvari_class'] for sv in test_sv_info]
print(f"Unique truvari_class values: {set(truvari_classes)}")
print(f"truvari_class counts: {pd.Series(truvari_classes).value_counts().to_dict()}")

# Create labels (TP=1, FP=0) with explicit mapping
test_labels = torch.tensor([
    1 if sv['truvari_class'] in ['TP', 'tp_comp_vcf'] else 0
    for sv in test_sv_info
], dtype=torch.long).to(device)

# Convert test_labels to boolean for all logical operations
test_labels = test_labels.bool()

# Verify labels
print(f"Length of test_labels: {len(test_labels)}")
print(f"Label counts: {torch.unique(test_labels, return_counts=True)}")
print(f"Length of test_emb: {test_emb.shape[0]}")
assert len(test_labels) == test_emb.shape[0], "Mismatch between test embeddings and labels"
assert len(torch.unique(test_labels)) > 1, "Test labels contain only one class"

# Whiten test embeddings using training mean and std
test_emb_w = (test_emb - mu) / σ

# Check if evaluation results exist
pointer_file = pathlib.Path(DRIVE_ROOT) / "BEST_MODEL.json"
csv_path = pathlib.Path(DRIVE_ROOT) / "sae_eval_metrics.csv"
current_model_dir = str(out_dir)

records = []
baseline_f1 = None
baseline_f1_std = None

In [None]:
if csv_path.exists():
    print(f"Found existing evaluation results at {csv_path}")
    df = pd.read_csv(csv_path)
    # Filter for current SAE and Goodfire SAE
    relevant_models = [current_model_dir, "Goodfire_Evo2_SAE"]
    df_filtered = df[df['model_dir'].isin(relevant_models)]
    if len(df_filtered) == len(relevant_models):
        print(f"Found results for {current_model_dir} and Goodfire_Evo2_SAE, skipping evaluation")
        # Extract baseline metrics (same for all rows)
        baseline_f1 = float(df['Baseline_F1'].iloc[0])
        baseline_f1_std = float(df['Baseline_F1_std'].iloc[0])
        records = df_filtered.to_dict('records')
    else:
        print(f"Missing results for {current_model_dir}, computing evaluation")
else:
    print(f"No existing evaluation results at {csv_path}, computing evaluation")

# %%
if not records:
    # Logistic regression baseline on raw embeddings
    print("Computing baseline F1 on raw embeddings...")
    clf = LogisticRegression(max_iter=1000, class_weight="balanced", n_jobs=-1)
    try:
        scores = cross_val_score(clf, test_emb_w.cpu().numpy(), test_labels.cpu().numpy(), cv=5, scoring='f1')
        baseline_f1 = float(np.mean(scores))
        baseline_f1_std = float(np.std(scores))
        print(f"Baseline F1 on raw embeddings: {baseline_f1:.3f} ± {baseline_f1_std:.3f}")
    except ValueError as e:
        print(f"Error in baseline F1 calculation: {e}")
        baseline_f1, baseline_f1_std = 0.0, 0.0

    # Evaluate trained SAE
    print("Evaluating trained SAE...")
    sae.eval()
    with torch.no_grad():
        acts, _ = sae.get_feature_activations(test_emb_w)  # Shape: (N_test, k_active)

    N = acts.size(0)
    ind_mat = torch.zeros(N, FEATURE_DIM, dtype=torch.bool, device=device)
    ind_mat.scatter_(1, acts, True)

    tp_total = int(test_labels.sum())
    fp_total = int((~test_labels).sum())

    valid, pvals, odds, delta = [], [], [], []
    for atom in range(FEATURE_DIM):
        tp_on = int(ind_mat[test_labels, atom].sum())
        fp_on = int(ind_mat[~test_labels, atom].sum())
        if tp_on + fp_on < 10:
            continue
        tp_off, fp_off = tp_total - tp_on, fp_total - fp_on
        _, p, _, _ = chi2_contingency([[tp_on, fp_on], [tp_off, fp_off]], correction=False)
        OR = (tp_on * fp_off) / max(1, fp_on * tp_off)
        Δprop = abs(tp_on / tp_total - fp_on / fp_total)
        valid.append(atom)
        pvals.append(p)
        odds.append(OR)
        delta.append(Δprop)

    FP_atoms = {a for a, o in zip(valid, odds) if o < 0.5}
    TP_atoms = {a for a, o in zip(valid, odds) if o > 2.0}

    fp_tensor = torch.tensor(sorted(FP_atoms), device=device)
    best_f1, best_t, best_tp, best_fp = 0, None, None, None
    best_precision, best_recall = 0, 0

    for t in range(1, K_ACTIVE + 1):
        fp_counts = torch.isin(acts, fp_tensor).sum(dim=1)
        keep = (fp_counts < t)
        tp_kept = int((keep & test_labels).sum())
        fp_kept = int((keep & ~test_labels).sum())
        p = precision_score(test_labels.cpu(), keep.cpu())
        r = recall_score(test_labels.cpu(), keep.cpu())
        f = f1_score(test_labels.cpu(), keep.cpu())
        if f > best_f1:
            best_f1, best_t, best_tp, best_fp = f, t, tp_kept, fp_kept
            best_precision, best_recall = p, r

    # Logistic regression on SAE activations
    X = ind_mat.cpu().numpy().astype("uint8")
    y = test_labels.cpu().numpy().astype("uint8")
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    lr = LogisticRegression(max_iter=1000, class_weight="balanced", n_jobs=-1)

    cv_auprc_scores = []
    cv_precision_scores = []
    cv_recall_scores = []
    cv_f1_scores = []

    for tr, val in cv.split(X_train, y_train):
        lr.fit(X_train[tr], y_train[tr])
        prob = lr.predict_proba(X_train[val])[:, 1]
        pred = lr.predict(X_train[val])
        cv_auprc_scores.append(average_precision_score(y_train[val], prob))
        cv_precision_scores.append(precision_score(y_train[val], pred))
        cv_recall_scores.append(recall_score(y_train[val], pred))
        cv_f1_scores.append(f1_score(y_train[val], pred))

    lr.fit(X_train, y_train)
    prob_test = lr.predict_proba(X_test)[:, 1]
    pred_test = lr.predict(X_test)
    test_precision = precision_score(y_test, pred_test)
    test_recall = recall_score(y_test, pred_test)
    test_f1 = f1_score(y_test, pred_test)

    records.append({
        'model_dir': str(out_dir),
        'feature_dim': FEATURE_DIM,
        'k_active': K_ACTIVE,
        'sparsity_ratio': K_ACTIVE / FEATURE_DIM,
        'atoms_used': int(torch.unique(acts).numel()),
        'FP_sig_atoms': len(FP_atoms),
        'TP_sig_atoms': len(TP_atoms),
        'TP_retained': best_tp,
        'FP_remaining': best_fp,
        'best_hard_F1': best_f1,
        'best_hard_Precision': best_precision,
        'best_hard_Recall': best_recall,
        'best_t': best_t,
        'LogReg_CV_AUPRC': float(np.mean(cv_auprc_scores)),
        'LogReg_CV_Precision': float(np.mean(cv_precision_scores)),
        'LogReg_CV_Recall': float(np.mean(cv_recall_scores)),
        'LogReg_CV_F1': float(np.mean(cv_f1_scores)),
        'LogReg_test_AUPRC': float(average_precision_score(y_test, prob_test)),
        'LogReg_test_Precision': float(test_precision),
        'LogReg_test_Recall': float(test_recall),
        'LogReg_test_F1': float(test_f1),
        'Baseline_F1': baseline_f1,
        'Baseline_F1_std': baseline_f1_std
    })

    # Evaluate Goodfire Evo-2 SAE
    print("Evaluating Goodfire Evo-2 SAE...")
    # Check if Goodfire results exist in df
    if csv_path.exists() and 'df' in locals() and not df.empty and "Goodfire_Evo2_SAE" in df['model_dir'].values:
        print("Goodfire Evo-2 SAE evaluation already in results, reusing")
        goodfire_record = df[df['model_dir'] == "Goodfire_Evo2_SAE"].iloc[0].to_dict()
        records.append(goodfire_record)
    else:
        # Only try to evaluate Goodfire if the required objects exist from previous notebooks
        try:
            # Check if we have the necessary objects from the embedding extraction notebook
            if 'evo2_model' in globals() and 'SV_Evo2_Encoder' in globals():
                print("Found evo2_model and SV_Evo2_Encoder from previous notebooks")
                # Generate or load Goodfire activations
                if 'evo2_activations' not in globals() or 'evo2_config' not in globals():
                    print("Generating Goodfire Evo-2 activations for test set...")
                    sv_encoder = SV_Evo2_Encoder(evo2_model)
                    with open('../data/processed/sae_sequences.json', 'r') as f:
                        seq_data = json.load(f)
                        test_sequences = seq_data['sequences']
                    evo2_activations = sv_encoder.extract_embeddings_for_svs(test_sequences, batch_size=4, layer=26)
                    evo2_config = {"feature_dim": 32768}  # Adjust based on Goodfire SAE specs

                evo2_activations = evo2_activations[:len(test_sv_info)]  # Ensure alignment
                F = evo2_config["feature_dim"]
                k_active = evo2_activations.shape[1]

                ind_mat_goodfire = torch.zeros(len(test_sv_info), F, dtype=torch.bool, device=device)
                row_idx = torch.arange(len(test_sv_info)).view(-1, 1).expand(-1, k_active)
                ind_mat_goodfire[row_idx, evo2_activations] = True

                atoms_used = int(ind_mat_goodfire.any(0).sum())
                sparsity = ind_mat_goodfire.float().mean().item()

                valid, pvals, odds, delta = [], [], [], []
                for a in torch.nonzero(ind_mat_goodfire.any(0), as_tuple=False).flatten():
                    tp_on = int(ind_mat_goodfire[test_labels, a].sum())
                    fp_on = int(ind_mat_goodfire[~test_labels, a].sum())
                    if tp_on + fp_on < 10:
                        continue
                    tp_off, fp_off = tp_total - tp_on, fp_total - fp_on
                    _, p, _, _ = chi2_contingency([[tp_on, fp_on], [tp_off, fp_off]], correction=False)
                    OR = (tp_on * fp_off) / max(1, fp_on * tp_off)
                    Δprop = abs(tp_on / tp_total - fp_on / fp_total)
                    valid.append(a.item())
                    pvals.append(p)
                    odds.append(OR)
                    delta.append(Δprop)

                df_evo2 = pd.DataFrame({
                    'atom': valid,
                    'TP_on': [int(ind_mat_goodfire[test_labels, a].sum()) for a in valid],
                    'FP_on': [int(ind_mat_goodfire[~test_labels, a].sum()) for a in valid],
                    'odds': odds,
                    'Δprop': delta,
                    'p': pvals
                })
                df_evo2["fdr"] = multipletests(df_evo2["p"], method="fdr_bh")[1]

                FP_atoms_goodfire = set(df_evo2.query("odds < 0.5")["atom"])
                TP_atoms_goodfire = set(df_evo2.query("odds > 2.0")["atom"])

                best_f1_goodfire, best_t_goodfire, best_tp_evo2, best_fp_evo2 = 0, None, None, None
                best_precision_evo2, best_recall_evo2 = 0, 0

                fp_tensor_goodfire = torch.tensor(sorted(FP_atoms_goodfire), device=evo2_activations.device)
                for t in range(1, k_active + 1):
                    fp_counts = torch.isin(evo2_activations, fp_tensor_goodfire).sum(1).cpu().numpy()
                    keep = fp_counts < t
                    tp_kept = int((keep & test_labels.cpu().numpy()).sum())
                    fp_kept = int((keep & ~test_labels.cpu().numpy()).sum())
                    p = precision_score(test_labels.cpu(), keep)
                    r = recall_score(test_labels.cpu(), keep)
                    f = f1_score(test_labels.cpu(), keep)
                    if f > best_f1_goodfire:
                        best_f1_goodfire, best_t_goodfire, best_tp_evo2, best_fp_evo2 = f, t, tp_kept, fp_kept
                        best_precision_evo2, best_recall_evo2 = p, r

                X_goodfire = ind_mat_goodfire.cpu().numpy().astype("uint8")
                y_goodfire = test_labels.cpu().numpy().astype("uint8")
                X_train_gf, X_test_gf, y_train_gf, y_test_gf = train_test_split(X_goodfire, y_goodfire, test_size=0.3, random_state=42, stratify=y_goodfire)
                cv_gf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
                lr_gf = LogisticRegression(max_iter=1000, class_weight="balanced", n_jobs=-1)

                cv_auprc_scores_gf = []
                cv_precision_scores_gf = []
                cv_recall_scores_gf = []
                cv_f1_scores_gf = []

                for tr, val in cv_gf.split(X_train_gf, y_train_gf):
                    lr_gf.fit(X_train_gf[tr], y_train_gf[tr])
                    prob = lr_gf.predict_proba(X_train_gf[val])[:, 1]
                    pred = lr_gf.predict(X_train_gf[val])
                    cv_auprc_scores_gf.append(average_precision_score(y_train_gf[val], prob))
                    cv_precision_scores_gf.append(precision_score(y_train_gf[val], pred))
                    cv_recall_scores_gf.append(recall_score(y_train_gf[val], pred))
                    cv_f1_scores_gf.append(f1_score(y_train_gf[val], pred))

                lr_gf.fit(X_train_gf, y_train_gf)
                prob_test_gf = lr_gf.predict_proba(X_test_gf)[:, 1]
                pred_test_gf = lr_gf.predict(X_test_gf)
                test_precision_evo2 = precision_score(y_test_gf, pred_test_gf)
                test_recall_evo2 = recall_score(y_test_gf, pred_test_gf)

                records.append({
                    'model_dir': "Goodfire_Evo2_SAE",
                    'feature_dim': F,
                    'k_active': k_active,
                    'sparsity_ratio': k_active / F,
                    'atoms_used': atoms_used,
                    'FP_sig_atoms': len(FP_atoms_goodfire),
                    'TP_sig_atoms': len(TP_atoms_goodfire),
                    'TP_retained': best_tp_evo2,
                    'FP_remaining': best_fp_evo2,
                    'best_hard_F1': best_f1_goodfire,
                    'best_hard_Precision': best_precision_evo2,
                    'best_hard_Recall': best_recall_evo2,
                    'best_t': best_t_goodfire,
                    'LogReg_CV_AUPRC': float(np.mean(cv_auprc_scores_gf)),
                    'LogReg_CV_Precision': float(np.mean(cv_precision_scores_gf)),
                    'LogReg_CV_Recall': float(np.mean(cv_recall_scores_gf)),
                    'LogReg_CV_F1': float(np.mean(cv_f1_scores_gf)),
                    'LogReg_test_AUPRC': float(average_precision_score(y_test_gf, prob_test_gf)),
                    'LogReg_test_Precision': float(test_precision_evo2),
                    'LogReg_test_Recall': float(test_recall_evo2),
                    'LogReg_test_F1': float(f1_score(y_test_gf, pred_test_gf)),
                    'Baseline_F1': baseline_f1,
                    'Baseline_F1_std': baseline_f1_std
                })
                print("Completed Goodfire Evo-2 SAE evaluation")
            else:
                print("evo2_model or SV_Evo2_Encoder not available from previous notebooks")
                print("Skipping Goodfire SAE evaluation")
        except Exception as e:
            print(f"Error evaluating Goodfire SAE: {e}")
            print("Skipping Goodfire evaluation")

In [None]:
# Summary table and results
df = pd.DataFrame(records).sort_values(
    ["best_hard_F1", "TP_retained", "FP_sig_atoms", "sparsity_ratio"],
    ascending=[False, False, True, True]
).reset_index(drop=True)

# Add baseline-only row if computed
if baseline_f1 is not None:
    baseline_record = {
        'model_dir': 'Raw_Embeddings',
        'feature_dim': INPUT_DIM,
        'k_active': None,
        'sparsity_ratio': None,
        'atoms_used': None,
        'FP_sig_atoms': None,
        'TP_sig_atoms': None,
        'TP_retained': None,
        'FP_remaining': None,
        'best_hard_F1': None,
        'best_hard_Precision': None,
        'best_hard_Recall': None,
        'best_t': None,
        'LogReg_CV_AUPRC': None,
        'LogReg_CV_F1': None,
        'LogReg_CV_Precision': None,
        'LogReg_CV_Recall': None,
        'LogReg_test_AUPRC': None,
        'LogReg_test_F1': None,
        'LogReg_test_Precision': None,
        'LogReg_test_Recall': None,
        'Baseline_F1': baseline_f1,
        'Baseline_F1_std': baseline_f1_std
    }
    records.append(baseline_record)
    df = pd.DataFrame(records).sort_values(
        ["best_hard_F1", "TP_retained", "FP_sig_atoms", "sparsity_ratio"],
        ascending=[False, False, True, True]
    ).reset_index(drop=True)

display(Markdown("## Trained SAE + Goodfire Evo-2 + Raw Embeddings – Evaluation Summary"))
column_order = [
    'model_dir', 'feature_dim', 'k_active', 'sparsity_ratio', 'atoms_used',
    'FP_sig_atoms', 'TP_sig_atoms', 'TP_retained', 'FP_remaining',
    'best_hard_F1', 'best_hard_Precision', 'best_hard_Recall', 'best_t',
    'LogReg_CV_AUPRC', 'LogReg_CV_F1', 'LogReg_CV_Precision', 'LogReg_CV_Recall',
    'LogReg_test_AUPRC', 'LogReg_test_F1', 'LogReg_test_Precision', 'LogReg_test_Recall',
    'Baseline_F1', 'Baseline_F1_std'
]
existing_columns = [col for col in column_order if col in df.columns]
df_display = df[existing_columns]
display(df_display)

In [None]:
# Pick the winner (exclude Raw_Embeddings for best model selection)
df_models_for_best = df[df['model_dir'] != 'Raw_Embeddings']
if not df_models_for_best.empty:
    best_row = df_models_for_best.iloc[0]
    BEST_PATH = best_row["model_dir"]
    print(f"Best model: {BEST_PATH}")
    display_metrics = [
        "feature_dim", "k_active", "best_hard_F1", "best_hard_Precision", "best_hard_Recall",
        "TP_retained", "FP_remaining", "LogReg_test_F1", "LogReg_test_Precision",
        "LogReg_test_Recall", "Baseline_F1"
    ]
    existing_display_metrics = [col for col in display_metrics if col in best_row.index]
    print(best_row[existing_display_metrics])

    # Save results
    pointer_file = pathlib.Path(DRIVE_ROOT) / "BEST_MODEL.json"
    json.dump(dict(best_dir=BEST_PATH, metrics=best_row.to_dict()), open(pointer_file, "w"), indent=2)
    print(f"Pointer saved → {pointer_file}")
else:
    print("No models available for best model selection")

csv_path = pathlib.Path(DRIVE_ROOT) / "sae_eval_metrics.csv"
df.to_csv(csv_path, index=False)
print(f"Metrics CSV → {csv_path}")

In [None]:

# Summary statistics (exclude Raw_Embeddings for hard filter and logistic regression metrics)
if not df.empty:
    df_models = df[df['model_dir'] != 'Raw_Embeddings']
    print(f"PRECISION/RECALL SUMMARY:")
    print("="*50)
    if not df_models.empty:
        print(f"Hard Filter Approach:")
        print(f"  Best Precision: {df_models['best_hard_Precision'].max():.3f}")
        print(f"  Best Recall: {df_models['best_hard_Recall'].max():.3f}")
        print(f"  Best F1: {df_models['best_hard_F1'].max():.3f}")
        print(f"Logistic Regression (Test Set):")
        print(f"  Best Precision: {df_models['LogReg_test_Precision'].max():.3f}")
        print(f"  Best Recall: {df_models['LogReg_test_Recall'].max():.3f}")
        print(f"  Best F1: {df_models['LogReg_test_F1'].max():.3f}")

    if 'Baseline_F1' in df.columns and not df['Baseline_F1'].isna().all():
        print(f"Raw Embeddings Baseline:")
        print(f"  Baseline F1: {df['Baseline_F1'].iloc[0]:.3f} ± {df['Baseline_F1_std'].iloc[0]:.3f}")
else:
    print("No evaluation results to summarize")

# Clean up
try:
    del sae, acts, ind_mat, test_emb
except NameError:
    pass
torch.cuda.empty_cache()
gc.collect()

print("\nSAE training and evaluation complete!")