In [1]:
datasets=['GSE140203_BRAIN_atac2gex',
 'GSE140203_SKIN_atac2gex',
 'openproblems_2022_cite_gex2adt',
 'openproblems_2022_multi_atac2gex']

In [5]:

DCCA_scores=[]

import argparse

import anndata as ad
import numpy as np
import torch
import torch.utils.data as data_utils
from sklearn import preprocessing

import dance.utils.metrics as metrics
from dance.datasets.multimodality import JointEmbeddingNIPSDataset
from dance.modules.multi_modality.joint_embedding.dcca import DCCA
import scanpy as sc

def parameter_setting():
    parser = argparse.ArgumentParser(description="Single cell Multi-omics data analysis")

    parser.add_argument("--latent_fusion", "-olf1", type=str, default="First_simulate_fusion.csv",
                        help="fusion latent code file")
    parser.add_argument("--latent_1", "-ol1", type=str, default="scRNA_latent_combine.csv",
                        help="first latent code file")
    parser.add_argument("--latent_2", "-ol2", type=str, default="scATAC_latent.csv", help="seconde latent code file")
    parser.add_argument("--denoised_1", "-od1", type=str, default="scRNA_seq_denoised.csv",
                        help="outfile for denoised file1")
    parser.add_argument("--normalized_1", "-on1", type=str, default="scRNA_seq_normalized_combine.tsv",
                        help="outfile for normalized file1")
    parser.add_argument("--denoised_2", "-od2", type=str, default="scATAC_seq_denoised.csv",
                        help="outfile for denoised file2")

    parser.add_argument("--workdir", "-wk", type=str, default="./new_test/", help="work path")
    parser.add_argument("--outdir", "-od", type=str, default="./new_test/", help="Output path")

    parser.add_argument("--lr", type=float, default=1E-3, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=1e-6, help="weight decay")
    parser.add_argument("--eps", type=float, default=0.01, help="eps")

    parser.add_argument("--batch_size", "-b", type=int, default=64, help="Batch size")

    parser.add_argument("--seed", type=int, default=1, help="Random seed for repeat results")
    parser.add_argument("--latent", "-l", type=int, default=10, help="latent layer dim")
    parser.add_argument("--max_epoch", "-me", type=int, default=10, help="Max epoches")
    parser.add_argument("--max_iteration", "-mi", type=int, default=1500, help="Max iteration")
    parser.add_argument("--anneal_epoch", "-ae", type=int, default=200, help="Anneal epoch")
    parser.add_argument("--epoch_per_test", "-ept", type=int, default=5, help="Epoch per test")
    parser.add_argument("--max_ARI", "-ma", type=int, default=-200, help="initial ARI")
    parser.add_argument("-t", "--subtask", default="openproblems_bmmc_cite_phase2")
    parser.add_argument("-device", "--device", default="cuda")
    parser.add_argument("--final_rate", type=float, default=1e-4)
    parser.add_argument("--scale_factor", type=float, default=4)
    parser.add_argument("--span", default=0.3, type=float)
    return parser


parser = parameter_setting()
for dataset in datasets:
    args = parser.parse_args(['--subtask',dataset,'--device','cuda',"--span",'1.0'])

    args.sf1 = 5
    args.sf2 = 1
    args.cluster1 = args.cluster2 = 4
    args.lr1 = 0.0001
    args.flr1 = 0.0001
    args.lr2 = 0.0005
    args.flr2 = 0.0005

    dataset = JointEmbeddingNIPSDataset(args.subtask, root="../../../../data/joint_embedding", preprocess="feature_selection",span=args.span)
    data = dataset.load_data()

    le = preprocessing.LabelEncoder()
    labels = le.fit_transform(data.mod["test_sol"].obs["cell_type"])
    

    # sc.pp.filter_genes(data.mod["mod1"],min_counts=3)
    # sc.pp.filter_genes(data.mod["mod2"],min_counts=3)

    sc.pp.log1p(data.mod["mod2"])
    sc.pp.log1p(data.mod["mod1"])

    
    data.mod["mod2"].obsm["size_factors"] = np.sum(data.mod["mod2"].X.todense(), 1) / 100
    # # data.mod["mod1"].obsm["size_factors"] = data.mod["mod1"].obs["size_factors"]
    data.mod["mod1"].obsm["size_factors"] = np.sum(data.mod["mod1"].X.todense(), 1) / 100

    
    # data.mod["mod1"].obsm["size_factors"] = data.mod["mod1"].obs["size_factors"]
    # data.mod["mod2"].obsm["size_factors"] = data.mod["mod1"].obs["size_factors"]

    data.mod["mod1"].obsm["labels"] = labels

    data.set_config(feature_mod=["mod1", "mod2", "mod1", "mod2", "mod1", "mod2"], label_mod="mod1",
                    feature_channel_type=["layers", "layers", None, None, "obsm", "obsm"],
                    feature_channel=["counts", "counts", None, None, "size_factors",
                                     "size_factors"], label_channel="labels")
    (x_train, y_train, x_train_raw, y_train_raw, x_train_size,
     y_train_size), train_labels = data.get_train_data(return_type="torch")
    (x_test, y_test, x_test_raw, y_test_raw, x_test_size,
     y_test_size), test_labels = data.get_test_data(return_type="torch")

    Nfeature1 = x_train.shape[1]
    Nfeature2 = y_train.shape[1]

    device = torch.device(args.device)

    model = DCCA(layer_e_1=[Nfeature1, 128], hidden1_1=128, Zdim_1=4, layer_d_1=[4, 128], hidden2_1=128,
                 layer_e_2=[Nfeature2, 1500, 128], hidden1_2=128, Zdim_2=4, layer_d_2=[4], hidden2_2=4, args=args,
                 Type_1="NB", Type_2="Bernoulli", ground_truth1=torch.cat([train_labels, test_labels]), cycle=1,
                 attention_loss="Eucli",droprate=0)  # yapf: disable
    model.to(device)
    train = data_utils.TensorDataset(x_train.float(), x_train_raw, x_train_size.float(), y_train.float(), y_train_raw,
                                     y_train_size.float())


    train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True)

    test = data_utils.TensorDataset(x_test.float(), x_test_raw, x_test_size.float(), y_test.float(), y_test_raw,
                                    y_test_size.float())

    test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False)

    total = data_utils.TensorDataset(
        torch.cat([x_train, x_test]).float(), torch.cat([x_train_raw, x_test_raw]),
        torch.cat([x_train_size, x_test_size]).float(),
        torch.cat([y_train, y_test]).float(), torch.cat([y_train_raw, y_test_raw]),
        torch.cat([y_train_size, y_test_size]).float())

    total_loader = data_utils.DataLoader(total, batch_size=args.batch_size, shuffle=False)

    model.fit(train_loader, test_loader, total_loader, "RNA")

    with torch.no_grad():
        emb1, emb2 = model.predict(total_loader)

    embeds = np.concatenate([emb1, emb2], 1)
    print(embeds)
    print(model.score(total_loader))

    mod1_obs = data.mod["mod1"].obs
    mod1_uns = data.mod["mod1"].uns
    adata = ad.AnnData(
        X=embeds,
        obs=mod1_obs,
        uns={
            "dataset_id": mod1_uns["dataset_id"],
            "method_id": "scmogcn",
        },
    )

    NMI_score, ARI_score=metrics.labeled_clustering_evaluate(adata, data.mod["test_sol"])
    DCCA_scores.append({"NMI_score":NMI_score,"ARI_score":ARI_score})
"""To reproduce DCCA on other samples, please refer to command lines belows:

GEX-ADT:
python dcca.py --subtask openproblems_bmmc_cite_phase2 --device cuda

GEX-ATAC:
python dcca.py --subtask openproblems_bmmc_multiome_phase2 --device cuda

"""


[INFO][2023-10-01 21:08:31,307][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_mod1.h5ad


[INFO][2023-10-01 21:08:31,804][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_mod2.h5ad
[INFO][2023-10-01 21:08:31,898][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_train_mod1.h5ad
[INFO][2023-10-01 21:08:32,409][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-10-01 21:08:32,574][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_solution.h5ad
[INFO][2023-10-01 21:09:36,904][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-01 21:09:37,896][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object with n_obs × n_vars = 3291 × 897209
  uns:	'dance_config'
  5 mod

scRNA-ARI: 0.002 NMI: 0.012 scEpigenomics-ARI: 0.019 NMI: 0.052
Finish training, total time is: 2.5105178356170654s
False
train likelihood is :  100000 epoch: 0
Finish training, total time is: 2.0973620414733887s
False
train likelihood is :  100000 epoch: 0
scRNA-ARI: 0.002 NMI: 0.009 scEpigenomics-ARI: 0.014 NMI: 0.045
Finish training, total time is: 2.1798858642578125s
False
train likelihood is :  100000 epoch: 0
[[ 7.7893804e-03  8.5386321e-02 -1.9330960e-03 ...  3.4647911e+00
   3.7470562e+00 -2.5562053e+00]
 [-1.8570170e-01 -8.3318189e-02  9.6009038e-02 ...  3.2889309e+00
   3.5041676e+00 -2.6117847e+00]
 [-1.8328802e-01 -6.4812332e-02 -2.9120460e-01 ...  2.7171960e+00
   2.4213703e+00 -2.0959344e+00]
 ...
 [-8.3694747e-03 -3.1025354e-02 -1.0799071e-01 ...  2.2904475e+00
   2.4613416e+00 -1.7188892e+00]
 [ 5.6520768e-02  3.6582332e-02 -3.7959747e-02 ...  3.1053631e+00
   3.1622539e+00 -2.2706914e+00]
 [ 3.5017252e-02 -9.9636808e-02  1.7075598e-02 ...  2.3208294e+00
   2.4863646e+0

[INFO][2023-10-01 21:09:47,570][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_mod1.h5ad


NMI: 0.121 ARI: 0.036


[INFO][2023-10-01 21:09:48,842][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_mod2.h5ad
[INFO][2023-10-01 21:09:49,066][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_train_mod1.h5ad
[INFO][2023-10-01 21:09:51,665][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-10-01 21:09:52,116][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_solution.h5ad
[INFO][2023-10-01 21:09:58,576][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-01 21:09:59,582][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object with n_obs × n_vars = 34054 × 732480
  uns:	'dance_config'
  5 modalities

scRNA-ARI: 0.121 NMI: 0.041 scEpigenomics-ARI: 0.109 NMI: 0.045
Finish training, total time is: 25.54280972480774s
False
train likelihood is :  100000 epoch: 0
Finish training, total time is: 24.399454355239868s
False
train likelihood is :  100000 epoch: 0
scRNA-ARI: 0.078 NMI: 0.107 scEpigenomics-ARI: 0.128 NMI: 0.084
Finish training, total time is: 28.645615577697754s
False
train likelihood is :  100000 epoch: 0
[[-0.25106993  0.28049928 -0.55008644 ... -2.5539322  -2.9442744
  -3.1217422 ]
 [-0.02091426 -0.07577233 -0.7010712  ... -2.6711977  -3.130181
  -2.84519   ]
 [-0.46733496 -0.06588307  0.46740353 ... -2.7163928  -3.0954204
  -3.0968907 ]
 ...
 [ 0.7045505   0.30586278  0.21634519 ... -1.9041781  -3.1879153
  -2.2486959 ]
 [ 0.36578104 -0.14034665 -0.12904985 ... -2.4020846  -2.7188373
  -2.9552426 ]
 [-0.32595018 -0.04635978  0.03599662 ... -2.8619401  -2.8932853
  -3.3556836 ]]
scRNA-ARI: 0.078 NMI: 0.107 scEpigenomics-ARI: 0.19 NMI: 0.109
(0.107, 0.078, 0.109, 0.19)


[INFO][2023-10-01 21:11:38,498][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_mod1.h5ad


NMI: 0.15 ARI: 0.097


[INFO][2023-10-01 21:11:40,069][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_mod2.h5ad
[INFO][2023-10-01 21:11:40,164][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod1.h5ad
[INFO][2023-10-01 21:11:44,555][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod2.h5ad
[INFO][2023-10-01 21:11:44,698][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_solution.h5ad
[INFO][2023-10-01 21:11:54,893][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-01 21:11:55,109][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object with n_obs × n_vars = 70988 × 54450
 

scRNA-ARI: 0.041 NMI: 0.04 scEpigenomics-ARI: 0.134 NMI: 0.15
Finish training, total time is: 39.55171251296997s
False
train likelihood is :  100000 epoch: 0
Finish training, total time is: 35.00210785865784s
False
train likelihood is :  100000 epoch: 0
scRNA-ARI: 0.155 NMI: 0.185 scEpigenomics-ARI: 0.274 NMI: 0.219
Finish training, total time is: 34.28476905822754s
False
train likelihood is :  100000 epoch: 0
[[-2.168189    3.9037285   3.1827476  ...  5.3174214  -0.6608299
   1.1123077 ]
 [-2.6857333   4.4641213   4.4365983  ...  5.106319   -0.7351744
   1.4727234 ]
 [-2.3172736   3.2875607   2.325252   ...  4.5123677  -1.3269365
   1.3042661 ]
 ...
 [ 0.20897624  3.4813457   0.78894746 ...  4.3303175  -1.1400051
   1.4252745 ]
 [ 0.02423564 -2.1603315   2.4668574  ...  3.4652672  -0.8617171
   0.8102826 ]
 [-2.5454156   4.640644    2.5733957  ...  4.25856    -1.2654804
   1.2884603 ]]
scRNA-ARI: 0.155 NMI: 0.185 scEpigenomics-ARI: 0.104 NMI: 0.07
(0.185, 0.155, 0.07, 0.104)


[INFO][2023-10-01 21:14:07,142][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_mod1.h5ad


NMI: 0.166 ARI: 0.098


[INFO][2023-10-01 21:14:10,178][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_mod2.h5ad
[INFO][2023-10-01 21:14:11,822][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod1.h5ad
[INFO][2023-10-01 21:14:17,637][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod2.h5ad
[INFO][2023-10-01 21:14:21,484][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_solution.h5ad
[INFO][2023-10-01 21:14:48,725][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-01 21:14:49,242][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object with n_obs × n_vars =

scRNA-ARI: 0.011 NMI: 0.005 scEpigenomics-ARI: 0.028 NMI: 0.025
Finish training, total time is: 99.04896450042725s
False
train likelihood is :  100000 epoch: 0
Finish training, total time is: 59.12557649612427s
False
train likelihood is :  100000 epoch: 0
scRNA-ARI: 0.201 NMI: 0.225 scEpigenomics-ARI: 0.004 NMI: 0.005
Finish training, total time is: 68.65862393379211s
False
train likelihood is :  100000 epoch: 0
[[-1.8570884   0.04997385 -1.162971   ... -0.6702888   3.4764884
   0.4755263 ]
 [-1.9770805   0.43467298 -0.6536358  ... -0.5521196   3.5503116
   0.25821343]
 [ 1.3098079   0.8582083   1.3729968  ... -0.07480803  4.2486453
   1.8689553 ]
 ...
 [ 1.3577638  -0.5907472   0.10596782 ... -3.8812916   1.7039121
   4.8353686 ]
 [-1.8294289   0.07269537 -0.33176917 ... -0.9449903   3.2171643
  -0.3425314 ]
 [ 0.8234539   0.39255467  2.4930747  ... -3.0130513   3.590608
  -1.1587826 ]]
scRNA-ARI: 0.201 NMI: 0.225 scEpigenomics-ARI: 0.195 NMI: 0.254
(0.225, 0.201, 0.254, 0.195)
NMI: 0

'To reproduce DCCA on other samples, please refer to command lines belows:\n\nGEX-ADT:\npython dcca.py --subtask openproblems_bmmc_cite_phase2 --device cuda\n\nGEX-ATAC:\npython dcca.py --subtask openproblems_bmmc_multiome_phase2 --device cuda\n\n'

In [6]:
DCCA_scores

[{'NMI_score': 0.121, 'ARI_score': 0.036},
 {'NMI_score': 0.15, 'ARI_score': 0.097},
 {'NMI_score': 0.166, 'ARI_score': 0.098},
 {'NMI_score': 0.286, 'ARI_score': 0.228}]

In [2]:
JAE_scores=[]
import argparse
import random
from sklearn import preprocessing
import numpy as np
import torch

from dance.datasets.multimodality import JointEmbeddingNIPSDataset
from dance.modules.multi_modality.joint_embedding.jae import JAEWrapper
from dance.utils import set_seed

rndseed = random.randint(0, 2147483647)
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--subtask", default=datasets[0],
                    choices=datasets)
parser.add_argument("-d", "--data_folder", default="../../../../data/joint_embedding")
parser.add_argument("-pre", "--pretrained_folder", default="./data/joint_embedding/pretrained")
parser.add_argument("-csv", "--csv_path", default="decoupled_lsi.csv")
parser.add_argument("-seed", "--rnd_seed", default=rndseed, type=int)
parser.add_argument("-cpu", "--cpus", default=1, type=int)
parser.add_argument("-device", "--device", default="cpu")
parser.add_argument("-bs", "--batch_size", default=128, type=int)
parser.add_argument("-nm", "--normalize", default=1, type=int, choices=[0, 1])
parser.add_argument("--span", default=0.3, type=float)

for dataset in datasets:
    args = parser.parse_args(['--subtask',dataset,'--device','cpu','--span','1.0'])

    device = args.device
    pre_normalize = bool(args.normalize)
    torch.set_num_threads(args.cpus)
    rndseed = args.rnd_seed
    set_seed(rndseed)

    dataset = JointEmbeddingNIPSDataset(args.subtask, root=args.data_folder, preprocess="feature_selection", normalize=True,span=args.span)
    data = dataset.load_data()
    le = preprocessing.LabelEncoder()
    labels = le.fit_transform(data.mod["test_sol"].obs["cell_type"])
    data.mod["mod1"].obsm["labels"] = labels
    data.set_config(
        feature_mod=["mod1", "mod2"],
        label_mod="mod1",
        feature_channel=["counts", "counts"],
        feature_channel_type=["layers", "layers"],
        label_channel="labels",
    )
    (X_mod1_train, X_mod2_train), (cell_type) = data.get_train_data(return_type="torch")
    (X_mod1_test, X_mod2_test), (cell_type_test) = data.get_test_data(return_type="torch")
    print(X_mod1_train.shape,X_mod1_test.shape)
    X_train = torch.cat([X_mod1_train, X_mod2_train], dim=1)
    phase_score =torch.transpose(torch.tensor([[0.0]*(X_train.shape[0]),[0]*X_train.shape[0]]),0,1)
    batch_label=torch.tensor([0.0]*(X_train.shape[0]))
    # data.set_config(
    #     feature_mod=["mod1", "mod2"],
    #     label_mod=["mod1", "mod1", "mod1", "mod1", "mod1"],
    #     feature_channel=["X_pca", "X_pca"],
    #     label_channel=["cell_type", "batch_label", "phase_labels", "S_scores", "G2M_scores"],
    # )
    # (X_mod1_train, X_mod2_train), (cell_type, batch_label, phase_label, S_score,
    #                                G2M_score) = data.get_train_data(return_type="torch")
    # (X_mod1_test, X_mod2_test), (cell_type_test, _, _, _, _) = data.get_test_data(return_type="torch")
    # X_train = torch.cat([X_mod1_train, X_mod2_train], dim=1)
    # phase_score = torch.cat([S_score[:, None], G2M_score[:, None]], 1)
    model = JAEWrapper(args, num_celL_types=int(cell_type.max() + 1), num_batches=int(batch_label.max() + 1),#这里记得从data里的config里的batch_label里取
                       num_phases=phase_score.shape[1], num_features=X_train.shape[1])
    model.fit(X_train, cell_type, batch_label, phase_score)
    model.load(f"models/model_joint_embedding_{rndseed}.pth")

    with torch.no_grad():
        X_test = torch.cat([X_mod1_test, X_mod2_test], dim=1).float().to(device)
        test_id = np.arange(X_test.shape[0])
        labels = cell_type_test.numpy()
        embeds = model.predict(X_test, test_id).cpu().numpy()
        print(embeds)
        score=model.score(X_test, test_id, labels, metric="clustering")
        print(score)
        JAE_scores.append(score)
"""To reproduce JAE on other samples, please refer to command lines belows:

GEX-ADT:
python jae.py --subtask openproblems_bmmc_cite_phase2 --device cuda

GEX-ATAC:
python jae.py --subtask openproblems_bmmc_multiome_phase2 --device cuda

"""
# TODO
# 把所有preprocess修为feature_selection,可以尝试以下
#最好将output涵盖住train,然后sol说明output的细胞类型
#phase可以改没或者修改为同一时期的

[INFO][2023-10-03 15:06:37,185][dance][set_seed] Setting global random seed to 1206479118
[INFO][2023-10-03 15:06:37,186][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_mod1.h5ad
[INFO][2023-10-03 15:06:37,512][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_mod2.h5ad
[INFO][2023-10-03 15:06:37,631][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_train_mod1.h5ad
[INFO][2023-10-03 15:06:37,932][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-10-03 15:06:38,019][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.outpu

torch.Size([2303, 10000]) torch.Size([988, 10000])
19 1 2 20000
epoch 0
loss1 2.973365271792692, loss2 2.7845839191885555, loss3 -1.1687184553466068e-09, loss4 0.37793319190249725, 
val-loss1 3.060255527496338 val-loss2 2.7973666191101074 val-loss3 -5.160574745310953e-10 val-loss4 0.021960800513625145
val score 2.702750233069336
epoch 1
loss1 3.1081478735979866, loss2 2.2150117369259106, loss3 -1.1687184553466068e-09, loss4 0.29792946752379923, 
val-loss1 3.0539305210113525 val-loss2 2.3885557651519775 val-loss3 -5.160574745310953e-10 val-loss4 0.045025117695331573
val score 2.617713773597306
epoch 2
loss1 3.062740417087779, loss2 1.9152780210270601, loss3 -1.1687184553466068e-09, loss4 0.2709759377381381, 
val-loss1 3.0444154739379883 val-loss2 2.028909683227539 val-loss3 -5.160574745310953e-10 val-loss4 0.052663687616586685
val score 2.5395059527571258
epoch 3
loss1 2.9800869997809913, loss2 1.693169404478634, loss3 -1.1687184553466068e-09, loss4 0.2544134495889439, 
val-loss1 3.0278

[INFO][2023-10-03 15:08:20,868][dance][set_seed] Setting global random seed to 1206479118
[INFO][2023-10-03 15:08:20,869][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_mod1.h5ad


[ 3  5 18  3 18 10 10 11  6  7  3 10  8  1  4 18 17 12  9 18 10  5  5 17
  1 10 15 12  6 13  6 13  0 16  6  4 11 18 10 10  4 10  1 10 12 10  4 11
  2 10 10 10 12  3  3  6 14 15  1  0 17 10  9  4  6  1 15 11  2  8  2  5
 15  5 10 13  6 11 12 16 11  1  1 10  4 10 10 16 10 15 10  6 15 11  1  9
 17  3  3  3  3  3  3 14  2 17  5 15 15 12 17 10 14 10  0 10 10 15 15  5
 14  1 10  1 10  3 15 10  1  1 15 16  5 10  1  5  3  1  3 10 10 10  2 15
  1 11 10 16  5  3  2 18  5 18 11  1  3 10  5  1 11  2  2 10 11 18 18  6
  9 10  7 10 17  1  2  5  9 10  0  3 17  2  4 15  5 10  1  3 10  7 12  3
  7  2  1  2  6  4  6 18 10 17 10  0  1  5 17 10 15 10  0 16  0 10  5  5
 10  1 10 10  5  2  3 12  2 18 18 16 18  7 15 17 16 16 15 15 10 15  2 14
 10  2 11  5 12  5  5 18  2  1  1 15 18  4  6  8 11  0 17 15 13  4  1  7
  1  1 10  4  6  1  6 12 10  2  4 10 10  3  5  2  4  0 11  0  2  9  3 12
 10 11 10 11  1 10  8 11 10 17 10  1 10  2  7  5  7 10 12 10 15 10  6 10
 16  8 10 14 10  5 16  2 15 15 10 10 10 10  5 10  1

[INFO][2023-10-03 15:08:22,912][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_mod2.h5ad
[INFO][2023-10-03 15:08:23,230][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_train_mod1.h5ad
[INFO][2023-10-03 15:08:24,609][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-10-03 15:08:24,849][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_solution.h5ad
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)
[INFO][2023-10-03 15:08:40,700][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-03 15:08:45,206][dance][load_data] Raw data loaded:
Data object that wraps (.data):
Mu

torch.Size([24341, 10000]) torch.Size([10433, 10000])
23 1 2 20000
epoch 0
loss1 3.1685548417789993, loss2 2.2883755114189412, loss3 -9.64412143491109e-10, loss4 0.2619856252053449, 
val-loss1 2.963571310043335 val-loss2 1.6585065126419067 val-loss3 -4.8956578374559356e-11 val-loss4 0.04858076944947243
val score 2.408630258028741
epoch 1
loss1 2.5373286632604377, loss2 1.671169947746188, loss3 -9.64412143491109e-10, loss4 0.14225028810459514, 
val-loss1 1.391563892364502 val-loss2 1.7820154428482056 val-loss3 -4.8956578374559356e-11 val-loss4 0.018176157027482986
val score 1.3314066210737188
epoch 2
loss1 1.055997990245043, loss2 1.6867857263531796, loss3 -9.64412143491109e-10, loss4 0.09847766945008622, 
val-loss1 0.5775074362754822 val-loss2 1.5237488746643066 val-loss3 -4.8956578374559356e-11 val-loss4 0.016961075365543365
val score 0.7098530340915281
epoch 3
loss1 0.6612909089687259, loss2 1.351150124572044, loss3 -9.64412143491109e-10, loss4 0.07899486444630595, 
val-loss1 0.52588

[INFO][2023-10-03 15:22:59,893][dance][set_seed] Setting global random seed to 1206479118
[INFO][2023-10-03 15:22:59,895][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_mod1.h5ad


[18 18 18 ... 13 13 13] [9 5 0 ... 0 0 0]
(0.075, 0.074)


[INFO][2023-10-03 15:23:02,712][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_mod2.h5ad
[INFO][2023-10-03 15:23:02,837][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod1.h5ad
[INFO][2023-10-03 15:23:04,814][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod2.h5ad
[INFO][2023-10-03 15:23:04,901][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_solution.h5ad
  view_to_actual(adata)
[INFO][2023-10-03 15:23:31,045][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-03 15:23:31,170][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object with n_obs × 

torch.Size([49691, 10000]) torch.Size([21297, 10000])
7 1 2 10140
epoch 0
loss1 535.3763093784877, loss2 1.3551160260609219, loss3 -9.356125764613056e-10, loss4 3.5897841104439325, 
val-loss1 460.01739501953125 val-loss2 1.9513113498687744 val-loss3 -2.3985761293809915e-11 val-loss4 9.75899887084961
val score 322.8903887271869
epoch 1
loss1 330.95955448695594, loss2 2.03310606275286, loss3 -9.356125764613056e-10, loss4 7.766586474009922, 
val-loss1 228.9608612060547 val-loss2 2.846067190170288 val-loss3 -2.3985761293809915e-11 val-loss4 9.168705940246582
val score 161.30025157928347
epoch 2
loss1 118.1912338256836, loss2 2.9483223257746016, loss3 -9.356125764613056e-10, loss4 7.694849019050598, 
val-loss1 56.90885543823242 val-loss2 2.9680776596069336 val-loss3 -2.3985761293809915e-11 val-loss4 4.287264823913574
val score 40.64417757987856
epoch 3
loss1 65.23069524492536, loss2 2.991837411948613, loss3 -9.356125764613056e-10, loss4 1.4918665719032287, 
val-loss1 62.34651184082031 val-l

[INFO][2023-10-03 15:27:49,797][dance][set_seed] Setting global random seed to 1206479118
[INFO][2023-10-03 15:27:49,799][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_mod1.h5ad


[2 2 2 ... 1 3 1] [6 6 6 ... 5 6 5]
(0.223, 0.169)


[INFO][2023-10-03 15:27:54,819][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_mod2.h5ad
[INFO][2023-10-03 15:27:58,118][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod1.h5ad
[INFO][2023-10-03 15:28:01,994][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod2.h5ad
[INFO][2023-10-03 15:28:04,593][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_solution.h5ad
  view_to_actual(adata)
  view_to_actual(adata)
[INFO][2023-10-03 15:29:54,183][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-03 15:29:54,656][dance][load_data] Raw data loaded:
Data object that wr

torch.Size([74107, 10000]) torch.Size([31761, 10000])
7 1 2 20000
epoch 0
loss1 9.854349013032584, loss2 0.8411020556628932, loss3 -9.580846604578998e-10, loss4 0.1237062658486581, 
val-loss1 3.2408111095428467 val-loss2 0.962090790271759 val-loss3 -1.6085452414493773e-11 val-loss4 0.004456561058759689
val score 2.461208762786478
epoch 1
loss1 2.420678300419073, loss2 0.7737769584431959, loss3 -9.580846604578998e-10, loss4 0.04332884040329306, 
val-loss1 2.157156467437744 val-loss2 0.5962890386581421 val-loss3 -1.6085452414493773e-11 val-loss4 0.0015593749703839421
val score 1.6293453036857644
epoch 2
loss1 2.2567175183716404, loss2 0.6051138292898164, loss3 -9.580846604578998e-10, loss4 0.021058024536154563, 
val-loss1 1.9873970746994019 val-loss2 0.5005156397819519 val-loss3 -1.6085452414493773e-11 val-loss4 0.0012321437243372202
val score 1.4913426874313844
epoch 3
loss1 2.205635035403387, loss2 0.5261032117452201, loss3 -9.580846604578998e-10, loss4 0.00967300720637757, 
val-loss1 

'To reproduce JAE on other samples, please refer to command lines belows:\n\nGEX-ADT:\npython jae.py --subtask openproblems_bmmc_cite_phase2 --device cuda\n\nGEX-ATAC:\npython jae.py --subtask openproblems_bmmc_multiome_phase2 --device cuda\n\n'

In [3]:
JAE_scores

[(0.333, 0.23), (0.075, 0.074), (0.223, 0.169), (0.25, 0.18)]

In [2]:
scMoGCN_scores=[]
import argparse
import random
from sklearn import preprocessing
import numpy as np
import torch

from dance.datasets.multimodality import JointEmbeddingNIPSDataset
from dance.modules.multi_modality.joint_embedding.scmogcn import ScMoGCNWrapper
from dance.transforms.graph.cell_feature_graph import CellFeatureBipartiteGraph
from dance.utils import set_seed


rndseed = random.randint(0, 2147483647)
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--subtask", default=datasets[0],
                    choices=datasets)
parser.add_argument("-d", "--data_folder", default="../../../../data/joint_embedding")
parser.add_argument("-pre", "--pretrained_folder", default="./data/joint_embedding/pretrained")
parser.add_argument("-csv", "--csv_path", default="decoupled_lsi.csv")
parser.add_argument("-l", "--layers", default=3, type=int, choices=[3, 4, 5, 6, 7])
parser.add_argument("-dis", "--disable_propagation", default=0, type=int, choices=[0, 1, 2])
parser.add_argument("-seed", "--rnd_seed", default=rndseed, type=int)
parser.add_argument("-cpu", "--cpus", default=1, type=int)
parser.add_argument("-device", "--device", default="cuda")
parser.add_argument("-bs", "--batch_size", default=512, type=int)
parser.add_argument("-nm", "--normalize", default=1, type=int, choices=[0, 1])
parser.add_argument("--span", default=0.3, type=float)

for dataset in datasets:
    args = parser.parse_args(['--subtask',dataset,'--device','cuda','--span','1.0'])

    device = args.device
    pre_normalize = bool(args.normalize)
    torch.set_num_threads(args.cpus)
    rndseed = args.rnd_seed
    set_seed(rndseed)

    dataset = JointEmbeddingNIPSDataset(args.subtask, root=args.data_folder, preprocess="feature_selection", normalize=True,span=args.span)
    data = dataset.load_data()
    train_size = len(data.get_split_idx("train"))

    le = preprocessing.LabelEncoder()
    labels = le.fit_transform(data.mod["test_sol"].obs["cell_type"])
    data.mod["mod1"].obsm["labels"] = labels
    
    data = CellFeatureBipartiteGraph(cell_feature_channel="X_pca", mod="mod1")(data)
    data = CellFeatureBipartiteGraph(cell_feature_channel="X_pca", mod="mod2")(data)
    data.set_config(
        feature_mod=["mod1", "mod2"],
        label_mod=["mod1"],
        feature_channel=["X_pca", "X_pca"],
        label_channel=["labels"],
    )
    (x_mod1, x_mod2), (cell_type) = data.get_data(return_type="torch")
    phase_score =torch.transpose(torch.tensor([[0.0]*(x_mod1.shape[0]),[0]*x_mod1.shape[0]]),0,1)
    batch_label=torch.tensor([0.0]*(x_mod1.shape[0]))

    model = ScMoGCNWrapper(args, num_celL_types=int(cell_type.max() + 1), num_batches=int(batch_label.max() + 1),
                           num_phases=phase_score.shape[1], num_features=x_mod1.shape[1] + x_mod2.shape[1])
    model.fit(
        g_mod1=data.data["mod1"].uns["g"],
        g_mod2=data.data["mod2"].uns["g"],
        train_size=train_size,
        cell_type=cell_type,
        batch_label=batch_label,
        phase_score=phase_score,
    )
    model.load(f"models/model_joint_embedding_{rndseed}.pth")

    with torch.no_grad():
        test_id = np.arange(train_size, x_mod1.shape[0])
        labels = cell_type.numpy()[test_id]
        embeds = model.predict(test_id).cpu().numpy()
        print(embeds)
        score=model.score(test_id, labels, metric="clustering")
        print(score)
        scMoGCN_scores.append(score)
"""To reproduce scMoGCN on other samples, please refer to command lines belows:

GEX-ADT:
python scmogcn.py --subtask openproblems_bmmc_cite_phase2 --device cuda

GEX-ATAC:
python scmogcn.py --subtask openproblems_bmmc_multiome_phase2 --device cuda

"""


[INFO][2023-10-03 17:38:14,049][dance][set_seed] Setting global random seed to 1247938998
[INFO][2023-10-03 17:38:14,051][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_mod1.h5ad
[INFO][2023-10-03 17:38:14,458][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_mod2.h5ad
[INFO][2023-10-03 17:38:14,562][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_train_mod1.h5ad
[INFO][2023-10-03 17:38:14,850][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-10-03 17:38:14,946][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.outpu

epoch 0
loss1 1.0276559326383803, loss2 3.099752320183648, loss3 0.9825230439503988, loss4 0.442961464325587, 
val-loss1 1.0264090299606323 val-loss2 2.92504620552063 val-loss3 0.07733790576457977 val-loss4 0.011519686318933964
val score 1.3079384416807445
epoch 1
loss1 1.0256121224827237, loss2 3.003436221016778, loss3 1.0041138264867995, loss4 0.413780798514684, 
val-loss1 1.0257552862167358 val-loss2 2.8912832736968994 val-loss3 0.15937145054340363 val-loss4 0.02396094612777233
val score 1.3054519749246536
epoch 2
loss1 1.023488528198666, loss2 2.9367370075649686, loss3 0.9532757931285434, loss4 0.3987369305557675, 
val-loss1 1.0250908136367798 val-loss2 2.816216230392456 val-loss3 0.2464807778596878 val-loss4 0.04908547177910805
val score 1.295585128106177
epoch 3
loss1 1.0213981072107952, loss2 2.858206960890028, loss3 0.8819088008668687, loss4 0.37190412481625873, 
val-loss1 1.0244345664978027 val-loss2 2.7212274074554443 val-loss3 0.3116030991077423 val-loss4 0.07549013942480087

[INFO][2023-10-03 17:39:01,966][dance][set_seed] Setting global random seed to 1247938998
[INFO][2023-10-03 17:39:01,968][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_mod1.h5ad


(0.536, 0.464)


[INFO][2023-10-03 17:39:04,152][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_mod2.h5ad
[INFO][2023-10-03 17:39:04,497][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_train_mod1.h5ad
[INFO][2023-10-03 17:39:06,023][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-10-03 17:39:06,285][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_solution.h5ad
  utils.warn_names_duplicates("obs")
  self.idf = X.shape[0] / X.sum(axis=0)
  view_to_actual(adata)
  view_to_actual(adata)
[INFO][2023-10-03 17:42:47,861][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-03 17:42:52,782][dance][load_data] Raw data lo

epoch 0
loss1 0.9766389089961385, loss2 3.0757083504699, loss3 2.1226531877074133, loss4 0.3670721026354058, 
val-loss1 0.9533858895301819 val-loss2 2.577671766281128 val-loss3 3.2521517276763916 val-loss4 0.06938908249139786
val score 1.3489815164357424
epoch 1
loss1 0.9563381505566974, loss2 2.4858868510224097, loss3 1.7462219138478123, loss4 0.29313984513282776, 
val-loss1 0.92420494556427 val-loss2 2.057708501815796 val-loss3 2.827094554901123 val-loss4 0.05714154243469238
val score 1.202696967124939
epoch 2
loss1 0.9108819615009219, loss2 2.083017193993857, loss3 1.5208737378896668, loss4 0.23084795509659967, 
val-loss1 0.861436665058136 val-loss2 1.7471261024475098 val-loss3 2.184696674346924 val-loss4 0.04255502671003342
val score 1.0637934710830448
epoch 3
loss1 0.8361747625262238, loss2 1.8321943726650505, loss3 1.3425771219785823, loss4 0.1941535119400468, 
val-loss1 0.774044930934906 val-loss2 1.5737709999084473 val-loss3 1.7143093347549438 val-loss4 0.030860282480716705
val

[INFO][2023-10-03 17:43:35,850][dance][set_seed] Setting global random seed to 1247938998
[INFO][2023-10-03 17:43:35,852][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_mod1.h5ad


(0.388, 0.275)


[INFO][2023-10-03 17:43:38,863][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_mod2.h5ad
[INFO][2023-10-03 17:43:39,007][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod1.h5ad
[INFO][2023-10-03 17:43:41,169][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod2.h5ad
[INFO][2023-10-03 17:43:41,266][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_solution.h5ad
  self.idf = X.shape[0] / X.sum(axis=0)
  tf = X.multiply(1 / X.sum(axis=1))
  view_to_actual(adata)
[INFO][2023-10-03 17:47:48,205][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-03 17:47:48,351][dance][load_data]

epoch 0
loss1 0.9636432907798074, loss2 1.560280141505328, loss3 2.078423125500029, loss4 0.3786299215121703, 
val-loss1 0.9325199723243713 val-loss2 0.9174661040306091 val-loss3 6.0477118492126465 val-loss4 0.08128175884485245
val score 1.1427068818360568
epoch 1
loss1 0.8868953856554899, loss2 0.8454443663358688, loss3 1.5039144872941754, loss4 0.2443485569886186, 
val-loss1 0.8494113683700562 val-loss2 0.572718620300293 val-loss3 3.828327178955078 val-loss4 0.03873004764318466
val score 0.9024845432490111
epoch 2
loss1 0.8406203050505031, loss2 0.6014278862964023, loss3 1.2224178290502592, loss4 0.16939426721497017, 
val-loss1 0.82344651222229 val-loss2 0.4538191556930542 val-loss3 2.7190845012664795 val-loss4 0.021746991202235222
val score 0.8042179643176497
epoch 3
loss1 0.8203147094358098, loss2 0.49676038121635263, loss3 1.028939708728682, loss4 0.12810799725015054, 
val-loss1 0.8044955134391785 val-loss2 0.38955509662628174 val-loss3 1.9769529104232788 val-loss4 0.0142141766846

[INFO][2023-10-03 17:49:16,628][dance][set_seed] Setting global random seed to 1247938998
[INFO][2023-10-03 17:49:16,630][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_mod1.h5ad


(0.533, 0.507)


[INFO][2023-10-03 17:49:22,237][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_mod2.h5ad
[INFO][2023-10-03 17:49:25,826][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod1.h5ad
[INFO][2023-10-03 17:49:29,592][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod2.h5ad
[INFO][2023-10-03 17:49:32,168][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_solution.h5ad
  self.idf = X.shape[0] / X.sum(axis=0)
  view_to_actual(adata)
  view_to_actual(adata)
[INFO][2023-10-03 18:08:52,652][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-03 18:08:53,100][dance][load_da

epoch 0
loss1 0.9126302859255375, loss2 1.1689277182098563, loss3 1.877712727048015, loss4 0.3264447872647802, 
val-loss1 0.7879327535629272 val-loss2 0.6117417812347412 val-loss3 6.011832237243652 val-loss4 0.03547019511461258
val score 0.9762664053589106
epoch 1
loss1 0.7167177186667464, loss2 0.5916646989702269, loss3 1.256161720925615, loss4 0.17249394447066402, 
val-loss1 0.6336088180541992 val-loss2 0.45634716749191284 val-loss3 3.370448112487793 val-loss4 0.014594639651477337
val score 0.7040477437432855
epoch 2
loss1 0.6437534466954588, loss2 0.4780203813814935, loss3 0.9741421021122969, loss4 0.11781825278779022, 
val-loss1 0.5955461859703064 val-loss2 0.39776623249053955 val-loss3 2.1228954792022705 val-loss4 0.008486340753734112
val score 0.6030046676751226
epoch 3
loss1 0.618204427584437, loss2 0.4292549636527782, loss3 0.7548332770589654, loss4 0.0847671890759286, 
val-loss1 0.5808871984481812 val-loss2 0.3675185739994049 val-loss3 1.2701523303985596 val-loss4 0.0049369079

'To reproduce scMoGCN on other samples, please refer to command lines belows:\n\nGEX-ADT:\npython scmogcn.py --subtask openproblems_bmmc_cite_phase2 --device cuda\n\nGEX-ATAC:\npython scmogcn.py --subtask openproblems_bmmc_multiome_phase2 --device cuda\n\n'

In [3]:
scMoGCN_scores

[(0.536, 0.464), (0.388, 0.275), (0.533, 0.507), (0.546, 0.48)]

In [3]:
scMVAE_scores=[]
import argparse

import numpy as np
import torch
import torch.utils.data as data_utils
from sklearn import preprocessing
import scanpy as sc
from dance.datasets.multimodality import JointEmbeddingNIPSDataset
from dance.modules.multi_modality.joint_embedding.scmvae import scMVAE
from dance.transforms.preprocess import calculate_log_library_size


def parameter_setting():
    parser = argparse.ArgumentParser(description="Single cell Multi-omics data analysis")

    parser.add_argument("--workdir", "-wk", type=str, default="./new_test", help="work path")
    parser.add_argument("--outdir", "-od", type=str, default="./new_test", help="Output path")

    parser.add_argument("--lr", type=float, default=1E-3, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=1e-6, help="weight decay")
    parser.add_argument("--eps", type=float, default=0.01, help="eps")

    parser.add_argument("--batch_size", "-b", type=int, default=64, help="Batch size")
    parser.add_argument("--seed", type=int, default=200, help="Random seed for repeat results")
    parser.add_argument("--latent", "-l", type=int, default=10, help="latent layer dim")
    parser.add_argument("--max_epoch", "-me", type=int, default=20, help="Max epoches")
    parser.add_argument("--max_iteration", "-mi", type=int, default=2000, help="Max iteration")
    parser.add_argument("--anneal_epoch", "-ae", type=int, default=200, help="Anneal epoch")
    parser.add_argument("--epoch_per_test", "-ept", type=int, default=5,
                        help="Epoch per test, must smaller than max iteration.")
    parser.add_argument("--max_ARI", "-ma", type=int, default=-200, help="initial ARI")
    parser.add_argument("-t", "--subtask", default="openproblems_bmmc_cite_phase2")
    parser.add_argument("-device", "--device", default="cuda")
    parser.add_argument("--final_rate", type=float, default=1e-4)
    parser.add_argument("--scale_factor", type=float, default=4)
    parser.add_argument("--span", default=0.3, type=float)
    return parser




for dataset in datasets:
    parser = parameter_setting()
    args = parser.parse_args(['--subtask',dataset,'--device','cpu','--span','1.0'])
    assert args.max_iteration > args.epoch_per_test

    dataset = JointEmbeddingNIPSDataset(args.subtask, root="../../../../data/joint_embedding", preprocess="feature_selection",span=args.span)
    data = dataset.load_data()

    le = preprocessing.LabelEncoder()
    labels = le.fit_transform(data.mod["test_sol"].obs["cell_type"])
    
    data.mod["mod1"].obsm["labels"] = labels
    data.set_config(feature_mod=["mod1", "mod2"], label_mod="mod1", feature_channel_type=["layers", "layers"],
                    feature_channel=["counts", "counts"], label_channel="labels")
    # sc.pp.log1p(data.mod["mod2"])
    # sc.pp.log1p(data.mod["mod1"])
   

    (x_train, y_train), _ = data.get_train_data(return_type="torch")
    (x_test, y_test), labels = data.get_test_data(return_type="torch")

    lib_mean1, lib_var1 = calculate_log_library_size(np.concatenate([x_train.numpy(), x_test.numpy()]))
    lib_mean2, lib_var2 = calculate_log_library_size(np.concatenate([y_train.numpy(), y_test.numpy()]))
    lib_mean1 = torch.from_numpy(lib_mean1)
    lib_var1 = torch.from_numpy(lib_var1)
    lib_mean2 = torch.from_numpy(lib_mean2)
    lib_var2 = torch.from_numpy(lib_var2)

    Nfeature1 = x_train.shape[1]
    Nfeature2 = y_train.shape[1]

    device = torch.device(args.device)

    model = scMVAE(
        encoder_1=[Nfeature1, 1024, 128, 128],
        hidden_1=128,
        Z_DIMS=22,
        decoder_share=[22, 128, 256],
        share_hidden=128,
        decoder_1=[128, 128, 1024],
        hidden_2=1024,
        encoder_l=[Nfeature1, 128],
        hidden3=128,
        encoder_2=[Nfeature2, 1024, 128, 128],
        hidden_4=128,
        encoder_l1=[Nfeature2, 128],
        hidden3_1=128,
        decoder_2=[128, 128, 1024],
        hidden_5=1024,
        drop_rate=0.1,
        log_variational=True,
        Type="ZINB",
        device=device,
        n_centroids=22,
        penality="GMM",
        model=1,
    )

    args.lr = 0.001
    args.anneal_epoch = 200

    model.to(device)
    train_size = len(data.get_split_idx("train"))
    train = data_utils.TensorDataset(x_train, lib_mean1[:train_size], lib_var1[:train_size], lib_mean2[:train_size],
                                     lib_var2[:train_size], y_train)

    valid = data_utils.TensorDataset(x_test, lib_mean1[train_size:], lib_var1[train_size:], lib_mean2[train_size:],
                                     lib_var2[train_size:], y_test)

    total = data_utils.TensorDataset(torch.cat([x_train, x_test]), torch.cat([y_train, y_test]))

    total_loader = data_utils.DataLoader(total, batch_size=args.batch_size, shuffle=False)
    model.init_gmm_params(total_loader)
    model.fit(args, train, valid, args.final_rate, args.scale_factor, device)

    embeds = model.predict(torch.cat([x_train, x_test]), torch.cat([y_train, y_test])).cpu().numpy()
    print(embeds)

    nmi_score, ari_score = model.score(x_test, y_test, labels)
    print(f"NMI: {nmi_score:.3f}, ARI: {ari_score:.3f}")
    scMVAE_scores.append({"nmi_score":nmi_score,"ari_score":ari_score})
"""To reproduce scMVAE on other samples, please refer to command lines belows:

GEX-ADT:
python scmvae.py --subtask openproblems_bmmc_cite_phase2 --device cuda

GEX-ATAC:
python scmvae.py --subtask openproblems_bmmc_multiome_phase2 --device cuda

"""


[INFO][2023-10-04 10:00:25,923][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_mod1.h5ad
[INFO][2023-10-04 10:00:26,338][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_mod2.h5ad
[INFO][2023-10-04 10:00:26,473][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_train_mod1.h5ad
[INFO][2023-10-04 10:00:26,866][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-10-04 10:00:26,959][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_BRAIN_atac2gex/GSE140203_BRAIN_atac2gex.GSE140203_dataset.output_solution.h5ad
[INFO][2023-10-04 10:00:51,510][dance][_maybe_preprocess] Preprocessing do

5   4870.607091632938   5760.052151201152   134.73367309570312   3731.04443359375  kl_divergence_l:  1.09037639894252 kl_weight: 0.025 kl_divergence_z: 28.28915023803711
10   4890.69387233602   5347.943882437796   135.32171630859375   3283.099853515625  kl_divergence_l:  0.8434103989101192 kl_weight: 0.05 kl_divergence_z: 26.715003967285156
20   3700.3808181958057   4625.158626285617   142.5497283935547   3513.345947265625  kl_divergence_l:  0.9631066363136009 kl_weight: 0.1 kl_divergence_z: 26.495540618896484
Finish training, total time: 198.2806613445282s epoch: 20 status:  Reached 20 epoch, training complete. 
[[ 0.16673023  0.06058602 -0.92712873 ... -1.873663    0.33918393
   1.0599216 ]
 [ 2.2512248  -0.8462337  -2.4604926  ...  6.4582996  -0.02524094
  -0.11451921]
 [ 0.30436543  0.5637595   0.01622492 ... -0.60097295  0.86624026
  -1.212858  ]
 ...
 [ 0.8500484   1.0427794  -0.03551135 ...  1.9562974  -0.69208586
  -2.3211868 ]
 [ 0.97587556  0.5382068  -0.09366035 ... -0.05597

[INFO][2023-10-04 10:04:14,136][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_mod1.h5ad


NMI: 0.412, ARI: 0.260


[INFO][2023-10-04 10:04:16,238][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_mod2.h5ad
[INFO][2023-10-04 10:04:16,582][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_train_mod1.h5ad
[INFO][2023-10-04 10:04:18,099][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-10-04 10:04:18,347][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/GSE140203_SKIN_atac2gex/GSE140203_SKIN_atac2gex.GSE140203_dataset.output_solution.h5ad
  utils.warn_names_duplicates("obs")
[INFO][2023-10-04 10:07:47,896][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-04 10:07:53,167][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object with n_obs × n_vars = 35494 × 730259

5   1863.0816018127828   6201.330424286834   207.4984130859375   1221.3515625  kl_divergence_l:  3.6345413724494584 kl_weight: 0.025 kl_divergence_z: 28.52562713623047
10   1888.6250997608918   2379.479025048003   201.53872680664062   1193.8466796875  kl_divergence_l:  3.583497680371307 kl_weight: 0.05 kl_divergence_z: 11.419819831848145
15   1922.2007676485446   2252.770514677865   202.55029296875   1170.1573486328125  kl_divergence_l:  3.7431742611158425 kl_weight: 0.075 kl_divergence_z: 6.449275970458984
20   1484.5260860181595   2166.7726841994863   203.54898071289062   1148.37353515625  kl_divergence_l:  3.7746769860005873 kl_weight: 0.1 kl_divergence_z: 3.2357211112976074
Finish training, total time: 2005.356766462326s epoch: 20 status:  Reached 20 epoch, training complete. 
[[-0.28312817  0.3167457   0.2871076  ... -0.5831218  -0.33563468
  -1.1981037 ]
 [ 0.01919515 -0.18245426  0.16589062 ...  0.42177296 -0.21458617
  -2.3657603 ]
 [-0.49299112  1.1364506   0.29133832 ... -0.0

[INFO][2023-10-04 10:41:47,224][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_mod1.h5ad


NMI: 0.398, ARI: 0.285


[INFO][2023-10-04 10:41:50,965][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_mod2.h5ad
[INFO][2023-10-04 10:41:51,102][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod1.h5ad
[INFO][2023-10-04 10:41:53,726][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod2.h5ad
[INFO][2023-10-04 10:41:53,833][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_solution.h5ad
  tf = X.multiply(1 / X.sum(axis=1))
[INFO][2023-10-04 10:46:07,772][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-04 10:46:07,922][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object 

5   17985.007032570193   16382.881477100365   3793.132080078125   853.6953735351562  kl_divergence_l:  8.454174704874704 kl_weight: 0.025 kl_divergence_z: 48.263954162597656
10   16206.075509077205   16262.422772492015   3764.930908203125   855.4279174804688  kl_divergence_l:  8.544236952374451 kl_weight: 0.05 kl_divergence_z: 37.27095031738281
15   16635.13916145159   15459.252839351577   3563.09716796875   859.9569702148438  kl_divergence_l:  7.861500253318012 kl_weight: 0.075 kl_divergence_z: 51.291568756103516
20   15799.468724735767   15202.643965870297   3500.545166015625   860.4047241210938  kl_divergence_l:  7.981056714463123 kl_weight: 0.1 kl_divergence_z: 21.73849105834961
Finish training, total time: 2135.826046228409s epoch: 20 status:  Reached 20 epoch, training complete. 
[[-1.9540087  -0.6546646  -0.05632153 ... -1.9980044  -1.6573197
   0.9333149 ]
 [-1.6789708  -1.0913044  -0.42879108 ... -2.2321908  -0.51619124
   0.05385058]
 [ 0.50575674 -0.24153113 -1.529554   ... 

[INFO][2023-10-04 11:22:16,198][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_mod1.h5ad


NMI: 0.375, ARI: 0.274


[INFO][2023-10-04 11:22:21,845][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_mod2.h5ad
[INFO][2023-10-04 11:22:25,197][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod1.h5ad
[INFO][2023-10-04 11:22:28,707][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod2.h5ad
[INFO][2023-10-04 11:22:31,093][dance][_load_raw_data] Loading /home/zyxing/data/joint_embedding/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_solution.h5ad
[INFO][2023-10-04 11:42:59,779][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-10-04 11:43:00,267][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object with n_obs × n_vars =

5   13839.542842427742   13859.240867715967   1098.2010498046875   8836.97265625  kl_divergence_l:  6.5338216131399145 kl_weight: 0.025 kl_divergence_z: 31.927928924560547
10   13178.796699621582   13749.321095138544   1076.4365234375   8864.8779296875  kl_divergence_l:  8.18936132681222 kl_weight: 0.05 kl_divergence_z: 11.160152435302734
15   13598.348587159035   13712.04085653718   1070.6529541015625   8813.5673828125  kl_divergence_l:  9.21884509199748 kl_weight: 0.075 kl_divergence_z: 5.951897621154785
20   12752.13133829034   13693.182149188237   1066.2730712890625   8823.2724609375  kl_divergence_l:  10.258705066704177 kl_weight: 0.1 kl_divergence_z: 3.848344326019287
Finish training, total time: 4140.331627130508s epoch: 20 status:  Reached 20 epoch, training complete. 
[[ 0.12440279  0.98521894 -0.17474233 ...  0.6150322  -1.0702327
   0.24965405]
 [ 1.4394023  -0.11239281 -1.7417555  ... -0.41631913  1.5900434
  -0.6065301 ]
 [ 0.64373845  1.4043484   0.5709433  ... -0.4568111

'To reproduce scMVAE on other samples, please refer to command lines belows:\n\nGEX-ADT:\npython scmvae.py --subtask openproblems_bmmc_cite_phase2 --device cuda\n\nGEX-ATAC:\npython scmvae.py --subtask openproblems_bmmc_multiome_phase2 --device cuda\n\n'

In [4]:
scMVAE_scores

[{'nmi_score': 0.412, 'ari_score': 0.26},
 {'nmi_score': 0.398, 'ari_score': 0.285},
 {'nmi_score': 0.375, 'ari_score': 0.274},
 {'nmi_score': 0.406, 'ari_score': 0.348}]