# Import Packages

In [None]:
import os
import numpy as np
import pandas as pd
import datetime
from pathlib import Path
import scanpy as sc
import re
from pprint import pprint

from natsort import natsorted

from sklearn import metrics
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.decomposition import PCA
import umap

from tqdm.auto import tqdm, trange

from copy import deepcopy

from scipy import stats, linalg

import matplotlib as mpl
import matplotlib.pyplot as plt
import cmocean
import seaborn as sns
from mpl_toolkits import mplot3d
%matplotlib inline  

sc.settings.verbosity = 4

In [None]:
import scvi
from scvi.model.utils import mde

scvi.settings.verbosity = 40

In [None]:
n_dims = 16
HVG_THRESH = 0.9
HVG_KEY = "sig"
P2 = "Olfr17"
MITO_PCT = 40

# Read in data

In [None]:
adata=sc.read_h5ad('Concat_dataset.h5ad')

# HVG via scanpy

In [None]:
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=3000,
    subset=False,
    layer="counts",
    flavor="seurat_v3",
    batch_key="orig_ident"
)

In [None]:
adata.var['mean_'] = np.array(adata.X.mean(0))[0]
adata.var['frac_zero'] = 1 - np.array((adata.X > 0).sum(0))[0] / adata.shape[0]

In [None]:
fig, ax = plt.subplots(figsize=(9,6))

ax.scatter(adata.var.mean_, adata.var.frac_zero, s=1)
ax.set_xscale("log")

In [None]:
df_poisson = scvi.data.poisson_gene_selection(
    adata, n_top_genes=3_000, batch_key="orig_ident", inplace=False
)

In [None]:
df_poisson[df_poisson.highly_variable].sort_values('prob_zero_enrichment_rank')

In [None]:
pd.crosstab(df_poisson.highly_variable, adata.var.highly_variable)

In [None]:
is_hvg = df_poisson.highly_variable

In [None]:
adata.varm['df_poisson']= df_poisson

In [None]:
adata_query = adata[:, is_hvg].copy()
print(adata_query)

# Fit scvi model

In [None]:
scvi.model.SCVI.setup_anndata(
    adata_query,
    layer="counts",
    categorical_covariate_keys=["cond", "orig_ident"],
    continuous_covariate_keys=["pct_counts_mito"]
)

In [None]:
model = scvi.model.SCVI(adata_query, gene_likelihood="nb")

In [None]:
model.view_anndata_setup()

In [None]:
train_kwargs = dict(
    early_stopping=True,
    early_stopping_patience=20,
    enable_model_summary=True,
    enable_progress_bar=True,
    enable_checkpointing=True,
    max_epochs=500
)

In [None]:
model.train(**train_kwargs)

In [None]:
train_elbo = model.history['elbo_train'][1:]
test_elbo = model.history['elbo_validation']

ax = train_elbo.plot()
test_elbo.plot(ax = ax)

In [None]:
# save model
model.save("scvi_model")

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI"] = latent

In [None]:
sc.pp.neighbors(adata, use_rep="X_scVI")
sc.tl.umap(adata, min_dist=0.5)

In [None]:
# neighbors were already computed using scVI
sc.tl.leiden(adata, key_added="leiden_scVI_1.2", resolution=1.2)

# QC analysis

In [None]:
sc.pl.umap(
    adata,
    color=["n_genes", "total_counts", "pct_counts_mito", "log1p_total_counts"],
    cmap="cubehelix_r",
    s=3,
    ncols=2,
)

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
sc.pl.umap(adata, color="orig_ident", cmap="cmo.matter", s=3, ax=ax, vmax="p99.99")

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
sc.pl.umap(adata, color="cond", cmap="cmo.matter", s=3, ax=ax, vmax="p99.99")

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
sc.pl.umap(adata, color="leiden_scVI_1.2", legend_loc="on data", ax=ax, s=3)

In [None]:
#Analyze expression of given genes in global UMAP

#For example
genes = ["CXCL14", "MEG3", "TP63", "TOP2A", "SERPINB3", "SOX9",
         "ERMN", "ACSM4", "CFTR", "SH2D7", "LHX2", "STOML3", "PLP1",
         "CD3D", "FGFBP2", "CD79A", "S100A12", "CD14", "C1QB", "TPSB2",
         "HBB", "ENG", "DCN", "ACTA2"]

sc.pl.umap(
    adata,
    color=genes,
    use_raw=False,
    cmap="cmo.matter",
    ncols=3,
    frameon=False,
    vmax="p99.9",
    layer="norm"
)

In [None]:
adata_query.obs['cluster'] = adata.obs["leiden_scVI_1.1"].copy()

In [None]:
#log1p total counts
fig, ax = plt.subplots(figsize=(18,6))
sns.boxenplot(data=adata_query.obs, x="cluster", y="log1p_total_counts", ax=ax)

In [None]:
#Pct counts mito
fig, ax = plt.subplots(figsize=(18,6))
sns.boxenplot(data=adata_query.obs, x="cluster", y="pct_counts_mito", ax=ax)

In [None]:
#Identify poor quality clusters

#For example
bad_clust = ["40", "42"]

in_bad_clust = adata.obs["leiden_scVI_1.1"].isin(bad_clust)

fig, ax = plt.subplots(figsize=(12, 8))
xu, yu = adata.obsm["X_umap"].T
ax.scatter(xu, yu, s=0.1, color="0.7")
ax.scatter(xu[in_bad_clust], yu[in_bad_clust], s=0.1)

In [None]:
#Run DE on poor quality clusters to determine possible identity
adata_query.obs["clusters2"] = adata.obs["leiden_scVI_1.1"].copy()

df_de = model.differential_expression(adata_query, groupby="clusters2", group1="42")

In [None]:
df_de[df_de.lfc_mean > 0].head(20)

In [None]:
#Based on above QC metrics data, remove confirmed poor quality cell clusters
to_keep = (
    (~adata_all.obs["leiden_scVI_1.1"].isin(bad_clust))
    & (adata_all.obs.pct_counts_mito <= MITO_PCT)
)

print(to_keep.sum())
print(to_keep.mean())

In [None]:
adata_f = adata[to_keep].copy()

In [None]:
adata_f.write('COVID_dataset_scvi_1.h5ad')

From here can iteratively re-train and run a new model starting from the HVG via Scanpy step to eliminate all low quality cells.

Each time clusters of interest were subset out, the model was re-trained and re-run to allow for optimal clustering.