In [2]:
import warnings 
warnings.simplefilter('ignore')

import scanpy as sc
import scparadise
import muon as mu
import pandas as pd
import os

In [3]:
# Load normalized integrated data
mdata = mu.read_h5mu('PBMC_3p_CITE/mdata_unintegrated.h5mu')

In [4]:
# Select RNA modality from MuData object (mdata)
adata = mdata.mod['rna'].copy()
del mdata

In [5]:
# Subset anndata object based on a selected marker genes
genes = pd.read_csv('PBMC_3p_CITE/genes_for_AI.csv')
adata = adata[:, genes.genes].copy()

In [6]:
# Create list of samples to leave in reference train dataset
lst_reference = ['P1_0', 'P2_0', 'P3_0', 'P4_0', 'P5_0', 'P6_0', 'P7_0', 'P8_0']

In [7]:
# Create adata_train - 8 samples of 8 donors from unintegrated adata object (8 donors, 24 samples)
adata_train = adata[adata.obs['orig.ident'].isin(lst_reference)].copy()

In [8]:
# Balance dataset based on l3 annotation level
adata_balanced = scparadise.scnoah.balance(adata_train, 
                                           sample='orig.ident',
                                           celltype_l1='celltype_l1',
                                           celltype_l2='celltype_l2',
                                           celltype_l3='celltype_l3')

Successfully undersampled cell types: CD14 Mono, NK, CD4 T Naive, CD4 TCM, CD8 TEM, CD8 T Naive, CD16 Mono

Successfully oversampled cell types: B naive κ, CD4 TEM, gdT, CD8 TCM, cDC2, MAIT, Treg, B naive λ, B memory κ, CD4 CTL, Platelet, B int λ, B memory λ, B int κ, NK_CD56bright, pDC, HSPC, dnT, Plasmablast, NK Prolif, cDC1, ILC, CD4 T Prolif, ASDC, CD8 T Prolif


In [9]:
# Train scadam model using adata_balanced dataset
scparadise.scadam.train(adata_balanced,
                        path='',
                        model_name='model_PBMC_scAdam_default',
                        celltype_l1='celltype_l1',
                        celltype_l2='celltype_l2',
                        celltype_l3='celltype_l3',
                        eval_metric=['balanced_accuracy','accuracy'])

Successfully saved genes names for training model

Successfully saved dictionary of dataset annotations

Train dataset contains: 47808 cells, it is 90.0 % of input dataset
Test dataset contains: 5312 cells, it is 10.0 % of input dataset

Accelerator: cuda
Start training
epoch 0  | loss: 2.24877 | train_balanced_accuracy: 0.3667  | train_accuracy: 0.4812  | valid_balanced_accuracy: 0.36224 | valid_accuracy: 0.4786  |  0:00:02s
epoch 1  | loss: 1.17843 | train_balanced_accuracy: 0.71553 | train_accuracy: 0.71488 | valid_balanced_accuracy: 0.7152  | valid_accuracy: 0.71417 |  0:00:04s
epoch 2  | loss: 0.81992 | train_balanced_accuracy: 0.80681 | train_accuracy: 0.79671 | valid_balanced_accuracy: 0.8104  | valid_accuracy: 0.8002  |  0:00:07s
epoch 3  | loss: 0.67335 | train_balanced_accuracy: 0.85169 | train_accuracy: 0.8378  | valid_balanced_accuracy: 0.85469 | valid_accuracy: 0.84074 |  0:00:10s
epoch 4  | loss: 0.58694 | train_balanced_accuracy: 0.89004 | train_accuracy: 0.88149 | valid

In [10]:
# Create lists with paired test samples
lst_test = ['P1_3_P3_3', 'P1_7_P8_3', 'P2_3_P4_7', 'P2_7_P6_3', 'P3_7_P7_3', 'P4_3_P7_7', 'P5_3_P8_7', 'P5_7_P6_7']

In [17]:
for folder in lst_test:
    os.makedirs(os.path.join('PBMC_3p_CITE/reports_model_PBMC_scAdam_default', folder))
    adata_test = adata[adata.obs['orig.ident'].isin([folder[0:4], folder[5:9]])].copy()
    # Predict annotation levels using pretrained scadam model
    adata_test = scparadise.scadam.predict(adata_test, 
                                           path_model = 'model_PBMC_scAdam_default')
    # Create and save classification report of annotation levels
    scparadise.scnoah.report_classif_full(adata_test, 
                                          celltype = 'celltype_l1', 
                                          pred_celltype = 'pred_celltype_l1', 
                                          report_name = 'report_test_model_scAdam_default_celltype_l1.csv',
                                          save_path = os.path.join('PBMC_3p_CITE/reports_model_PBMC_scAdam_default', folder).replace("\\","/"),
                                          save_report = True)
    scparadise.scnoah.report_classif_full(adata_test, 
                                          celltype = 'celltype_l2', 
                                          pred_celltype = 'pred_celltype_l2', 
                                          report_name = 'report_test_model_scAdam_default_celltype_l2.csv',
                                          save_path = os.path.join('PBMC_3p_CITE/reports_model_PBMC_scAdam_default', folder).replace("\\","/"),
                                          save_report = True)
    scparadise.scnoah.report_classif_full(adata_test, 
                                          celltype = 'celltype_l3', 
                                          pred_celltype = 'pred_celltype_l3', 
                                          report_name = 'report_test_model_scAdam_default_celltype_l3.csv',
                                          save_path = os.path.join('PBMC_3p_CITE/reports_model_PBMC_scAdam_default', folder).replace("\\","/"),
                                          save_report = True)

Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
Successfully added predicted celltype_l3 and cell type probabilities
Successfully saved report

Successfully saved report

Successfully saved report

Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
Successfully added predicted celltype_l3 and cell type probabilities
Successfully saved report

Successfully saved report

Successfully saved report

Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

S

In [22]:
import session_info
session_info.show()