<a href="https://colab.research.google.com/github/OmicsML/dance-tutorials/blob/dev/dance_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Installation

DANCE is published on [PyPI](https://pypi.org/project/pydance/). Thus, installing DANCE is as easy as

```bash
pip install pydance
```

Or, to install the latest dev version on GitHub as

```bash
pip install git+https://github.com/OmicsML/dance
```

But becaues DANCE includes many deep learning based methods, there are also deep learning library dependencies, such as [PyTorch](https://pytorch.org/), [PyG](https://www.pyg.org/), and [DGL](https://www.dgl.ai/). We will walk through the installation process below.

### 1.1. Install torch related dependencies

In [None]:
# Colab comes with torch installed, so we do not need to install pytorch here
# !pip3 install torch torchvision torchaudio

!pip install -q torch_geometric==2.3.1
!pip install -q dgl==1.1.0 -f https://data.dgl.ai/wheels/cu117/repo.html
!pip install -q torchnmf==0.3.4

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/661.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m655.4/661.6 kB[0m [31m25.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.7/86.7 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25h

### 1.2 Install latest dev version of DANCE

In [None]:
!pip install -q git+https://github.com/OmicsML/dance.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m732.5/732.5 kB[0m [31m52.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m61.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.7/10.7 MB[0m [31m83.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m107.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### 1.3 Check if DANCE is installed successfully

In [None]:
import dance
print(f"Installed DANCE version {dance.__version__}")

Installed DANCE version 1.0.0-rc.1


## 2. Data loading and processing

DANCE comes with several benchmarking datasets in a unified dataset object format. This makes data downloading, processing, and caching easy for users through our dataset object interface.

### 2.1. Check available data options and load data object

In [None]:
from pprint import pprint
from dance.datasets.singlemodality import ClusteringDataset, ScDeepSortDataset

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [None]:
print("Available dataset option for ClusteringDataset:")
pprint(ClusteringDataset.get_avalilable_data())

Available dataset option for ClusteringDataset:
['10X_PBMC', 'mouse_ES_cell', 'mouse_bladder_cell', 'worm_neuron_cell']


In [None]:
print("\nAvailable dataset option for ScDeepSortDataset:")
pprint(ScDeepSortDataset.get_avalilable_data())


Available dataset option for ScDeepSortDataset:
[{'dataset': '3285', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '753', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '4682', 'species': 'mouse', 'split': 'train', 'tissue': 'Kidney'},
 {'dataset': '1970', 'species': 'mouse', 'split': 'train', 'tissue': 'Spleen'},
 {'dataset': '2695', 'species': 'mouse', 'split': 'test', 'tissue': 'Brain'},
 {'dataset': '203', 'species': 'mouse', 'split': 'test', 'tissue': 'Kidney'},
 {'dataset': '1759', 'species': 'mouse', 'split': 'test', 'tissue': 'Spleen'}]


#### Example: ClusteringDataset

In [None]:
dataset = ClusteringDataset("10X_PBMC")
print(dataset)

ClusteringDataset()


In [None]:
# The dataset object do not contain data, it only loads the data upon calling
# the load_data function
data = dataset.load_data()
print(data)

[INFO][2023-06-26 00:10:59,672][dance][download_file] Downloading: 10X_PBMC/mouse_bladder_cell.h5 Bytes: 5,479,777
100%|██████████| 5.23M/5.23M [00:00<00:00, 23.3MB/s]
[INFO][2023-06-26 00:11:00,679][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-06-26 00:11:00,680][dance][wrapped_func] Took 0:00:02.167560 to load and process data.


Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'


#### Example: ScDeepSortDataset

In [None]:
dataset = ScDeepSortDataset(species="mouse", tissue="Brain",
                            train_dataset=["3285", "753"], test_dataset=["2695"])
data = dataset.load_data()
print(data)

[INFO][2023-06-26 00:15:17,272][dance][is_complete] ./train/mouse/mouse_Spleen1970_celltype.csv
[INFO][2023-06-26 00:15:17,274][dance][is_complete] file mouse_Spleen1970_celltype.csv doesn't exist
[INFO][2023-06-26 00:15:17,276][dance][is_complete] ./train/mouse/mouse_Spleen1970_celltype.csv
[INFO][2023-06-26 00:15:17,281][dance][is_complete] file mouse_Spleen1970_celltype.csv doesn't exist
[INFO][2023-06-26 00:15:18,392][dance][download_file] Downloading: ./train/mouse/mouse_Spleen1970_celltype.csv Bytes: 65,703
100%|██████████| 64.2k/64.2k [00:00<00:00, 15.6MB/s]
[INFO][2023-06-26 00:15:19,480][dance][download_file] Downloading: ./train/mouse/mouse_Spleen1970_data.csv Bytes: 80,382,707
100%|██████████| 76.7M/76.7M [00:03<00:00, 22.5MB/s]
[INFO][2023-06-26 00:15:23,816][dance][download_file] Downloading: ./test/mouse/mouse_Spleen1759_celltype.csv Bytes: 66,410
100%|██████████| 64.9k/64.9k [00:00<00:00, 13.2MB/s]
[INFO][2023-06-26 00:15:24,651][dance][download_file] Downloading: ./test

Data object that wraps (.data):
AnnData object with n_obs × n_vars = 6733 × 19856
    uns: 'dance_config'
    obsm: 'cell_type'


### 2.2. A quick primer on AnnData

<img
  src="https://raw.githubusercontent.com/scverse/anndata/main/docs/_static/img/anndata_schema.svg"
  align="right" width="450" alt="image"
/>

The [dance data object](https://github.com/OmicsML/dance/blob/912405cb5ab43caf16eb22b9216865c7e3976eaf/dance/data/base.py#L40) is heavily built on top of [AnnData](https://anndata.readthedocs.io/en/latest/), which is a widely used data object to represent, store, and manipulate large annotated matrices.

> anndata is a Python package for handling annotated data matrices in memory and on disk, positioned between pandas and xarray. anndata offers a broad range of computationally efficient features...

AnnData falls into the ecosystem of scVerse, providing extra advantage and ease for handeling single-cell data using, for example, [Scanpy](https://scanpy.readthedocs.io/en/stable/).

The dance data object essentially wraps around an AnnData object,
which can be accessed in the `.data` attribute.

In [None]:
adata = data.data
print(adata)

AnnData object with n_obs × n_vars = 6733 × 19856
    uns: 'dance_config'
    obsm: 'cell_type'


In [None]:
num_cells, num_genes = adata.shape
print(f"There are {num_cells:,} cells and {num_genes:,} genes in this data object.")

There are 6,733 cells and 19,856 genes in this data object.


There are several key attributes in AnnData objects. For example, `.X` typically holds the main data, such as gene expression. `obs` and `obsm` hold metadata for each sample (i.e., a cell).

In [None]:
adata.X

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [None]:
adata.obsm["cell_type"]

Unnamed: 0,Astrocyte,Astroglial cell,Granulocyte,Hypothalamic ependymal cell,Macrophage,Microglia,Myelinating oligodendrocyte,Neuron,Oligodendrocyte precursor cell,Pan-GABAergic,Schwann cell
mouse_Brain3285_C_1,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
mouse_Brain3285_C_2,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
mouse_Brain3285_C_3,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
mouse_Brain3285_C_4,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
mouse_Brain3285_C_5,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...
mouse_Brain2695_C_2691,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
mouse_Brain2695_C_2692,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
mouse_Brain2695_C_2693,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
mouse_Brain2695_C_2694,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### 2.3. Data pre-processing using transforms

Applying individual in-place transformations to data


In [None]:
import scanpy as sc

from dance.transforms import AnnDataTransform, FilterGenesPercentile

# Reload data
data = dataset.load_data()

# Library size normalization
AnnDataTransform(sc.pp.normalize_total, target_sum=1e-4)(data)

# Shifted log transformation
AnnDataTransform(sc.pp.log1p)(data)

# Filter out genes that have extreme coefficient of variation
FilterGenesPercentile(min_val=1, max_val=99, mode="sum")(data)

[INFO][2023-06-26 00:20:11,401][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_data.csv
[INFO][2023-06-26 00:20:17,931][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_data.csv
[INFO][2023-06-26 00:20:20,479][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_data.csv
[INFO][2023-06-26 00:20:28,966][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_celltype.csv
[INFO][2023-06-26 00:20:28,977][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_celltype.csv
[INFO][2023-06-26 00:20:28,984][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_celltype.csv
[INFO][2023-06-26 00:20:32,095][dance][_load_raw_data] Loaded expression data: AnnData object with n_obs × n_vars = 6733 × 19856
[INFO][2023-06-26 00:20:32,096][dance][_load_raw_data] Number of training samples: 4,038
[INFO][2023-06-26 00:20:32,099][dance][_load_raw_data] Number of testing samples: 2,695
[INFO][2023-06-26 00:20:32,103][dance][_load

Composing transformations into a a pre-precoessing pipeline (feat. caching)

In [None]:
from dance.transforms import Compose

preprocessing_pipeline = Compose(
    AnnDataTransform(sc.pp.normalize_total, target_sum=1e-4),
    AnnDataTransform(sc.pp.log1p),
    FilterGenesPercentile(min_val=1, max_val=99, mode="sum"),
)

# Now we can apply the preprocessing pipeline transformation to our data
# data = dataset.load_data()
# preprocessing_pipeline(data)

# Alternatively, we can also pass the transformation to the loading function
data = dataset.load_data(transform=preprocessing_pipeline, cache=True)

[INFO][2023-06-26 00:21:37,779][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_data.csv
[INFO][2023-06-26 00:21:45,699][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_data.csv
[INFO][2023-06-26 00:21:47,627][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_data.csv
[INFO][2023-06-26 00:21:54,970][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_celltype.csv
[INFO][2023-06-26 00:21:54,987][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_celltype.csv
[INFO][2023-06-26 00:21:54,996][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_celltype.csv
[INFO][2023-06-26 00:21:58,969][dance][_load_raw_data] Loaded expression data: AnnData object with n_obs × n_vars = 6733 × 19856
[INFO][2023-06-26 00:21:58,970][dance][_load_raw_data] Number of training samples: 4,038
[INFO][2023-06-26 00:21:58,973][dance][_load_raw_data] Number of testing samples: 2,695
[INFO][2023-06-26 00:21:58,975][dance][_load

In [None]:
data = dataset.load_data(transform=preprocessing_pipeline, cache=True)

[INFO][2023-06-26 00:22:10,531][dance][_maybe_load_cache] Loading cached data at /content/cache/3b4e7591e2276386e7ba5ac8ab3d46c8.pkl
------------------------------Dataset object info-------------------------------
ScDeepSortDataset(species='mouse', tissue='Brain', train_dataset=['3285', '753'], test_dataset=['2695'])
------------------------------Transformation info-------------------------------
Compose(
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={'target_sum': 0.0001}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  FilterGenesPercentile(min_val=1, max_val=99, mode='sum'),
)
--------------------------------Loaded data info--------------------------------
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 6733 × 19657
    uns: 'dance_config', 'log1p'
    obsm: 'cell_type'
[INFO][2023-06-26 00:22:10,534][dance][load_data] Data loaded:
Data object that wraps (.data):
AnnData object with n_obs

## 3. Single modality tasks

### 3.1 Example: ACTINN for Cell Type Annotation

#### Visualization of annotation results

![image](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/singlemodality/cell_type_visualization.png)

#### Load data

In [None]:
import numpy as np

from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
from dance.utils import set_seed

# Initialize model and get model specific preprocessing pipeline
model = ACTINN(hidden_dims=[256, 256], lambd=0.01, device='cuda')
preprocessing_pipeline = model.preprocessing_pipeline(normalize=True, filter_genes=True)

# Load data and perform necessary preprocessing
dataset = ScDeepSortDataset(species="mouse", tissue="Brain",
                            train_dataset=["3285", "753"], test_dataset=["2695"])
data = dataset.load_data(transform=preprocessing_pipeline)
print(data)

#### Train and evaluate model

In [None]:
# Obtain training and testing data
x_train, y_train = data.get_train_data(return_type="torch")
x_test, y_test = data.get_test_data(return_type="torch")

# Train and evaluate model
set_seed(10)
model.fit(x_train, y_train, lr=0.001, num_epochs=21,
          batch_size=1000, print_cost=True)
print(f"ACC: {model.score(x_test, y_test):.4f}")

### 3.2 Example: GraphSCI for Imputation

#### Model structure

![image](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/singlemodality/graphsci_visualization.png)

#### Reported results

![image](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/singlemodality/imputation_results_example.png)

#### Load data

In [None]:
import torch

from dance.datasets.singlemodality import ImputationDataset
from dance.modules.single_modality.imputation.graphsci import GraphSCI
from dance.utils import set_seed

# Load data and perform preprocessing
set_seed(10)
dataloader = ImputationDataset(data_dir='./data', dataset='pbmc_data', train_size=0.9)
preprocessing_pipeline = GraphSCI.preprocessing_pipeline(mask=True, mask_rate=0.1)
data = dataloader.load_data(transform=preprocessing_pipeline)
print(data)

#### Train and evaluate model

In [None]:
# Obtain training and testing data
X, X_raw, g, mask = data.get_x(return_type="default")
device = 'cuda:0'
X = torch.tensor(X.toarray()).to(device)
X_raw = torch.tensor(X_raw.toarray()).to(device)
g = g.to(device)
train_idx = data.train_idx
test_idx = data.test_idx

# Train and evaluate model
model = GraphSCI(num_cells=X.shape[0], num_genes=X.shape[1],
                 dataset='pbmc_data', gpu=0)
model.fit(X, X_raw, g, train_idx, mask, n_epochs=10, la=1e-7)
model.load_model()
imputed_data = model.predict(X, X_raw, g, mask)
score = model.score(X_raw, imputed_data, test_idx, mask, metric='RMSE')
print("RMSE: %.4f" % score)

### 3.3 Example: scDeepCluster for Clustering

#### Model structure

![image](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/singlemodality/scdeepcluster_visualization.png)

#### Reported results

![image](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/singlemodality/clustering_results_example.png)

#### Load data

In [None]:
from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.scdeepcluster import ScDeepCluster
from dance.utils import set_seed


# Load data and perform necessary preprocessing
dataloader = ClusteringDataset('./data', '10X_PBMC')
preprocessing_pipeline = ScDeepCluster.preprocessing_pipeline()
data = dataloader.load_data(transform=preprocessing_pipeline)
print(data)

#### Train and evaluate model

In [None]:
# inputs: x, x_raw, n_counts
inputs, y = data.get_train_data()
n_clusters = len(np.unique(y))
in_dim = inputs[0].shape[1]

# Build and train model
set_seed(10)
model = ScDeepCluster(input_dim=in_dim, z_dim=32, encodeLayer=[256, 64], decodeLayer=[64, 256], device='cuda')
model.fit(inputs, y, n_clusters=n_clusters, lr=0.01, epochs=3, pt_epochs=3)

# Evaluate model predictions
score = model.score(None, y)
print(f"ARI: {score:.4f}")

## 4. Multi-modality tasks

### 4.1 Modality Prediction

#### Task and Model Description

Modality Prediction: predicting the flow of information from DNA to RNA and RNA to Protein.

![image](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/multimodality/modality_prediction_visualization.svg)

In this section, we take RNA-to-Protein as an example task, where the data are obtained from CITE-seq technology. We use BABEL[1] model as an example to demonstrate the workflow of DANCE package.

![image](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/multimodality/babel_visualization.jpeg)

[1] Wu, Kevin E., et al. "BABEL enables cross-modality translation between multiomic profiles at single-cell resolution." Proceedings of the National Academy of Sciences 118.15 (2021): e2023070118.

#### Argument parsing setting

In [None]:
import anndata
import mudata
import torch
import scanpy as sc
from dance import logger
from dance.data import Data
from dance.datasets.multimodality import ModalityPredictionDataset
from dance.modules.multi_modality.predict_modality.babel import BabelWrapper
from dance.utils import set_seed
from sklearn.decomposition import TruncatedSVD
from scipy.sparse import csr_matrix
import argparse
import random
import torch
import os

OPTIMIZER_DICT = {
    "adam": torch.optim.Adam,
    "rmsprop": torch.optim.RMSprop,
}
rndseed = random.randint(0, 2147483647)
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--subtask", default="openproblems_bmmc_cite_phase2_rna")
parser.add_argument("-device", "--device", default="cuda")
parser.add_argument("-cpu", "--cpus", default=1, type=int)
parser.add_argument("-seed", "--rnd_seed", default=rndseed, type=int)
parser.add_argument("--lossweight", type=float, default=1., help="Relative loss weight")
parser.add_argument("--lr", "-l", type=float, default=0.01, help="Learning rate")
parser.add_argument("--batchsize", "-b", type=int, default=64, help="Batch size")
parser.add_argument("--hidden", type=int, default=64, help="Hidden dimensions")
parser.add_argument("--earlystop", type=int, default=20, help="Early stopping after N epochs")
parser.add_argument("--naive", "-n", action="store_true", help="Use a naive model instead of lego model")

parser.add_argument("-m", "--model_folder", default="./")
parser.add_argument("--outdir", "-o", default="./", help="Directory to output to")
parser.add_argument("--resume", action="store_true")
parser.add_argument("--max_epochs", type=int, default=500)
args_defaults = parser.parse_args([])
args = argparse.Namespace(**vars(args_defaults))

rndseed = args.rnd_seed
set_seed(rndseed)
device = args.device
os.makedirs(args.model_folder, exist_ok=True)
args

#### Load data and perform necessary preprocessing

In [None]:
dataset = ModalityPredictionDataset(args.subtask)
data = dataset.load_data()
data

In [None]:
# Construct data object
data.set_config(feature_mod="mod1", label_mod="mod2")

# Obtain training and testing data
x_train, y_train = data.get_train_data(return_type="torch")

# Colab has very limited memory (12.7GB), therefore we subsample the data
x_train = x_train[:10000]
y_train = y_train[:10000]

x_test, y_test = data.get_test_data(return_type="torch")

#### Initialize model and get model specific preprocessing pipeline

In [None]:
model = BabelWrapper(args, dim_in=x_train.shape[1], dim_out=y_train.shape[1])

#### Train and evaluate model

In [None]:
model.fit(x_train.float(), y_train.float(), val_ratio=0.15)

In [None]:
model.predict(x_test.float())

In [None]:
model.score(x_test.float(), y_test.float())

## 5. Spatial

### 5.1 Spatial Domain

#### SpaGCN model for spatial domain identification

![image](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/spatial_domain/SpaGCN_framework.png)

#### Argument parsing setting

In [None]:
import argparse
from dance.transforms import Compose
from dance.datasets.spatial import SpatialLIBDDataset
from dance.modules.spatial.spatial_domain.spagcn import SpaGCN, refine
from dance.utils import set_seed
from dance.transforms import AnnDataTransform, CellPCA, Compose, FilterGenesMatch, SetConfig
from dance.transforms.graph import SpaGCNGraph, SpaGCNGraph2D

parser = argparse.ArgumentParser()
parser.add_argument("-f", required=False)

# SpaGCN model specific pamaters to be tuned
parser.add_argument("--beta", type=int, default=49, help="")
parser.add_argument("--alpha", type=int, default=1, help="")
parser.add_argument("--p", type=float, default=0.05,
                    help="percentage of total expression contributed by neighborhoods.")
parser.add_argument("--l", type=float, default=0.5, help="the parameter to control percentage p.")
parser.add_argument("--start", type=float, default=0.01, help="starting value for searching l.")
parser.add_argument("--end", type=float, default=1000, help="ending value for searching l.")
parser.add_argument("--tol", type=float, default=5e-3, help="tolerant value for searching l.")


# deep learning model generic pamaters to be tuned
parser.add_argument("--cache", action="store_true", help="Cache processed data.")
parser.add_argument("--sample_number", type=str, default="151673",
                    help="12 human dorsolateral prefrontal cortex datasets for the spatial domain task.")
parser.add_argument("--max_run", type=int, default=200, help="max runs.")
parser.add_argument("--epochs", type=int, default=200, help="Number of epochs.")
parser.add_argument("--n_clusters", type=int, default=7, help="the number of clusters")
parser.add_argument("--step", type=float, default=0.1, help="")
parser.add_argument("--lr", type=float, default=0.05, help="learning rate")
parser.add_argument("--random_state", type=int, default=100, help="")
args = parser.parse_args()
set_seed(args.random_state)

#### Initialize model and get model specific preprocessing *pipeline*

In [None]:
model = SpaGCN()
# In SpaGCN, alpha and beta are used for graph construction
preprocessing_pipeline = model.preprocessing_pipeline(alpha=args.alpha, beta=args.beta)

#### User defined customized transform

In [None]:
preprocessing_pipeline = Compose(
    FilterGenesMatch(prefixes=["ERCC", "MT-"]),
    SpaGCNGraph(alpha=1, beta=49),
    SpaGCNGraph2D(),
    CellPCA(n_components=40),
    SetConfig({
        "feature_channel": ["CellPCA", "SpaGCNGraph", "SpaGCNGraph2D"],
        "feature_channel_type": ["obsm", "obsp", "obsp"],
        "label_channel": "label",
        "label_channel_type": "obs"
    }),
    )

#### Load data and perform necessary preprocessing

In [None]:
dataloader = SpatialLIBDDataset(data_id=args.sample_number)
data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)
(x, adj, adj_2d), y = data.get_train_data()

In [None]:
data

In [None]:
data.data.obsp["SpaGCNGraph"].shape

In [None]:
data.x[0].shape

In [None]:
data.x[1].shape

#### Train and evaluate model

In [None]:
l = model.search_l(args.p, adj, start=args.start, end=args.end, tol=args.tol,
                   max_run=args.max_run)
model.set_l(l)
res = model.search_set_res((x, adj), l=l, target_num=args.n_clusters, start=0.4,
                           step=args.step, tol=args.tol, lr=args.lr,
                           epochs=args.epochs, max_run=args.max_run)

In [None]:
pred = model.fit_predict((x, adj), init_spa=True, init="louvain", tol=args.tol,
                         lr=args.lr, epochs=args.epochs, res=res)

In [None]:
score = model.default_score_func(y, pred)
print(f"ARI: {score:.4f}")

### 5.2 Cell Type Deconvolution

#### DSTG model for cell type deconvolution

![image](https://github.com/OmicsML/dance-tutorials/raw/dev/imgs/tutorial_v1/cell_type_deconvo/DSTG_framework.png)

#### Argument parsing setting

In [None]:
import argparse
from pprint import pprint
import numpy as np
import torch
from dance.datasets.spatial import CellTypeDeconvoDataset
from dance.modules.spatial.cell_type_deconvo import DSTG
from dance.utils import set_seed

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-f", required=False)


parser.add_argument("--cache", action="store_true", help="Cache processed data.")
parser.add_argument("--dataset", default="CARD_synthetic")
parser.add_argument("--datadir", default="data/spatial", help="Directory to save the data.")
parser.add_argument("--sc_ref", type=bool, default=True, help="Reference scRNA (True) or cell-mixtures (False).")
parser.add_argument("--num_pseudo", type=int, default=500, help="Number of pseudo mixtures to generate.")
parser.add_argument("--n_hvg", type=int, default=2000, help="Number of HVGs.")
parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate.")
parser.add_argument("--wd", type=float, default=1e-4, help="Weight decay.")
parser.add_argument("--k_filter", type=int, default=200, help="Graph node filter.")
parser.add_argument("--num_cc", type=int, default=30, help="Dimension of canonical correlation analysis.")
parser.add_argument("--bias", type=bool, default=False, help="Include/Exclude bias term.")
parser.add_argument("--nhid", type=int, default=16, help="Number of neurons in latent layer.")
parser.add_argument("--dropout", type=float, default=0., help="Dropout rate.")
parser.add_argument("--epochs", type=int, default=25, help="Number of epochs to train the model.")
parser.add_argument("--seed", type=int, default=17, help="Random seed.")
parser.add_argument("--device", default="auto", help="Computation device.")
args = parser.parse_args()
set_seed(args.seed)

# Load dataset
preprocessing_pipeline = DSTG.preprocessing_pipeline(
    n_pseudo=args.num_pseudo,
    n_top_genes=args.n_hvg,
    k_filter=args.k_filter,
    num_cc=args.num_cc,
)
dataset = CellTypeDeconvoDataset(data_dir=args.datadir, data_id=args.dataset)
data = dataset.load_data(transform=preprocessing_pipeline, cache=args.cache)

(adj, x), y = data.get_data(return_type="default")
x, y = torch.FloatTensor(x), torch.FloatTensor(y.values)
adj = torch.sparse.FloatTensor(torch.LongTensor([adj.row.tolist(), adj.col.tolist()]),
                               torch.FloatTensor(adj.data.astype(np.int32)))
train_mask = data.get_split_mask("pseudo", return_type="torch")
inputs = (adj, x, train_mask)

# Train and evaluate model
model = DSTG(nhid=args.nhid, bias=args.bias, dropout=args.dropout, device=args.device)
pred = model.fit_predict(inputs, y, lr=args.lr, max_epochs=args.epochs, weight_decay=args.wd)
test_mask = data.get_split_mask("test", return_type="torch")
score = model.default_score_func(y[test_mask], pred[test_mask])
print(f"MSE: {score:7.4f}")

#### Get model specific preprocessing *pipeline*

In [None]:
preprocessing_pipeline = DSTG.preprocessing_pipeline(
    n_pseudo=args.num_pseudo,
    n_top_genes=args.n_hvg,
    k_filter=args.k_filter,
    num_cc=args.num_cc,
)

#### Load data and perform necessary preprocessing

In [None]:
dataset = CellTypeDeconvoDataset(data_dir=args.datadir, data_id=args.dataset)
data = dataset.load_data(transform=preprocessing_pipeline, cache=args.cache)

In [None]:
data.x

In [None]:
len(data.x)

In [None]:
data.x[0].shape

In [None]:
data.x[1].shape

In [None]:
(adj, x), y = data.get_data(return_type="default")
x, y = torch.FloatTensor(x), torch.FloatTensor(y.values)
adj = torch.sparse.FloatTensor(torch.LongTensor([adj.row.tolist(), adj.col.tolist()]),
                               torch.FloatTensor(adj.data.astype(np.int32)))
train_mask = data.get_split_mask("pseudo", return_type="torch")
inputs = (adj, x, train_mask)

#### Train and evaluate model

In [None]:
model = DSTG(nhid=args.nhid, bias=args.bias, dropout=args.dropout, device=args.device)
pred = model.fit_predict(inputs, y, lr=args.lr, max_epochs=args.epochs, weight_decay=args.wd)
test_mask = data.get_split_mask("test", return_type="torch")
score = model.default_score_func(y[test_mask], pred[test_mask])
print(f"MSE: {score:7.4f}")