<a href="https://colab.research.google.com/github/OmicsML/dance-tutorials/blob/dev/dance_tutorial_jiayuan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Installation

DANCE is published on [PyPI](https://pypi.org/project/pydance/). Thus, installing DANCE is as easy as

```bash
pip install pydance
```

Or, to install the latest dev version on GitHub as

```bash
pip install git+https://github.com/OmicsML/dance
```

But becaues DANCE includes many deep learning based methods, there are also deep learning library dependencies, such as [PyTorch](https://pytorch.org/), [PyG](https://www.pyg.org/), and [DGL](https://www.dgl.ai/). We will walk through the installation process below.

### 1.1. Install torch related dependencies

In [1]:
# Colab comes with torch installed, so we do not need to install pytorch here
# !pip3 install torch torchvision torchaudio

!pip install torch_geometric==2.3.1
!pip install dgl==1.1.0 -f https://data.dgl.ai/wheels/cu117/repo.html
!pip install torchnmf==0.3.4

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_geometric==2.3.1
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910459 sha256=3e87ddfce81fbb234dfccef50bdfa64d814455ac4527766fc1d9843c9bb0dff4
  Stored in directory: /root/.cache/pip/wheels/ac/dc/30/e2874821ff308ee67dcd7a66dbde912411e19e35a1addda028
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch

### 1.2 Install latest dev version of DANCE

In [2]:
!pip install git+https://github.com/OmicsML/dance.git

# # Clone DANCE repo and install
# !git clone https://github.com/OmicsML/dance.git pydance
# %cd /content/pydance
# !pip install -e .

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/OmicsML/dance.git
  Cloning https://github.com/OmicsML/dance.git to /tmp/pip-req-build-osb9_2tp
  Running command git clone --filter=blob:none --quiet https://github.com/OmicsML/dance.git /tmp/pip-req-build-osb9_2tp
  Resolved https://github.com/OmicsML/dance.git to commit e8b618fe937d69a59cb91a7ad07a0df307106b0b
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting leidenalg (from pydance==1.0.0rc1)
  Downloading leidenalg-0.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mudata (from pydance==1.0.0rc1)
  Downloading mudata-0.2.3-py3-none-any.whl (24 kB)
Collecting pyro-p

### 1.3 Check if DANCE is installed successfully

In [3]:
import dance
print(f"Installed DANCE version {dance.__version__}")

Installed DANCE version 1.0.0-rc.1


## 2. Data loading and processing

DANCE comes with several benchmarking datasets in a unified dataset object format. This makes data downloading, processing, and caching easy for users through our dataset object interface.

### 2.1. Check available data options and load data object

In [None]:
from pprint import pprint
from dance.datasets.singlemodality import ClusteringDataset, ScDeepSortDataset

In [None]:
print("Available dataset option for ClusteringDataset:")
pprint(ClusteringDataset.get_avalilable_data())

print("\nAvailable dataset option for ScDeepSortDataset:")
pprint(ScDeepSortDataset.get_avalilable_data())

Available dataset option for ClusteringDataset:
['10X_PBMC', 'mouse_ES_cell', 'mouse_bladder_cell', 'worm_neuron_cell']

Available dataset option for ScDeepSortDataset:
[{'dataset': '3285', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '753', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '4682', 'species': 'mouse', 'split': 'train', 'tissue': 'Kidney'},
 {'dataset': '1970', 'species': 'mouse', 'split': 'train', 'tissue': 'Spleen'},
 {'dataset': '2695', 'species': 'mouse', 'split': 'test', 'tissue': 'Brain'},
 {'dataset': '203', 'species': 'mouse', 'split': 'test', 'tissue': 'Kidney'},
 {'dataset': '1759', 'species': 'mouse', 'split': 'test', 'tissue': 'Spleen'}]


In [None]:
dataset = ClusteringDataset("10X_PBMC")

# The dataset object do not contain data, it only loads the data upon calling
# the load_data function
data = dataset.load_data()
print(data)

[INFO][2023-06-24 13:48:22,085][dance][download_file] Downloading: 10X_PBMC/mouse_bladder_cell.h5 Bytes: 5,479,777
100%|██████████| 5.23M/5.23M [00:00<00:00, 27.4MB/s]
[INFO][2023-06-24 13:48:23,173][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-06-24 13:48:23,174][dance][wrapped_func] Took 0:00:02.114399 to load and process data.


Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'


In [None]:
dataset = ScDeepSortDataset(species="mouse", tissue="Brain",
                            train_dataset=["3285", "753"], test_dataset=["2695"])
data = dataset.load_data()
print(data)

[INFO][2023-06-24 13:53:00,958][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_data.csv
[INFO][2023-06-24 13:53:11,916][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_data.csv
[INFO][2023-06-24 13:53:14,119][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_data.csv
[INFO][2023-06-24 13:53:23,130][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_celltype.csv
[INFO][2023-06-24 13:53:23,151][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_celltype.csv
[INFO][2023-06-24 13:53:23,161][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_celltype.csv
[INFO][2023-06-24 13:53:26,847][dance][_load_raw_data] Loaded expression data: AnnData object with n_obs × n_vars = 6733 × 19856
[INFO][2023-06-24 13:53:26,848][dance][_load_raw_data] Number of training samples: 4,038
[INFO][2023-06-24 13:53:26,852][dance][_load_raw_data] Number of testing samples: 2,695
[INFO][2023-06-24 13:53:26,855][dance][_load

Data object that wraps (.data):
AnnData object with n_obs × n_vars = 6733 × 19856
    uns: 'dance_config'
    obsm: 'cell_type'


### 2.2. A quick primer on AnnData

The [dance data object](https://github.com/OmicsML/dance/blob/912405cb5ab43caf16eb22b9216865c7e3976eaf/dance/data/base.py#L40) is heavily built on top of [AnnData](https://anndata.readthedocs.io/en/latest/), which is a widely used data object to represent, store, and manipulate large annotated matrices.

> anndata is a Python package for handling annotated data matrices in memory and on disk, positioned between pandas and xarray. anndata offers a broad range of computationally efficient features...

AnnData falls into the ecosystem of scVerse, providing extra advantage and ease for handeling single-cell data using, for example, [Scanpy](https://scanpy.readthedocs.io/en/stable/).

anndata_schema.svg

In [None]:
adata = data.data
adata

AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'

In [None]:
adata.X

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [None]:
adata.obsm["Group"]

array([ 1,  5, 16, ..., 10,  1,  6], dtype=int32)

### 2.3. Data pre-processing using transforms

Applying individual in-place transformations to data


In [None]:
import scanpy as sc

from dance.transforms import AnnDataTransform, FilterGenesPercentile

# Reload data
data = dataset.load_data()

# Library size normalization
AnnDataTransform(sc.pp.normalize_total, target_sum=1e-4)(data)

# Shifted log transformation
AnnDataTransform(sc.pp.log1p)(data)

# Filter out genes that have extreme coefficient of variation
FilterGenesPercentile(min_val=1, max_val=99, mode="sum")(data)



Composing transformations into a a pre-precoessing pipeline (feat. caching)

In [None]:
from dance.transforms import Compose

preprocessing_pipeline = Compose(
    AnnDataTransform(sc.pp.normalize_total, target_sum=1e-4),
    AnnDataTransform(sc.pp.log1p),
    FilterGenesPercentile(min_val=1, max_val=99, mode="sum"),
)

data = dataset.load_data()
preprocessing_pipeline(data)

[INFO][2023-06-24 11:54:34,794][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-06-24 11:54:34,795][dance][wrapped_func] Took 0:00:00.827345 to load and process data.


In [None]:
data = dataset.load_data(transform=preprocessing_pipeline, cache=True)

[INFO][2023-06-24 11:54:46,010][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-06-24 11:54:47,141][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20463
    uns: 'dance_config', 'log1p'
    obsm: 'Group'
[INFO][2023-06-24 11:54:48,140][dance][load_data] Saved processed data to cache: /content/pydance/10X_PBMC/cache/97d10b6e54d74aff33617925c4696e31.pkl
[INFO][2023-06-24 11:54:48,141][dance][wrapped_func] Took 0:00:03.182861 to load and process data.


## Single modality tasks

- Example:
  - Main: Cell Type Annotation
  - ...

In [None]:
import argparse
import pprint
from typing import get_args

import numpy as np

from dance import logger
from dance.datasets.singlemodality import ScDeepSortDataset
from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
from dance.typing import LogLevel

if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--batch_size", type=int, default=1024, help="Batch size")
    parser.add_argument("--cache", action="store_true", help="Cache processed data.")
    parser.add_argument("--device", default="cpu", help="Computation device.")
    parser.add_argument("--hidden_dims", nargs="+", type=int, default=[2000], help="Hidden dimensions.")
    parser.add_argument("--lambd", type=float, default=0.01, help="Regularization parameter")
    parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate")
    parser.add_argument("--log_level", type=str, default="INFO", choices=get_args(LogLevel))
    parser.add_argument("--nofilter", action="store_true", help="Disable filtering genes by expression summaries.")
    parser.add_argument(
        "--normalize", action="store_true", help="Whether to perform the normalization described in ACTINN. "
        "Disabled by default since the scDeepSort data is already normalized")
    parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs")
    parser.add_argument("--print_cost", action="store_true", help="Print cost when training")
    parser.add_argument("--runs", type=int, default=10, help="Number of repetitions")
    parser.add_argument("--seed", type=int, default=0, help="Initial seed random, offset for each repeatition")
    parser.add_argument("--species", default="mouse")
    parser.add_argument("--test_dataset", nargs="+", default=[1759], help="List of testing dataset ids.")
    parser.add_argument("--tissue", default="Spleen")
    parser.add_argument("--train_dataset", nargs="+", default=[1970], help="List of training dataset ids.")

    args = parser.parse_args()
    logger.setLevel(args.log_level)
    logger.info(f"Running SVM with the following parameters:\n{pprint.pformat(vars(args))}")

    # Initialize model and get model specific preprocessing pipeline
    model = ACTINN(hidden_dims=args.hidden_dims, lambd=args.lambd, device=args.device)
    preprocessing_pipeline = model.preprocessing_pipeline(normalize=args.normalize, filter_genes=not args.nofilter)

    # Load data and perform necessary preprocessing
    dataloader = ScDeepSortDataset(train_dataset=args.train_dataset, test_dataset=args.test_dataset, tissue=args.tissue,
                                   species=args.species)
    data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)

    # Obtain training and testing data
    x_train, y_train = data.get_train_data(return_type="torch")
    x_test, y_test = data.get_test_data(return_type="torch")

    # Train and evaluate models for several rounds
    scores = []
    for k in range(args.runs):
        model.fit(x_train, y_train, seed=args.seed + k, lr=args.learning_rate, num_epochs=args.num_epochs,
                  batch_size=args.batch_size, print_cost=args.print_cost)
        scores.append(score := model.score(x_test, y_test))
        print(f"{score=:.4f}")
    print(f"Score: {np.mean(scores):04.3f} +/- {np.std(scores):04.3f}")

## Multi-modality tasks
- Example: Modality prediction


### Argument parsing setting

In [None]:
import anndata
import mudata
import torch
import scanpy as sc
from dance import logger
from dance.data import Data
from dance.datasets.multimodality import ModalityPredictionDataset
from dance.modules.multi_modality.predict_modality.babel import BabelWrapper
from dance.utils import set_seed
from sklearn.decomposition import TruncatedSVD
from scipy.sparse import csr_matrix
import argparse
import random
import torch
import os

OPTIMIZER_DICT = {
    "adam": torch.optim.Adam,
    "rmsprop": torch.optim.RMSprop,
}
rndseed = random.randint(0, 2147483647)
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--subtask", default="openproblems_bmmc_cite_phase2_rna")
parser.add_argument("-device", "--device", default="cuda")
parser.add_argument("-cpu", "--cpus", default=1, type=int)
parser.add_argument("-seed", "--rnd_seed", default=rndseed, type=int)
parser.add_argument("--lossweight", type=float, default=1., help="Relative loss weight")
parser.add_argument("--lr", "-l", type=float, default=0.01, help="Learning rate")
parser.add_argument("--batchsize", "-b", type=int, default=64, help="Batch size")
parser.add_argument("--hidden", type=int, default=64, help="Hidden dimensions")
parser.add_argument("--earlystop", type=int, default=20, help="Early stopping after N epochs")
parser.add_argument("--naive", "-n", action="store_true", help="Use a naive model instead of lego model")

parser.add_argument("-m", "--model_folder", default="./")
parser.add_argument("--outdir", "-o", default="./", help="Directory to output to")
parser.add_argument("--resume", action="store_true")
parser.add_argument("--max_epochs", type=int, default=500)
args_defaults = parser.parse_args([])
args = argparse.Namespace(**vars(args_defaults))

rndseed = args.rnd_seed
set_seed(rndseed)
device = args.device
os.makedirs(args.model_folder, exist_ok=True)
args

[INFO][2023-06-24 14:30:07,498][dance][set_seed] Setting global random seed to 174567710


Namespace(subtask='openproblems_bmmc_cite_phase2_rna', device='cuda', cpus=1, rnd_seed=174567710, lossweight=1.0, lr=0.01, batchsize=64, hidden=64, earlystop=20, naive=False, model_folder='./', outdir='./', resume=False, max_epochs=500)

### Load data and perform necessary preprocessing

In [None]:
dataset = ModalityPredictionDataset(args.subtask)
data = dataset.load_data()
data

[INFO][2023-06-24 14:30:10,883][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna.censor_dataset.output_train_mod1.h5ad
[INFO][2023-06-24 14:30:20,935][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna.censor_dataset.output_train_mod2.h5ad
[INFO][2023-06-24 14:30:21,674][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna.censor_dataset.output_test_mod1.h5ad
[INFO][2023-06-24 14:30:21,908][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna.censor_dataset.output_test_mod2.h5ad
[INFO][2023-06-24 14:30:21,937][dance][_maybe_preprocess] Preprocessing done.
[INFO][2023-06-24 14:30:31,711][dance][set_config_from_dict] Setting config 'feature_mod' to 'mod1'
[INFO][2023-06-24 14:30:31,712][dance][set_config_from_dict] Setting config 'label_mod' to '

Data object that wraps (.data):
MuData object with n_obs × n_vars = 67175 × 14087
  uns:	'dance_config'
  2 modalities
    mod1:	67175 x 13953
      obs:	'batch', 'size_factors'
      layers:	'counts'
    mod2:	67175 x 134
      obs:	'batch', 'size_factors'
      layers:	'counts'

In [None]:
# Construct data object
data.set_config(feature_mod="mod1", label_mod="mod2")

# Obtain training and testing data
x_train, y_train = data.get_train_data(return_type="torch")

# Colab has very limited memory (12.7GB), therefore we subsample the data
x_train = x_train[:10000]
y_train = y_train[:10000]

x_test, y_test = data.get_test_data(return_type="torch")

### Initialize model and get model specific preprocessing pipeline

In [None]:
model = BabelWrapper(args, dim_in=x_train.shape[1], dim_out=y_train.shape[1])

[INFO][2023-06-24 14:31:23,785][dance][__init__] ChromDecoder with 1 output activations


### Train and evaluate model

In [None]:
model.fit(x_train.float(), y_train.float(), val_ratio=0.15)

epoch:  1
training (sum of 4 losses): 1.6766332656817329
validation (prediction loss): 0.41206314481472334
epoch:  2
training (sum of 4 losses): 1.40987203354226
validation (prediction loss): 0.38240737221883014
epoch:  3
training (sum of 4 losses): 1.3574806496613008
validation (prediction loss): 0.3682764279868932
epoch:  4
training (sum of 4 losses): 1.3285951874309914
validation (prediction loss): 0.36719259341407
epoch:  5
training (sum of 4 losses): 1.312786208955865
validation (prediction loss): 0.3694950917258167
epoch:  6
training (sum of 4 losses): 1.2974244553343695
validation (prediction loss): 0.36570537171478096
epoch:  7
training (sum of 4 losses): 1.2830024088235725
validation (prediction loss): 0.36225591615754504
epoch:  8
training (sum of 4 losses): 1.2671941814566017
validation (prediction loss): 0.35869080089694916
epoch:  9
training (sum of 4 losses): 1.2591027944607842
validation (prediction loss): 0.3587988289542294
epoch:  10
training (sum of 4 losses): 1.25397

In [None]:
model.predict(x_test.float())

tensor([[0.0000, 0.1828, 1.4499,  ..., 0.6975, 0.7497, 1.2867],
        [0.0000, 0.2233, 1.4454,  ..., 0.4905, 0.5937, 0.4263],
        [0.0000, 0.2122, 1.3088,  ..., 0.6468, 0.8948, 0.4278],
        ...,
        [0.0000, 0.2826, 1.0508,  ..., 0.6880, 0.6888, 0.4230],
        [0.0000, 0.2779, 1.0777,  ..., 0.6544, 0.7332, 0.4518],
        [0.0000, 0.2132, 1.4187,  ..., 0.5752, 0.7776, 0.5988]],
       device='cuda:0')

In [None]:
model.score(x_test.float(), y_test.float())

0.5074194340577416

# Spatial Transcriptomics Module


## Task 1: Spatial Domain

### SpaGCN model for spatial domain identification
![SpaGCN_framework](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/spatial_domain/SpaGCN_framework.png)

In [22]:
import argparse
from dance.transforms import Compose
from dance.datasets.spatial import SpatialLIBDDataset
from dance.modules.spatial.spatial_domain.spagcn import SpaGCN, refine
from dance.utils import set_seed
from dance.transforms import AnnDataTransform, CellPCA, Compose, FilterGenesMatch, SetConfig
from dance.transforms.graph import SpaGCNGraph, SpaGCNGraph2D


### Initialize model and get model specific preprocessing *pipeline*

In [23]:
model = SpaGCN()
preprocessing_pipeline = model.preprocessing_pipeline(alpha=1, beta=49) # In SpaGCN, alpha and beta are used for graph construction


### User defined customized trasform

In [24]:
preprocessing_pipeline = Compose(
    FilterGenesMatch(prefixes=["ERCC", "MT-"]),
    SpaGCNGraph(alpha=1, beta=49),
    SpaGCNGraph2D(),
    CellPCA(n_components=40),
    SetConfig({
        "feature_channel": ["CellPCA", "SpaGCNGraph", "SpaGCNGraph2D"],
        "feature_channel_type": ["obsm", "obsp", "obsp"],
        "label_channel": "label",
        "label_channel_type": "obs"
    }),
    )

### Load data and perform necessary preprocessing

In [25]:
dataloader = SpatialLIBDDataset(data_id="151673")
data = dataloader.load_data(transform=preprocessing_pipeline, cache="store_true")

[INFO][2023-06-26 05:36:27,141][dance][_load_raw_data] Loading image data from data/spatial/151673/151673_full_image.tif
[INFO][2023-06-26 05:36:28,111][dance][_load_raw_data] Loading expression data from data/spatial/151673/151673_raw_feature_bc_matrix.h5
  utils.warn_names_duplicates("var")
[INFO][2023-06-26 05:36:28,701][dance][_load_raw_data] Loading spatial info from data/spatial/151673/tissue_positions_list.txt
[INFO][2023-06-26 05:36:28,712][dance][_load_raw_data] Loading label info from data/spatial/151673/cluster_labels.csv
  utils.warn_names_duplicates("var")
[INFO][2023-06-26 05:36:28,782][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 3639 × 33538
    obs: 'key', 'ground_truth', 'SpatialDE_PCA', 'SpatialDE_pool_PCA', 'HVG_PCA', 'pseudobulk_PCA', 'markers_PCA', 'SpatialDE_UMAP', 'SpatialDE_pool_UMAP', 'HVG_UMAP', 'pseudobulk_UMAP', 'markers_UMAP', 'SpatialDE_PCA_spatial', 'SpatialDE_pool_PCA_spatial', 'HVG_PCA_spatial'

In [26]:
data

Data object that wraps (.data):
AnnData object with n_obs × n_vars = 3639 × 33516
    obs: 'key', 'ground_truth', 'SpatialDE_PCA', 'SpatialDE_pool_PCA', 'HVG_PCA', 'pseudobulk_PCA', 'markers_PCA', 'SpatialDE_UMAP', 'SpatialDE_pool_UMAP', 'HVG_UMAP', 'pseudobulk_UMAP', 'markers_UMAP', 'SpatialDE_PCA_spatial', 'SpatialDE_pool_PCA_spatial', 'HVG_PCA_spatial', 'pseudobulk_PCA_spatial', 'markers_PCA_spatial', 'SpatialDE_UMAP_spatial', 'SpatialDE_pool_UMAP_spatial', 'HVG_UMAP_spatial', 'pseudobulk_UMAP_spatial', 'markers_UMAP_spatial', 'label'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'image', 'dance_config'
    obsm: 'spatial', 'spatial_pixel', 'CellPCA'
    obsp: 'SpaGCNGraph', 'SpaGCNGraph2D'

In [27]:
data.data.obsp["SpaGCNGraph"].shape

(3639, 3639)

In [28]:
data.x[0].shape

(3639, 40)

In [29]:
data.x[1].shape

(3639, 3639)

In [30]:
data.x[1].shape

(3639, 3639)

In [31]:
(x, adj, adj_2d), y = data.get_train_data()

In [32]:
x, x.shape

(array([[ 74.59424   ,   5.102027  ,  -1.5807899 , ...,  -1.2770147 ,
           4.5531125 ,   4.498791  ],
        [-52.20469   , -29.175175  ,   0.14970411, ...,  -0.59834   ,
          -0.46938702,   2.6673055 ],
        [-46.55332   ,  97.913826  ,  -3.319207  , ...,   1.1419314 ,
           4.833102  ,  -2.9667187 ],
        ...,
        [-48.64317   ,   0.6173291 ,  -0.66554165, ...,  -0.3374023 ,
          -2.9235873 ,   0.31737345],
        [-49.08601   ,   7.224671  ,  -0.7592219 , ...,  -1.6662663 ,
           1.0502098 ,   0.1705698 ],
        [ 35.53666   , -15.498928  ,  -1.0658048 , ...,  -4.917498  ,
          -0.5530586 ,   0.6497386 ]], dtype=float32),
 (3639, 40))

In [33]:
adj, adj.shape

(array([[  0.     ,  94.04338, 107.42335, ...,  75.5097 , 119.22895,
          67.19388],
        [ 94.04338,   0.     ,  61.99059, ...,  77.6764 ,  82.52692,
          64.84502],
        [107.42335,  61.99059,   0.     , ...,  76.50982,  41.4561 ,
          97.2263 ],
        ...,
        [ 75.5097 ,  77.6764 ,  76.50982, ...,   0.     , 111.5035 ,
          44.95287],
        [119.22895,  82.52692,  41.4561 , ..., 111.5035 ,   0.     ,
         126.87175],
        [ 67.19388,  64.84502,  97.2263 , ...,  44.95287, 126.87175,
           0.     ]], dtype=float32),
 (3639, 3639))

In [34]:
y, y.shape

(array([0, 1, 2, ..., 4, 2, 5]), (3639,))

### Train and evaluate model

In [35]:
l = model.search_l(0.05, adj, start=0.01, end=1000, tol=5e-3, max_run=200)
model.set_l(l)
res = model.search_set_res((x, adj), l=l, target_num=7, start=0.4, step=0.1, tol=5e-3, lr=0.05, epochs=200, max_run=200)

[INFO][2023-06-26 05:36:55,264][dance][search_l] Run 1: l [0.01, 1000], p [0.0, 3629.7406229057055]
[INFO][2023-06-26 05:36:55,328][dance][search_l] Run 2: l [0.01, 500.005], p [0.0, 3605.191650390625]
[INFO][2023-06-26 05:36:55,389][dance][search_l] Run 3: l [0.01, 250.0075], p [0.0, 3510.283935546875]
[INFO][2023-06-26 05:36:55,452][dance][search_l] Run 4: l [0.01, 125.00874999999999], p [0.0, 3176.004150390625]
[INFO][2023-06-26 05:36:55,509][dance][search_l] Run 5: l [0.01, 62.509375], p [0.0, 2292.207275390625]
[INFO][2023-06-26 05:36:55,566][dance][search_l] Run 6: l [0.01, 31.2596875], p [0.0, 1045.6600341796875]
[INFO][2023-06-26 05:36:55,632][dance][search_l] Run 7: l [0.01, 15.63484375], p [0.0, 292.86767578125]
[INFO][2023-06-26 05:36:55,713][dance][search_l] Run 8: l [0.01, 7.822421875], p [0.0, 59.20479965209961]
[INFO][2023-06-26 05:36:55,811][dance][search_l] Run 9: l [0.01, 3.9162109375], p [0.0, 9.6504545211792]
[INFO][2023-06-26 05:36:55,888][dance][search_l] Run 10: 

In [36]:
pred = model.fit_predict((x, adj), init_spa=True, init="louvain", tol=5e-3, lr=0.05, epochs=200, res=res)

[INFO][2023-06-26 05:37:07,040][dance][fit] Initializing cluster centers with louvain, resolution = 0.6
[INFO][2023-06-26 05:37:07,707][dance][fit] Epoch 0
[INFO][2023-06-26 05:37:08,517][dance][fit] Epoch 10
[INFO][2023-06-26 05:37:09,331][dance][fit] Epoch 20
[INFO][2023-06-26 05:37:10,173][dance][fit] Epoch 30
[INFO][2023-06-26 05:37:10,966][dance][fit] Epoch 40
[INFO][2023-06-26 05:37:11,765][dance][fit] Epoch 50
[INFO][2023-06-26 05:37:12,256][dance][fit] delta_label 0.004396812311074471 < tol 0.005
[INFO][2023-06-26 05:37:12,258][dance][fit] Reach tolerance threshold. Stopping training.
[INFO][2023-06-26 05:37:12,262][dance][fit] Total epoch: 55


In [37]:
score = model.default_score_func(y, pred)
print(f"ARI: {score:.4f}")

ARI: 0.1699


## Task 2: Cell Type Deconvolution

### DSTG model for cell type deconvolution
![DSTG_framework](https://github.com/OmicsML/dance-tutorials/raw/4a5a17326ecc652439ec3b459788af7f41316216/imgs/tutorial_v1/cell_type_deconvo/DSTG_framework.png)

### Argument parsing setting

In [38]:
import argparse
from pprint import pprint
import numpy as np
import torch
from dance.datasets.spatial import CellTypeDeconvoDataset
from dance.modules.spatial.cell_type_deconvo import DSTG
from dance.utils import set_seed


### Get model specific preprocessing *pipeline*

In [39]:
preprocessing_pipeline = DSTG.preprocessing_pipeline(
    n_pseudo=500,
    n_top_genes=2000,
    k_filter=200,
    num_cc=30,
)

### Load data and perform necessary preprocessing

In [40]:
dataset = CellTypeDeconvoDataset(data_dir="data/spatial", data_id="CARD_synthetic")
data = dataset.load_data(transform=preprocessing_pipeline, cache="store_true")

[INFO][2023-06-26 05:40:38,116][dance][download_file] Downloading: data/spatial/CARD_synthetic.zip Bytes: 104,321,218
100%|██████████| 99.5M/99.5M [00:04<00:00, 21.3MB/s]
[INFO][2023-06-26 05:40:43,032][dance][unzip_file] Unzipping 'data/spatial/CARD_synthetic.zip'
[INFO][2023-06-26 05:40:43,293][dance][delete_file] Deleting 'data/spatial/CARD_synthetic.zip'
[INFO][2023-06-26 05:40:44,205][dance][_load_raw_data] Number of cell types: reference = 7, real = 6
[INFO][2023-06-26 05:40:44,206][dance][_load_raw_data] Subsetting to common cell types (n=6):
['Astrocytes', 'Ependymal', 'Immune', 'Neurons', 'Oligos', 'Vascular']
[INFO][2023-06-26 05:40:45,646][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 10454 × 18263
    obs: 'cellname', 'cellType', 'sampleID', 'batch'
    uns: 'dance_config'
    obsm: 'cell_type_portion', 'spatial'
[INFO][2023-06-26 05:40:45,648][dance.Compose][__call__] Applying composed transformations:
Compose(
  Fi

In [41]:
data.x

[<760x760 sparse matrix of type '<class 'numpy.float64'>'
 	with 3032 stored elements in COOrdinate format>,
 array([[-0.96134275,  0.976699  , -1.028233  , ..., -0.16210744,
          0.10654867,  0.72191226],
        [-0.96134275,  0.67003554,  0.12525097, ..., -0.16210744,
         -0.7449631 ,  0.32156822],
        [-0.96134275, -0.9604567 ,  0.06584159, ..., -0.16210744,
         -0.89897835,  1.0575923 ],
        ...,
        [-0.7177087 , -0.29554528, -0.8562431 , ...,  0.        ,
         -0.7281318 ,  2.2388463 ],
        [-0.7177087 , -0.29554528, -0.8562431 , ...,  0.        ,
         -0.9590941 , -0.49117133],
        [ 1.5180172 , -0.29554528, -0.8562431 , ...,  0.        ,
         -0.87318116, -0.49117133]], dtype=float32)]

In [42]:
len(data.x)

2

In [43]:
data.x[0], data.x[0].shape

(<760x760 sparse matrix of type '<class 'numpy.float64'>'
 	with 3032 stored elements in COOrdinate format>,
 (760, 760))

In [44]:
data.x[1], data.x[1].shape

(array([[-0.96134275,  0.976699  , -1.028233  , ..., -0.16210744,
          0.10654867,  0.72191226],
        [-0.96134275,  0.67003554,  0.12525097, ..., -0.16210744,
         -0.7449631 ,  0.32156822],
        [-0.96134275, -0.9604567 ,  0.06584159, ..., -0.16210744,
         -0.89897835,  1.0575923 ],
        ...,
        [-0.7177087 , -0.29554528, -0.8562431 , ...,  0.        ,
         -0.7281318 ,  2.2388463 ],
        [-0.7177087 , -0.29554528, -0.8562431 , ...,  0.        ,
         -0.9590941 , -0.49117133],
        [ 1.5180172 , -0.29554528, -0.8562431 , ...,  0.        ,
         -0.87318116, -0.49117133]], dtype=float32),
 (760, 2000))

In [45]:
(adj, x), y = data.get_data(return_type="default")
x, y = torch.FloatTensor(x), torch.FloatTensor(y.values)
adj = torch.sparse.FloatTensor(torch.LongTensor([adj.row.tolist(), adj.col.tolist()]),
                               torch.FloatTensor(adj.data.astype(np.int32)))


In [51]:
x, x.shape

(tensor([[-0.9613,  0.9767, -1.0282,  ..., -0.1621,  0.1065,  0.7219],
         [-0.9613,  0.6700,  0.1253,  ..., -0.1621, -0.7450,  0.3216],
         [-0.9613, -0.9605,  0.0658,  ..., -0.1621, -0.8990,  1.0576],
         ...,
         [-0.7177, -0.2955, -0.8562,  ...,  0.0000, -0.7281,  2.2388],
         [-0.7177, -0.2955, -0.8562,  ...,  0.0000, -0.9591, -0.4912],
         [ 1.5180, -0.2955, -0.8562,  ...,  0.0000, -0.8732, -0.4912]]),
 torch.Size([760, 2000]))

In [52]:
adj, adj.shape

(tensor(indices=tensor([[  0,   0,   0,  ..., 759, 759, 759],
                        [  0, 563, 647,  ..., 295, 374, 759]]),
        values=tensor([0., 0., 0.,  ..., 0., 0., 0.]),
        size=(760, 760), nnz=3032, layout=torch.sparse_coo),
 torch.Size([760, 760]))

In [53]:
y, y.shape

(tensor([[0.2000, 0.1000, 0.0000, 0.4000, 0.0000, 0.3000],
         [0.3000, 0.1000, 0.1000, 0.2000, 0.2000, 0.1000],
         [0.0000, 0.0000, 0.0000, 0.2222, 0.1111, 0.6667],
         ...,
         [0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000],
         [0.1429, 0.0000, 0.0000, 0.4286, 0.4286, 0.0000],
         [0.0000, 0.0000, 0.2500, 0.7500, 0.0000, 0.0000]]),
 torch.Size([760, 6]))

In [49]:
train_mask = data.get_split_mask("pseudo", return_type="torch")
inputs = (adj, x, train_mask)
train_mask, train_mask.shape

(tensor([False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, F

### Train and evaluate model

In [56]:
model = DSTG(nhid=16, bias=False, dropout=0, device="auto")
pred = model.fit_predict(inputs, y, lr=0.01, max_epochs=25, weight_decay=0.0001)
pred, pred.shape

[INFO][2023-06-26 05:45:53,437][dance][fit] Epoch: 0005, train_loss=1.79176, time=0.00165
[INFO][2023-06-26 05:45:53,456][dance][fit] Epoch: 0010, train_loss=1.79176, time=0.00123
[INFO][2023-06-26 05:45:53,474][dance][fit] Epoch: 0015, train_loss=1.79176, time=0.00125
[INFO][2023-06-26 05:45:53,493][dance][fit] Epoch: 0020, train_loss=1.79176, time=0.00124
[INFO][2023-06-26 05:45:53,512][dance][fit] Epoch: 0025, train_loss=1.79176, time=0.00120


(tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         ...,
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]], device='cuda:0',
        grad_fn=<SoftmaxBackward0>),
 torch.Size([760, 6]))

In [57]:
test_mask = data.get_split_mask("test", return_type="torch")
test_mask, test_mask.shape

(tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  

In [58]:
score = model.default_score_func(y[test_mask], pred[test_mask])
print(f"MSE: {score:7.4f}")

MSE:  0.0239
