### AntiSplodge 10-fold cross validation example
This notebook relies on heavily on the minimal example notebook (https://github.com/HealthML/AntiSplodge/blob/main/AntiSplodge_minimal_example.ipynb).

If something related to AntiSplodge is unclear, please refer to that notebook.

First: 
Import the packages we need for this, remember to install them with "pip install PACKAGE" in your terminal. 

In [5]:
import anndata as ann
import scanpy as sc
import numpy as np

import antisplodge as AS

from collections import Counter

After importing the required packages, load the dataset into memory, remember to have the dataset in the same folder as the notebook. 

You can download the dataset direcly from https://cellgeni.cog.sanger.ac.uk/heartcellatlas/data/global_raw.h5ad.

In [6]:
SC = ann.read("global_raw.h5ad")

In [7]:
SC.obs

Unnamed: 0,NRP,age_group,cell_source,cell_type,donor,gender,n_counts,n_genes,percent_mito,percent_ribo,region,sample,scrublet_score,source,type,version,cell_states,Used
AAACCCAAGAACGCGT-1-H0015_apex,No,50-55,Harvard-Nuclei,Smooth_muscle_cells,H5,Female,688.0,480,0.004360,0.002907,AX,H0015_apex,0.035917,Nuclei,DBD,V3,SMC1_basic,Yes
AAACCCAAGCAAACAT-1-H0015_apex,No,50-55,Harvard-Nuclei,Ventricular_Cardiomyocyte,H5,Female,3216.0,1365,0.000933,0.001244,AX,H0015_apex,0.147122,Nuclei,DBD,V3,vCM1,Yes
AAACCCAAGCTACTGT-1-H0015_apex,No,50-55,Harvard-Nuclei,Ventricular_Cardiomyocyte,H5,Female,3182.0,1521,0.000943,0.002200,AX,H0015_apex,0.185751,Nuclei,DBD,V3,vCM1,Yes
AAACCCAGTACCGCGT-1-H0015_apex,No,50-55,Harvard-Nuclei,Pericytes,H5,Female,1202.0,726,0.000832,0.000000,AX,H0015_apex,0.035917,Nuclei,DBD,V3,PC2_atria,Yes
AAACCCATCAAACCCA-1-H0015_apex,No,50-55,Harvard-Nuclei,Ventricular_Cardiomyocyte,H5,Female,3804.0,1584,0.000263,0.001314,AX,H0015_apex,0.108062,Nuclei,DBD,V3,vCM2,Yes
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGTCATACGGT-1-HCAHeart8102862,Yes,60-65,Sanger-CD45,Endothelial,D11,Female,1619.0,904,0.087708,0.024707,AX,HCAHeart8102862,0.108062,CD45+,DCD,V3,EC1_cap,Yes
TTTGTTGTCCTACCAC-1-HCAHeart8102862,Yes,60-65,Sanger-CD45,Endothelial,D11,Female,1186.0,692,0.126476,0.015177,AX,HCAHeart8102862,0.085546,CD45+,DCD,V3,EC1_cap,Yes
TTTGTTGTCGACGCTG-1-HCAHeart8102862,Yes,60-65,Sanger-CD45,Endothelial,D11,Female,1672.0,847,0.145335,0.044258,AX,HCAHeart8102862,0.113475,CD45+,DCD,V3,EC1_cap,Yes
TTTGTTGTCGGCTGAC-1-HCAHeart8102862,Yes,60-65,Sanger-CD45,Endothelial,D11,Female,4676.0,1925,0.163388,0.044482,AX,HCAHeart8102862,0.042770,CD45+,DCD,V3,EC5_art,Yes


After we have loaded the dataset into memory, we remove all cells that stems from other than Nuclei sequencing ('source'), only use a specific technology to remove technology-based batch effects ('cell_source'), and, remove cell types that are defined as 'doublets' or 'NotAssigned'.

In [8]:
SC = SC[SC.obs['source'] == 'Nuclei']
SC = SC[SC.obs['cell_source'] == 'Harvard-Nuclei']
SC = SC[~SC.obs['cell_type'].isin(['doublets', 'NotAssigned'])]

Let us look at the number of cells remaining (N=163959)

In [9]:
SC

View of AnnData object with n_obs × n_vars = 163959 × 33538
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45'
    uns: 'cell_type_colors'
    obsm: 'X_pca', 'X_umap'

### Find marker genes
Because of the high number of genes (N=33538), we want to reduce the set of genes to a smaller one in order to speed up training and reduce memory footprint. But it would work with all the genes, but the time to get an as good model will be a lot higher. 

However, before we look for marker genes, we scale all profiles to 1 using scanpy's function normalize_total, as shown below:

In [10]:
sc.pp.normalize_total(SC, target_sum=1)

  view_to_actual(adata)


We now use the marker genes that can be found by running the minimal example notebook, we have hardcoded them below.

In [11]:
import pandas as pd

use_genes = list(pd.read_csv("marker_genes.csv")["0"])
use_genes

['ABCA6',
 'ABCA8',
 'ABCA9',
 'ABCA9-AS1',
 'ABCC1',
 'ABCC9',
 'ABI1',
 'ABI3BP',
 'ABL1',
 'ABTB2',
 'AC003991.1',
 'AC005037.1',
 'AC005358.1',
 'AC005699.1',
 'AC007319.1',
 'AC008056.1',
 'AC008250.1',
 'AC009264.1',
 'AC010609.1',
 'AC011369.1',
 'AC011389.1',
 'AC012636.1',
 'AC013640.1',
 'AC013652.1',
 'AC015712.2',
 'AC016766.1',
 'AC016831.7',
 'AC017002.5',
 'AC018464.1',
 'AC018742.1',
 'AC019197.1',
 'AC020637.1',
 'AC022034.2',
 'AC022075.1',
 'AC027097.2',
 'AC058822.1',
 'AC068234.2',
 'AC079336.4',
 'AC084357.2',
 'AC091057.6',
 'AC091691.2',
 'AC092114.1',
 'AC092164.1',
 'AC092567.1',
 'AC092683.1',
 'AC098617.1',
 'AC100803.3',
 'AC109587.1',
 'AC109779.1',
 'AC113386.1',
 'AC114760.2',
 'AC114763.1',
 'AC120193.1',
 'AC130650.1',
 'AC140912.1',
 'ACACB',
 'ACAP1',
 'ACKR1',
 'ACLY',
 'ACSL1',
 'ACSL4',
 'ACSM1',
 'ACSM3',
 'ACSS2',
 'ACTA1',
 'ACTB',
 'ACTC1',
 'ACTN4',
 'ACVR1C',
 'ADAM19',
 'ADAM28',
 'ADAMTS12',
 'ADAMTS2',
 'ADAMTS4',
 'ADAMTS9',
 'ADAMTSL1',

In [12]:
SC = SC[:,use_genes] # filter SC

We need to scale again after removing unwanted genes, in order to again, have profiles of equal counts!

In [13]:
sc.pp.normalize_total(SC, target_sum=1)

  view_to_actual(adata)


In [14]:
SC # the profiles are now much smaller which will reduce memory and speed up traininig time

AnnData object with n_obs × n_vars = 163959 × 1388
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45'
    uns: 'cell_type_colors'
    obsm: 'X_pca', 'X_umap'

### Define and construct our 10-folds


In [15]:
# requires scikit learn
from sklearn.model_selection import StratifiedKFold

Y = SC.obs['cell_type']
X = SC.X

skf = StratifiedKFold(n_splits=10)
skf.get_n_splits(X, Y)
skf

StratifiedKFold(n_splits=10, random_state=None, shuffle=False)

In [16]:
from collections import Counter

i = 1
folds = []
for train_index, test_index in skf.split(X, Y):
    print("TRAIN:", train_index, "TEST:", test_index)
    # define our fold
    SC_fold_train = SC[train_index,]
    SC_fold_test = SC[test_index,]

    # do an additional fold for a validation set
    Y_fold = SC_fold_test.obs['cell_type']
    X_fold = SC_fold_test.X

    print(X_fold.shape)
    skf_fold = StratifiedKFold(n_splits=2)
    skf_fold.get_n_splits(X_fold, Y_fold)
    for test_index_2, val_index in skf_fold.split(X_fold, Y_fold):
        print("TEST (2):", test_index_2, "VAL:", val_index)
        SC_fold_val = SC_fold_test[val_index,]
        SC_fold_test = SC_fold_test[test_index_2,]
        break # break because we already have our 2 "folds" here
        
    folds.append([SC_fold_train, SC_fold_val, SC_fold_test])
    # verify that cell types are indeed stratified
    print(Counter(SC_fold_train.obs['cell_type']))
    print(Counter(SC_fold_test.obs['cell_type']))
    print(Counter(SC_fold_val.obs['cell_type']))

TRAIN: [ 13876  13879  13880 ... 163956 163957 163958] TEST: [    0     1     2 ... 39310 39317 39319]
(16396, 1388)
TEST (2): [ 5898  5900  5902 ... 16393 16394 16395] VAL: [    0     1     2 ... 15894 15899 15910]
Counter({'Ventricular_Cardiomyocyte': 61613, 'Pericytes': 33543, 'Fibroblast': 13899, 'Endothelial': 11495, 'Atrial_Cardiomyocyte': 7146, 'Smooth_muscle_cells': 6742, 'Myeloid': 5030, 'Lymphoid': 4268, 'Adipocytes': 1991, 'Neuronal': 1836})
Counter({'Ventricular_Cardiomyocyte': 3423, 'Pericytes': 1863, 'Fibroblast': 772, 'Endothelial': 639, 'Atrial_Cardiomyocyte': 397, 'Smooth_muscle_cells': 375, 'Myeloid': 279, 'Lymphoid': 238, 'Adipocytes': 110, 'Neuronal': 102})
Counter({'Ventricular_Cardiomyocyte': 3423, 'Pericytes': 1863, 'Fibroblast': 772, 'Endothelial': 639, 'Atrial_Cardiomyocyte': 397, 'Smooth_muscle_cells': 375, 'Myeloid': 279, 'Lymphoid': 237, 'Adipocytes': 111, 'Neuronal': 102})
TRAIN: [     0      1      2 ... 163956 163957 163958] TEST: [13876 13879 13880 ... 5

In [17]:
experiments = []
for i in range(len(folds)):
    exp_ = AS.DeconvolutionExperiment(SC)
    exp_.SC_train = folds[i][0] # train
    exp_.SC_val = folds[i][1]   # validation
    exp_.SC_test = folds[i][2]  # test
    exp_.setCellTypeColumn('cell_type')
    exp_.setVerbosity(True)

    experiments.append(exp_)



print(experiments) # our 10 folded experiments

[<antisplodge.DeconvolutionExperiment object at 0x7f651b18ce20>, <antisplodge.DeconvolutionExperiment object at 0x7f651b00e800>, <antisplodge.DeconvolutionExperiment object at 0x7f651af729b0>, <antisplodge.DeconvolutionExperiment object at 0x7f651af72b30>, <antisplodge.DeconvolutionExperiment object at 0x7f651af72c50>, <antisplodge.DeconvolutionExperiment object at 0x7f651af738e0>, <antisplodge.DeconvolutionExperiment object at 0x7f6641382ec0>, <antisplodge.DeconvolutionExperiment object at 0x7f651af71090>, <antisplodge.DeconvolutionExperiment object at 0x7f6641381ae0>, <antisplodge.DeconvolutionExperiment object at 0x7f651af73970>]


Then we generate 100.000 training, 2.500 validation, and, 2.500 test samples. 
And load these into data loaders for use in training of the neural network. 

Notice: This might require more RAM than the average laptop have.

In [18]:
for exp_ in experiments:
    exp_.generateTrainTestValidation(num_profiles=[100000,2500,2500], CD=[10,10])
    # Load the profiles into data loaders
    exp_.setupDataLoaders(batch_size=500)

GENERATING PROFILES
GENERATING TRAIN DATASET (N=100000)
GENERATING VALIDATION DATASET (N=2500)
GENERATING TEST DATASET (N=2500)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=100000)
GENERATING VALIDATION DATASET (N=2500)
GENERATING TEST DATASET (N=2500)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=100000)
GENERATING VALIDATION DATASET (N=2500)
GENERATING TEST DATASET (N=2500)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=100000)
GENERATING VALIDATION DATASET (N=2500)
GENERATING TEST DATASET (N=2500)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=100000)
GENERATING VALIDATION DATASET (N=2500)
GENERATING TEST DATASET (N=2500)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=100000)
GENERATING VALIDATION DATASET (N=2500)
GENERATING TEST DATASET (N=2500)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=100000)
GENERATING VALIDATION DATASET (N=2500)
GENERATING TEST DATASET (N=2500)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=100000)
GENERATING VALIDATION DATASET (N=2500)
GENERATIN

We then define the model (with default values) and use the first cuda device (cuda_id=1).

In [19]:
# Initialize Neural network-model and allocate it to the cuda_id specified
# Use 'cuda_id="cpu"' if you want to allocate it to a cpu
for exp_ in experiments:
    exp_.setupModel(cuda_id=0)

(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0


We then train the network with 25 warm restarts, this means whenever we finish a training session we load the best model settings we have found so far back onto the model and continue from there in the next setting. 

In [20]:
for i, exp_ in enumerate(experiments):
    # do 25 warm restarts with decreasing learning rate
    lr = 0.001 
    best_error=None # no target error to beat in the beginning
    for k in range(25):
        
        # Consider changing learning rate (lr) during run more dynamically
        if k >= 5:
            lr = 0.0005
        if k >= 10:
            lr = 0.0001
        if k >= 15:
            lr = 0.00005
        exp_.setupOptimizerAndCriterion(learning_rate = lr)
        # Train the experiment constructed by passing the experiment to the AntiSplodge training function 
        AS.train(exp_, save_file=f"ModelCheckpoint_{i}.pt", patience=5, best_loss=best_error) # For longer training, increase patience threshold
        best_error = AS.getMeanJSD(exp_, "validation") # set best error as the target error to beat
        
        print("Exp [{}] - Restart [{}] - JSDs".format(i, k),AS.getMeanJSD(exp_, "train"), AS.getMeanJSD(exp_, "validation"))

Epoch: 001 | Epochs since last increase: 000
Loss: (Train) 0.02472 | (Valid): 0.01993 - (mean)JSD: (Train) 0.11659 | (Valid) 0.12424 

Epoch: 002 | Epochs since last increase: 000 | Better solution found
Loss: (Train) 0.01852 | (Valid): 0.01915 - (mean)JSD: (Train) 0.10339 | (Valid) 0.11411 

Epoch: 003 | Epochs since last increase: 001
Loss: (Train) 0.01720 | (Valid): 0.01963 - (mean)JSD: (Train) 0.11191 | (Valid) 0.12282 

Epoch: 004 | Epochs since last increase: 002
Loss: (Train) 0.01658 | (Valid): 0.01808 - (mean)JSD: (Train) 0.10451 | (Valid) 0.11528 

Epoch: 005 | Epochs since last increase: 000 | Better solution found
Loss: (Train) 0.01616 | (Valid): 0.01741 - (mean)JSD: (Train) 0.09971 | (Valid) 0.11226 

Epoch: 006 | Epochs since last increase: 001
Loss: (Train) 0.01531 | (Valid): 0.01698 - (mean)JSD: (Train) 0.10346 | (Valid) 0.11455 

Epoch: 007 | Epochs since last increase: 000 | Better solution found
Loss: (Train) 0.01504 | (Valid): 0.01674 - (mean)JSD: (Train) 0.09401 | (

In the end, we have a very satisfactory JSD of 7.974%!

In [24]:
jsd = 0
for exp in experiments:
    cur_jsd = AS.getMeanJSD(exp, "test")
    print("Test accuracy:", "{:2f}%".format(cur_jsd*100))
    jsd += cur_jsd

print(f"Mean JSD: {(jsd/10) * 100}%")

Test accuracy: 8.646763%
Test accuracy: 8.270927%
Test accuracy: 9.185757%
Test accuracy: 10.056727%
Test accuracy: 8.671460%
Test accuracy: 7.788181%
Test accuracy: 8.113168%
Test accuracy: 8.163213%
Test accuracy: 7.717306%
Test accuracy: 9.374959%
Mean JSD: 8.598846208617214%


In [30]:
from scipy.spatial import distance

jsds_ = []
for i, exp in enumerate(experiments): 
    loader      = exp.test_loader
    proportions = exp.Y_test_prop

    y_preds = AS.predict(exp, loader)
    for j in range(len(y_preds)):
        # i + 1 to get which fold the jsd belongs to
        jsds_.append([distance.jensenshannon(proportions[j], y_preds[j]), i+1]) 

In [32]:
df_for_plot = pd.DataFrame(jsds_, columns=['JSD', 'FOLD'])
df_for_plot.to_csv("df_for_plot.csv")

In [40]:
for i, exp in enumerate(experiments): 
    loader      = exp.test_loader
    proportions = exp.Y_test_prop

    jsds_ = []
    y_preds = AS.predict(exp, loader)
    for j in range(len(y_preds)):
        # i + 1 to get which fold the jsd belongs to
        jsds_.append(distance.jensenshannon(proportions[j], y_preds[j])) 
    print(f"SD: {np.std(jsds_)*100:0.3}%")

SD: 7.97%
SD: 7.11%
SD: 8.43%
SD: 8.07%
SD: 8.28%
SD: 7.74%
SD: 7.83%
SD: 8.14%
SD: 8.15%
SD: 8.53%
