In [1]:
datasets=["10k_pbmc","5k_pbmc_subset","pbmc_cite","openproblems_2022_multi_atac2gex","openproblems_2022_cite_gex2adt","GSE127064_AdBrain_gex2atac","GSE127064_p0Brain_gex2atac","GSE117089_mouse_gex2atac",
          "GSE140203_3T3_HG19_atac2gex",
        "GSE140203_3T3_MM10_atac2gex",
        "GSE140203_12878.rep2_atac2gex",
        "GSE140203_12878.rep3_atac2gex",
        "GSE140203_K562_HG19_atac2gex",
        "GSE140203_K562_MM10_atac2gex",
        "GSE140203_LUNG_atac2gex"]

In [2]:
train_leng=7000
test_leng=3000

In [3]:
import argparse
import logging
import os
import random

import torch

from dance import logger
from dance.datasets.multimodality import ModalityPredictionDataset
from dance.modules.multi_modality.predict_modality.babel import BabelWrapper
from dance.utils import set_seed


OPTIMIZER_DICT = {
    "adam": torch.optim.Adam,
    "rmsprop": torch.optim.RMSprop,
}
BabelWrapper_scores=[]
rndseed = random.randint(0, 2147483647)
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--subtask", default="openproblems_bmmc_cite_phase2_rna")
parser.add_argument("-device", "--device", default="cuda")
parser.add_argument("-cpu", "--cpus", default=1, type=int)
parser.add_argument("-seed", "--rnd_seed", default=rndseed, type=int)
parser.add_argument("-m", "--model_folder", default="./models")
parser.add_argument("--outdir", "-o", default="./logs", help="Directory to output to")
parser.add_argument("--lossweight", type=float, default=1., help="Relative loss weight")
parser.add_argument("--lr", "-l", type=float, default=0.01, help="Learning rate")
parser.add_argument("--batchsize", "-b", type=int, default=64, help="Batch size")
parser.add_argument("--hidden", type=int, default=64, help="Hidden dimensions")
parser.add_argument("--earlystop", type=int, default=5, help="Early stopping after N epochs")
parser.add_argument("--naive", "-n", action="store_true", help="Use a naive model instead of lego model")
parser.add_argument("--resume", action="store_true")
parser.add_argument("--span", default=0.3, type=float)
parser.add_argument("--max_epochs", type=int, default=500)
for dataset in datasets:
    v=None
    try:
        args = parser.parse_args(['--subtask',dataset,'--device','cuda:1','--span','1.0'])
        args.resume = True

        torch.set_num_threads(args.cpus)
        rndseed = args.rnd_seed
        set_seed(rndseed)
        dataset = ModalityPredictionDataset(args.subtask, preprocess="feature_selection",span=args.span)
        data = dataset.load_data()

        device = args.device
        args.outdir = os.path.abspath(args.outdir)
        os.makedirs(args.model_folder, exist_ok=True)
        os.makedirs(args.outdir, exist_ok=True)

        # Specify output log file
        fh = logging.FileHandler(f"{args.outdir}/training_{args.subtask}_{args.rnd_seed}.log", "w")
        fh.setLevel(logging.INFO)
        logger.addHandler(fh)

        for arg in vars(args):
            logger.info(f"Parameter {arg}: {getattr(args, arg)}")

        # Obtain training and testing data
        x_train, y_train = data.get_train_data(return_type="torch")
        x_test, y_test = data.get_test_data(return_type="torch")
        x_train, y_train= x_train.float()[:train_leng], y_train.float()[:train_leng]
        x_test, y_test =x_test.float()[:test_leng], y_test.float()[:test_leng]
        # Train and evaluate the model
        model = BabelWrapper(args, dim_in=x_train.shape[1], dim_out=y_train.shape[1])
        model.fit(x_train, y_train, val_ratio=0.15)
        print(model.predict(x_test))
        score=model.score(x_test, y_test)
        
    except (Exception, BaseException) as e:
        v=e
    else:
        v=score
    finally:
        print(v)
        BabelWrapper_scores.append(v)
        torch.cuda.empty_cache()
"""To reproduce BABEL on other samples, please refer to command lines belows:
GEX to ADT (subset):
python babel.py --subtask openproblems_bmmc_cite_phase2_rna_subset --device cuda

GEX to ADT:
python babel.py --subtask openproblems_bmmc_cite_phase2_rna --device cuda

ADT to GEX:
python babel.py --subtask openproblems_bmmc_cite_phase2_mod2 --device cuda

GEX to ATAC:
python babel.py --subtask openproblems_bmmc_multiome_phase2_rna --device cuda

ATAC to GEX:
python babel.py --subtask openproblems_bmmc_multiome_phase2_mod2 --device cuda
"""


[INFO][2023-09-19 21:00:58,909][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:00:58,911][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_train_mod1.h5ad
[INFO][2023-09-19 21:00:58,975][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_train_mod2.h5ad
[INFO][2023-09-19 21:00:59,104][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_test_mod1.h5ad
[INFO][2023-09-19 21:00:59,152][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_test_mod2.h5ad
[INFO][2023-09-19 21:00:59,415][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:00:59,616][dance][set_config_from_dict] Setti

epoch:  1
training (sum of 4 losses): 0.4884380333686388
validation (prediction loss): 0.1444048840880803
epoch:  2
training (sum of 4 losses): 0.4676007402520026
validation (prediction loss): 0.14437734382443324
epoch:  3
training (sum of 4 losses): 0.41085299677265585
validation (prediction loss): 0.1421603447715486
epoch:  4
training (sum of 4 losses): 0.40321212085664915
validation (prediction loss): 0.14489776693860312
epoch:  5
training (sum of 4 losses): 0.3787506782720166
validation (prediction loss): 0.14410383723233866
epoch:  6
training (sum of 4 losses): 0.3865004012020685
validation (prediction loss): 0.14192507909423738
epoch:  7
training (sum of 4 losses): 0.3756286142173634
validation (prediction loss): 0.1387318035332448
epoch:  8
training (sum of 4 losses): 0.38754369006041556
validation (prediction loss): 0.1375600077144824
epoch:  9
training (sum of 4 losses): 0.369317932034372
validation (prediction loss): 0.1392066564413893
epoch:  10
training (sum of 4 losses): 0

[INFO][2023-09-19 21:02:12,964][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:02:12,966][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/5k_pbmc_subset/5k_pbmc_subset.5kanti_dataset.output_train_mod1.h5ad
[INFO][2023-09-19 21:02:13,017][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/5k_pbmc_subset/5k_pbmc_subset.5kanti_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:02:13,032][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/5k_pbmc_subset/5k_pbmc_subset.5kanti_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:02:13,078][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/5k_pbmc_subset/5k_pbmc_subset.5kanti_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:02:13,118][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:02:13,156][dance][set_config_

0.1266901373398157


[INFO][2023-09-19 21:02:13,165][dance][<module>] Parameter device: cuda:1
[INFO][2023-09-19 21:02:13,166][dance][<module>] Parameter cpus: 1
[INFO][2023-09-19 21:02:13,167][dance][<module>] Parameter rnd_seed: 1472461312
[INFO][2023-09-19 21:02:13,169][dance][<module>] Parameter model_folder: ./models
[INFO][2023-09-19 21:02:13,170][dance][<module>] Parameter outdir: /home/zyxing/dance/examples/multi_modality/predict_modality/logs
[INFO][2023-09-19 21:02:13,171][dance][<module>] Parameter lossweight: 1.0
[INFO][2023-09-19 21:02:13,172][dance][<module>] Parameter lr: 0.01
[INFO][2023-09-19 21:02:13,173][dance][<module>] Parameter batchsize: 64
[INFO][2023-09-19 21:02:13,173][dance][<module>] Parameter hidden: 64
[INFO][2023-09-19 21:02:13,174][dance][<module>] Parameter earlystop: 20
[INFO][2023-09-19 21:02:13,175][dance][<module>] Parameter naive: False
[INFO][2023-09-19 21:02:13,176][dance][<module>] Parameter resume: True
[INFO][2023-09-19 21:02:13,179][dance][<module>] Parameter spa

epoch:  1
training (sum of 4 losses): 3.0859150451818302
validation (prediction loss): 0.4343856859689326
epoch:  2
training (sum of 4 losses): 3.0873978654102934
validation (prediction loss): 0.4343856305087138
epoch:  3
training (sum of 4 losses): 3.0738648255856367
validation (prediction loss): 0.4343856305087138
epoch:  4
training (sum of 4 losses): 3.101783684985612
validation (prediction loss): 0.4343856305087138
epoch:  5
training (sum of 4 losses): 3.092152844994299
validation (prediction loss): 0.4343856305087138
epoch:  6
training (sum of 4 losses): 3.0926215876775083
validation (prediction loss): 0.4343856305087138
epoch:  7
training (sum of 4 losses): 3.087234290356758
validation (prediction loss): 0.4343856305087138
epoch:  8
training (sum of 4 losses): 3.065364938309436
validation (prediction loss): 0.4343856305087138
epoch:  9
training (sum of 4 losses): 3.0991285987278467
validation (prediction loss): 0.4343856305087138
epoch:  10
training (sum of 4 losses): 3.075914652

[INFO][2023-09-19 21:13:42,149][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:13:42,153][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:13:42,154][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod1.h5ad


epoch:  500
training (sum of 4 losses): 3.082521486676909
validation (prediction loss): 0.4343856305087138
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:1')
19.13239858295633
unknown url type: ''


[INFO][2023-09-19 21:13:44,913][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:13:45,039][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:13:46,172][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:14:03,041][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:14:09,266][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:14:09,267][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:14:09,268][danc

epoch:  1
training (sum of 4 losses): 177.49334807037025
validation (prediction loss): 79.46867807530691
epoch:  2
training (sum of 4 losses): 129.29607973816573
validation (prediction loss): 57.69537625325045
epoch:  3
training (sum of 4 losses): 74.16864686371177
validation (prediction loss): 62.00613760925471
epoch:  4
training (sum of 4 losses): 60.73457623553532
validation (prediction loss): 44.08797680110216
epoch:  5
training (sum of 4 losses): 58.37695521693076
validation (prediction loss): 50.51903749635057
epoch:  6
training (sum of 4 losses): 56.85867354690388
validation (prediction loss): 48.17986016027018
epoch:  7
training (sum of 4 losses): 54.85173416137695
validation (prediction loss): 43.13511872277288
epoch:  8
training (sum of 4 losses): 52.207647262081025
validation (prediction loss): 53.653946000414926
epoch:  9
training (sum of 4 losses): 53.18870653131957
validation (prediction loss): 44.535603300823674
epoch:  10
training (sum of 4 losses): 51.458678973618376
v

[INFO][2023-09-19 21:15:01,365][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:15:01,367][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod1.h5ad


epoch:  35
training (sum of 4 losses): 35.34936061982186
validation (prediction loss): 46.29487638780536
Early stopped.
tensor([[ 0.0000,  2.6790,  6.1411,  ...,  2.8162,  9.5749,  6.5767],
        [ 0.0000,  3.4002,  6.9197,  ...,  3.6601, 11.2052,  6.4026],
        [ 0.0000,  3.1497,  6.5538,  ...,  3.8472, 13.1526,  6.2175],
        ...,
        [ 0.0000,  2.4016,  5.9496,  ...,  3.0232, 10.2940,  7.0378],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, 46.7790,  6.2148],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, 22.3061,  4.4539]],
       device='cuda:1')
51.56493365468921


[INFO][2023-09-19 21:15:04,603][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:15:06,745][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:15:08,066][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:15:26,066][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:15:42,618][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:15:42,619][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:15:42,620][dance][load_data

epoch:  1
training (sum of 4 losses): 6.054888586844167
validation (prediction loss): 2.039194893294945
epoch:  2
training (sum of 4 losses): 3.550351599211334
validation (prediction loss): 1.5884296928334243
epoch:  3
training (sum of 4 losses): 3.055933831840433
validation (prediction loss): 1.6263510516215915
epoch:  4
training (sum of 4 losses): 2.9744460582733154
validation (prediction loss): 1.5773497617460455
epoch:  5
training (sum of 4 losses): 2.9202012323564097
validation (prediction loss): 1.590001056164847
epoch:  6
training (sum of 4 losses): 2.9055882397518364
validation (prediction loss): 1.5901638612378226
epoch:  7
training (sum of 4 losses): 2.8710520472577823
validation (prediction loss): 1.5926563774235332
epoch:  8
training (sum of 4 losses): 2.852466142305764
validation (prediction loss): 1.5811302306443198
epoch:  9
training (sum of 4 losses): 2.842612871559717
validation (prediction loss): 1.5856943665390573
epoch:  10
training (sum of 4 losses): 2.821272944891

[INFO][2023-09-19 21:18:21,804][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:18:21,805][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE127064_AdBrain_gex2atac/GSE127064_AdBrain_gex2atac.GSE126074_dataset.output_train_mod1.h5ad


epoch:  98
training (sum of 4 losses): 2.37927657558072
validation (prediction loss): 1.5612152854233874
Early stopped.
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:1')
2.6246179575154547


[INFO][2023-09-19 21:18:21,905][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE127064_AdBrain_gex2atac/GSE127064_AdBrain_gex2atac.GSE126074_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:18:22,360][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE127064_AdBrain_gex2atac/GSE127064_AdBrain_gex2atac.GSE126074_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:18:22,427][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE127064_AdBrain_gex2atac/GSE127064_AdBrain_gex2atac.GSE126074_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:18:23,179][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:18:24,862][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:18:24,863][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:18:24,864][dance][load_data] Raw dat

epoch:  1
training (sum of 4 losses): 1.686972430957261
validation (prediction loss): 0.11411561623502865
epoch:  2
training (sum of 4 losses): 1.1477442018447384
validation (prediction loss): 0.11383731198202648
epoch:  3
training (sum of 4 losses): 1.0694541649151874
validation (prediction loss): 0.11377274181994997
epoch:  4
training (sum of 4 losses): 1.0288011816240126
validation (prediction loss): 0.11378918352592651
epoch:  5
training (sum of 4 losses): 1.0142527280315277
validation (prediction loss): 0.1137605984789322
epoch:  6
training (sum of 4 losses): 1.0023410288236474
validation (prediction loss): 0.11376102177809314
epoch:  7
training (sum of 4 losses): 0.9935873849417574
validation (prediction loss): 0.113743790206662
epoch:  8
training (sum of 4 losses): 0.9820077271871669
validation (prediction loss): 0.11374340320740363
epoch:  9
training (sum of 4 losses): 0.9652831009639207
validation (prediction loss): 0.11373652417821603
epoch:  10
training (sum of 4 losses): 0.

[INFO][2023-09-19 21:22:31,648][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:22:31,650][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE127064_p0Brain_gex2atac/GSE127064_p0Brain_gex2atac.GSE126074_dataset.output_train_mod1.h5ad
[INFO][2023-09-19 21:22:31,701][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE127064_p0Brain_gex2atac/GSE127064_p0Brain_gex2atac.GSE126074_dataset.output_train_mod2.h5ad


0.10078361678486301


[INFO][2023-09-19 21:22:32,057][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE127064_p0Brain_gex2atac/GSE127064_p0Brain_gex2atac.GSE126074_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:22:32,095][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE127064_p0Brain_gex2atac/GSE127064_p0Brain_gex2atac.GSE126074_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:22:33,067][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:22:34,061][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:22:34,062][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:22:34,064][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object with n_obs × n_vars = 5081 × 239429
  uns:	'dance_config'
  2 modalities
    mod1:	5081 x 10000
      obs:	'batch'
      layers:	'counts'
    mod2:	5081 x 229429
     

epoch:  1
training (sum of 4 losses): 0.6340719697376093
validation (prediction loss): 0.12085669782509582
epoch:  2
training (sum of 4 losses): 0.6005005799233913
validation (prediction loss): 0.12070306011857995
epoch:  3
training (sum of 4 losses): 0.5894700984160105
validation (prediction loss): 0.12065047952721512
epoch:  4
training (sum of 4 losses): 0.580543494472901
validation (prediction loss): 0.12073649883290594
epoch:  5
training (sum of 4 losses): 0.5766722845534483
validation (prediction loss): 0.12064669835708183
epoch:  6
training (sum of 4 losses): 0.5740952255825201
validation (prediction loss): 0.12069246455250954
epoch:  7
training (sum of 4 losses): 0.5710397822161516
validation (prediction loss): 0.12068486485358319
epoch:  8
training (sum of 4 losses): 0.5676821085313956
validation (prediction loss): 0.12073363316721383
epoch:  9
training (sum of 4 losses): 0.5645412343243758
validation (prediction loss): 0.12070713785833419
epoch:  10
training (sum of 4 losses):

[INFO][2023-09-19 21:23:49,197][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:23:49,199][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE117089_mouse_gex2atac/GSE117089_mouse_gex2atac.GSE117089_dataset.output_train_mod1.h5ad
[INFO][2023-09-19 21:23:49,333][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE117089_mouse_gex2atac/GSE117089_mouse_gex2atac.GSE117089_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:23:49,718][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE117089_mouse_gex2atac/GSE117089_mouse_gex2atac.GSE117089_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:23:49,823][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE117089_mouse_gex2atac/GSE117089_mouse_gex2atac.GSE117089_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:23:50,423][danc

epoch:  1
training (sum of 4 losses): 7.195250362478276
validation (prediction loss): 0.9033749203330936
epoch:  2
training (sum of 4 losses): 4.742904732304234
validation (prediction loss): 0.7112232107592957
epoch:  3
training (sum of 4 losses): 3.6593864861355034
validation (prediction loss): 0.6854514747191773
epoch:  4
training (sum of 4 losses): 3.248938229776198
validation (prediction loss): 0.6633752969626507
epoch:  5
training (sum of 4 losses): 3.029469874597365
validation (prediction loss): 0.6653697231500104
epoch:  6
training (sum of 4 losses): 2.863940420971122
validation (prediction loss): 0.6842632688681899
epoch:  7
training (sum of 4 losses): 2.693825178248908
validation (prediction loss): 0.67946860713661
epoch:  8
training (sum of 4 losses): 2.63599919760099
validation (prediction loss): 0.6802026963940365
epoch:  9
training (sum of 4 losses): 2.520859738831879
validation (prediction loss): 0.6953922220982671
epoch:  10
training (sum of 4 losses): 2.44619051102669
v

[INFO][2023-09-19 21:26:36,213][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:26:36,215][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_3T3_HG19_atac2gex/GSE140203_3T3_HG19_atac2gex.GSE140203_dataset.output_train_mod1.h5ad


0.5988701603438525


[INFO][2023-09-19 21:26:36,553][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_3T3_HG19_atac2gex/GSE140203_3T3_HG19_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:26:36,604][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_3T3_HG19_atac2gex/GSE140203_3T3_HG19_atac2gex.GSE140203_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:26:36,917][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_3T3_HG19_atac2gex/GSE140203_3T3_HG19_atac2gex.GSE140203_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:27:04,792][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:27:05,000][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:27:05,001][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:27:05,002][dance][load_data] R

epoch:  1
training (sum of 4 losses): 1.2440299212932586
validation (prediction loss): 0.8407128849141869
epoch:  2
training (sum of 4 losses): 1.1802061438560485
validation (prediction loss): 0.7923565389834389
epoch:  3
training (sum of 4 losses): 1.0934483140707016
validation (prediction loss): 0.6092397649799105
epoch:  4
training (sum of 4 losses): 1.0242873102426528
validation (prediction loss): 0.5396056893834356
epoch:  5
training (sum of 4 losses): 0.9408860266208648
validation (prediction loss): 0.5558832470398108
epoch:  6
training (sum of 4 losses): 0.9143607974052429
validation (prediction loss): 0.5525688354841787
epoch:  7
training (sum of 4 losses): 0.8876344472169876
validation (prediction loss): 0.552781153610146
epoch:  8
training (sum of 4 losses): 0.872772490978241
validation (prediction loss): 0.5135312232302631
epoch:  9
training (sum of 4 losses): 0.8918287307024002
validation (prediction loss): 0.5161979447056466
epoch:  10
training (sum of 4 losses): 0.8802196

[INFO][2023-09-19 21:27:25,631][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:27:25,633][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_3T3_MM10_atac2gex/GSE140203_3T3_MM10_atac2gex.GSE140203_dataset.output_train_mod1.h5ad


epoch:  58
training (sum of 4 losses): 0.8079221338033676
validation (prediction loss): 0.4892658453793959
Early stopped.
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:1')
0.43203375679378764


[INFO][2023-09-19 21:27:25,775][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_3T3_MM10_atac2gex/GSE140203_3T3_MM10_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:27:25,871][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_3T3_MM10_atac2gex/GSE140203_3T3_MM10_atac2gex.GSE140203_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:27:25,965][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_3T3_MM10_atac2gex/GSE140203_3T3_MM10_atac2gex.GSE140203_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:27:26,555][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:27:26,827][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:27:26,829][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:27:26,830][dance][load_data] R

epoch:  1
training (sum of 4 losses): 3.4619213008880614
validation (prediction loss): 1.7347826847989025
epoch:  2
training (sum of 4 losses): 3.2861946487426756
validation (prediction loss): 1.6167329305983797
epoch:  3
training (sum of 4 losses): 2.9744982814788816
validation (prediction loss): 1.470161311841974
epoch:  4
training (sum of 4 losses): 2.5268531131744383
validation (prediction loss): 1.3226941584896406
epoch:  5
training (sum of 4 losses): 2.0210997581481935
validation (prediction loss): 1.3289017548841087
epoch:  6
training (sum of 4 losses): 1.669931139945984
validation (prediction loss): 1.2054750354575785
epoch:  7
training (sum of 4 losses): 1.5214461517333984
validation (prediction loss): 1.11888961773238
epoch:  8
training (sum of 4 losses): 1.4825672388076783
validation (prediction loss): 1.1139806163597892
epoch:  9
training (sum of 4 losses): 1.4863713455200196
validation (prediction loss): 1.1227721566592197
epoch:  10
training (sum of 4 losses): 1.463402161

[INFO][2023-09-19 21:27:39,727][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:27:39,730][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_12878.rep2_atac2gex/GSE140203_12878.rep2_atac2gex.GSE140203_dataset.output_train_mod1.h5ad


epoch:  28
training (sum of 4 losses): 1.3536356067657471
validation (prediction loss): 1.1359012666194317
Early stopped.
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.3821, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:1')
0.753119854098266


[INFO][2023-09-19 21:27:40,212][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_12878.rep2_atac2gex/GSE140203_12878.rep2_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:27:40,297][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_12878.rep2_atac2gex/GSE140203_12878.rep2_atac2gex.GSE140203_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:27:40,737][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_12878.rep2_atac2gex/GSE140203_12878.rep2_atac2gex.GSE140203_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:28:06,653][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:28:06,885][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:28:06,887][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:28:06,888][dance][

epoch:  1
training (sum of 4 losses): 1.2733942491036874
validation (prediction loss): 0.5101479014112257
epoch:  2
training (sum of 4 losses): 1.1874933772616916
validation (prediction loss): 0.47536874625169717
epoch:  3
training (sum of 4 losses): 1.0707807717499909
validation (prediction loss): 0.44905143607359876
epoch:  4
training (sum of 4 losses): 0.9863690601454841
validation (prediction loss): 0.446120379266086
epoch:  5
training (sum of 4 losses): 0.9708190140900789
validation (prediction loss): 0.46779514971139435
epoch:  6
training (sum of 4 losses): 0.949304496800458
validation (prediction loss): 0.4296741691602186
epoch:  7
training (sum of 4 losses): 0.9371073268077992
validation (prediction loss): 0.4281656259905467
epoch:  8
training (sum of 4 losses): 0.9390008626160798
validation (prediction loss): 0.43167693386730543
epoch:  9
training (sum of 4 losses): 0.9217096854139257
validation (prediction loss): 0.43137035279933567
epoch:  10
training (sum of 4 losses): 0.92

[INFO][2023-09-19 21:28:25,152][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:28:25,154][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_12878.rep3_atac2gex/GSE140203_12878.rep3_atac2gex.GSE140203_dataset.output_train_mod1.h5ad


epoch:  39
training (sum of 4 losses): 0.9019504697234543
validation (prediction loss): 0.444219938793866
Early stopped.
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:1')
0.55351224156586


[INFO][2023-09-19 21:28:26,062][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_12878.rep3_atac2gex/GSE140203_12878.rep3_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:28:26,403][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_12878.rep3_atac2gex/GSE140203_12878.rep3_atac2gex.GSE140203_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:28:26,860][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_12878.rep3_atac2gex/GSE140203_12878.rep3_atac2gex.GSE140203_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:28:44,647][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:28:46,151][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:28:46,152][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:28:46,154][dance][

epoch:  1
training (sum of 4 losses): 1.4232844678304528
validation (prediction loss): 0.5275329981623834
epoch:  2
training (sum of 4 losses): 1.1918446876669442
validation (prediction loss): 0.5229252794943787
epoch:  3
training (sum of 4 losses): 1.1316799213809352
validation (prediction loss): 0.5216518258477311
epoch:  4
training (sum of 4 losses): 1.0875388178774106
validation (prediction loss): 0.5267297165956305
epoch:  5
training (sum of 4 losses): 1.0539429251865675
validation (prediction loss): 0.5236867197497354
epoch:  6
training (sum of 4 losses): 1.04711425881232
validation (prediction loss): 0.5353541460778607
epoch:  7
training (sum of 4 losses): 1.0147536570026028
validation (prediction loss): 0.5255863259824153
epoch:  8
training (sum of 4 losses): 1.0068120821829765
validation (prediction loss): 0.5339721693394394
epoch:  9
training (sum of 4 losses): 1.0047983892502323
validation (prediction loss): 0.5234349395582156
epoch:  10
training (sum of 4 losses): 0.9878080

[INFO][2023-09-19 21:29:30,803][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:29:30,806][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_K562_HG19_atac2gex/GSE140203_K562_HG19_atac2gex.GSE140203_dataset.output_train_mod1.h5ad


0.2339013558986142


[INFO][2023-09-19 21:29:31,037][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_K562_HG19_atac2gex/GSE140203_K562_HG19_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:29:31,105][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_K562_HG19_atac2gex/GSE140203_K562_HG19_atac2gex.GSE140203_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:29:31,262][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_K562_HG19_atac2gex/GSE140203_K562_HG19_atac2gex.GSE140203_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:29:32,811][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:29:32,971][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:29:32,972][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:29:32,974][dance][load_d

epoch:  1
training (sum of 4 losses): 0.6097353416330674
validation (prediction loss): 0.19198396409003873
epoch:  2
training (sum of 4 losses): 0.5933447778224945
validation (prediction loss): 0.19290750859942385
epoch:  3
training (sum of 4 losses): 0.5904526920879588
validation (prediction loss): 0.1923082335414629
epoch:  4
training (sum of 4 losses): 0.5010337641133982
validation (prediction loss): 0.1513654647236891
epoch:  5
training (sum of 4 losses): 0.47387221280266256
validation (prediction loss): 0.15107476228790057
epoch:  6
training (sum of 4 losses): 0.46636655532261906
validation (prediction loss): 0.152706475526842
epoch:  7
training (sum of 4 losses): 0.4622322947663419
validation (prediction loss): 0.15242317922836798
epoch:  8
training (sum of 4 losses): 0.460110559621278
validation (prediction loss): 0.15608944790219065
epoch:  9
training (sum of 4 losses): 0.4594288814593764
validation (prediction loss): 0.15413981504013213
epoch:  10
training (sum of 4 losses): 0

[INFO][2023-09-19 21:30:03,089][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:30:03,091][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_K562_MM10_atac2gex/GSE140203_K562_MM10_atac2gex.GSE140203_dataset.output_train_mod1.h5ad
[INFO][2023-09-19 21:30:03,464][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_K562_MM10_atac2gex/GSE140203_K562_MM10_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:30:03,578][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_K562_MM10_atac2gex/GSE140203_K562_MM10_atac2gex.GSE140203_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:30:03,869][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_K562_MM10_atac2gex/GSE140203_K562_MM10_atac2gex.GSE140203_dataset.output_test_mod2.h5ad
[INF

epoch:  1
training (sum of 4 losses): 1.3197136511062753
validation (prediction loss): 0.47672586440903114
epoch:  2
training (sum of 4 losses): 0.9774285493225887
validation (prediction loss): 0.44657988185557557
epoch:  3
training (sum of 4 losses): 0.8974346415749912
validation (prediction loss): 0.44872536535112895
epoch:  4
training (sum of 4 losses): 0.8646423025377865
validation (prediction loss): 0.44248193299862215
epoch:  5
training (sum of 4 losses): 0.851755539918768
validation (prediction loss): 0.44651705061700764
epoch:  6
training (sum of 4 losses): 0.8440704376533114
validation (prediction loss): 0.4452520813177969
epoch:  7
training (sum of 4 losses): 0.8436905402561714
validation (prediction loss): 0.44608051974010976
epoch:  8
training (sum of 4 losses): 0.8319747232157608
validation (prediction loss): 0.4446780858645408
epoch:  9
training (sum of 4 losses): 0.8329041425524086
validation (prediction loss): 0.45183504584867473
epoch:  10
training (sum of 4 losses): 0

[INFO][2023-09-19 21:31:08,302][dance][set_seed] Setting global random seed to 1472461312
[INFO][2023-09-19 21:31:08,304][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_LUNG_atac2gex/GSE140203_LUNG_atac2gex.GSE140203_dataset.output_train_mod1.h5ad


epoch:  45
training (sum of 4 losses): 0.7774057994628775
validation (prediction loss): 0.4499552844208699
Early stopped.
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:1')
0.29593586448934517


[INFO][2023-09-19 21:31:08,599][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_LUNG_atac2gex/GSE140203_LUNG_atac2gex.GSE140203_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:31:08,635][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_LUNG_atac2gex/GSE140203_LUNG_atac2gex.GSE140203_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:31:08,892][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/GSE140203_LUNG_atac2gex/GSE140203_LUNG_atac2gex.GSE140203_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:31:14,992][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:31:15,054][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:31:15,055][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:31:15,057][dance][load_data] Raw data loaded:
Data obj

epoch:  1
training (sum of 4 losses): 1.1234569045213552
validation (prediction loss): 0.42410048656450705
epoch:  2
training (sum of 4 losses): 1.0052658869669988
validation (prediction loss): 0.3943372759482203
epoch:  3
training (sum of 4 losses): 0.8944838688923762
validation (prediction loss): 0.36574033550996954
epoch:  4
training (sum of 4 losses): 0.7780089240807754
validation (prediction loss): 0.34837227925133674
epoch:  5
training (sum of 4 losses): 0.6740867449687078
validation (prediction loss): 0.30719822200976293
epoch:  6
training (sum of 4 losses): 0.6373183589715224
validation (prediction loss): 0.3040490430341659
epoch:  7
training (sum of 4 losses): 0.6293028317964994
validation (prediction loss): 0.3129961962004108
epoch:  8
training (sum of 4 losses): 0.6117697174732502
validation (prediction loss): 0.3058858641565913
epoch:  9
training (sum of 4 losses): 0.608460022852971
validation (prediction loss): 0.3090762338200057
epoch:  10
training (sum of 4 losses): 0.60

'To reproduce BABEL on other samples, please refer to command lines belows:\nGEX to ADT (subset):\npython babel.py --subtask openproblems_bmmc_cite_phase2_rna_subset --device cuda\n\nGEX to ADT:\npython babel.py --subtask openproblems_bmmc_cite_phase2_rna --device cuda\n\nADT to GEX:\npython babel.py --subtask openproblems_bmmc_cite_phase2_mod2 --device cuda\n\nGEX to ATAC:\npython babel.py --subtask openproblems_bmmc_multiome_phase2_rna --device cuda\n\nATAC to GEX:\npython babel.py --subtask openproblems_bmmc_multiome_phase2_mod2 --device cuda\n'

In [4]:
BabelWrapper_scores

[0.1266901373398157,
 19.13239858295633,
 ValueError("unknown url type: ''"),
 51.56493365468921,
 2.6246179575154547,
 0.10078361678486301,
 0.12101614680464011,
 0.5988701603438525,
 0.43203375679378764,
 0.753119854098266,
 0.55351224156586,
 0.2339013558986142,
 0.1670045178863737,
 0.29593586448934517,
 0.25755463333076084]

In [6]:
"""Main functionality for starting training.

This code is based on https://github.com/NVlabs/MUNIT.

"""
import argparse
import os
import random

import torch
from sklearn import preprocessing

from dance.datasets.multimodality import ModalityPredictionDataset
from dance.modules.multi_modality.predict_modality.cmae import CMAE
from dance.utils import set_seed


rndseed = random.randint(0, 2147483647)
parser = argparse.ArgumentParser()
parser.add_argument("--output_path", type=str, default="./predict_modality/output", help="outputs path")
parser.add_argument("--resume", action="store_true")
parser.add_argument("-t", "--subtask", default="openproblems_bmmc_cite_phase2_rna")
parser.add_argument("-device", "--device", default="cuda:1")
parser.add_argument("-cpu", "--cpus", default=1, type=int)
parser.add_argument("-seed", "--rnd_seed", default=rndseed, type=int)
parser.add_argument("--span", default=0.3, type=float)
parser.add_argument("--selection_threshold", default=10000, type=int)
parser.add_argument("--max_epochs", default=5, type=int, help="maximum number of training epochs")
parser.add_argument("--batch_size", default=64, type=int, help="batch size")
parser.add_argument("--log_data", default=True, type=bool, help="take a log1p of the data as input")
parser.add_argument("--feature_filter", default=False, type=bool)
parser.add_argument("--normalize_data", default=True, type=bool,
                    help="normalize the data (after the log, if applicable)")
parser.add_argument("--weight_decay", default=1e-4, type=float, help="weight decay")
parser.add_argument("--beta1", default=0.5, type=float, help="Adam parameter")
parser.add_argument("--beta2", default=0.999, type=float, help="Adam parameter")
parser.add_argument("--init", default="kaiming", type=str,
                    help="initialization [gaussian/kaiming/xavier/orthogonal]")
parser.add_argument("--lr", default=1e-4, type=float, help="initial learning rate")
parser.add_argument("--lr_policy", default="step", type=str, help="learning rate scheduler")
parser.add_argument("--step_size", default=100000, type=int, help="how often to decay learning rate")
parser.add_argument("--gamma", default=0.5, type=float, help="how much to decay learning rate")
parser.add_argument("--gan_w", default=10, type=int, help="weight of adversarial loss")
parser.add_argument("--recon_x_w", default=10, type=int, help="weight of image reconstruction loss")
parser.add_argument("--recon_h_w", default=0, type=int, help="weight of hidden reconstruction loss")
parser.add_argument("--recon_kl_w", default=0, type=int, help="weight of KL loss for reconstruction")
parser.add_argument("--supervise", default=1, type=float, help="fraction to supervise")
parser.add_argument("--super_w", default=0.1, type=float, help="weight of supervision loss")
CMAE_scores=[]
for dataset in datasets:
    try:
        opts = parser.parse_args(['--subtask',dataset,'--device',"cuda:2","--span","1.0",'--selection_threshold','10000'])
        device = opts.device

        torch.set_num_threads(opts.cpus)
        rndseed = opts.rnd_seed
        set_seed(rndseed)
        dataset = ModalityPredictionDataset(opts.subtask, preprocess="feature_selection",span=opts.span)
        data = dataset.load_data()

        output_directory = os.path.join(opts.output_path, "outputs")
        checkpoint_directory = os.path.join(output_directory, "checkpoints")
        os.makedirs(checkpoint_directory, exist_ok=True)

        # Manually set batch features
        le = preprocessing.LabelEncoder()
        batch = le.fit_transform(data.data.mod["mod1"].obs["batch"])
        data.data.mod["mod1"].obsm["batch"] = batch

        data.set_config(feature_mod=["mod1", "mod1"], label_mod="mod2", feature_channel_type=[None, "obsm"],
                        feature_channel=[None, "batch"], overwrite=True)

        # Obtain training and testing data
        (x_train, batch), y_train = data.get_train_data(return_type="torch")
        (x_test, _), y_test = data.get_test_data(return_type="torch")
        batch = batch.long().to(device)[:train_leng]
        x_train = x_train.float().to(device)[:train_leng]
        y_train = y_train.float().to(device)[:train_leng]
        x_test = x_test.float().to(device)[:test_leng]
        y_test = y_test.float().to(device)[:test_leng]

        config = vars(opts)
        # Some Fixed Settings
        config["input_dim_a"] = x_train.shape[1]
        config["input_dim_b"] = y_train.shape[1]
        config["resume"] = opts.resume
        config["num_of_classes"] = max(batch) + 1
        config["shared_layer"] = True
        config["gen"] = {
            "dim": 100,  # hidden layer
            "latent": 50,  # latent layer size
            "activ": "relu",
        }  # activation function [relu/lrelu/prelu/selu/tanh]
        config["dis"] = {
            "dim": 100,
            "norm": None,  # normalization layer [none/bn/in/ln]
            "activ": "lrelu",  # activation function [relu/lrelu/prelu/selu/tanh]
            "gan_type": "lsgan",
        }  # GAN loss [lsgan/nsgan]

        model = CMAE(config)
        model.to(device)

        model.fit(x_train, y_train, batch, checkpoint_directory)
        print(model.predict(x_test))
        score=model.score(x_test, y_test)
        print(score)
    except (Exception, BaseException) as e:
        v=e
    else:
        v=score
    finally:
        print(v)
        CMAE_scores.append(v)
        torch.cuda.empty_cache()
"""To reproduce CMAE on other samples, please refer to command lines belows:
GEX to ADT (subset):
python cmae.py --subtask openproblems_bmmc_cite_phase2_rna_subset --device cuda

GEX to ADT:
python cmae.py --subtask openproblems_bmmc_cite_phase2_rna --device cuda

ADT to GEX:
python cmae.py --subtask openproblems_bmmc_cite_phase2_mod2 --device cuda

GEX to ATAC:
python cmae.py --subtask openproblems_bmmc_multiome_phase2_rna --device cuda

ATAC to GEX:
python cmae.py --subtask openproblems_bmmc_multiome_phase2_mod2 --device cuda
"""


[INFO][2023-09-19 21:39:27,618][dance][set_seed] Setting global random seed to 1080833930
[INFO][2023-09-19 21:39:27,620][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_train_mod1.h5ad
[INFO][2023-09-19 21:39:27,718][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_train_mod2.h5ad
[INFO][2023-09-19 21:39:27,911][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_test_mod1.h5ad
[INFO][2023-09-19 21:39:27,978][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_test_mod2.h5ad
[INFO][2023-09-19 21:39:28,330][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:39:28,604][dance][set_config_from_dict] Setti

Iteration:  0
RMSE Loss: 0.30236404902800124
Iteration:  1
RMSE Loss: 1.0750895906420075
Iteration:  2
RMSE Loss: 2.07194759650223
Iteration:  3
RMSE Loss: 0.8797536537135732
Iteration:  4


In [6]:
CMAE_scores

[TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 TypeError("__init__() got an unexpected keyword argument 'selection_threshold'"),
 Typ

In [3]:
import argparse
import random

import torch

from dance.datasets.multimodality import ModalityPredictionDataset
from dance.modules.multi_modality.predict_modality.scmm import MMVAE
from dance.utils import set_seed


rndseed = random.randint(0, 2147483647)
parser = argparse.ArgumentParser()
parser.add_argument("--output_path", type=str, default="./predict_modality/output", help="outputs path")
parser.add_argument("--resume", action="store_true")
parser.add_argument("-t", "--subtask", default="openproblems_bmmc_cite_phase2_rna")
parser.add_argument("-device", "--device", default="cuda")
parser.add_argument("-cpu", "--cpus", default=1, type=int)
parser.add_argument("-seed", "--rnd_seed", default=rndseed, type=int)

parser.add_argument("--experiment", type=str, default="test", metavar="E", help="experiment name")
parser.add_argument("--obj", type=str, default="m_elbo_naive_warmup", metavar="O",
                    help="objective to use (default: elbo)")
parser.add_argument(
    "--llik_scaling", type=float, default=1., help="likelihood scaling for cub images/svhn modality when running in"
    "multimodal setting, set as 0 to use default value")
parser.add_argument("--batch_size", type=int, default=64, metavar="N", help="batch size for data (default: 256)")
parser.add_argument("--epochs", type=int, default=5, metavar="E", help="number of epochs to train (default: 100)")
parser.add_argument("--lr", type=float, default=1e-4, metavar="L", help="learning rate (default: 1e-3)")
parser.add_argument("--latent_dim", type=int, default=10, metavar="L", help="latent dimensionality (default: 20)")
parser.add_argument("--num_hidden_layers", type=int, default=2, metavar="H",
                    help="number of hidden layers in enc and dec (default: 2)")
parser.add_argument("--r_hidden_dim", type=int, default=100, help="number of hidden units in enc/dec for gene")
parser.add_argument("--p_hidden_dim", type=int, default=20,
                    help="number of hidden units in enc/dec for protein/peak")
parser.add_argument("--pre_trained", type=str, default="",
                    help="path to pre-trained model (train from scratch if empty)")
parser.add_argument("--learn_prior", type=bool, default=True, help="learn model prior parameters")
parser.add_argument("--print_freq", type=int, default=0, metavar="f",
                    help="frequency with which to print stats (default: 0)")
parser.add_argument("--deterministic_warmup", type=int, default=50, metavar="W", help="deterministic warmup")
parser.add_argument("--span", default=0.3, type=float)
MMVAE_scores=[]
for dataset in datasets:
    try:
        args = parser.parse_args(['--subtask',dataset,'--device','cuda:5','--span','1.0'])
        if args.subtask in ['pbmc_cite','openproblems_2022_cite_gex2adt']:
            model_class="rna-protein"
        elif args.subtask in ['GSE117089_A549_gex2atac','GSE117089_sciCAR_gex2atac','GSE127064_AdBrain_gex2atac','GSE127064_p0Brain_gex2atac']:
            model_class="rna-dna"
        else:
            raise RuntimeError("subtask does not exist")
        torch.set_num_threads(args.cpus)
        rndseed = args.rnd_seed
        set_seed(rndseed)
        dataset = ModalityPredictionDataset(args.subtask, preprocess="feature_selection",span=args.span)
        data = dataset.load_data()

        data.set_config(feature_mod="mod1", label_mod="mod2", feature_channel_type="layers", feature_channel="counts",
                        label_channel_type="layers", label_channel="counts")

        # Obtain training and testing data
        x_train, y_train = data.get_train_data(return_type="torch")
        x_test, y_test = data.get_test_data(return_type="torch")
        x_train,y_train=x_train[:train_leng],y_train[:train_leng]
        x_test, y_test = x_test[:test_leng],y_test[:test_leng]
        args.r_dim = x_train.shape[1]
        args.p_dim = y_train.shape[1]
      
        model_class = "rna-protein" if args.subtask == "openproblems_bmmc_cite_phase2_rna" else "rna-dna"
        model = MMVAE(model_class, args).to(args.device)

        model.fit(x_train, y_train)
        print(model.predict(x_test))
        score=model.score(x_test, y_test)
        print(score)
    except (Exception, BaseException) as e:
        v=e
    else:
        v=score
    finally:
        print(v)
        MMVAE_scores.append(v)
        torch.cuda.empty_cache()
"""To reproduce scMM on other samples, please refer to command lines belows:
GEX to ADT (subset):
python scmm.py --subtask openproblems_bmmc_cite_phase2_rna_subset --device cuda

GEX to ADT:
python scmm.py --subtask openproblems_bmmc_cite_phase2_rna --device cuda

ADT to GEX:
python scmm.py --subtask openproblems_bmmc_cite_phase2_mod2 --device cuda

GEX to ATAC:
python scmm.py --subtask openproblems_bmmc_multiome_phase2_rna --device cuda

ATAC to GEX:
python scmm.py --subtask openproblems_bmmc_multiome_phase2_mod2 --device cuda
"""


[INFO][2023-09-19 21:50:02,828][dance][set_seed] Setting global random seed to 1760354272
[INFO][2023-09-19 21:50:02,828][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/pbmc_cite/pbmc_cite.citeanti_dataset.output_train_mod1.h5ad


subtask does not exist
subtask does not exist


[INFO][2023-09-19 21:50:03,167][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/pbmc_cite/pbmc_cite.citeanti_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:50:03,200][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/pbmc_cite/pbmc_cite.citeanti_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:50:03,353][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/pbmc_cite/pbmc_cite.citeanti_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:50:04,502][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:50:05,309][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:50:05,310][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:50:05,311][dance][load_data] Raw data loaded:
Data object that wraps (.data):
MuData object with n_obs × n_vars = 33454 × 10025
  uns:	'dance

====> Epoch: 001 Train loss: 5349.8446
====>             Valid loss: 5021.2367
====> Epoch: 002 Train loss: 4715.2114
====>             Valid loss: 4495.5334
====> Epoch: 003 Train loss: 4231.3334
====>             Valid loss: 4079.0045
====> Epoch: 004 Train loss: 3854.5760
====>             Valid loss: 3711.0837
====> Epoch: 005 Train loss: 3537.7190
====>             Valid loss: 3449.8149
====> Epoch: 006 Train loss: 3270.6652
====>             Valid loss: 3235.4207
====> Epoch: 007 Train loss: 3048.0922
====>             Valid loss: 2987.4435
====> Epoch: 008 Train loss: 2851.9083
====>             Valid loss: 2812.4503
====> Epoch: 009 Train loss: 2686.9450
====>             Valid loss: 2687.0422
====> Epoch: 010 Train loss: 2526.9660
====>             Valid loss: 2524.9074
Valid RMSELoss: 0.7114304730241654
tensor([[ 0.,  0.,  1.,  ...,  0.,  0.,  0.],
        [26., 21., 12.,  ...,  0.,  0.,  0.],
        [ 6.,  2.,  0.,  ...,  0., 12., 56.],
        ...,
        [17.,  9., 15., 

[INFO][2023-09-19 21:50:33,658][dance][set_seed] Setting global random seed to 1760354272
[INFO][2023-09-19 21:50:33,660][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod1.h5ad


subtask does not exist


[INFO][2023-09-19 21:50:37,035][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 21:50:38,854][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 21:50:40,017][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_cite_gex2adt/openproblems_2022_cite_gex2adt.open_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 21:50:58,388][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-09-19 21:51:16,074][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 21:51:16,075][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'
[INFO][2023-09-19 21:51:16,076][dance][load_data

====> Epoch: 001 Train loss: 20033.4733
====>             Valid loss: 19581.1498
====> Epoch: 002 Train loss: 18918.0257
====>             Valid loss: 18502.4225
====> Epoch: 003 Train loss: 17967.0510
====>             Valid loss: 17643.2184
====> Epoch: 004 Train loss: 17154.6401
====>             Valid loss: 16864.0279
====> Epoch: 005 Train loss: 16429.7727
====>             Valid loss: 16202.2957
====> Epoch: 006 Train loss: 15796.9335
====>             Valid loss: 15560.6810


In [None]:
MMVAE_scores

[RuntimeError('subtask does not exist'),
 RuntimeError('subtask does not exist'),
 0.2854820506646566,
 RuntimeError('subtask does not exist'),
 0.4441797002749248,
 0.13539711569599308,
 0.13856082149851734,
 ValueError("Expected value argument (Tensor of shape (64, 252741)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution ZINB(gate: torch.Size([64, 252741])), but found invalid values:\ntensor([[0., 0., 0.,  ..., 0., 0., 0.],\n        [0., 0., 0.,  ..., 0., 0., 0.],\n        [0., 0., 0.,  ..., 0., 0., 0.],\n        ...,\n        [0., 0., 0.,  ..., 0., 0., 0.],\n        [0., 0., 0.,  ..., 0., 0., 0.],\n        [0., 0., 0.,  ..., 3., 0., 0.]], device='cuda:5')"),
 RuntimeError('subtask does not exist'),
 RuntimeError('subtask does not exist'),
 RuntimeError('subtask does not exist'),
 RuntimeError('subtask does not exist'),
 RuntimeError('subtask does not exist'),
 RuntimeError('subtask does not exist'),
 RuntimeError('subtask does not exist')]

In [3]:
train_leng=2000
test_leng=700
import argparse
import os
from argparse import Namespace

import anndata
import mudata
import numpy as np
import torch
import scanpy as sc
from dance.data import Data
from dance.datasets.multimodality import ModalityPredictionDataset
from dance.modules.multi_modality.predict_modality.scmogcn import ScMoGCNWrapper
from dance.transforms.cell_feature import BatchFeature
from dance.transforms.graph import ScMoGNNGraph

from dance.utils import set_seed
pipl_scores=[]

def pipeline(inductive=False, verbose=2, logger=None, **kwargs):
    PREFIX = kwargs["prefix"]
    os.makedirs(kwargs["log_folder"], exist_ok=True)
    os.makedirs(kwargs["model_folder"], exist_ok=True)
    os.makedirs(kwargs["result_folder"], exist_ok=True)
    if verbose > 1:
        logger = open(f"{kwargs['log_folder']}/{PREFIX}.log", "w")
        logger.write(str(kwargs) + "\n")

    subtask = kwargs["subtask"]
    cell_filter=True
    dataset = ModalityPredictionDataset(subtask, preprocess=kwargs["preprocessing"])
    # dataset.download_pathway() 测试时需要恢复
    modalities = dataset.load_raw_data()
    if cell_filter:
        for moda in modalities:
            sc.pp.filter_cells(moda,min_genes=3)

    modalities[0]=modalities[0][:train_leng]
    modalities[1]=modalities[1][:train_leng]
    modalities[2]=modalities[2][:test_leng]
    modalities[3]=modalities[3][:test_leng]
    mod1 = anndata.concat((modalities[0], modalities[2]))
    mod2 = anndata.concat((modalities[1], modalities[3]))
    mod1.var_names_make_unique()
    mod2.var_names_make_unique()
    mdata = mudata.MuData({"mod1": mod1, "mod2": mod2})
    train_size = modalities[0].shape[0]
    data = Data(mdata, train_size=train_size)
    data.set_config(feature_mod="mod1", label_mod="mod2")

    data = ScMoGNNGraph(inductive, kwargs["cell_init"], kwargs["pathway"], kwargs["subtask"], kwargs["pathway_weight"],
                        kwargs["pathway_threshold"], kwargs["pathway_path"])(data)
    if not kwargs["no_batch_features"]:
        data = BatchFeature()(data)

    idx = np.random.permutation(modalities[0].shape[0])
    split = {"train": idx[:-int(len(idx) * 0.15)], "valid": idx[-int(len(idx) * 0.15):]}
    kwargs["FEATURE_SIZE"] = modalities[0].shape[1]
    kwargs["TRAIN_SIZE"] = modalities[0].shape[0]
    kwargs["OUTPUT_SIZE"] = modalities[1].shape[1]
    kwargs["CELL_SIZE"] = modalities[0].shape[0] + modalities[2].shape[0]

    if inductive:
        g, gtest = data.data.uns["g"], data.data.uns["gtest"]
    else:
        gtest = g = data.data.uns["g"]

    _, y_train = data.get_train_data(return_type="torch")
    _, y_test = data.get_test_data(return_type="torch")
    y_train=y_train[:train_leng]
    y_test=y_test[:test_leng]
    if not kwargs["no_batch_features"]:
        batch_features = torch.from_numpy(data.data["mod1"].obsm["batch_features"]).float()
        kwargs["BATCH_NUM"] = batch_features.shape[1]
        if inductive:
            g.nodes["cell"].data["bf"] = batch_features[:kwargs["TRAIN_SIZE"]]
            gtest.nodes["cell"].data["bf"] = batch_features
        else:
            g.nodes["cell"].data["bf"] = batch_features

    model = ScMoGCNWrapper(Namespace(**kwargs))

    if kwargs["sampling"]:
        model.fit_with_sampling(g, y_train, split, not inductive, verbose, y_test, logger)
    else:
        model.fit(g, y_train, split, not inductive, verbose, y_test, logger)

    print(model.predict(g, np.arange(kwargs["TRAIN_SIZE"], kwargs["CELL_SIZE"]), device="cpu"))
    score=model.score(g, np.arange(kwargs["TRAIN_SIZE"], kwargs["CELL_SIZE"]), y_test, device="cpu")
    return score



parser = argparse.ArgumentParser()
parser.add_argument("-prefix", "--prefix", default="dance_openproblems_bmmc_atac2rna_test")
parser.add_argument("-t", "--subtask", default="openproblems_bmmc_cite_phase2_rna")
parser.add_argument("-pww", "--pathway_weight", default="pearson", choices=["cos", "one", "pearson"])
parser.add_argument("-pwth", "--pathway_threshold", type=float, default=-1.0)
parser.add_argument("-l", "--log_folder", default="./logs")
parser.add_argument("-m", "--model_folder", default="./models")
parser.add_argument("-r", "--result_folder", default="./results")
parser.add_argument("-e", "--epoch", type=int, default=5)
parser.add_argument("-nbf", "--no_batch_features", action="store_true")
parser.add_argument("-npw", "--pathway", action="store_true")
parser.add_argument("-res", "--residual", default="res_cat", choices=["none", "res_add", "res_cat"])
parser.add_argument("-inres", "--initial_residual", action="store_true")
parser.add_argument("-pwagg", "--pathway_aggregation", default="alpha",
                    choices=["sum", "attention", "two_gate", "one_gate", "alpha", "cat"])
parser.add_argument("-pwalpha", "--pathway_alpha", type=float, default=0.5)
parser.add_argument("-nrc", "--no_readout_concatenate", action="store_true")
parser.add_argument("-bs", "--batch_size", default=1000, type=int)
parser.add_argument("-nm", "--normalization", default="group", choices=["batch", "layer", "group", "none"])
parser.add_argument("-ac", "--activation", default="gelu", choices=["leaky_relu", "relu", "prelu", "gelu"])
parser.add_argument("-em", "--embedding_layers", default=1, type=int, choices=[1, 2, 3])
parser.add_argument("-ro", "--readout_layers", default=1, type=int, choices=[1, 2])
parser.add_argument("-conv", "--conv_layers", default=4, type=int, choices=[1, 2, 3, 4, 5, 6])
parser.add_argument("-agg", "--agg_function", default="mean", choices=["gcn", "mean"])
parser.add_argument("-device", "--device", default="cuda")
parser.add_argument("-sb", "--save_best", action="store_true")
parser.add_argument("-sf", "--save_final", action="store_true")
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-2)
parser.add_argument("-lrd", "--lr_decay", type=float, default=0.99)
parser.add_argument("-wd", "--weight_decay", type=float, default=1e-5)
parser.add_argument("-hid", "--hidden_size", type=int, default=48)
parser.add_argument("-edd", "--edge_dropout", type=float, default=0.3)
parser.add_argument("-mdd", "--model_dropout", type=float, default=0.2)
parser.add_argument("-es", "--early_stopping", type=int, default=0)
parser.add_argument("-c", "--cpu", type=int, default=1)
parser.add_argument("-or", "--output_relu", default="none", choices=["relu", "leaky_relu", "none"])
parser.add_argument("-i", "--inductive", action="store_true")
parser.add_argument("-sa", "--subpath_activation", action="store_true")
parser.add_argument("-ci", "--cell_init", default="none", choices=["none", "svd"])
parser.add_argument("-bas", "--batch_seperation", action="store_true")
parser.add_argument("-pwpath", "--pathway_path", default="./data/h.all.v7.4")
parser.add_argument("-seed", "--random_seed", type=int, default=777)
parser.add_argument("-ws", "--weighted_sum", action="store_true")
parser.add_argument("-samp", "--sampling", action="store_true")
parser.add_argument("-ns", "--node_sampling_rate", type=float, default=0.5)
parser.add_argument("-prep", "--preprocessing", default="none", choices=["none", "feature_selection", "svd"])
parser.add_argument("--span", default=0.3, type=float)
parser.add_argument("-lm", "--low_memory", type=bool, default=True)
for dataset in datasets:
    try:
        args = parser.parse_args(['--subtask',dataset,'--device','cuda:4','--span','1.0'])

        # For test only (low gpu memory setting; to reproduce competition result need >20G GPU memory - v100)
        if args.low_memory:
            print("WARNING: Running in low memory mode, some cli settings maybe overwritten!")
            args.preprocessing = "feature_selection"
            args.pathway = False
            args.sampling = True
            args.batch_size = 10000
            args.epoch = 10
        elif args.subtask == "openproblems_bmmc_multiome_phase2_mod2":
            args.preprocessing = "feature_selection"
        elif args.subtask in ["openproblems_bmmc_cite_phase2_mod2", "openproblems_bmmc_multiome_phase2_rna"]:
            args.sampling = True
            args.edge_dropout = 0

        # Regular settings
        if args.subtask.find("rna") == -1:
            args.pathway = False
        if args.sampling:
            args.pathway = False

        set_seed(args.random_seed)
        torch.set_num_threads(args.cpu)

        score=pipeline(**vars(args))
    except (Exception, BaseException) as e:
        v=e
    else:
        v=score
    finally:
        print(v)
        pipl_scores.append(v)
        torch.cuda.empty_cache()
"""To reproduce scMoGCN on other samples, please refer to command lines belows:
GEX to ADT (subset):
python scmogcn.py --subtask oopenproblems_bmmc_cite_phase2_rna_subset --device cuda

GEX to ADT:
python scmogcn.py --subtask oopenproblems_bmmc_cite_phase2_rna --device cuda

ADT to GEX:
python scmogcn.py --subtask openproblems_bmmc_cite_phase2_mod2 --device cuda

GEX to ATAC:
python scmogcn.py --subtask openproblems_bmmc_multiome_phase2_rna --device cuda

ATAC to GEX:
python scmogcn.py --subtask openproblems_bmmc_multiome_phase2_mod2 --device cuda
"""


[INFO][2023-09-19 22:01:29,845][dance][set_seed] Setting global random seed to 777
[INFO][2023-09-19 22:01:29,894][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_train_mod1.h5ad
[INFO][2023-09-19 22:01:30,013][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_train_mod2.h5ad




[INFO][2023-09-19 22:01:30,265][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_test_mod1.h5ad
[INFO][2023-09-19 22:01:30,331][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/10k_pbmc/10k_pbmc.10kanti_dataset_subset.output_test_mod2.h5ad
[INFO][2023-09-19 22:01:30,714][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 22:01:30,715][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'


{1}
epoch 0
training:  0.3652422674621689
valid:  0.3343837090721739
testing:  0.18146739410621704
epoch 1
training:  0.21908153716503745
valid:  0.3254443981713573
testing:  0.16418733275404632
epoch 2
training:  0.18339274644400624
valid:  0.3235610587692363
testing:  0.1603420157151971
epoch 3
training:  0.17604387479917108
valid:  0.32244603280291606
testing:  0.15802941970300863
epoch 4
training:  0.16405353635188533
valid:  0.32201306243133027
testing:  0.15707468314684553
epoch 5
training:  0.16612899163267797
valid:  0.3217542347692322
testing:  0.15647788714720917
epoch 6
training:  0.15674837858290816
valid:  0.3216336382370442
testing:  0.15623969997789336
epoch 7
training:  0.15608930907934368
valid:  0.32152147785841373
testing:  0.1560403848331432
epoch 8
training:  0.15424512224603673
valid:  0.3214127554673244
testing:  0.15582420626365792
epoch 9
training:  0.15984403136630385
valid:  0.3213803008142016
testing:  0.155770370338794
min testing 0.155770370338794 9
conver

[INFO][2023-09-19 22:02:23,663][dance][set_seed] Setting global random seed to 777
[INFO][2023-09-19 22:02:23,665][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/5k_pbmc_subset/5k_pbmc_subset.5kanti_dataset.output_train_mod1.h5ad
[INFO][2023-09-19 22:02:23,718][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/5k_pbmc_subset/5k_pbmc_subset.5kanti_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 22:02:23,731][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/5k_pbmc_subset/5k_pbmc_subset.5kanti_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 22:02:23,775][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/5k_pbmc_subset/5k_pbmc_subset.5kanti_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 22:02:23,825][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 22:02:23,826][dan

0.155770370338794


[INFO][2023-09-19 22:02:24,067][dance][set_seed] Setting global random seed to 777
[INFO][2023-09-19 22:02:24,068][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/pbmc_cite/pbmc_cite.citeanti_dataset.output_train_mod1.h5ad


{1}
Expect number of features to match number of nodes (len(u)). Got 137 and 247 instead.


[INFO][2023-09-19 22:02:24,389][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/pbmc_cite/pbmc_cite.citeanti_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 22:02:24,416][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/pbmc_cite/pbmc_cite.citeanti_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 22:02:24,563][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/pbmc_cite/pbmc_cite.citeanti_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 22:02:25,578][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 22:02:25,579][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'


{1}
epoch 0
training:  517.6118447978175
valid:  535.7553662820374
testing:  517.4888887696044
epoch 1
training:  512.6652111758706
valid:  533.7501756439992
testing:  515.4013605919178
epoch 2
training:  510.5312306803571
valid:  531.4411126832398
testing:  512.9911975365659
epoch 3
training:  509.37218404129607
valid:  528.8447610121518
testing:  510.27988643488584
epoch 4
training:  509.73444373026234
valid:  525.9819982661004
testing:  507.28613104440376
epoch 5
training:  512.4651268623065
valid:  522.8667373241484
testing:  504.02833501699087
epoch 6
training:  507.4267311642145
valid:  519.5155796508898
testing:  500.5200108387276
epoch 7
training:  501.42932577782085
valid:  515.9715956329379
testing:  496.8052781271552
epoch 8
training:  499.89595479959627
valid:  512.2748285832517
testing:  492.9255458484577
epoch 9
training:  495.70164287603484
valid:  508.440507827612
testing:  488.9080307174346
min testing 488.9080307174346 9
converged testing -1 488.9080307174346
tensor([

[INFO][2023-09-19 22:03:00,809][dance][set_seed] Setting global random seed to 777
[INFO][2023-09-19 22:03:00,811][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod1.h5ad


488.9080307174346


[INFO][2023-09-19 22:03:03,083][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_train_mod2.h5ad
[INFO][2023-09-19 22:03:03,178][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_test_mod1.h5ad
[INFO][2023-09-19 22:03:04,145][dance][_load_raw_data] Loading /home/zyxing/dance/examples/multi_modality/predict_modality/data/openproblems_2022_multi_atac2gex/openproblems_2022_multi_atac2gex.open_dataset.output_test_mod2.h5ad
[INFO][2023-09-19 22:03:12,647][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-09-19 22:03:12,648][dance][set_config_from_dict] Setting config 'label_mod' to 'mod2'


{'1'}
epoch 0
training:  82.27929581329224
valid:  88.21013451997027
testing:  119.55090524672951
epoch 1
training:  83.54160114035253
valid:  87.74554553775792
testing:  119.15157838993363
epoch 2
training:  82.55304757586178
valid:  87.31572951704493
testing:  118.77678151947458
epoch 3
training:  83.92819021807422
valid:  86.87693624191205
testing:  118.4019364482165
epoch 4
training:  82.51685049269861
valid:  86.41435879673007
testing:  118.01698103451045
epoch 5
training:  82.8972314167518
valid:  85.91974248353809
testing:  117.61116333999719
epoch 6
training:  80.72422915008232
valid:  85.37913777948773
testing:  117.16799837732998


In [None]:
pipl_scores

[0.155770370338794,
 dgl._ffi.base.DGLError('Expect number of features to match number of nodes (len(u)). Got 137 and 247 instead.'),
 488.9080307174346,
 115.5787365216349,
 2.49032743392298,
 0.09841357311221884,
 FileNotFoundError(2, 'No such file or directory'),
 FileNotFoundError(2, 'No such file or directory'),
 FileNotFoundError(2, 'No such file or directory'),
 FileNotFoundError(2, 'No such file or directory'),
 FileNotFoundError(2, 'No such file or directory'),
 FileNotFoundError(2, 'No such file or directory'),
 FileNotFoundError(2, 'No such file or directory'),
 FileNotFoundError(2, 'No such file or directory'),
 FileNotFoundError(2, 'No such file or directory')]