## Notebook for Gut Cell Atlas data integration and batch correction `scVI`

+ Developed by: Anna Maguza
+ Institute of Computational Biology - Computational Health Centre - Hemlholtz Munich
+ Date created: 4th July 2023
+ Last modified: 22nd May 2024

### Load required modules

In [None]:
import scvi
import torch
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import plotnine as p
from pywaffle import Waffle
import matplotlib.pyplot as plt

In [None]:
torch.cuda.is_available()

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 2,
)

In [None]:
def X_is_raw(adata):
    return np.array_equal(adata.X.sum(axis=0).astype(int), adata.X.sum(axis=0))

### Read in datasets

In [17]:
input_dir = '/mnt/LaCIE/annaM/gut_project/raw_data/Elmentaite_2021/'
fig_dir = '/mnt/LaCIE/annaM/gut_project/Processed_data/Gut_data/Plots/Finding_stem_cells'

In [None]:
adata = sc.read(f'{input_dir}/GCA_filtered_raw.h5ad')

In [None]:
X_is_raw(adata)

In [None]:
# Save raw data
adata.raw = adata

In [None]:
adata.layers['counts'] = adata.X.copy()

# Calculate 5000 HVGs
sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 5000,
    layer = "counts",
    batch_key = "Library_Preparation_Protocol",
    subset = True,
    span = 1
)

### Run Integration with scVI

In [None]:
adata = adata.copy()
scvi.model.SCVI.setup_anndata(adata, 
                              layer = "counts", 
                              labels_key = "Cell_Type", 
                              categorical_covariate_keys = ["Sample_ID"])

In [None]:
scvi_model = scvi.model.SCVI(adata, n_latent = 50, n_layers = 3, dispersion = 'gene-batch', gene_likelihood = 'nb')

In [None]:
scvi_model.train(50, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 accelerator = "gpu",
                 devices = [0])

In [None]:
history_df = (
    scvi_model.history['elbo_train'].astype(float)
    .join(scvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

print(p_)

In [None]:
adata.obsm["X_scVI"] = scvi_model.get_latent_representation()

### Integration with scANVI

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(
    scvi_model,
    adata=adata,
    labels_key="Cell_Type",
    unlabeled_category="Unknown",
)

In [None]:
scanvi_model.train(100, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 accelerator = "gpu",
                 devices = [0])

In [None]:
history_df = (
    scanvi_model.history['elbo_train'].astype(float)
    .join(scanvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

In [None]:
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

### UMAP calculation

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scANVI", n_neighbors = 50, metric = 'minkowski')

In [None]:
sc.tl.umap(adata, min_dist = 0.4, spread = 4, random_state = 1712)

In [None]:
adata = adata.raw.to_adata()

In [None]:
adata.write(f'{input_dir}/Elementaine_2021_scVI_scANVI_corrected.h5ad')

### Create plots

In [3]:
adata = sc.read_h5ad(f'{input_dir}/Elementaine_2021_scVI_scANVI_corrected.h5ad')

In [None]:
sc.set_figure_params(dpi=300)
sc.pl.umap(adata, frameon = False, color = ['Cell_Type', 'Diagnosis', 'Donor_ID', 'Library_Preparation_Protocol', 'Location', 'Sex', ], size = 1, legend_fontsize = 5, ncols = 3)

In [None]:
adata.obs['Stem_cell'] = adata.obs['Cell_State'] == 'Stem cells'
adata.obs['Stem_cell'] = adata.obs['Stem_cell'].astype(str)

new_palette = ['#759EB8', '#824670']  # Hex codes for pink and light blue

# Assign the new color palette to your categories
adata.uns['Stem_cell_colors'] = new_palette

fig_dir = '/mnt/LaCIE/annaM/gut_project/Processed_data/Gut_data/Plots/Finding_stem_cells'

with plt.rc_context():
    sc.set_figure_params(dpi=300, figsize=(15, 15))
    sc.pl.umap(adata, frameon=False, color='Stem_cell', size=10, legend_fontsize=5, ncols=3, show=False)
    plt.savefig(f"{fig_dir}/Elementaite_stem_umap.png", bbox_inches="tight")

In [4]:
adata_log = adata.copy()
sc.pp.normalize_total(adata_log, target_sum = 1e6, exclude_highly_expressed = True)
sc.pp.log1p(adata_log)

In [5]:
stem_cells_markers = ['AXIN2', 'ASCL2', 'ATOH1', 'BMI1', 'CA12', 'CLU', 'GPX2', 'HMGCS2', 'LEFTY1', 'LGR5', 'LRIG1', 'MYC', 'OLFM4', 'SMOC2', 'TERT']

In [None]:
sc.tl.score_genes(adata_log, stem_cells_markers, score_name = 'Stem_cells_markers_score')

In [None]:
with plt.rc_context():
    sc.set_figure_params(dpi=300, figsize=(15, 15))
    sc.pl.umap(adata_log, color= ['Stem_cells_markers_score'], color_map = "magma_r", frameon=False, size = 10, show=False)
    plt.savefig(f"{fig_dir}/Elementaite_stem_markers.png", bbox_inches="tight")

In [None]:
stem_cells = adata_log[adata_log.obs['Cell_State'] == 'Stem cells']
with plt.rc_context():
    sc.set_figure_params(dpi=300, figsize=(15, 15))
    sc.pl.dotplot(stem_cells, stem_cells_markers, groupby='Cell_State', cmap = 'magma_r', show=False) 
    plt.savefig(f"{fig_dir}/Elementaite_stem_markers_dotplot.png", bbox_inches="tight")

In [None]:
Stem_cells_markers = ['CD24', 'DCLK1', 'LGR5', 'CD166', 'CD44', 'DCAMKL-1', 'SOX9', 'ACAD10', 'ACVR1C', 'ADH1C', 'ALDH1', 'ALK3', 'ARSE', 
'ASCL2', 'ATP10B', 'BMI1', 'C16orf89', 'C6orf136', 'CD29', 'CDCA7', 'CFTR','CHMP4C', 'CHP2', 'CLDN15', 'CLDN18', 'CLDN2', 'CPA6', 'DAPK2', 
'DDC', 'EFNA3', 'EPHB2', 'EPYC', 'EVPL', 'F2RL1', 'FBLN2', 'FOXD2-AS1', 'GATA6-AS1', 'GDF15', 'GJB1', 'GJB1', 'GOLT1A', 'GPX2', 'HNF1A', 
'HSD17B2', 'ITPKC','LEFTY1', 'LHFPL3-AS2', 'LIPG', 'LY6G6D', 'MGST1', 'MSI1', 'MYOM3', 'Musashi-1', 'NOX1', 'OLFM4', 'PCSK9', 'PDZD3', 
'PHLDA1', 'PKP2', 'PLAGL2', 'PLEKHH1', 'PPP1R1B', 'PTGDR', 'PTK7', 'RGMB', 'RNF157', 'RNF186', 'SFN', 'SLC27A2', 'SLC38A4', 'SLPI',
'SULT1B1', 'TAF4B', 'TANC1', 'TMEM171', 'TSPAN8', 'Telomerase Inhibitors', 'URB1-AS1', 'ZBED9', 'ZNF296', 'ASCL2', 'SMOC2']
sc.tl.score_genes(adata, Stem_cells_markers, score_name = 'Stem_cells_markers_score')

sc.set_figure_params(dpi=300)
sc.pl.umap(adata, color= ['Stem_cells_markers_score'], color_map = "RdPu", size = 0.3, frameon = False)

In [None]:
input = '/Users/anna.maguza/Desktop/Data/Processed_datasets/1_QC/GCA_scVI_scANVI.h5ad'
adata = sc.read(input)

adata.obs['predicted_doublets'] = adata.obs['predicted_doublets'].astype(str)

sc.set_figure_params(dpi=300)
sc.pl.umap(adata, color=['n_genes_by_counts', 'total_counts', 'pct_counts_mt', 'pct_counts_ribo', 'predicted_doublets'],
             color_map = "RdPu", size = 0.3, frameon = False, ncols=5)

In [None]:
sc.set_figure_params(dpi=300)
sc.pl.umap(adata, color=['Cell_Type'],
             color_map = "RdPu", size = 0.3, frameon = False, ncols=5)

+ Dot plot with selected Cell_State

In [10]:
df = adata_log.obs['Cell_State'].value_counts()

In [12]:
adata_log.obs['states_for_figure'] = adata_log.obs['Cell_Type'].copy()
adata_log.obs['states_for_figure'] = adata_log.obs['states_for_figure'].cat.set_categories(['Epithelial',
                                                                                            'Mesenchymal', 'T cells', 'Plasma cells', 'Myeloid',
                                                                                            'Neuronal', 'B cells', 'Endothelial', 'Red blood cells', 
                                                                                            'Colonocyte', 'Goblet cells', 'Enterocyte', 'Paneth cells', 
                                                                                            'Stem cells', 'TA',
                                                                                            'Tuft cells'])


adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Goblet cell', 'states_for_figure'] = 'Goblet cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'BEST2+ Goblet cell', 'states_for_figure'] = 'Goblet cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Goblet cells MUC2 TFF1', 'states_for_figure'] = 'Goblet cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Goblet cells SPINK4', 'states_for_figure'] = 'Goblet cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Paneth', 'states_for_figure'] = 'Paneth cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Paneth cells', 'states_for_figure'] = 'Paneth cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'TA', 'states_for_figure'] = 'TA'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Tuft', 'states_for_figure'] = 'Tuft cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Tuft cells', 'states_for_figure'] = 'Tuft cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Enterocyte', 'states_for_figure'] = 'Enterocyte'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Enterocytes BEST4', 'states_for_figure'] = 'Enterocyte'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Enterocytes TMIGD1 MEP1A GSTA1', 'states_for_figure'] = 'Enterocyte'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Enterocytes TMIGD1 MEP1A', 'states_for_figure'] = 'Enterocyte'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Stem cells OLFM4', 'states_for_figure'] = 'Stem cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Stem cells OLFM4 GSTA1', 'states_for_figure'] = 'Stem cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Stem cells OLFM4 LGR5', 'states_for_figure'] = 'Stem cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Stem cells OLFM4 PCNA', 'states_for_figure'] = 'Stem cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Stem_Cells_GCA', 'states_for_figure'] = 'Stem cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Stem_Cells_ext', 'states_for_figure'] = 'Stem cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Stem cells', 'states_for_figure'] = 'Stem cells'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Enterocytes CA1 CA2 CA4-', 'states_for_figure'] = 'Enterocyte'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Colonocyte', 'states_for_figure'] = 'Colonocyte'
adata_log.obs.loc[adata_log.obs['Cell_State'] == 'Goblet cells MUC2 TFF1-', 'states_for_figure'] = 'Goblet cells'

In [15]:
import matplotlib.pyplot as plt

In [18]:
with plt.rc_context():
    sc.set_figure_params(dpi=300, figsize=(15, 15))
    sc.pl.dotplot(adata_log, stem_cells_markers, groupby='states_for_figure', cmap = 'magma_r', show=False) 
    plt.savefig(f"{fig_dir}/Elementaite_stem_markers_dotplot_all_cells.png", bbox_inches="tight")

  obs_bool.groupby(level=0).sum() / obs_bool.groupby(level=0).count()
  dot_color_df = self.obs_tidy.groupby(level=0).mean()
  dot_ax.scatter(x, y, **kwds)
