DeepOMAPNet is a deep learning framework for multi-modal single-cell analysis of CITE-seq data. Given RNA expression profiles, it simultaneously predicts surface protein (ADT) levels, classifies cell types, and performs disease diagnosis (AML vs. Normal). It does this by combining Graph Attention Networks (GAT) with cross-modal Transformer Fusion.
CITE-seq experiments measure both RNA and surface protein (ADT) levels in the same cell, but protein measurements are expensive and noisy. DeepOMAPNet learns to predict protein abundance directly from RNA expression while exploiting the cell-cell neighborhood graph structure. It also jointly classifies whether a cell belongs to an AML or Normal sample, sharing representations across both tasks.
Start with Tutorials/Training.ipynb. It walks through the complete pipeline end-to-end: data loading, normalization, graph construction, model training, and evaluation. This is the recommended starting point for all users.
jupyter notebook Tutorials/Training.ipynbThe model processes cells as nodes in a k-nearest-neighbor graph built in PCA space.
Raw CITE-seq AnnData (RNA + ADT)
--> CLR normalization (ADT) + Z-score normalization (both modalities)
--> Train / val / test split (70 / 15 / 15, stratified)
--> PCA (50 components) + k-NN graph (k=15) + Leiden clustering
--> PyTorch Geometric Data objects
--> GATWithTransformerFusion
--> ADT predictions | AML classification | Fused cell embeddings
| Component | Description |
|---|---|
GATWithTransformerFusion |
Top-level model: GAT encoder followed by TransformerFusion and task-specific output heads |
GraphPositionalEncoding |
Enriches node embeddings with graph topology signals (node degree, clustering coefficient) |
SparseCrossAttentionLayer |
Sparse multi-head cross-attention operating over edge lists; scales to large graphs (>100k cells) |
CrossAttentionLayer |
Dense cross-attention variant with layer norm; suitable for smaller graphs |
AdapterLayer |
Bottleneck adapter (dim -> dim/r -> dim) for parameter-efficient fine-tuning without retraining the full model |
TransformerFusion |
Stacks multiple cross-attention layers to fuse RNA and ADT modalities bidirectionally |
The model produces three outputs:
adt_pred— predicted protein levels per cell (regression)aml_pred— binary disease classification score per cellfused_embeddings— latent cell representations usable for downstream analysis (UMAP, clustering)
Training is handled by train_gat_transformer_fusion(), which provides:
- Multi-task loss: MSE for ADT regression + BCE for AML classification, with an optional cross-entropy term for cell-type classification
- Stratified train/val/test splits
- Automatic mixed precision (AMP) and gradient accumulation
- Early stopping with best-model restoration
Conda (recommended)
git clone https://github.com/SreeSatyaGit/DeepOMAPNet.git
cd DeepOMAPNet
conda env create -f environment.yml
conda activate deepomapnetpip
pip install -r requirements.txtCore dependencies: Python 3.8, PyTorch >= 2.0, PyTorch Geometric >= 2.3, ScanPy >= 1.9, AnnData >= 0.9.
Evaluated on a synthetic CITE-seq benchmark (500 cells, 250 Normal + 250 AML, 30 proteins, 500 genes).
| Metric | Value |
|---|---|
| ADT prediction — mean Pearson r | 0.785 |
| ADT prediction — best single protein r | 0.948 |
| AML classification — AUC-ROC | 0.836 |
| AML classification — F1 | 0.719 |
To reproduce these results:
python run_experiment.py # saves figures to results/DeepOMAPNet/
├── scripts/
│ ├── model/
│ │ └── doNET.py # GATWithTransformerFusion and all model components
│ ├── data_provider/
│ │ ├── graph_data_builder.py # k-NN graph construction, PCA, PyG Data objects
│ │ └── data_preprocessing.py # CLR and Z-score normalization, train/val/test splitting
│ ├── trainer/
│ │ └── gat_trainer.py # Multi-task training loop with AMP and early stopping
│ └── visualizations.py # Plotting utilities
├── Tutorials/
│ └── Training.ipynb # End-to-end tutorial
├── tests/ # pytest test suite
├── R/ # Supporting R scripts (WNN, preprocessing)
├── research/ # Experiment scripts
├── run_experiment.py # Main experiment execution script
├── environment.yml
└── requirements.txt
# Run the full test suite
pytest
# Run a specific test file
pytest tests/test_model_components.py -v| File | Tests | Coverage |
|---|---|---|
test_model_components.py |
36 | Forward pass, gradients, sparse attention, adapters |
test_data_pipeline.py |
25 | Normalization correctness, graph validity, split integrity |
test_training.py |
10 | Loss convergence, gradient clipping, reproducibility |
test_performance_benchmark.py |
16 | Pearson r vs. baselines, Wilcoxon test, AML AUC |
MIT. See LICENSE.
