In [1]:
"""
This Jupyter notebook serves as a tutorial implementing all functionalities 
described in the MultiGAI paper. 

Notes:
- The model relies on `adata.obs['Modality']` to identify the modality of each cell. 
  Valid modality labels include 'rna', 'atac', 'adt', 'multiome', or 'cite'.
- The dataset must include batch information in `adata.obs['batch']`.
- The model uses the raw count matrices stored in `adata.layers['counts']`.
"""

"\nThis Jupyter notebook serves as a tutorial implementing all functionalities \ndescribed in the MultiGAI paper. \n\nNotes:\n- The model relies on `adata.obs['Modality']` to identify the modality of each cell. \n  Valid modality labels include 'rna', 'atac', 'adt', 'multiome', or 'cite'.\n- The dataset must include batch information in `adata.obs['batch']`.\n- The model uses the raw count matrices stored in `adata.layers['counts']`.\n"

In [2]:
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
import scanpy as sc
from scipy.sparse import issparse
from torch.optim import Adam
from tqdm import tqdm
import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import MultiGAI

def set_seed(seed):
    # Set Python built-in random seed
    random.seed(seed)  
    
    # Set NumPy random seed
    np.random.seed(seed) 
    
    # Set PyTorch CPU random seed
    torch.manual_seed(seed) 
    
    # Set PyTorch GPU random seed (current device)
    torch.cuda.manual_seed(seed)  
    
    # Set PyTorch GPU random seed (all devices)
    torch.cuda.manual_seed_all(seed)  
    
    # Ensure deterministic behavior for CuDNN
    torch.backends.cudnn.deterministic = True  
    
    # Disable CuDNN auto-tuner to guarantee reproducibility
    torch.backends.cudnn.benchmark = False  

# Fix random seed for reproducibility
set_seed(42)

In [None]:
# ================================================
# Load Multiome single-cell data (RNA + ATAC)
# ================================================

# Load RNA and ATAC datasets
rna = sc.read("./data/neurips-multiome/rna_hvg.h5ad") 
atac = sc.read("./data/neurips-multiome/atac_hvf.h5ad") 

# Keep only cells with Modality equal to 'multiome'
rna = rna[rna.obs['Modality'] == 'multiome']

# Extract raw count matrices for RNA and ATAC
rna_d, atac_d = rna.layers['counts'], atac.layers['counts'] 

# Create a placeholder matrix for ADT (protein) data (not present in Multiome dataset)
adt_d = np.zeros((rna_d.shape[0], 1))  

# Record feature dimensions for each modality
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1  

# Convert sparse matrices to dense arrays if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(atac_d): atac_d = atac_d.toarray() 

# Construct modality vector (all cells are Multiome)
modality_vector = np.full(rna.shape[0], 12.0)  # 12 corresponds to Multiome

# Encode batch information as one-hot vectors for batch effect correction
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build the integrated Multi-Omics dataset
# Includes RNA, ATAC, ADT (placeholder), modality, and batch information
dataset = MultiGAI.MultiOmicsDataset(rna_d, atac_d, adt_d, modality_vector, batch_encoded) 

# Create training and testing DataLoaders
train_loader = DataLoader(dataset, batch_size=512, shuffle=True) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: The `rna` object passed here is NOT used for training.
#       It only provides metadata (.obs) and serves as the template
#       to save the learned latent variables into a .h5ad file.
#       The model integrates Multiome data (RNA + ATAC),
#       supports batch effect correction, and outputs a unified latent representation
# ================================================
MultiGAI.train_and_evaluate_model(
    './results/neurips-multiome-multigai_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only used to save latent variables, not for training
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for all neural network components (encoder & decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention
    128,        # Decoder hidden dimension (can be same as above if shared)
    128         # Number of key-value (K-V) pairs in attention
)

In [None]:
# ================================================
# Load CITE-seq single-cell data (RNA + ADT)
# ================================================

# Load RNA and protein (ADT) datasets
rna = sc.read("./data/neurips-cite/rna_hvg.h5ad") 
adt = sc.read("./data/neurips-cite/protein.h5ad") 

# Extract raw count matrices for RNA and ADT
rna_d, adt_d = rna.layers['counts'], adt.layers['counts'] 

# Create a placeholder matrix for ATAC (not present in CITE-seq dataset)
atac_d = np.zeros((rna_d.shape[0], 1))  

# Record feature dimensions for each modality
rna_dim, atac_dim, adt_dim = rna.shape[1], 1, adt.shape[1]  

# Convert sparse matrices to dense arrays if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(adt_d): adt_d = adt_d.toarray() 

# Construct modality vector to distinguish Multiome and CITE-seq data
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot vectors for batch effect correction
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build the integrated Multi-Omics dataset
# Includes RNA, ATAC (placeholder), ADT, modality, and batch information
dataset = MultiGAI.MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 

# Create training and testing DataLoaders
train_loader = DataLoader(dataset, batch_size=512, shuffle=True) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: The `rna` object passed here is NOT used for training.
#       It only provides metadata (.obs) and serves as the template
#       to save the learned latent variables into a .h5ad file.
#       The model integrates CITE-seq data (RNA + ADT),
#       supports batch effect correction, and outputs a unified latent representation
# ================================================
MultiGAI.train_and_evaluate_model(
    './results/neurips-cite-multigai_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only used to save latent variables, not for training
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
# ================================================
# Load Multiome single-cell data (RNA + ATAC)
# ================================================

# Load RNA and ATAC datasets
rna = sc.read("./data/neurips-multiome/rna_hvg.h5ad") 
atac = sc.read("./data/neurips-multiome/atac_hvf.h5ad") 

# ================= Training data (exclude NK cells) =================
# Select only cells that are NOT NK cells
rna_t = rna[rna.obs_names[rna.obs["cell_type"] != "NK"]].copy() 
atac_t = atac[atac.obs_names[atac.obs["cell_type"] != "NK"]].copy() 

# Extract raw count matrices
rna_t_d, atac_t_d = rna_t.layers['counts'], atac_t.layers['counts'] 

# Placeholder for ADT (not present in Multiome)
adt_t_d = np.zeros((rna_t_d.shape[0], 1))  

# Feature dimensions
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1 

# Convert sparse matrices to dense arrays if needed
if issparse(rna_t_d): rna_t_d = rna_t_d.toarray() 
if issparse(atac_t_d): atac_t_d = atac_t_d.toarray() 

# Construct modality vector (all cells are Multiome)
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna_t.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot
batch_indices = torch.from_numpy(rna_t.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build training dataset
dataset_t = MultiGAI.MultiOmicsDataset(rna_t_d, atac_t_d, adt_t_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset_t, batch_size=512, shuffle=True) 

# ================= Test data (all cells, including NK) =================
rna_d, atac_d = rna.layers['counts'], atac.layers['counts'] 
adt_d = np.zeros((rna_d.shape[0], 1))  
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1 

# Convert sparse to dense if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(atac_d): atac_d = atac_d.toarray() 

# Modality vector
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Batch encoding
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build testing dataset
dataset = MultiGAI.MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: `rna` is only used to save latent variables and provide metadata (.obs)
#       The model is trained on cells excluding NK (train_loader)
#       and evaluated on all cells including NK (test_loader)
# ================================================
MultiGAI.train_and_evaluate_model(
    './results/neurips-multiome-NK_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only for saving latent variables and metadata
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
# ================================================
# Load Multiome single-cell data (RNA + ATAC)
# ================================================

# Load RNA and ATAC datasets
rna = sc.read("./data/neurips-multiome/rna_hvg.h5ad") 
atac = sc.read("./data/neurips-multiome/atac_hvf.h5ad") 

# ================= Training data (exclude Lymph prog cells) =================
# Select only cells that are NOT Lymph prog
rna_t = rna[rna.obs_names[rna.obs["cell_type"] != "Lymph prog"]].copy() 
atac_t = atac[atac.obs_names[atac.obs["cell_type"] != "Lymph prog"]].copy() 

# Extract raw count matrices
rna_t_d, atac_t_d = rna_t.layers['counts'], atac_t.layers['counts'] 

# Placeholder for ADT (not present in Multiome)
adt_t_d = np.zeros((rna_t_d.shape[0], 1))  

# Feature dimensions
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1 

# Convert sparse matrices to dense arrays if needed
if issparse(rna_t_d): rna_t_d = rna_t_d.toarray() 
if issparse(atac_t_d): atac_t_d = atac_t_d.toarray() 

# Construct modality vector (all cells are Multiome)
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna_t.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot
batch_indices = torch.from_numpy(rna_t.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build training dataset
dataset_t = MultiGAI.MultiOmicsDataset(rna_t_d, atac_t_d, adt_t_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset_t, batch_size=512, shuffle=True) 

# ================= Test data (all cells, including Lymph prog) =================
rna_d, atac_d = rna.layers['counts'], atac.layers['counts'] 
adt_d = np.zeros((rna_d.shape[0], 1))  
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1 

# Convert sparse to dense if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(atac_d): atac_d = atac_d.toarray() 

# Modality vector
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Batch encoding
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build testing dataset
dataset = MultiGAI.MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: `rna` is only used to save latent variables and provide metadata (.obs)
#       The model is trained on cells excluding Lymph prog (train_loader)
#       and evaluated on all cells including Lymph prog (test_loader)
# ================================================
MultiGAI.train_and_evaluate_model(
    './results/neurips-multiome-Lymphprog_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only for saving latent variables and metadata
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
# ================================================
# Load CITE-seq single-cell data (RNA + ADT)
# ================================================

# Load RNA and protein (ADT) datasets
rna = sc.read("./data/neurips-cite/rna_hvg.h5ad")
adt = sc.read("./data/neurips-cite/protein.h5ad")

# ================= Training data (exclude CD8+ T naive cells) =================
# Select only cells that are NOT CD8+ T naive
rna_t  = rna[rna.obs_names[rna.obs["cell_type"] != "CD8+ T naive"]].copy()
adt_t = adt[adt.obs_names[adt.obs["cell_type"] != "CD8+ T naive"]].copy()

# Extract raw count matrices
rna_t_d, adt_t_d = rna_t.layers['counts'], adt_t.layers['counts']

# Placeholder for ATAC (not present in CITE-seq)
atac_t_d = np.zeros((rna_t_d.shape[0], 1))  

# Feature dimensions
rna_dim, atac_dim, adt_dim = rna.shape[1], 1, adt.shape[1]

# Convert sparse matrices to dense arrays if needed
if issparse(rna_t_d): rna_t_d = rna_t_d.toarray()
if issparse(adt_t_d): adt_t_d = adt_t_d.toarray()  

# Construct modality vector (all cells are CITE-seq)
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna_t.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot
batch_indices = torch.from_numpy(rna_t.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build training dataset
dataset_t = MultiGAI.MultiOmicsDataset(rna_t_d, atac_t_d, adt_t_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset_t, batch_size=512, shuffle=True) 

# ================= Test data (all cells, including CD8+ T naive) =================
rna_d, adt_d = rna.layers['counts'], adt.layers['counts'] 
atac_d = np.zeros((rna_d.shape[0], 1))  

# Convert sparse to dense if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(adt_d): adt_d = adt_d.toarray() 

# Modality vector
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Batch encoding
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build testing dataset
dataset = MultiGAI.MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: `rna` is only used to save latent variables and provide metadata (.obs)
#       The model is trained on cells excluding CD8+ T naive (train_loader)
#       and evaluated on all cells including CD8+ T naive (test_loader)
# ================================================
MultiGAI.train_and_evaluate_model(
    './results/neurips-cite-CD8+Tnaive_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only for saving latent variables and metadata
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
# ================================================
# Load CITE-seq single-cell data (RNA + ADT)
# ================================================

# Load RNA and protein (ADT) datasets
rna = sc.read("./data/neurips-cite/rna_hvg.h5ad")
adt = sc.read("./data/neurips-cite/protein.h5ad")

# ================= Training data (exclude HSC cells) =================
# Select only cells that are NOT HSC
rna_t  = rna[rna.obs_names[rna.obs["cell_type"] != "HSC"]].copy()
adt_t = adt[adt.obs_names[adt.obs["cell_type"] != "HSC"]].copy()

# Extract raw count matrices
rna_t_d, adt_t_d = rna_t.layers['counts'], adt_t.layers['counts']

# Placeholder for ATAC (not present in CITE-seq)
atac_t_d = np.zeros((rna_t_d.shape[0], 1))  

# Feature dimensions
rna_dim, atac_dim, adt_dim = rna.shape[1], 1, adt.shape[1]

# Convert sparse matrices to dense arrays if needed
if issparse(rna_t_d): rna_t_d = rna_t_d.toarray()
if issparse(adt_t_d): adt_t_d = adt_t_d.toarray()  

# Construct modality vector (all cells are CITE-seq)
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna_t.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot
batch_indices = torch.from_numpy(rna_t.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build training dataset
dataset_t = MultiGAI.MultiOmicsDataset(rna_t_d, atac_t_d, adt_t_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset_t, batch_size=512, shuffle=True) 

# ================= Test data (all cells, including HSC) =================
rna_d, adt_d = rna.layers['counts'], adt.layers['counts'] 
atac_d = np.zeros((rna_d.shape[0], 1))  

# Convert sparse to dense if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(adt_d): adt_d = adt_d.toarray() 

# Modality vector
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Batch encoding
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build testing dataset
dataset = MultiGAI.MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: `rna` is only used to save latent variables and provide metadata (.obs)
#       The model is trained on cells excluding HSC (train_loader)
#       and evaluated on all cells including HSC (test_loader)
# ================================================
MultiGAI.train_and_evaluate_model(
    './results/neurips-cite-HSC_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only for saving latent variables and metadata
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
# ================================================
# Load trimodal single-cell data (RNA + ATAC + ADT)
# ================================================

# Paths for results
e_dir = "./results"
v_dir = "./results"

# Output path for latent embeddings
output_path = './results/trimodal_latent.h5ad'

# Load RNA, ATAC, and ADT data
rna = sc.read('./data/trimodal_rna.h5ad')
atac = sc.read('./data/trimodal_atac.h5ad')
adt = sc.read('./data/trimodal_adt.h5ad')

# Extract raw count matrices
rna_d, atac_d, adt_d = rna.layers['counts'], atac.layers['counts'], adt.layers['counts']

# Feature dimensions
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], adt.shape[1]

# Convert sparse matrices to dense arrays if needed
if issparse(rna_d): rna_d = rna_d.toarray()
if issparse(atac_d): atac_d = atac_d.toarray()
if issparse(adt_d): adt_d = adt_d.toarray()

# Construct modality vector to indicate dataset type (multiome / cite)
modality_map = {'multiome': 12,'cite': 13}
modality_vector = rna.obs['Modality'].map(modality_map)
modality_d = modality_vector.to_numpy().astype(float)

# Encode batch information as one-hot vectors
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long()
batch_encoded = torch.nn.functional.one_hot(batch_indices)
batch_dim = batch_encoded.shape[1]

# Construct PyTorch dataset and dataloaders
dataset = MultiGAI.MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded)
train_loader = DataLoader(dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(dataset, batch_size=512, shuffle=False)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ================================================
# Initialize the MultiGAI model
# Note: parameters explanation:
# n_hidden: number of hidden layers (encoder+decoder)
# hidden: hidden layer dimension (all layers)
# z_dim: latent dimension
# batch_dim: dimension of batch one-hot vector
# q_dim: dimension of query vector in attention
# kv_n: number of key-value pairs in attention
# ================================================
# Here, 'rna' is only used to save latent variables and provide metadata (.obs)
adata, input_dim1, input_dim2, input_dim3 = rna, rna_dim, atac_dim, adt_dim
n_hidden, hidden, z_dim, q_dim, kv_n = 1, 128, 30, 128, 128

model = MultiGAI.multigai(input_dim1, input_dim2, input_dim3,
                 n_hidden, hidden, z_dim,
                 batch_dim, q_dim, kv_n).to(device)

# Optimizer and learning rate scheduler
optimizer_main = Adam(model.parameters(), lr=0.001)
scheduler_main = torch.optim.lr_scheduler.StepLR(optimizer_main, step_size=50, gamma=0.9)

# ================================================
# Training loop
# ================================================
tqdm_bar = tqdm(range(200), desc="Training Progress")

for epoch in tqdm_bar:
    running_loss = 0.0
    running_recon = 0.0
    running_kl = 0.0
    running_cos = 0.0

    # Gradual KL weight scheduling: no KL in first 100 epochs
    kl_weight = 0.0 if epoch < 100 else 0.1

    model.train()
    for batch_data in train_loader:
        optimizer_main.zero_grad()

        # Modality label for each cell in the batch
        # For example, values might indicate: 12 = Multiome, 13 = CITE-seq
        m_values = batch_data[3] 
        unique_m = m_values.unique()

        # Random permutation of modalities
        perm = torch.randperm(len(unique_m))
        unique_m = unique_m[perm]

        for m_curr in unique_m:
            mask = (m_values == m_curr)

            if mask.any():
                # Extract sub-batch corresponding to current modality
                sub_batch = [d[mask] for d in batch_data]
                m1, m2, m3, m_tensor, batch_tensor, idx = [x.to(device) for x in sub_batch]

                # Forward pass through model
                # Outputs:
                # z     : sampled latent variable from qz (z ~ Normal(mu, sigma))
                # p1/p2/p3 : predicted ZINB distributions for RNA, ATAC, and ADT, respectively
                # qz/pz : distributions of the latent variable z
                #         qz = Normal(mu, sigma^0.5) predicted by the encoder (posterior)
                #         pz = standard Normal prior (mean 0, std 1) used in the VAE KL divergence
                # a1/a2/a3 : modality-specific intermediate representations for RNA, ATAC, and ADT
                # ae        : joint (integrated) representation across all modalities
                z, p1, p2, p3, qz, pz, a1, a2, a3, ae = model(
                    m1, m2, m3,
                    int(m_curr.item()),
                    batch_tensor
                )

                # Compute loss
                # Includes reconstruction loss, KL divergence, and cosine (alignment) loss
                loss, reconst_loss, kl_loss, cos_loss = model.loss_function(
                    m1, m2, m3,
                    int(m_curr.item()),
                    p1, p2, p3, qz, pz,
                    a1, a2, a3,
                    kl_weight
                )

                # Backpropagation
                loss.backward()
                optimizer_main.step()

                # Accumulate metrics
                running_loss += loss.item()
                running_recon += reconst_loss.item()
                running_kl += kl_loss.item()
                running_cos += cos_loss.item()

    n_batches = len(train_loader)
    tqdm_bar.set_postfix({
        "loss": f"{running_loss/n_batches:.4f}",
        "recon": f"{running_recon/n_batches:.4f}",
        "kl": f"{running_kl/n_batches:.4f}",
        "cos": f"{running_cos/n_batches:.4f}",
        "w": f"{kl_weight:.3f}"
    })

    # Step the learning rate scheduler
    scheduler_main.step()

# ================================================
# Latent extraction
# ================================================
model.eval()

# Initialize storage for latent variables (z) and AE global embedding (ae)
z_all = torch.zeros((len(adata), z_dim), device=device)
ae_all_global = torch.zeros((len(adata), model.hidden), device=device)

# Extract the joint value matrix
val = model.values(torch.eye(model.kv_n, device=device))

with torch.no_grad():
    for batch_idx, batch_data in enumerate(test_loader):
        indices = batch_data[-1]  # indices in full dataset
        m_values = batch_data[3]
        unique_m = m_values.unique()

        for m_curr in unique_m:
            mask = (m_values == m_curr)
            if not mask.any():
                continue

            sub_batch = [d[mask] for d in batch_data]
            m1, m2, m3, m_tensor, batch_tensor, idx = [x.to(device) for x in sub_batch]

            # Forward pass for current modality batch
            z, p1, p2, p3, qz, pz, a1_batch, a2_batch, a3_batch, ae_batch = model(
                m1, m2, m3, int(m_curr.item()), batch_tensor
            )

            # Store latent embeddings
            z_all[idx.long()] = z
            ae_all_global[idx.long()] = ae_batch

# Save latent embeddings to AnnData object
adata.obsm['latent'] = z_all.cpu().numpy()
adata.write_h5ad(output_path)

# Save joint embeddings and joint value matrix
torch.save(ae_all_global.cpu(), os.path.join(e_dir, "trimodal_e.pt"))
torch.save(val.cpu(), os.path.join(v_dir, "trimodal_v.pt"))

In [None]:
MultiGAI.rnaatacmapping('./results/neurips-multiome-mapping.h5ad', 
               sc.read('./data/neurips-multiome/mapping.h5ad'), 
               1, 128, 30, 128, 128)

In [None]:
MultiGAI.rnaadtmapping('./results/neurips-cite-mapping.h5ad', 
               sc.read('./data/neurips-cite/mapping.h5ad'), 
               1, 128, 30, 128, 128)

In [None]:
MultiGAI.rnaatacadtmappingandimputing('./results/trimodal-mappingandimputing.h5ad',
           sc.read('./data/trimodal_mappingandimputing_rna.h5ad'), 
           sc.read('./data/trimodal_mappingandimputing_atac.h5ad'),
           sc.read('./data/trimodal_mappingandimputing_adt.h5ad'),
           1, 128, 30, 128, 128, ["CD3G", "NCAM1", "MS4A1"], ["CD20", "CD3", "CD56"])