In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
import pyensembl
from IPython.display import clear_output, display
import torch


In [None]:
import get_control_file as gcf
import position_encoding as pos
import pathway_encoding as path
from model import TranscriptomePredictor, PerturbationDataset

In [3]:
adata = sc.read_h5ad('../vcc_data/adata_Training.h5ad')

In [9]:
ensembl_data = pyensembl.EnsemblRelease(109)
ensembl_data.download()
ensembl_data.index()

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /Users/steveyin/Library/Caches/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /Users/steveyin/Library/Caches/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /Users/steveyin/Library/Caches/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle


In [6]:
#adata.obs['perturbation_idx'] = np.random.randint(0, CONFIG["n_perturbations"], size=adata.n_obs)
gene_names = pd.read_csv('../vcc_data/gene_names.csv', header = None)
gene_names = gene_names.iloc[:, 0].tolist()


CONFIG = {
    "n_genes": len(gene_names),
    "n_perturbations": len(adata.obs['target_gene'].unique().tolist()),
    "n_chromosomes": 25,

    "perturbation_dim": 256,      # Condition embedding
    "chrom_embedding_dim": 16,     # Learned in-model for chromosome identity
    "locus_fourier_features": 8,   # Number of Fourier frequency pairs (2*F)
    "pathway_dim": 50,             # From pre-trained Autoencoder(based on hallmark MSigDB)
    "gene_identity_dim": 189,       # Main learnable gene embedding

    # Backbone dims
    "d_model": 512,                # Mamba hidden size
    "mamba_layers": 4,

    # Head
    "prediction_head": "probabilistic", # "linear" | "probabilistic"

    # Training
    "batch_size": 64,
    "learning_rate": 5e-5, # Lowered LR for AdamW stability
    "epochs": 100,
}

# Derived dimensions for clarity
# Positional encoding dim = chromosome embedding + MLP output from Fourier features
POS_DIM = CONFIG["chrom_embedding_dim"] + CONFIG["chrom_embedding_dim"]
GENE_FEAT_DIM = CONFIG["gene_identity_dim"] + CONFIG["pathway_dim"] + POS_DIM

In [4]:
unique_perturbations = adata.obs['target_gene'].unique().tolist()
perturbation_to_idx_map = {name: i for i, name in enumerate(unique_perturbations)}
adata.obs['perturbation_idx'] = adata.obs['target_gene'].map(perturbation_to_idx_map)
control_adata = gcf.get_control_data(adata)

In [31]:
chr_idx, locus_norm, locus_fourier = pos.precompute_positional_indices(ensembl_data, gene_names, CONFIG)



--- Preparing positional indices using pyensembl ---


Fetching gene positions: 100%|██████████| 18080/18080 [00:00<00:00, 60112.75it/s]


SUCCESS: Generated positional tensors with shapes:
Chromosome Indices (chr_idx): torch.Size([18080])
Normalized Locus (locus_norm): torch.Size([18080, 1])
Locus Fourier Features (locus_fourier): torch.Size([18080, 16])





In [7]:
pathway_feats = path.precompute_pathway_features(control_adata, CONFIG)

--- Precomputing pathway features on control data ---
AE epoch 10/100 | recon MSE: 78.6051
AE epoch 20/100 | recon MSE: 50.8467
AE epoch 30/100 | recon MSE: 25.8626
AE epoch 40/100 | recon MSE: 13.6297
AE epoch 50/100 | recon MSE: 9.4090
AE epoch 60/100 | recon MSE: 7.9185
AE epoch 70/100 | recon MSE: 7.3963
AE epoch 80/100 | recon MSE: 7.2014
AE epoch 90/100 | recon MSE: 7.1255
AE epoch 100/100 | recon MSE: 7.0952
Generated pathway_features shape: torch.Size([18080, 50])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TranscriptomePredictor(CONFIG, GENE_FEAT_DIM, pathway_feats, chr_idx, locus_fourier).to(device)

dataset = PerturbationDataset(adata, CONFIG) # Pass CONFIG to the dataset constructor
train_loader = DataLoader(dataset, batch_size=CONFIG["batch_size"], shuffle=True)


array([[-4.0940959e-02, -1.7496973e-02,  3.9910398e-02, ...,
         4.0042114e-02, -1.0482197e-02,  2.9803514e-03],
       [ 4.3170462e+00,  5.3395712e-01,  3.0772862e+00, ...,
         2.6977115e+00, -3.4365354e+00, -3.8631627e-01],
       [ 7.8416473e-01,  1.3108829e-01,  5.3170836e-01, ...,
         6.6698223e-01, -5.0516069e-01, -7.5793922e-02],
       ...,
       [ 2.3736744e-01,  9.1434773e-03,  2.5253361e-01, ...,
         2.5257570e-01, -1.8625672e-01, -7.2753406e-03],
       [ 8.8144159e+00,  8.0110180e-01,  6.5307493e+00, ...,
         6.2219801e+00, -6.9380894e+00, -1.0412694e+00],
       [ 5.4845099e+00,  1.9520372e-01,  4.7072082e+00, ...,
         4.1869702e+00, -4.9142218e+00, -1.0798045e+00]], dtype=float32)