In [None]:
import random
import numpy as np
import torch
from torch.distributions import Distribution, Normal, constraints, kl_divergence, Bernoulli, NegativeBinomial
from torch import nn
from torch.utils.data import DataLoader, Dataset
import scanpy as sc
import pandas as pd
from scipy.sparse import issparse
from torch.optim import Adam
import pickle
import anndata as ad
import math
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
def set_seed(seed):
    # Set Python built-in random seed
    random.seed(seed)  
    
    # Set NumPy random seed
    np.random.seed(seed) 
    
    # Set PyTorch CPU random seed
    torch.manual_seed(seed) 
    
    # Set PyTorch GPU random seed (current device)
    torch.cuda.manual_seed(seed)  
    
    # Set PyTorch GPU random seed (all devices)
    torch.cuda.manual_seed_all(seed)  
    
    # Ensure deterministic behavior for CuDNN
    torch.backends.cudnn.deterministic = True  
    
    # Disable CuDNN auto-tuner to guarantee reproducibility
    torch.backends.cudnn.benchmark = False  

# Fix random seed for reproducibility
set_seed(42)


class ZeroInflatedNegativeBinomial(Distribution):
    """
    Zero-Inflated Negative Binomial (ZINB) distribution.

    This distribution is commonly used to model over-dispersed count data
    with excessive zeros, such as scRNA-seq gene expression counts.
    """

    # Constraints on distribution parameters
    arg_constraints = {
        "mu": constraints.greater_than_eq(0),      # Mean of the NB distribution
        "theta": constraints.greater_than_eq(0),   # Inverse dispersion parameter
        "zi_logits": constraints.real,              # Logits for zero-inflation probability
        "scale": constraints.greater_than_eq(0),   # Optional scaling factor (e.g. library size)
    }

    # Support of the distribution: non-negative integers
    support = constraints.nonnegative_integer

    def __init__(self, mu, theta, zi_logits, scale, eps=1e-8, validate_args=False):
        """
        Parameters
        ----------
        mu : torch.Tensor
            Mean of the Negative Binomial distribution.
        theta : torch.Tensor
            Inverse dispersion parameter of the NB distribution.
        zi_logits : torch.Tensor
            Logits controlling the zero-inflation probability.
        scale : torch.Tensor
            Scaling factor applied to the mean (e.g. size factor).
        eps : float
            Small constant for numerical stability.
        validate_args : bool
            Whether to validate distribution arguments.
        """
        self.mu = mu
        self.theta = theta
        self.zi_logits = zi_logits
        self.scale = scale 
        self.eps = eps

        # Initialize base Distribution class
        super().__init__(validate_args=validate_args)

    def log_prob(self, x):
        """
        Compute log-probability of observed counts under ZINB.

        Parameters
        ----------
        x : torch.Tensor
            Observed count data.

        Returns
        -------
        torch.Tensor
            Log-likelihood of each observation.
        """

        # Convert zero-inflation logits to probability
        pi = torch.sigmoid(self.zi_logits)

        # Log-probability under the Negative Binomial distribution
        log_nb = (
            torch.lgamma(x + self.theta)
            - torch.lgamma(self.theta)
            - torch.lgamma(x + 1)
            + self.theta * torch.log(self.theta + self.eps)
            + x * torch.log(self.mu + self.eps)
            - (x + self.theta) * torch.log(self.mu + self.theta + self.eps)
        )

        # Zero-inflated mixture:
        # - If x == 0: mixture of structural zero and NB zero
        # - If x > 0: NB probability scaled by (1 - pi)
        log_prob = torch.where(
            (x == 0),
            torch.log(pi + (1 - pi) * torch.exp(log_nb) + self.eps),
            torch.log(1 - pi + self.eps) + log_nb,
        )

        return log_prob

class multigain(nn.Module):
    def __init__(self, input_dim1, input_dim2, input_dim3, n_hidden, hidden, z_dim, batch_dim, q_dim=128, kv_n=64, dropout_rate=0.1):
        """
        Multi-modal VAE model supporting three input modalities (m1, m2, m3) 
        and modeling outputs with Zero-Inflated Negative Binomial (ZINB) distributions.
        
        Args:
            input_dim1, input_dim2, input_dim3: Input dimensions for the three modalities
            n_hidden: Number of hidden layers in encoder/decoder
            hidden: Hidden layer dimension
            z_dim: Latent variable dimension
            batch_dim: One-hot batch encoding dimension
            q_dim: Query vector dimension (for attention, not currently used)
            kv_n: Number of key/value vectors (for attention, not currently used)
            dropout_rate: Dropout probability
        """
        super().__init__()
        self.kv_n = kv_n
        self.q_dim = q_dim
        self.z_dim = z_dim
        self.batch_dim = batch_dim
        self.hidden = hidden

        # ===== Encoder builder =====
        # make_encoder constructs n_hidden fully connected layers with LayerNorm + ReLU + Dropout
        def make_encoder(in_dim):
            layers = []
            for _ in range(n_hidden):
                layers.append(
                    nn.Sequential(
                        nn.Linear(in_dim, hidden),
                        nn.LayerNorm(hidden),
                        nn.ReLU(),
                        nn.Dropout(dropout_rate)
                    )
                )
                in_dim = hidden
            return nn.Sequential(*layers)

        # Three encoders for three modalities
        self.encoder1 = make_encoder(input_dim1)
        self.encoder2 = make_encoder(input_dim2)
        self.encoder3 = make_encoder(input_dim3)

        # ===== Latent variable networks =====
        # m_net outputs mean (mu)
        # l_net outputs log-variance (logvar)
        self.m_net = nn.Linear(hidden, z_dim)
        self.l_net = nn.Linear(hidden, z_dim)

        # ===== Decoder =====
        # Shared base layers + modality-specific heads
        self.decoder_base = nn.ModuleList([
            nn.Sequential(
                nn.Linear(z_dim + batch_dim if i == 0 else hidden + batch_dim, hidden),
                nn.LayerNorm(hidden),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            ) for i in range(n_hidden)
        ])

        # ZINB output heads for three modalities
        self.fc_scale1 = nn.Sequential(nn.Linear(hidden + batch_dim, input_dim1), nn.Softmax(dim=-1))
        self.fc_dropout1 = nn.Linear(hidden + batch_dim, input_dim1)
        self.fc_r1 = nn.Parameter(torch.randn(input_dim1))  # Learnable dispersion

        self.fc_scale2 = nn.Sequential(nn.Linear(hidden + batch_dim, input_dim2), nn.Softmax(dim=-1))
        self.fc_dropout2 = nn.Linear(hidden + batch_dim, input_dim2)
        self.fc_r2 = nn.Parameter(torch.randn(input_dim2))

        self.fc_scale3 = nn.Sequential(nn.Linear(hidden + batch_dim, input_dim3), nn.Softmax(dim=-1))
        self.fc_dropout3 = nn.Linear(hidden + batch_dim, input_dim3)
        self.fc_r3 = nn.Parameter(torch.randn(input_dim3))

    def decode_from_z(self, dz, m1, m2, m3, m, batch):
        """
        Decode latent variable z into ZINB distributions for each modality.

        Args:
            dz: Latent variable z
            m1, m2, m3: Original modality inputs (for library size calculation)
            m: Modality type, 12 -> m1+m2, 13 -> m1+m3
            batch: One-hot batch encoding

        Returns:
            p1, p2, p3: ZINB distribution objects for each modality
        """
        for layer in self.decoder_base:
            dz = torch.cat([dz, batch], dim=1)  # Concatenate batch info
            dz = layer(dz)
        final = torch.cat([dz, batch], dim=1)

        p1 = p2 = p3 = None
        if m in [12, 13]:
            # Decode modality 1
            scale1 = self.fc_scale1(final)
            dropout1 = self.fc_dropout1(final)
            library1 = torch.log(m1.sum(1, keepdim=True) + 1e-8)
            rate1 = torch.exp(library1) * scale1
            p1 = ZeroInflatedNegativeBinomial(mu=rate1, theta=torch.exp(self.fc_r1), zi_logits=dropout1, scale=scale1)

            # Decode modality 2
            if m == 12:
                scale2 = self.fc_scale2(final)
                dropout2 = self.fc_dropout2(final)
                library2 = torch.log(m2.sum(1, keepdim=True) + 1e-8)
                rate2 = torch.exp(library2) * scale2
                p2 = ZeroInflatedNegativeBinomial(mu=rate2, theta=torch.exp(self.fc_r2), zi_logits=dropout2, scale=scale2)

            # Decode modality 3
            if m == 13:
                scale3 = self.fc_scale3(final)
                dropout3 = self.fc_dropout3(final)
                library3 = torch.log(m3.sum(1, keepdim=True) + 1e-8)
                rate3 = torch.exp(library3) * scale3
                p3 = ZeroInflatedNegativeBinomial(mu=rate3, theta=torch.exp(self.fc_r3), zi_logits=dropout3, scale=scale3)

        return p1, p2, p3

    def forward(self, m1, m2, m3, m, batch):
        """
        Forward pass.

        Args:
            m1, m2, m3: Modality inputs
            m: Modality type, 12 -> m1+m2, 13 -> m1+m3
            batch: One-hot batch encoding

        Returns:
            z: Sampled latent variable
            p1, p2, p3: Decoded ZINB distributions
            qz: Posterior distribution
            pz: Prior distribution
            a1_all, a2_all, a3_all: Encoder hidden representations
        """
        device = batch.device
        batch_size = m1.size(0)

        # Initialize placeholders
        mu_all = torch.zeros(batch_size, self.z_dim, device=device)
        var_all = torch.zeros(batch_size, self.z_dim, device=device)
        a1_all = torch.zeros(batch_size, self.hidden, device=device)
        a2_all = torch.zeros(batch_size, self.hidden, device=device)
        a3_all = torch.zeros(batch_size, self.hidden, device=device)

        # ===== Encode =====
        if m in [12, 13]:
            e1 = self.encoder1(m1)
            if m == 12:
                e2 = self.encoder2(m2)
                ae = (e1 + e2) / 2.0  # Simple averaging
            if m == 13:
                e3 = self.encoder3(m3)
                ae = (e1 + e3) / 2.0
        else:
            raise ValueError(f"Unsupported modality m={m}")

        # Latent variable mean and variance
        mu = self.m_net(ae)
        logvar = self.l_net(ae)
        var = torch.exp(logvar) + 1e-8

        mu_all = mu
        var_all = var
        a1_all = e1
        a2_all = e2 if m == 12 else None
        a3_all = e3 if m == 13 else None

        var_all = torch.clamp(var_all, min=1e-6)  # Prevent numerical instability
        qz = Normal(mu_all, var_all.sqrt())  # Posterior
        z = qz.rsample()  # Reparameterization trick
        pz = Normal(torch.zeros_like(z), torch.ones_like(z))  # Prior

        # Decode
        p1, p2, p3 = self.decode_from_z(z, m1, m2, m3, m, batch)
        return z, p1, p2, p3, qz, pz, a1_all, a2_all, a3_all

    def loss_function(self, m1, m2, m3, m, p1, p2, p3, q, p, a1, a2, a3, w):
        """
        Compute total loss: reconstruction + KL divergence + cosine similarity loss

        Args:
            p1,p2,p3: ZINB distribution objects
            q, p: Posterior and prior distributions
            a1,a2,a3: Encoder hidden representations
            w: KL weight

        Returns:
            loss: Total loss
            reconst_loss: Reconstruction loss
            kl: KL divergence
            cos_loss: Cosine alignment loss
        """
        device = m1.device
        cos_loss = torch.tensor(0.0, device=device)

        if m == 12:
            # Modality 1 + 2
            reconst_loss1 = -p1.log_prob(m1).sum(-1).mean()
            reconst_loss2 = -p2.log_prob(m2).sum(-1).mean()
            reconst_loss = reconst_loss1 + reconst_loss2
            cos_sim = F.cosine_similarity(a1, a2, dim=1)
            cos_loss = (1 - cos_sim).mean()  # Alignment loss
        elif m == 13:
            # Modality 1 + 3
            reconst_loss1 = -p1.log_prob(m1).sum(-1).mean()
            reconst_loss3 = -p3.log_prob(m3).sum(-1).mean()
            reconst_loss = reconst_loss1 + reconst_loss3
            cos_sim = F.cosine_similarity(a1, a3, dim=1)
            cos_loss = (1 - cos_sim).mean()
        else:
            raise ValueError(f"Unsupported modality m={m}")

        kl = torch.distributions.kl_divergence(q, p).sum(dim=-1).mean()  # KL divergence
        loss = reconst_loss + w * kl + cos_loss
        return loss, reconst_loss, kl, cos_loss
    
class MultiOmicsDataset(Dataset):
    """
    PyTorch Dataset for multi-omics data.

    Each sample contains:
    - Three modality feature vectors (m1, m2, m3)
    - A modality indicator m
    - A batch covariate
    """

    def __init__(self, *args):
        self.m1_data = args[0]     # Modality 1 data (e.g., scRNA-seq)
        self.m2_data = args[1]     # Modality 2 data (e.g., scATAC-seq)
        self.m3_data = args[2]     # Modality 3 data (e.g., ADT)
        self.m_data = args[3]      # Modality indicator (12 or 13)
        self.batch_data = args[4]  # Batch labels / covariates

    def __len__(self):
        # Number of samples
        return len(self.batch_data)

    def __getitem__(self, idx):
        # Convert each modality to float tensor
        m1 = torch.tensor(self.m1_data[idx], dtype=torch.float32).squeeze(0)
        m2 = torch.tensor(self.m2_data[idx], dtype=torch.float32).squeeze(0)
        m3 = torch.tensor(self.m3_data[idx], dtype=torch.float32).squeeze(0)

        # Modality combination indicator
        m = torch.tensor(self.m_data[idx], dtype=torch.float32).squeeze(0)

        # Batch covariate
        batch = self.batch_data[idx]

        return m1, m2, m3, m, batch, idx
    
def train_and_evaluate_model(output_path,
                             train_loader, test_loader, adata,
                             *args,
                             num_epochs=200):
    """
    Train the multigain model and extract latent representations for the dataset.

    Args:
        output_path: File path to save the AnnData object with latent space
        train_loader: PyTorch DataLoader for training batches
        test_loader: PyTorch DataLoader for test/evaluation batches
        adata: AnnData object containing the dataset
        *args: Model hyperparameters in order:
               input_dim1, input_dim2, input_dim3, n_hidden, hidden, z_dim, batch_dim, q_dim, kv_n
        num_epochs: Number of training epochs
    """
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Unpack model hyperparameters
    input_dim1, input_dim2, input_dim3, n_hidden, hidden, z_dim, batch_dim, q_dim, kv_n = args

    # Initialize the model and move to GPU/CPU
    model = multigain(input_dim1, input_dim2, input_dim3,
                     n_hidden, hidden, z_dim,
                     batch_dim, q_dim, kv_n).to(device)

    # Optimizer and learning rate scheduler
    optimizer_main = Adam(model.parameters(), lr=0.001)
    scheduler_main = torch.optim.lr_scheduler.StepLR(optimizer_main, step_size=50, gamma=0.9)

    tqdm_bar = tqdm(range(num_epochs), desc="Training Progress")

    # ===== Training Loop =====
    for epoch in tqdm_bar:
        running_loss = 0.0
        running_recon = 0.0
        running_kl = 0.0
        running_cos = 0.0

        # KL weight schedule: 0 for first 100 epochs, then 0.1
        kl_weight = 0.0 if epoch < 100 else 0.1

        model.train()
        for batch_data in train_loader:
            optimizer_main.zero_grad()

            m_values = batch_data[3]  # modality identifiers for each sample
            unique_m = m_values.unique()  # get unique modalities in this batch

            # Shuffle modalities to randomize training order
            perm = torch.randperm(len(unique_m))
            unique_m = unique_m[perm]

            # Process each modality separately
            for m_curr in unique_m:
                mask = (m_values == m_curr)  # select samples with current modality

                if mask.any():
                    # Select sub-batch corresponding to this modality
                    sub_batch = [d[mask] for d in batch_data]
                    m1, m2, m3, m_tensor, batch_tensor, idx = [x.to(device) for x in sub_batch]

                    # Forward pass
                    z, p1, p2, p3, qz, pz, a1, a2, a3 = model(
                        m1, m2, m3,
                        int(m_curr.item()),
                        batch_tensor
                    )

                    # Compute loss: reconstruction + KL + cosine similarity
                    loss, reconst_loss, kl_loss, cos_loss = model.loss_function(
                        m1, m2, m3,
                        int(m_curr.item()),
                        p1, p2, p3, qz, pz,
                        a1, a2, a3,
                        kl_weight
                    )

                    # Backpropagation
                    loss.backward()
                    optimizer_main.step()

                    # Accumulate losses for reporting
                    running_loss += loss.item()
                    running_recon += reconst_loss.item()
                    running_kl += kl_loss.item()
                    running_cos += cos_loss.item()

        # Average losses per batch
        n_batches = len(train_loader)
        tqdm_bar.set_postfix({
            "loss": f"{running_loss/n_batches:.4f}",
            "recon": f"{running_recon/n_batches:.4f}",
            "kl": f"{running_kl/n_batches:.4f}",
            "cos": f"{running_cos/n_batches:.4f}",
            "w": f"{kl_weight:.3f}"
        })

        # Step the learning rate scheduler
        scheduler_main.step()

    # ===== Evaluation =====
    model.eval()
    z_all = torch.zeros((len(adata), z_dim), device=device)  # placeholder for latent vectors

    with torch.no_grad():
        for batch_data in test_loader:
            indices = batch_data[-1]  # sample indices in the original dataset
            m_values = batch_data[3]
            unique_m = m_values.unique()

            for m_curr in unique_m:
                mask = (m_values == m_curr)
                if mask.any():
                    # Sub-batch for this modality
                    sub_batch = [d[mask] for d in batch_data]
                    m1, m2, m3, m_tensor, batch_tensor, idx = [x.to(device) for x in sub_batch]

                    # Forward pass to extract latent variable z
                    z, _, _, _, _, _, _, _, _ = model(
                        m1, m2, m3, int(m_curr.item()), batch_tensor
                    )

                    # Store latent vectors at correct indices
                    z_all[idx.long()] = z

    # Save latent vectors to AnnData object
    adata.obsm['latent'] = z_all.cpu().numpy()
    adata.write_h5ad(output_path)

In [None]:
rna = sc.read("./data/neurips-multiome/rna_hvg.h5ad") 
atac = sc.read("./data/neurips-multiome/atac_hvf.h5ad") 
rna_d, atac_d = rna.layers['counts'], atac.layers['counts'] 
adt_d = np.zeros((rna_d.shape[0], 1)) 
rna_dim, atac_dim, adt_dim = rna.shape[1], atac.shape[1], 1 
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(atac_d): atac_d = atac_d.toarray() 
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 
dataset = MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset, batch_size=512, shuffle=True) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 
train_and_evaluate_model('./results/neurips-multiome-multigain.h5ad', train_loader, test_loader, rna, rna_dim, atac_dim, adt_dim, 1, 128, 30, batch_dim, 128, 128)

In [None]:
rna = sc.read("./data/neurips-cite/rna_hvg.h5ad") 
adt = sc.read("./data/neurips-cite/protein.h5ad") 
rna_d, adt_d = rna.layers['counts'], adt.layers['counts'] 
atac_d = np.zeros((rna_d.shape[0], 1))  
rna_dim, atac_dim, adt_dim = rna.shape[1], 1, adt.shape[1] 
if issparse(rna_d): rna_d = rna_d.toarray() 
if issparse(adt_d): adt_d = adt_d.toarray() 
modality_map = { 'multiome': 12, 'cite': 13} 
modality_vector = rna.obs['Modality'].map(modality_map) 
modality_d = modality_vector.to_numpy().astype(float) 
batch_indices = torch.from_numpy(rna.obs['batch'].astype('category').cat.codes.values).long() 
batch_encoded = torch.nn.functional.one_hot(batch_indices) 
batch_dim = batch_encoded.shape[1] 
dataset = MultiOmicsDataset(rna_d, atac_d, adt_d, modality_d, batch_encoded) 
train_loader = DataLoader(dataset, batch_size=512, shuffle=True) 
test_loader = DataLoader(dataset, batch_size=512, shuffle=False) 
train_and_evaluate_model('./results/neurips-cite-multigain.h5ad', train_loader, test_loader, rna, rna_dim, atac_dim, adt_dim, 1, 128, 30, batch_dim, 128, 128)