# Managing large and complex training datasets

Large and complex datasets are difficult to use for training models due to computational limitations (enormous amounts of RAM are required just to open such datasets). <br>

Additionally, the time needed to select genes for model training increases. <br>

Meanwhile, increasing the amount of data used for training does not always significantly improve model quality.

Here, we present a method to overcome computational limits for training models on large, complex datasets with many donors.

In [1]:
# Python packages
import warnings
warnings.simplefilter('ignore')

import scanpy as sc
import scparadise
import numpy as np
import pandas as pd
import os

sc.set_figure_params(dpi = 120)

In [2]:
# Create folder to save files
# Dorsolateral Prefrontal Cortex: Seattle Alzheimer's Disease Atlas (SEA-AD)
os.makedirs('snRNAseq_human_retina')

In [3]:
# Download CELLxGENE dataset (Dorsolateral Prefrontal Cortex: Seattle Alzheimer's Disease Atlas (SEA-AD)): 
# https://cellxgene.cziscience.com/collections/1ca90a2d-2943-483d-b678-b809bf464c30
!wget https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad

--2025-02-08 13:49:05--  https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad
Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 52.85.49.24, 52.85.49.28, 52.85.49.17, ...
Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|52.85.49.24|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 37946797973 (35G) [binary/octet-stream]
Saving to: ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’


2025-02-08 14:13:58 (22.1 MB/s) - ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’ saved [37946797973/37946797973]



## Train scAdam model using fraction of dataset
The entire dataset contains 3,177,310 cells and 36406 genes (35.34 GB). It is too large to open on a standard computer. <br>

Additionally, selecting genes for training a model on such a large dataset requires significant computational power and time. <br>

Therefore, the scParadise team recommends that you extract a small portion of the dataset for further steps.

In [4]:
# Obtain 25000 cells randomly
adata_fraction = scparadise.scnoah.get_frac(path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
                                            fraction = 25000,
                                            path_save = 'snRNAseq_human_retina',
                                            celltype = 'cell_type',
                                            random_state = 0)

In [5]:
# Get raw counts from adata_fraction.raw
adata_fraction = adata_fraction.raw.to_adata()
# Replace variable names with gene names
adata_fraction.var.set_index('feature_name', inplace = True)
adata_fraction.var_names_make_unique()
# Normalize data
sc.pp.normalize_total(adata_fraction, target_sum = None)
sc.pp.log1p(adata_fraction)
adata_fraction.raw = adata_fraction

In [6]:
# Find genes for model training (marker genes of cell types)
lst_genes = []
annotations = ['majorclass', 'cell_type'] # annotation levels
for annotation in annotations:
    sc.tl.rank_genes_groups(adata_fraction, 
                            groupby = annotation,
                            method = 't-test_overestim_var', pts = True)
    # Filter marker genes of cell types
    sc.tl.filter_rank_genes_groups(adata_fraction, 
                                   min_fold_change = 1.0, 
                                   min_in_group_fraction = 0.4,
                                   key_added = 'filtered_rank_genes_groups')
    # Create list of genes for model training

    for i in adata_fraction.obs[annotation].unique():
        df = sc.get.rank_genes_groups_df(adata_fraction, group = i, key = 'filtered_rank_genes_groups', pval_cutoff = 0.05)
        df['pts_comparizon'] = df['pct_nz_group']/df['pct_nz_reference']
        lst_genes.extend(df.sort_values(by = 'logfoldchanges', ascending = False).head(20)['names'].tolist())
        lst_genes.extend(df.sort_values(by = 'pts_comparizon', ascending = False).head(20)['names'].tolist())
# Remove duplicates 
lst_genes = np.unique(lst_genes).tolist()

In [7]:
# Subset genes for model training
adata_fraction = adata_fraction[:, lst_genes]

In [8]:
# Alternative way to select genes for model training
# sc.pp.highly_variable_genes(adata_fraction,
#                             n_top_genes = 1000,
#                             subset = True)
# lst_genes = adata_fraction.var_names.tolist()

In [9]:
adata_balanced = scparadise.scnoah.balance(adata_fraction, 
                                           celltype_l1 = annotations[0], # majorclass
                                           celltype_l2 = annotations[1], # cell_type
                                           sample = 'donor_id')

Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell

Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell


In [10]:
# Train scadam model using adata_fraction dataset
scparadise.scadam.train(adata_balanced,
                        path = 'snRNAseq_human_retina', # path to save model
                        model_name = 'model_scAdam', # folder name with model
                        celltype_l1 = 'celltype_l1', # previously: majorclass
                        celltype_l2 = 'celltype_l2', # previously: cell_type
                        eval_metric = ['balanced_accuracy', 'accuracy'])

Successfully saved genes names for training model

Successfully saved dictionary of dataset annotations

Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of input dataset

Accelerator: cuda
Start training
epoch 0  | loss: 2.81661 | train_balanced_accuracy: 0.07509 | train_accuracy: 0.23478 | valid_balanced_accuracy: 0.07484 | valid_accuracy: 0.23441 |  0:00:01s
epoch 1  | loss: 2.15053 | train_balanced_accuracy: 0.29301 | train_accuracy: 0.46274 | valid_balanced_accuracy: 0.29203 | valid_accuracy: 0.46203 |  0:00:02s
epoch 2  | loss: 1.40403 | train_balanced_accuracy: 0.53809 | train_accuracy: 0.6624  | valid_balanced_accuracy: 0.54454 | valid_accuracy: 0.66587 |  0:00:03s
epoch 3  | loss: 0.89481 | train_balanced_accuracy: 0.75518 | train_accuracy: 0.80697 | valid_balanced_accuracy: 0.76594 | valid_accuracy: 0.81095 |  0:00:04s
epoch 4  | loss: 0.66516 | train_balanced_accuracy: 0.81854 | train_accuracy: 0.85328 | valid

## Evaluation of model quality

For model evaluation, we use another subset of 25,000 cells generated using a different random state.

In [11]:
# Get test dataset for model quality evaluation
adata_test = scparadise.scnoah.get_frac(path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
                                        fraction = 25000,
                                        path_save = 'snRNAseq_human_retina',
                                        celltype = 'cell_type',
                                        random_state = 42)

In [12]:
# Check common cells between test and training datasets
lst_train = adata_fraction.obs_names.tolist()
lst_test = adata_test.obs_names.tolist()
lst_train.extend(lst_test)
lst_train = np.unique(lst_train)
percent = round((2 * len(lst_test) - len(lst_train))/len(lst_test)*100, 5)
print(f"There are {percent} % common cells ({2 * len(lst_test) - len(lst_train)} cells) between the test and training datasets")

There are 0.836 % common cells (209 cells) between the test and training datasets


Less than 1% of cells are the same between the test dataset and the training dataset. <br>

This number of similar cells can be ignored, and we can proceed with testing the model's quality.

In [13]:
# Apply the same preprocessing steps to the test dataset as used for training
# Get raw counts from adata_fraction.raw
adata_test = adata_test.raw.to_adata()

# Replace variable names with gene names
adata_test.var.set_index('feature_name', inplace = True)
adata_test.var_names_make_unique()

# Normalize data
sc.pp.normalize_total(adata_test, target_sum = None)
sc.pp.log1p(adata_test)
adata_test.raw = adata_test

In [14]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test, 
                                       path_model = 'snRNAseq_human_retina/model_scAdam')

Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities


In [15]:
## Check model quality
df_l1 = scparadise.scnoah.report_classif_full(adata_test,
                                              celltype = 'majorclass',
                                              pred_celltype = 'pred_celltype_l1')
df_l1

Unnamed: 0,precision,recall/sensitivity,specificity,f1-score,geometric mean,index balanced accuracy,number of cells
AC,0.9998,0.9964,1.0,0.9981,0.9982,0.996,4496.0
Astrocyte,1.0,0.982,1.0,0.9909,0.991,0.9802,111.0
BC,0.9969,0.9994,0.9991,0.9982,0.9993,0.9986,5437.0
Cone,1.0,0.999,1.0,0.9995,0.9995,0.9989,1000.0
HC,0.9984,1.0,1.0,0.9992,1.0,1.0,634.0
MG,0.9994,0.9983,1.0,0.9989,0.9991,0.9981,1744.0
Microglia,1.0,0.9744,1.0,0.987,0.9871,0.9719,39.0
RGC,0.9981,0.9997,0.9997,0.9989,0.9997,0.9994,3144.0
RPE,1.0,1.0,1.0,1.0,1.0,1.0,7.0
Rod,0.9998,0.9999,0.9999,0.9998,0.9999,0.9998,8388.0


In [16]:
## Check model quality
df_l2 = scparadise.scnoah.report_classif_full(adata_test,
                                              celltype = 'cell_type',
                                              pred_celltype = 'pred_celltype_l2')
df_l2

Unnamed: 0,precision,recall/sensitivity,specificity,f1-score,geometric mean,index balanced accuracy,number of cells
GABAergic amacrine cell,0.994,0.9881,0.9992,0.991,0.9936,0.9862,2855.0
H1 horizontal cell,0.9871,0.9907,0.9997,0.9889,0.9952,0.9896,540.0
H2 horizontal cell,0.9355,0.9255,0.9998,0.9305,0.9619,0.9184,94.0
Mueller cell,0.9994,0.9977,1.0,0.9986,0.9988,0.9974,1744.0
OFF midget ganglion cell,0.9146,0.896,0.9944,0.9052,0.9439,0.8822,1577.0
OFF parasol ganglion cell,0.9157,0.962,0.9997,0.9383,0.9807,0.9581,79.0
OFFx cell,0.9449,0.9836,0.9997,0.9639,0.9916,0.9817,122.0
ON midget ganglion cell,0.9259,0.9003,0.9964,0.9129,0.9471,0.8884,1193.0
ON parasol ganglion cell,0.8889,0.9796,0.9998,0.932,0.9896,0.9774,49.0
ON-blue cone bipolar cell,0.875,0.913,0.9999,0.8936,0.9555,0.905,23.0


The model performs well except for the 'retinal ganglion cell'. <br> 

You could try using a different random state to generate another test dataset.

## Iterative warm start training (optional)

You may use another subset of the whole dataset to increase model generalization, but this may lead to overfitting.

In [17]:
# Do not change the lower bound of the range to exclude 0, which was used for the primary training of the model
for i in range(1, 4):
    # Obtain 1% of cells randomly
    adata_fraction = scparadise.scnoah.get_frac(path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
                                                fraction = 25000,
                                                path_save = 'snRNAseq_human_retina',
                                                celltype = 'cell_type',
                                                random_state = i)
    # Get raw counts from adata_fraction.raw
    adata_fraction = adata_fraction.raw.to_adata()
    
    # Replace variable names with gene names
    adata_fraction.var.set_index('feature_name', inplace=True)
    adata_fraction.var_names_make_unique()
    
    # Normalize data
    sc.pp.normalize_total(adata_fraction, target_sum = None)
    sc.pp.log1p(adata_fraction)
    adata_fraction.raw = adata_fraction
    
    # Subset genes for model training
    adata_fraction = adata_fraction[:, lst_genes]
    # Balance dataset
    adata_balanced = scparadise.scnoah.balance(adata_fraction, 
                                               celltype_l1 = annotations[0], # majorclass
                                               celltype_l2 = annotations[1], # cell_type
                                               sample = 'donor_id')
    adata_balanced.raw = adata_balanced
    # Warm start requires second training dataset and path to pretrained model 
    scparadise.scadam.warm_start(adata_balanced,
                                 path = 'snRNAseq_human_retina', # path to save model
                                 model_name = 'model_scAdam', # folder name with pretrained model
                                 celltype_l1 = 'celltype_l1', # previously: majorclass
                                 celltype_l2 = 'celltype_l2', # previously: cell_type
                                 eval_metric = ['balanced_accuracy', 'accuracy'])

Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell

Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of inpu

In [18]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test, 
                                       path_model = 'snRNAseq_human_retina/model_scAdam')

Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities


In [19]:
## Check model quality
df_warm_start_l1 = scparadise.scnoah.report_classif_full(adata_test,
                                                         celltype='majorclass',
                                                         pred_celltype='pred_celltype_l1')
df_warm_start_l1

Unnamed: 0,precision,recall/sensitivity,specificity,f1-score,geometric mean,index balanced accuracy,number of cells
AC,0.9993,0.9949,0.9999,0.9971,0.9974,0.9942,4496.0
Astrocyte,0.9909,0.982,1.0,0.9864,0.9909,0.9802,111.0
BC,0.9971,0.9996,0.9992,0.9983,0.9994,0.9989,5437.0
Cone,1.0,1.0,1.0,1.0,1.0,1.0,1000.0
HC,0.9953,1.0,0.9999,0.9976,0.9999,0.9999,634.0
MG,1.0,0.9989,1.0,0.9994,0.9994,0.9987,1744.0
Microglia,0.95,0.9744,0.9999,0.962,0.9871,0.9718,39.0
RGC,0.9978,0.9997,0.9997,0.9987,0.9997,0.9994,3144.0
RPE,1.0,1.0,1.0,1.0,1.0,1.0,7.0
Rod,0.9996,0.9995,0.9998,0.9996,0.9997,0.9993,8388.0


In [20]:
## Check model quality
df_warm_start_l2 = scparadise.scnoah.report_classif_full(adata_test,
                                                         celltype='cell_type',
                                                         pred_celltype='pred_celltype_l2')
df_warm_start_l2

Unnamed: 0,precision,recall/sensitivity,specificity,f1-score,geometric mean,index balanced accuracy,number of cells
GABAergic amacrine cell,0.9965,0.9832,0.9995,0.9898,0.9913,0.9811,2855.0
H1 horizontal cell,0.9981,0.9944,1.0,0.9963,0.9972,0.9939,540.0
H2 horizontal cell,0.949,0.9894,0.9998,0.9688,0.9946,0.9881,94.0
Mueller cell,1.0,0.9977,1.0,0.9989,0.9989,0.9975,1744.0
OFF midget ganglion cell,0.9604,0.877,0.9976,0.9168,0.9353,0.8643,1577.0
OFF parasol ganglion cell,0.963,0.9873,0.9999,0.975,0.9936,0.986,79.0
OFFx cell,0.9918,0.9918,1.0,0.9918,0.9959,0.991,122.0
ON midget ganglion cell,0.928,0.9405,0.9963,0.9342,0.968,0.9318,1193.0
ON parasol ganglion cell,0.98,1.0,1.0,0.9899,1.0,1.0,49.0
ON-blue cone bipolar cell,0.88,0.9565,0.9999,0.9167,0.978,0.9523,23.0


In [22]:
pd.set_option('display.max_rows', 100)
df_l2.compare(df_warm_start_l2, keep_equal=True, align_axis = 0, result_names=('default', 'warm start'))

Unnamed: 0,Unnamed: 1,precision,recall/sensitivity,specificity,f1-score,geometric mean,index balanced accuracy
GABAergic amacrine cell,default,0.994,0.9881,0.9992,0.991,0.9936,0.9862
GABAergic amacrine cell,warm start,0.9965,0.9832,0.9995,0.9898,0.9913,0.9811
H1 horizontal cell,default,0.9871,0.9907,0.9997,0.9889,0.9952,0.9896
H1 horizontal cell,warm start,0.9981,0.9944,1.0,0.9963,0.9972,0.9939
H2 horizontal cell,default,0.9355,0.9255,0.9998,0.9305,0.9619,0.9184
H2 horizontal cell,warm start,0.949,0.9894,0.9998,0.9688,0.9946,0.9881
Mueller cell,default,0.9994,0.9977,1.0,0.9986,0.9988,0.9974
Mueller cell,warm start,1.0,0.9977,1.0,0.9989,0.9989,0.9975
OFF midget ganglion cell,default,0.9146,0.896,0.9944,0.9052,0.9439,0.8822
OFF midget ganglion cell,warm start,0.9604,0.877,0.9976,0.9168,0.9353,0.8643


Iterative warm start training led to an increase in all model quality metrics (rows - macro average, weighted average, accuracy, and balanced accuracy). <br>

Additionally, the model's sensitivity increased by 19.5% and precision by 5.5% for the retinal ganglion cell.

In [21]:
import session_info
session_info.show()