## Notebook for Smillie data integration and batch correction `scVI`

+ Developed by: Anna Maguza
+ Institute of Computational Biology - Computational Health Centre - Hemlholtz Munich
+ Date created: 16th July 2023
+ Last modified: 22nd May 2024

### Load required modules

In [None]:
import scvi
import torch
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import plotnine as p
from pywaffle import Waffle
import matplotlib.pyplot as plt

In [None]:
torch.cuda.is_available()

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 2,
)

In [None]:
def X_is_raw(adata):
    return np.array_equal(adata.X.sum(axis=0).astype(int), adata.X.sum(axis=0))

### Read in datasets

In [None]:
input_dir = '/mnt/LaCIE/annaM/gut_project/raw_data/Smillie_2019/SCP259/'
adata = sc.read(f'{input_dir}/Smillie_with_QC_raw.h5ad')

In [None]:
X_is_raw(adata)

In [None]:
# Save raw data
adata.raw = adata

In [None]:
adata.layers['counts'] = adata.X.copy()

# Calculate 5000 HVGs
sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 5000,
    layer = "counts",
    batch_key = None,
    subset = True,
    span = 1
)

In [None]:
adata.obs.rename(columns = {'CellType': 'Cell_Type'}, inplace = True)

### Run Integration with scVI

In [None]:
adata = adata.copy()
scvi.model.SCVI.setup_anndata(adata, 
                              layer = "counts", 
                              labels_key = "Cell_Type", 
                              categorical_covariate_keys = ["Sample_ID"])

In [None]:
scvi_model = scvi.model.SCVI(adata, n_latent = 50, n_layers = 3, dispersion = 'gene-batch', gene_likelihood = 'nb')

In [None]:
scvi_model.train(50, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 accelerator = "gpu",
                 devices = [0])

In [None]:
adata.obsm["X_scVI"] = scvi_model.get_latent_representation()

#### Evaluate model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scvi_model.history['elbo_train'].astype(float)
    .join(scvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

print(p_)

### Integration with scANVI

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(
    scvi_model,
    adata=adata,
    labels_key="Cell_Type",
    unlabeled_category="Unknown",
)

In [None]:
scanvi_model.train(50, 
                   check_val_every_n_epoch = 1, 
                   enable_progress_bar = True,
                   accelerator = "gpu",
                   devices = [0])

In [None]:
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

In [None]:
history_df = (
    scanvi_model.history['elbo_train'].astype(float)
    .join(scanvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

### UMAP calculation

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scANVI", n_neighbors = 50, metric = 'minkowski')

In [None]:
sc.tl.umap(adata, min_dist = 0.4, spread = 4, random_state = 1712)

+ Write anndata object

In [None]:
adata = adata.raw.to_adata()

In [None]:
adata.write(f'{input_dir}/Smillie_scVI_scANVI.h5ad')

In [None]:
adata.obs_keys

In [None]:
# Add 'Female' to sex column in adata.obs
female_donors = ['N7', 'N8', 'N10', 'N13', 'N14', 'N18', 'N19', 'N20', 'N21', 'N23', 'N24', 'N44', 'N50', 'N106', 'N110', 'N111', 'N539']

adata.obs['Sex'] = ['Female' if donor in female_donors else 'Male' for donor in adata.obs['Donor_ID']]

In [None]:
adata.obs

In [None]:
sc.set_figure_params(dpi=300)
sc.pl.umap(adata, frameon = False, color = ['Cell_Type', 'Diagnosis', 'Donor_ID', 'Location', 'Sex', 'Cell_States'], size = 1, legend_fontsize = 5, ncols = 3)

In [None]:
adata.obs['predicted_doublets'] = adata.obs['predicted_doublets'].astype(str)

In [None]:
sc.set_figure_params(dpi=300)
sc.pl.umap(adata, frameon = False, color = ['n_genes_by_counts', 'total_counts', 'pct_counts_mito', 'pct_counts_ribo', 'predicted_doublets', ], size = 1, legend_fontsize = 5, ncols = 3)

In [None]:
# Make a column 'Stem_cell' in adata.obs, and put True if adata.obs['Cell_State'] == 'Stem_cell', False otherwise
adata.obs['Stem_cell'] = adata.obs['Cell_States'] == 'Stem'

In [None]:
adata.obs['Stem_cell'] = adata.obs['Stem_cell'].astype(str)

In [None]:
adata

In [None]:
new_palette = ['#759EB8', '#824670']  # Hex codes for pink and light blue

# Assign the new color palette to your categories
adata.uns['Stem_cell_colors'] = new_palette

fig_dir = '/mnt/LaCIE/annaM/gut_project/Processed_data/Gut_data/Plots/Finding_stem_cells'

with plt.rc_context():
    sc.set_figure_params(dpi=300, figsize=(15, 15))
    sc.pl.umap(adata, frameon=False, color='Stem_cell', size=10, legend_fontsize=5, ncols=3, show=False)
    plt.savefig(f"{fig_dir}/Smillie_stem_umap.png", bbox_inches="tight")

In [None]:
adata_log = adata.copy()
sc.pp.normalize_total(adata_log, target_sum = 1e6, exclude_highly_expressed = True)
sc.pp.log1p(adata_log)

In [None]:
stem_cells_markers = ['AXIN2', 'ASCL2', 'ATOH1', 'BMI1', 'CA12', 'CLU', 'GPX2', 'HMGCS2', 'LEFTY1', 'LGR5', 'LRIG1', 'MYC', 'OLFM4', 'SMOC2', 'TERT']

In [None]:
sc.tl.score_genes(adata_log, stem_cells_markers, score_name = 'Stem_cells_markers_score')

In [None]:
with plt.rc_context():
    sc.set_figure_params(dpi=300, figsize=(15, 15))
    sc.pl.umap(adata_log, color= ['Stem_cells_markers_score'], color_map = "magma_r", frameon=False, size = 8, show=False)
    plt.savefig(f"{fig_dir}/Smillie_stem_markers.png", bbox_inches="tight")

In [None]:
# Return to raw counts
adata = adata.raw.to_adata()

In [None]:
Stem_cells_markers = ['CD24', 'DCLK1', 'LGR5', 'CD166', 'CD44', 'DCAMKL-1', 'SOX9', 'ACAD10', 'ACVR1C', 'ADH1C', 'ALDH1', 'ALK3', 'ARSE', 
'ASCL2', 'ATP10B', 'BMI1', 'C16orf89', 'C6orf136', 'CD29', 'CDCA7', 'CFTR','CHMP4C', 'CHP2', 'CLDN15', 'CLDN18', 'CLDN2', 'CPA6', 'DAPK2', 
'DDC', 'EFNA3', 'EPHB2', 'EPYC', 'EVPL', 'F2RL1', 'FBLN2', 'FOXD2-AS1', 'GATA6-AS1', 'GDF15', 'GJB1', 'GJB1', 'GOLT1A', 'GPX2', 'HNF1A', 
'HSD17B2', 'ITPKC','LEFTY1', 'LHFPL3-AS2', 'LIPG', 'LY6G6D', 'MGST1', 'MSI1', 'MYOM3', 'Musashi-1', 'NOX1', 'OLFM4', 'PCSK9', 'PDZD3', 
'PHLDA1', 'PKP2', 'PLAGL2', 'PLEKHH1', 'PPP1R1B', 'PTGDR', 'PTK7', 'RGMB', 'RNF157', 'RNF186', 'SFN', 'SLC27A2', 'SLC38A4', 'SLPI',
'SULT1B1', 'TAF4B', 'TANC1', 'TMEM171', 'TSPAN8', 'Telomerase Inhibitors', 'URB1-AS1', 'ZBED9', 'ZNF296', 'ASCL2', 'SMOC2']
sc.tl.score_genes(adata_log, Stem_cells_markers, score_name = 'Stem_cells_markers_score')

sc.set_figure_params(dpi=300)
sc.pl.umap(adata_log, color= ['Stem_cells_markers_score'], color_map = "RdPu", size = 0.3, frameon = False)

In [None]:
input = '/Users/anna.maguza/Desktop/Data/Processed_datasets/1_QC/Smillie_scVI_scANVI.h5ad'
adata = sc.read(input)

adata.obs['predicted_doublets'] = adata.obs['predicted_doublets'].astype(str)

sc.set_figure_params(dpi=300)
sc.pl.umap(adata, color=['n_genes_by_counts', 'n_counts', 'pct_counts_mito', 'pct_counts_ribo', 'predicted_doublets'],
             color_map = "RdPu", size = 1, frameon = False, ncols=6)

In [None]:
# Rename stem cells into epithelial
adata.obs['Cell_Type'] = adata.obs['Cell_Type'].replace({'Stem cells': 'Epithelial'})

In [None]:
sc.set_figure_params(dpi=300)
sc.pl.umap(adata, color=['Cell_Type'],
             color_map = "RdPu", size = 1, frameon = False, ncols=6)