In [None]:
import scanpy as sc
import scvi
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib import cm

In [2]:
DATA = Path("../data")
MODEL = Path("../model")

In [None]:
adata = sc.read_h5ad(DATA / "doublet.h5ad")

In [None]:
adata = adata.copy()
scvi.model.SCVI.setup_anndata(
    adata,
    layer='counts',
    batch_key='batch',
    continuous_covariate_keys=['pct_counts_mt'])
vae = scvi.model.SCVI(adata)
vae.train()

In [16]:
scvi.model.SCVI.save(
    vae,
    'trained2.model',
    overwrite=True,
    save_anndata=True)

In [None]:
vae = scvi.model.SCVI.load('trained2.model')
adata = vae.adata

In [4]:
adata.obsm['X_scVI'] = vae.get_latent_representation()
adata.obsm['X_normalized'] = vae.get_normalized_expression()

In [None]:
sc.pp.neighbors(adata, use_rep="X_scVI")
sc.tl.leiden(adata,resolution=0.5)
sc.tl.umap(adata, min_dist=0.3)

In [None]:
sc.pl.umap(adata,color=["cell_type",'batch'],frameon=False, ncols=1)
plt.show()

In [None]:
de = vae.differential_expression(adata,groupby='cell_type')

In [13]:
def to_curly(gene: str) -> str:
    curly_gene = ("$" + gene + "$")
    return curly_gene

In [14]:
de_gene = de[
    (de['proba_de'] > 0.8) &
    (de['lfc_mean'] > 1) &
    (de['non_zeros_proportion1'] > 0.2)
]
de_gene = de_gene.sort_values('proba_de', ascending = False)
de_groups = de_gene.groupby('comparison')

markers = {
    name.split(' ')[0]: de_groups.get_group(name).index.tolist()[:5] for name in de_groups.groups.keys()
    }

symbols = []
for name in markers:
    symbols.extend(list(map(to_curly,markers[name])))

In [None]:
ax = sc.pl.dotplot(
    adata,
    markers,
    groupby='cell_type',
    use_raw=True,
    standard_scale='var',
    cmap=cm.viridis_r,
    show=False
    )

ax['mainplot_ax'].set_xticklabels(symbols)
plt.show()