<a href="https://colab.research.google.com/github/OmicsML/dance-tutorials/blob/dev/dance_tutorial_remy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Installation

DANCE is published on [PyPI](https://pypi.org/project/pydance/). Thus, installing DANCE is as easy as

```bash
pip install pydance
```

Or, to install the latest dev version on GitHub as

```bash
pip install git+https://github.com/OmicsML/dance
```

But becaues DANCE includes many deep learning based methods, there are also deep learning library dependencies, such as [PyTorch](https://pytorch.org/), [PyG](https://www.pyg.org/), and [DGL](https://www.dgl.ai/). We will walk through the installation process below.

### 1.1. Install torch related dependencies

In [None]:
# Colab comes with torch installed, so we do not need to install pytorch here
# !pip3 install torch torchvision torchaudio

!pip install torch_geometric==2.3.1
!pip install dgl==1.1.0 -f https://data.dgl.ai/wheels/cu117/repo.html
!pip install torchnmf==0.3.4

### 1.2 Install latest dev version of DANCE

In [None]:
!pip install git+https://github.com/OmicsML/dance.git

# # Clone DANCE repo and install
# !git clone https://github.com/OmicsML/dance.git pydance
# %cd /content/pydance
# !pip install -e .

### 1.3 Check if DANCE is installed successfully

In [None]:
import dance
print(f"Installed DANCE version {dance.__version__}")

Installed DANCE version 1.0.0-rc.1


## 2. Data loading and processing

DANCE comes with several benchmarking datasets in a unified dataset object format. This makes data downloading, processing, and caching easy for users through our dataset object interface.

### 2.1. Check available data options and load data object

In [None]:
from pprint import pprint
from dance.datasets.singlemodality import ClusteringDataset, ScDeepSortDataset

In [None]:
print("Available dataset option for ClusteringDataset:")
pprint(ClusteringDataset.get_avalilable_data())

print("\nAvailable dataset option for ScDeepSortDataset:")
pprint(ScDeepSortDataset.get_avalilable_data())

Available dataset option for ClusteringDataset:
['10X_PBMC', 'mouse_ES_cell', 'mouse_bladder_cell', 'worm_neuron_cell']

Available dataset option for ScDeepSortDataset:
[{'dataset': '3285', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '753', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '4682', 'species': 'mouse', 'split': 'train', 'tissue': 'Kidney'},
 {'dataset': '1970', 'species': 'mouse', 'split': 'train', 'tissue': 'Spleen'},
 {'dataset': '2695', 'species': 'mouse', 'split': 'test', 'tissue': 'Brain'},
 {'dataset': '203', 'species': 'mouse', 'split': 'test', 'tissue': 'Kidney'},
 {'dataset': '1759', 'species': 'mouse', 'split': 'test', 'tissue': 'Spleen'}]


In [None]:
dataset = ClusteringDataset("10X_PBMC")

# The dataset object do not contain data, it only loads the data upon calling
# the load_data function
data = dataset.load_data()
print(data)

[INFO][2023-06-24 13:48:22,085][dance][download_file] Downloading: 10X_PBMC/mouse_bladder_cell.h5 Bytes: 5,479,777
100%|██████████| 5.23M/5.23M [00:00<00:00, 27.4MB/s]
[INFO][2023-06-24 13:48:23,173][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-06-24 13:48:23,174][dance][wrapped_func] Took 0:00:02.114399 to load and process data.


Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'


In [None]:
dataset = ScDeepSortDataset(species="mouse", tissue="Brain",
                            train_dataset=["3285", "753"], test_dataset=["2695"])
data = dataset.load_data()
print(data)

[INFO][2023-06-24 13:53:00,958][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_data.csv
[INFO][2023-06-24 13:53:11,916][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_data.csv
[INFO][2023-06-24 13:53:14,119][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_data.csv
[INFO][2023-06-24 13:53:23,130][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_celltype.csv
[INFO][2023-06-24 13:53:23,151][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_celltype.csv
[INFO][2023-06-24 13:53:23,161][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_celltype.csv
[INFO][2023-06-24 13:53:26,847][dance][_load_raw_data] Loaded expression data: AnnData object with n_obs × n_vars = 6733 × 19856
[INFO][2023-06-24 13:53:26,848][dance][_load_raw_data] Number of training samples: 4,038
[INFO][2023-06-24 13:53:26,852][dance][_load_raw_data] Number of testing samples: 2,695
[INFO][2023-06-24 13:53:26,855][dance][_load

Data object that wraps (.data):
AnnData object with n_obs × n_vars = 6733 × 19856
    uns: 'dance_config'
    obsm: 'cell_type'


### 2.2. A quick primer on AnnData

<img
  src="https://raw.githubusercontent.com/scverse/anndata/main/docs/_static/img/anndata_schema.svg"
  align="right" width="450" alt="image"
/>

The [dance data object](https://github.com/OmicsML/dance/blob/912405cb5ab43caf16eb22b9216865c7e3976eaf/dance/data/base.py#L40) is heavily built on top of [AnnData](https://anndata.readthedocs.io/en/latest/), which is a widely used data object to represent, store, and manipulate large annotated matrices.

> anndata is a Python package for handling annotated data matrices in memory and on disk, positioned between pandas and xarray. anndata offers a broad range of computationally efficient features...

AnnData falls into the ecosystem of scVerse, providing extra advantage and ease for handeling single-cell data using, for example, [Scanpy](https://scanpy.readthedocs.io/en/stable/).

In [None]:
adata = data.data
adata

AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'

In [None]:
adata.X

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [None]:
adata.obsm["Group"]

array([ 1,  5, 16, ..., 10,  1,  6], dtype=int32)

### 2.3. Data pre-processing using transforms

Applying individual in-place transformations to data


In [None]:
import scanpy as sc

from dance.transforms import AnnDataTransform, FilterGenesPercentile

# Reload data
data = dataset.load_data()

# Library size normalization
AnnDataTransform(sc.pp.normalize_total, target_sum=1e-4)(data)

# Shifted log transformation
AnnDataTransform(sc.pp.log1p)(data)

# Filter out genes that have extreme coefficient of variation
FilterGenesPercentile(min_val=1, max_val=99, mode="sum")(data)



Composing transformations into a a pre-precoessing pipeline (feat. caching)

In [None]:
from dance.transforms import Compose

preprocessing_pipeline = Compose(
    AnnDataTransform(sc.pp.normalize_total, target_sum=1e-4),
    AnnDataTransform(sc.pp.log1p),
    FilterGenesPercentile(min_val=1, max_val=99, mode="sum"),
)

# Now we can apply the preprocessing pipeline transformation to our data
data = dataset.load_data()
preprocessing_pipeline(data)

# Alternatively, we can also pass the transformation to the loading function
data = dataset.load_data(transform=preprocessing_pipeline)

[INFO][2023-06-24 11:54:34,794][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-06-24 11:54:34,795][dance][wrapped_func] Took 0:00:00.827345 to load and process data.


In [None]:
data = dataset.load_data(transform=preprocessing_pipeline, cache=True)

[INFO][2023-06-24 11:54:46,010][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-06-24 11:54:47,141][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20463
    uns: 'dance_config', 'log1p'
    obsm: 'Group'
[INFO][2023-06-24 11:54:48,140][dance][load_data] Saved processed data to cache: /content/pydance/10X_PBMC/cache/97d10b6e54d74aff33617925c4696e31.pkl
[INFO][2023-06-24 11:54:48,141][dance][wrapped_func] Took 0:00:03.182861 to load and process data.


## Single modality tasks

- Example:
  - Main: Cell Type Annotation
  - ...

In [None]:
import argparse
import pprint
from typing import get_args

import numpy as np

from dance import logger
from dance.datasets.singlemodality import ScDeepSortDataset
from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
from dance.typing import LogLevel

if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--batch_size", type=int, default=1024, help="Batch size")
    parser.add_argument("--cache", action="store_true", help="Cache processed data.")
    parser.add_argument("--device", default="cpu", help="Computation device.")
    parser.add_argument("--hidden_dims", nargs="+", type=int, default=[2000], help="Hidden dimensions.")
    parser.add_argument("--lambd", type=float, default=0.01, help="Regularization parameter")
    parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate")
    parser.add_argument("--log_level", type=str, default="INFO", choices=get_args(LogLevel))
    parser.add_argument("--nofilter", action="store_true", help="Disable filtering genes by expression summaries.")
    parser.add_argument(
        "--normalize", action="store_true", help="Whether to perform the normalization described in ACTINN. "
        "Disabled by default since the scDeepSort data is already normalized")
    parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs")
    parser.add_argument("--print_cost", action="store_true", help="Print cost when training")
    parser.add_argument("--runs", type=int, default=10, help="Number of repetitions")
    parser.add_argument("--seed", type=int, default=0, help="Initial seed random, offset for each repeatition")
    parser.add_argument("--species", default="mouse")
    parser.add_argument("--test_dataset", nargs="+", default=[1759], help="List of testing dataset ids.")
    parser.add_argument("--tissue", default="Spleen")
    parser.add_argument("--train_dataset", nargs="+", default=[1970], help="List of training dataset ids.")

    args = parser.parse_args()
    logger.setLevel(args.log_level)
    logger.info(f"Running SVM with the following parameters:\n{pprint.pformat(vars(args))}")

    # Initialize model and get model specific preprocessing pipeline
    model = ACTINN(hidden_dims=args.hidden_dims, lambd=args.lambd, device=args.device)
    preprocessing_pipeline = model.preprocessing_pipeline(normalize=args.normalize, filter_genes=not args.nofilter)

    # Load data and perform necessary preprocessing
    dataloader = ScDeepSortDataset(train_dataset=args.train_dataset, test_dataset=args.test_dataset, tissue=args.tissue,
                                   species=args.species)
    data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)

    # Obtain training and testing data
    x_train, y_train = data.get_train_data(return_type="torch")
    x_test, y_test = data.get_test_data(return_type="torch")

    # Train and evaluate models for several rounds
    scores = []
    for k in range(args.runs):
        model.fit(x_train, y_train, seed=args.seed + k, lr=args.learning_rate, num_epochs=args.num_epochs,
                  batch_size=args.batch_size, print_cost=args.print_cost)
        scores.append(score := model.score(x_test, y_test))
        print(f"{score=:.4f}")
    print(f"Score: {np.mean(scores):04.3f} +/- {np.std(scores):04.3f}")

## Multi-modality tasks
- Example: Modality prediction


## Spatial transcriptomics tasks
- Example: Spatial domain


### Argument parsing setting

In [None]:
import argparse
from dance.transforms import Compose
from dance.datasets.spatial import SpatialLIBDDataset
from dance.modules.spatial.spatial_domain.spagcn import SpaGCN, refine
from dance.utils import set_seed
from dance.transforms import AnnDataTransform, CellPCA, Compose, FilterGenesMatch, SetConfig
from dance.transforms.graph import SpaGCNGraph, SpaGCNGraph2D

parser = argparse.ArgumentParser()
parser.add_argument("-f", required=False)

parser.add_argument("--cache", action="store_true", help="Cache processed data.")
parser.add_argument("--sample_number", type=str, default="151673",
                    help="12 human dorsolateral prefrontal cortex datasets for the spatial domain task.")
parser.add_argument("--beta", type=int, default=49, help="")
parser.add_argument("--alpha", type=int, default=1, help="")
parser.add_argument("--p", type=float, default=0.05,
                    help="percentage of total expression contributed by neighborhoods.")
parser.add_argument("--l", type=float, default=0.5, help="the parameter to control percentage p.")
parser.add_argument("--start", type=float, default=0.01, help="starting value for searching l.")
parser.add_argument("--end", type=float, default=1000, help="ending value for searching l.")
parser.add_argument("--tol", type=float, default=5e-3, help="tolerant value for searching l.")
parser.add_argument("--max_run", type=int, default=200, help="max runs.")
parser.add_argument("--epochs", type=int, default=200, help="Number of epochs.")
parser.add_argument("--n_clusters", type=int, default=7, help="the number of clusters")
parser.add_argument("--step", type=float, default=0.1, help="")
parser.add_argument("--lr", type=float, default=0.05, help="learning rate")
parser.add_argument("--random_state", type=int, default=100, help="")
args = parser.parse_args()
set_seed(args.random_state)

[INFO][2023-06-24 03:54:36,167][dance][set_seed] Setting global random seed to 100


### Initialize model and get model specific preprocessing *pipeline*

In [None]:
model = SpaGCN()
preprocessing_pipeline = model.preprocessing_pipeline(alpha=args.alpha, beta=args.beta) # In SpaGCN, alpha and beta are used for graph construction


### User defined customized trasform

In [None]:
preprocessing_pipeline = Compose(
    FilterGenesMatch(prefixes=["ERCC", "MT-"]),
    SpaGCNGraph(alpha=1, beta=49),
    SpaGCNGraph2D(),
    CellPCA(n_components=40),
    SetConfig({
        "feature_channel": ["CellPCA", "SpaGCNGraph", "SpaGCNGraph2D"],
        "feature_channel_type": ["obsm", "obsp", "obsp"],
        "label_channel": "label",
        "label_channel_type": "obs"
    }),
    )

### Load data and perform necessary preprocessing

In [None]:
dataloader = SpatialLIBDDataset(data_id=args.sample_number)
data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)
(x, adj, adj_2d), y = data.get_train_data()

[INFO][2023-06-24 03:54:37,850][dance][_load_raw_data] Loading image data from data/spatial/151673/151673_full_image.tif
[INFO][2023-06-24 03:54:38,829][dance][_load_raw_data] Loading expression data from data/spatial/151673/151673_raw_feature_bc_matrix.h5
  utils.warn_names_duplicates("var")
[INFO][2023-06-24 03:54:39,428][dance][_load_raw_data] Loading spatial info from data/spatial/151673/tissue_positions_list.txt
[INFO][2023-06-24 03:54:39,441][dance][_load_raw_data] Loading label info from data/spatial/151673/cluster_labels.csv
  utils.warn_names_duplicates("var")
[INFO][2023-06-24 03:54:39,514][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 3639 × 33538
    obs: 'key', 'ground_truth', 'SpatialDE_PCA', 'SpatialDE_pool_PCA', 'HVG_PCA', 'pseudobulk_PCA', 'markers_PCA', 'SpatialDE_UMAP', 'SpatialDE_pool_UMAP', 'HVG_UMAP', 'pseudobulk_UMAP', 'markers_UMAP', 'SpatialDE_PCA_spatial', 'SpatialDE_pool_PCA_spatial', 'HVG_PCA_spatial'

In [None]:
data

Data object that wraps (.data):
AnnData object with n_obs × n_vars = 3639 × 33516
    obs: 'key', 'ground_truth', 'SpatialDE_PCA', 'SpatialDE_pool_PCA', 'HVG_PCA', 'pseudobulk_PCA', 'markers_PCA', 'SpatialDE_UMAP', 'SpatialDE_pool_UMAP', 'HVG_UMAP', 'pseudobulk_UMAP', 'markers_UMAP', 'SpatialDE_PCA_spatial', 'SpatialDE_pool_PCA_spatial', 'HVG_PCA_spatial', 'pseudobulk_PCA_spatial', 'markers_PCA_spatial', 'SpatialDE_UMAP_spatial', 'SpatialDE_pool_UMAP_spatial', 'HVG_UMAP_spatial', 'pseudobulk_UMAP_spatial', 'markers_UMAP_spatial', 'label'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'image', 'dance_config'
    obsm: 'spatial', 'spatial_pixel', 'CellPCA'
    obsp: 'SpaGCNGraph', 'SpaGCNGraph2D'

In [None]:
data.data.obsp["SpaGCNGraph"].shape

(3639, 3639)

In [None]:
data.x[0].shape

(3639, 40)

In [None]:
data.x[1].shape

(3639, 3639)

In [None]:
data.x[1].shape

(3639, 3639)

### Train and evaluate model

In [None]:
l = model.search_l(args.p, adj, start=args.start, end=args.end, tol=args.tol, max_run=args.max_run)
model.set_l(l)
res = model.search_set_res((x, adj), l=l, target_num=args.n_clusters, start=0.4, step=args.step, tol=args.tol, lr=args.lr, epochs=args.epochs, max_run=args.max_run)

[INFO][2023-06-24 03:54:48,962][dance][search_l] Run 1: l [0.01, 1000], p [0.0, 3629.7406229057055]
[INFO][2023-06-24 03:54:49,019][dance][search_l] Run 2: l [0.01, 500.005], p [0.0, 3605.191650390625]
[INFO][2023-06-24 03:54:49,074][dance][search_l] Run 3: l [0.01, 250.0075], p [0.0, 3510.283935546875]
[INFO][2023-06-24 03:54:49,129][dance][search_l] Run 4: l [0.01, 125.00874999999999], p [0.0, 3176.004150390625]
[INFO][2023-06-24 03:54:49,182][dance][search_l] Run 5: l [0.01, 62.509375], p [0.0, 2292.207275390625]
[INFO][2023-06-24 03:54:49,236][dance][search_l] Run 6: l [0.01, 31.2596875], p [0.0, 1045.6600341796875]
[INFO][2023-06-24 03:54:49,292][dance][search_l] Run 7: l [0.01, 15.63484375], p [0.0, 292.86767578125]
[INFO][2023-06-24 03:54:49,369][dance][search_l] Run 8: l [0.01, 7.822421875], p [0.0, 59.20479965209961]
[INFO][2023-06-24 03:54:49,464][dance][search_l] Run 9: l [0.01, 3.9162109375], p [0.0, 9.6504545211792]
[INFO][2023-06-24 03:54:49,546][dance][search_l] Run 10: 

In [None]:
pred = model.fit_predict((x, adj), init_spa=True, init="louvain", tol=args.tol, lr=args.lr, epochs=args.epochs, res=res)

[INFO][2023-06-24 03:55:08,464][dance][fit] Initializing cluster centers with louvain, resolution = 0.5
[INFO][2023-06-24 03:55:09,035][dance][fit] Epoch 0
[INFO][2023-06-24 03:55:09,828][dance][fit] Epoch 10
[INFO][2023-06-24 03:55:10,645][dance][fit] Epoch 20
[INFO][2023-06-24 03:55:11,466][dance][fit] Epoch 30
[INFO][2023-06-24 03:55:12,259][dance][fit] Epoch 40
[INFO][2023-06-24 03:55:13,291][dance][fit] Epoch 50
[INFO][2023-06-24 03:55:14,394][dance][fit] Epoch 60
[INFO][2023-06-24 03:55:14,578][dance][fit] delta_label 0.003572410002748008 < tol 0.005
[INFO][2023-06-24 03:55:14,586][dance][fit] Reach tolerance threshold. Stopping training.
[INFO][2023-06-24 03:55:14,588][dance][fit] Total epoch: 61


In [None]:
score = model.default_score_func(y, pred)
print(f"ARI: {score:.4f}")

ARI: 0.1610


In [None]:
!wget https://www.dropbox.com/sh/dg10o9wmfmd2cpi/AABWGBng2HeU3g14D1dD20Wia?dl=1 -O openproblems.zip
!unzip openproblems.zip

--2023-04-26 00:41:11--  https://www.dropbox.com/sh/dg10o9wmfmd2cpi/AABWGBng2HeU3g14D1dD20Wia?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6019:18::a27d:412
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /sh/dl/dg10o9wmfmd2cpi/AABWGBng2HeU3g14D1dD20Wia [following]
--2023-04-26 00:41:11--  https://www.dropbox.com/sh/dl/dg10o9wmfmd2cpi/AABWGBng2HeU3g14D1dD20Wia
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uce2b20a115a81401477259fe87e.dl.dropboxusercontent.com/zip_download_get/BeTTrWzKtMmgllMt3wIwhg3LkPVcjYs2JS640oWG-k8aFWhtseSMjK70mncZtF_eoFztQuQuYTAaFnmSUM9fInGI51BDRgPwSuHEXHtS6OPPUA?dl=1# [following]
--2023-04-26 00:41:12--  https://uce2b20a115a81401477259fe87e.dl.dropboxusercontent.com/zip_download_get/BeTTrWzKtMmgllMt3wIwhg3LkPVcjYs2JS640oWG-k8aFWhtseSMjK70mncZtF_eoFztQuQuYTAaFnmSUM9fInGI5

In [None]:
!git clone https://github.com/OmicsML/dance.git pydance
!pip install torch_geometric
!pip install packaging
!pip install  dgl -f https://data.dgl.ai/wheels/cu117/repo.html
!pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html
!pip install -r /content/pydance/requirements.txt
!pip install -e /content/pydance

fatal: destination path 'pydance' already exists and is not an empty directory.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import argparse
import random
import torch

OPTIMIZER_DICT = {
    "adam": torch.optim.Adam,
    "rmsprop": torch.optim.RMSprop,
}
rndseed = random.randint(0, 2147483647)
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--subtask", default="openproblems_bmmc_cite_phase2_rna")
parser.add_argument("-device", "--device", default="cuda")
parser.add_argument("-cpu", "--cpus", default=1, type=int)
parser.add_argument("-seed", "--rnd_seed", default=rndseed, type=int)
parser.add_argument("-m", "--model_folder", default="./models")
parser.add_argument("--outdir", "-o", default="./logs", help="Directory to output to")
parser.add_argument("--lossweight", type=float, default=1., help="Relative loss weight")
parser.add_argument("--lr", "-l", type=float, default=0.01, help="Learning rate")
parser.add_argument("--batchsize", "-b", type=int, default=64, help="Batch size")
parser.add_argument("--hidden", type=int, default=64, help="Hidden dimensions")
parser.add_argument("--earlystop", type=int, default=20, help="Early stopping after N epochs")
parser.add_argument("--naive", "-n", action="store_true", help="Use a naive model instead of lego model")
parser.add_argument("--resume", action="store_true")
parser.add_argument("--max_epochs", type=int, default=500)
args_defaults = parser.parse_args([])
args = argparse.Namespace(**vars(args_defaults))

args

Namespace(subtask='openproblems_bmmc_cite_phase2_rna', device='cuda', cpus=1, rnd_seed=1000370692, model_folder='./models', outdir='./logs', lossweight=1.0, lr=0.01, batchsize=64, hidden=64, earlystop=20, naive=False, resume=False, max_epochs=500)

In [None]:
import logging
import os

import anndata
import mudata
import torch
import scanpy as sc

from dance import logger
from dance.data import Data
from dance.datasets.multimodality import ModalityPredictionDataset
from dance.modules.multi_modality.predict_modality.babel import BabelWrapper
from dance.utils import set_seed
from sklearn.decomposition import TruncatedSVD
from scipy.sparse import csr_matrix

torch.set_num_threads(args.cpus)
rndseed = args.rnd_seed
set_seed(rndseed)
device = args.device
os.makedirs(args.model_folder, exist_ok=True)
os.makedirs(args.outdir, exist_ok=True)
args.outdir = os.path.abspath(args.outdir)

if not os.path.isdir(os.path.dirname(args.outdir)):
    os.makedirs(os.path.dirname(args.outdir))

# Specify output log file
fh = logging.FileHandler(f"{args.outdir}/training_{args.subtask}_{args.rnd_seed}.log", "w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)

for arg in vars(args):
    logger.info(f"Parameter {arg}: {getattr(args, arg)}")

dataset = ModalityPredictionDataset(args.subtask)
tsvd = TruncatedSVD(n_components=256)
X_train = anndata.read_h5ad('Gex_processed_training.h5ad')
X_test = anndata.read_h5ad('Gex_processed_testing.h5ad')
X_train = anndata.AnnData(X=csr_matrix(tsvd.fit_transform(X_train.X)), obs=X_train.obs)
X_test = anndata.AnnData(X=csr_matrix(tsvd.transform(X_test.X)), obs=X_test.obs)

dataset.modalities = [  X_train,
                        anndata.read_h5ad('Adt_processed_training.h5ad'),
                        X_test,
                        anndata.read_h5ad('Adt_processed_testing.h5ad')]


# Construct data object
mod1 = anndata.concat((dataset.modalities[0], dataset.modalities[2]))
mod2 = anndata.concat((dataset.modalities[1], dataset.modalities[3]))
mod1.var_names_make_unique()
mod2.var_names_make_unique()
mdata = mudata.MuData({"mod1": mod1, "mod2": mod2})
mdata.var_names_make_unique()
train_size = dataset.modalities[0].shape[0]
del dataset.modalities, X_train, X_test
data = Data(mdata, train_size=train_size)
data.set_config(feature_mod="mod1", label_mod="mod2")

# Obtain training and testing data
x_train, y_train = data.get_train_data(return_type="torch")
x_test, y_test = data.get_test_data(return_type="torch")

# Train and evaluate the model
model = BabelWrapper(args, dim_in=x_train.shape[1], dim_out=y_train.shape[1])
model.fit(x_train.float(), y_train.float(), val_ratio=0.15)
print(model.predict(x_test.float()))
print(model.score(x_test.float(), y_test.float()))

[INFO][2023-04-26 01:35:03,316][dance][set_seed] Setting global random seed to 1000370692
[INFO][2023-04-26 01:35:03,336][dance][<cell line: 33>] Parameter subtask: openproblems_bmmc_cite_phase2_rna
[INFO][2023-04-26 01:35:03,341][dance][<cell line: 33>] Parameter device: cuda
[INFO][2023-04-26 01:35:03,345][dance][<cell line: 33>] Parameter cpus: 1
[INFO][2023-04-26 01:35:03,361][dance][<cell line: 33>] Parameter rnd_seed: 1000370692
[INFO][2023-04-26 01:35:03,365][dance][<cell line: 33>] Parameter model_folder: ./models
[INFO][2023-04-26 01:35:03,372][dance][<cell line: 33>] Parameter outdir: /content/logs
[INFO][2023-04-26 01:35:03,386][dance][<cell line: 33>] Parameter lossweight: 1.0
[INFO][2023-04-26 01:35:03,394][dance][<cell line: 33>] Parameter lr: 0.01
[INFO][2023-04-26 01:35:03,403][dance][<cell line: 33>] Parameter batchsize: 64
[INFO][2023-04-26 01:35:03,414][dance][<cell line: 33>] Parameter hidden: 64
[INFO][2023-04-26 01:35:03,422][dance][<cell line: 33>] Parameter earl

epoch:  1
training (sum of 4 losses): 2.85336529314518
validation (prediction loss): 0.3432220473770043
epoch:  2
training (sum of 4 losses): 2.610635186093194
validation (prediction loss): 0.3331627611409845
epoch:  3
training (sum of 4 losses): 2.58757663496903
validation (prediction loss): 0.3309708068422586
epoch:  4
training (sum of 4 losses): 2.573169892174857
validation (prediction loss): 0.3261576296222578
epoch:  5
training (sum of 4 losses): 2.5631031351430074
validation (prediction loss): 0.32456449571622425
epoch:  6
training (sum of 4 losses): 2.5579753650086268
validation (prediction loss): 0.3211859431323911
epoch:  7
training (sum of 4 losses): 2.5499601781368257
validation (prediction loss): 0.3224165896502167
epoch:  8
training (sum of 4 losses): 2.5460016880716596
validation (prediction loss): 0.32148665859457726
epoch:  9
training (sum of 4 losses): 2.5443559527397155
validation (prediction loss): 0.31938622254656607
epoch:  10
training (sum of 4 losses): 2.54078204