### AntiSplodge minimal example
In this short and minimal tutorial we are going to use the "global" dataset located at https://www.heartcellatlas.org/.
We are going to deconvolute major heart-based cell types to check the JSD of the method.

You can download the dataset direcly from https://cellgeni.cog.sanger.ac.uk/heartcellatlas/data/global_raw.h5ad.

First: 
Import the packages we need for this, remember to install them with "pip install PACKAGE" in your terminal. 

In [None]:
import anndata as ann
import scanpy as sc
import numpy as np

import antisplodge as AS

from collections import Counter

After importing the required packages, load the dataset into memory, remember to have the dataset in the same folder as the notebook. 

In [None]:
SC = ann.read("global_raw.h5ad")

In [None]:
SC.obs

Unnamed: 0,NRP,age_group,cell_source,cell_type,donor,gender,n_counts,n_genes,percent_mito,percent_ribo,region,sample,scrublet_score,source,type,version,cell_states,Used
AAACCCAAGAACGCGT-1-H0015_apex,No,50-55,Harvard-Nuclei,Smooth_muscle_cells,H5,Female,688.0,480,0.004360,0.002907,AX,H0015_apex,0.035917,Nuclei,DBD,V3,SMC1_basic,Yes
AAACCCAAGCAAACAT-1-H0015_apex,No,50-55,Harvard-Nuclei,Ventricular_Cardiomyocyte,H5,Female,3216.0,1365,0.000933,0.001244,AX,H0015_apex,0.147122,Nuclei,DBD,V3,vCM1,Yes
AAACCCAAGCTACTGT-1-H0015_apex,No,50-55,Harvard-Nuclei,Ventricular_Cardiomyocyte,H5,Female,3182.0,1521,0.000943,0.002200,AX,H0015_apex,0.185751,Nuclei,DBD,V3,vCM1,Yes
AAACCCAGTACCGCGT-1-H0015_apex,No,50-55,Harvard-Nuclei,Pericytes,H5,Female,1202.0,726,0.000832,0.000000,AX,H0015_apex,0.035917,Nuclei,DBD,V3,PC2_atria,Yes
AAACCCATCAAACCCA-1-H0015_apex,No,50-55,Harvard-Nuclei,Ventricular_Cardiomyocyte,H5,Female,3804.0,1584,0.000263,0.001314,AX,H0015_apex,0.108062,Nuclei,DBD,V3,vCM2,Yes
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGTCATACGGT-1-HCAHeart8102862,Yes,60-65,Sanger-CD45,Endothelial,D11,Female,1619.0,904,0.087708,0.024707,AX,HCAHeart8102862,0.108062,CD45+,DCD,V3,EC1_cap,Yes
TTTGTTGTCCTACCAC-1-HCAHeart8102862,Yes,60-65,Sanger-CD45,Endothelial,D11,Female,1186.0,692,0.126476,0.015177,AX,HCAHeart8102862,0.085546,CD45+,DCD,V3,EC1_cap,Yes
TTTGTTGTCGACGCTG-1-HCAHeart8102862,Yes,60-65,Sanger-CD45,Endothelial,D11,Female,1672.0,847,0.145335,0.044258,AX,HCAHeart8102862,0.113475,CD45+,DCD,V3,EC1_cap,Yes
TTTGTTGTCGGCTGAC-1-HCAHeart8102862,Yes,60-65,Sanger-CD45,Endothelial,D11,Female,4676.0,1925,0.163388,0.044482,AX,HCAHeart8102862,0.042770,CD45+,DCD,V3,EC5_art,Yes


After we have loaded the dataset into memory, we remove all cells that stems from other than Nuclei sequencing ('source'), only use a specific technology to remove technology-based batch effects ('cell_source'), and, remove cell types that are defined as 'doublets' or 'NotAssigned'.

In [None]:
SC = SC[SC.obs['source'] == 'Nuclei']
SC = SC[SC.obs['cell_source'] == 'Harvard-Nuclei']
SC = SC[~SC.obs['cell_type'].isin(['doublets', 'NotAssigned'])]

Let us look at the number of cells remaining (N=163959)

In [None]:
SC

View of AnnData object with n_obs × n_vars = 163959 × 33538
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45'
    uns: 'cell_type_colors'
    obsm: 'X_pca', 'X_umap'

### Find marker genes
Because of the high number of genes (N=33538), we want to reduce the set of genes to a smaller one in order to speed up training and reduce memory footprint. But it would work with all the genes, but the time to get an as good model will be a lot higher. 

However, before we look for marker genes, we scale all profiles to 1 using scanpy's function normalize_total, as shown below:

In [None]:
sc.pp.normalize_total(SC, target_sum=1)

  view_to_actual(adata)


In [None]:
#
# Find the top 'key' marker genes that was found for each cell type during the univariate t-test analysis 
#
# usually you should use method='logreg', but t-test is faster for demonstration purpose
sc.tl.rank_genes_groups(SC, groupby='cell_type', method='t-test', key_added='ranks') 

In [None]:
#
# Get the corresponding gene sets
#
def getGenes(adata, key, ct, min_genes, score_threshold=0.01):
    genes = []
    
    # get the N most correlated genes for each cell type
    for i, cell_ in enumerate(ct):
        
        # find the number of genes to include
        index = min_genes 
        while adata.uns[key]['scores'][cell_][index] > score_threshold: # hardcoded 0.01 inclusion 
            index += 1
        
        genes_ = adata.uns[key]['names'][cell_][0:index]
        scores = adata.uns[key]['scores'][cell_][0:index]
        
        print(cell_, len(genes_))
        
        genes.extend(genes_)
        
    np_genes = np.unique(np.array(genes))
    print("Length of unique genes:",len(np_genes))
    
    return np_genes

# use the top 50 for cell_type and allow some threshold to select the rest, 
# only ventricular cardiomyocytes have genes with scores above this threshold
use_genes = getGenes(SC, 'ranks', np.unique(SC.obs['cell_type']), 50, 120) 

Adipocytes 50
Atrial_Cardiomyocyte 50
Endothelial 50
Fibroblast 50
Lymphoid 50
Myeloid 50
Neuronal 50
Pericytes 50
Smooth_muscle_cells 50
Ventricular_Cardiomyocyte 229
Length of unique genes: 609


In [None]:
SC = SC[:,use_genes] # filter SC

We need to scale again after removing unwanted genes, in order to again, have profiles of equal counts!

In [None]:
sc.pp.normalize_total(SC, target_sum=1)

  view_to_actual(adata)


In [None]:
SC # the profiles are now much smaller which will reduce memory and speed up traininig time

AnnData object with n_obs × n_vars = 163959 × 609
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45'
    uns: 'cell_type_colors', 'ranks'
    obsm: 'X_pca', 'X_umap'

### AntiSplodge experiment
We then setup the AntiSplodge experiment, in this step we do:

1) Create a new experiment by passing SC to the AntiSplodge 'DeconvolutionExperiment' function.
2) We set the cell type column to 'cell_type' in the dataset
3) We split the dataset into 80% train, 10% validation, and, 10% test. 

In [None]:
# SC should be the single-cell dataset formatted as .h5ad (AnnData)
Exp = AS.DeconvolutionExperiment(SC) 
Exp.setVerbosity(True)

# CELLTYPE_COLUMN should be replaced with actual column
Exp.setCellTypeColumn('cell_type') 
# Use 80% as train data and split the rest into a 50/50 split validation and testing
Exp.splitTrainTestValidation(train=0.8, rest=0.5)

Then we generate 500.000 training, 10.000 validation, and, 10.000 test samples. 
And load these into data loaders for use in training of the neural network. 

In [None]:
Exp.generateTrainTestValidation(num_profiles=[500000,10000,10000], CD=[10,10])
# Load the profiles into data loaders
Exp.setupDataLoaders(batch_size=2500)

GENERATING PROFILES
GENERATING TRAIN DATASET (N=500000)
GENERATING VALIDATION DATASET (N=10000)
GENERATING TEST DATASET (N=10000)


We then define the model (with default values) and use the first cuda device (cuda_id=1).

In [None]:
# Initialize Neural network-model and allocate it to the cuda_id specified
# Use 'cuda_id="cpu"' if you want to allocate it to a cpu
Exp.setupModel(cuda_id=0)

(CUDA) device is: cuda:0


We then train the network with 25 warm restarts, this means whenever we finish a training session we load the best model settings we have found so far back onto the model and continue from there in the next setting. 

In [None]:
# do 25 warm restarts with decreasing learning rate
lr = 0.001 
best_error=None # no target error to beat in the beginning
for k in range(25):
    
    # Consider changing learning rate (lr) during run more dynamically
    if k >= 5:
        lr = 0.0005
    if k >= 10:
        lr = 0.0001
    if k >= 15:
        lr = 0.00005
    Exp.setupOptimizerAndCriterion(learning_rate = lr)
    # Train the experiment constructed by passing the experiment to the AntiSplodge training function 
    AS.train(Exp, save_file="ModelCheckpoint.pt", patience=5, best_loss=best_error) # For longer training, increase patience threshold
    best_error = AS.getMeanJSD(Exp, "validation") # set best error as the target error to beat
    
    print("Restart [{}] - JSDs".format(k),AS.getMeanJSD(Exp, "train"), AS.getMeanJSD(Exp, "validation"))

Epoch: 001 | Epochs since last increase: 000
Loss: (Train) 0.02406 | (Valid): 0.01790 - (mean)JSD: (Train) 0.11571 | (Valid) 0.11644 

Epoch: 002 | Epochs since last increase: 000 | Better solution found
Loss: (Train) 0.01703 | (Valid): 0.01679 - (mean)JSD: (Train) 0.10646 | (Valid) 0.10779 

Epoch: 003 | Epochs since last increase: 000 | Better solution found
Loss: (Train) 0.01590 | (Valid): 0.01553 - (mean)JSD: (Train) 0.10468 | (Valid) 0.10581 

Epoch: 004 | Epochs since last increase: 001
Loss: (Train) 0.01523 | (Valid): 0.01537 - (mean)JSD: (Train) 0.10600 | (Valid) 0.10759 

Epoch: 005 | Epochs since last increase: 000 | Better solution found
Loss: (Train) 0.01474 | (Valid): 0.01505 - (mean)JSD: (Train) 0.10223 | (Valid) 0.10498 

Epoch: 006 | Epochs since last increase: 000 | Better solution found
Loss: (Train) 0.01441 | (Valid): 0.01433 - (mean)JSD: (Train) 0.09420 | (Valid) 0.09504 

Epoch: 007 | Epochs since last increase: 001
Loss: (Train) 0.01409 | (Valid): 0.01482 - (mean)

In the end, we have a very satisfactory JSD of 7.974%!

In [None]:
print("Test accuracy:", "{:2f}%".format(AS.getMeanJSD(Exp, "test")*100))

Test accuracy: 7.974431%


### Breakdown of the training
Below are the most useful prints from the training above:

Time elapsed: 1097.11 (18.29 Minutes)
Restart [0] - JSDs 0.0781799753933141 0.08347195511076605

Time elapsed: 324.06 (5.40 Minutes)
Restart [1] - JSDs 0.07761261954272479 0.08313330346758878

Time elapsed: 193.54 (3.23 Minutes)
Restart [2] - JSDs 0.07761261954272479 0.08313330346758878

Time elapsed: 197.02 (3.28 Minutes)
Restart [3] - JSDs 0.07761261954272479 0.08313330346758878

Time elapsed: 197.83 (3.30 Minutes)
Restart [4] - JSDs 0.07761261954272479 0.08313330346758878

Time elapsed: 799.08 (13.32 Minutes)
Restart [5] - JSDs 0.0713664292946016 0.07981690082481785

Time elapsed: 199.12 (3.32 Minutes)
Restart [6] - JSDs 0.0713664292946016 0.07981690082481785

Time elapsed: 199.62 (3.33 Minutes)
Restart [7] - JSDs 0.0713664292946016 0.07981690082481785

Time elapsed: 198.68 (3.31 Minutes)
Restart [8] - JSDs 0.0713664292946016 0.07981690082481785

Time elapsed: 199.28 (3.32 Minutes)
Restart [9] - JSDs 0.0713664292946016 0.07981690082481785

Time elapsed: 394.37 (6.57 Minutes)
Restart [10] - JSDs 0.06827267522789586 0.07820636961783578

Time elapsed: 196.94 (3.28 Minutes)
Restart [11] - JSDs 0.06827267522789586 0.07820636961783578

Time elapsed: 198.96 (3.32 Minutes)
Restart [12] - JSDs 0.06827267522789586 0.07820636961783578

Time elapsed: 199.34 (3.32 Minutes)
Restart [13] - JSDs 0.06827267522789586 0.07820636961783578

Time elapsed: 198.95 (3.32 Minutes)
Restart [14] - JSDs 0.06827267522789586 0.07820636961783578

Time elapsed: 199.61 (3.33 Minutes)
Restart [15] - JSDs 0.06827267522789586 0.07820636961783578

Time elapsed: 199.34 (3.32 Minutes)
Restart [16] - JSDs 0.06827267522789586 0.07820636961783578

Time elapsed: 200.27 (3.34 Minutes)
Restart [17] - JSDs 0.06827267522789586 0.07820636961783578

Time elapsed: 266.63 (4.44 Minutes)
Restart [18] - JSDs 0.06784003396816517 0.07800503775816546

Time elapsed: 199.32 (3.32 Minutes)
Restart [19] - JSDs 0.06784003396816517 0.07800503775816546

Time elapsed: 199.21 (3.32 Minutes)
Restart [20] - JSDs 0.06784003396816517 0.07800503775816546

Time elapsed: 202.27 (3.37 Minutes)
Restart [21] - JSDs 0.06784003396816517 0.07800503775816546

Time elapsed: 231.64 (3.86 Minutes)
Restart [22] - JSDs 0.06784003396816517 0.07800503775816546

Time elapsed: 233.92 (3.90 Minutes)
Restart [23] - JSDs 0.06784003396816517 0.07800503775816546

Time elapsed: 239.28 (3.99 Minutes)
Restart [24] - JSDs 0.06784003396816517 0.07800503775816546

Already at the first training session, we have a validation JSD of 0.0834, then at: 

[1] 0.0831,
[5] 0.0798,
[10] 0.0782,
[18] 0.0780 (18 is the last improvement)

The numbers listed above are the mean validation JSDs, actually already at the first step we had a very satisfactory JSD, but we reduced the JSD by approximately 0.5% overall by doing this extra training, all of this took approximately 1.5 hours, so all in all a very neat tool. 