In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
import matplotlib.pyplot as plt
import seaborn as sns

sc.logging.print_header()
sc.settings.set_figure_params(dpi=100, facecolor='white')

## Load data

In [None]:
ad_MTf = sc.read_10x_mtx(path='./cellranger_out/MTf/filtered_feature_bc_matrix/', cache=True)
ad_MTm = sc.read_10x_mtx(path='./cellranger_out/MTm/filtered_feature_bc_matrix/', cache=True)
ad_WTf = sc.read_10x_mtx(path='./cellranger_out/WTf/filtered_feature_bc_matrix/', cache=True)
ad_WTm = sc.read_10x_mtx(path='./cellranger_out/WTm/filtered_feature_bc_matrix/', cache=True)

In [None]:
display(ad_MTf)

In [None]:
display(ad_MTm)

In [None]:
display(ad_WTf)

In [None]:
display(ad_WTm)

In [None]:
ad_MTf.obs['Genotype'] = 'MT'
ad_MTf.obs['Sex'] = 'female'

ad_MTm.obs['Genotype'] = 'MT'
ad_MTm.obs['Sex'] = 'male'

ad_WTf.obs['Genotype'] = 'WT'
ad_WTf.obs['Sex'] = 'female'

ad_WTm.obs['Genotype'] = 'WT'
ad_WTm.obs['Sex'] = 'male'

In [None]:
adata = ad_MTf.concatenate([ad_MTm, ad_WTf, ad_WTm], batch_categories=['MTf', 'MTm', 'WTf', 'WTm'])

In [None]:
adata

## Preprocessing

In [None]:
adata.obs['batch'].value_counts()

In [None]:
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=20)

In [None]:
adata.obs['batch'].value_counts()

In [None]:
adata.var['mt'] = adata.var_names.str.startswith('mt-')

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

In [None]:
adata.obs

In [None]:
fig = plt.figure(figsize=(16,4))
ax = fig.add_subplot(111)
sc.pl.violin(adata, ['n_genes_by_counts'],
             groupby='batch',
             jitter=0.4, multi_panel=True,
             ax=ax)

In [None]:
fig = plt.figure(figsize=(16,4))
ax = fig.add_subplot(111)
sc.pl.violin(adata, ['total_counts'],
             groupby='batch',
             jitter=0.4, multi_panel=True,
             ax=ax)

In [None]:
fig = plt.figure(figsize=(16,4))
ax = fig.add_subplot(111)
sc.pl.violin(adata, ['pct_counts_mt'],
             groupby='batch',
             jitter=0.4, multi_panel=True,
             ax=ax)

In [None]:
target = adata[adata.obs['batch'] == 'MTf']

sc.pl.scatter(target, 'total_counts', 'n_genes_by_counts', color='pct_counts_mt', size=40)
sns.displot(target.obs['pct_counts_mt'][target.obs['pct_counts_mt'] < 30], kde=False)

In [None]:
target = adata[adata.obs['batch'] == 'MTm']

sc.pl.scatter(target, 'total_counts', 'n_genes_by_counts', color='pct_counts_mt', size=40)
sns.displot(target.obs['pct_counts_mt'][target.obs['pct_counts_mt'] < 30], kde=False)

In [None]:
target = adata[adata.obs['batch'] == 'WTf']

sc.pl.scatter(target, 'total_counts', 'n_genes_by_counts', color='pct_counts_mt', size=40)
sns.displot(target.obs['pct_counts_mt'][target.obs['pct_counts_mt'] < 30], kde=False)

In [None]:
target = adata[adata.obs['batch'] == 'WTm']

sc.pl.scatter(target, 'total_counts', 'n_genes_by_counts', color='pct_counts_mt', size=40)
sns.displot(target.obs['pct_counts_mt'][target.obs['pct_counts_mt'] < 30], kde=False)

In [None]:
MIN_COUNTS = 2000
MAX_COUNTS = 40000
MIN_GENES = 1000
MT_PCT = 10

print('Total number of cells: {:d}'.format(adata.n_obs))

sc.pp.filter_cells(adata, min_counts = MIN_COUNTS)
print('Number of cells after min count filter: {:d}'.format(adata.n_obs))

sc.pp.filter_cells(adata, max_counts = MAX_COUNTS)
print('Number of cells after max count filter: {:d}'.format(adata.n_obs))

sc.pp.filter_cells(adata, min_genes = MIN_GENES)
print('Number of cells after gene filter: {:d}'.format(adata.n_obs))

adata = adata[adata.obs['pct_counts_mt'] < MT_PCT]
print('Number of cells after MT filter: {:d}'.format(adata.n_obs))

In [None]:
adata

In [None]:
adata.obs['batch'].value_counts()

In [None]:
adata.obs.groupby('batch')['total_counts'].describe()

In [None]:
adata.layers['counts'] = adata.X.copy()

In [None]:
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
sc.pp.log1p(adata)

In [None]:
adata.raw = adata

In [None]:
sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=2000)
print('\n','Number of highly variable genes: {:d}'.format(np.sum(adata.var['highly_variable'])))

In [None]:
sc.pl.highly_variable_genes(adata)

In [None]:
adata.write(filename='./adata.h5ad')

## Embedding

In [None]:
sc.pp.pca(adata, use_highly_variable=True, svd_solver='arpack')
sc.pp.neighbors(adata)
sc.tl.tsne(adata)
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=0.5, key_added='leiden_r0.5')

In [None]:
adata

In [None]:
sc.pl.pca(adata, color=['total_counts', 'batch'])

In [None]:
sc.pl.pca(adata[adata.obs['batch'] == 'MTf'], color='total_counts', size=40)

In [None]:
sc.pl.pca(adata[adata.obs['batch'] == 'MTm'], color='total_counts', size=40)

In [None]:
sc.pl.pca(adata[adata.obs['batch'] == 'WTf'], color='total_counts', size=40)

In [None]:
sc.pl.pca(adata[adata.obs['batch'] == 'WTm'], color='total_counts', size=40)

In [None]:
sc.pl.umap(adata, color=['total_counts', 'batch'])

In [None]:
sc.pl.umap(adata[adata.obs['batch'] == 'MTf'], color='total_counts', size=20)

In [None]:
sc.pl.umap(adata[adata.obs['batch'] == 'MTm'], color='total_counts', size=20)

In [None]:
sc.pl.umap(adata[adata.obs['batch'] == 'WTf'], color='total_counts', size=20)

In [None]:
sc.pl.umap(adata[adata.obs['batch'] == 'WTm'], color='total_counts', size=20)

In [None]:
sc.pl.umap(adata,
           color=['batch', 'leiden_r0.5'],
           ncols=2,
           frameon=False)

## Batch correction

In [None]:
import scvi

In [None]:
scvi.model.SCVI.setup_anndata(
    adata,
    layer='counts',
    batch_key='batch',
)

In [None]:
model = scvi.model.SCVI(adata)

In [None]:
model.train()

In [None]:
model.save('./models/scVI_model', overwrite=True)

In [None]:
model = scvi.model.SCVI.load('./models/scVI_model', adata=adata)

In [None]:
adata.obsm['X_scVI'] = model.get_latent_representation()

In [None]:
adata.layers['scvi_normalized'] = model.get_normalized_expression(library_size=1e4)

In [None]:
adata

In [None]:
adata.write(filename='./ad_scvimodel.h5ad')

In [None]:
sc.pp.neighbors(adata,
                n_neighbors=30,
                use_rep="X_scVI")
sc.tl.umap(adata, min_dist=0.5)
sc.pl.umap(adata, color='batch')