SAE Interpretability Analysis
Analyze which Sparse Autoencoder features are biased towards true positive vs false positive structural variant calls, and investigate biological patterns in feature activations.
 - Run SAE training notebook first to generate trained model
 - Embedding extraction notebook for test data


In [None]:
import os
import json
import pathlib
import torch
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
from scipy.stats import chi2_contingency, fisher_exact, mannwhitneyu
from statsmodels.stats.multitest import multipletests
import warnings
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

In [None]:
# Load test embeddings and SAE model
print("Loading test embeddings and trained SAE model...")

# Load test data
test_pkg = torch.load("../data/processed/sae_sv_embeddings.pt", map_location="cpu")
test_emb = test_pkg["embeddings"].float().to(device)
test_sv_info = test_pkg["sv_info"]

# Load best SAE model
best_model_path = pathlib.Path("../data/models/BEST_MODEL.json")
if best_model_path.exists():
    with open(best_model_path) as f:
        best_meta = json.load(f)
    model_dir = pathlib.Path(best_meta["best_dir"])
    ckpt_path = model_dir / "sae.pt"
else:
    # Fallback to standard path
    model_dir = pathlib.Path("../data/models/4096to4096_k64")
    ckpt_path = model_dir / "sae.pt"

pkg = torch.load(ckpt_path, map_location="cpu")
cfg = pkg["config"]
input_dim = int(cfg["input_dim"])
feature_dim = int(cfg["feature_dim"])
k_active = int(cfg["k"])

print(f"SAE config: {input_dim} -> {feature_dim}, k={k_active}")

In [None]:
# Recreate SAE model class if needed
class BatchTopKSAE(torch.nn.Module):
    """Sparse Autoencoder with Top-K activation"""

    def __init__(self, input_dim, feature_dim, k_active):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = feature_dim
        self.k_active = k_active

        self.encoder = torch.nn.Linear(input_dim, feature_dim, bias=True)
        self.decoder = torch.nn.Linear(feature_dim, input_dim, bias=False)

    def forward(self, x):
        features = self.encoder(x)
        k = min(self.k_active, self.feature_dim)
        topk_values, topk_indices = torch.topk(features, k, dim=-1)
        sparse_features = torch.zeros_like(features)
        sparse_features.scatter_(-1, topk_indices, topk_values)
        reconstructed = self.decoder(sparse_features)
        return {
            'reconstructed': reconstructed,
            'sparse_features': sparse_features,
            'dense_features': features
        }

    def get_feature_activations(self, x):
        with torch.no_grad():
            features = self.encoder(x)
            k = min(self.k_active, self.feature_dim)
            _, topk_indices = torch.topk(features, k, dim=-1)
            return topk_indices, features

# Load and setup SAE
sae = BatchTopKSAE(input_dim, feature_dim, k_active).to(device)
sae.load_state_dict(pkg["model_state_dict"], strict=False)
sae.eval()

In [None]:
# Get whitening parameters
mu = pkg["config"].get("μ", pkg["config"].get("mu"))
sigma = pkg["config"].get("σ", pkg["config"].get("std"))
if not isinstance(mu, torch.Tensor):
    mu = torch.tensor(mu)
if not isinstance(sigma, torch.Tensor):
    sigma = torch.tensor(sigma)
mu = mu.view(1, -1).float()
sigma = (sigma.view(1, -1).float() + 1e-6)

# Whiten test embeddings
test_emb_w = (test_emb - mu) / sigma

print(f"Loaded {len(test_sv_info)} test samples")

# Extract feature activations
print("Extracting SAE feature activations...")

with torch.no_grad():
    acts, _ = sae.get_feature_activations(test_emb_w)

# Create indicator matrix for analysis
N, k_active = acts.shape
ind_mat = torch.zeros(N, feature_dim, dtype=torch.bool, device=device)
ind_mat.scatter_(1, acts, True)

# Create labels (TP=True, FP=False)
labels = torch.tensor([
    sv.get('truvari_class') in ['TP', 'tp_comp_vcf']
    for sv in test_sv_info
], dtype=torch.bool).to(device)

print(f"Feature activations: {acts.shape}")
print(f"Indicator matrix: {ind_mat.shape}")
print(f"Labels: TP={labels.sum().item()}, FP={(~labels).sum().item()}")

In [None]:
# Compute TP/FP bias for each feature using chi-square test
print("Computing feature bias statistics...")

tp_total = int(labels.sum())
fp_total = int((~labels).sum())

bias_records = []
for atom in range(feature_dim):
    tp_on = int(ind_mat[labels, atom].sum())
    fp_on = int(ind_mat[~labels, atom].sum())

    if tp_on + fp_on < 5:  # Skip rarely activated features
        continue

    tp_off = tp_total - tp_on
    fp_off = fp_total - fp_on

    try:
        _, p_value, _, _ = chi2_contingency([[tp_on, fp_on], [tp_off, fp_off]], correction=False)
        odds_ratio = (tp_on * fp_off) / max(1, fp_on * tp_off)
        support = tp_on + fp_on

        bias_records.append({
            'atom': atom,
            'TP_on': tp_on,
            'FP_on': fp_on,
            'TP_off': tp_off,
            'FP_off': fp_off,
            'odds': odds_ratio,
            'support': support,
            'p': p_value
        })
    except ValueError:
        continue

stats_df = pd.DataFrame(bias_records)

# Apply FDR correction
if len(stats_df) > 0:
    _, fdr_corrected, _, _ = multipletests(stats_df['p'], method='fdr_bh')
    stats_df['fdr_corrected'] = fdr_corrected
    stats_df['significant'] = fdr_corrected < 0.05

print(f"Analyzed {len(stats_df)} features with ≥5 activations")

# Feature bias landscape overview
total = len(stats_df)
tp_bias = (stats_df["odds"] > 1).sum()
fp_bias = (stats_df["odds"] < 1).sum()
strong_tp = (stats_df["odds"] > 2).sum()
strong_fp = (stats_df["odds"] < 0.5).sum()
significant = stats_df['significant'].sum()

print(" SAE Feature Bias Landscape")
print("=" * 40)
print(f"Total features analyzed: {total}")
print(f"TP-biased (OR > 1): {tp_bias} ({tp_bias/total*100:.1f}%)")
print(f"FP-biased (OR < 1): {fp_bias} ({fp_bias/total*100:.1f}%)")
print(f"Strong TP-bias (OR > 2): {strong_tp}")
print(f"Strong FP-bias (OR < 0.5): {strong_fp}")
print(f"Statistically significant: {significant}")

support_stats = stats_df["support"].describe()
print(f"\nSupport statistics:")
print(f"  Mean: {support_stats['mean']:.1f}")
print(f"  Median: {support_stats['50%']:.1f}")
print(f"  Range: {support_stats['min']:.0f} - {support_stats['max']:.0f}")

In [None]:
# Top Biased Features Analysis

def show_top_features(df, title, n=10, ascending=False):
    """Display top biased features"""
    if df.empty:
        print(f"{title}: No features found")
        return

    if ascending:
        subset = df.nsmallest(n, "odds")
    else:
        subset = df.nlargest(n, "odds")

    print(f"\n{title}")
    print("-" * 80)
    print(f"{'Atom':<6} {'TP_on':<6} {'FP_on':<6} {'Odds_Ratio':<12} {'Support':<8} {'P-value':<10} {'FDR_Sig'}")
    print("-" * 80)

    for _, row in subset.iterrows():
        sig_marker = "✓" if row.get('significant', False) else "✗"
        print(f"{int(row['atom']):<6} {int(row['TP_on']):<6} {int(row['FP_on']):<6} "
              f"{row['odds']:<12.3f} {int(row['support']):<8} {row['p']:<10.2e} {sig_marker}")

# Show top TP and FP biased features
show_top_features(stats_df[stats_df["odds"] > 1], "Top TP-biased Features", n=10)
show_top_features(stats_df[stats_df["odds"] < 1], "Top FP-biased Features", n=10, ascending=True)

In [None]:
# SV Type Bias Analysis

def analyze_sv_type_bias(min_activations=10):
    """Analyze which features are biased towards specific SV types"""

    print(" SV Type Bias Analysis")
    print("=" * 40)

    # Get SV types for all samples
    sv_types = [sv.get('svtype', 'UNK') for sv in test_sv_info]
    sv_type_counts = Counter(sv_types)

    print("Overall SV type distribution:")
    for svtype, count in sv_type_counts.most_common():
        print(f"  {svtype}: {count} ({count/len(sv_types)*100:.1f}%)")

    # Analyze INS vs DEL bias specifically
    ins_indices = [i for i, svtype in enumerate(sv_types) if svtype == 'INS']
    del_indices = [i for i, svtype in enumerate(sv_types) if svtype == 'DEL']

    if len(ins_indices) > 0 and len(del_indices) > 0:
        ins_mask = torch.zeros(len(sv_types), dtype=torch.bool, device=device)
        ins_mask[ins_indices] = True

        print(f"\n INS vs DEL Feature Bias (minimum {min_activations} activations):")
        print("-" * 70)
        print(f"{'Atom':<6} {'INS_count':<10} {'DEL_count':<10} {'INS_rate':<10} {'Bias_score':<12}")
        print("-" * 70)

        ins_del_results = []

        for atom_id in stats_df['atom'].values:
            # Get activations for this atom
            atom_activations = (acts == atom_id).any(dim=1)

            ins_activations = (atom_activations & ins_mask).sum().item()
            del_activations = (atom_activations & ~ins_mask).sum().item()
            total_activations = ins_activations + del_activations

            if total_activations < min_activations:
                continue

            ins_rate = ins_activations / total_activations if total_activations > 0 else 0
            bias_score = ins_rate - 0.5  # Bias relative to balanced (0.5)

            ins_del_results.append({
                'atom': atom_id,
                'ins_count': ins_activations,
                'del_count': del_activations,
                'total': total_activations,
                'ins_rate': ins_rate,
                'bias_score': bias_score
            })

        # Sort by bias score and show extremes
        ins_del_df = pd.DataFrame(ins_del_results)

        if len(ins_del_df) > 0:
            # Top INS-biased
            top_ins = ins_del_df.nlargest(5, 'bias_score')
            for _, row in top_ins.iterrows():
                print(f"{int(row['atom']):<6} {int(row['ins_count']):<10} {int(row['del_count']):<10} "
                      f"{row['ins_rate']:<10.3f} {row['bias_score']:<12.3f}")

            print("  ...")

            # Top DEL-biased
            top_del = ins_del_df.nsmallest(5, 'bias_score')
            for _, row in top_del.iterrows():
                print(f"{int(row['atom']):<6} {int(row['ins_count']):<10} {int(row['del_count']):<10} "
                      f"{row['ins_rate']:<10.3f} {row['bias_score']:<12.3f}")

            print(f"\nStrong INS-biased features (bias > 0.3): {(ins_del_df['bias_score'] > 0.3).sum()}")
            print(f"Strong DEL-biased features (bias < -0.3): {(ins_del_df['bias_score'] < -0.3).sum()}")

            return ins_del_df

    return None

ins_del_bias_df = analyze_sv_type_bias()

In [None]:
# Detailed Feature Investigation

def investigate_feature(atom_id, stats_df, test_sv_info, acts, labels):
    """Detailed investigation of a specific feature"""

    # Get feature statistics
    feature_stats = stats_df[stats_df['atom'] == atom_id]
    if len(feature_stats) == 0:
        print(f"Feature {atom_id} not found in statistics")
        return

    stats = feature_stats.iloc[0]

    print(f" Feature {atom_id} Detailed Analysis")
    print("=" * 50)
    print(f"Activations: TP={stats['TP_on']}, FP={stats['FP_on']}, Total={stats['support']}")
    print(f"Odds Ratio: {stats['odds']:.3f}")
    print(f"P-value: {stats['p']:.2e}")
    if 'significant' in stats:
        print(f"FDR Significant: {'Yes' if stats['significant'] else 'No'}")

    # Find samples where this feature fires
    firing_mask = (acts == atom_id).any(dim=1)
    firing_indices = torch.where(firing_mask)[0].cpu().numpy()

    if len(firing_indices) == 0:
        print("No activations found")
        return

    # Analyze SV characteristics
    firing_svs = [test_sv_info[idx] for idx in firing_indices]

    # SV types
    sv_types = [sv.get('svtype', 'UNK') for sv in firing_svs]
    type_counts = Counter(sv_types)
    print(f"\nSV Types:")
    for svtype, count in type_counts.most_common():
        print(f"  {svtype}: {count} ({count/len(sv_types)*100:.1f}%)")

    # Size distribution
    sizes = [sv.get('svlen', 0) for sv in firing_svs]
    sizes = [s for s in sizes if s > 0]
    if sizes:
        print(f"\nSize Statistics:")
        print(f"  Mean: {np.mean(sizes):.0f} bp")
        print(f"  Median: {np.median(sizes):.0f} bp")
        print(f"  Range: {min(sizes):.0f} - {max(sizes):.0f} bp")

    # Chromosome distribution
    chroms = [sv.get('chrom', 'unknown') for sv in firing_svs]
    chrom_counts = Counter(chroms)
    print(f"\nTop Chromosomes:")
    for chrom, count in chrom_counts.most_common(5):
        print(f"  {chrom}: {count} ({count/len(chroms)*100:.1f}%)")

    # Dataset distribution
    datasets = [sv.get('dataset', 'unknown') for sv in firing_svs]
    dataset_counts = Counter(datasets)
    print(f"\nDatasets:")
    for dataset, count in dataset_counts.items():
        print(f"  {dataset}: {count} ({count/len(datasets)*100:.1f}%)")

    # Show a few example variants
    print(f"\nExample Variants (first 5):")
    print(f"{'Index':<8} {'Chrom':<8} {'Pos':<12} {'Type':<6} {'Size':<10} {'Class'}")
    print("-" * 55)
    for i, idx in enumerate(firing_indices[:5]):
        sv = test_sv_info[idx]
        is_tp = labels[idx].item()
        class_label = 'TP' if is_tp else 'FP'
        print(f"{idx:<8} {sv.get('chrom', 'unk'):<8} {sv.get('pos', 0):<12} "
              f"{sv.get('svtype', 'UNK'):<6} {sv.get('svlen', 0):<10} {class_label}")

In [None]:
# Investigate top biased features
print(" Investigating Top Features")
print("=" * 50)

# Top TP-biased features
top_tp_features = stats_df.nlargest(3, 'odds')['atom'].tolist()
for atom_id in top_tp_features:
    investigate_feature(atom_id, stats_df, test_sv_info, acts, labels)
    print()

# Top FP-biased features
top_fp_features = stats_df.nsmallest(3, 'odds')['atom'].tolist()
for atom_id in top_fp_features:
    investigate_feature(atom_id, stats_df, test_sv_info, acts, labels)
    print()

In [None]:
# Feature Co-activation Analysis

def analyze_feature_coactivation(top_features, acts, min_coactivation=5):
    """Analyze which features tend to activate together"""

    print(" Feature Co-activation Analysis")
    print("=" * 40)

    coactivation_matrix = {}

    for i, feat1 in enumerate(top_features):
        for j, feat2 in enumerate(top_features[i+1:], i+1):
            # Count samples where both features activate
            mask1 = (acts == feat1).any(dim=1)
            mask2 = (acts == feat2).any(dim=1)
            coactivation_count = (mask1 & mask2).sum().item()

            if coactivation_count >= min_coactivation:
                coactivation_matrix[(feat1, feat2)] = coactivation_count

    if coactivation_matrix:
        print(f"Feature pairs with ≥{min_coactivation} co-activations:")
        sorted_pairs = sorted(coactivation_matrix.items(), key=lambda x: x[1], reverse=True)

        for (feat1, feat2), count in sorted_pairs[:10]:
            print(f"  Features {feat1} & {feat2}: {count} co-activations")
    else:
        print("No significant co-activations found")

# Analyze co-activation among top biased features
top_biased_features = (stats_df.nlargest(10, 'odds')['atom'].tolist() +
                      stats_df.nsmallest(10, 'odds')['atom'].tolist())
analyze_feature_coactivation(top_biased_features, acts)


In [None]:
# Summary

def generate_interpretability_summary(stats_df):
    """Generate summary of interpretability findings"""

    print(" SAE Interpretability Summary")
    print("=" * 50)

    # Overall statistics
    total_features = len(stats_df)
    strong_tp = (stats_df['odds'] > 2).sum()
    strong_fp = (stats_df['odds'] < 0.5).sum()
    significant = stats_df.get('significant', pd.Series(dtype=bool)).sum()

    print(f"Features analyzed: {total_features}")
    print(f"Strong TP-biased features (OR > 2): {strong_tp}")
    print(f"Strong FP-biased features (OR < 0.5): {strong_fp}")
    print(f"Statistically significant features: {significant}")

    # Key findings
    print(f"\n Key Findings:")

    if strong_tp > 0:
        top_tp_odds = stats_df['odds'].max()
        print(f"• Strongest TP bias: OR = {top_tp_odds:.2f}")

    if strong_fp > 0:
        top_fp_odds = stats_df['odds'].min()
        print(f"• Strongest FP bias: OR = {top_fp_odds:.2f}")

    # Support analysis
    high_support = (stats_df['support'] > stats_df['support'].quantile(0.75)).sum()
    print(f"• High-support features (>75th percentile): {high_support}")

    # Size specialization (if ins_del analysis was run)
    if 'ins_del_bias_df' in globals() and ins_del_bias_df is not None:
        strong_ins_bias = (ins_del_bias_df['bias_score'] > 0.3).sum()
        strong_del_bias = (ins_del_bias_df['bias_score'] < -0.3).sum()
        print(f"• INS-specialized features: {strong_ins_bias}")
        print(f"• DEL-specialized features: {strong_del_bias}")

# Save results
results_dir = pathlib.Path("../data/models")
results_dir.mkdir(exist_ok=True)

# Save feature statistics
stats_df.to_csv(results_dir / "sae_feature_bias_analysis.csv", index=False)

# Save feature details for top biased features
top_features_analysis = {
    'top_tp_biased': stats_df.nlargest(10, 'odds')[['atom', 'odds', 'support', 'TP_on', 'FP_on']].to_dict('records'),
    'top_fp_biased': stats_df.nsmallest(10, 'odds')[['atom', 'odds', 'support', 'TP_on', 'FP_on']].to_dict('records'),
}

with open(results_dir / "top_biased_features.json", 'w') as f:
    json.dump(top_features_analysis, f, indent=2)

generate_interpretability_summary(stats_df)

print(f"\n Results saved to {results_dir}")
print(f" Interpretability analysis complete!")