SAE Data Extraction

Extract SAE latent representations and save analysis-ready datasets

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle
import os
import pathlib
import json

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
class BatchTopKSAE(torch.nn.Module):
    """Sparse Autoencoder with Top-K activation"""
    def __init__(self, input_dim, feature_dim, k_active):
        super().__init__()
        self.input_dim = input_dim
        self.feature_dim = feature_dim
        self.k_active = k_active
        self.encoder = torch.nn.Linear(input_dim, feature_dim, bias=True)
        self.decoder = torch.nn.Linear(feature_dim, input_dim, bias=False)

    def forward(self, x):
        features = self.encoder(x)
        k = min(self.k_active, self.feature_dim)
        topk_values, topk_indices = torch.topk(features, k, dim=-1)
        sparse_features = torch.zeros_like(features)
        sparse_features.scatter_(-1, topk_indices, topk_values)
        reconstructed = self.decoder(sparse_features)
        return {
            'reconstructed': reconstructed,
            'sparse_features': sparse_features,
            'dense_features': features
        }

    def get_feature_activations(self, x):
        with torch.no_grad():
            features = self.encoder(x)
            k = min(self.k_active, self.feature_dim)
            _, topk_indices = torch.topk(features, k, dim=-1)
            return topk_indices, features


In [None]:
# Load test embeddings and SAE model
print("Loading test embeddings and trained SAE model...")

# Load test data
test_pkg = torch.load("../data/processed/sae_sv_embeddings.pt", map_location="cpu")
test_emb = test_pkg["embeddings"].float().to(device)
test_sv_info = test_pkg["sv_info"]

# Load best SAE model
import pathlib
import json

best_model_path = pathlib.Path("../data/models/BEST_MODEL.json")
if best_model_path.exists():
    with open(best_model_path) as f:
        best_meta = json.load(f)
    model_dir = pathlib.Path(best_meta["best_dir"])
    ckpt_path = model_dir / "sae.pt"
else:
    raise ValueError("No best model found")

pkg = torch.load(ckpt_path, map_location="cpu")
cfg = pkg["config"]
input_dim = int(cfg["input_dim"])
feature_dim = int(cfg["feature_dim"])
k_active = int(cfg["k"])
print(f"SAE config: {input_dim} -> {feature_dim}, k={k_active}")

# Load and setup SAE
sae = BatchTopKSAE(input_dim, feature_dim, k_active).to(device)
sae.load_state_dict(pkg["model_state_dict"], strict=False)
sae.eval()
print("SAE loaded")

In [None]:

def extract_sae_latents(embeddings, batch_size=128, desc="Processing"):
    n_samples = embeddings.shape[0]
    all_dense_features = torch.zeros(n_samples, feature_dim, dtype=torch.bool)

    with torch.no_grad():
        for i in tqdm(range(0, n_samples, batch_size), desc=desc):
            end_idx = min(i + batch_size, n_samples)
            batch_emb = embeddings[i:end_idx].to(device)
            top_k_indices, _ = sae.get_feature_activations(batch_emb)

            batch_dense = torch.zeros(batch_emb.shape[0], feature_dim, dtype=torch.bool)
            batch_dense.scatter_(1, top_k_indices, True)
            all_dense_features[i:end_idx] = batch_dense

    return all_dense_features


print("Extracting SAE latents...")
train_dense = extract_sae_latents(train_embeddings, desc="Train set")
test_dense = extract_sae_latents(test_emb, desc="Test set")
print("Latents extracted")

In [None]:
# For training data, load from separate file
train_pkg = torch.load("../data/processed/layer26_sv_total_train_embeddings.pt", map_location="cpu")
train_embeddings = train_pkg['embeddings'].float()
train_sv_info = train_pkg['sv_info']

train_labels = torch.tensor([sv['truvari_class'] == 'tp_comp_vcf' for sv in train_sv_info])
test_labels = torch.tensor([sv['truvari_class'] == 'tp_comp_vcf' for sv in test_sv_info])

print(f"Train: {train_embeddings.shape[0]} samples ({train_labels.sum().item()} TP)")
print(f"Test: {test_emb.shape[0]} samples ({test_labels.sum().item()} TP)")


# %%
# Create datasets
train_dense_np = train_dense.numpy().astype(np.uint8)
test_dense_np = test_dense.numpy().astype(np.uint8)
train_labels_np = train_labels.numpy().astype(np.uint8)
test_labels_np = test_labels.numpy().astype(np.uint8)

combined_dense = np.concatenate([train_dense_np, test_dense_np], axis=0)
combined_labels = np.concatenate([train_labels_np, test_labels_np], axis=0)
combined_sv_info = train_sv_info + test_sv_info

# %%
# Save everything
save_dir = "../data/sae_latents/"
os.makedirs(save_dir, exist_ok=True)

# Train data
torch.save({
    'dense_features': train_dense_np,
    'labels': train_labels_np,
    'sv_info': train_sv_info,
    'n_samples': len(train_sv_info),
    'n_features': feature_dim,
    'k_active': k_active
}, f"{save_dir}sae_latents_train.pt")

# Test data
torch.save({
    'dense_features': test_dense_np,
    'labels': test_labels_np,
    'sv_info': test_sv_info,
    'n_samples': len(test_sv_info),
    'n_features': feature_dim,
    'k_active': k_active
}, f"{save_dir}sae_latents_test.pt")

# Combined data
torch.save({
    'dense_features': combined_dense,
    'labels': combined_labels,
    'sv_info': combined_sv_info,
    'n_samples': len(combined_sv_info),
    'n_features': feature_dim,
    'k_active': k_active,
    'train_indices': list(range(len(train_sv_info))),
    'test_indices': list(range(len(train_sv_info), len(combined_sv_info)))
}, f"{save_dir}sae_latents_combined.pt")

# Sklearn-ready data
with open(f"{save_dir}sae_features_sklearn.pkl", 'wb') as f:
    pickle.dump({
        'X_train': train_dense_np,
        'y_train': train_labels_np,
        'X_test': test_dense_np,
        'y_test': test_labels_np,
        'X_combined': combined_dense,
        'y_combined': combined_labels,
        'feature_names': [f'atom_{i}' for i in range(feature_dim)]
    }, f)

print("All data saved!")
print(f"Files saved to: {save_dir}")
print(f"SAE dimensions: {input_dim} -> {feature_dim}, k={k_active}")