# Setup

In [None]:
# General
import scipy as sci
import numpy as np
import pandas as pd
import logging
import time
import pickle
from itertools import chain
import h5py
import scipy.sparse as sparse
import anndata as ad
import scipy.stats as stats
import gc

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rcParams
from matplotlib import cm
from matplotlib import colors
from matplotlib.pyplot import rc_context
import seaborn as sb
from plotnine import *
from adjustText import adjust_text
import umap.umap_ as umap
import pegasus as pg

# Analysis
import scanpy as sc

# Warnings
import warnings
warnings.filterwarnings('ignore') #(action='once') 

sc.logging.print_versions()

In [None]:
# Batch correction
import scvi
#import scanorama
#import harmonypy

In [None]:
# Colormap
colors2 = plt.cm.Reds(np.linspace(0, 1, 128)) 
colors3 = plt.cm.Greys_r(np.linspace(0.7,0.8,20)) 
colorsComb = np.vstack([colors3, colors2]) 
mymap = colors.LinearSegmentedColormap.from_list('my_colormap', colorsComb)
rcParams['figure.figsize']=(6,6)

# Filter doublet cells

In [None]:
adata = sc.read('/mnt/hdd/data/GUT_concatenated.h5ad')

In [None]:
adata

## Batch correction

In [None]:
sc.tl.pca(adata, n_comps = 80, svd_solver='arpack')

In [None]:
sc.pl.pca_variance_ratio(adata, log=True, n_pcs = 80)

In [None]:
sc.tl.pca(adata, n_comps = 50)

### n_latent = 30, n_hidden = 512, n_layers = 2

In [None]:
scvi.model.SCVI.setup_anndata(
    adata,
    layer="raw_counts",
    batch_key='batch'
)

In [None]:
model = scvi.model.SCVI(adata, dispersion = 'gene-batch')

In [None]:
model = scvi.model.SCVI(adata, n_latent = 30, n_hidden = 512, n_layers = 2, dispersion = 'gene-batch')

In [None]:
model

In [None]:
model.train(max_epochs = 250, early_stopping = True, enable_progress_bar=True)

#### Check reconstruction loss

In [None]:
# plot reconstruction loss
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
#plt.plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI_30_512_2"] = latent

In [None]:
sc.pp.neighbors(adata, n_neighbors = 35, n_pcs = 30, use_rep="X_scVI_30_512_2", metric='correlation')
sc.tl.leiden(adata)
sc.tl.umap(adata, min_dist=0.2)

In [None]:
sc.tl.leiden(adata, resolution = 2, key_added = 'leiden_2')
sc.tl.leiden(adata, resolution = 3, key_added = 'leiden_3')

In [None]:
sc.pl.umap(adata, color= ['batch', 'leiden', 'leiden_2', 'leiden_3'], size=20, color_map=mymap, ncols=2)

In [None]:
import gc
gc.collect()
import torch
torch.cuda.empty_cache()

In [None]:
adata.write('/mnt/hdd/data/GUT_concatenated_scvi30512.h5ad')

## Filter doublet clusters

In [None]:
import pegasus as pg

### Distribution in leiden clusters

In [None]:
pg.compo_plot(adata, 'leiden', 'doublet_calls', style = 'frequency',
              sort_function=None, 
              palette=['#FFD700', '#FF7F50', '#8B0000', '#0000CD', '#6495ED', '#008080', '#B0C4DE', '#696969'], dpi = 150)

In [None]:
pg.compo_plot(adata, 'leiden_2', 'doublet_calls', style = 'frequency',
              sort_function=None, 
              palette=['#FFD700', '#FF7F50', '#8B0000', '#0000CD', '#6495ED', '#008080', '#B0C4DE', '#696969'], dpi = 150)

In [None]:
pg.compo_plot(adata, 'leiden_3', 'doublet_calls', style = 'frequency',
              sort_function=None, 
              palette=['#FFD700', '#FF7F50', '#8B0000', '#0000CD', '#6495ED', '#008080', '#B0C4DE', '#696969'], dpi = 150)

In [None]:
adata.obs.leiden.value_counts()

In [None]:
adata.obs.leiden_2.value_counts()

In [None]:
adata.obs.leiden_3.value_counts()

In [None]:
pd.set_option('display.max_columns', None)

In [None]:
doublets_3 = pd.DataFrame(pd.crosstab(adata.obs['leiden_3'], adata.obs['doublet_calls'], normalize = 'index'))
doublets_3.style.highlight_max(color='lightgreen', axis = 1)

### Filter - create object with doublets

In [None]:
adata[np.isin(adata.obs['leiden_3'], ['10', '17', '28'])].shape

In [None]:
adata_doublets = adata[np.isin(adata.obs['leiden_3'], ['10', '17', '28'])].copy()
adata_doublets

In [None]:
adata = adata[np.isin(adata.obs['leiden_3'], ['10', '17', '28'], invert = True)].copy()
adata

### Recalculate UMAP

In [None]:
sc.pp.neighbors(adata, n_neighbors = 35, n_pcs = 30, use_rep='X_scVI_30_512_2', metric='correlation')
sc.tl.umap(adata, min_dist = 0.2)

In [None]:
sc.tl.leiden(adata)
sc.tl.leiden(adata, resolution = 2, key_added = 'leiden_2')
sc.tl.leiden(adata, resolution = 3, key_added = 'leiden_3')

In [None]:
sc.pl.umap(adata, color= ['batch', 'leiden', 'leiden_2', 'doublet_calls'], size=20, color_map=mymap, ncols= 2)

In [None]:
pg.compo_plot(adata, 'leiden', 'doublet_calls', style = 'frequency',
              sort_function=None, 
              palette=['#FFD700', '#FF7F50', '#8B0000', '#0000CD', '#6495ED', '#008080', '#B0C4DE', '#696969'], dpi = 150)

In [None]:
pg.compo_plot(adata, 'leiden_2', 'doublet_calls', style = 'frequency',
              sort_function=None, 
              palette=['#FFD700', '#FF7F50', '#8B0000', '#0000CD', '#6495ED', '#008080', '#B0C4DE', '#696969'], dpi = 150)

In [None]:
pg.compo_plot(adata, 'leiden_3', 'doublet_calls', style = 'frequency',
              sort_function=None, 
              palette=['#FFD700', '#FF7F50', '#8B0000', '#0000CD', '#6495ED', '#008080', '#B0C4DE', '#696969'], dpi = 150)

In [None]:
doublets_3 = pd.DataFrame(pd.crosstab(adata.obs['leiden_3'], adata.obs['doublet_calls'], normalize = 'index'))
doublets_3.style.highlight_max(color='lightgreen', axis = 1)

### Filter - all with doublet calls above 3

In [None]:
adata.obs.doublet_calls.value_counts()

In [None]:
adata_doublets_2 = adata[adata.obs['doublet_calls'] > 3].copy()
adata_doublets_2

In [None]:
adata = adata[adata.obs['doublet_calls'] < 4].copy()
adata

### Append doublet objects

In [None]:
doublets = []
doublets.append(adata_doublets)
doublets.append(adata_doublets_2)

In [None]:
doublets

In [None]:
adata[adata.obs['doublet_calls'] > 3].shape

In [None]:
doublets = ad.AnnData.concatenate(
    *doublets, join = 'inner',
    batch_key = 'batch',
).copy()

In [None]:
doublets

### Save doublet object

In [None]:
doublets.write(('/mnt/hdd/data/Gut_WT_filtered_doublets.h5ad'))

In [None]:
del adata_doublets
del adata_doublets_2
del doublets

### Recalculate UMAP

In [None]:
sc.pp.neighbors(adata, n_neighbors = 35, n_pcs = 30, use_rep='X_scVI_30_512_2', metric='correlation')
sc.tl.umap(adata, min_dist = 0.2)

In [None]:
sc.tl.leiden(adata)

In [None]:
sc.pl.umap(adata, color= ['batch', 'leiden', 'doublet_calls'], size=20, color_map=mymap, ncols=2)

## Save adata

In [None]:
adata.write('/mnt/hdd/data/GUT_concatenated_WT_corr.h5ad')

# Initial cell type

In [None]:
# General
import scipy as sci
import numpy as np
import pandas as pd
import logging
import time
import pickle
from itertools import chain
import h5py
import scipy.sparse as sparse
import anndata as ad
import scipy.stats as stats
import gc

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rcParams
from matplotlib import cm
from matplotlib import colors
from matplotlib.pyplot import rc_context
import seaborn as sb
from plotnine import *
from adjustText import adjust_text
import umap.umap_ as umap
import pegasus as pg

# Analysis
import scanpy as sc

# Warnings
import warnings
warnings.filterwarnings('ignore') #(action='once') 

#sc.logging.print_versions()

In [None]:
# Batch correction
import scvi
#import scanorama
#import harmonypy

In [None]:
# Colormap
colors2 = plt.cm.Reds(np.linspace(0, 1, 128)) 
colors3 = plt.cm.Greys_r(np.linspace(0.7,0.8,20)) 
colorsComb = np.vstack([colors3, colors2]) 
mymap = colors.LinearSegmentedColormap.from_list('my_colormap', colorsComb)
rcParams['figure.figsize']=(6,6)

## Samples

In [None]:
base_path = '/mnt/hdd/data'

In [None]:
adata = sc.read('/mnt/hdd/data/GUT_concatenated_WT_corr.h5ad')

In [None]:
adata

In [None]:
pg.compo_plot(adata, 'leiden', 'sample', style = 'frequency',
              sort_function=None, dpi = 150)

In [None]:
sc.tl.rank_genes_groups(adata, groupby='leiden')

## Azimuth cell types

In [None]:
import gseapy as gp

In [None]:
adata.obs['cell_type_azimuth'] = 'Unkown'
for cluster in adata.obs['leiden'].cat.categories:
    enriched = gp.enrichr(list(adata.uns['rank_genes_groups']['names'][cluster][:100]),
                    gene_sets=['Azimuth_Cell_Types_2021'],
                    outdir=None)
    adata.obs.loc[adata.obs['leiden']==cluster,'cell_type_azimuth'] = enriched.results.sort_values('Combined Score', ascending=False).iloc[0,1].split(' CL')[0]
    gp.dotplot(enriched.results, figsize=(6,5), cutoff = 0.1, title=cluster, cmap = plt.cm.turbo, top_term=5)
    plt.show()
    display(enriched.results.sort_values('Combined Score', ascending=False).iloc[0:5,:])

## Marker genes

In [None]:
marker_genes = ['Lgr5','Mki67','Ascl2','Sis','Dclk1','Tff3','Neurod1','Tph1','Isl1','Epcam','Cd68',#'Cd19',
                'Cd79a','Cd74',# 'Cd4',#'Cd8a',
                #'Cd3g',
                #'Mcpt2',
                'Ephx2', 'Defa5', 'Defa23', 'Clca1', 'Scin'] 
#Adgre1, Cd4

In [None]:
sc.pl.umap(adata, color = marker_genes + ['doublets_shown', 'leiden'], color_map = mymap, size = 30)#, legend_loc = 'on data')

In [None]:
sc.pl.umap(adata, color = ['leiden'], color_map = mymap, size = 30, legend_loc = 'on data')

In [None]:
adata.var_names_make_unique()

In [None]:
adata.obs_names_make_unique()

In [None]:
pg.compo_plot(adata, 'leiden', 'doublet_calls', style = 'frequency',
              sort_function=None, 
              palette=['#FFD700', '#FF7F50', '#8B0000', '#0000CD', '#6495ED', '#008080', '#B0C4DE', '#696969'], dpi = 150)

In [None]:
sc.set_figure_params(frameon=False, dpi = 80, dpi_save=300, color_map='viridis', vector_friendly=True, transparent=True)
for _ in adata.obs['sample'].value_counts().index:
    sc.pl.umap(adata, color='sample', groups=_, size = 8, add_outline=True)
rcParams['figure.figsize']=(6,7)

In [None]:
sc.pl.DotPlot(adata, var_names =  marker_genes , groupby='leiden', standard_scale='var', 
              cmap=mymap, use_raw=False).style(color_on='square', dot_edge_lw=1, grid=True, dot_min=0.15, dot_edge_color=None).show()

In [None]:
sc.pl.umap(adata, color = ['leiden', 'doublets_shown'], color_map = mymap, size = 30, legend_loc = 'on data', legend_fontoutline = 1)

In [None]:
# Defa5,23 = Paneth -1,19 
# Neurod1, TPH1, ISL1 = EEC - 7,6,11
# Dclk1 = Tuft - 10
# Clca1, Scin = Goblet - 3
# Lgr5, Mki67, Aslc2 = ISC - 0,2
# Sis = Enterocytes -4
# Cd79a, Cd74 = B-cell -12, 18
# Cd4+ = T-cells 
# Anxa5 = M-cells

In [None]:
adata.obs.leiden.cat.categories

## Annotation

In [None]:
map_names = {}
for c in adata.obs['leiden'].cat.categories:
    if c in ['0', '2']:
        map_names[c] = 'ISC'         
    elif c in ['10']:
        map_names[c] = 'Tuft'        
    elif c in ['6', '7', '11']:
        map_names[c] = 'EEC'      
    elif c in ['12','18']:
        map_names[c] = 'B-cells'
    elif c in ['4']:
        map_names[c] = 'Enterocytes'
    elif c in ['1', '19']:
        map_names[c] = 'Paneth'
    elif c in ['3']:
        map_names[c] = 'Goblet'
    elif c in ['16']:
        map_names[c] = 'Cd68+'
    elif c in ['9']:
        map_names[c] = 'non-epithelial'
    else:
        map_names[c] = 'unknown' #to avoid NA values

adata.obs['initial_cell_type'] = adata.obs['leiden']
adata.obs['initial_cell_type'] = adata.obs['initial_cell_type'].map(map_names).astype('category')

In [None]:
pd.value_counts(adata.obs['initial_cell_type'])

In [None]:
sc.pl.umap(adata, color=['initial_cell_type','batch','leiden'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0)) #, wspace=0.55)

### Save

In [None]:
adata.write('/mnt/hdd/data/GUT_concatenated_WT_no_doublets.h5ad')

In [None]:
adata = sc.read_h5ad('/mnt/hdd/data/GUT_concatenated_WT_no_doublets.h5ad')

In [None]:
sc.pl.umap(adata, color=['initial_cell_type','batch','leiden'], size=10, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=1)

## Benchmark Batch correction

In [None]:
adata

### HVG, remove ambient

In [None]:
adata.var

In [None]:
adata.var['is_ambient-4'].value_counts()

In [None]:
adata.var[adata.var['is_ambient-19']==True].index

In [None]:
ambient_genes = (list(adata.var[adata.var['is_ambient-0']=='True'].index) + list(adata.var[adata.var['is_ambient-1']=='True'].index)
                 + list(adata.var[adata.var['is_ambient-2']=='True'].index)+ list(adata.var[adata.var['is_ambient-3']=='True'].index)
                 + list(adata.var[adata.var['is_ambient-4']=='True'].index)+ list(adata.var[adata.var['is_ambient-5']=='True'].index)
                 + list(adata.var[adata.var['is_ambient-6']==True].index)+ list(adata.var[adata.var['is_ambient-7']==True].index)
                 + list(adata.var[adata.var['is_ambient-8']==True].index)+ list(adata.var[adata.var['is_ambient-9']==True].index)
                 + list(adata.var[adata.var['is_ambient-10']==True].index)+ list(adata.var[adata.var['is_ambient-11']==True].index)
                 + list(adata.var[adata.var['is_ambient-12']==True].index)+ list(adata.var[adata.var['is_ambient-13']==True].index)
                 + list(adata.var[adata.var['is_ambient-14']==True].index)+ list(adata.var[adata.var['is_ambient-15']==True].index)
                 + list(adata.var[adata.var['is_ambient-16']==True].index)+ list(adata.var[adata.var['is_ambient-17']==True].index)
                 + list(adata.var[adata.var['is_ambient-18']==True].index)+ list(adata.var[adata.var['is_ambient-19']==True].index))

In [None]:
ambient_genes = set(ambient_genes)
len(ambient_genes)

In [None]:
# if hgv breaks
# adata.uns['log1p']['base'] = None

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=4000, batch_key = 'batch')

In [None]:
hvg = adata.var[adata.var['highly_variable']==True].index

In [None]:
hvg= hvg[np.isin(hvg, ambient_genes, invert=True)]
hvg

In [None]:
adata=adata[:,hvg].copy()
adata

## Benchmark metrics 

In [None]:
sc.tl.pca(adata, n_comps = 55)

### Harmony

In [None]:
sc.external.pp.harmony_integrate(adata, 'batch', max_iter_harmony = 30)

In [None]:
sc.pp.neighbors(adata, n_neighbors = 75, n_pcs = 55, use_rep='X_pca_harmony', metric='correlation')
sc.tl.umap(adata, min_dist = 0.2)

In [None]:
sc.pl.umap(adata, color= ['batch', 'initial_cell_type', 'leiden'], size=20, color_map=mymap)

### scVI Default

In [None]:
scvi.model.SCVI.setup_anndata(
    adata,
    layer="raw_counts",
    batch_key='batch'
)

In [None]:
model = scvi.model.SCVI(adata, dispersion = 'gene-batch')

In [None]:
model

In [None]:
model.train(max_epochs = 300, enable_progress_bar=True)

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI_10_128"] = latent

In [None]:
sc.pp.neighbors(adata, n_neighbors = 25, n_pcs = 10, use_rep="X_scVI_10_128", metric='correlation')
sc.tl.umap(adata, min_dist=0.2)
sc.tl.leiden(adata)

In [None]:
sc.pl.umap(adata, color=['batch', 'leiden', 'initial_cell_type'], size=20, color_map=mymap)

#### Check reconstruction loss

In [None]:
# plot reconstruction loss
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
#plt.plot(model_scvi.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

### n_latent = 20, n_hidden = 512

In [None]:
model = scvi.model.SCVI(adata, n_latent = 20, n_hidden = 512, dispersion = 'gene-batch')

In [None]:
model

In [None]:
model.train(max_epochs = 300, enable_progress_bar=True)

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI_20_512"] = latent

In [None]:
sc.pp.neighbors(adata, n_neighbors = 35, n_pcs = 20, use_rep="X_scVI_20_512", metric='correlation')
sc.tl.leiden(adata)
sc.tl.umap(adata, min_dist=0.2)

In [None]:
sc.pl.umap(adata, color=['batch', 'leiden', 'initial_cell_type'], size=20, color_map=mymap)

#### Check reconstruction loss

In [None]:
# plot reconstruction loss
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
#plt.plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

### n_latent = 30, n_hidden = 512, n_layers = 2

In [None]:
model = scvi.model.SCVI(adata, n_latent = 30, n_hidden = 512, n_layers = 2, dispersion = 'gene-batch')

In [None]:
model

In [None]:
model.train(max_epochs = 300, enable_progress_bar=True)

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI_30_512_2"] = latent

In [None]:
sc.pp.neighbors(adata, n_neighbors = 35, n_pcs = 30, use_rep="X_scVI_30_512_2", metric='correlation')
sc.tl.leiden(adata)
sc.tl.umap(adata, min_dist=0.2)

In [None]:
sc.pl.umap(adata, color=['batch', 'leiden', 'initial_cell_type'], size=20, color_map=mymap)

#### Check reconstruction loss

In [None]:
# plot reconstruction loss
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
#plt.plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

### n_latent = 30, n_hidden = 1024

In [None]:
model = scvi.model.SCVI(adata, n_latent = 30, n_hidden = 1024, dispersion = 'gene-batch')

In [None]:
model

In [None]:
model.train(max_epochs = 300, enable_progress_bar=True)

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI_30_1024"] = latent

In [None]:
sc.pp.neighbors(adata, n_neighbors = 45, n_pcs = 30, use_rep="X_scVI_30_1024", metric='correlation')
sc.tl.leiden(adata)
sc.tl.umap(adata, min_dist=0.2)

In [None]:
sc.pl.umap(adata, color=['batch', 'leiden', 'initial_cell_type'], size=20, color_map=mymap)

#### Check reconstruction loss

In [None]:
# plot reconstruction loss
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
#plt.plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

### n_latent = 50, n_hidden = 1024, n_layers = 3

In [None]:
model = scvi.model.SCVI(adata, n_latent = 50, n_hidden = 1024, n_layers = 3, dispersion = 'gene-batch')

In [None]:
model

In [None]:
model.train(max_epochs = 300, enable_progress_bar=True)

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI_50_1024_3"] = latent

In [None]:
sc.pp.neighbors(adata, n_neighbors = 70, n_pcs = 50, use_rep="X_scVI_50_1024_3", metric='correlation')
sc.tl.leiden(adata)
sc.tl.umap(adata, min_dist=0.2)

In [None]:
sc.pl.umap(adata, color=['batch', 'leiden', 'initial_cell_type'], size=20, color_map=mymap)

#### Check reconstruction loss

In [None]:
# plot reconstruction loss
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
#plt.plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

### n_latent = 50, n_hidden = 2048, n_layers = 1

In [None]:
model = scvi.model.SCVI(adata, n_latent = 50, n_hidden = 2048, n_layers = 1, dispersion = 'gene-batch')

In [None]:
model

In [None]:
model.train(max_epochs = 300, enable_progress_bar=True)

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI_50_2048_1"] = latent

In [None]:
sc.pp.neighbors(adata, n_neighbors = 70, n_pcs = 50, use_rep="X_scVI_50_2048_1", metric='correlation')
sc.tl.leiden(adata)
sc.tl.umap(adata, min_dist=0.2)

In [None]:
sc.pl.umap(adata, color=['batch', 'leiden', 'initial_cell_type'], size=20, color_map=mymap)

#### Check reconstruction loss

In [None]:
# plot reconstruction loss
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
#plt.plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

### n_latent = 60, n_hidden = 2048, n_layers = 2

In [None]:
model = scvi.model.SCVI(adata, n_latent = 60, n_hidden = 2048, n_layers = 2, dispersion = 'gene-batch')

In [None]:
model

In [None]:
model.train(max_epochs = 300, enable_progress_bar=True)

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI_60_2048_2"] = latent

In [None]:
sc.pp.neighbors(adata, n_neighbors = 80, n_pcs = 60, use_rep="X_scVI_60_2048_2", metric='correlation')
sc.tl.leiden(adata)
sc.tl.umap(adata, min_dist=0.2)

In [None]:
sc.pl.umap(adata, color=['batch', 'leiden', 'initial_cell_type'], size=20, color_map=mymap)

#### Check reconstruction loss

In [None]:
# plot reconstruction loss
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
#plt.plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

### scPoli

In [None]:
from sklearn.metrics import classification_report
from scarches.models.scpoli import scPoli

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_prototype_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [None]:
adata.X = adata.layers['raw_counts'].copy()

In [None]:
#adata.X = adata.X.todense()
adata.X = adata.X.astype('float32')

In [None]:
adata.obs['initial_cell_type'].value_counts()

In [None]:
scpoli_model = scPoli(
    adata = adata,
    condition_keys = 'sample',
    cell_type_keys = 'initial_cell_type',
    hidden_layer_sizes = [100],
    latent_dim = 25,
    embedding_dims = 5,
    recon_loss = 'poisson',
)
scpoli_model.train(
    n_epochs = 250,
    pretraining_epochs = 100,
#    early_stopping_kwargs=early_stopping_kwargs,
    use_early_stopping = False,
    alpha_epoch_anneal = 1000,
    eta = 0.5
)

In [None]:
#get latent representation of reference data
scpoli_model.model.eval()
data_latent = scpoli_model.get_latent(
    adata,
    mean=True
)

In [None]:
adata.obsm["X_scPoli"] = data_latent

In [None]:
sc.pp.neighbors(adata, use_rep="X_scPoli", metric='correlation')
sc.tl.leiden(adata)
sc.tl.umap(adata, min_dist=0.2)

In [None]:
sc.pl.umap(adata, color = 'initial_cell_type', size = 5)

In [None]:
adata

### Benchmark

In [None]:
from scib_metrics.benchmark import Benchmarker

In [None]:
bm = Benchmarker(
    adata,
    batch_key = "batch",
    label_key = "initial_cell_type",
    embedding_obsm_keys = ['X_scVI_10_128', 'X_scVI_20_512', 'X_scVI_30_512_2', 'X_scVI_30_1024',
                           'X_scVI_50_1024_3', 'X_scVI_50_2048_1', 'X_scVI_60_2048_2', 'X_pca_harmony', 'X_scPoli'],
    n_jobs = 20,
)
bm.benchmark()

In [None]:
bm.plot_results_table(min_max_scale=False)

In [None]:
bm.plot_results_table(min_max_scale=True)

In [None]:
adata.write('/mnt/hdd/data/GUT_concatenated_WT_no_doublets_HVG_benchmarked.h5ad')

# Use best correction

In [None]:
adata = sc.read_h5ad('/mnt/hdd/data/GUT_concatenated_WT_no_doublets_HVG_benchmarked.h5ad')

### HVG, remove ambient

In [None]:
ambient_genes = (list(adata.var[adata.var['is_ambient-0']=='True'].index) + list(adata.var[adata.var['is_ambient-1']=='True'].index)
                 + list(adata.var[adata.var['is_ambient-2']=='True'].index)+ list(adata.var[adata.var['is_ambient-3']=='True'].index)
                 + list(adata.var[adata.var['is_ambient-4']=='True'].index)+ list(adata.var[adata.var['is_ambient-5']=='True'].index)
                 + list(adata.var[adata.var['is_ambient-6']==True].index)+ list(adata.var[adata.var['is_ambient-7']==True].index)
                 + list(adata.var[adata.var['is_ambient-8']==True].index)+ list(adata.var[adata.var['is_ambient-9']==True].index)
                 + list(adata.var[adata.var['is_ambient-10']==True].index)+ list(adata.var[adata.var['is_ambient-11']==True].index)
                 + list(adata.var[adata.var['is_ambient-12']==True].index)+ list(adata.var[adata.var['is_ambient-13']==True].index)
                 + list(adata.var[adata.var['is_ambient-14']==True].index)+ list(adata.var[adata.var['is_ambient-15']==True].index)
                 + list(adata.var[adata.var['is_ambient-16']==True].index)+ list(adata.var[adata.var['is_ambient-17']==True].index)
                 + list(adata.var[adata.var['is_ambient-18']==True].index)+ list(adata.var[adata.var['is_ambient-19']==True].index))

In [None]:
ambient_genes = np.unique(ambient_genes)
len(ambient_genes)

In [None]:
# if hgv breaks
#adata.uns['log1p']['base'] = None

In [None]:
adata.X.min()

In [None]:
adata_hvg = sc.pp.log1p(adata,copy=True)# if hgv breaks

In [None]:
sc.pp.highly_variable_genes(adata_hvg, n_top_genes=4000, batch_key = 'batch')

In [None]:
hvg = adata_hvg.var[adata_hvg.var['highly_variable']==True].index

In [None]:
hvg= hvg[np.isin(hvg, ambient_genes, invert=True)]
hvg

In [None]:
adata=adata_hvg[:,hvg].copy()
adata

## n_latent = 20, n_hidden = 512

In [None]:
scvi.model.SCVI.setup_anndata(
    adata,
    layer="raw_counts",
    batch_key='batch'
)

In [None]:
model = scvi.model.SCVI(adata, n_latent = 20, n_hidden = 512, dispersion = 'gene-batch')

In [None]:
model

In [None]:
model.train(max_epochs = 300, enable_progress_bar=True, early_stopping=True)

In [None]:
latent = model.get_latent_representation()

In [None]:
adata.obsm["X_scVI_20_512"] = latent

In [None]:
sc.pp.neighbors(adata, n_neighbors = 35, n_pcs = 20, use_rep="X_scVI_20_512", metric='correlation')
sc.tl.leiden(adata)
sc.tl.umap(adata, min_dist=0.2)

In [None]:
sc.pl.umap(adata, color=['batch', 'leiden', 'initial_cell_type'], size=20, color_map=mymap)

### Check reconstruction loss

In [None]:
# plot reconstruction loss
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
#plt.plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

## Add to object with all genes

In [None]:
adata_all = sc.read('/mnt/hdd/data/GUT_concatenated_WT_corr.h5ad')

In [None]:
adata_all.shape

In [None]:
pd.crosstab(adata.obs['batch'], adata.obs['initial_cell_type'])

In [None]:
adata_all.obsm['X_scVI_20_512'] = adata.obsm['X_scVI_20_512']

In [None]:
adata_all.obs['initial_cell_type'] = adata.obs['initial_cell_type'].copy()

In [None]:
sc.pp.neighbors(adata_all, n_neighbors = 50, n_pcs = 20, use_rep="X_scVI_20_512", metric='correlation')
sc.tl.leiden(adata_all)
sc.tl.umap(adata_all, min_dist = 0.2)

In [None]:
sc.pl.umap(adata_all, color=['initial_cell_type', 'leiden', 'n_counts', 'n_genes', 'doublet_calls',
                             'Gcg', 'Neurog3', 'Gip'], color_map=mymap, size=20)

In [None]:
sc.pp.neighbors(adata_all, n_neighbors = 40, n_pcs = 20, use_rep="X_scVI_20_512", metric='correlation')
sc.tl.leiden(adata_all)
sc.tl.umap(adata_all, min_dist = 0.2)

In [None]:
sc.pl.umap(adata_all, color=['initial_cell_type', 'leiden', 'n_counts', 'n_genes', 'doublet_calls',
                             'Gcg', 'Neurog3', 'Gip', 'Tph1'], color_map=mymap, size=20)

In [None]:
sc.pp.neighbors(adata_all, n_neighbors = 25, n_pcs = 20, use_rep="X_scVI_20_512", metric='correlation')
sc.tl.leiden(adata_all)
sc.tl.umap(adata_all, min_dist = 0.2)

In [None]:
sc.pl.umap(adata_all, color=['initial_cell_type', 'leiden', 'n_counts', 'n_genes', 'doublet_calls',
                             'Gcg', 'Neurog3'], color_map=mymap, size=20)

In [None]:
adata_all.obsm['X_umap'] = adata.obsm['X_umap'] #why??

In [None]:
sc.pl.umap(adata_all, color=['initial_cell_type', 'leiden', 'n_counts', 'n_genes', 'doublet_calls',
                             'Gcg', 'Neurog3', 'Gip', 'Tph1'], color_map=mymap, size=20)

## reduce adata dimensionality - Add ambient genes

In [None]:
ambient_genes = (list(adata.var[adata.var['is_ambient-0']=='True'].index) + list(adata.var[adata.var['is_ambient-1']=='True'].index)
                 + list(adata.var[adata.var['is_ambient-2']=='True'].index)+ list(adata.var[adata.var['is_ambient-3']=='True'].index)
                 + list(adata.var[adata.var['is_ambient-4']=='True'].index)+ list(adata.var[adata.var['is_ambient-5']=='True'].index)
                 + list(adata.var[adata.var['is_ambient-6']==True].index)+ list(adata.var[adata.var['is_ambient-7']==True].index)
                 + list(adata.var[adata.var['is_ambient-8']==True].index)+ list(adata.var[adata.var['is_ambient-9']==True].index)
                 + list(adata.var[adata.var['is_ambient-10']==True].index)+ list(adata.var[adata.var['is_ambient-11']==True].index)
                 + list(adata.var[adata.var['is_ambient-12']==True].index)+ list(adata.var[adata.var['is_ambient-13']==True].index)
                 + list(adata.var[adata.var['is_ambient-14']==True].index)+ list(adata.var[adata.var['is_ambient-15']==True].index)
                 + list(adata.var[adata.var['is_ambient-16']==True].index)+ list(adata.var[adata.var['is_ambient-17']==True].index)
                 + list(adata.var[adata.var['is_ambient-18']==True].index)+ list(adata.var[adata.var['is_ambient-19']==True].index))

In [None]:
ambient_genes = np.unique(ambient_genes)
len(ambient_genes)

In [None]:
ambient_genes

In [None]:
ambient_category = np.isin(adata_all.var.index, ambient_genes)
ambient_category

In [None]:
len(ambient_category)

In [None]:
adata_all.var['ambient_genes'] = ambient_category

In [None]:
adata_all[:, adata_all.var['ambient_genes'] == True].var.index

In [None]:
adata_all.var['genome'] = adata_all.var['genome-0'].copy()

In [None]:
list(adata_all.obs)

In [None]:
# Clean up .obs
adata_all.obs = adata_all.obs.loc[:,['n_counts',
 'sample',
 'log_cell_probs',
 'log_counts',
 'n_genes',
 'log_genes',
 'total_counts_rank',
 'ambi_frac',
 'mt_frac',
 'rp_frac',
 'leiden',
 'doublet_calls',
 'doublets_shown',
 'batch',
'Project',
 'sequencing',
 'modality',
 'condition',
 'sample number Minas',
 'Internal ID',
 'SeqID',
 'kit',
 'line',
 'strain',
 'enriched',
 'enrichment proportion',
 'treatment',
 'diet',
 'tissue',
 'structure',
 'target cell number',
 'Read Length',
 'Index Type',
 'sequencing machine',
 'size_factors',
 'leiden_2',
 'leiden_3',
 'initial_cell_type']].copy()

In [None]:
list(adata_all.var)

In [None]:
# Clean up .var
adata_all.var = adata_all.var.loc[:,['feature_types',
 'genome',
 'ambient_genes']].copy()

## Save

In [None]:
adata_all.write('/mnt/hdd/data/GUT_concatenated_WT_no_doublets_corrected.h5ad')