In [None]:
"""
This Jupyter notebook contains the source code for multimodal integration in the MultiGAI framework.
"""

In [None]:
import random
import numpy as np
import torch
from torch.distributions import Distribution, Normal, constraints
from torch import nn
from torch.utils.data import DataLoader, Dataset
import scanpy as sc
from scipy.sparse import issparse
from torch.optim import Adam
import math
from tqdm import tqdm
import torch.nn.functional as F
import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
def set_seed(seed):
    # Set Python built-in random seed
    random.seed(seed)  
    
    # Set NumPy random seed
    np.random.seed(seed) 
    
    # Set PyTorch CPU random seed
    torch.manual_seed(seed) 
    
    # Set PyTorch GPU random seed (current device)
    torch.cuda.manual_seed(seed)  
    
    # Set PyTorch GPU random seed (all devices)
    torch.cuda.manual_seed_all(seed)  
    
    # Ensure deterministic behavior for CuDNN
    torch.backends.cudnn.deterministic = True  
    
    # Disable CuDNN auto-tuner to guarantee reproducibility
    torch.backends.cudnn.benchmark = False  

# Fix random seed for reproducibility
set_seed(42)

In [None]:
class ZeroInflatedNegativeBinomial(Distribution):
    """
    Zero-Inflated Negative Binomial (ZINB) distribution.

    This distribution is commonly used to model over-dispersed count data
    with excessive zeros, such as scRNA-seq gene expression counts.
    """

    # Constraints on distribution parameters
    arg_constraints = {
        "mu": constraints.greater_than_eq(0),      # Mean of the NB distribution
        "theta": constraints.greater_than_eq(0),   # Inverse dispersion parameter
        "zi_logits": constraints.real,              # Logits for zero-inflation probability
        "scale": constraints.greater_than_eq(0),   # Optional scaling factor (e.g. library size)
    }

    # Support of the distribution: non-negative integers
    support = constraints.nonnegative_integer

    def __init__(self, mu, theta, zi_logits, scale, eps=1e-8, validate_args=False):
        """
        Parameters
        ----------
        mu : torch.Tensor
            Mean of the Negative Binomial distribution.
        theta : torch.Tensor
            Inverse dispersion parameter of the NB distribution.
        zi_logits : torch.Tensor
            Logits controlling the zero-inflation probability.
        scale : torch.Tensor
            Scaling factor applied to the mean (e.g. size factor).
        eps : float
            Small constant for numerical stability.
        validate_args : bool
            Whether to validate distribution arguments.
        """
        self.mu = mu
        self.theta = theta
        self.zi_logits = zi_logits
        self.scale = scale 
        self.eps = eps

        # Initialize base Distribution class
        super().__init__(validate_args=validate_args)

    def log_prob(self, x):
        """
        Compute log-probability of observed counts under ZINB.

        Parameters
        ----------
        x : torch.Tensor
            Observed count data.

        Returns
        -------
        torch.Tensor
            Log-likelihood of each observation.
        """

        # Convert zero-inflation logits to probability
        pi = torch.sigmoid(self.zi_logits)

        # Log-probability under the Negative Binomial distribution
        log_nb = (
            torch.lgamma(x + self.theta)
            - torch.lgamma(self.theta)
            - torch.lgamma(x + 1)
            + self.theta * torch.log(self.theta + self.eps)
            + x * torch.log(self.mu + self.eps)
            - (x + self.theta) * torch.log(self.mu + self.theta + self.eps)
        )

        # Zero-inflated mixture:
        # - If x == 0: mixture of structural zero and NB zero
        # - If x > 0: NB probability scaled by (1 - pi)
        log_prob = torch.where(
            (x == 0),
            torch.log(pi + (1 - pi) * torch.exp(log_nb) + self.eps),
            torch.log(1 - pi + self.eps) + log_nb,
        )

        return log_prob

class multigai(nn.Module):
    """
    MultiGAI: A multi-modal generative integration model.

    This model supports joint representation learning and cross-modality
    reconstruction using attention-based latent fusion and ZINB decoders.
    """

    def __init__(
        self,
        input_dim1,
        input_dim2,
        input_dim3,
        n_hidden,
        hidden,
        z_dim,
        batch_dim,
        q_dim=128,
        kv_n=64,
        dropout_rate=0.1
    ):
        super().__init__()

        # ===== Hyperparameters =====
        self.kv_n = kv_n              # Number of key/value tokens
        self.q_dim = q_dim            # Query embedding dimension
        self.z_dim = z_dim            # Latent space dimension
        self.batch_dim = batch_dim    # Batch covariate dimension
        self.hidden = hidden          # Hidden layer width

        # ===== Shared encoder constructor =====
        def make_encoder(in_dim):
            """
            Build a multi-layer MLP encoder with LayerNorm and Dropout.
            """
            layers = []
            for _ in range(n_hidden):
                layers.append(
                    nn.Sequential(
                        nn.Linear(in_dim, hidden),
                        nn.LayerNorm(hidden),
                        nn.ReLU(),
                        nn.Dropout(dropout_rate)
                    )
                )
                in_dim = hidden
            return nn.Sequential(*layers)

        # ===== Modality-specific encoders =====
        self.encoder1 = make_encoder(input_dim1)
        self.encoder2 = make_encoder(input_dim2)
        self.encoder3 = make_encoder(input_dim3)

        # ===== Projection to query space =====
        self.q_net1 = nn.Linear(hidden, q_dim)
        self.q_net2 = nn.Linear(hidden, q_dim)
        self.q_net3 = nn.Linear(hidden, q_dim)

        # ===== Key / Value network constructor =====
        def make_kv(is_value):
            """
            Build key/value networks for attention-based latent fusion.
            """
            layers = []
            in_dim = kv_n
            for _ in range(n_hidden):
                layers.append(
                    nn.Sequential(
                        nn.Linear(in_dim, hidden),
                        nn.LayerNorm(hidden),
                        nn.ReLU(),
                        nn.Dropout(dropout_rate)
                    )
                )
                in_dim = hidden
            if not is_value:
                layers.append(nn.Linear(hidden, q_dim))
            return nn.Sequential(*layers)

        # ===== Modality-specific key/value banks =====
        self.keys1 = make_kv(is_value=False)
        self.values1 = make_kv(is_value=True)
        self.keys2 = make_kv(is_value=False)
        self.values2 = make_kv(is_value=True)
        self.keys3 = make_kv(is_value=False)
        self.values3 = make_kv(is_value=True)

        # ===== Shared key/value banks =====
        self.keys = make_kv(is_value=False)
        self.values = make_kv(is_value=True)

        # ===== Latent Gaussian parameter heads =====
        self.m_net = nn.Linear(hidden, z_dim)   # Mean of q(z|x)
        self.l_net = nn.Linear(hidden, z_dim)   # Log-variance of q(z|x)

        # ===== Shared decoder backbone =====
        self.decoder_base = nn.ModuleList([
            nn.Sequential(
                nn.Linear(z_dim + batch_dim if i == 0 else hidden + batch_dim, hidden),
                nn.LayerNorm(hidden),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            ) for i in range(n_hidden)
        ])

        # ===== ZINB decoders for each modality =====
        # Modality 1
        self.fc_scale1 = nn.Sequential(
            nn.Linear(hidden + batch_dim, input_dim1),
            nn.Softmax(dim=-1)
        )
        self.fc_dropout1 = nn.Linear(hidden + batch_dim, input_dim1)
        self.fc_r1 = nn.Parameter(torch.randn(input_dim1))  # Dispersion

        # Modality 2
        self.fc_scale2 = nn.Sequential(
            nn.Linear(hidden + batch_dim, input_dim2),
            nn.Softmax(dim=-1)
        )
        self.fc_dropout2 = nn.Linear(hidden + batch_dim, input_dim2)
        self.fc_r2 = nn.Parameter(torch.randn(input_dim2))

        # Modality 3
        self.fc_scale3 = nn.Sequential(
            nn.Linear(hidden + batch_dim, input_dim3),
            nn.Softmax(dim=-1)
        )
        self.fc_dropout3 = nn.Linear(hidden + batch_dim, input_dim3)
        self.fc_r3 = nn.Parameter(torch.randn(input_dim3))

    def compute_mu_var(self, device, q1=None, q2=None, q3=None, m=None):
        """
        Compute latent mean and variance using attention-based fusion
        of modality-specific query embeddings.
        """
        I = torch.eye(self.kv_n, device=device)   # Identity tokens
        scale = math.sqrt(self.q_dim)

        attn1 = attn2 = attn3 = None

        # ===== Modality-pair attention fusion =====
        if m == 12:
            ker1, val1 = self.keys1(I), self.values1(I)
            ker2, val2 = self.keys2(I), self.values2(I)

            attn1 = torch.softmax((q1 @ ker1.T) / scale, dim=-1) @ val1
            attn2 = torch.softmax((q2 @ ker2.T) / scale, dim=-1) @ val2

            ae = (attn1 + attn2) / 2.0

        elif m == 13:
            ker1, val1 = self.keys1(I), self.values1(I)
            ker3, val3 = self.keys3(I), self.values3(I)

            attn1 = torch.softmax((q1 @ ker1.T) / scale, dim=-1) @ val1
            attn3 = torch.softmax((q3 @ ker3.T) / scale, dim=-1) @ val3

            ae = (attn1 + attn3) / 2.0
        else:
            raise ValueError(f"Unsupported modality m={m}")

        # ===== Shared attention refinement =====
        ker, val = self.keys(I), self.values(I)
        attn = torch.softmax((ae @ ker.T) / scale, dim=-1) @ val

        # ===== Latent Gaussian parameters =====
        mu = self.m_net(attn)
        logvar = self.l_net(attn)
        var = torch.exp(logvar) + 1e-8

        return mu, var, attn1, attn2, attn3, ae, attn

    def decode_from_z(self, dz, m1, m2, m3, m, batch):
        """
        Decode latent variables into modality-specific ZINB distributions.
        """
        for layer in self.decoder_base:
            dz = torch.cat([dz, batch], dim=1)
            dz = layer(dz)

        final = torch.cat([dz, batch], dim=1)

        p1 = p2 = p3 = None

        if m in [12, 13]:
            # ===== Modality 1 =====
            scale1 = self.fc_scale1(final)
            dropout1 = self.fc_dropout1(final)
            library1 = torch.log(m1.sum(1, keepdim=True) + 1e-8)
            rate1 = torch.exp(library1) * scale1

            p1 = ZeroInflatedNegativeBinomial(
                mu=rate1,
                theta=torch.exp(self.fc_r1),
                zi_logits=dropout1,
                scale=scale1
            )

            # ===== Modality 2 =====
            if m == 12:
                scale2 = self.fc_scale2(final)
                dropout2 = self.fc_dropout2(final)
                library2 = torch.log(m2.sum(1, keepdim=True) + 1e-8)
                rate2 = torch.exp(library2) * scale2

                p2 = ZeroInflatedNegativeBinomial(
                    mu=rate2,
                    theta=torch.exp(self.fc_r2),
                    zi_logits=dropout2,
                    scale=scale2
                )

            # ===== Modality 3 =====
            if m == 13:
                scale3 = self.fc_scale3(final)
                dropout3 = self.fc_dropout3(final)
                library3 = torch.log(m3.sum(1, keepdim=True) + 1e-8)
                rate3 = torch.exp(library3) * scale3

                p3 = ZeroInflatedNegativeBinomial(
                    mu=rate3,
                    theta=torch.exp(self.fc_r3),
                    zi_logits=dropout3,
                    scale=scale3
                )

        return p1, p2, p3

    def forward(self, m1, m2, m3, m, batch):
        """
        Forward pass of MultiGAI.
        """
        device = batch.device
        batch_size = m1.size(0)

        # ===== Encode =====
        q1 = q2 = q3 = None
        if m in [12, 13]:
            e1 = self.encoder1(m1)
            q1 = self.q_net1(e1)

            if m == 12:
                e2 = self.encoder2(m2)
                q2 = self.q_net2(e2)

            if m == 13:
                e3 = self.encoder3(m3)
                q3 = self.q_net3(e3)

        # ===== Latent inference =====
        mu, var, a1, a2, a3, aq, ae = self.compute_mu_var(device, q1, q2, q3, m)
        var = torch.clamp(var, min=1e-6)

        qz = Normal(mu, var.sqrt())
        z = qz.rsample()
        pz = Normal(torch.zeros_like(z), torch.ones_like(z))

        # ===== Decode =====
        p1, p2, p3 = self.decode_from_z(z, m1, m2, m3, m, batch)

        return z, p1, p2, p3, qz, pz, a1, a2, a3, aq, ae

    def loss_function(self, m1, m2, m3, m, p1, p2, p3, q, p, a1, a2, a3, w):
        """
        Compute total loss:
        reconstruction + KL divergence + cosine alignment loss.
        """
        device = m1.device
        cos_loss = torch.tensor(0.0, device=device)

        if m == 12:
            reconst_loss = (
                -p1.log_prob(m1).sum(-1).mean()
                -p2.log_prob(m2).sum(-1).mean()
            )
            cos_loss = (1 - F.cosine_similarity(a1, a2, dim=1)).mean()

        elif m == 13:
            reconst_loss = (
                -p1.log_prob(m1).sum(-1).mean()
                -p3.log_prob(m3).sum(-1).mean()
            )
            cos_loss = (1 - F.cosine_similarity(a1, a3, dim=1)).mean()
        else:
            raise ValueError(f"Unsupported modality m={m}")

        kl = torch.distributions.kl_divergence(q, p).sum(dim=-1).mean()

        loss = reconst_loss + w * kl + cos_loss
        return loss, reconst_loss, kl, cos_loss
    
class MultiOmicsDataset(Dataset):
    """
    PyTorch Dataset for multi-omics data.

    Each sample contains:
    - Three modality feature vectors (m1, m2, m3)
    - A modality indicator m
    - A batch covariate
    """

    def __init__(self, *args):
        self.m1_data = args[0]     # Modality 1 data (e.g., scRNA-seq)
        self.m2_data = args[1]     # Modality 2 data (e.g., scATAC-seq)
        self.m3_data = args[2]     # Modality 3 data (e.g., ADT)
        self.m_data = args[3]      # Modality indicator (12 or 13)
        self.batch_data = args[4]  # Batch labels / covariates

    def __len__(self):
        # Number of samples
        return len(self.batch_data)

    def __getitem__(self, idx):
        # Convert each modality to float tensor
        m1 = torch.tensor(self.m1_data[idx], dtype=torch.float32).squeeze(0)
        m2 = torch.tensor(self.m2_data[idx], dtype=torch.float32).squeeze(0)
        m3 = torch.tensor(self.m3_data[idx], dtype=torch.float32).squeeze(0)

        # Modality combination indicator
        m = torch.tensor(self.m_data[idx], dtype=torch.float32).squeeze(0)

        # Batch covariate
        batch = self.batch_data[idx]

        return m1, m2, m3, m, batch, idx

def train_and_evaluate_model(
    output_path,
    train_loader,
    test_loader,
    adata,
    *args,
    num_epochs=200
):
    """
    Train the MultiGAI model and extract latent representations.

    The final latent embeddings are stored in adata.obsm['latent'].
    """

    # Select training device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Unpack model hyperparameters
    input_dim1, input_dim2, input_dim3, n_hidden, hidden, z_dim, batch_dim, q_dim, kv_n = args

    # Initialize model
    model = multigai(
        input_dim1, input_dim2, input_dim3,
        n_hidden, hidden, z_dim,
        batch_dim, q_dim, kv_n
    ).to(device)

    # Optimizer and learning rate scheduler
    optimizer_main = Adam(model.parameters(), lr=0.001)
    scheduler_main = torch.optim.lr_scheduler.StepLR(
        optimizer_main, step_size=50, gamma=0.9
    )

    # Progress bar
    tqdm_bar = tqdm(range(num_epochs), desc="Training Progress")

    # ===== Training loop =====
    for epoch in tqdm_bar:
        running_loss = 0.0
        running_recon = 0.0
        running_kl = 0.0
        running_cos = 0.0

        # KL annealing schedule
        kl_weight = 0.0 if epoch < 100 else 0.1

        model.train()

        for batch_data in train_loader:
            optimizer_main.zero_grad()

            # Extract modality indicators for the batch
            m_values = batch_data[3]
            unique_m = m_values.unique()

            # Randomize modality processing order
            perm = torch.randperm(len(unique_m))
            unique_m = unique_m[perm]

            # Process each modality combination separately
            for m_curr in unique_m:
                mask = (m_values == m_curr)

                if mask.any():
                    # Sub-batch corresponding to the current modality
                    sub_batch = [d[mask] for d in batch_data]
                    m1, m2, m3, m_tensor, batch_tensor, idx = [
                        x.to(device) for x in sub_batch
                    ]

                    # Forward pass
                    z, p1, p2, p3, qz, pz, a1, a2, a3, aq, ae = model(
                        m1, m2, m3,
                        int(m_curr.item()),
                        batch_tensor
                    )

                    # Compute loss
                    loss, reconst_loss, kl_loss, cos_loss = model.loss_function(
                        m1, m2, m3,
                        int(m_curr.item()),
                        p1, p2, p3,
                        qz, pz,
                        a1, a2, a3,
                        kl_weight
                    )

                    # Backpropagation and optimization
                    loss.backward()
                    optimizer_main.step()

                    # Accumulate metrics
                    running_loss += loss.item()
                    running_recon += reconst_loss.item()
                    running_kl += kl_loss.item()
                    running_cos += cos_loss.item()

        # Update progress bar
        n_batches = len(train_loader)
        tqdm_bar.set_postfix({
            "loss": f"{running_loss / n_batches:.4f}",
            "recon": f"{running_recon / n_batches:.4f}",
            "kl": f"{running_kl / n_batches:.4f}",
            "cos": f"{running_cos / n_batches:.4f}",
            "w": f"{kl_weight:.3f}"
        })

        scheduler_main.step()

    # ===== Evaluation and latent extraction =====
    model.eval()
    z_all = torch.zeros((len(adata), z_dim), device=device)

    with torch.no_grad():
        for batch_data in test_loader:
            indices = batch_data[-1]
            m_values = batch_data[3]
            unique_m = m_values.unique()

            for m_curr in unique_m:
                mask = (m_values == m_curr)
                if mask.any():
                    sub_batch = [d[mask] for d in batch_data]
                    m1, m2, m3, m_tensor, batch_tensor, idx = [
                        x.to(device) for x in sub_batch
                    ]

                    # Encode to latent space
                    z, _, _, _, _, _, _, _, _, _, _ = model(
                        m1, m2, m3,
                        int(m_curr.item()),
                        batch_tensor
                    )

                    # Store latent embeddings at original indices
                    z_all[idx.long()] = z

    # Save latent representation to AnnData
    adata.obsm['latent'] = z_all.cpu().numpy()
    adata.write_h5ad(output_path)

In [None]:
# ================================================
# Load Multiome single-cell data (RNA + ATAC)
# ================================================

# Load RNA and ATAC datasets
rna = sc.read("./data/neurips-multiome/rna_hvg.h5ad") 
atac = sc.read("./data/neurips-multiome/atac_hvf.h5ad") 

# Keep only cells with Modality equal to 'multiome'
rna = rna[rna.obs['Modality'] == 'multiome']

# Extract raw count matrices for RNA and ATAC
rna_d, atac_d = rna.layers['counts'], atac.layers['counts'] 

# Create a placeholder matrix for ADT (protein) data (not present in Multiome dataset)
adt_d = np.zeros((rna_d.shape[0], 1))  

# Record feature dimensions for each modality
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1  

# Convert sparse matrices to dense arrays if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(atac_d): atac_d = atac_d.toarray() 

# Construct modality vector (all cells are Multiome)
modality_vector = np.full(rna.shape[0], 12.0)  # 12 corresponds to Multiome

# Encode batch information as one-hot vectors for batch effect correction
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build the integrated Multi-Omics dataset
# Includes RNA, ATAC, ADT (placeholder), modality, and batch information
dataset = MultiOmicsDataset(rna_d, atac_d, adt_d, modality_vector, batch_encoded) 

# Create training and testing DataLoaders
train_loader = DataLoader(dataset, batch_size=512, shuffle=True) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: The `rna` object passed here is NOT used for training.
#       It only provides metadata (.obs) and serves as the template
#       to save the learned latent variables into a .h5ad file.
#       The model integrates Multiome data (RNA + ATAC),
#       supports batch effect correction, and outputs a unified latent representation
# ================================================
train_and_evaluate_model(
    './results/neurips-multiome-multigai_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only used to save latent variables, not for training
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for all neural network components (encoder & decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention
    128,        # Decoder hidden dimension (can be same as above if shared)
    128         # Number of key-value (K-V) pairs in attention
)

In [None]:
# ================================================
# Load CITE-seq single-cell data (RNA + ADT)
# ================================================

# Load RNA and protein (ADT) datasets
rna = sc.read("./data/neurips-cite/rna_hvg.h5ad") 
adt = sc.read("./data/neurips-cite/protein.h5ad") 

# Extract raw count matrices for RNA and ADT
rna_d, adt_d = rna.layers['counts'], adt.layers['counts'] 

# Create a placeholder matrix for ATAC (not present in CITE-seq dataset)
atac_d = np.zeros((rna_d.shape[0], 1))  

# Record feature dimensions for each modality
rna_dim, atac_dim, adt_dim = rna.shape[1], 1, adt.shape[1]  

# Convert sparse matrices to dense arrays if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(adt_d): adt_d = adt_d.toarray() 

# Construct modality vector to distinguish Multiome and CITE-seq data
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot vectors for batch effect correction
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build the integrated Multi-Omics dataset
# Includes RNA, ATAC (placeholder), ADT, modality, and batch information
dataset = MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 

# Create training and testing DataLoaders
train_loader = DataLoader(dataset, batch_size=512, shuffle=True) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: The `rna` object passed here is NOT used for training.
#       It only provides metadata (.obs) and serves as the template
#       to save the learned latent variables into a .h5ad file.
#       The model integrates CITE-seq data (RNA + ADT),
#       supports batch effect correction, and outputs a unified latent representation
# ================================================
train_and_evaluate_model(
    './results/neurips-cite-multigai_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only used to save latent variables, not for training
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
# ================================================
# Load Multiome single-cell data (RNA + ATAC)
# ================================================

# Load RNA and ATAC datasets
rna = sc.read("./data/neurips-multiome/rna_hvg.h5ad") 
atac = sc.read("./data/neurips-multiome/atac_hvf.h5ad") 

# ================= Training data (exclude NK cells) =================
# Select only cells that are NOT NK cells
rna_t = rna[rna.obs_names[rna.obs["cell_type"] != "NK"]].copy() 
atac_t = atac[atac.obs_names[atac.obs["cell_type"] != "NK"]].copy() 

# Extract raw count matrices
rna_t_d, atac_t_d = rna_t.layers['counts'], atac_t.layers['counts'] 

# Placeholder for ADT (not present in Multiome)
adt_t_d = np.zeros((rna_t_d.shape[0], 1))  

# Feature dimensions
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1 

# Convert sparse matrices to dense arrays if needed
if issparse(rna_t_d): rna_t_d = rna_t_d.toarray() 
if issparse(atac_t_d): atac_t_d = atac_t_d.toarray() 

# Construct modality vector (all cells are Multiome)
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna_t.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot
batch_indices = torch.from_numpy(rna_t.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build training dataset
dataset_t = MultiOmicsDataset(rna_t_d, atac_t_d, adt_t_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset_t, batch_size=512, shuffle=True) 

# ================= Test data (all cells, including NK) =================
rna_d, atac_d = rna.layers['counts'], atac.layers['counts'] 
adt_d = np.zeros((rna_d.shape[0], 1))  
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1 

# Convert sparse to dense if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(atac_d): atac_d = atac_d.toarray() 

# Modality vector
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Batch encoding
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build testing dataset
dataset = MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: `rna` is only used to save latent variables and provide metadata (.obs)
#       The model is trained on cells excluding NK (train_loader)
#       and evaluated on all cells including NK (test_loader)
# ================================================
train_and_evaluate_model(
    './results/neurips-multiome-NK_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only for saving latent variables and metadata
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
# ================================================
# Load Multiome single-cell data (RNA + ATAC)
# ================================================

# Load RNA and ATAC datasets
rna = sc.read("./data/neurips-multiome/rna_hvg.h5ad") 
atac = sc.read("./data/neurips-multiome/atac_hvf.h5ad") 

# ================= Training data (exclude Lymph prog cells) =================
# Select only cells that are NOT Lymph prog
rna_t = rna[rna.obs_names[rna.obs["cell_type"] != "Lymph prog"]].copy() 
atac_t = atac[atac.obs_names[atac.obs["cell_type"] != "Lymph prog"]].copy() 

# Extract raw count matrices
rna_t_d, atac_t_d = rna_t.layers['counts'], atac_t.layers['counts'] 

# Placeholder for ADT (not present in Multiome)
adt_t_d = np.zeros((rna_t_d.shape[0], 1))  

# Feature dimensions
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1 

# Convert sparse matrices to dense arrays if needed
if issparse(rna_t_d): rna_t_d = rna_t_d.toarray() 
if issparse(atac_t_d): atac_t_d = atac_t_d.toarray() 

# Construct modality vector (all cells are Multiome)
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna_t.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot
batch_indices = torch.from_numpy(rna_t.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build training dataset
dataset_t = MultiOmicsDataset(rna_t_d, atac_t_d, adt_t_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset_t, batch_size=512, shuffle=True) 

# ================= Test data (all cells, including Lymph prog) =================
rna_d, atac_d = rna.layers['counts'], atac.layers['counts'] 
adt_d = np.zeros((rna_d.shape[0], 1))  
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1 

# Convert sparse to dense if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(atac_d): atac_d = atac_d.toarray() 

# Modality vector
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Batch encoding
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build testing dataset
dataset = MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: `rna` is only used to save latent variables and provide metadata (.obs)
#       The model is trained on cells excluding Lymph prog (train_loader)
#       and evaluated on all cells including Lymph prog (test_loader)
# ================================================
train_and_evaluate_model(
    './results/neurips-multiome-Lymphprog_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only for saving latent variables and metadata
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
# ================================================
# Load CITE-seq single-cell data (RNA + ADT)
# ================================================

# Load RNA and protein (ADT) datasets
rna = sc.read("./data/neurips-cite/rna_hvg.h5ad")
adt = sc.read("./data/neurips-cite/protein.h5ad")

# ================= Training data (exclude CD8+ T naive cells) =================
# Select only cells that are NOT CD8+ T naive
rna_t  = rna[rna.obs_names[rna.obs["cell_type"] != "CD8+ T naive"]].copy()
adt_t = adt[adt.obs_names[adt.obs["cell_type"] != "CD8+ T naive"]].copy()

# Extract raw count matrices
rna_t_d, adt_t_d = rna_t.layers['counts'], adt_t.layers['counts']

# Placeholder for ATAC (not present in CITE-seq)
atac_t_d = np.zeros((rna_t_d.shape[0], 1))  

# Feature dimensions
rna_dim, atac_dim, adt_dim = rna.shape[1], 1, adt.shape[1]

# Convert sparse matrices to dense arrays if needed
if issparse(rna_t_d): rna_t_d = rna_t_d.toarray()
if issparse(adt_t_d): adt_t_d = adt_t_d.toarray()  

# Construct modality vector (all cells are CITE-seq)
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna_t.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot
batch_indices = torch.from_numpy(rna_t.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build training dataset
dataset_t = MultiOmicsDataset(rna_t_d, atac_t_d, adt_t_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset_t, batch_size=512, shuffle=True) 

# ================= Test data (all cells, including CD8+ T naive) =================
rna_d, adt_d = rna.layers['counts'], adt.layers['counts'] 
atac_d = np.zeros((rna_d.shape[0], 1))  

# Convert sparse to dense if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(adt_d): adt_d = adt_d.toarray() 

# Modality vector
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Batch encoding
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build testing dataset
dataset = MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: `rna` is only used to save latent variables and provide metadata (.obs)
#       The model is trained on cells excluding CD8+ T naive (train_loader)
#       and evaluated on all cells including CD8+ T naive (test_loader)
# ================================================
train_and_evaluate_model(
    './results/neurips-cite-CD8+Tnaive_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only for saving latent variables and metadata
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
# ================================================
# Load CITE-seq single-cell data (RNA + ADT)
# ================================================

# Load RNA and protein (ADT) datasets
rna = sc.read("./data/neurips-cite/rna_hvg.h5ad")
adt = sc.read("./data/neurips-cite/protein.h5ad")

# ================= Training data (exclude HSC cells) =================
# Select only cells that are NOT HSC
rna_t  = rna[rna.obs_names[rna.obs["cell_type"] != "HSC"]].copy()
adt_t = adt[adt.obs_names[adt.obs["cell_type"] != "HSC"]].copy()

# Extract raw count matrices
rna_t_d, adt_t_d = rna_t.layers['counts'], adt_t.layers['counts']

# Placeholder for ATAC (not present in CITE-seq)
atac_t_d = np.zeros((rna_t_d.shape[0], 1))  

# Feature dimensions
rna_dim, atac_dim, adt_dim = rna.shape[1], 1, adt.shape[1]

# Convert sparse matrices to dense arrays if needed
if issparse(rna_t_d): rna_t_d = rna_t_d.toarray()
if issparse(adt_t_d): adt_t_d = adt_t_d.toarray()  

# Construct modality vector (all cells are CITE-seq)
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna_t.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Encode batch information as one-hot
batch_indices = torch.from_numpy(rna_t.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build training dataset
dataset_t = MultiOmicsDataset(rna_t_d, atac_t_d, adt_t_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset_t, batch_size=512, shuffle=True) 

# ================= Test data (all cells, including HSC) =================
rna_d, adt_d = rna.layers['counts'], adt.layers['counts'] 
atac_d = np.zeros((rna_d.shape[0], 1))  

# Convert sparse to dense if needed
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(adt_d): adt_d = adt_d.toarray() 

# Modality vector
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 

# Batch encoding
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 

# Build testing dataset
dataset = MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 

# ================================================
# Train and evaluate the MultiGAI model
# Note: `rna` is only used to save latent variables and provide metadata (.obs)
#       The model is trained on cells excluding HSC (train_loader)
#       and evaluated on all cells including HSC (test_loader)
# ================================================
train_and_evaluate_model(
    './results/neurips-cite-HSC_latent.h5ad',
    train_loader, 
    test_loader, 
    rna,  # Only for saving latent variables and metadata
    rna_dim, 
    atac_dim, 
    adt_dim, 
    1,          # Number of hidden layers for the entire network (encoder + decoder)
    128,        # Hidden dimension for all layers in the network
    30,         # Latent dimension
    batch_dim,  # Dimension of the query vector (q) in attention mechanism
    128,        # Decoder hidden dimension (if different from shared hidden layers)
    128         # Number of key-value (K-V) pairs used in attention
)

In [None]:
rna = sc.read('./data/trimodal_rna.h5ad')
atac = sc.read('./data/trimodal_atac.h5ad')
adt = sc.read('./data/trimodal_adt.h5ad')

t_dir = "./results"

rna_d, atac_d, adt_d = rna.layers['counts'], atac.layers['counts'], adt.layers['counts']
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], adt.shape[1]
if issparse(rna_d): rna_d = rna_d.toarray()
if issparse(atac_d): atac_d = atac_d.toarray()
if issparse(adt_d): adt_d = adt_d.toarray() 
modality_map = {'multiome': 12,'cite': 13}
modality_vector = rna.obs['Modality'].map(modality_map)
modality_d = modality_vector.to_numpy().astype(float)
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long()
batch_encoded = torch.nn.functional.one_hot(batch_indices)
batch_dim = batch_encoded.shape[1]
dataset = MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded)
train_loader = DataLoader(dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(dataset, batch_size=512, shuffle=False)

output_path = './results/trimodal.h5ad'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

adata, input_dim1, input_dim2, input_dim3, n_hidden, hidden, z_dim, batch_dim, q_dim, kv_n = rna, rna_dim, atac_dim, adt_dim, 1, 128, 30, batch_dim, 128, 128
model = multigai(input_dim1, input_dim2, input_dim3,
                    n_hidden, hidden, z_dim,
                    batch_dim, q_dim, kv_n).to(device)

optimizer_main = Adam(model.parameters(), lr=0.001)
scheduler_main = torch.optim.lr_scheduler.StepLR(optimizer_main, step_size=50, gamma=0.9)
tqdm_bar = tqdm(range(200), desc="Training Progress")

for epoch in tqdm_bar:
    running_loss = 0.0
    running_recon = 0.0
    running_kl = 0.0
    running_cos = 0.0

    kl_weight = 0.0 if epoch < 100 else 0.1

    model.train()
    for batch_data in train_loader:
        optimizer_main.zero_grad()

        m_values = batch_data[3]                       
        unique_m = m_values.unique()

        perm = torch.randperm(len(unique_m))
        unique_m = unique_m[perm]

        for m_curr in unique_m:
            mask = (m_values == m_curr)

            if mask.any():
                sub_batch = [d[mask] for d in batch_data]
                m1, m2, m3, m_tensor, batch_tensor, idx = [x.to(device) for x in sub_batch]

                z, p1, p2, p3, qz, pz, a1, a2, a3, aq, ae = model(
                    m1, m2, m3,
                    int(m_curr.item()),
                    batch_tensor
                )

                loss, reconst_loss, kl_loss, cos_loss = model.loss_function(
                    m1, m2, m3,
                    int(m_curr.item()),
                    p1, p2, p3, qz, pz,
                    a1, a2, a3,
                    kl_weight
                )

                loss.backward()
                optimizer_main.step()

                running_loss += loss.item()
                running_recon += reconst_loss.item()
                running_kl += kl_loss.item()
                running_cos += cos_loss.item()

    n_batches = len(train_loader)
    tqdm_bar.set_postfix({
        "loss": f"{running_loss/n_batches:.4f}",
        "recon": f"{running_recon/n_batches:.4f}",
        "kl": f"{running_kl/n_batches:.4f}",
        "cos": f"{running_cos/n_batches:.4f}",
        "w": f"{kl_weight:.3f}"
    })

    scheduler_main.step()

model.eval()
z_all = torch.zeros((len(adata), z_dim), device=device)
ae_all_global = torch.zeros((len(adata), model.hidden), device=device)
aq_all_global = torch.zeros((len(adata), model.hidden), device=device) 
val = model.values(torch.eye(model.kv_n, device=device))  

with torch.no_grad():
    for batch_idx, batch_data in enumerate(test_loader):
        indices = batch_data[-1]  
        m_values = batch_data[3]
        unique_m = m_values.unique()

        for m_curr in unique_m:
            mask = (m_values == m_curr)
            if not mask.any():
                continue

            sub_batch = [d[mask] for d in batch_data]
            m1, m2, m3, m_tensor, batch_tensor, idx = [x.to(device) for x in sub_batch]

            z, p1, p2, p3, qz, pz, a1_batch, a2_batch, a3_batch, aq_batch, ae_batch = model(
                m1, m2, m3, int(m_curr.item()), batch_tensor
            )

            z_all[idx.long()] = z
            ae_all_global[idx.long()] = ae_batch
            aq_all_global[idx.long()] = aq_batch  

adata.obsm['latent'] = z_all.cpu().numpy()
adata.write_h5ad(output_path)

torch.save(ae_all_global.cpu(), os.path.join(t_dir, "trimodal_e.pt"))
torch.save(aq_all_global.cpu(), os.path.join(t_dir, "trimodal_q.pt")) 
torch.save(val.cpu(), os.path.join(t_dir, "trimodal_v.pt"))