In [1]:
import sys
import os
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Add project root to Python path
current_dir = os.getcwd()
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("=== Setup Information ===")
print(f"Current directory: {current_dir}")
print(f"Project root: {project_root}")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


=== Setup Information ===
Current directory: /projects/vanaja_lab/satya/DeepOMAPNet/Tutorials
Project root: /projects/vanaja_lab/satya/DeepOMAPNet
Python version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0]
PyTorch version: 2.9.0+cu128
CUDA available: True
CUDA device: NVIDIA H200
CUDA memory: 150.1 GB


In [2]:
import sys, os, importlib

# --- Paths ---
current_dir = os.getcwd()  # This will be .../DeepOMAPNet/Notebooks
project_root = os.path.dirname(current_dir)  # This will be .../DeepOMAPNet

# Add project root to Python path
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("Added to Python path:")
print(f"- Current directory: {current_dir}")
print(f"- Project root: {project_root}")
print(f"- Scripts directory exists: {os.path.exists(os.path.join(project_root, 'scripts'))}")
print(f"- Scripts/data_provider exists: {os.path.exists(os.path.join(project_root, 'scripts', 'data_provider'))}")

# Clear any cached imports
importlib.invalidate_caches()

# --- Import modules (module-style, not from ... import ...) ---
import scripts.data_provider.data_preprocessing as data_preprocessing
import scripts.data_provider.graph_data_builder as graph_data_builder
import scripts.model.doNET as doNET
import scripts.trainer.gat_trainer as gat_trainer

print("Module imports successful!")

Added to Python path:
- Current directory: /projects/vanaja_lab/satya/DeepOMAPNet/Tutorials
- Project root: /projects/vanaja_lab/satya/DeepOMAPNet
- Scripts directory exists: True
- Scripts/data_provider exists: True
Module imports successful!


In [11]:
import scanpy as sc
import anndata


# Load the preprocessed data
from scripts.data_provider.data_preprocessing import prepare_train_test_anndata
data = prepare_train_test_anndata()
rna_adata = data[0]  # RNA data
rna_test = data[1]
adt_adata = data[2]   # ADT data
adt_test = data[3]





All sample IDs in gene data: ['AML0612' 'AML3762' 'AML3133' 'AML2910' 'AML3050' 'AML2451' 'AML056'
 'AML073' 'AML055' 'AML048' 'AML052' 'AML2123' 'AML1371' 'AML4340'
 'AML4897' 'AML051' 'AML0693' 'AML3948' 'AML3730' 'AML0160' 'AML0310'
 'AML0361' 'AML038' 'AML008' 'AML043' 'AML028' 'AML006' 'AML025' 'AML003'
 'AML012' 'AML005' 'AML0048' 'AML022' 'AML0024' 'AML009' 'AML026' 'AML001'
 'AML0114' 'Control4' 'Control2' 'Control1' 'Control3' 'Control5'
 'Control0004' 'Control0058' 'Control0082' 'Control4003' 'Control0005']
AML 80% train: ['AML0024', 'AML001', 'AML3050', 'AML4340', 'AML005', 'AML006', 'AML056', 'AML025', 'AML043', 'AML051', 'AML3948', 'AML055', 'AML0693', 'AML1371', 'AML0160', 'AML048', 'AML022', 'AML0612', 'AML028', 'AML2451', 'AML2123', 'AML3762', 'AML0114', 'AML0361', 'AML3133', 'AML012', 'AML026', 'AML2910', 'AML009', 'AML008', 'AML0048']
AML 20% test: ['AML052', 'AML038', 'AML3730', 'AML0310', 'AML073', 'AML4897', 'AML003']
Control 80% train: ['Control4003', 'Control1', 

In [16]:
rna_adata = sc.read_h5ad("/projects/vanaja_lab/satya/DeepOMAPNet/GSE116256.h5ad")

In [17]:
# One-liner using pandas string methods
rna_adata.obs['aml_labels'] = rna_adata.obs['samples'].str.startswith('AML').astype(int)

# Convert to numpy array for training
aml_labels_array = rna_adata.obs['aml_labels'].values

# Check distribution
print(f"AML samples: {aml_labels_array.sum()}")
print(f"Normal samples: {len(aml_labels_array) - aml_labels_array.sum()}")

AML samples: 118224
Normal samples: 39955


In [18]:
# Convert AnnData to PyTorch Geometric format
print("=== Converting to PyTorch Geometric Format ===")
from scripts.data_provider.graph_data_builder import build_pyg_data
# Convert RNA data
print("Converting RNA data...")
rna_pyg_data = build_pyg_data(rna_adata)
print(f"RNA PyG data: {rna_pyg_data}")

# Convert ADT data
print("Converting ADT data...")
adt_pyg_data = build_pyg_data(adt_adata)
print(f"ADT PyG data: {adt_pyg_data}")


if rna_pyg_data.num_nodes != adt_pyg_data.num_nodes:
    print("⚠️  Warning: RNA and ADT data have different number of nodes!")
else:
    print("✅ RNA and ADT data have same number of nodes")

print("✅ PyTorch Geometric conversion complete!")


=== Converting to PyTorch Geometric Format ===
Converting RNA data...
build_pyg_data called with use_pca=True
Input adata shape: (158179, 2000)
Available obsm keys: ['X_integrated.cca', 'X_pca', 'X_scVI', 'X_umap', 'X_umap.unintegrated']
Computing PCA with exactly 50 components...
PCA computed, shape: (158179, 50)
Computing neighbor graph first...
Computing leiden clusters first...
Using PCA features, shape: (158179, 50)
RNA PyG data: Data(x=[158179, 50], edge_index=[2, 1916900], y=[158179])
Converting ADT data...
build_pyg_data called with use_pca=True
Input adata shape: (158179, 279)
Available obsm keys: []
Computing PCA with exactly 50 components...
PCA computed, shape: (158179, 50)
Computing neighbor graph first...
Computing leiden clusters first...
Using PCA features, shape: (158179, 50)
ADT PyG data: Data(x=[158179, 50], edge_index=[2, 1706010], y=[158179])
✅ RNA and ADT data have same number of nodes
✅ PyTorch Geometric conversion complete!


In [20]:
# 1) Extract labels from AnnData
labels_series = rna_adata.obs['Cell_type_identity'].astype('category')

# 2) Map to integer classes
celltype_to_idx = {cat: i for i, cat in enumerate(labels_series.cat.categories)}
idx_to_celltype = {i: cat for cat, i in celltype_to_idx.items()}
celltype_labels = labels_series.cat.codes.to_numpy()  # shape [N], ints in [0, C-1]
num_cell_types = len(celltype_to_idx)
num_cell_types

54

In [21]:
%load_ext autoreload
%autoreload 2
import importlib
import time
from datetime import datetime
from scripts.trainer import gat_trainer

# Reload the module
importlib.reload(gat_trainer)

# Re-import the functions you need
from scripts.trainer.gat_trainer import train_gat_transformer_fusion

print("=== Training Enhanced GATWithTransformerFusion Model ===")

# Training parameters (only parameters accepted by train_gat_transformer_fusion)
training_config = {
    'epochs': 550,
    'use_cpu_fallback': False,
    'seed': 42,
    'learning_rate': 1e-3,           # Changed from 'lr'
    'weight_decay': 1e-4,
    'dropout_rate': 0.6,             # Changed from 'dropout'
    'hidden_channels': 64,
    
    'num_heads': 4,                  # Changed from 'heads'
    'num_attention_heads': 4,        # Changed from 'nhead'
    'num_layers': 2,
    'use_mixed_precision': True,     # Changed from 'amp'
    'early_stopping_patience': 5,    # Changed from 'patience'
    'num_cell_types': num_cell_types
}



start_time = datetime.now()
print(f"Training started at: {start_time}")


# Train the enhanced model
trained_model, rna_data_with_masks, adt_data_with_masks, training_history,adt_mean, adt_std, node_degrees_rna, node_degrees_adt, clustering_coeffs_rna, clustering_coeffs_adt = train_gat_transformer_fusion(
    rna_data=rna_pyg_data,
    adt_data=adt_pyg_data,
    aml_labels=aml_labels_array,
    rna_anndata = rna_adata,
    adt_anndata = adt_adata,
    celltype_labels=celltype_labels, 
    celltype_weight=1.0,
    **training_config
)

end_time = datetime.now()
training_duration = end_time - start_time

print("=" * 60)
print(f"\n=== Training Complete ===")
print(f"Training finished at: {end_time}")
print(f"Total training time: {training_duration}")
print(f"Training time per epoch: {training_duration / training_config['epochs']}")

# Enhanced training results analysis
print(f"\n=== Enhanced Training Results ===")
print(f"Final training loss: {training_history['train_loss'][-1]:.6f}")
print(f"Final validation MSE: {training_history['val_MSE'][-1]:.6f}")
print(f"Final validation R²: {training_history['val_R2'][-1]:.4f}")
print(f"Final test MSE: {training_history['test_MSE'][-1]:.6f}")
print(f"Final test R²: {training_history['test_R2'][-1]:.4f}")

# Regularization loss analysis
if 'reg_loss' in training_history:
    print(f"Final regularization loss: {training_history['reg_loss'][-1]:.6f}")
    print(f"Average regularization loss: {sum(training_history['reg_loss']) / len(training_history['reg_loss']):.6f}")

# Model capabilities test
print(f"\n=== Testing Enhanced Model Capabilities ===")


=== Training Enhanced GATWithTransformerFusion Model ===
Training started at: 2025-11-01 15:25:54.440155
GPU memory fraction set to 50%
Using device: cuda
Data splits — train: 126543, val: 15817, test: 15819
Preprocessing RNA data using AnnData...
RNA preprocessing applied: torch.Size([158179, 2000])
Updated RNA input dimension after preprocessing: 2000
Loading ADT data for preprocessing...
Using AnnData object for ADT preprocessing...
Data from AnnData loaded: torch.Size([158179, 279])
Data appears to be already normalized. Using provided statistics.
Statistics: mean=-0.0000, std=1.0000
AML labels processed: (158179,), Normal: 39955, AML: 118224
AML labels split — train: 126543 (Normal: 31893, AML: 94650), val: 15817 (Normal: 4026, AML: 11791), test: 15819 (Normal: 4036, AML: 11783)
Updated ADT output dimension after preprocessing: 279
Model parameters: 483,503
Model moved to cuda
RNA data moved to cuda
ADT data moved to cuda
AML labels moved to cuda: torch.Size([158179])
CellType lab

## Save Model

In [22]:
import torch

model = trained_model

# Option 2: Save only the weights
torch.save(model.state_dict(), "DeepOMAPNet_weights.pth")
