In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
from torch.utils.data import random_split, DataLoader, Dataset
from IPython.display import clear_output, display

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import scanpy as sc
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import pyensembl
import time

In [4]:
import src.modelnew as m
import src.data_cleaning as dc
import src.position_encoding as pos
import src.pathway_encoding as path
import src.training as train

In [5]:
ensembl_data = pyensembl.EnsemblRelease(109)
ensembl_data.download()
ensembl_data.index()

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /hpc/home/zy231/.cache/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /hpc/home/zy231/.cache/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /hpc/home/zy231/.cache/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle


In [6]:
adata = sc.read_h5ad('vcc_data/adata_Training.h5ad')
#adata.obs['perturbation_idx'] = np.random.randint(0, CONFIG["n_perturbations"], size=adata.n_obs)
gene_names = pd.read_csv('vcc_data/gene_names.csv', header = None)
gene_names = gene_names.iloc[:, 0].tolist()


CONFIG = {
    "n_genes": len(gene_names), #total number of expression genes needed to predict
    "n_perturbations": len(adata.obs['target_gene'].unique().tolist()), #number of unique perturbations
    "n_chromosomes": 25, #chromosome number 23+X+Y

    "perturbation_dim": 256,      # Condition embedding
    "chrom_embedding_dim": 16,     # Learned in-model for chromosome identity
    "locus_fourier_features": 8,   # Number of Fourier frequency pairs (2*F)
    "pathway_dim": 50,             # From pre-trained Autoencoder(based on hallmark MSigDB)
    "gene_identity_dim": 189,       # Main learnable gene embedding

    # Backbone dims
    "d_model": 512,                # Mamba hidden size
    "mamba_layers": 4,
    "n_heads": 8,
    "n_layers": 4,

    # Head
    "prediction_head": "linear", # "linear" | "probabilistic"

    # Training
    "batch_size": 16,
    "learning_rate": 5e-5, # Lowered LR for AdamW stability
    "epochs": 100,
}

# Derived dimensions for clarity
# Positional encoding dim = chromosome embedding + MLP output from Fourier features
POS_DIM = CONFIG["chrom_embedding_dim"] + CONFIG["chrom_embedding_dim"]
GENE_FEAT_DIM = CONFIG["gene_identity_dim"] + CONFIG["pathway_dim"] + POS_DIM

In [None]:

unique_perturbations = adata.obs['target_gene'].unique().tolist()
perturbation_to_idx_map = {name: i for i, name in enumerate(unique_perturbations)}
adata.obs['perturbation_idx'] = adata.obs['target_gene'].map(perturbation_to_idx_map)
control_adata = dc.get_control_data(adata)
chr_idx, locus_norm, locus_fourier = pos.precompute_positional_indices(ensembl_data, gene_names, CONFIG)


--- Preparing positional indices using pyensembl ---


Fetching gene positions: 100%|██████████| 18080/18080 [03:32<00:00, 84.95it/s] 


SUCCESS: Generated positional tensors with shapes:
Chromosome Indices (chr_idx): torch.Size([18080])
Normalized Locus (locus_norm): torch.Size([18080, 1])
Locus Fourier Features (locus_fourier): torch.Size([18080, 16])





In [8]:
    # Pre-compute Fixed Features (run once, then save/load)
#filtered_data = dc.clean_and_preprocess_data(adata)
    # find dataset that can represent a nnoramla scell set
pathway_feats = path.precompute_pathway_features(control_adata, CONFIG)

--- Precomputing pathway features on control data ---
AE epoch 10/50 | recon MSE: 25.4281
AE epoch 20/50 | recon MSE: 10.0758
AE epoch 30/50 | recon MSE: 7.7874
AE epoch 40/50 | recon MSE: 7.2923
AE epoch 50/50 | recon MSE: 7.1357
Generated pathway_features shape: torch.Size([18080, 50])


In [9]:
    # Instantiate Model, Dataset, Dataloader
device = torch.device("cuda")
print(f"\nUsing device: {device}")
model = m.TranscriptomePredictor(CONFIG, GENE_FEAT_DIM, pathway_feats, chr_idx, locus_fourier)
model.to(device)

dataset = m.PerturbationDataset(adata, CONFIG) # Pass CONFIG to the dataset constructor
train_loader = DataLoader(dataset, batch_size=CONFIG["batch_size"], shuffle=True)


#train_size = int(0.8 * len(dataset))
#val_size = len(dataset) - train_size

# 2. Perform the split
#train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# 3. Create separate DataLoaders for each set
#train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

#print(f"Data split into {len(train_dataset)} training samples and {len(val_dataset)} validation samples.")


Using device: cuda
Dataset using log-normalized data from adata.X for linear head.


In [10]:
from tqdm.notebook import tqdm # For a progress bar

print("\n--- Checking every dataset entry for NaNs (Slow Method) ---")

nan_found_at_index = -1
total_samples = len(dataset)

# Use tqdm for a progress bar
for i in tqdm(range(total_samples), desc="Scanning dataset"):
    sample = dataset[i]
    target_expr = sample["target_expression"]
    
    # Check if this tensor has any NaN values
    if torch.isnan(target_expr).any():
        nan_found_at_index = i
        break # Stop as soon as we find one

# --- Report the result ---
if nan_found_at_index != -1:
    print("\n!!!!! CRITICAL: NaN DETECTED !!!!!")
    print(f"A NaN value was found in the 'target_expression' for sample index: {nan_found_at_index}")
    print("This is the likely cause of your CUDA error.")
else:
    print("\nOK: Scanned all {total_samples} samples. The dataset is 100% clean of NaNs.")


--- Checking every dataset entry for NaNs (Slow Method) ---


Scanning dataset:   0%|          | 0/221273 [00:00<?, ?it/s]


OK: Scanned all {total_samples} samples. The dataset is 100% clean of NaNs.


In [11]:
'''device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")
model = TranscriptomePredictor(CONFIG, GENE_FEAT_DIM, pathway_feats, chr_idx, locus_fourier).to(device)

dataset = m.PerturbationDataset(adata, CONFIG) # Pass CONFIG to the dataset constructor
dataloader = DataLoader(dataset, batch_size=CONFIG["batch_size"], shuffle=True)'''

    # Train
loss_history = train.train_model(model, train_loader, CONFIG,device)


--- Starting Model Training ---


                                                                                   

NameError: name 'clear_output' is not defined