In [None]:
import sys
import os
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plta
import seaborn as sns
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Add project root to Python path
current_dir = os.getcwd()
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("=== Setup Information ===")
print(f"Current directory: {current_dir}")
print(f"Project root: {project_root}")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


=== Setup Information ===
Current directory: /projects/vanaja_lab/satya/DeepOMAPNet/Tutorials
Project root: /projects/vanaja_lab/satya/DeepOMAPNet
Python version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0]
PyTorch version: 2.9.0+cu128
CUDA available: True
CUDA device: NVIDIA H200
CUDA memory: 150.1 GB


In [2]:
import sys, os, importlib

# --- Paths ---
current_dir = os.getcwd()  # This will be .../DeepOMAPNet/Notebooks
project_root = os.path.dirname(current_dir)  # This will be .../DeepOMAPNet

# Add project root to Python path
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("Added to Python path:")
print(f"- Current directory: {current_dir}")
print(f"- Project root: {project_root}")
print(f"- Scripts directory exists: {os.path.exists(os.path.join(project_root, 'scripts'))}")
print(f"- Scripts/data_provider exists: {os.path.exists(os.path.join(project_root, 'scripts', 'data_provider'))}")

# Clear any cached imports
importlib.invalidate_caches()

# --- Import modules (module-style, not from ... import ...) ---
import scripts.data_provider.data_preprocessing as data_preprocessing
import scripts.data_provider.graph_data_builder as graph_data_builder
import scripts.model.doNET as doNET
import scripts.trainer.gat_trainer as gat_trainer

print("Module imports successful!")

Added to Python path:
- Current directory: /projects/vanaja_lab/satya/DeepOMAPNet/Tutorials
- Project root: /projects/vanaja_lab/satya/DeepOMAPNet
- Scripts directory exists: True
- Scripts/data_provider exists: True
Module imports successful!


In [3]:
# Load training data
print("=== Loading Training Data ===")
# Load RNA and ADT data
rna_adata = sc.read_h5ad("/projects/vanaja_lab/satya/Datasets/GSMControlRNA.h5ad")
adt_adata = sc.read_h5ad("/projects/vanaja_lab/satya/Datasets/ControlADT.h5ad")


# Load the preprocessed data
from scripts.data_provider.data_preprocessing import prepare_train_test_anndata
data = prepare_train_test_anndata()
rna_adata = data[0]  # RNA data
rna_test = data[1]
adt_adata = data[2]   # ADT data
adt_test = data[3]

sc.pp.highly_variable_genes(rna_adata, n_top_genes=2000, flavor='cell_ranger')
# Subset by highly variable genes
hvg_mask = rna_adata.var['highly_variable']
rna_adata = rna_adata[:, hvg_mask].copy()

=== Loading Training Data ===
All sample IDs in gene data: ['AML0612' 'AML3762' 'AML3133' 'AML2910' 'AML3050' 'AML2451' 'AML056'
 'AML073' 'AML055' 'AML048' 'AML052' 'AML2123' 'AML1371' 'AML4340'
 'AML4897' 'AML051' 'AML0693' 'AML3948' 'AML3730' 'AML0160' 'AML0310'
 'AML0361' 'AML038' 'AML008' 'AML043' 'AML028' 'AML006' 'AML025' 'AML003'
 'AML012' 'AML005' 'AML0048' 'AML022' 'AML0024' 'AML009' 'AML026' 'AML001'
 'AML0114' 'Control4' 'Control2' 'Control1' 'Control3' 'Control5'
 'Control0004' 'Control0058' 'Control0082' 'Control4003' 'Control0005']
AML 80% train: ['AML0024', 'AML001', 'AML3050', 'AML4340', 'AML005', 'AML006', 'AML056', 'AML025', 'AML043', 'AML051', 'AML3948', 'AML055', 'AML0693', 'AML1371', 'AML0160', 'AML048', 'AML022', 'AML0612', 'AML028', 'AML2451', 'AML2123', 'AML3762', 'AML0114', 'AML0361', 'AML3133', 'AML012', 'AML026', 'AML2910', 'AML009', 'AML008', 'AML0048']
AML 20% test: ['AML052', 'AML038', 'AML3730', 'AML0310', 'AML073', 'AML4897', 'AML003']
Control 80% train

In [4]:
# Subsampling configuration
SUBSAMPLE_SIZE = 50000  # Adjust based on your GPU memory
TRAIN_SUBSAMPLE_SIZE = int(SUBSAMPLE_SIZE * 0.8)  # 40,000 train
TEST_SUBSAMPLE_SIZE = int(SUBSAMPLE_SIZE * 0.2)   # 10,000 test

print(f"=== Subsampling Dataset ===")
print(f"Target subsample size: {SUBSAMPLE_SIZE}")
print(f"Train subsample: {TRAIN_SUBSAMPLE_SIZE}")
print(f"Test subsample: {TEST_SUBSAMPLE_SIZE}")

# Check original sizes
print(f"Original train RNA: {rna_adata.shape}")
print(f"Original train ADT: {adt_adata.shape}")
print(f"Original test RNA: {rna_test.shape}")
print(f"Original test ADT: {adt_test.shape}")

# Get cell indices for each class in train data
aml_train_indices = []
normal_train_indices = []

# Separate train indices by class
for i, sample_id in enumerate(rna_adata.obs['samples']):
    if sample_id.startswith('AML'):
        aml_train_indices.append(i)
    else:  # Control samples
        normal_train_indices.append(i)

# Get cell indices for each class in test data
aml_test_indices = []
normal_test_indices = []

# Separate test indices by class
for i, sample_id in enumerate(rna_test.obs['samples']):
    if sample_id.startswith('AML'):
        aml_test_indices.append(i)
    else:  # Control samples
        normal_test_indices.append(i)

print(f"Original - AML train: {len(aml_train_indices)}, Normal train: {len(normal_train_indices)}")
print(f"Original - AML test: {len(aml_test_indices)}, Normal test: {len(normal_test_indices)}")

# Calculate subsample sizes maintaining class balance
aml_train_subsize = int(TRAIN_SUBSAMPLE_SIZE * 0.747)  # 74.7% AML
normal_train_subsize = TRAIN_SUBSAMPLE_SIZE - aml_train_subsize
aml_test_subsize = int(TEST_SUBSAMPLE_SIZE * 0.747)
normal_test_subsize = TEST_SUBSAMPLE_SIZE - aml_test_subsize

print(f"Subsample - AML train: {aml_train_subsize}, Normal train: {normal_train_subsize}")
print(f"Subsample - AML test: {aml_test_subsize}, Normal test: {normal_test_subsize}")

# Randomly sample indices
import numpy as np
np.random.seed(42)

# Sample train indices
aml_train_subset = np.random.choice(aml_train_indices, aml_train_subsize, replace=False)
normal_train_subset = np.random.choice(normal_train_indices, normal_train_subsize, replace=False)

# Sample test indices
aml_test_subset = np.random.choice(aml_test_indices, aml_test_subsize, replace=False)
normal_test_subset = np.random.choice(normal_test_indices, normal_test_subsize, replace=False)

# Combine train and test indices
train_subset_indices = np.concatenate([aml_train_subset, normal_train_subset])
test_subset_indices = np.concatenate([aml_test_subset, normal_test_subset])

print(f"Final subsample - Train: {len(train_subset_indices)}, Test: {len(test_subset_indices)}")

# Create subsampled data using your existing split
rna_adata_subset = rna_adata[train_subset_indices].copy()
adt_adata_subset = adt_adata[train_subset_indices].copy()

# Create test data using your existing split
rna_test_subset = rna_test[test_subset_indices].copy()
adt_test_subset = adt_test[test_subset_indices].copy()

print(f"Subsampled train RNA: {rna_adata_subset.shape}")
print(f"Subsampled train ADT: {adt_adata_subset.shape}")
print(f"Subsampled test RNA: {rna_test_subset.shape}")
print(f"Subsampled test ADT: {adt_test_subset.shape}")

# Verify class balance
aml_count_train = sum(1 for sample in rna_adata_subset.obs['samples'] if sample.startswith('AML'))
normal_count_train = len(rna_adata_subset) - aml_count_train
aml_count_test = sum(1 for sample in rna_test_subset.obs['samples'] if sample.startswith('AML'))
normal_count_test = len(rna_test_subset) - aml_count_test

print(f"Class balance - Train: AML {aml_count_train} ({aml_count_train/len(rna_adata_subset)*100:.1f}%), Normal {normal_count_train}")
print(f"Class balance - Test: AML {aml_count_test} ({aml_count_test/len(rna_test_subset)*100:.1f}%), Normal {normal_count_test}")

# Now use the subsampled data for training
print(f"\n=== Ready for GPU Training ===")
print(f"Use rna_adata_subset and adt_adata_subset for training")
print(f"Use rna_test_subset and adt_test_subset for testing")

=== Subsampling Dataset ===
Target subsample size: 50000
Train subsample: 40000
Test subsample: 10000
Original train RNA: (158179, 2000)
Original train ADT: (158179, 279)
Original test RNA: (46922, 36601)
Original test ADT: (46922, 279)
Original - AML train: 118224, Normal train: 39955
Original - AML test: 36633, Normal test: 10289
Subsample - AML train: 29880, Normal train: 10120
Subsample - AML test: 7470, Normal test: 2530
Final subsample - Train: 40000, Test: 10000
Subsampled train RNA: (40000, 2000)
Subsampled train ADT: (40000, 279)
Subsampled test RNA: (10000, 36601)
Subsampled test ADT: (10000, 279)
Class balance - Train: AML 29880 (74.7%), Normal 10120
Class balance - Test: AML 7470 (74.7%), Normal 2530

=== Ready for GPU Training ===
Use rna_adata_subset and adt_adata_subset for training
Use rna_test_subset and adt_test_subset for testing


In [5]:
# Get all indices from the original dataset
all_indices = np.arange(rna_adata.n_obs)

# Combine the subsampled indices (train + test)
subsample_indices = np.concatenate([train_subset_indices, test_subset_indices])

# Identify the remaining indices (not used in the 50k subsample)
remaining_indices = np.setdiff1d(all_indices, subsample_indices)

# Create new AnnData objects for the remaining portions
rna_adata_remaining = rna_adata[remaining_indices].copy()
adt_adata_remaining = adt_adata[remaining_indices].copy()



# Check shapes of remaining datasets
print(f"Remaining RNA data shape: {rna_adata_remaining.shape}")
print(f"Remaining ADT data shape: {adt_adata_remaining.shape}")


Remaining RNA data shape: (110673, 2000)
Remaining ADT data shape: (110673, 279)


In [6]:
# Convert AnnData to PyTorch Geometric format
print("=== Converting to PyTorch Geometric Format ===")
from scripts.data_provider.graph_data_builder import build_pyg_data
# Convert RNA data
print("Converting RNA data...")
rna_pyg_data = build_pyg_data(rna_adata_remaining)
print(f"RNA PyG data: {rna_pyg_data}")

# Convert ADT data
print("Converting ADT data...")
adt_pyg_data = build_pyg_data(adt_adata_remaining)
print(f"ADT PyG data: {adt_pyg_data}")


if rna_pyg_data.num_nodes != adt_pyg_data.num_nodes:
    print("⚠️  Warning: RNA and ADT data have different number of nodes!")
else:
    print("✅ RNA and ADT data have same number of nodes")

print("✅ PyTorch Geometric conversion complete!")


=== Converting to PyTorch Geometric Format ===
Converting RNA data...
build_pyg_data called with use_pca=True
Input adata shape: (110673, 2000)
Available obsm keys: ['X_integrated.cca', 'X_pca', 'X_umap', 'X_umap.unintegrated']
Computing PCA with exactly 50 components...
PCA computed, shape: (110673, 50)
Computing neighbor graph first...
Computing leiden clusters first...
Using PCA features, shape: (110673, 50)
RNA PyG data: Data(x=[110673, 50], edge_index=[2, 1265303], y=[110673])
Converting ADT data...
build_pyg_data called with use_pca=True
Input adata shape: (110673, 279)
Available obsm keys: []
Computing PCA with exactly 50 components...
PCA computed, shape: (110673, 50)
Computing neighbor graph first...
Computing leiden clusters first...
Using PCA features, shape: (110673, 50)
ADT PyG data: Data(x=[110673, 50], edge_index=[2, 1193139], y=[110673])
✅ RNA and ADT data have same number of nodes
✅ PyTorch Geometric conversion complete!


In [7]:
# One-liner using pandas string methods
rna_adata_remaining.obs['aml_labels'] = rna_adata_remaining.obs['samples'].str.startswith('AML').astype(int)

# Convert to numpy array for training
aml_labels_array = rna_adata_remaining.obs['aml_labels'].values

# Check distribution
print(f"AML samples: {aml_labels_array.sum()}")
print(f"Normal samples: {len(aml_labels_array) - aml_labels_array.sum()}")

AML samples: 83194
Normal samples: 27479


In [8]:
# 1) Extract labels from AnnData
labels_series = rna_adata_remaining.obs['Cell_type_identity'].astype('category')

# 2) Map to integer classes
celltype_to_idx = {cat: i for i, cat in enumerate(labels_series.cat.categories)}
idx_to_celltype = {i: cat for cat, i in celltype_to_idx.items()}
celltype_labels = labels_series.cat.codes.to_numpy()  # shape [N], ints in [0, C-1]
num_cell_types = len(celltype_to_idx)
num_cell_types

54

In [10]:
from scripts.trainer.fineTune import fine_tune_model, load_and_finetune

fine_tuned_model = load_and_finetune(
    model_path="DeepOMAPNet_weights.pth",
    rna_data=rna_adata_remaining,
    adt_data=adt_adata_remaining,
    aml_labels=aml_labels_array,
    model_params={
        'in_channels': 2000,
        'hidden_channels': 32,
        'out_channels': 279,
        'heads': 2,
        'dropout': 0.6,
        'nhead': 2,
        'num_layers': 1,
        'use_adapters': True,
        'reduction_factor': 4,
        'adapter_l2_reg': 5e-5,
        'use_positional_encoding': True,
        'num_cell_types': 54
    },
    # pass these as function args, not model_params
    celltype_labels=celltype_labels,     # ndarray or tensor of class indices [N]
    num_cell_types=54,                   # optional; enables head post-load if missing
    celltype_weight=1.0,
    epochs=600,
    learning_rate=1e-4,
    freeze_encoder=False
)

Loading pre-trained model from DeepOMAPNet_weights.pth...
✅ Model loaded successfully!
FINE-TUNING DEEPOMAPNET MODEL
Device: cuda
Epochs: 600
Learning rate: 0.0001
Freeze encoder: False
Data shapes:
  RNA: torch.Size([110673, 2000])
  ADT targets: torch.Size([110673, 279])

Trainable parameters: 483,503 / 483,503 (100.0%)

Starting fine-tuning...
Epoch 1/600 | ADT Loss: 0.779449 | AML Loss: 0.884034 | CellType Loss: 2.200339 | Total Loss: 3.863822
Epoch 60/600 | ADT Loss: 0.758864 | AML Loss: 0.432907 | CellType Loss: 2.177469 | Total Loss: 3.369240
Epoch 120/600 | ADT Loss: 0.754847 | AML Loss: 0.317327 | CellType Loss: 2.157536 | Total Loss: 3.229710
Epoch 180/600 | ADT Loss: 0.749220 | AML Loss: 0.300951 | CellType Loss: 2.141824 | Total Loss: 3.191995
Epoch 240/600 | ADT Loss: 0.748075 | AML Loss: 0.295211 | CellType Loss: 2.127924 | Total Loss: 3.171210
Epoch 300/600 | ADT Loss: 0.743310 | AML Loss: 0.290573 | CellType Loss: 2.108609 | Total Loss: 3.142492
Epoch 360/600 | ADT Loss

In [11]:
import torch

model = fine_tuned_model

# Option 2: Save only the weights
torch.save(model.state_dict(), "DeepOMAPNet_finetuned.pth")
