In [None]:
import sys
import scanpy as sc
import decoupler as dc

# Only needed for processing
import numpy as np
import pandas as pd

sys.path.append('..')
from load_data import load_datasets

sc.settings.verbosity = 2

In [None]:
norm_method = 'mean' # mean or CP10k

In [None]:
adata = load_datasets('luad_xing', preprocessed=True, norm_method=norm_method)
adata.uns['log1p']['base'] = None

In [None]:
adata.layers['normalized'] = adata.X.copy()

In [None]:
adata.obs['condition'] = adata.obs['malignant_key'].astype(str).copy()
adata.obs['condition'] = adata.obs['condition'].map({
    'malignant':'B',
    'non-malignant':'A',
})
adata.obs['condition'] = adata.obs['condition'].astype('category')
adata.obs['condition']

In [None]:
# Identify highly variable genes
sc.pp.highly_variable_genes(adata, batch_key='sample_id')

# Generate PCA features
sc.tl.pca(adata, svd_solver='arpack', use_highly_variable=True)

# Compute distances in the PCA space, and find cell neighbors
sc.pp.neighbors(adata)

# Generate UMAP features
sc.tl.umap(adata)

# Visualize
sc.pl.umap(adata, color=['condition','celltype'], frameon=False)

In [None]:
# Get pseudo-bulk profile
pdata = dc.get_pseudobulk(adata,
                          sample_col='sample_id',
                          groups_col='condition',
                          layer='counts',
                          mode='sum',
                          min_cells=10,
                          min_counts=1000,
                         )
pdata

In [None]:
dc.plot_psbulk_samples(pdata, groupby=['sample_id', 'condition'], figsize=(11, 3))

In [None]:
dc.plot_filter_by_expr(pdata, group='condition')

In [None]:
# Obtain genes that pass the thresholds
genes = dc.filter_by_expr(pdata, group='condition', min_count=10, min_total_count=30)

print(len(genes))

# Filter by these genes
pdata = pdata[:, genes].copy()

In [None]:
dc.plot_filter_by_expr(pdata, group='condition')

In [None]:
# Import DESeq2
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats

In [None]:
# Build DESeq2 object
dds = DeseqDataSet(
    adata=pdata,
    design_factors='condition',
    refit_cooks=True,
    n_cpus=8,
)

In [None]:
# Compute LFCs
dds.deseq2()

In [None]:
print(dds)

In [None]:
#contrast=["malignant_key",'malignant','non-malignant']
stat_res = DeseqStats(dds, alpha=0.01, n_cpus=8, joblib_verbosity=1)

In [None]:
# Compute Wald test
stat_res.summary()

In [None]:
stat_res.LFC

In [None]:
# Shrink LFCs
stat_res.lfc_shrink(coeff='condition_B_vs_A')

In [None]:
stat_res.shrunk_LFCs

In [None]:
# Extract results
results_df = stat_res.results_df

In [None]:
dc.plot_volcano_df(results_df, x='log2FoldChange', y='padj', top=30)

In [None]:
tmp = results_df[(results_df.padj<0.01)&(results_df.log2FoldChange>2)].sort_values(by=['padj','log2FoldChange'], ascending=[True,False])

In [None]:
tmp.loc['C19orf33']

In [None]:
tmp.loc['FOXA2']

In [None]:
tmp.reset_index(names='genes')[0:50]