In this demo, we will perform scMTO cluster analysis using the [Qx_Spleen](https://drive.google.com/drive/folders/1BIZxZNbouPtGf_cyu7vM44G5EcbxECeu) dataset.

## Import python package

In [6]:
import argparse
import random
import numpy as np
import pandas as pd
import scanpy as sc
import torch
from scMTO.preprocess import prepro, normalize
from scMTO.model import scMTO
import torch.backends.cudnn as cudnn

## Parameter settings

In [11]:
cudnn.deterministic = True
cudnn.benchmark = True
cudnn.enabled =  True
torch.set_num_threads(2)

parser = argparse.ArgumentParser(description='train', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--name', type=str, default='Quake_10x_Spleen', help="name of scRNA-seq dataset")
parser.add_argument('--pre_lr', type=float, default=4e-4, help="learning rate of pre-training")
parser.add_argument('--lr', type=float, default=1e-5, help="learning rate of formal-training")
parser.add_argument('--pre_epoch', type=int, default=250, help="epoch numbers of pre-training")
parser.add_argument('--train_epoch', type=int, default=500, help="epoch numbers of formal-training")
parser.add_argument('--latent_dim', default=32, type=int, help="dimension of latent space")
parser.add_argument('--device', type=int, default=3)
args = parser.parse_args(args=[])

random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
np.random.seed(1)

args.cuda = torch.cuda.is_available()
torch.cuda.set_device(args.device)
device = torch.device("cuda" if args.cuda else "cpu")

file_path = "../data/" + args.name + "/data.h5"
print("dataset: {}".format(args.name))


dataset: Quake_10x_Spleen


## Read data

In [12]:
x, y = prepro(file_path)
x = np.ceil(x).astype(np.float32)

## Single-cell multi-view feature space construction

In [13]:
# Multi-View 1 (normalized, 2000HVGs)
adata1 = sc.AnnData(x)
adata1.obs['Group'] = y
adata1 = normalize(adata1, copy=True, highly_genes=2000, size_factors=True, normalize_input=True, logtrans_input=True)

# Multi-View 2 (normalized, 500HVGs)
adata2 = sc.AnnData(x)
adata2.obs['Group'] = y   
adata2 = normalize(adata2, copy=True, highly_genes=500, size_factors=True, normalize_input=True, logtrans_input=True)

# Multi-View 3 (sparse topic patterns, 2000HVGs) 
highly_genes_index = [int(gene_idx) for gene_idx in list(adata1.var.index)]
raw_data = np.ceil(adata1.raw.X[:, highly_genes_index]).astype(int)

count = [adata1.X, adata2.X]
args.n_clusters = int(max(y) - min(y) + 1)

## Initialize model 

In [14]:
# Initialize model 
model = scMTO(n_z=args.latent_dim, n_clusters=args.n_clusters, x_raw=raw_data, device=device).to(device)

Calculating the cell graph...
Topic modeling...


## Pre-training stage

In [15]:
# Pre-training stage
model.pretrain(x=count, raw_data=raw_data, pre_lr=args.pre_lr, pre_epoch=args.pre_epoch)

scMTO pretraining...
0: Pretraining Loss:0.395810
1: Pretraining Loss:0.386074
2: Pretraining Loss:0.378549
3: Pretraining Loss:0.370711
4: Pretraining Loss:0.360898
5: Pretraining Loss:0.348554
6: Pretraining Loss:0.333225
7: Pretraining Loss:0.315368
8: Pretraining Loss:0.295715
9: Pretraining Loss:0.274909
10: Pretraining Loss:0.254963
11: Pretraining Loss:0.236771
12: Pretraining Loss:0.221970
13: Pretraining Loss:0.210635
14: Pretraining Loss:0.202552
15: Pretraining Loss:0.197194
16: Pretraining Loss:0.193612
17: Pretraining Loss:0.190770
18: Pretraining Loss:0.187992
19: Pretraining Loss:0.185721
20: Pretraining Loss:0.183317
21: Pretraining Loss:0.181183
22: Pretraining Loss:0.179022
23: Pretraining Loss:0.177127
24: Pretraining Loss:0.175174
25: Pretraining Loss:0.173506
26: Pretraining Loss:0.171849
27: Pretraining Loss:0.170250
28: Pretraining Loss:0.168920
29: Pretraining Loss:0.167585
30: Pretraining Loss:0.166134
31: Pretraining Loss:0.164913
32: Pretraining Loss:0.164002

## Clustering stage

In [16]:
# Clustering stage
z, w, h, pred, acc, nmi, ari = model.fit(x=count, y=y, raw_data=raw_data, lr=args.lr, train_epoch=args.train_epoch)

Clustering initialization: K-means: ACC: 0.975607202680067, NMI: 0.8674018678751776, ARI:0.9407655633996225
scMTO clustering...
0 :ACC 0.9750837521, NMI 0.8657787203, ARI 0.9398070068
0: Total Loss 0.2580, ZINB Loss 0.1317, OT Loss 0.3246, KL Loss 0.0829
delta_label  0.0036641541038525964 < tol  0.005
Reached tolerance threshold. Stopping training.


## Evaluation

In [17]:
# Cluster performance evaluated by ACC, NMI, ARI metrics
if y is not None:
    print(f'The End: ACC {acc:.4f}, NMI {nmi:.4f}, ARI {ari:.4f}')

The End: ACC 0.9758, NMI 0.8685, ARI 0.9409
