In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import scanpy as sc

In [None]:
if __name__ == '__main__':
    # 1. Create Dummy Data (replace with real data)
    # For probabilistic head, use raw-ish counts
    dummy_counts = np.random.poisson(lam=2.0, size=(500, CONFIG["n_genes"]))
    adata = sc.AnnData(dummy_counts)
    adata.obs['perturbation_idx'] = np.random.randint(0, CONFIG["n_perturbations"], size=adata.n_obs)
    sc.pp.calculate_qc_metrics(adata, inplace=True)

    # 2. Pre-compute Fixed Features (run once, then save/load)
    control_adata = sc.AnnData(np.random.poisson(lam=1.5, size=(1000, CONFIG["n_genes"])))
    pathway_feats = precompute_pathway_features(control_adata, CONFIG)
    chr_idx, locus_norm = precompute_positional_indices(None, CONFIG)

    # 3. Instantiate Model, Dataset, Dataloader
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    model = TranscriptomePredictor(CONFIG, pathway_feats, chr_idx, locus_norm).to(device)
    
    dataset = PerturbationDataset(adata)
    dataloader = DataLoader(dataset, batch_size=CONFIG["batch_size"], shuffle=True)

    # 4. Train
    train_model(model, dataloader, CONFIG)
