In [1]:
from torch import Tensor
from torch_geometric.typing import Adj

from datasets.preprocess import remove_lsi_key, preprocess
import scanpy as sc

modalities = ["adt", "RNA", 'atac']

datasets = {}

for modality in modalities:
    try:
        datasets[modality] = preprocess(modality, f"../datasets/data/processed/PBMC-DOGMA_{modality}.h5ad", n_pcs=100)
    except ValueError:
        remove_lsi_key(f"../datasets/data/processed/PBMC-DOGMA_{modality}.h5ad")
        datasets[modality] = preprocess(modality, f"../datasets/data/processed/PBMC-DOGMA_{modality}.h5ad", n_pcs=100)
datasets

  view_to_actual(adata)


{'adt': AnnData object with n_obs × n_vars = 13763 × 210
     obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'stim', 'celltype', 'nCount_atac', 'nFeature_atac', 'nCount_adt', 'nFeature_adt'
     var: 'features', 'mean', 'std'
     uns: 'pca'
     obsm: 'X_apca', 'X_apca.raw', 'X_pca'
     varm: 'APCA.RAW', 'PCs',
 'RNA': AnnData object with n_obs × n_vars = 13763 × 300
     obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'stim', 'celltype', 'nCount_atac', 'nFeature_atac', 'nCount_adt', 'nFeature_adt'
     var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
     uns: 'log1p', 'hvg', 'pca'
     obsm: 'X_rpca', 'X_rpca.raw', 'X_pca'
     varm: 'RPCA.RAW', 'PCs',
 'atac': AnnData object with n_obs × n_vars = 13763 × 5000
     obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'stim', 'celltype', 'nCount_atac', 'nFeature_atac', 'nCount_adt', 'nFeature_adt'
     va

In [2]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X = {m: torch.tensor(datasets[m].obsm['X_pca'], dtype=torch.float, device=device) for m in modalities}
for k, v in X.items():
    print(k, v.shape)

adt torch.Size([13763, 100])
RNA torch.Size([13763, 100])
atac torch.Size([13763, 100])


# cell_seq_rep -> seq_sim -> cell

In [3]:
from torch_cluster import knn_graph
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
edge_indices = {}
for m in modalities:
    edge_indices[m] = knn_graph(
        X[m],
        k=min(X[m].shape[0] // 20, 50),
        cosine=True,
        num_workers=16
    )
edge_indices

{'adt': tensor([[ 2917,   844,  2871,  ..., 12844, 12267,  9163],
         [    0,     0,     0,  ..., 13762, 13762, 13762]], device='cuda:0'),
 'RNA': tensor([[ 2574,  9294,  4927,  ..., 12470, 12342,   480],
         [    0,     0,     0,  ..., 13762, 13762, 13762]], device='cuda:0'),
 'atac': tensor([[ 6431,  5201,  7867,  ...,  2002,  2193,  3477],
         [    0,     0,     0,  ..., 13762, 13762, 13762]], device='cuda:0')}

In [4]:
E = {}
for m in modalities:
    src, tgt = edge_indices[m]
    # E[m] = torch.bmm(X[m][src].unsqueeze(1), X[m][tgt].unsqueeze(-1)).squeeze()
    E[m] = torch.norm(X[m][src] - X[m][tgt], dim=1)
E

{'adt': tensor([10.4106, 11.0088, 10.1860,  ..., 11.6050, 12.9893, 11.3765],
        device='cuda:0'),
 'RNA': tensor([ 0.0394,  0.3796,  0.6669,  ...,  6.5936, 25.8649,  6.7953],
        device='cuda:0'),
 'atac': tensor([2.4650, 2.5006, 2.5179,  ..., 2.1866, 2.1866, 2.1866], device='cuda:0')}

In [5]:
C = torch.cat([X[m] for m in modalities], dim=1)
C.shape

torch.Size([13763, 300])

In [22]:
from torch_geometric.data import HeteroData

# (1) Assign attributes after initialization,
data = HeteroData()

data['cell'].x = C
data['cell', 'similar_to', 'cell'].edge_index = knn_graph(
    C,
    k=min(100, C.shape[0] // 200),
    cosine=True,
    num_workers=16
)

for m in modalities:
    data[m].x = X[m]
    data[m, "similar_to", m].edge_index = edge_indices[m]
    data[m, "belongs_to", "cell"].edge_index = torch.stack([torch.arange(C.shape[0]), torch.arange(C.shape[0])], dim=0)
    data[m, "similar_to", m].e = E[m] if len(E[m].shape) > 1 else E[m].unsqueeze(-1)
data

HeteroData(
  cell={ x=[13763, 300] },
  adt={ x=[13763, 100] },
  RNA={ x=[13763, 100] },
  atac={ x=[13763, 100] },
  (cell, similar_to, cell)={ edge_index=[2, 935884] },
  (adt, similar_to, adt)={
    edge_index=[2, 688150],
    e=[688150, 1],
  },
  (adt, belongs_to, cell)={ edge_index=[2, 13763] },
  (RNA, similar_to, RNA)={
    edge_index=[2, 688150],
    e=[688150, 1],
  },
  (RNA, belongs_to, cell)={ edge_index=[2, 13763] },
  (atac, similar_to, atac)={
    edge_index=[2, 688150],
    e=[688150, 1],
  },
  (atac, belongs_to, cell)={ edge_index=[2, 13763] }
)

In [113]:
from torch_geometric.loader import NeighborLoader

if 'train_mask' not in data['cell']:
    data['cell']['train_mask'] = torch.ones(data['cell'].x.size(0), dtype=torch.bool)

# Define the neighbor sampling configuration.
# For each edge type, you specify a list of numbers representing the number of neighbors
# to sample at each layer (e.g., [first_hop, second_hop]).
sampler = NeighborLoader(
    data,
    num_neighbors={
        ('adt', 'similar_to', 'adt'): [15, 10],
        ('adt', 'belongs_to', 'cell'): [1, 0],
        ('RNA', 'similar_to', 'RNA'): [15, 10],
        ('RNA', 'belongs_to', 'cell'): [1, 0],
        ('atac', 'similar_to', 'atac'): [15, 10],
        ('atac', 'belongs_to', 'cell'): [1, 0],
        ('cell', 'similar_to', 'cell'): [8, 4],
    },
    input_nodes=('cell', data['cell']['train_mask']),
    batch_size=2048,
    shuffle=True,
)

sample = next(iter(sampler))
sample

HeteroData(
  cell={
    x=[10847, 300],
    train_mask=[10847],
    n_id=[10847],
    num_sampled_nodes=[3],
    input_id=[2048],
    batch_size=2048,
  },
  adt={
    x=[9989, 100],
    n_id=[9989],
    num_sampled_nodes=[3],
  },
  RNA={
    x=[7924, 100],
    n_id=[7924],
    num_sampled_nodes=[3],
  },
  atac={
    x=[8459, 100],
    n_id=[8459],
    num_sampled_nodes=[3],
  },
  (cell, similar_to, cell)={
    edge_index=[2, 42500],
    e_id=[42500],
    num_sampled_edges=[2],
  },
  (adt, similar_to, adt)={
    edge_index=[2, 20480],
    e=[20480, 1],
    e_id=[20480],
    num_sampled_edges=[2],
  },
  (adt, belongs_to, cell)={
    edge_index=[2, 2048],
    e_id=[2048],
    num_sampled_edges=[2],
  },
  (RNA, similar_to, RNA)={
    edge_index=[2, 20480],
    e=[20480, 1],
    e_id=[20480],
    num_sampled_edges=[2],
  },
  (RNA, belongs_to, cell)={
    edge_index=[2, 2048],
    e_id=[2048],
    num_sampled_edges=[2],
  },
  (atac, similar_to, atac)={
    edge_index=[2, 20480],
  

In [114]:
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch.nn import Linear

class IntraModalityMP(MessagePassing):
    def __init__(self, node_dim, edge_attr_dim, hidden_dims):
        super(IntraModalityMP, self).__init__(aggr='add')
        # Fully connected layer to compute edge messages from concatenated features:
        # concatenated features have size: 2 * in_channels + edge_attr_dim
        self.fc_message = nn.Linear(2 * node_dim, hidden_dims)
        # Fully connected layer to update node features after aggregation
        self.fc_update = nn.Linear(hidden_dims, hidden_dims)
        # Optionally, include a non-linear activation
        self.activation = nn.SiLU()

        self.edge_gate = nn.Sequential(
            nn.Linear(edge_attr_dim, edge_attr_dim),
            nn.SiLU(),
            nn.Linear(edge_attr_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x, edge_index, edge_attr):
        """
        Args:
            x (Tensor): Node features with shape [num_nodes, in_channels].
            edge_index (LongTensor): Graph connectivity in COO format [2, num_edges].
            edge_attr (Tensor): Edge attributes with shape [num_edges, edge_attr_dim].
        Returns:
            Tensor: Updated node features with shape [num_nodes, out_channels].
        """
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        """
        For each edge, concatenate target node feature (x_i), source node feature (x_j),
        and edge attribute (edge_attr), then compute the edge message.
        """
        # Concatenate features from the target, source, and edge attribute
        m_ij = torch.cat([x_i, x_j], dim=-1)
        a_ij = self.edge_gate(edge_attr)

        # Compute the edge message and apply activation
        return self.activation(self.fc_message(m_ij)) * a_ij

    def update(self, aggr_out):
        """
        Update the node features by applying a fully connected layer on the aggregated messages.
        """
        return self.fc_update(aggr_out)

class ModalityToCellMP(MessagePassing):
    def __init__(self, node_dim, cell_dim, hidden_cell_dim):
        super(ModalityToCellMP, self).__init__(aggr='add')
        self.cell_gate = nn.Sequential(
            nn.Linear(cell_dim, cell_dim),
            nn.SiLU(),
            nn.Linear(cell_dim, 1),
            nn.Sigmoid()
        )
        self.fc_message = nn.Linear(node_dim + cell_dim, hidden_cell_dim)
        self.fc_update = nn.Linear(hidden_cell_dim, hidden_cell_dim)
        self.activation = nn.SiLU()

    def forward(self, x, c, edge_index):
        return self.propagate(edge_index, x=x, c=c)

    def message(self, c_i, x_j):
        a = self.cell_gate(c_i)
        m_ij = self.fc_message(torch.cat([x_j, c_i], dim=-1))
        return self.activation(m_ij) * a

    def update(self, aggr_out):
        return self.fc_update(aggr_out)

class MultiOmicsEmbedding(nn.Module):
    def __init__(self, modality_in_dims, cell_in_dims, edge_attr_dims, modality_hidden_dims, modalities, cell_hidden_dims, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.activation = nn.SiLU()
        self.modality_embs = nn.ModuleDict({
            m: IntraModalityMP(
                node_dim=modality_in_dims[m],
                edge_attr_dim=edge_attr_dims[m],
                hidden_dims=modality_hidden_dims,
            ) for m in modalities
        })
        self.cell_emb = nn.Sequential(
            Linear(cell_in_dims, cell_hidden_dims),
            self.activation,
        )

    def forward(self, batch):

        H = {}
        for m in modalities:
            H[m] = self.modality_embs[m](batch[m].x, batch[m, 'similar_to', m].edge_index, batch[m, 'similar_to', m].e)

        C = self.cell_emb(batch['cell'].x)

        return H, C

class MultiOmicsLayer(nn.Module):
    def __init__(self, modality_in_dims, cell_in_dims, edge_attr_dims, modality_hidden_dims, cell_hidden_dims, modalities, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.activation = nn.SiLU()
        self.intra_modality_msg = nn.ModuleDict({
            m: IntraModalityMP(
                node_dim=modality_in_dims[m],
                edge_attr_dim=edge_attr_dims[m],
                hidden_dims=modality_hidden_dims,
            ) for m in modalities
        })

        self.modality_to_cell_msg = nn.ModuleDict({
            m: ModalityToCellMP(
                node_dim=modality_hidden_dims,
                cell_dim=cell_in_dims,
                hidden_cell_dim=cell_hidden_dims,
            ) for m in modalities
        })

    def forward(self, batch, H, C):

        for m in modalities:
            H[m] = H[m] + self.intra_modality_msg[m](H[m], batch[m, 'similar_to', m].edge_index, batch[m, 'similar_to', m].e)
            C = C + self.modality_to_cell_msg[m](H[m], C, batch[m, 'belongs_to', 'cell'].edge_index)

        return H, C

class MultiOmicsIntegration(nn.Module):
    def __init__(self, modality_in_dims, cell_in_dims, edge_attr_dims, modality_hidden_dims, cell_hidden_dims,
                 modalities, layer_num=2, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.modality_in_dims = modality_in_dims
        self.cell_in_dims = cell_in_dims
        self.edge_attr_dims = edge_attr_dims
        self.modality_hidden_dims = modality_hidden_dims
        self.cell_hidden_dims = cell_hidden_dims
        self.modalities = modalities
        self.layer_num = layer_num

        self.embedding = MultiOmicsEmbedding(
            modality_in_dims=self.modality_in_dims,
            cell_in_dims=self.cell_in_dims,
            edge_attr_dims=self.edge_attr_dims,
            modality_hidden_dims=self.modality_hidden_dims,
            cell_hidden_dims=self.cell_hidden_dims,
            modalities=self.modalities,
        )

        self.layers = nn.ModuleList([
            MultiOmicsLayer(
                modality_in_dims={m: self.modality_hidden_dims for m in self.modalities},
                cell_in_dims=self.cell_hidden_dims,
                edge_attr_dims=self.edge_attr_dims,
                modality_hidden_dims=self.modality_hidden_dims,
                cell_hidden_dims=self.cell_hidden_dims,
                modalities=self.modalities
            ) for _ in range(self.layer_num)
        ])

        self.modality_bn = nn.ModuleDict({
            m: nn.BatchNorm1d(self.modality_hidden_dims) for m in self.modalities
        })

        self.cell_bn = nn.BatchNorm1d(self.cell_hidden_dims)

    def normalize_each_modality(self, H):
        for m in self.modalities:
            H[m] = self.modality_bn[m](H[m])
        return H

    def forward(self, batch):
        H, C = self.embedding(batch)
        for layer in self.layers:
            H, C = layer(batch, H, C)
            H = self.normalize_each_modality(H)
            C = self.cell_bn(C)

        return H, C

model = MultiOmicsIntegration(
    modality_in_dims={m: data[m].x.shape[1] for m in modalities},
    edge_attr_dims={m: data[m, 'similar_to', m].e.shape[1] for m in modalities},
    cell_in_dims=data['cell'].x.shape[1],
    cell_hidden_dims=256,
    modality_hidden_dims=64,
    modalities=modalities,
).to(device)



batch = model(next(iter(sampler)).to(device))
batch


({'adt': tensor([[ 2.3311,  0.5682, -0.4285,  ...,  2.4638, -1.0885, -0.8179],
          [ 1.0701, -4.5158,  0.0322,  ...,  1.1128,  0.3254,  1.8957],
          [ 0.5050, -0.3739,  1.6729,  ...,  2.5747, -0.6587, -2.3867],
          ...,
          [-0.1301,  0.1298,  0.0759,  ..., -0.1524,  0.3059, -0.1478],
          [-0.1301,  0.1298,  0.0759,  ..., -0.1524,  0.3059, -0.1478],
          [-0.1301,  0.1298,  0.0759,  ..., -0.1524,  0.3059, -0.1478]],
         device='cuda:0', grad_fn=<NativeBatchNormBackward0>),
  'RNA': tensor([[ 0.3337, -0.4875, -0.3552,  ..., -0.5535,  0.2425, -0.0739],
          [ 2.4849, -2.7782,  2.1306,  ..., -2.5821,  0.8290, -0.4272],
          [-0.1772,  0.2140, -0.3553,  ...,  0.1153, -0.0359,  0.0949],
          ...,
          [-0.0039, -0.0275, -0.2441,  ...,  0.0126, -0.3625, -0.1814],
          [-0.0039, -0.0275, -0.2441,  ...,  0.0126, -0.3625, -0.1814],
          [-0.0039, -0.0275, -0.2441,  ...,  0.0126, -0.3625, -0.1814]],
         device='cuda:0', g

In [117]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# ------------------------------------------------------------------------------
# Define the decoder.
# ------------------------------------------------------------------------------
class MultiOmicsDecoder(nn.Module):
    def __init__(self, encoder, modality_in_dims, cell_in_dims, modalities):
        super().__init__()
        self.modalities = modalities

        # For each modality, define a decoder mapping from latent dim back to the original feature dim.
        self.modality_decoders = nn.ModuleDict({
            m: nn.Sequential(
                nn.Linear(encoder.modality_hidden_dims, modality_in_dims[m]),
                nn.SiLU(),  # optional non-linearity; adjust as needed
            )
            for m in modalities
        })

        # Decoder for cell nodes.
        self.cell_decoder = nn.Sequential(
            nn.Linear(encoder.cell_hidden_dims, cell_in_dims),
            nn.SiLU(),
        )

    def forward(self, H, C):
        modality_recons = {m: self.modality_decoders[m](H[m]) for m in self.modalities}
        cell_recon = self.cell_decoder(C)
        return modality_recons, cell_recon

# ------------------------------------------------------------------------------
# Define a separate noising function.
# This function adds noise only once per batch when called.
# ------------------------------------------------------------------------------
def add_noise(batch, modalities, noise_std=0.1, drop_modality_prob=0.5):
    """
    Adds noise to the input batch.
      - For each modality, adds Gaussian noise.
      - With some probability, drops a modality (zeroes its features).
      - Also adds noise to cell node features.
    """
    noisy_batch = copy.deepcopy(batch)
    noise_mask = torch.rand(noisy_batch['cell'].x.shape[0]) > 0.1

    # Process each modality.
    for m in modalities:
        if 'x' in noisy_batch[m]:
            original = noisy_batch[m].x[noise_mask]
            if torch.rand(1).item() < drop_modality_prob:
                # Drop modality: replace with pure noise.
                noisy_batch[m].x[noise_mask] = noise_std * torch.randn_like(original)
            else:
                # Add Gaussian noise.
                noisy_batch[m].x[noise_mask] = original + noise_std * torch.randn_like(original)

    # Add noise to cell features.
    if 'x' in noisy_batch['cell']:
        original = noisy_batch['cell'].x[noise_mask]
        noisy_batch['cell'].x[noise_mask] = original + noise_std * torch.randn_like(original)

    return noisy_batch

# ------------------------------------------------------------------------------
# Example usage.
# ------------------------------------------------------------------------------
# Assume you have:
# - `data`: a dict-like heterogeneous graph with keys for each modality and 'cell'
# - `modalities`: list of modality names (e.g. ['RNA', 'adt', 'atac'])
# - `sampler`: an iterable over batches
# - `device`: torch device (e.g. 'cuda' or 'cpu')

# Instantiate the encoder.
encoder = MultiOmicsIntegration(
    modality_in_dims={m: data[m].x.shape[1] for m in modalities},
    edge_attr_dims={m: data[m, 'similar_to', m].e.shape[1] for m in modalities},
    cell_in_dims=data['cell'].x.shape[1],
    cell_hidden_dims=256,
    modality_hidden_dims=64,
    modalities=modalities,
).to(device)

# Instantiate the decoder.
decoder = MultiOmicsDecoder(
    encoder=encoder,
    modality_in_dims={m: data[m].x.shape[1] for m in modalities},
    cell_in_dims=data['cell'].x.shape[1],
    modalities=modalities,
).to(device)

# Define optimizer over the parameters of both encoder and decoder.
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)

epochs = 10
for epoch in range(epochs):
    for batch in tqdm(sampler):
        # Move batch to device.
        batch = batch.to(device)

        # Apply noise once to the input batch.
        noisy_batch = add_noise(batch, modalities, noise_std=0.1, drop_modality_prob=0.003)

        # Zero gradients.
        optimizer.zero_grad()

        # Forward pass: encode then decode.
        H, C = encoder(noisy_batch)
        modality_recons, cell_recon = decoder(H, C)

        # Compute reconstruction losses, e.g. using MSE.
        loss_modality = sum(F.mse_loss(modality_recons[m], batch[m].x) for m in modalities)
        loss_cell = F.mse_loss(cell_recon, batch['cell'].x)
        loss = loss_modality + loss_cell

        # Backward pass.
        loss.backward()
        optimizer.step()

    # (Optional) Evaluate reconstruction loss at end of epoch.
    H, C = encoder(noisy_batch)
    modality_recons, cell_recon = decoder(H, C)
    loss_modality = sum(F.mse_loss(modality_recons[m], batch[m].x) for m in modalities)
    loss_cell = F.mse_loss(cell_recon, batch['cell'].x)
    loss = loss_modality + loss_cell
    print(f"Epoch {epoch + 1} Reconstruction loss: {loss.item()}")


  0%|          | 0/7 [00:00<?, ?it/s]


IndexError: The shape of the mask [10881] at index 0 does not match the shape of the indexed tensor [10066, 100] at index 0

 seq -> cell -> cell sim

In [22]:
import torch

torch.topk(torch.tensor(X[modalities[0]]), k=10, dim=1).indices

tensor([[ 74, 109,  96,  ...,  51, 192, 159],
        [ 27, 158,  57,  ..., 139, 122, 209],
        [122,  94, 175,  ...,  58,  53,   6],
        ...,
        [196,  34, 123,  ...,  31,  76, 109],
        [ 27,  26,  79,  ..., 161,  70, 165],
        [158, 132, 195,  ...,  21,  97, 166]])

In [34]:
import torch
import toponetx as tnx
from toponetx.classes.combinatorial_complex import CombinatorialComplex

offsets = [0] * len(modalities)
for idx, m in enumerate(modalities[1:]):
    offsets[idx + 1] = offsets[idx] + X[m].shape[1]

ccc: CombinatorialComplex = CombinatorialComplex()
cell = torch.cat([
    torch.topk(torch.tensor(X[m]), k=10, dim=1).indices + offset
    for offset, m in zip(offsets, modalities)
], dim=1)

ccc.add_cells_from(range(cell.shape[0]), ranks=0)
ccc.add_cells_from(cell.detach().cpu().numpy(), ranks=1)

ccc.number_of_cells()

27526

In [30]:
ccc.skeleton(rank=0)[-1]

frozenset({13762})

In [31]:
cell.max()

tensor(9789)

In [33]:
offsets = [0] * len(modalities)
for idx, m in enumerate(modalities[1:]):
    offsets[idx + 1] = offsets[idx] + X[m].shape[1]
offsets

[0, 3000, 8000]