# Compare Scooby, Decima and seq2cell

In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
from plotnine import *
import os
import torch

from scikit_posthocs import posthoc_dunn
from scipy.stats import kruskal, mannwhitneyu

## Paths

In [2]:
scooby_dir = '/gstore/data/resbioai/grelu/decima/scooby'

In [3]:
scooby_all_outputs_path = os.path.join(scooby_dir, 'count_eval', 'count_predicted_test_no_neighbor.pq')
scooby_all_targets_path = os.path.join(scooby_dir, 'count_eval', 'count_target_test_no_neighbor.pq')
scooby_gene_names_path = os.path.join(scooby_dir, 'count_eval', 'gene_names.pq')

sch_outputs_path = os.path.join(scooby_dir, 'schwessinger/adata_processed_schwessinger_preds_40epochs.h5ad')

## Load scooby predictions

In [4]:
scooby_all_outputs = torch.load(scooby_all_outputs_path)
scooby_all_targets = torch.load(scooby_all_targets_path)
scooby_gene_names = pd.read_parquet(scooby_gene_names_path)[0].tolist()
scooby_all_outputs.shape, scooby_all_targets.shape, len(scooby_gene_names)



((2539, 21), (2539, 21), 2539)

## Load Schwessinger predictions

In [5]:
sch = sc.read(sch_outputs_path)
sch = sc.get.aggregate(sch, by='l2_cell_type', func='sum', axis='var', layer='predicted')
sch.shape

(15892, 21)

## Load Decima predictions

In [6]:
decima_outputs = np.load('decima_test_preds.npy')
decima_test_genes = np.load('decima_test_genes.npy', allow_pickle=True)
decima_outputs.shape, len(decima_test_genes)

((1628, 21), 1628)

## Subset to common genes

In [7]:
decima_outputs = pd.DataFrame(decima_outputs, index = decima_test_genes)

In [8]:
scooby_all_outputs = pd.DataFrame(scooby_all_outputs, index=scooby_gene_names)
scooby_all_targets = pd.DataFrame(scooby_all_targets, index=scooby_gene_names)

In [9]:
sch.obs_names = sch.obs.index = sch.obs.gene_name.tolist()

In [10]:
common_genes = list(set(decima_test_genes).intersection(scooby_gene_names).intersection(sch.obs_names))
len(common_genes)

1464

In [11]:
scooby_all_outputs = scooby_all_outputs.loc[common_genes]
scooby_all_targets = scooby_all_targets.loc[common_genes]

scooby_all_outputs.shape, scooby_all_targets.shape

((1464, 21), (1464, 21))

In [12]:
decima_outputs = decima_outputs.loc[common_genes]
decima_outputs.shape

(1464, 21)

In [13]:
sch = sch[common_genes, :]
sch = pd.DataFrame(sch.layers['sum'], index=sch.obs_names, columns=sch.var_names)
sch.shape

(1464, 21)

In [14]:
scooby_all_outputs = scooby_all_outputs.values
scooby_all_targets = scooby_all_targets.values
decima_outputs = decima_outputs.values
sch = sch.values

## Log transform

In [15]:
scooby_all_outputs = np.log(scooby_all_outputs + 1)
scooby_all_targets = np.log(scooby_all_targets + 1)

In [16]:
sch = np.log(sch + 1)

## Compute metrics

In [17]:
def compute_metrics(preds, targets):
    assert preds.shape == targets.shape
    assert preds.shape[1] == 21
    n_genes = preds.shape[0]
    per_ct_corrs = [np.corrcoef(preds[:, i], targets[:, i])[0, 1] for i in range(21)]
    per_gene_corrs = [np.corrcoef(preds[i], targets[i])[0, 1] for i in range(n_genes)]
    return per_ct_corrs, per_gene_corrs

In [18]:
scooby_metrics = compute_metrics(scooby_all_outputs, scooby_all_targets)
decima_metrics = compute_metrics(decima_outputs, scooby_all_targets)
sch_metrics = compute_metrics(sch, scooby_all_targets)



In [24]:
print("Scooby")
print(np.round(np.mean(scooby_metrics[0]), 4), np.round(np.nanmean(scooby_metrics[1]), 4))
print("Decima")
print(np.round(np.mean(decima_metrics[0]), 4), np.round(np.nanmean(decima_metrics[1]), 4))
print("seq2cells")
print(np.round(np.mean(sch_metrics[0]), 4), np.round(np.nanmean(sch_metrics[1]), 4))

Scooby
0.828 0.7461
Decima
0.8152 0.7494
seq2cells
0.793 0.5217


## Visualize

In [29]:
cell_types = pd.read_parquet('/gstore/data/resbioai/grelu/decima/scooby/training_data/scooby_training_data/celltype_fixed.pq')

In [30]:
ct_df = pd.DataFrame({'cell type':cell_types.celltype.tolist(),
    'scooby':scooby_metrics[0], 'decima':decima_metrics[0], 'seq2cells':sch_metrics[0]
})
ct_df = ct_df.melt(id_vars='cell type', var_name='model')

In [41]:
p = (
    ggplot(ct_df, aes(y='value', fill='model', x='cell type'))
    + geom_col(position='dodge', width=.6) 
    + theme_classic() + theme(figure_size=(6,3))
    + theme(axis_text_x=element_text(angle=45, hjust=1))
    + ylab("Pearson Correlation\n per cell type")
    + ylim(0,1)
    + geom_text(label='Mean: Decima: 0.815  Scooby: 0.828  seq2cells: 0.793', x=9, y=1, size=9)
)
p.save('ct.svg')



In [33]:
gene_df = pd.DataFrame({'scooby':scooby_metrics[1],
    'decima':decima_metrics[1], 'seq2cells':sch_metrics[1]
})
gene_df = gene_df.dropna()
print(len(gene_df))
gene_df = gene_df.reset_index().melt(id_vars='index', var_name='model')

1439


In [47]:
p=(
    ggplot(gene_df, aes(y='value', x='model'))
    + geom_boxplot(outlier_size=.1) 
    + theme_classic() + theme(figure_size=(4.5,2))
    + ylab("Pearson correlation\n     per gene")
    + ylim(-0.4, 1.2)
    + geom_text(label="Mean: Decima: 0.749  Scooby: 0.746  seq2cells:0.522", x=2, y=1.15, size=9)
)
p.save('genes.svg')



In [35]:
groups = gene_df.model.unique()
pval = kruskal(*[gene_df.loc[gene_df.model == group, 'value'] for group in groups]).pvalue
print(pval)

1.524338633580944e-210


In [36]:
padj = posthoc_dunn(gene_df, val_col='value', group_col='model', p_adjust="fdr_bh")
padj

Unnamed: 0,decima,scooby,seq2cells
decima,1.0,0.7953533,1.1880880000000001e-160
scooby,0.7953533,1.0,6.459098e-158
seq2cells,1.1880880000000001e-160,6.459098e-158,1.0


In [38]:
mannwhitneyu(gene_df.loc[gene_df.model == 'decima', 'value'],
            gene_df.loc[gene_df.model == 'scooby', 'value'])

MannwhitneyuResult(statistic=np.float64(1042318.0), pvalue=np.float64(0.7549441925038136))

In [39]:
mannwhitneyu(gene_df.loc[gene_df.model == 'decima', 'value'],
            gene_df.loc[gene_df.model == 'seq2cells', 'value'])

MannwhitneyuResult(statistic=np.float64(1637037.0), pvalue=np.float64(1.7252306611478113e-160))