<a href="https://colab.research.google.com/github/OmicsML/dance-tutorials/blob/dev/dance_tutorial_full_run.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Installation

DANCE is published on [PyPI](https://pypi.org/project/pydance/). Thus, installing DANCE is as easy as

```bash
pip install pydance
```

Or, to install the latest dev version on GitHub as

```bash
pip install git+https://github.com/OmicsML/dance
```

But becaues DANCE includes many deep learning based methods, there are also deep learning library dependencies, such as [PyTorch](https://pytorch.org/), [PyG](https://www.pyg.org/), and [DGL](https://www.dgl.ai/). We will walk through the installation process below.

### 1.1. Install torch related dependencies

In [1]:
# Colab comes with torch installed, so we do not need to install pytorch here
# !pip3 install torch torchvision torchaudio

!pip install -q torch_geometric==2.3.1
!pip install -q dgl==1.1.0 -f https://data.dgl.ai/wheels/cu117/repo.html
!pip install -q torchnmf==0.3.4

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/661.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m44.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.7/86.7 MB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25h

### 1.2 Install latest dev version of DANCE

In [2]:
!pip install -q git+https://github.com/OmicsML/dance.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m61.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m732.5/732.5 kB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m38.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.7/10.7 MB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m53.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### 1.3 Check if DANCE is installed successfully

In [3]:
import dance
print(f"Installed DANCE version {dance.__version__}")

Installed DANCE version 1.0.0-rc.1


## 2. Data loading and processing

DANCE comes with several benchmarking datasets in a unified dataset object format. This makes data downloading, processing, and caching easy for users through our dataset object interface.

### 2.1. Check available data options and load data object

In [4]:
import os
os.environ["DGLBACKEND"] = "pytorch"
from pprint import pprint
from dance.datasets.singlemodality import ClusteringDataset, ScDeepSortDataset

In [5]:
print("Available dataset option for ClusteringDataset:")
pprint(ClusteringDataset.get_avalilable_data())

Available dataset option for ClusteringDataset:
['10X_PBMC', 'mouse_ES_cell', 'mouse_bladder_cell', 'worm_neuron_cell']


In [6]:
print("Available dataset option for ScDeepSortDataset:")
pprint(ScDeepSortDataset.get_avalilable_data())

Available dataset option for ScDeepSortDataset:
[{'dataset': '3285', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '753', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '4682', 'species': 'mouse', 'split': 'train', 'tissue': 'Kidney'},
 {'dataset': '1970', 'species': 'mouse', 'split': 'train', 'tissue': 'Spleen'},
 {'dataset': '2695', 'species': 'mouse', 'split': 'test', 'tissue': 'Brain'},
 {'dataset': '203', 'species': 'mouse', 'split': 'test', 'tissue': 'Kidney'},
 {'dataset': '1759', 'species': 'mouse', 'split': 'test', 'tissue': 'Spleen'}]


#### Example: ClusteringDataset

In [7]:
dataset = ClusteringDataset("10X_PBMC")
print(dataset)

ClusteringDataset()


In [8]:
# The dataset object do not contain data, it only loads the data upon calling
# the load_data function
data = dataset.load_data()

[INFO][2023-06-28 04:40:58,947][dance][download_file] Downloading: 10X_PBMC/mouse_bladder_cell.h5 Bytes: 5,479,777
100%|██████████| 5.23M/5.23M [00:01<00:00, 3.08MB/s]
[INFO][2023-06-28 04:41:01,474][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 2746 × 20670
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-06-28 04:41:01,476][dance][wrapped_func] Took 0:00:05.533764 to load and process data.


#### Example: ScDeepSortDataset

In [9]:
dataset = ScDeepSortDataset(species="mouse", tissue="Brain",
                            train_dataset=["3285", "753"], test_dataset=["2695"])
data = dataset.load_data()

[INFO][2023-06-28 04:41:01,490][dance][is_complete] ./train/mouse/mouse_Spleen1970_celltype.csv
[INFO][2023-06-28 04:41:01,492][dance][is_complete] file mouse_Spleen1970_celltype.csv doesn't exist
[INFO][2023-06-28 04:41:01,494][dance][is_complete] ./train/mouse/mouse_Spleen1970_celltype.csv
[INFO][2023-06-28 04:41:01,497][dance][is_complete] file mouse_Spleen1970_celltype.csv doesn't exist
[INFO][2023-06-28 04:41:03,209][dance][download_file] Downloading: ./train/mouse/mouse_Spleen1970_celltype.csv Bytes: 65,703
100%|██████████| 64.2k/64.2k [00:00<00:00, 13.7MB/s]
[INFO][2023-06-28 04:41:05,008][dance][download_file] Downloading: ./train/mouse/mouse_Spleen1970_data.csv Bytes: 80,382,707
100%|██████████| 76.7M/76.7M [00:04<00:00, 16.9MB/s]
[INFO][2023-06-28 04:41:11,312][dance][download_file] Downloading: ./test/mouse/mouse_Spleen1759_celltype.csv Bytes: 66,410
100%|██████████| 64.9k/64.9k [00:00<00:00, 27.2MB/s]
[INFO][2023-06-28 04:41:14,056][dance][download_file] Downloading: ./test

### 2.2. A quick primer on AnnData

<img
  src="https://raw.githubusercontent.com/scverse/anndata/main/docs/_static/img/anndata_schema.svg"
  align="right" width="450" alt="image"
/>

The [dance data object](https://github.com/OmicsML/dance/blob/912405cb5ab43caf16eb22b9216865c7e3976eaf/dance/data/base.py#L40) is heavily built on top of [AnnData](https://anndata.readthedocs.io/en/latest/), which is a widely used data object to represent, store, and manipulate large annotated matrices.

> anndata is a Python package for handling annotated data matrices in memory and on disk, positioned between pandas and xarray. anndata offers a broad range of computationally efficient features...

AnnData falls into the ecosystem of scVerse, providing extra advantage and ease for handeling single-cell data using, for example, [Scanpy](https://scanpy.readthedocs.io/en/stable/).

The dance data object essentially wraps around an AnnData object,
which can be accessed in the `.data` attribute.

In [10]:
adata = data.data
print(adata)

AnnData object with n_obs × n_vars = 6733 × 19856
    uns: 'dance_config'
    obsm: 'cell_type'


In [11]:
num_cells, num_genes = adata.shape
print(f"There are {num_cells:,} cells and {num_genes:,} genes in this data object.")

There are 6,733 cells and 19,856 genes in this data object.


There are several key attributes in AnnData objects. For example, `.X` typically holds the main data, such as gene expression. `obs` and `obsm` hold metadata for each sample (i.e., a cell).

In [12]:
adata.X

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [13]:
adata.obsm["cell_type"]

Unnamed: 0,Astrocyte,Astroglial cell,Granulocyte,Hypothalamic ependymal cell,Macrophage,Microglia,Myelinating oligodendrocyte,Neuron,Oligodendrocyte precursor cell,Pan-GABAergic,Schwann cell
mouse_Brain3285_C_1,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
mouse_Brain3285_C_2,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
mouse_Brain3285_C_3,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
mouse_Brain3285_C_4,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
mouse_Brain3285_C_5,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...
mouse_Brain2695_C_2691,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
mouse_Brain2695_C_2692,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
mouse_Brain2695_C_2693,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
mouse_Brain2695_C_2694,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### 2.3. Data pre-processing using transforms

Applying individual in-place transformations to data


In [14]:
import scanpy as sc
from dance.transforms import AnnDataTransform, FilterGenesPercentile

In [15]:
print(f"Library sizes before normalization: {data.data.X.sum(1).round(0)}")

# Library size normalization
AnnDataTransform(sc.pp.normalize_total, target_sum=1e4)(data)

print(f"Library sizes after normalization: {data.data.X.sum(1).round(0)}")

Library sizes before normalization: [1357. 1631. 1324. ... 2989. 3587. 3037.]
Library sizes after normalization: [10000. 10000. 10000. ... 10000. 10000. 10000.]


In [16]:
# Shifted log transformation
AnnDataTransform(sc.pp.log1p)(data)

print(f"Sum of expression per cell after log1p transformation: {data.data.X.sum(1)}")

Sum of expression per cell after log1p transformation: [1402.2233 1705.7186 1353.9008 ... 3424.3523 4250.343  3506.1055]


In [17]:
print(f"Number of genes before filtering: {data.shape[1]:,}")

# Filter out genes that have extreme coefficient of variation
FilterGenesPercentile(min_val=1, max_val=99, mode="sum")(data)

print(f"Number of genes before filtering: {data.shape[1]:,}")

Number of genes before filtering: 19,856
Number of genes before filtering: 19,657


Composing transformations into a a pre-precoessing pipeline (feat. caching)

In [18]:
from dance.transforms import Compose

preprocessing_pipeline = Compose(
    AnnDataTransform(sc.pp.normalize_total, target_sum=1e-4),
    AnnDataTransform(sc.pp.log1p),
    FilterGenesPercentile(min_val=1, max_val=99, mode="sum"),
)

# Now we can apply the preprocessing pipeline transformation to our data
# data = dataset.load_data()
# preprocessing_pipeline(data)

# Alternatively, we can also pass the transformation to the loading function
data = dataset.load_data(transform=preprocessing_pipeline, cache=True)

[INFO][2023-06-28 04:43:01,621][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_data.csv
[INFO][2023-06-28 04:43:09,244][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_data.csv
[INFO][2023-06-28 04:43:11,265][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_data.csv
[INFO][2023-06-28 04:43:17,674][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_celltype.csv
[INFO][2023-06-28 04:43:17,686][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_celltype.csv
[INFO][2023-06-28 04:43:17,691][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_celltype.csv
[INFO][2023-06-28 04:43:21,274][dance][_load_raw_data] Loaded expression data: AnnData object with n_obs × n_vars = 6733 × 19856
[INFO][2023-06-28 04:43:21,276][dance][_load_raw_data] Number of training samples: 4,038
[INFO][2023-06-28 04:43:21,278][dance][_load_raw_data] Number of testing samples: 2,695
[INFO][2023-06-28 04:43:21,280][dance][_load

In [19]:
# Reloading the data with cache enabled using the same transformation
# before can significantly reduce the data loading and pre-processing
# time. Making it easier for researcher to run evaluation with different
# configurations many times but with the same pre-processed data
data = dataset.load_data(transform=preprocessing_pipeline, cache=True)

[INFO][2023-06-28 04:43:34,439][dance][_maybe_load_cache] Loading cached data at /content/cache/3b4e7591e2276386e7ba5ac8ab3d46c8.pkl
------------------------------Dataset object info-------------------------------
ScDeepSortDataset(species='mouse', tissue='Brain', train_dataset=['3285', '753'], test_dataset=['2695'])
------------------------------Transformation info-------------------------------
Compose(
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={'target_sum': 0.0001}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  FilterGenesPercentile(min_val=1, max_val=99, mode='sum'),
)
--------------------------------Loaded data info--------------------------------
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 6733 × 19657
    uns: 'dance_config', 'log1p'
    obsm: 'cell_type'
[INFO][2023-06-28 04:43:34,444][dance][load_data] Data loaded:
Data object that wraps (.data):
AnnData object with n_obs

## 3. Single modality tasks

### 3.1 Example: ACTINN for Cell Type Annotation

#### Model structure

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/singlemodality/mlp_visualization.png)

#### Visualization of annotation results

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/singlemodality/cell_type_visualization.png)

#### Load data

In [20]:
print("Available dataset option for ScDeepSortDataset:")
pprint(ScDeepSortDataset.get_avalilable_data())

Available dataset option for ScDeepSortDataset:
[{'dataset': '3285', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '753', 'species': 'mouse', 'split': 'train', 'tissue': 'Brain'},
 {'dataset': '4682', 'species': 'mouse', 'split': 'train', 'tissue': 'Kidney'},
 {'dataset': '1970', 'species': 'mouse', 'split': 'train', 'tissue': 'Spleen'},
 {'dataset': '2695', 'species': 'mouse', 'split': 'test', 'tissue': 'Brain'},
 {'dataset': '203', 'species': 'mouse', 'split': 'test', 'tissue': 'Kidney'},
 {'dataset': '1759', 'species': 'mouse', 'split': 'test', 'tissue': 'Spleen'}]


In [21]:
import numpy as np

from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
from dance.utils import set_seed

# Initialize model and get model specific preprocessing pipeline
model = ACTINN(hidden_dims=[256, 256], lambd=0.01, device='cuda')
preprocessing_pipeline = model.preprocessing_pipeline(normalize=True, filter_genes=True)

# Load data and perform necessary preprocessing
dataset = ScDeepSortDataset(species="mouse", tissue="Brain",
                            train_dataset=["3285", "753"], test_dataset=["2695"])
data = dataset.load_data(transform=preprocessing_pipeline, cache=True)

[INFO][2023-06-28 04:43:34,671][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_data.csv
[INFO][2023-06-28 04:43:41,397][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_data.csv
[INFO][2023-06-28 04:43:43,469][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_data.csv
[INFO][2023-06-28 04:43:51,601][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain3285_celltype.csv
[INFO][2023-06-28 04:43:51,612][dance][_load_dfs] Loading data from ./train/mouse/mouse_Brain753_celltype.csv
[INFO][2023-06-28 04:43:51,617][dance][_load_dfs] Loading data from ./test/mouse/mouse_Brain2695_celltype.csv
[INFO][2023-06-28 04:43:56,014][dance][_load_raw_data] Loaded expression data: AnnData object with n_obs × n_vars = 6733 × 19856
[INFO][2023-06-28 04:43:56,016][dance][_load_raw_data] Number of training samples: 4,038
[INFO][2023-06-28 04:43:56,018][dance][_load_raw_data] Number of testing samples: 2,695
[INFO][2023-06-28 04:43:56,021][dance][_load

#### Train and evaluate model

In [22]:
# Obtain training and testing data
x_train, y_train = data.get_train_data(return_type="torch")
x_test, y_test = data.get_test_data(return_type="torch")

In [23]:
print(x_train)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [24]:
print(y_train)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])


In [25]:
# Train and evaluate model
set_seed(42)
model.fit(x_train, y_train, lr=0.001, num_epochs=21,
          batch_size=1000, print_cost=True)
print(f"ACC: {model.score(x_test, y_test):.4f}")

[INFO][2023-06-28 04:44:05,746][dance][set_seed] Setting global random seed to 42


Epoch:    0 Loss: 4.8823
Epoch:   10 Loss: 2.1393
Epoch:   20 Loss: 2.0517
ACC: 0.3677


In [26]:
print(model.model)

VanillaMLP(
  (model): Sequential(
    (0): Linear(in_features=18159, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=11, bias=True)
  )
)


### 3.2 Example: GraphSCI for Imputation

#### Model structure

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/singlemodality/graphsci_visualization.png)

#### Reported results

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/singlemodality/imputation_results_example.png)

#### Load data

In [27]:
import torch

from dance.datasets.singlemodality import ImputationDataset
from dance.modules.single_modality.imputation.graphsci import GraphSCI
from dance.utils import set_seed

# Load data and perform preprocessing
set_seed(42)
dataloader = ImputationDataset(data_dir='./data', dataset='pbmc_data', train_size=0.9)
preprocessing_pipeline = GraphSCI.preprocessing_pipeline(mask=True, mask_rate=0.1)
data = dataloader.load_data(transform=preprocessing_pipeline, cache=True)

[INFO][2023-06-28 04:44:08,236][dance][set_seed] Setting global random seed to 42
[INFO][2023-06-28 04:44:08,240][dance][is_complete] file ./data/train doesn't exist
  utils.warn_names_duplicates("var")
[INFO][2023-06-28 04:45:18,855][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 5247 × 33538
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'dance_config'
[INFO][2023-06-28 04:45:18,857][dance.Compose][__call__] Applying composed transformations:
Compose(
  FilterGenesScanpy(min_counts=None, min_cells=0.1, max_counts=None, max_cells=None, split_name=None),
  FilterCellsScanpy(min_counts=1, min_genes=None, max_counts=None, max_genes=None, split_name=None),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  FeatureFeatureGraph(),
  CellwiseMaskData(distr='exp', mask_rate=0.1, seed=1),
  SetConfig(config_dict={'feature_channel': [None, None, 'FeatureFeatureGraph', 'train_mask'], 'featu

In [28]:
data.data.layers['train_mask']

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True, False,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ..., False,  True,  True],
       [ True,  True,  True, ...,  True,  True, False]])

In [29]:
data.data.layers['valid_mask']

array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False,  True, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ...,  True, False, False],
       [False, False, False, ..., False, False,  True]])

#### Train and evaluate model

In [30]:
# Obtain training and testing data
X, X_raw, g, mask = data.get_x(return_type="default")
device = 'cuda:0'
X = torch.tensor(X.toarray()).to(device)
X_raw = torch.tensor(X_raw.toarray()).to(device)
g = g.to(device)
train_idx = data.train_idx
test_idx = data.test_idx

# Train and evaluate model
model = GraphSCI(num_cells=X.shape[0], num_genes=X.shape[1],
                 dataset='pbmc_data', gpu=0)
model.fit(X, X_raw, g, train_idx, mask, n_epochs=10, la=1e-7)
model.load_model()
imputed_data = model.predict(X, X_raw, g, mask)
score = model.score(X_raw, imputed_data, test_idx, mask, metric='RMSE')
print("RMSE: %.4f" % score)

[Epoch0], train_loss 2.985725, adj_loss 0.386764, express_loss 2.118230, kl_loss 0.480731, valid_loss 2.871126
[Epoch1], train_loss 3.081999, adj_loss 0.465519, express_loss 2.109051, kl_loss 0.507429, valid_loss 3.051302
[Epoch2], train_loss 2.857551, adj_loss 0.355556, express_loss 2.102364, kl_loss 0.399632, valid_loss 2.679845
[Epoch3], train_loss 2.627350, adj_loss 0.329589, express_loss 2.089371, kl_loss 0.208390, valid_loss 2.614986
[Epoch4], train_loss 2.500940, adj_loss 0.318771, express_loss 2.077216, kl_loss 0.104953, valid_loss 2.557731
[Epoch5], train_loss 2.488447, adj_loss 0.312193, express_loss 2.066709, kl_loss 0.109546, valid_loss 2.532905
[Epoch6], train_loss 2.654563, adj_loss 0.309210, express_loss 2.054052, kl_loss 0.291301, valid_loss 2.524547
[Epoch7], train_loss 2.395145, adj_loss 0.308574, express_loss 2.045947, kl_loss 0.040625, valid_loss 2.517977
[Epoch8], train_loss 2.378382, adj_loss 0.307692, express_loss 2.037600, kl_loss 0.033090, valid_loss 2.554008
[

### 3.3 Example: scDeepCluster for Clustering

#### Model structure

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/singlemodality/scdeepcluster_visualization.png)

#### Reported results

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/singlemodality/clustering_results_example.png)

#### Load data

In [31]:
from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.scdeepcluster import ScDeepCluster
from dance.utils import set_seed


# Load data and perform necessary preprocessing
dataloader = ClusteringDataset('./data', '10X_PBMC')
preprocessing_pipeline = ScDeepCluster.preprocessing_pipeline()
data = dataloader.load_data(transform=preprocessing_pipeline)

[INFO][2023-06-28 04:45:30,652][dance][download_file] Downloading: ./data/10X_PBMC.h5 Bytes: 12,135,959
100%|██████████| 11.6M/11.6M [00:01<00:00, 6.59MB/s]
[INFO][2023-06-28 04:45:35,056][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 4271 × 16653
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-06-28 04:45:35,058][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.scale, func_kwargs={}),
  SetConfig(config_dict={'feature_channel': [None, None, 'n_counts'], 'feature_channel_type'

#### Train and evaluate model

In [32]:
# inputs: x, x_raw, n_counts
inputs, y = data.get_train_data()
n_clusters = len(np.unique(y))
in_dim = inputs[0].shape[1]

# Build and train model
set_seed(42)
model = ScDeepCluster(input_dim=in_dim, z_dim=32, encodeLayer=[256, 64], decodeLayer=[64, 256], device='cuda')
model.fit(inputs, y, n_clusters=n_clusters, lr=0.01, epochs=3, pt_epochs=3)

# Evaluate model predictions
score = model.score(None, y)
print(f"ARI: {score:.4f}")

[INFO][2023-06-28 04:45:37,626][dance][set_seed] Setting global random seed to 42
[INFO][2023-06-28 04:45:37,752][dance][_pretrain] Pre-training started
[INFO][2023-06-28 04:45:38,613][dance][pretrain] Pretrain epoch   1, ZINB loss: 0.39054949
[INFO][2023-06-28 04:45:39,438][dance][pretrain] Pretrain epoch   2, ZINB loss: 0.29828299
[INFO][2023-06-28 04:45:40,258][dance][pretrain] Pretrain epoch   3, ZINB loss: 0.28693729
[INFO][2023-06-28 04:45:40,259][dance][_pretrain] Pre-training finished (took 2.50 seconds)
[INFO][2023-06-28 04:45:40,556][dance][fit] Initializing cluster centers with kmeans.
[INFO][2023-06-28 04:45:42,170][dance][fit] Epoch   1: Total: 0.43940250, Clustering Loss: 0.15573579, ZINB Loss: 0.28366671
[INFO][2023-06-28 04:45:43,022][dance][fit] Epoch   2: Total: 0.43976923, Clustering Loss: 0.15603647, ZINB Loss: 0.28373276
[INFO][2023-06-28 04:45:43,872][dance][fit] Epoch   3: Total: 0.43987711, Clustering Loss: 0.15615725, ZINB Loss: 0.28371986


ARI: 0.5475


## 4. Multi-modality tasks

### 4.1 Modality Prediction

#### Task and Model Description

Modality Prediction: predicting the flow of information from DNA to RNA and RNA to Protein.

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/multimodality/modality_prediction_visualization.svg)

In this section, we take RNA-to-Protein as an example task, where the data are obtained from CITE-seq technology. We use BABEL[1] model as an example to demonstrate the workflow of DANCE package.

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/multimodality/babel_visualization.jpeg)

[1] Wu, Kevin E., et al. "BABEL enables cross-modality translation between multiomic profiles at single-cell resolution." Proceedings of the National Academy of Sciences 118.15 (2021): e2023070118.

#### Import packages and initializations

In [33]:
import argparse
import os
import random

import anndata
import mudata
import scanpy as sc
import torch
from scipy.sparse import csr_matrix
from sklearn.decomposition import TruncatedSVD

from dance import logger
from dance.data import Data
from dance.datasets.multimodality import ModalityPredictionDataset
from dance.modules.multi_modality.predict_modality.babel import BabelWrapper
from dance.utils import set_seed

set_seed(42)
device = 'cuda'

[INFO][2023-06-28 04:45:44,134][dance][set_seed] Setting global random seed to 42


#### Load data and perform necessary preprocessing

In [34]:
dataset = ModalityPredictionDataset("openproblems_bmmc_cite_phase2_rna_subset")
data = dataset.load_data()

[INFO][2023-06-28 04:45:46,602][dance][download_file] Downloading: /content/data/openproblems_bmmc_cite_phase2_rna_subset.zip Bytes: 128,749,677
100%|██████████| 123M/123M [00:08<00:00, 15.6MB/s]
[INFO][2023-06-28 04:45:54,848][dance][unzip_file] Unzipping /content/data/openproblems_bmmc_cite_phase2_rna_subset.zip
[INFO][2023-06-28 04:45:55,535][dance][delete_file] Deleting /content/data/openproblems_bmmc_cite_phase2_rna_subset.zip
[INFO][2023-06-28 04:45:55,558][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna_subset/openproblems_bmmc_cite_phase2_rna_subset.censor_dataset.output_train_mod1.h5ad
[INFO][2023-06-28 04:45:57,000][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna_subset/openproblems_bmmc_cite_phase2_rna_subset.censor_dataset.output_train_mod2.h5ad
[INFO][2023-06-28 04:45:57,117][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna_subset/openproblems_bmmc_cite_phase2_rna_subset.censor_data

In [35]:
data

Data object that wraps (.data):
MuData object with n_obs × n_vars = 9000 × 14087
  uns:	'dance_config'
  2 modalities
    mod1:	9000 x 13953
      obs:	'batch', 'size_factors'
      layers:	'counts'
    mod2:	9000 x 134
      obs:	'batch', 'size_factors'
      layers:	'counts'

In [36]:
len(data.get_split_idx("train"))

8000

In [37]:
data.mod["mod2"].X.shape

(9000, 134)

In [38]:
data.get_split_idx("test")[-10:]

[8990, 8991, 8992, 8993, 8994, 8995, 8996, 8997, 8998, 8999]

In [39]:
data.shape

(9000, 14087)

In [40]:
data.get_feature(mod="mod2", split_name="test")

array([[0.23983176, 0.66712326, 1.3292243 , ..., 0.5173572 , 1.2153752 ,
        0.93940276],
       [0.        , 1.2779273 , 1.2079018 , ..., 0.33972508, 0.6785143 ,
        0.830557  ],
       [0.        , 0.40740192, 1.5360246 , ..., 0.48774227, 0.96867484,
        0.7859117 ],
       ...,
       [0.05793161, 0.8636399 , 1.2541865 , ..., 0.50444597, 0.88847566,
        0.670022  ],
       [0.1118779 , 0.6962416 , 1.314437  , ..., 0.6354239 , 0.8592239 ,
        0.60356927],
       [0.10331292, 1.4756752 , 1.0848818 , ..., 0.68286705, 0.9257406 ,
        0.7871937 ]], dtype=float32)

In [41]:
data.obs

Unnamed: 0,mod1:batch,mod1:size_factors,mod2:batch,mod2:size_factors
GCATTAGCATAAGCGG-1-s1d1,s1d1,0.356535,s1d1,0.356535
TACAGGTGTTAGAGTA-1-s1d1,s1d1,1.292643,s1d1,1.292643
AGGATCTAGGTCTACT-1-s1d1,s1d1,0.970558,s1d1,0.970558
GTAGAAAGTGACACAG-1-s1d1,s1d1,1.232604,s1d1,1.232604
TCCGAAAAGGATCATA-1-s1d1,s1d1,0.044585,s1d1,0.044585
...,...,...,...,...
ATGGTTGTCGCCCAGA-1-s4d1,s4d1,0.514767,s4d1,0.514767
ATCCTATGTTGGGATG-1-s4d1,s4d1,0.234756,s4d1,0.234756
CTACCTGTCAAGCTGT-1-s4d1,s4d1,0.953712,s4d1,0.953712
AACCACACAACATACC-1-s4d8,s4d8,0.821679,s4d8,0.821679


In [42]:
data.mod['mod2']

AnnData object with n_obs × n_vars = 9000 × 134
    obs: 'batch', 'size_factors'
    layers: 'counts'

In [43]:
data

Data object that wraps (.data):
MuData object with n_obs × n_vars = 9000 × 14087
  uns:	'dance_config'
  2 modalities
    mod1:	9000 x 13953
      obs:	'batch', 'size_factors'
      layers:	'counts'
    mod2:	9000 x 134
      obs:	'batch', 'size_factors'
      layers:	'counts'

In [44]:
# Construct data object
data.set_config(feature_mod="mod1", label_mod="mod2")

# Obtain training and testing data
x_train, y_train = data.get_train_data(return_type="torch")
x_test, y_test = data.get_test_data(return_type="torch")

In [45]:
# Construct data object
data.set_config(feature_mod="mod1", label_mod="mod2")

# Obtain training and testing data
x_train, y_train = data.get_train_data(return_type="torch")
x_test, y_test = data.get_test_data(return_type="torch")

In [46]:
x_test, y_test, x_test.shape, y_test.shape

(tensor([[0.0000, 0.0000, 0.0000,  ..., 0.7606, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 1.3700, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]),
 tensor([[0.2398, 0.6671, 1.3292,  ..., 0.5174, 1.2154, 0.9394],
         [0.0000, 1.2779, 1.2079,  ..., 0.3397, 0.6785, 0.8306],
         [0.0000, 0.4074, 1.5360,  ..., 0.4877, 0.9687, 0.7859],
         ...,
         [0.0579, 0.8636, 1.2542,  ..., 0.5044, 0.8885, 0.6700],
         [0.1119, 0.6962, 1.3144,  ..., 0.6354, 0.8592, 0.6036],
         [0.1033, 1.4757, 1.0849,  ..., 0.6829, 0.9257, 0.7872]]),
 torch.Size([1000, 13953]),
 torch.Size([1000, 134]))

#### Specify hyperparameters and initialize the model

In [47]:
parser = argparse.ArgumentParser()

######## Important hyperparameters
parser.add_argument("--subtask", default="openproblems_bmmc_cite_phase2_rna_subset")
parser.add_argument("--max_epochs", type=int, default=40)
parser.add_argument("--lr", "-l", type=float, default=0.01, help="Learning rate")
parser.add_argument("--batchsize", "-b", type=int, default=64, help="Batch size")
parser.add_argument("--hidden", type=int, default=64, help="Hidden dimensions")
parser.add_argument("--earlystop", type=int, default=2, help="Early stopping after N epochs")
parser.add_argument("--naive", "-n", action="store_true", help="Use a naive model instead of lego model")
parser.add_argument("--lossweight", type=float, default=1., help="Relative loss weight")
########

parser.add_argument("--model_folder", default="./")
parser.add_argument("--outdir", "-o", default="./", help="Directory to output to")
parser.add_argument("--resume", action="store_true")
parser.add_argument("--device", default="cuda")
parser.add_argument("--cpus", default=1, type=int)
parser.add_argument("--rnd_seed", default=42, type=int)

args_defaults = parser.parse_args([])
args = argparse.Namespace(**vars(args_defaults))
args

Namespace(subtask='openproblems_bmmc_cite_phase2_rna_subset', max_epochs=40, lr=0.01, batchsize=64, hidden=64, earlystop=2, naive=False, lossweight=1.0, model_folder='./', outdir='./', resume=False, device='cuda', cpus=1, rnd_seed=42)

In [48]:
model = BabelWrapper(args, dim_in=x_train.shape[1], dim_out=y_train.shape[1])

[INFO][2023-06-28 04:46:02,490][dance][__init__] ChromDecoder with 1 output activations


#### Train and evaluate model

In [49]:
model.fit(x_train.float(), y_train.float(), val_ratio=0.15)

epoch:  1
training (sum of 4 losses): 1.729675633885036
validation (prediction loss): 0.43476885608907095
epoch:  2
training (sum of 4 losses): 1.4153931085194382
validation (prediction loss): 0.3933187532053297
epoch:  3
training (sum of 4 losses): 1.35525126991985
validation (prediction loss): 0.387645403155249
epoch:  4
training (sum of 4 losses): 1.3230651227113241
validation (prediction loss): 0.3739129218113424
epoch:  5
training (sum of 4 losses): 1.298898705812258
validation (prediction loss): 0.37932540665903486
epoch:  6
training (sum of 4 losses): 1.289031654874855
validation (prediction loss): 0.3697096068122427
epoch:  7
training (sum of 4 losses): 1.279110183225614
validation (prediction loss): 0.3692062800761993
epoch:  8
training (sum of 4 losses): 1.2611568976785534
validation (prediction loss): 0.37058703098469215
epoch:  9
training (sum of 4 losses): 1.2530649376806813
validation (prediction loss): 0.36730712979491736
epoch:  10
training (sum of 4 losses): 1.24354864

In [50]:
model.predict(x_test.float())

tensor([[0.6611, 0.3581, 1.3748,  ..., 0.5862, 0.9071, 1.1447],
        [0.0169, 0.2847, 1.4770,  ..., 0.4408, 0.5901, 0.4358],
        [0.0000, 0.3031, 1.4437,  ..., 0.2850, 0.6995, 0.3197],
        ...,
        [0.0000, 0.2750, 1.0094,  ..., 0.6977, 0.5756, 0.4167],
        [0.0050, 0.2644, 1.0867,  ..., 0.5433, 0.6473, 0.3956],
        [0.0487, 0.2764, 1.3526,  ..., 0.0877, 0.5033, 0.4192]],
       device='cuda:0')

In [51]:
model.score(x_test.float(), y_test.float())

0.48369167180769357

### 4.2 Modality Matching

Matching profiles of each cell from different modalities.

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/multimodality/modality_matching_visualization.jpeg)

In this section, we take RNA-to-Protein as an example task, where the data are obtained from CITE-seq technology. We use scMoGNN[2] model as an example to demonstrate the workflow of DANCE package.

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/multimodality/scmogcn_visualization.jpeg)

[2] Wen, Hongzhi, et al. "Graph neural networks for multimodal single-cell data integration." Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining. 2022.

#### Load data and perform necessary preprocessing

In [52]:
from dance.datasets.multimodality import ModalityMatchingDataset
from dance.modules.multi_modality.match_modality.scmogcn import ScMoGCNWrapper
from dance.transforms.graph.cell_feature_graph import CellFeatureBipartiteGraph
import numpy as np
import torch.nn.functional as F

dataset = ModalityMatchingDataset('openproblems_bmmc_cite_phase2_rna_subset', root='./data', preprocess="pca", pkl_path='lsi_input_pca_count.pkl')
data = dataset.load_data()

[INFO][2023-06-28 04:46:21,751][dance][download_file] Downloading: /content/data/openproblems_bmmc_cite_phase2_rna_subset.zip Bytes: 34,858,859
100%|██████████| 33.2M/33.2M [00:02<00:00, 14.3MB/s]
[INFO][2023-06-28 04:46:24,202][dance][unzip_file] Unzipping /content/data/openproblems_bmmc_cite_phase2_rna_subset.zip
[INFO][2023-06-28 04:46:24,425][dance][delete_file] Deleting /content/data/openproblems_bmmc_cite_phase2_rna_subset.zip
[INFO][2023-06-28 04:46:24,432][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna_subset/openproblems_bmmc_cite_phase2_rna_subset.censor_dataset.output_train_mod1.h5ad
[INFO][2023-06-28 04:46:24,753][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna_subset/openproblems_bmmc_cite_phase2_rna_subset.censor_dataset.output_train_mod2.h5ad
[INFO][2023-06-28 04:46:24,854][dance][_load_raw_data] Loading /content/data/openproblems_bmmc_cite_phase2_rna_subset/openproblems_bmmc_cite_phase2_rna_subset.censor_dat

In [53]:
# ScMoGNN graph construction
data = CellFeatureBipartiteGraph(cell_feature_channel="X_pca", mod="mod1")(data)
data = CellFeatureBipartiteGraph(cell_feature_channel="X_pca", mod="mod2")(data)
data.set_config(feature_mod=["mod1", "mod2", "mod1", "mod2"], feature_channel_type=["uns", "uns", "obs", "obs"],
                feature_channel=["g", "g", "batch", "batch"], label_mod="mod1", label_channel="labels")

[INFO][2023-06-28 04:46:39,366][dance][set_config_from_dict] Setting config 'feature_mod' to ['mod1', 'mod2', 'mod1', 'mod2']
[INFO][2023-06-28 04:46:39,367][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['uns', 'uns', 'obs', 'obs']
[INFO][2023-06-28 04:46:39,370][dance][set_config_from_dict] Setting config 'feature_channel' to ['g', 'g', 'batch', 'batch']
[INFO][2023-06-28 04:46:39,372][dance][set_config_from_dict] Setting config 'label_mod' to 'mod1'
[INFO][2023-06-28 04:46:39,374][dance][set_config_from_dict] Setting config 'label_channel' to 'labels'


In [54]:
(g_mod1, g_mod2, batch_mod1, batch_mod2), z = data.get_data(return_type="default")
train_size = len(data.get_split_idx("train"))
test_idx = np.arange(train_size, g_mod1.num_nodes("cell"))
z_test = F.one_hot(torch.from_numpy(z[train_size:]).long())
labels1 = torch.argmax(z_test, dim=0).to(device)
labels2 = torch.argmax(z_test, dim=1).to(device)
g_mod1 = g_mod1.to(device)
g_mod2 = g_mod2.to(device)

#### Specify hyperparametsr and initialize the model

In [55]:
parser = argparse.ArgumentParser()
parser.add_argument("--layers", default=4, type=int, choices=[3, 4, 5, 6, 7])
parser.add_argument("--learning_rate", default=6e-4, type=float)
parser.add_argument("--disable_propagation", default=0, type=int, choices=[0, 1, 2])
parser.add_argument("--auxiliary_loss", default=True, type=bool)
parser.add_argument("--epochs", default=2000, type=int)
parser.add_argument("--hidden_size", default=64, type=int)
parser.add_argument("--temperature", default=2.739896, type=float)
parser.add_argument("--device", default='cuda', type=str)
parser.add_argument("--rnd_seed", default=42, type=int)

args_defaults = parser.parse_args([])
args = argparse.Namespace(**vars(args_defaults))
data_folder = './data/'
device = 'cuda'
args

Namespace(layers=4, learning_rate=0.0006, disable_propagation=0, auxiliary_loss=True, epochs=2000, hidden_size=64, temperature=2.739896, device='cuda', rnd_seed=42)

In [56]:
model = ScMoGCNWrapper(
    args,
    [
        [(g_mod1.num_nodes("feature"), 512, 0.25), (512, 512, 0.25), (512, args.hidden_size)],
        [(g_mod2.num_nodes("feature"), 512, 0.2), (512, 512, 0.2), (512, args.hidden_size)],
        [(args.hidden_size, 512, 0.2), (512, g_mod1.num_nodes("feature"))],
        [(args.hidden_size, 512, 0.2), (512, g_mod2.num_nodes("feature"))],
    ],
    args.temperature,
)

#### Train and evaluate model

In [57]:
model.fit(g_mod1, g_mod2, labels1, labels2, train_size=train_size)

[INFO][2023-06-28 04:46:39,620][dance][fit] epoch 0
[INFO][2023-06-28 04:46:39,702][dance][fit] training loss: 18.65970, forward: 0.0008, backward: 0.0005
[INFO][2023-06-28 04:46:39,722][dance][fit] validation score: 0.00125
[INFO][2023-06-28 04:46:39,732][dance][fit] epoch 1
[INFO][2023-06-28 04:46:39,774][dance][fit] training loss: 17.66809, forward: 0.0013, backward: 0.0008
[INFO][2023-06-28 04:46:39,798][dance][fit] validation score: 0.00088
[INFO][2023-06-28 04:46:39,800][dance][fit] epoch 2
[INFO][2023-06-28 04:46:39,838][dance][fit] training loss: 15.66323, forward: 0.0018, backward: 0.0010
[INFO][2023-06-28 04:46:39,860][dance][fit] validation score: 0.00100
[INFO][2023-06-28 04:46:39,864][dance][fit] epoch 3
[INFO][2023-06-28 04:46:39,902][dance][fit] training loss: 15.15878, forward: 0.0008, backward: 0.0010
[INFO][2023-06-28 04:46:39,924][dance][fit] validation score: 0.00088
[INFO][2023-06-28 04:46:39,925][dance][fit] epoch 4
[INFO][2023-06-28 04:46:39,963][dance][fit] trai

<dance.modules.multi_modality.match_modality.scmogcn.ScMoGCNWrapper at 0x7f7447832290>

In [58]:
model.predict(test_idx, enhance=True, batch1=batch_mod1, batch2=batch_mod2)

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [59]:
model.score(test_idx, labels_matrix=z_test, enhance=True, batch1=batch_mod1, batch2=batch_mod2)

0.136

## 5. Spatial tasks

### 5.1 Spatial Domain

#### SpaGCN model for spatial domain identification

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/spatial/spagcn_framework.png)

In [60]:
from dance.transforms import Compose
from dance.datasets.spatial import SpatialLIBDDataset
from dance.modules.spatial.spatial_domain.spagcn import SpaGCN
from dance.utils import set_seed
from dance.transforms import CellPCA, Compose, FilterGenesMatch, SetConfig
from dance.transforms.graph import SpaGCNGraph, SpaGCNGraph2D

#### Initialize model and get model specific preprocessing *pipeline*

In [61]:
model = SpaGCN()
# In SpaGCN, alpha and beta are used for graph construction
preprocessing_pipeline = model.preprocessing_pipeline(alpha=1, beta=49)

#### User defined customized transform

In [62]:
preprocessing_pipeline = Compose(
    FilterGenesMatch(prefixes=["ERCC", "MT-"]),
    SpaGCNGraph(alpha=1, beta=49),
    SpaGCNGraph2D(),
    CellPCA(n_components=40),
    SetConfig({
        "feature_channel": ["CellPCA", "SpaGCNGraph", "SpaGCNGraph2D"],
        "feature_channel_type": ["obsm", "obsp", "obsp"],
        "label_channel": "label",
        "label_channel_type": "obs"
    }),
)

#### Load data and perform necessary preprocessing

In [63]:
dataloader = SpatialLIBDDataset(data_id="151673")
data = dataloader.load_data(transform=preprocessing_pipeline, cache="store_true")

[INFO][2023-06-28 04:46:48,892][dance][is_complete] lack data/spatial/151673/151673_raw_feature_bc_matrix.h5
[INFO][2023-06-28 04:46:51,077][dance][download_file] Downloading: data/spatial/151673/151673.zip Bytes: 548,448,341
100%|██████████| 523M/523M [00:41<00:00, 13.2MB/s]
[INFO][2023-06-28 04:47:32,574][dance][unzip_file] Unzipping data/spatial/151673/151673.zip
[INFO][2023-06-28 04:47:34,516][dance][delete_file] Deleting data/spatial/151673/151673.zip
[INFO][2023-06-28 04:47:34,636][dance][_load_raw_data] Loading image data from data/spatial/151673/151673_full_image.tif
[INFO][2023-06-28 04:47:36,020][dance][_load_raw_data] Loading expression data from data/spatial/151673/151673_raw_feature_bc_matrix.h5
  utils.warn_names_duplicates("var")
[INFO][2023-06-28 04:47:36,521][dance][_load_raw_data] Loading spatial info from data/spatial/151673/tissue_positions_list.txt
[INFO][2023-06-28 04:47:36,535][dance][_load_raw_data] Loading label info from data/spatial/151673/cluster_labels.csv


In [64]:
data

Data object that wraps (.data):
AnnData object with n_obs × n_vars = 3639 × 33516
    obs: 'key', 'ground_truth', 'SpatialDE_PCA', 'SpatialDE_pool_PCA', 'HVG_PCA', 'pseudobulk_PCA', 'markers_PCA', 'SpatialDE_UMAP', 'SpatialDE_pool_UMAP', 'HVG_UMAP', 'pseudobulk_UMAP', 'markers_UMAP', 'SpatialDE_PCA_spatial', 'SpatialDE_pool_PCA_spatial', 'HVG_PCA_spatial', 'pseudobulk_PCA_spatial', 'markers_PCA_spatial', 'SpatialDE_UMAP_spatial', 'SpatialDE_pool_UMAP_spatial', 'HVG_UMAP_spatial', 'pseudobulk_UMAP_spatial', 'markers_UMAP_spatial', 'label'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'image', 'dance_config'
    obsm: 'spatial', 'spatial_pixel', 'CellPCA'
    obsp: 'SpaGCNGraph', 'SpaGCNGraph2D'

In [65]:
data.data.obsp["SpaGCNGraph"].shape

(3639, 3639)

In [66]:
data.x[0].shape

(3639, 40)

In [67]:
data.x[1].shape

(3639, 3639)

In [68]:
(x, adj, adj_2d), y = data.get_train_data()

In [69]:
x, x.shape

(array([[ 74.59441   ,   5.1020446 ,  -1.5808486 , ...,   3.0062046 ,
           2.0821643 ,   5.399893  ],
        [-52.2047    , -29.17509   ,   0.14971869, ...,   0.9300715 ,
           0.5794006 ,   0.22404312],
        [-46.553333  ,  97.91386   ,  -3.3191757 , ...,  -0.30053952,
           2.1526213 ,   2.920683  ],
        ...,
        [-48.64317   ,   0.61732996,  -0.6655372 , ...,  -0.44500932,
          -2.0972178 ,  -0.24019004],
        [-49.086014  ,   7.2246747 ,  -0.7592183 , ...,   0.39618298,
           0.83608055,  -0.8478186 ],
        [ 35.53667   , -15.498938  ,  -1.0658216 , ...,   5.82994   ,
           0.12134802,  -1.7941561 ]], dtype=float32),
 (3639, 40))

In [70]:
adj, adj.shape

(array([[  0.     ,  94.04338, 107.42335, ...,  75.5097 , 119.22895,
          67.19388],
        [ 94.04338,   0.     ,  61.99059, ...,  77.6764 ,  82.52692,
          64.84502],
        [107.42335,  61.99059,   0.     , ...,  76.50982,  41.4561 ,
          97.2263 ],
        ...,
        [ 75.5097 ,  77.6764 ,  76.50982, ...,   0.     , 111.5035 ,
          44.95287],
        [119.22895,  82.52692,  41.4561 , ..., 111.5035 ,   0.     ,
         126.87175],
        [ 67.19388,  64.84502,  97.2263 , ...,  44.95287, 126.87175,
           0.     ]], dtype=float32),
 (3639, 3639))

In [71]:
y, y.shape

(array([0, 1, 2, ..., 4, 2, 5]), (3639,))

#### Train and evaluate model

In [72]:
l = model.search_l(0.05, adj, start=0.01, end=1000, tol=5e-3, max_run=200)
model.set_l(l)
res = model.search_set_res((x, adj), l=l, target_num=7, start=0.4, step=0.1,
                           tol=5e-3, lr=0.05, epochs=200, max_run=200)

[INFO][2023-06-28 04:47:49,971][dance][search_l] Run 1: l [0.01, 1000], p [0.0, 3629.7406229057055]
[INFO][2023-06-28 04:47:50,014][dance][search_l] Run 2: l [0.01, 500.005], p [0.0, 3605.191650390625]
[INFO][2023-06-28 04:47:50,059][dance][search_l] Run 3: l [0.01, 250.0075], p [0.0, 3510.283935546875]
[INFO][2023-06-28 04:47:50,101][dance][search_l] Run 4: l [0.01, 125.00874999999999], p [0.0, 3176.004150390625]
[INFO][2023-06-28 04:47:50,144][dance][search_l] Run 5: l [0.01, 62.509375], p [0.0, 2292.207275390625]
[INFO][2023-06-28 04:47:50,186][dance][search_l] Run 6: l [0.01, 31.2596875], p [0.0, 1045.6600341796875]
[INFO][2023-06-28 04:47:50,229][dance][search_l] Run 7: l [0.01, 15.63484375], p [0.0, 292.86767578125]
[INFO][2023-06-28 04:47:50,310][dance][search_l] Run 8: l [0.01, 7.822421875], p [0.0, 59.20479965209961]
[INFO][2023-06-28 04:47:50,413][dance][search_l] Run 9: l [0.01, 3.9162109375], p [0.0, 9.6504545211792]
[INFO][2023-06-28 04:47:50,479][dance][search_l] Run 10: 

In [73]:
pred = model.fit_predict((x, adj), init_spa=True, init="louvain", tol=5e-3,
                         lr=0.05, epochs=200, res=res)

[INFO][2023-06-28 04:48:22,992][dance][fit] Initializing cluster centers with louvain, resolution = 0.5
[INFO][2023-06-28 04:48:23,547][dance][fit] Epoch 0
[INFO][2023-06-28 04:48:24,451][dance][fit] Epoch 10
[INFO][2023-06-28 04:48:25,382][dance][fit] Epoch 20
[INFO][2023-06-28 04:48:26,322][dance][fit] Epoch 30
[INFO][2023-06-28 04:48:27,372][dance][fit] Epoch 40
[INFO][2023-06-28 04:48:28,652][dance][fit] delta_label 0.004122011541632316 < tol 0.005
[INFO][2023-06-28 04:48:28,654][dance][fit] Reach tolerance threshold. Stopping training.
[INFO][2023-06-28 04:48:28,656][dance][fit] Total epoch: 49


In [74]:
score = model.default_score_func(y, pred)
print(f"ARI: {score:.4f}")

ARI: 0.1778


### 5.2 Cell Type Deconvolution

#### DSTG model for cell type deconvolution

![image](https://github.com/OmicsML/dance-tutorials/raw/main/imgs/tutorial_v1/spatial/dstg_framework.png)

In [75]:
import numpy as np
import torch
from dance.datasets.spatial import CellTypeDeconvoDataset
from dance.modules.spatial.cell_type_deconvo import DSTG
from dance.utils import set_seed

#### Get model specific preprocessing *pipeline*

In [76]:
preprocessing_pipeline = DSTG.preprocessing_pipeline(
    n_pseudo=500,
    n_top_genes=2000,
    k_filter=200,
    num_cc=30,
)

#### Load data and perform necessary preprocessing

In [77]:
dataset = CellTypeDeconvoDataset(data_dir="data/spatial", data_id="CARD_synthetic")
data = dataset.load_data(transform=preprocessing_pipeline, cache="store_true")

[INFO][2023-06-28 04:48:31,813][dance][download_file] Downloading: data/spatial/CARD_synthetic.zip Bytes: 104,321,218
100%|██████████| 99.5M/99.5M [00:08<00:00, 12.4MB/s]
[INFO][2023-06-28 04:48:40,235][dance][unzip_file] Unzipping data/spatial/CARD_synthetic.zip
[INFO][2023-06-28 04:48:40,561][dance][delete_file] Deleting data/spatial/CARD_synthetic.zip
[INFO][2023-06-28 04:48:41,642][dance][_load_raw_data] Number of cell types: reference = 7, real = 6
[INFO][2023-06-28 04:48:41,649][dance][_load_raw_data] Subsetting to common cell types (n=6):
['Astrocytes', 'Ependymal', 'Immune', 'Neurons', 'Oligos', 'Vascular']
[INFO][2023-06-28 04:48:43,107][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 10454 × 18263
    obs: 'cellname', 'cellType', 'sampleID', 'batch'
    uns: 'dance_config'
    obsm: 'cell_type_portion', 'spatial'
[INFO][2023-06-28 04:48:43,113][dance.Compose][__call__] Applying composed transformations:
Compose(
  Filter

In [78]:
data.x

[<760x760 sparse matrix of type '<class 'numpy.float64'>'
 	with 3032 stored elements in COOrdinate format>,
 array([[-0.96134275,  0.976699  , -1.028233  , ..., -0.16210744,
          0.10654867,  0.72191226],
        [-0.96134275,  0.67003554,  0.12525097, ..., -0.16210744,
         -0.7449631 ,  0.32156822],
        [-0.96134275, -0.9604567 ,  0.06584168, ..., -0.16210744,
         -0.89897835,  1.0575923 ],
        ...,
        [-0.7177087 , -0.29554528, -0.85624325, ...,  0.        ,
         -0.7281318 ,  2.2388465 ],
        [-0.7177087 , -0.29554528, -0.85624325, ...,  0.        ,
         -0.9590941 , -0.49117133],
        [ 1.5180172 , -0.29554528, -0.85624325, ...,  0.        ,
         -0.87318116, -0.49117133]], dtype=float32)]

In [79]:
len(data.x)

2

In [80]:
data.x[0], data.x[0].shape

(<760x760 sparse matrix of type '<class 'numpy.float64'>'
 	with 3032 stored elements in COOrdinate format>,
 (760, 760))

In [81]:
data.x[1], data.x[1].shape

(array([[-0.96134275,  0.976699  , -1.028233  , ..., -0.16210744,
          0.10654867,  0.72191226],
        [-0.96134275,  0.67003554,  0.12525097, ..., -0.16210744,
         -0.7449631 ,  0.32156822],
        [-0.96134275, -0.9604567 ,  0.06584168, ..., -0.16210744,
         -0.89897835,  1.0575923 ],
        ...,
        [-0.7177087 , -0.29554528, -0.85624325, ...,  0.        ,
         -0.7281318 ,  2.2388465 ],
        [-0.7177087 , -0.29554528, -0.85624325, ...,  0.        ,
         -0.9590941 , -0.49117133],
        [ 1.5180172 , -0.29554528, -0.85624325, ...,  0.        ,
         -0.87318116, -0.49117133]], dtype=float32),
 (760, 2000))

In [82]:
(adj, x), y = data.get_data(return_type="default")
x, y = torch.FloatTensor(x), torch.FloatTensor(y.values)
adj = torch.sparse.FloatTensor(torch.LongTensor([adj.row.tolist(), adj.col.tolist()]),
                               torch.FloatTensor(adj.data.astype(np.int32)))

In [83]:
x, x.shape

(tensor([[-0.9613,  0.9767, -1.0282,  ..., -0.1621,  0.1065,  0.7219],
         [-0.9613,  0.6700,  0.1253,  ..., -0.1621, -0.7450,  0.3216],
         [-0.9613, -0.9605,  0.0658,  ..., -0.1621, -0.8990,  1.0576],
         ...,
         [-0.7177, -0.2955, -0.8562,  ...,  0.0000, -0.7281,  2.2388],
         [-0.7177, -0.2955, -0.8562,  ...,  0.0000, -0.9591, -0.4912],
         [ 1.5180, -0.2955, -0.8562,  ...,  0.0000, -0.8732, -0.4912]]),
 torch.Size([760, 2000]))

In [84]:
adj, adj.shape

(tensor(indices=tensor([[  0,   0,   0,  ..., 759, 759, 759],
                        [  0, 563, 647,  ..., 295, 374, 759]]),
        values=tensor([0., 0., 0.,  ..., 0., 0., 0.]),
        size=(760, 760), nnz=3032, layout=torch.sparse_coo),
 torch.Size([760, 760]))

In [85]:
y, y.shape

(tensor([[0.2000, 0.1000, 0.0000, 0.4000, 0.0000, 0.3000],
         [0.3000, 0.1000, 0.1000, 0.2000, 0.2000, 0.1000],
         [0.0000, 0.0000, 0.0000, 0.2222, 0.1111, 0.6667],
         ...,
         [0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000],
         [0.1429, 0.0000, 0.0000, 0.4286, 0.4286, 0.0000],
         [0.0000, 0.0000, 0.2500, 0.7500, 0.0000, 0.0000]]),
 torch.Size([760, 6]))

In [86]:
train_mask = data.get_split_mask("pseudo", return_type="torch")
inputs = (adj, x, train_mask)
train_mask, train_mask.shape

(tensor([False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, F

#### Train and evaluate model

In [87]:
model = DSTG(nhid=16, bias=False, dropout=0, device="auto")
pred = model.fit_predict(inputs, y, lr=0.01, max_epochs=25, weight_decay=0.0001)
pred, pred.shape

[INFO][2023-06-28 04:48:55,799][dance][fit] Epoch: 0005, train_loss=1.79176, time=0.00161
[INFO][2023-06-28 04:48:55,823][dance][fit] Epoch: 0010, train_loss=1.79176, time=0.00147
[INFO][2023-06-28 04:48:55,846][dance][fit] Epoch: 0015, train_loss=1.79176, time=0.00167
[INFO][2023-06-28 04:48:55,871][dance][fit] Epoch: 0020, train_loss=1.79176, time=0.00163
[INFO][2023-06-28 04:48:55,894][dance][fit] Epoch: 0025, train_loss=1.79176, time=0.00148


(tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         ...,
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]], device='cuda:0',
        grad_fn=<SoftmaxBackward0>),
 torch.Size([760, 6]))

In [88]:
test_mask = data.get_split_mask("test", return_type="torch")
test_mask, test_mask.shape

(tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  

In [89]:
score = model.default_score_func(y[test_mask], pred[test_mask])
print(f"MSE: {score:7.4f}")

MSE:  0.0239
