In [None]:
import scanpy as sc
import scanpy.external as sce

import pandas as pd
import numpy as np
import os
from functools import reduce
import gseapy as gp

import triku as tk

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

import scipy.stats as sts

import gc

from cellassign import assign_cats

import time
from scipy.stats import linregress

from tqdm.notebook import tqdm

In [None]:
seed = 0

In [None]:
# Palettes for UMAP gene expression

magma = [plt.get_cmap('magma')(i) for i in np.linspace(0,1, 80)]
magma[0] = (0.88, 0.88, 0.88, 1)
magma = mpl.colors.LinearSegmentedColormap.from_list("", magma[:65])

In [None]:
mpl.rcParams['figure.dpi'] = 200  # Set this to make higher quality figures

# Analyzing runtimes of a standard processing pipeline

In [None]:
reynolds_dir = 'reynolds_2020'

adata_all = sc.read(f"{reynolds_dir}/reynolds_2020.h5ad")

In [None]:
dict_times = {
    'Number of cells': [],
    'Filter genes': [],
    'Log1p': [],
    'Normalization': [],
    'PCA': [],
    'Batch effect correction (harmony)': [],
    'Batch effect correction (bbknn)': [],
    'Neighbors': [],
    'Feature selection': [],
    'UMAP': [],
    'Leiden': [], 
    'DEGs': [],
    'Total': [],
}


def run_analysis(adata, n_cells, min_cells=5, random_state=0):
    adata_sub = sc.pp.subsample(adata,n_obs=n_cells, copy=True, random_state=random_state)
    
    
    names, counts = np.unique(adata_sub.obs['sample_id'], return_counts=True)
    adata_sub = adata_sub[adata_sub.obs['sample_id'].isin(names[counts > min_cells])]
    dict_times['Number of cells'].append(len(adata_sub))
    print(len(adata_sub))
    
    t_total = time.time()
    t = time.time()
    sc.pp.filter_genes(adata_sub, min_counts=50)
    dict_times['Filter genes'].append(time.time() - t)
    
    t = time.time()
    sc.pp.log1p(adata_sub)
    dict_times['Log1p'].append(time.time() - t)
    
    t = time.time()
    sc.pp.normalize_total(adata_sub)
    dict_times['Normalization'].append(time.time() - t)
    
    sc.pp.filter_genes(adata_sub, min_counts=1)
    
    t = time.time()
    sc.pp.pca(adata_sub, random_state=seed, n_comps=30)
    dict_times['PCA'].append(time.time() - t)

    t = time.time()
    sce.pp.bbknn(adata_sub, metric='angular', batch_key='sample_id', neighbors_within_batch=5)
    dict_times['Batch effect correction (bbknn)'].append(time.time() - t)
    
    t = time.time()
    sce.pp.harmony_integrate(adata_sub, max_iter_harmony=50, key='sample_id', basis='X_pca', adjusted_basis='X_pca_harmony', verbose=False)
    dict_times['Batch effect correction (harmony)'].append(time.time() - t)
    
    t = time.time()
    sc.pp.neighbors(adata_sub, use_rep='X_pca_harmony')
    dict_times['Neighbors'].append(time.time() - t)
    
    t = time.time()
    tk.tl.triku(adata_sub, use_raw=False)
    dict_times['Feature selection'].append(time.time() - t)
    
    t = time.time()
    sc.tl.umap(adata_sub)
    dict_times['UMAP'].append(time.time() - t)
    
    t = time.time()
    sc.tl.leiden(adata_sub)
    dict_times['Leiden'].append(time.time() - t)
    
    t = time.time()
    sc.tl.rank_genes_groups(adata_sub, groupby='leiden')
    dict_times['DEGs'].append(time.time() - t)
    
    dict_times['Total'].append(time.time() - t_total)

In [None]:
cats = [1000] * 3 + [2000] * 3 + [5000] * 3 + [10000] * 3 + [20000] * 3 + [50000] * 3 + [100000] * 3 + [200000] * 3 + [400000] * 3

In [None]:
for idx, n_cells in tqdm(enumerate(cats)):
    run_analysis(adata=adata_all, n_cells=n_cells, random_state=idx)

In [None]:
df = pd.DataFrame(dict_times)
df['Total (no harmony)'] = df['Total'] - df['Batch effect correction (harmony)']
df['x'] = cats
df

In [None]:
df.to_csv('figures/supp_table_times.csv')

In [None]:
fig, ax = plt.subplots(1,1, figsize=(5, 3))

means = df.groupby('x').mean()['Total (no harmony)'].values
stds = df.groupby('x').std()['Total (no harmony)'].values
xpos = df.groupby('x').mean().index.values

ax.scatter(np.log10(xpos), np.log10(means), c="#7fb3d5", marker='_')

for x, mean, std in zip(xpos, means, stds):
    ax.plot(np.log10([x, x]), np.log10([mean - std, mean + std]), c="#7fb3d5",)
    
ax.set_xticks(np.log10([1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 400000]))
ax.set_xticklabels(['1k', '2k', '5k', '10k', '20k', '50k', '100k', '200k', '400k'])

ax.set_yticks(np.log10([30, 60, 120, 300, 900, 1800, 3600]))
ax.set_yticklabels(['30s', '1 min', '2 min', '5 min', '15 min', '30 min',  '1 h'])

ax.set_xlabel('Number of cells')
ax.set_ylabel('Computing time')


# Plot regresssions
l1 = linregress(np.log10(df['Number of cells'][df['Number of cells'] < 20000].values), np.log10(df['Total (no harmony)'][df['Number of cells'] < 20000].values))
l2 = linregress(np.log10(df['Number of cells'][df['Number of cells'] > 20000].values), np.log10(df['Total (no harmony)'][df['Number of cells'] > 20000].values))
ax.plot(np.log10([900, 100000]), l1.slope * np.log10([900, 100000]) + l1.intercept, alpha=0.1, c="#008800")
ax.plot(np.log10([10000, 500000]), l2.slope * np.log10([10000, 500000]) + l2.intercept, alpha=0.1, c="#880000")

# Plot horizontal bars
alpha, color = 0.2, "#7fb3d5"
ax.plot(np.log10([1000, 420000]), np.log10([3600, 3600]), alpha=alpha, c=color)
ax.plot(np.log10([1000, 420000]), np.log10([1800, 1800]), alpha=alpha, c=color)
ax.plot(np.log10([1000, 420000]), np.log10([900, 900]), alpha=alpha, c=color)
ax.plot(np.log10([1000, 420000]), np.log10([300, 300]), alpha=alpha, c=color)
ax.plot(np.log10([1000, 420000]), np.log10([120, 120]), alpha=alpha, c=color)
ax.plot(np.log10([1000, 420000]), np.log10([60, 60]), alpha=alpha, c=color)
ax.plot(np.log10([1000, 420000]), np.log10([30, 30]), alpha=alpha, c=color)

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

ax.text(3, 3.35, f'$time (s) = 10 ^ {{{l1.slope:.2f} \cdot log_{{{10}}}(n_c) {l1.intercept:.2f}}}$', fontsize=10, alpha=0.4, c="#008800")
ax.text(3, 3.05, f'$time (s) = 10 ^ {{{l2.slope:.2f} \cdot log_{{{10}}}(n_c) {l2.intercept:.2f}}}$', fontsize=10, alpha=0.4, c="#880000")


plt.tight_layout()
plt.savefig('figures/figtimes.pdf')

In [None]:
from scipy.stats import linregress

In [None]:
linregress(np.log10(df['Number of cells'].values), 
           np.log10(df['Total (no harmony)'].values))