### SSVAE

#### Setup

In [None]:
# %pip install torch pandas numpy matplotlib scanpy scikit-learn

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import scanpy as sc
import matplotlib.pyplot as plt
%matplotlib inline

plt.rcParams["figure.figsize"] = (7.0, 5.5)
plt.rcParams["font.size"] = 16
plt.rcParams["image.interpolation"] = "nearest"
plt.rcParams["image.cmap"] = "gray"

DATA_PATH = "data/GSE161819_hc_t0_rna_adt_preprocessed.h5ad"
RANDOM_STATE=15179215

EPOCHS=100                      # from 50 (100?)
SUBSET_SIZE=10000               # from 15000
DATALOADER_BATCH_SIZE = 128     # from 256
LEARNING_RATE = 5e-5            # from 2e-4
WEIGHT_DECAY = 1e-3             # from 5e-3
LABEL_FRACTION = 0.5            # from 0.5


CHECKPOINT_PATH = "data/saved_models/ssvae_checkpoint.pt"
MODEL_EXISTS = os.path.exists(CHECKPOINT_PATH)

if MODEL_EXISTS:
    print(f"  Found existing model at {CHECKPOINT_PATH}")
    print("  Will load model and proceed to testing.")
else:
    print(f"  No existing model found at {CHECKPOINT_PATH}")
    print("  Will load data and train a new model.")



In [None]:
# Get cpu, gpu or mps device for training.
import torch
#if you have a CUDA-enabled nVidia GPU on your system, or are using Google Colab
if torch.cuda.is_available():
    DEVICE="cuda"
    DTYPE=torch.float64
#if you have a Mac with an M1 or greater processor and macOS 13 or greater
elif torch.backends.mps.is_available():
    DEVICE="mps"
    DTYPE=torch.float32
#otherwise use the system's CPU
else:
    DEVICE="cpu"
    DTYPE=torch.float64

print(f"Using {DEVICE} device with {DTYPE} precision.")

#### Part 1: Load data 
Load and split into (train - test - val) subsets

In [None]:
if MODEL_EXISTS:
    from utils import load_ssvae_conditioned_model
    model = load_ssvae_conditioned_model(CHECKPOINT_PATH, DEVICE)
else:
    print("  No existing model - will train from scratch")

In [None]:
# Clear any cached memory before starting
if DEVICE == "mps":
    torch.mps.empty_cache()
    print("Cleared MPS cache before loading/training")

In [None]:
if not MODEL_EXISTS:
    from data_utils.dataset import SingleCellDataset
    from torch.utils.data import DataLoader

    dataset = SingleCellDataset(datapath=DATA_PATH, use_covariates=True, random_state=RANDOM_STATE)

    train_dataset, val_dataset, test_dataset = dataset.stratified_split(subset_size=SUBSET_SIZE)

    train_loader = DataLoader(train_dataset, batch_size=DATALOADER_BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=DATALOADER_BATCH_SIZE, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=DATALOADER_BATCH_SIZE, shuffle=False)

    print(f"\nDataset splits:")
    print(f"  Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)} cells")

    print(f"\nBatch size {DATALOADER_BATCH_SIZE}")

    for batch in train_loader:
        rna_batch = batch['rna'].to(DEVICE)
        adt_batch = batch['adt'].to(DEVICE)
        label_batch = batch['label'].to(DEVICE)
        batch_batch = batch['batch'].to(DEVICE)
else:
    print("\nModel was loaded from disk")
    

### Train

In [None]:
if not MODEL_EXISTS:
    from models.ssvae_conditioned import SSVAE_Conditioned
    from training.ssvae_conditioned_trainer import train_ssvae_conditioned

    # Clear MPS cache to free up memory
    if DEVICE == "mps":
        torch.mps.empty_cache()
        print("Cleared MPS cache")

    rna_dim  = dataset.rna_matrix.shape[1]
    adt_dim  = dataset.adt_matrix.shape[1]
    n_labels = dataset.num_classes
    n_batches = dataset.num_batches

    model = SSVAE_Conditioned(
        rna_dim=rna_dim,
        adt_dim=adt_dim,
        num_batches=n_batches,
        n_classes=n_labels,
        num_covariates_cat={
                "age_group_idx": dataset.covariates_cat["age_group_idx"].max().item() + 1
            },
        num_covariates_cont = ["inflammation_score", "immunosenescence_score"]
    )

    history = train_ssvae_conditioned(
        model, 
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=EPOCHS,
        device=DEVICE,
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        label_fraction=LABEL_FRACTION,
        checkpoint_path="temp.pt" #CHECKPOINT_PATH
    )
else:
    print("Model loaded from disk")

 Recommended:  β_final ∈ {0.5, 1.0, 2.0}

 beta = 0.1 + (epoch > warmup_epochs) * (beta_final - 0.1)


add:
prior_logvar = torch.clamp(prior_logvar_raw, min=-3.0, max=1.5). 


 KL term >>> ADT loss








## Diagnostic Visualizations

Monitor training dynamics, posterior collapse, latent space structure, and reconstruction quality.

In [None]:
from utils import plot_ssvae_diagnostics

if not MODEL_EXISTS and 'history' in locals():
    
    # Collect hyperparameters for display in plots
    hyperparams = {
        'Epochs': EPOCHS,
        'Learning Rate': LEARNING_RATE,
        'Weight Decay': WEIGHT_DECAY,
        'Batch Size': DATALOADER_BATCH_SIZE,
        'Subset Size': SUBSET_SIZE
    }
    
    plot_ssvae_diagnostics(history, hyperparams, model, val_loader, DEVICE)
    
elif MODEL_EXISTS:
    print("Skipping diagnostic plots - model loaded from checkpoint")
else:
    print("Training data not available - cannot generate plots")

## LSTM