In [1]:
import numpy as np
import pandas as pd
import pickle
from time import time
from scipy.stats import spearmanr, gamma, poisson
from anndata import AnnData, read_h5ad
import scanpy as sc
from scanpy import read
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch import tensor
from torch.cuda import is_available
from scMMT.scMMT_API import scMMT_API
from sklearn.metrics import f1_score, accuracy_score
import warnings
warnings.filterwarnings("ignore")

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [2]:
seed = 5
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) 
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU
np.random.seed(seed)

### Data preprocessing

In [3]:
# Read in Raw Data
adata_gene = sc.read("./pbmc/pbmc_gene.h5ad")
adata_protein = sc.read("./pbmc/pbmc_protein.h5ad")

In [4]:
adata_gene.X = adata_gene.X.toarray()
adata_protein.X = adata_protein.X.toarray()

In [5]:
# This is the protein processing process, which can be switched to any processing method
sc.pp.normalize_total(adata_protein)
sc.pp.log1p(adata_protein)
patients = np.unique(adata_protein.obs['donor'].values)
for patient in patients:
    indices = [x == patient for x in adata_protein.obs['donor']]
    sub_adata = adata_protein[indices]
    sc.pp.scale(sub_adata)
    adata_protein[indices] = sub_adata.X

In [6]:
# Create training and testing dataset
train_bool = [x in ['P1', 'P3', 'P4', 'P7'] for x in adata_protein.obs['donor']]
adata_gene_train = adata_gene[train_bool].copy()
adata_protein_train = adata_protein[train_bool].copy()
adata_gene_test = adata_gene[np.invert(train_bool)].copy()
adata_protein_test = adata_protein[np.invert(train_bool)].copy()

### Train scMMT model

In [7]:
scMMT = scMMT_API(    gene_trainsets = [adata_gene_train], protein_trainsets = [adata_protein_train], gene_test = adata_gene_test, 
                      train_batchkeys = ['donor'], test_batchkey = 'donor',
                      log_normalize = True,            # Is scRNA seq standardized for log
                      type_key = 'celltype.l3',        # Keywords representing cell types (in protein dataset)
                      data_dir="preprocess_data_l3.pkl",  # Save path for processed data
                      data_load=False,                # Do you want to import existing processed data
                      dataset_batch = True,           # Is there a batch effect in the training set and testing machine
                      log_weight=3,                   # Log weights for different cell types
                      val_split = None,               # Do you need to divide the validation set according to the distribution of the test set
                      min_cells = 0,                  # Minimum cell count filtering
                      min_genes = 0,                  # Minimum number of genes filtering
                      n_svd = 300,                    # Dimension obtained using Tsvd dimensionality reduction
                      n_fa=180,                       # Dimension obtained by using FA dimensionality reduction
                      n_hvg=550,                      # Number of high variants obtained through screening
                     )

Searching for GPU
GPU detected, using GPU

Normalizing Training Cells
Normalizing Testing Cells

Log-Normalizing Training Data
Log-Normalizing Testing Data

combat

TSVD...

FA ...

Finding HVGs

Normalizing Gene Training Data by Batch

Normalizing Gene Testing Data by Batch


##### *label_smoothing: This value is related to the probability of correct data annotation and the number of cell types. Generally speaking, the poorer the quality of data annotation and the more cell types there are, the larger the value of label smoothing should be. The dataset used here has a large number of data types, with 58 types. Moreover, the reference dataset did not provide accurate annotations for these 58 cell types, so a larger value of 0.4 was chosen

In [8]:
scMMT.train(n_epochs = 100, ES_max = 12, decay_max = 6, decay_step = 0.1, lr = 10**(-3), label_smoothing=0.4, 
            h_size=600, drop_rate=0.15, n_layer=4,
            weights_dir = "model_weight", load = False)

Epoch 0 validation accuracy = 0.897,mseloss=0.760
Epoch 1 validation accuracy = 0.905,mseloss=0.702
Epoch 2 validation accuracy = 0.914,mseloss=0.696
Epoch 3 validation accuracy = 0.912,mseloss=0.689
Epoch 4 validation accuracy = 0.911,mseloss=0.689
Epoch 5 validation accuracy = 0.901,mseloss=0.685
Epoch 6 validation accuracy = 0.907,mseloss=0.688
Epoch 7 validation accuracy = 0.906,mseloss=0.688
Epoch 8 validation accuracy = 0.905,mseloss=0.692
Decaying loss to 0.0001
Epoch 9 validation accuracy = 0.918,mseloss=0.665
Epoch 10 validation accuracy = 0.920,mseloss=0.663
Epoch 11 validation accuracy = 0.921,mseloss=0.663
Epoch 12 validation accuracy = 0.922,mseloss=0.662
Epoch 13 validation accuracy = 0.922,mseloss=0.662
Epoch 14 validation accuracy = 0.921,mseloss=0.662
Epoch 15 validation accuracy = 0.919,mseloss=0.666
Epoch 16 validation accuracy = 0.922,mseloss=0.661
Epoch 17 validation accuracy = 0.920,mseloss=0.662
Epoch 18 validation accuracy = 0.921,mseloss=0.660
Epoch 19 validati

### Cell annotation, protein prediction and embedding 

In [9]:
predicted_test = scMMT.predict()

In [10]:
predicted_test

AnnData object with n_obs × n_vars = 86008 × 224
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'batch', 'scale_factor', 'dataset', 'transfered cell labels'

In [11]:
acc = (predicted_test.obs['transfered cell labels'] == predicted_test.obs['celltype.l3']).mean()
f1 = f1_score(predicted_test.obs['transfered cell labels'], predicted_test.obs['celltype.l3'], average=None)
f1_median = np.median(f1)
print("ACC:",acc," F1:",f1_median)

ACC: 0.9139963724304716  F1: 0.8674924924924925


In [12]:
embedding = scMMT.embed()

In [13]:
embedding

AnnData object with n_obs × n_vars = 161764 × 512
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'batch', 'dataset'