In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
from torch.utils.data import random_split, DataLoader, Dataset
from IPython.display import clear_output, display

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import scanpy as sc
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import pyensembl
import time

In [4]:
import src.updatedmodel as m
import src.data_cleaning as dc
import src.position_encoding as pos
import src.pathway_encoding as path
import src.updatetraining as train

In [5]:
ensembl_data = pyensembl.EnsemblRelease(109)
ensembl_data.download()
ensembl_data.index()

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /hpc/home/zy231/.cache/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /hpc/home/zy231/.cache/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /hpc/home/zy231/.cache/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle


In [6]:
adata = sc.read_h5ad('vcc_data/adata_Training.h5ad')
#adata.obs['perturbation_idx'] = np.random.randint(0, CONFIG["n_perturbations"], size=adata.n_obs)
gene_names = pd.read_csv('vcc_data/gene_names.csv', header = None)
gene_names = gene_names.iloc[:, 0].tolist()


CONFIG = {
    "n_genes": len(gene_names), #total number of expression genes needed to predict
    "n_perturbations": len(adata.obs['target_gene'].unique().tolist()), #number of unique perturbations
    "n_chromosomes": 25, #chromosome number 23+X+Y

    "perturbation_dim": 256,      # Condition embedding
    "chrom_embedding_dim": 16,     # Learned in-model for chromosome identity
    "locus_fourier_features": 8,   # Number of Fourier frequency pairs (2*F)
    "pathway_dim": 50,             # From pre-trained Autoencoder(based on hallmark MSigDB)
    "gene_identity_dim": 189,       # Main learnable gene embedding

    # Backbone dims
    "d_model": 512,                # Mamba hidden size
    "mamba_layers": 4,
    "n_heads": 8,
    "n_layers": 4,

    # Head
    "prediction_head": "probabilistic", # "linear" | "probabilistic"

    # Training
    "batch_size": 16,
    "learning_rate": 5e-5, # Lowered LR for AdamW stability
    "epochs": 100,
}

# Derived dimensions for clarity
# Positional encoding dim = chromosome embedding + MLP output from Fourier features
POS_DIM = CONFIG["chrom_embedding_dim"] + CONFIG["chrom_embedding_dim"]
GENE_FEAT_DIM = CONFIG["gene_identity_dim"] + CONFIG["pathway_dim"] + POS_DIM

In [7]:

unique_genes = adata.var.index.tolist()
perturbation_to_idx_map = {name: i for i, name in enumerate(unique_genes)}
adata.obs['perturbation_idx'] = adata.obs['target_gene'].map(perturbation_to_idx_map)
adata.obs['perturbation_idx'] = adata.obs['perturbation_idx'].fillna(CONFIG['n_genes']).astype(int)
control_adata = dc.get_control_data(adata)
chr_idx, locus_norm, locus_fourier = pos.precompute_positional_indices(ensembl_data, gene_names, CONFIG)


--- Preparing positional indices using pyensembl ---


Fetching gene positions: 100%|██████████| 18080/18080 [02:18<00:00, 130.25it/s]


SUCCESS: Generated positional tensors with shapes:
Chromosome Indices (chr_idx): torch.Size([18080])
Normalized Locus (locus_norm): torch.Size([18080, 1])
Locus Fourier Features (locus_fourier): torch.Size([18080, 16])





In [8]:
adata.obs

Unnamed: 0,target_gene,guide_id,batch,perturbation_idx
AAACAAGCAACCTTGTACTTTAGG-Flex_1_01,CHMP3,CHMP3_P1P2_A|CHMP3_P1P2_B,Flex_1_01,2305
AAACAAGCATTGCCGCACTTTAGG-Flex_1_01,AKT2,AKT2_P1P2_A|AKT2_P1P2_B,Flex_1_01,15684
AAACCAATCAATGTTCACTTTAGG-Flex_1_01,SHPRH,SHPRH_P1P2_A|SHPRH_P1P2_B,Flex_1_01,6365
AAACCAATCCCTCGCTACTTTAGG-Flex_1_01,TMSB4X,TMSB4X_P1_A|TMSB4X_P1_B,Flex_1_01,17382
AAACCAATCTAAATCCACTTTAGG-Flex_1_01,KLF10,KLF10_P2_A|KLF10_P2_B,Flex_1_01,7714
...,...,...,...,...
TTTGGACGTGGTGCAGATTCGGTT-Flex_3_16,non-targeting,non-targeting_00035|non-targeting_03439,Flex_3_16,18080
TTTGTGAGTAGTAGCAATTCGGTT-Flex_3_16,KDM1A,KDM1A_P1P2_A|KDM1A_P1P2_B,Flex_3_16,264
TTTGTGAGTCCATCCTATTCGGTT-Flex_3_16,non-targeting,non-targeting_00020|non-targeting_01323,Flex_3_16,18080
TTTGTGAGTCCTGACAATTCGGTT-Flex_3_16,BIRC2,BIRC2_P1P2_A|BIRC2_P1P2_B,Flex_3_16,10238


In [9]:
adata.obs['perturbation_idx']

AAACAAGCAACCTTGTACTTTAGG-Flex_1_01     2305
AAACAAGCATTGCCGCACTTTAGG-Flex_1_01    15684
AAACCAATCAATGTTCACTTTAGG-Flex_1_01     6365
AAACCAATCCCTCGCTACTTTAGG-Flex_1_01    17382
AAACCAATCTAAATCCACTTTAGG-Flex_1_01     7714
                                      ...  
TTTGGACGTGGTGCAGATTCGGTT-Flex_3_16    18080
TTTGTGAGTAGTAGCAATTCGGTT-Flex_3_16      264
TTTGTGAGTCCATCCTATTCGGTT-Flex_3_16    18080
TTTGTGAGTCCTGACAATTCGGTT-Flex_3_16    10238
TTTGTGAGTGGACACGATTCGGTT-Flex_3_16    17065
Name: perturbation_idx, Length: 221273, dtype: int64

In [10]:
    # Pre-compute Fixed Features (run once, then save/load)
#filtered_data = dc.clean_and_preprocess_data(adata)
    # find dataset that can represent a nnoramla scell set
#pathway_feats = path.precompute_pathway_features(control_adata, CONFIG)
pathway_feats = np.load('pathway_matrix.npy')

In [11]:
    # Instantiate Model, Dataset, Dataloader
device = torch.device("cuda")
print(f"\nUsing device: {device}")
model = m.TranscriptomePredictor(CONFIG, GENE_FEAT_DIM, pathway_feats, chr_idx, locus_fourier)
model.to(device)

dataset = m.PerturbationDataset(adata, CONFIG) # Pass CONFIG to the dataset constructor
train_loader = DataLoader(dataset, batch_size=CONFIG["batch_size"], shuffle=True)


#train_size = int(0.8 * len(dataset))
#val_size = len(dataset) - train_size

# 2. Perform the split
#train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# 3. Create separate DataLoaders for each set
#train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

#print(f"Data split into {len(train_dataset)} training samples and {len(val_dataset)} validation samples.")


Using device: cuda
Dataset using log-normalized data from adata.X for linear head.


In [12]:
from tqdm.notebook import tqdm # For a progress bar

print("\n--- Checking every dataset entry for NaNs (Slow Method) ---")

nan_found_at_index = -1
total_samples = len(dataset)

# Use tqdm for a progress bar
for i in tqdm(range(total_samples), desc="Scanning dataset"):
    sample = dataset[i]
    target_expr = sample["target_expression"]
    
    # Check if this tensor has any NaN values
    if torch.isnan(target_expr).any():
        nan_found_at_index = i
        break # Stop as soon as we find one

# --- Report the result ---
if nan_found_at_index != -1:
    print("\n!!!!! CRITICAL: NaN DETECTED !!!!!")
    print(f"A NaN value was found in the 'target_expression' for sample index: {nan_found_at_index}")
    print("This is the likely cause of your CUDA error.")
else:
    print("\nOK: Scanned all {total_samples} samples. The dataset is 100% clean of NaNs.")


--- Checking every dataset entry for NaNs (Slow Method) ---


Scanning dataset:   0%|          | 0/221273 [00:00<?, ?it/s]


OK: Scanned all {total_samples} samples. The dataset is 100% clean of NaNs.


In [13]:
'''device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")
model = TranscriptomePredictor(CONFIG, GENE_FEAT_DIM, pathway_feats, chr_idx, locus_fourier).to(device)

dataset = m.PerturbationDataset(adata, CONFIG) # Pass CONFIG to the dataset constructor
dataloader = DataLoader(dataset, batch_size=CONFIG["batch_size"], shuffle=True)'''

    # Train
loss_history = train.train_model(model, train_loader, CONFIG,device)


--- Starting Model Training ---


                                                                               

KeyboardInterrupt: 