In [1]:
datasets= ["mouse_kidney_cell","human_pbmc_cell","human_pbmc2_cell","human_skin_cell","human_lung_cell","mouse_kidney_drop","mouse_kidney_cl2","mouse_kidney_10x"]

In [2]:
import argparse

import numpy as np

from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.graphsc import GraphSC
from dance.utils import set_seed


parser = argparse.ArgumentParser()
parser.add_argument("-e", "--epochs", default=100, type=int)
parser.add_argument("-dv", "--device", default="cpu")
parser.add_argument("-if", "--in_feats", default=50, type=int)
parser.add_argument("-bs", "--batch_size", default=128, type=int)
parser.add_argument("-nw", "--normalize_weights", default="log_per_cell", choices=["log_per_cell", "per_cell"])
parser.add_argument("-ac", "--activation", default="relu", choices=["leaky_relu", "relu", "prelu", "gelu"])
parser.add_argument("-drop", "--dropout", default=0.1, type=float)
parser.add_argument("-nf", "--node_features", default="scale", choices=["scale_by_cell", "scale", "none"])
parser.add_argument("-sev", "--same_edge_values", default=False, action="store_true")
parser.add_argument("-en", "--edge_norm", default=True, action="store_true")
parser.add_argument("-hr", "--hidden_relu", default=False, action="store_true")
parser.add_argument("-hbn", "--hidden_bn", default=False, action="store_true")
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-5)
parser.add_argument("-nl", "--n_layers", type=int, default=1, choices=[1, 2])
parser.add_argument("-agg", "--agg", default="sum", choices=["sum", "mean"])
parser.add_argument("-hd", "--hidden_dim", type=int, default=200)
parser.add_argument("-nh", "--n_hidden", type=int, default=1, choices=[0, 1, 2])
parser.add_argument("-h1", "--hidden_1", type=int, default=300)
parser.add_argument("-h2", "--hidden_2", type=int, default=0)
parser.add_argument("-ng", "--nb_genes", type=int, default=3000)
parser.add_argument("-nr", "--num_run", type=int, default=1)
parser.add_argument("-nbw", "--num_workers", type=int, default=1)
parser.add_argument("-eve", "--eval_epoch", action="store_true")
parser.add_argument("-show", "--show_epoch_ari", action="store_true")
parser.add_argument("-plot", "--plot", default=False, action="store_true")
parser.add_argument("-dd", "--data_dir", default="./data", type=str)
parser.add_argument("-data", "--dataset", default="10X_PBMC",
                    choices=["10X_PBMC", "mouse_bladder_cell", "mouse_ES_cell", "worm_neuron_cell","mouse_kidney_cell","human_pbmc_cell","human_pbmc2_cell","human_skin_cell","human_lung_cell","mouse_kidney_drop","mouse_kidney_cl2","mouse_kidney_10x"])
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--cache", action="store_true", help="Cache processed data.")
graphsc_scores=[]
for dataset in datasets:
    args = parser.parse_args(args=['--dataset',dataset])
    set_seed(args.seed)

    # Load data and perform necessary preprocessing
    dataloader = ClusteringDataset(args.data_dir, args.dataset)
    preprocessing_pipeline = GraphSC.preprocessing_pipeline(
        n_top_genes=args.nb_genes,
        normalize_weights=args.normalize_weights,
        n_components=args.in_feats,
        normalize_edges=args.edge_norm,
    )
    data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)

    graph, y = data.get_train_data()
    n_clusters = len(np.unique(y))

    # Evaluate model for several runs
    for run in range(args.num_run):
        set_seed(args.seed + run)
        model = GraphSC(agg=args.agg, activation=args.activation, in_feats=args.in_feats, n_hidden=args.n_hidden,
                        hidden_dim=args.hidden_dim, hidden_1=args.hidden_1, hidden_2=args.hidden_2,
                        dropout=args.dropout, n_layers=args.n_layers, hidden_relu=args.hidden_relu,
                        hidden_bn=args.hidden_bn, n_clusters=n_clusters, cluster_method="leiden",
                        num_workers=args.num_workers, device=args.device)
        model.fit(graph, epochs=args.epochs, lr=args.learning_rate, show_epoch_ari=args.show_epoch_ari,
                  eval_epoch=args.eval_epoch)
        score = model.score(None, y)
        print(f"{score=:.4f}")
        graphsc_scores.append(score)
""" Reproduction information
10X PBMC:
python graphsc.py --dataset 10X_PBMC

Mouse ES:
python graphsc.py --dataset mouse_ES_cell

Worm Neuron:
python graphsc.py --dataset worm_neuron_cell

Mouse Bladder:
python graphsc.py --dataset mouse_bladder_cell
"""


[INFO][2023-09-02 15:44:15,069][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 15:44:15,235][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 15:44:15,236][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 3}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._highly_variable_genes.highly_variable_genes, func_kwargs={'min_mean': 0.0125, 'max_mean': 4, 'flavor': 'cell_ranger', 'min_disp': 0.5, 'n_top_genes': 3000, 'subset': True}),
  AnnDataTransform(func=scanpy.preproc



[INFO][2023-09-02 15:44:15,588][dance.WeightedFeaturePCA][__call__] Normalizing feature before PCA decomposition with mode=standardize and axis=0
[INFO][2023-09-02 15:44:15,597][dance.WeightedFeaturePCA][__call__] Start decomposing None features (902, 3000) (k=50)
[INFO][2023-09-02 15:44:15,636][dance.WeightedFeaturePCA][__call__] Total explained variance: 47.53%
[INFO][2023-09-02 15:44:15,718][dance.CellFeatureGraph][__call__] Number of nonzero entries: 1,423,104
[INFO][2023-09-02 15:44:15,719][dance.CellFeatureGraph][__call__] Nonzero rate = 52.6%
[INFO][2023-09-02 15:44:16,445][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': 'CellFeatureGraph',
 'feature_channel_type': 'uns',
 'label_channel': 'Group'}
[INFO][2023-09-02 15:44:16,445][dance][set_config_from_dict] Setting config 'feature_channel' to 'CellFeatureGraph'
[INFO][2023-09-02 15:44:16,446][dance][set_config_from_dict] Setting config 'feature_channel_type' to 'uns'
[INFO][2023-09-

score=0.4672


[INFO][2023-09-02 15:46:19,733][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 648 × 64535
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 15:46:19,734][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 3}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._highly_variable_genes.highly_variable_genes, func_kwargs={'min_mean': 0.0125, 'max_mean': 4, 'flavor': 'cell_ranger', 'min_disp': 0.5, 'n_top_genes': 3000, 'subset': True}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessin



[INFO][2023-09-02 15:46:20,202][dance.WeightedFeaturePCA][__call__] Normalizing feature before PCA decomposition with mode=standardize and axis=0
[INFO][2023-09-02 15:46:20,207][dance.WeightedFeaturePCA][__call__] Start decomposing None features (648, 3000) (k=50)
[INFO][2023-09-02 15:46:20,233][dance.WeightedFeaturePCA][__call__] Total explained variance: 17.32%
[INFO][2023-09-02 15:46:20,253][dance.CellFeatureGraph][__call__] Number of nonzero entries: 114,062
[INFO][2023-09-02 15:46:20,254][dance.CellFeatureGraph][__call__] Nonzero rate = 5.9%
[INFO][2023-09-02 15:46:20,820][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': 'CellFeatureGraph',
 'feature_channel_type': 'uns',
 'label_channel': 'Group'}
[INFO][2023-09-02 15:46:20,820][dance][set_config_from_dict] Setting config 'feature_channel' to 'CellFeatureGraph'
[INFO][2023-09-02 15:46:20,821][dance][set_config_from_dict] Setting config 'feature_channel_type' to 'uns'
[INFO][2023-09-02 

score=0.5969


[INFO][2023-09-02 15:46:32,055][dance.WeightedFeaturePCA][__call__] Normalizing feature before PCA decomposition with mode=standardize and axis=0
[INFO][2023-09-02 15:46:32,078][dance.WeightedFeaturePCA][__call__] Start decomposing None features (8218, 1000) (k=50)
[INFO][2023-09-02 15:46:32,144][dance.WeightedFeaturePCA][__call__] Total explained variance: 29.20%
[INFO][2023-09-02 15:46:32,314][dance.CellFeatureGraph][__call__] Number of nonzero entries: 3,923,204
[INFO][2023-09-02 15:46:32,315][dance.CellFeatureGraph][__call__] Nonzero rate = 47.7%
[INFO][2023-09-02 15:46:33,993][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': 'CellFeatureGraph',
 'feature_channel_type': 'uns',
 'label_channel': 'Group'}
[INFO][2023-09-02 15:46:33,994][dance][set_config_from_dict] Setting config 'feature_channel' to 'CellFeatureGraph'
[INFO][2023-09-02 15:46:33,995][dance][set_config_from_dict] Setting config 'feature_channel_type' to 'uns'
[INFO][2023-09

score=0.2442


[INFO][2023-09-02 15:52:53,704][dance.WeightedFeaturePCA][__call__] Normalizing feature before PCA decomposition with mode=standardize and axis=0
[INFO][2023-09-02 15:52:53,718][dance.WeightedFeaturePCA][__call__] Start decomposing None features (4853, 1000) (k=50)
[INFO][2023-09-02 15:52:53,767][dance.WeightedFeaturePCA][__call__] Total explained variance: 46.04%
[INFO][2023-09-02 15:52:53,885][dance.CellFeatureGraph][__call__] Number of nonzero entries: 2,819,267
[INFO][2023-09-02 15:52:53,886][dance.CellFeatureGraph][__call__] Nonzero rate = 58.1%
[INFO][2023-09-02 15:52:54,999][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': 'CellFeatureGraph',
 'feature_channel_type': 'uns',
 'label_channel': 'Group'}
[INFO][2023-09-02 15:52:55,000][dance][set_config_from_dict] Setting config 'feature_channel' to 'CellFeatureGraph'
[INFO][2023-09-02 15:52:55,000][dance][set_config_from_dict] Setting config 'feature_channel_type' to 'uns'
[INFO][2023-09

score=0.4977


[INFO][2023-09-02 15:57:06,967][dance.WeightedFeaturePCA][__call__] Normalizing feature before PCA decomposition with mode=standardize and axis=0
[INFO][2023-09-02 15:57:06,972][dance.WeightedFeaturePCA][__call__] Start decomposing None features (1756, 999) (k=50)
[INFO][2023-09-02 15:57:07,000][dance.WeightedFeaturePCA][__call__] Total explained variance: 29.91%
[INFO][2023-09-02 15:57:07,026][dance.CellFeatureGraph][__call__] Number of nonzero entries: 359,184
[INFO][2023-09-02 15:57:07,027][dance.CellFeatureGraph][__call__] Nonzero rate = 20.5%
[INFO][2023-09-02 15:57:07,521][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': 'CellFeatureGraph',
 'feature_channel_type': 'uns',
 'label_channel': 'Group'}
[INFO][2023-09-02 15:57:07,522][dance][set_config_from_dict] Setting config 'feature_channel' to 'CellFeatureGraph'
[INFO][2023-09-02 15:57:07,522][dance][set_config_from_dict] Setting config 'feature_channel_type' to 'uns'
[INFO][2023-09-02

score=0.5912


[INFO][2023-09-02 15:57:31,516][dance.WeightedFeaturePCA][__call__] Normalizing feature before PCA decomposition with mode=standardize and axis=0
[INFO][2023-09-02 15:57:31,519][dance.WeightedFeaturePCA][__call__] Start decomposing None features (225, 3000) (k=50)
[INFO][2023-09-02 15:57:31,532][dance.WeightedFeaturePCA][__call__] Total explained variance: 46.23%
[INFO][2023-09-02 15:57:31,544][dance.CellFeatureGraph][__call__] Number of nonzero entries: 235,178
[INFO][2023-09-02 15:57:31,545][dance.CellFeatureGraph][__call__] Nonzero rate = 34.8%
[INFO][2023-09-02 15:57:32,052][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': 'CellFeatureGraph',
 'feature_channel_type': 'uns',
 'label_channel': 'Group'}
[INFO][2023-09-02 15:57:32,053][dance][set_config_from_dict] Setting config 'feature_channel' to 'CellFeatureGraph'
[INFO][2023-09-02 15:57:32,053][dance][set_config_from_dict] Setting config 'feature_channel_type' to 'uns'
[INFO][2023-09-02




[INFO][2023-09-02 15:57:51,724][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 15:57:51,748][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 225 × 15127
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 15:57:51,749][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 3}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._highly_variable_genes.highly_variable_genes, func_kwargs={'min_mean': 0.0125, 'max_mean': 4, 'flavor': 'cell_ranger', 'min_disp': 0.5, 'n_top_genes': 3000, 'subset': True}),
  AnnDataTransform(func=scanpy.prepro

score=0.9097


[INFO][2023-09-02 15:57:51,853][dance.WeightedFeaturePCA][__call__] Normalizing feature before PCA decomposition with mode=standardize and axis=0
[INFO][2023-09-02 15:57:51,855][dance.WeightedFeaturePCA][__call__] Start decomposing None features (225, 3000) (k=50)
[INFO][2023-09-02 15:57:51,868][dance.WeightedFeaturePCA][__call__] Total explained variance: 46.23%
[INFO][2023-09-02 15:57:51,881][dance.CellFeatureGraph][__call__] Number of nonzero entries: 235,178
[INFO][2023-09-02 15:57:51,881][dance.CellFeatureGraph][__call__] Nonzero rate = 34.8%
[INFO][2023-09-02 15:57:52,387][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': 'CellFeatureGraph',
 'feature_channel_type': 'uns',
 'label_channel': 'Group'}
[INFO][2023-09-02 15:57:52,388][dance][set_config_from_dict] Setting config 'feature_channel' to 'CellFeatureGraph'
[INFO][2023-09-02 15:57:52,388][dance][set_config_from_dict] Setting config 'feature_channel_type' to 'uns'
[INFO][2023-09-02




[INFO][2023-09-02 15:58:12,368][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 15:58:12,500][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 15:58:12,500][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 3}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._highly_variable_genes.highly_variable_genes, func_kwargs={'min_mean': 0.0125, 'max_mean': 4, 'flavor': 'cell_ranger', 'min_disp': 0.5, 'n_top_genes': 3000, 'subset': True}),
  AnnDataTransform(func=scanpy.prepro

score=0.9097


[INFO][2023-09-02 15:58:12,974][dance.WeightedFeaturePCA][__call__] Normalizing feature before PCA decomposition with mode=standardize and axis=0
[INFO][2023-09-02 15:58:12,983][dance.WeightedFeaturePCA][__call__] Start decomposing None features (902, 3000) (k=50)
[INFO][2023-09-02 15:58:13,018][dance.WeightedFeaturePCA][__call__] Total explained variance: 47.53%
[INFO][2023-09-02 15:58:13,079][dance.CellFeatureGraph][__call__] Number of nonzero entries: 1,423,104
[INFO][2023-09-02 15:58:13,080][dance.CellFeatureGraph][__call__] Nonzero rate = 52.6%
[INFO][2023-09-02 15:58:13,786][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': 'CellFeatureGraph',
 'feature_channel_type': 'uns',
 'label_channel': 'Group'}
[INFO][2023-09-02 15:58:13,787][dance][set_config_from_dict] Setting config 'feature_channel' to 'CellFeatureGraph'
[INFO][2023-09-02 15:58:13,787][dance][set_config_from_dict] Setting config 'feature_channel_type' to 'uns'
[INFO][2023-09-

score=0.4672


' Reproduction information\n10X PBMC:\npython graphsc.py --dataset 10X_PBMC\n\nMouse ES:\npython graphsc.py --dataset mouse_ES_cell\n\nWorm Neuron:\npython graphsc.py --dataset worm_neuron_cell\n\nMouse Bladder:\npython graphsc.py --dataset mouse_bladder_cell\n'

In [3]:
graphsc_scores

[0.4672362410138263,
 0.5968663719240431,
 0.24416902238895835,
 0.49769852192025643,
 0.5912071617671083,
 0.9097075531594662,
 0.9097075531594662,
 0.4672362410138263]

In [4]:
import argparse
import os

import numpy as np

from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.scdcc import ScDCC
from dance.transforms.preprocess import generate_random_pair
from dance.utils import set_seed

parser = argparse.ArgumentParser(description="train", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--label_cells", default=0.1, type=float)
parser.add_argument("--label_cells_files", default="label_mouse_ES_cell.txt")
parser.add_argument("--n_pairwise", default=0, type=int)
parser.add_argument("--n_pairwise_error", default=0, type=float)
parser.add_argument("--batch_size", default=256, type=int)
parser.add_argument("--data_dir", default="./data")
parser.add_argument("--dataset", default="mouse_ES_cell", type=str,
                    choices=["10X_PBMC", "mouse_bladder_cell", "mouse_ES_cell", "worm_neuron_cell","mouse_kidney_cell","human_pbmc_cell","human_pbmc2_cell","human_skin_cell","human_lung_cell","mouse_kidney_drop","mouse_kidney_cl2","mouse_kidney_10x"])
parser.add_argument("--epochs", default=500, type=int)
parser.add_argument("--pretrain_epochs", default=50, type=int)
parser.add_argument("--lr", default=0.01, type=float)
parser.add_argument("--pretrain_lr", default=0.001, type=float)
parser.add_argument("--sigma", default=2.5, type=float, help="coefficient of Gaussian noise")
parser.add_argument("--gamma", default=1., type=float, help="coefficient of clustering loss")
parser.add_argument("--ml_weight", default=1., type=float, help="coefficient of must-link loss")
parser.add_argument("--cl_weight", default=1., type=float, help="coefficient of cannot-link loss")
parser.add_argument("--update_interval", default=1, type=int)
parser.add_argument("--tol", default=0.00001, type=float)
parser.add_argument("--ae_weights", default=None)
parser.add_argument("--ae_weight_file", default="AE_weights.pth.tar")
parser.add_argument("--device", default="cpu")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--cache", action="store_true", help="Cache processed data.")
scdcc_scores=[]
for dataset in datasets:
    args = parser.parse_args(['--dataset',dataset,'--label_cells_files','label_mouse_kidney.txt','--gamma','1.5'])
    set_seed(args.seed)

    # Load data and perform necessary preprocessing
    dataloader = ClusteringDataset(args.data_dir, args.dataset)
    preprocessing_pipeline = ScDCC.preprocessing_pipeline()
    data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)

    # inputs: x, x_raw, n_counts
    inputs, y = data.get_train_data()
    n_clusters = len(np.unique(y))
    in_dim = inputs[0].shape[1]

    # Generate random pairs
    if not os.path.exists(args.label_cells_files):
        indx = np.arange(len(y))
        np.random.shuffle(indx)
        label_cell_indx = indx[0:int(np.ceil(args.label_cells * len(y)))]
    else:
        label_cell_indx = np.loadtxt(args.label_cells_files, dtype=np.int)

    if args.n_pairwise > 0:
        ml_ind1, ml_ind2, cl_ind1, cl_ind2, error_num = generate_random_pair(y, label_cell_indx, args.n_pairwise,
                                                                             args.n_pairwise_error)
        print("Must link paris: %d" % ml_ind1.shape[0])
        print("Cannot link paris: %d" % cl_ind1.shape[0])
        print("Number of error pairs: %d" % error_num)
    else:
        ml_ind1, ml_ind2, cl_ind1, cl_ind2 = np.array([]), np.array([]), np.array([]), np.array([])

    # Build and train moodel
    model = ScDCC(input_dim=in_dim, z_dim=32, n_clusters=n_clusters, encodeLayer=[256, 64], decodeLayer=[64, 256],
                  sigma=args.sigma, gamma=args.gamma, ml_weight=args.ml_weight, cl_weight=args.ml_weight,
                  device=args.device, pretrain_path=f"scdcc_{args.dataset}_pre.pkl")
    model.fit(inputs, y, lr=args.lr, batch_size=args.batch_size, epochs=args.epochs, ml_ind1=ml_ind1, ml_ind2=ml_ind2,
              cl_ind1=cl_ind1, cl_ind2=cl_ind2, update_interval=args.update_interval, tol=args.tol,
              pt_batch_size=args.batch_size, pt_lr=args.pretrain_lr, pt_epochs=args.pretrain_epochs)

    # Evaluate model predictions
    score = model.score(None, y)
    print(f"{score=:.4f}")
    scdcc_scores.append(score)
""" Reproduction information
10X PBMC:
python scdcc.py --dataset 10X_PBMC --label_cells_files label_10X_PBMC.txt --gamma=1.5

Mouse ES:
python scdcc.py --dataset mouse_ES_cell --label_cells_files label_mouse_ES_cell.txt --gamma 1 --ml_weight 0.8 --cl_weight 0.8

Worm Neuron:
python scdcc.py --dataset worm_neuron_cell --label_cells_files label_worm_neuron_cell.txt --gamma 1 --pretrain_epochs 300

Mouse Bladder:
python scdcc.py --dataset mouse_bladder_cell --label_cells_files label_mouse_bladder_cell.txt --gamma 1.5 --pretrain_epochs 100 --sigma 3
"""


[INFO][2023-09-02 16:00:11,499][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 16:00:11,625][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:11,626][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.scale, func_kwargs={}),
  SetConfig(config_dict={'feature_channel': [None, None, 'n_counts'], 'feature_channel_type': ['X', 'raw_X', 'obs'], 'label_channel': 'Group'}),
)
[INFO][2023-09-02 16:

score=0.9965


[INFO][2023-09-02 16:00:13,466][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 648 × 64535
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:13,467][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.scale, func_kwargs={}),
  SetConfig(config_dict={'feature_channel': [None, None, 'n_counts'], 'feature_channel_type': ['X', 'raw_X', 'obs'], 'label_channel': 'Group'}),
)
[INFO][2023-09-02 16:00:13,660][dance.SaveRaw][__call__] Saving data to ``.raw``
[INFO][2023-09-02 16:0

score=0.2302


[INFO][2023-09-02 16:00:17,046][dance][fit] #Epoch   1: Total: 0.9746, Clustering Loss: 0.2618, ZINB Loss: 0.7129
[INFO][2023-09-02 16:00:17,599][dance][fit] #Epoch   2: Total: 0.9754, Clustering Loss: 0.2626, ZINB Loss: 0.7129
[INFO][2023-09-02 16:00:18,151][dance][fit] #Epoch   3: Total: 0.9743, Clustering Loss: 0.2623, ZINB Loss: 0.7120
[INFO][2023-09-02 16:00:18,719][dance][fit] #Epoch   4: Total: 0.9751, Clustering Loss: 0.2618, ZINB Loss: 0.7133
[INFO][2023-09-02 16:00:19,273][dance][fit] #Epoch   5: Total: 0.9747, Clustering Loss: 0.2610, ZINB Loss: 0.7136
[INFO][2023-09-02 16:00:19,825][dance][fit] #Epoch   6: Total: 0.9739, Clustering Loss: 0.2602, ZINB Loss: 0.7136
[INFO][2023-09-02 16:00:20,377][dance][fit] #Epoch   7: Total: 0.9722, Clustering Loss: 0.2593, ZINB Loss: 0.7130
[INFO][2023-09-02 16:00:20,927][dance][fit] #Epoch   8: Total: 0.9721, Clustering Loss: 0.2582, ZINB Loss: 0.7139
[INFO][2023-09-02 16:00:21,479][dance][fit] #Epoch   9: Total: 0.9707, Clustering Loss: 

score=0.5339


[INFO][2023-09-02 16:00:22,355][dance][fit] #Epoch   1: Total: 0.9328, Clustering Loss: 0.3156, ZINB Loss: 0.6172
[INFO][2023-09-02 16:00:22,683][dance][fit] #Epoch   2: Total: 0.9370, Clustering Loss: 0.3203, ZINB Loss: 0.6168
[INFO][2023-09-02 16:00:23,007][dance][fit] #Epoch   3: Total: 0.9392, Clustering Loss: 0.3230, ZINB Loss: 0.6162
[INFO][2023-09-02 16:00:23,330][dance][fit] #Epoch   4: Total: 0.9411, Clustering Loss: 0.3249, ZINB Loss: 0.6162
[INFO][2023-09-02 16:00:23,660][dance][fit] #Epoch   5: Total: 0.9435, Clustering Loss: 0.3263, ZINB Loss: 0.6172
[INFO][2023-09-02 16:00:23,984][dance][fit] #Epoch   6: Total: 0.9426, Clustering Loss: 0.3275, ZINB Loss: 0.6152
[INFO][2023-09-02 16:00:24,308][dance][fit] #Epoch   7: Total: 0.9441, Clustering Loss: 0.3284, ZINB Loss: 0.6156
[INFO][2023-09-02 16:00:24,631][dance][fit] #Epoch   8: Total: 0.9455, Clustering Loss: 0.3292, ZINB Loss: 0.6162
[INFO][2023-09-02 16:00:24,956][dance][fit] #Epoch   9: Total: 0.9464, Clustering Loss: 

score=0.6258


[INFO][2023-09-02 16:00:31,819][dance][fit] #Epoch   1: Total: 0.7320, Clustering Loss: 0.2235, ZINB Loss: 0.5085
[INFO][2023-09-02 16:00:31,844][dance][fit] Reach tolerance threshold (0.000e+00 < 1.000e-05). Stopping training.
[INFO][2023-09-02 16:00:31,846][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 16:00:31,871][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 225 × 15127
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:31,871][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransfor

score=0.7682


[INFO][2023-09-02 16:00:32,254][dance][fit] #Epoch   1: Total: 0.7840, Clustering Loss: 0.1484, ZINB Loss: 0.6356
[INFO][2023-09-02 16:00:32,283][dance][fit] Reach tolerance threshold (0.000e+00 < 1.000e-05). Stopping training.
[INFO][2023-09-02 16:00:32,286][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 16:00:32,309][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 225 × 15127
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:32,309][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransfor

score=0.8688


[INFO][2023-09-02 16:00:32,671][dance][fit] #Epoch   1: Total: 0.7840, Clustering Loss: 0.1484, ZINB Loss: 0.6356
[INFO][2023-09-02 16:00:32,702][dance][fit] Reach tolerance threshold (0.000e+00 < 1.000e-05). Stopping training.
[INFO][2023-09-02 16:00:32,704][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 16:00:32,780][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:32,780][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransfor

score=0.8688


[INFO][2023-09-02 16:00:32,957][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': [None, None, 'n_counts'],
 'feature_channel_type': ['X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:00:32,957][dance][set_config_from_dict] Setting config 'feature_channel' to [None, None, 'n_counts']
[INFO][2023-09-02 16:00:32,958][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['X', 'raw_X', 'obs']
[INFO][2023-09-02 16:00:32,959][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:00:32,960][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    obs: 'n_counts'
    var: 'n_counts', 'mean', 'std'
    uns: 'dance_config', 'log1p'
    obsm: 'Group'
[INFO][2023-09-02 16:00:32,960][dance][wrapped_func] Took 0:00:00.255337 to load and process data.
[INFO][2023-09-02 16:00:33,057][dance][_pretrain] Loading pre-trained 

score=1.0000


' Reproduction information\n10X PBMC:\npython scdcc.py --dataset 10X_PBMC --label_cells_files label_10X_PBMC.txt --gamma=1.5\n\nMouse ES:\npython scdcc.py --dataset mouse_ES_cell --label_cells_files label_mouse_ES_cell.txt --gamma 1 --ml_weight 0.8 --cl_weight 0.8\n\nWorm Neuron:\npython scdcc.py --dataset worm_neuron_cell --label_cells_files label_worm_neuron_cell.txt --gamma 1 --pretrain_epochs 300\n\nMouse Bladder:\npython scdcc.py --dataset mouse_bladder_cell --label_cells_files label_mouse_bladder_cell.txt --gamma 1.5 --pretrain_epochs 100 --sigma 3\n'

In [5]:
scdcc_scores

[0.9965362702578435,
 0.2302306982490099,
 0.5338897606315373,
 0.6257754637662535,
 0.7682366319440681,
 0.8688432600660997,
 0.8688432600660997,
 1.0]

In [6]:
import argparse

import numpy as np

from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.scdeepcluster import ScDeepCluster
from dance.utils import set_seed


parser = argparse.ArgumentParser(description="train", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--knn", default=20, type=int,
                    help="number of nearest neighbors, used by the Louvain algorithm")
parser.add_argument(
    "--resolution", default=.8, type=float,
    help="resolution parameter, used by the Louvain algorithm, larger value for more number of clusters")
parser.add_argument("--select_genes", default=0, type=int, help="number of selected genes, 0 means using all genes")
parser.add_argument("--batch_size", default=256, type=int)
parser.add_argument("--data_dir", default="./data")
parser.add_argument("--dataset", default="mouse_bladder_cell", type=str,
                    choices=["10X_PBMC", "mouse_bladder_cell", "mouse_ES_cell", "worm_neuron_cell","mouse_kidney_cell","human_pbmc_cell","human_pbmc2_cell","human_skin_cell","human_lung_cell","mouse_kidney_drop","mouse_kidney_cl2","mouse_kidney_10x"])
parser.add_argument("--epochs", default=500, type=int)
parser.add_argument("--pretrain_epochs", default=50, type=int)
parser.add_argument("--lr", default=0.1, type=float)
parser.add_argument("--pretrain_lr", default=0.001, type=float)
parser.add_argument("--gamma", default=1., type=float, help="coefficient of clustering loss")
parser.add_argument("--sigma", default=2.5, type=float, help="coefficient of random noise")
parser.add_argument("--update_interval", default=1, type=int)
parser.add_argument("--tol", default=0.001, type=float,
                    help="tolerance for delta clustering labels to terminate training stage")
parser.add_argument("--ae_weights", default=None, help="file to pretrained weights, None for a new pretraining")
parser.add_argument("--device", default="cpu")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--cache", action="store_true", help="Cache processed data.")
scdeepcluster_scores=[]
for dataset in datasets:
    args = parser.parse_args(args=['--dataset',dataset])
    set_seed(args.seed)

    # Load data and perform necessary preprocessing
    dataloader = ClusteringDataset(args.data_dir, args.dataset)
    preprocessing_pipeline = ScDeepCluster.preprocessing_pipeline()
    data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)

    # inputs: x, x_raw, n_counts
    inputs, y = data.get_train_data()
    n_clusters = len(np.unique(y))
    in_dim = inputs[0].shape[1]

    # Build and train model
    model = ScDeepCluster(input_dim=in_dim, z_dim=32, encodeLayer=[256, 64], decodeLayer=[64, 256], sigma=args.sigma,
                          gamma=args.gamma, device=args.device, pretrain_path=f"scdeepcluster_{args.dataset}_pre.pkl")
    model.fit(inputs, y, n_clusters=n_clusters, y_pred_init=None, lr=args.lr, batch_size=args.batch_size,
              epochs=args.epochs, update_interval=args.update_interval, tol=args.tol, pt_batch_size=args.batch_size,
              pt_lr=args.pretrain_lr, pt_epochs=args.pretrain_epochs)

    # Evaluate model predictions
    score = model.score(None, y)
    print(f"{score=:.4f}")
    scdeepcluster_scores.append(score)
""" Reproduction information
10X PBMC:
python scdeepcluster.py --dataset 10X_PBMC

Mouse ES:
python scdeepcluster.py --dataset mouse_ES_cell

Worm Neuron:
python scdeepcluster.py --dataset worm_neuron_cell --pretrain_epochs 300

Mouse Bladder:
python scdeepcluster.py --dataset mouse_bladder_cell --pretrain_epochs 300 --sigma 2.75
"""


[INFO][2023-09-02 16:00:33,931][dance][set_seed] Setting global random seed to 42


[INFO][2023-09-02 16:00:34,055][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:34,056][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.scale, func_kwargs={}),
  SetConfig(config_dict={'feature_channel': [None, None, 'n_counts'], 'feature_channel_type': ['X', 'raw_X', 'obs'], 'label_channel': 'Group'}),
)
[INFO][2023-09-02 16:00:34,179][dance.SaveRaw][__call__] Saving data to ``.raw``
[INFO][2023-09-02 16:0

score=0.9965


[INFO][2023-09-02 16:00:35,451][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 648 × 64535
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:35,452][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.scale, func_kwargs={}),
  SetConfig(config_dict={'feature_channel': [None, None, 'n_counts'], 'feature_channel_type': ['X', 'raw_X', 'obs'], 'label_channel': 'Group'}),
)
[INFO][2023-09-02 16:00:35,638][dance.SaveRaw][__call__] Saving data to ``.raw``
[INFO][2023-09-02 16:0

score=0.2403


[INFO][2023-09-02 16:00:44,136][dance][_pretrain] Loading pre-trained model from scdeepcluster_human_pbmc2_cell_pre.pkl
[INFO][2023-09-02 16:00:44,143][dance][fit] Initializing cluster centers with kmeans.
[INFO][2023-09-02 16:00:45,150][dance][fit] Epoch   1: Total: 0.85924018, Clustering Loss: 0.14573694, ZINB Loss: 0.71350324
[INFO][2023-09-02 16:00:45,743][dance][fit] Epoch   2: Total: 0.86499669, Clustering Loss: 0.15021640, ZINB Loss: 0.71478028
[INFO][2023-09-02 16:00:46,288][dance][fit] Epoch   3: Total: 0.86306669, Clustering Loss: 0.14789919, ZINB Loss: 0.71516750
[INFO][2023-09-02 16:00:46,832][dance][fit] Epoch   4: Total: 0.86266700, Clustering Loss: 0.14452439, ZINB Loss: 0.71814260
[INFO][2023-09-02 16:00:47,374][dance][fit] Epoch   5: Total: 0.85929741, Clustering Loss: 0.13954957, ZINB Loss: 0.71974784
[INFO][2023-09-02 16:00:47,923][dance][fit] Epoch   6: Total: 0.85530772, Clustering Loss: 0.13404534, ZINB Loss: 0.72126238
[INFO][2023-09-02 16:00:48,471][dance][fit] 

score=0.5352


[INFO][2023-09-02 16:00:51,484][dance][fit] Epoch   1: Total: 0.79569265, Clustering Loss: 0.17811839, ZINB Loss: 0.61757427
[INFO][2023-09-02 16:00:51,804][dance][fit] Epoch   2: Total: 0.81139577, Clustering Loss: 0.19312188, ZINB Loss: 0.61827388
[INFO][2023-09-02 16:00:52,124][dance][fit] Epoch   3: Total: 0.81578227, Clustering Loss: 0.19511222, ZINB Loss: 0.62067005
[INFO][2023-09-02 16:00:52,443][dance][fit] Epoch   4: Total: 0.81771320, Clustering Loss: 0.19333718, ZINB Loss: 0.62437602
[INFO][2023-09-02 16:00:52,763][dance][fit] Epoch   5: Total: 0.81784995, Clustering Loss: 0.18925953, ZINB Loss: 0.62859042
[INFO][2023-09-02 16:00:53,082][dance][fit] Epoch   6: Total: 0.81448477, Clustering Loss: 0.18354857, ZINB Loss: 0.63093620
[INFO][2023-09-02 16:00:53,402][dance][fit] Epoch   7: Total: 0.81208446, Clustering Loss: 0.17711113, ZINB Loss: 0.63497332
[INFO][2023-09-02 16:00:53,721][dance][fit] Epoch   8: Total: 0.80960464, Clustering Loss: 0.17055195, ZINB Loss: 0.63905269


score=0.6335


[INFO][2023-09-02 16:00:55,674][dance][fit] Epoch   1: Total: 0.65420954, Clustering Loss: 0.14566157, ZINB Loss: 0.50854797
[INFO][2023-09-02 16:00:55,697][dance][fit] Reach tolerance threshold (5.695e-04 < 1.000e-03). Stopping training.
[INFO][2023-09-02 16:00:55,699][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 16:00:55,723][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 225 × 15127
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:55,724][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnD

score=0.7695


[INFO][2023-09-02 16:00:56,081][dance][fit] Epoch   1: Total: 0.73605788, Clustering Loss: 0.09888156, ZINB Loss: 0.63717633
[INFO][2023-09-02 16:00:56,110][dance][fit] Reach tolerance threshold (0.000e+00 < 1.000e-03). Stopping training.
[INFO][2023-09-02 16:00:56,113][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 16:00:56,137][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 225 × 15127
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:56,137][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnD

score=0.8565


[INFO][2023-09-02 16:00:56,491][dance][fit] Epoch   1: Total: 0.73605788, Clustering Loss: 0.09888156, ZINB Loss: 0.63717633
[INFO][2023-09-02 16:00:56,520][dance][fit] Reach tolerance threshold (0.000e+00 < 1.000e-03). Stopping training.
[INFO][2023-09-02 16:00:56,523][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 16:00:56,621][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:56,621][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  SaveRaw(),
  AnnDataTransform(func=scanpy.preprocessing._normalization.normalize_total, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnD

score=0.8565


[INFO][2023-09-02 16:00:56,750][dance.SaveRaw][__call__] Saving data to ``.raw``
[INFO][2023-09-02 16:00:56,840][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': [None, None, 'n_counts'],
 'feature_channel_type': ['X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:00:56,841][dance][set_config_from_dict] Setting config 'feature_channel' to [None, None, 'n_counts']
[INFO][2023-09-02 16:00:56,842][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['X', 'raw_X', 'obs']
[INFO][2023-09-02 16:00:56,843][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:00:56,843][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    obs: 'n_counts'
    var: 'n_counts', 'mean', 'std'
    uns: 'dance_config', 'log1p'
    obsm: 'Group'
[INFO][2023-09-02 16:00:56,844][dance][wrapped_func] Took 0:00:00.320538 to load and pro

score=0.9965


' Reproduction information\n10X PBMC:\npython scdeepcluster.py --dataset 10X_PBMC\n\nMouse ES:\npython scdeepcluster.py --dataset mouse_ES_cell\n\nWorm Neuron:\npython scdeepcluster.py --dataset worm_neuron_cell --pretrain_epochs 300\n\nMouse Bladder:\npython scdeepcluster.py --dataset mouse_bladder_cell --pretrain_epochs 300 --sigma 2.75\n'

In [7]:
scdeepcluster_scores

[0.9965362702578435,
 0.24032257798854678,
 0.5352188341844301,
 0.6334722718620797,
 0.7695415409808991,
 0.8565223251023905,
 0.8565223251023905,
 0.9965362702578435]

In [8]:
import argparse

from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.scdsc import ScDSC
from dance.utils import set_seed


parser = argparse.ArgumentParser()

# model_para = [n_enc_1(n_dec_3), n_enc_2(n_dec_2), n_enc_3(n_dec_1)]
model_para = [512, 256, 256]
# Cluster_para = [n_z1, n_z2, n_z3, n_init, n_input, n_clusters]
Cluster_para = [256, 128, 32, 20, 100, 10]
# Balance_para = [binary_crossentropy_loss, ce_loss, re_loss, zinb_loss, sigma]
Balance_para = [1, 0.01, 0.1, 0.1, 1]

parser.add_argument("--data_dir", default="./data")
parser.add_argument("--dataset", type=str, default="worm_neuron_cell",
                    choices=["10X_PBMC", "mouse_bladder_cell", "mouse_ES_cell", "worm_neuron_cell","mouse_kidney_cell","human_pbmc_cell","human_pbmc2_cell","human_skin_cell","human_lung_cell","mouse_kidney_drop","mouse_kidney_cl2","mouse_kidney_10x"])
# TODO: implement callbacks for "heat_kernel" and "cosine_normalized"
parser.add_argument("--method", type=str, default="correlation", choices=["cosine", "correlation"])
parser.add_argument("--batch_size", default=256, type=int)
parser.add_argument("--n_enc_1", default=model_para[0], type=int)
parser.add_argument("--n_enc_2", default=model_para[1], type=int)
parser.add_argument("--n_enc_3", default=model_para[2], type=int)
parser.add_argument("--n_dec_1", default=model_para[2], type=int)
parser.add_argument("--n_dec_2", default=model_para[1], type=int)
parser.add_argument("--n_dec_3", default=model_para[0], type=int)
parser.add_argument("--topk", type=int, default=50)
parser.add_argument("--lr", type=float, default=1e-2)
parser.add_argument("--pretrain_lr", type=float, default=1e-3)
parser.add_argument("--pretrain_epochs", type=int, default=200)
parser.add_argument("--epochs", type=int, default=1000)
parser.add_argument("--n_z1", default=Cluster_para[0], type=int)
parser.add_argument("--n_z2", default=Cluster_para[1], type=int)
parser.add_argument("--n_z3", default=Cluster_para[2], type=int)
parser.add_argument("--n_input", type=int, default=Cluster_para[4])
parser.add_argument("--n_clusters", type=int, default=Cluster_para[5])
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--v", type=int, default=1)
parser.add_argument("--nb_genes", type=int, default=2000)
parser.add_argument("--binary_crossentropy_loss", type=float, default=Balance_para[0])
parser.add_argument("--ce_loss", type=float, default=Balance_para[1])
parser.add_argument("--re_loss", type=float, default=Balance_para[2])
parser.add_argument("--zinb_loss", type=float, default=Balance_para[3])
parser.add_argument("--sigma", type=float, default=Balance_para[4])
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--cache", action="store_true", help="Cache processed data.")
scdsc_scores=[]
for dataset in datasets:
    args = parser.parse_args(['--dataset', dataset, '--method','cosine','--topk', '30', '--v', '7', '--binary_crossentropy_loss', '0.75','--ce_loss', '0.5', '--re_loss','0.1' ,'--zinb_loss','2.5','--sigma', '0.4'])
    set_seed(args.seed)

    # Load data and perform necessary preprocessing
    dataloader = ClusteringDataset(args.data_dir, args.dataset)
    preprocessing_pipeline = ScDSC.preprocessing_pipeline(n_top_genes=args.nb_genes, n_neighbors=args.topk)
    data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)

    # inputs: adj, x, x_raw, n_counts
    inputs, y = data.get_data(return_type="default")
    args.n_input = inputs[1].shape[1]

    model = ScDSC(pretrain_path=f"scdsc_{args.dataset}_pre.pkl", sigma=args.sigma, n_enc_1=args.n_enc_1,
                  n_enc_2=args.n_enc_2, n_enc_3=args.n_enc_3, n_dec_1=args.n_dec_1, n_dec_2=args.n_dec_2,
                  n_dec_3=args.n_dec_3, n_z1=args.n_z1, n_z2=args.n_z2, n_z3=args.n_z3, n_clusters=args.n_clusters,
                  n_input=args.n_input, v=args.v, device=args.device)

    # Build and train model
    model.fit(inputs, y, lr=args.lr, epochs=args.epochs, bcl=args.binary_crossentropy_loss, cl=args.ce_loss,
              rl=args.re_loss, zl=args.zinb_loss, pt_epochs=args.pretrain_epochs, pt_batch_size=args.batch_size,
              pt_lr=args.pretrain_lr)

    # Evaluate model predictions
    score = model.score(None, y)
    print(f"{score=:.4f}")
    scdsc_scores.append(score)
"""Reproduction information
10X PBMC:
python scdsc.py --dataset 10X_PBMC --method cosine --topk 30 --v 7 --binary_crossentropy_loss 0.75 --ce_loss 0.5 --re_loss 0.1 --zinb_loss 2.5 --sigma 0.4

Mouse Bladder:
python scdsc.py --dataset mouse_bladder_cell --topk 50 --v 7 --binary_crossentropy_loss 2.5 --ce_loss 0.1 --re_loss 0.5 --zinb_loss 1.5 --sigma 0.6

Mouse ES:
python scdsc.py --dataset mouse_ES_cell --topk 50 --v 7 --binary_crossentropy_loss 0.1 --ce_loss 0.01 --re_loss 1.5 --zinb_loss 0.5 --sigma 0.1

Worm Neuron:
python scdsc.py --dataset worm_neuron_cell --topk 20 --v 7 --binary_crossentropy_loss 2 --ce_loss 2 --re_loss 3 --zinb_loss 0.1 --sigma 0.4
"""


[INFO][2023-09-02 16:00:57,799][dance][set_seed] Setting global random seed to 42


[INFO][2023-09-02 16:00:57,887][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:00:57,888][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 3}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.normalize_per_cell, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._highly_variable_genes.highly_variable_genes, func_kwargs={'min_mean': 0.0125, 'max_mean': 4, 'flavor': 'cell_ranger', 'min_disp': 0.5, 'n_top_genes': 2000, 'subset': True}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=s



[INFO][2023-09-02 16:00:58,168][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:01:00,596][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:01:00,596][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:01:00,597][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['obsp', 'X', 'raw_X', 'obs']
[INFO][2023-09-02 16:01:00,598][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:01:00,599][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 2000
    obs: 'n_counts'
    var: 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean

score=0.5722


[INFO][2023-09-02 16:02:04,636][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 648 × 64535
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:02:04,636][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 3}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.normalize_per_cell, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._highly_variable_genes.highly_variable_genes, func_kwargs={'min_mean': 0.0125, 'max_mean': 4, 'flavor': 'cell_ranger', 'min_disp': 0.5, 'n_top_genes': 2000, 'subset': True}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=s



[INFO][2023-09-02 16:02:05,005][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:02:05,328][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:02:05,329][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:02:05,330][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['obsp', 'X', 'raw_X', 'obs']
[INFO][2023-09-02 16:02:05,331][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:02:05,332][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 648 × 2000
    obs: 'n_counts'
    var: 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean

score=0.2168


[INFO][2023-09-02 16:02:55,957][dance.SaveRaw][__call__] Saving data to ``.raw``




[INFO][2023-09-02 16:02:56,010][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:03:03,881][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:03:03,884][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:03:03,885][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['obsp', 'X', 'raw_X', 'obs']
[INFO][2023-09-02 16:03:03,886][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:03:03,888][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 8218 × 1000
    obs: 'n_counts'
    var: 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mea

score=0.3234


[INFO][2023-09-02 16:10:23,223][dance.SaveRaw][__call__] Saving data to ``.raw``




[INFO][2023-09-02 16:10:23,253][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:10:24,814][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:10:24,817][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:10:24,818][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['obsp', 'X', 'raw_X', 'obs']
[INFO][2023-09-02 16:10:24,819][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:10:24,820][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 4853 × 1000
    obs: 'n_counts'
    var: 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mea

score=0.5530


[INFO][2023-09-02 16:15:00,277][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:15:01,371][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:15:01,372][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:15:01,373][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['obsp', 'X', 'raw_X', 'obs']
[INFO][2023-09-02 16:15:01,374][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:15:01,375][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 1756 × 999
    obs: 'n_counts'
    var: 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean

score=0.6597


[INFO][2023-09-02 16:16:33,732][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:16:33,802][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:16:33,803][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:16:33,804][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['obsp', 'X', 'raw_X', 'obs']
[INFO][2023-09-02 16:16:33,805][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:16:33,806][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 225 × 2000
    obs: 'n_counts'
    var: 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean

score=0.3271


[INFO][2023-09-02 16:17:00,438][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:17:00,508][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:17:00,509][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:17:00,510][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['obsp', 'X', 'raw_X', 'obs']
[INFO][2023-09-02 16:17:00,510][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:17:00,511][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 225 × 2000
    obs: 'n_counts'
    var: 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean

score=0.3271


[INFO][2023-09-02 16:17:27,510][dance.SaveRaw][__call__] Saving data to ``.raw``




[INFO][2023-09-02 16:17:27,521][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:17:28,104][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:17:28,105][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:17:28,106][dance][set_config_from_dict] Setting config 'feature_channel_type' to ['obsp', 'X', 'raw_X', 'obs']
[INFO][2023-09-02 16:17:28,107][dance][set_config_from_dict] Setting config 'label_channel' to 'Group'
[INFO][2023-09-02 16:17:28,108][dance][load_data] Data transformed:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 2000
    obs: 'n_counts'
    var: 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean

score=0.5907


'Reproduction information\n10X PBMC:\npython scdsc.py --dataset 10X_PBMC --method cosine --topk 30 --v 7 --binary_crossentropy_loss 0.75 --ce_loss 0.5 --re_loss 0.1 --zinb_loss 2.5 --sigma 0.4\n\nMouse Bladder:\npython scdsc.py --dataset mouse_bladder_cell --topk 50 --v 7 --binary_crossentropy_loss 2.5 --ce_loss 0.1 --re_loss 0.5 --zinb_loss 1.5 --sigma 0.6\n\nMouse ES:\npython scdsc.py --dataset mouse_ES_cell --topk 50 --v 7 --binary_crossentropy_loss 0.1 --ce_loss 0.01 --re_loss 1.5 --zinb_loss 0.5 --sigma 0.1\n\nWorm Neuron:\npython scdsc.py --dataset worm_neuron_cell --topk 20 --v 7 --binary_crossentropy_loss 2 --ce_loss 2 --re_loss 3 --zinb_loss 0.1 --sigma 0.4\n'

In [9]:
scdsc_scores

[0.5722479940804349,
 0.21680167115856036,
 0.32343149332579424,
 0.5530234095994098,
 0.6596624201855319,
 0.3270840856044606,
 0.3270840856044606,
 0.5907197301775053]

In [10]:
import torch
torch.cuda.is_available()

True

In [11]:
import argparse

import numpy as np

from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.sctag import ScTAG
from dance.utils import set_seed

parser = argparse.ArgumentParser(description="train", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--data_dir", default="./data", type=str)
parser.add_argument("--dataset", default="mouse_bladder_cell", type=str,
                    choices=["10X_PBMC", "mouse_bladder_cell", "mouse_ES_cell", "worm_neuron_cell","mouse_kidney_cell","human_pbmc_cell","human_pbmc2_cell","human_skin_cell","human_lung_cell","mouse_kidney_drop","mouse_kidney_cl2","mouse_kidney_10x"])
parser.add_argument("--k_neighbor", default=15, type=int)
parser.add_argument("--highly_genes", default=3000, type=int)
parser.add_argument("--pca_dim", default=50, type=int)
parser.add_argument("--k", default=2, type=int)
parser.add_argument("--hidden_dim", default=128, type=int)
parser.add_argument("--latent_dim", default=15, type=int)
parser.add_argument("--dec_dim", default=None, type=int)
parser.add_argument("--dropout", default=0.4, type=float)
parser.add_argument("--alpha", default=1.0, type=float)
parser.add_argument("--pretrain_epochs", default=200, type=int)
parser.add_argument("--epochs", default=500, type=int)
parser.add_argument("--device", default="cpu")
parser.add_argument("--w_a", default=1, type=float)
parser.add_argument("--w_x", default=1, type=float)
parser.add_argument("--w_d", default=0, type=float)
parser.add_argument("--w_c", default=1, type=float)
parser.add_argument("--lr", default=1e-3, type=float)
parser.add_argument("--min_dist", default=0.5, type=float)
parser.add_argument("--max_dist", default=20.0, type=float)
parser.add_argument("--info_step", default=50, type=int)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--cache", action="store_true", help="Cache processed data.")
sctag_scores=[]
for dataset in datasets:
    args = parser.parse_args([ '--dataset',dataset,'--pretrain_epochs','100',"--k","1"])
    set_seed(args.seed)

    # Load data and perform necessary preprocessing
    dataloader = ClusteringDataset(args.data_dir, args.dataset)
    preprocessing_pipeline = ScTAG.preprocessing_pipeline(n_top_genes=args.highly_genes, n_components=args.pca_dim,
                                                          n_neighbors=args.k_neighbor)
    data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)

    # inputs: adj, x, x_raw, n_counts
    inputs, y = data.get_train_data()

    n_clusters = len(np.unique(y))

    # Build and train model
    model = ScTAG(n_clusters=n_clusters, k=args.k, hidden_dim=args.hidden_dim, latent_dim=args.latent_dim,
                  dec_dim=args.dec_dim, dropout=args.dropout, device=args.device, alpha=args.alpha,
                  pretrain_path=f"sctag_{args.dataset}_pre.pkl")
    print(np.max(inputs[0]))
    model.fit(inputs, y, epochs=args.epochs, pretrain_epochs=args.pretrain_epochs, lr=args.lr, w_a=args.w_a,
              w_x=args.w_x, w_c=args.w_c, w_d=args.w_d, info_step=args.info_step, max_dist=args.max_dist,
              min_dist=args.min_dist)

    # Evaluate model predictions
    score = model.score(None, y)
    print(f"{score=:.4f}")
    sctag_scores.append(score)
"""Reproduction information
10X PBMC:
python sctag.py --dataset 10X_PBMC --pretrain_epochs 100 --w_a 0.01 --w_x 3 --w_c 0.1 --dropout 0.5

Mouse ES:
python sctag.py --dataset mouse_ES_cell --pretrain_epochs 100 --w_a 0.01 --w_x 0.75 --w_c 1

Worm Neuron:
python sctag.py --dataset worm_neuron_cell --w_a 0.01 --w_x 2 --w_c 0.25 --k 1

Mouse Bladder:
python sctag.py --dataset mouse_bladder_cell --pretrain_epochs 100 --w_a 0.1 --w_x 2.5 --w_c 3
"""


[INFO][2023-09-02 16:18:32,108][dance][set_seed] Setting global random seed to 42
[INFO][2023-09-02 16:18:32,181][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 902 × 16468
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:18:32,181][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 3}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.normalize_per_cell, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._highly_variable_genes.highly_variable_genes, func_kwargs={'min_mean': 0.0125, 'max_mean': 4, 'flavor': 'cell_ranger', 'min_disp': 0.5, 'n_top_genes': 3000, 'subset': True}),
  AnnDataTransform(func=scanpy.preprocessi



[INFO][2023-09-02 16:18:32,491][dance.CellPCA][__call__] Start generating cell PCA features (902, 3000) (k=50)
[INFO][2023-09-02 16:18:32,527][dance.CellPCA][__call__] Top 10 explained variances: [0.17064218 0.09295565 0.03068668 0.02320614 0.01286646 0.0118582
 0.00996639 0.00687872 0.0059774  0.00535993]
[INFO][2023-09-02 16:18:32,527][dance.CellPCA][__call__] Total explained variance: 47.30%
[INFO][2023-09-02 16:18:32,528][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:18:32,570][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:18:32,570][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:18:32,571][dance][set_config_from_dict] Setting config 'feature_channel_t

1.0


  assert input.numel() == input.storage().size(), "Cannot convert view " \
[INFO][2023-09-02 16:18:32,727][dance][fit] Epoch   1, ARI: 0.9965, Best ARI: 0.9965
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.

score=1.0000


[INFO][2023-09-02 16:18:59,653][dance][load_data] Raw data loaded:
Data object that wraps (.data):
AnnData object with n_obs × n_vars = 648 × 64535
    uns: 'dance_config'
    obsm: 'Group'
[INFO][2023-09-02 16:18:59,654][dance.Compose][__call__] Applying composed transformations:
Compose(
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 3}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_cells, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=scanpy.preprocessing._simple.normalize_per_cell, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._simple.log1p, func_kwargs={}),
  AnnDataTransform(func=scanpy.preprocessing._highly_variable_genes.highly_variable_genes, func_kwargs={'min_mean': 0.0125, 'max_mean': 4, 'flavor': 'cell_ranger', 'min_disp': 0.5, 'n_top_genes': 3000, 'subset': True}),
  AnnDataTransform(func=scanpy.preprocessing._simple.filter_genes, func_kwargs={'min_counts': 1}),
  AnnDataTransform(func=s



[INFO][2023-09-02 16:19:00,334][dance.CellPCA][__call__] Start generating cell PCA features (648, 3000) (k=50)
[INFO][2023-09-02 16:19:00,362][dance.CellPCA][__call__] Top 10 explained variances: [0.00805397 0.00736713 0.00616094 0.00608613 0.00534675 0.00517343
 0.00435965 0.00431036 0.00420092 0.0041413 ]
[INFO][2023-09-02 16:19:00,362][dance.CellPCA][__call__] Total explained variance: 18.06%
[INFO][2023-09-02 16:19:00,363][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:19:00,391][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:19:00,392][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:19:00,392][dance][set_config_from_dict] Setting config 'feature_channel_

1.0


  assert input.numel() == input.storage().size(), "Cannot convert view " \
[INFO][2023-09-02 16:19:00,545][dance][fit] Epoch   1, ARI: 0.8408, Best ARI: 0.8408
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.

  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() ==

score=0.8596


[INFO][2023-09-02 16:19:23,699][dance.SaveRaw][__call__] Saving data to ``.raw``




[INFO][2023-09-02 16:19:23,747][dance.CellPCA][__call__] Start generating cell PCA features (8218, 1000) (k=50)
[INFO][2023-09-02 16:19:23,810][dance.CellPCA][__call__] Top 10 explained variances: [0.10556723 0.03730556 0.02041031 0.01361092 0.01271562 0.0079526
 0.00625591 0.00581772 0.00472144 0.00400747]
[INFO][2023-09-02 16:19:23,811][dance.CellPCA][__call__] Total explained variance: 28.82%
[INFO][2023-09-02 16:19:23,812][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:19:24,696][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:19:24,698][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:19:24,699][dance][set_config_from_dict] Setting config 'feature_channel_

1.0


[INFO][2023-09-02 16:19:27,218][dance][_pretrain] Loading pre-trained model from sctag_human_pbmc2_cell_pre.pkl
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
[INFO][2023-09-02 16:19:27,900][dance][fit] Epoch   1, ARI: 0.3945, Best ARI: 0.3945
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Ca

score=0.4008


[INFO][2023-09-02 16:28:55,278][dance.SaveRaw][__call__] Saving data to ``.raw``




[INFO][2023-09-02 16:28:55,308][dance.CellPCA][__call__] Start generating cell PCA features (4853, 1000) (k=50)
[INFO][2023-09-02 16:28:55,351][dance.CellPCA][__call__] Top 10 explained variances: [0.12303824 0.05227059 0.0387879  0.02777254 0.02466786 0.02116269
 0.01945324 0.0181231  0.01202287 0.00956288]
[INFO][2023-09-02 16:28:55,352][dance.CellPCA][__call__] Total explained variance: 46.11%
[INFO][2023-09-02 16:28:55,352][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:28:55,894][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:28:55,895][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:28:55,896][dance][set_config_from_dict] Setting config 'feature_channel

1.0


[INFO][2023-09-02 16:28:56,712][dance][_pretrain] Loading pre-trained model from sctag_human_skin_cell_pre.pkl
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
[INFO][2023-09-02 16:28:57,052][dance][fit] Epoch   1, ARI: 0.6784, Best ARI: 0.6784
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Can

score=0.9157


[INFO][2023-09-02 16:31:15,073][dance.CellPCA][__call__] Start generating cell PCA features (1756, 999) (k=50)
[INFO][2023-09-02 16:31:15,098][dance.CellPCA][__call__] Top 10 explained variances: [0.07306425 0.03857969 0.02582724 0.01343988 0.00957488 0.00913248
 0.00741359 0.00714034 0.00621518 0.00597752]
[INFO][2023-09-02 16:31:15,099][dance.CellPCA][__call__] Total explained variance: 30.52%
[INFO][2023-09-02 16:31:15,099][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:31:15,194][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:31:15,195][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:31:15,195][dance][set_config_from_dict] Setting config 'feature_channel_

1.0


  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() ==

score=0.8447


[INFO][2023-09-02 16:31:41,313][dance.CellPCA][__call__] Start generating cell PCA features (225, 3000) (k=50)
[INFO][2023-09-02 16:31:41,326][dance.CellPCA][__call__] Top 10 explained variances: [0.06927603 0.04143312 0.02029931 0.01362791 0.0129288  0.01088451
 0.00948721 0.00932862 0.0083904  0.00816148]
[INFO][2023-09-02 16:31:41,326][dance.CellPCA][__call__] Total explained variance: 44.79%
[INFO][2023-09-02 16:31:41,327][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:31:41,337][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:31:41,338][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:31:41,338][dance][set_config_from_dict] Setting config 'feature_channel_

1.0


[INFO][2023-09-02 16:31:41,381][dance][_pretrain] Loading pre-trained model from sctag_mouse_kidney_drop_pre.pkl
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
[INFO][2023-09-02 16:31:41,456][dance][fit] Epoch   1, ARI: 0.9209, Best ARI: 0.9209
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "C

score=0.9209


[INFO][2023-09-02 16:31:56,380][dance.CellPCA][__call__] Start generating cell PCA features (225, 3000) (k=50)
[INFO][2023-09-02 16:31:56,393][dance.CellPCA][__call__] Top 10 explained variances: [0.06927603 0.04143312 0.02029931 0.01362791 0.0129288  0.01088451
 0.00948721 0.00932862 0.0083904  0.00816148]
[INFO][2023-09-02 16:31:56,394][dance.CellPCA][__call__] Total explained variance: 44.79%
[INFO][2023-09-02 16:31:56,394][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:31:56,405][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:31:56,406][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:31:56,406][dance][set_config_from_dict] Setting config 'feature_channel_

1.0


[INFO][2023-09-02 16:31:56,449][dance][_pretrain] Loading pre-trained model from sctag_mouse_kidney_cl2_pre.pkl
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
[INFO][2023-09-02 16:31:56,524][dance][fit] Epoch   1, ARI: 0.9209, Best ARI: 0.9209
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Ca

score=0.9209


[INFO][2023-09-02 16:32:11,002][dance.SaveRaw][__call__] Saving data to ``.raw``




[INFO][2023-09-02 16:32:11,019][dance.CellPCA][__call__] Start generating cell PCA features (902, 3000) (k=50)
[INFO][2023-09-02 16:32:11,052][dance.CellPCA][__call__] Top 10 explained variances: [0.17064218 0.09295565 0.03068668 0.02320614 0.01286646 0.0118582
 0.00996639 0.00687872 0.0059774  0.00535993]
[INFO][2023-09-02 16:32:11,053][dance.CellPCA][__call__] Total explained variance: 47.30%
[INFO][2023-09-02 16:32:11,053][dance.NeighborGraph][__call__] Start computing the kNN connectivity adjacency matrix
[INFO][2023-09-02 16:32:11,094][dance.SetConfig][__call__] Updating the dance data object config options:
{'feature_channel': ['NeighborGraph', None, None, 'n_counts'],
 'feature_channel_type': ['obsp', 'X', 'raw_X', 'obs'],
 'label_channel': 'Group'}
[INFO][2023-09-02 16:32:11,094][dance][set_config_from_dict] Setting config 'feature_channel' to ['NeighborGraph', None, None, 'n_counts']
[INFO][2023-09-02 16:32:11,095][dance][set_config_from_dict] Setting config 'feature_channel_t

1.0


  assert input.numel() == input.storage().size(), "Cannot convert view " \
[INFO][2023-09-02 16:32:11,247][dance][fit] Epoch   1, ARI: 0.9965, Best ARI: 0.9965
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.numel() == input.storage().size(), "Cannot convert view " \
  assert input.

score=1.0000


  assert input.numel() == input.storage().size(), "Cannot convert view " \


'Reproduction information\n10X PBMC:\npython sctag.py --dataset 10X_PBMC --pretrain_epochs 100 --w_a 0.01 --w_x 3 --w_c 0.1 --dropout 0.5\n\nMouse ES:\npython sctag.py --dataset mouse_ES_cell --pretrain_epochs 100 --w_a 0.01 --w_x 0.75 --w_c 1\n\nWorm Neuron:\npython sctag.py --dataset worm_neuron_cell --w_a 0.01 --w_x 2 --w_c 0.25 --k 1\n\nMouse Bladder:\npython sctag.py --dataset mouse_bladder_cell --pretrain_epochs 100 --w_a 0.1 --w_x 2.5 --w_c 3\n'

In [12]:
sctag_scores

[1.0,
 0.8596006841137562,
 0.40078924676001404,
 0.9156527237548763,
 0.8447362442419666,
 0.9209258226474065,
 0.9209258226474065,
 1.0]