<!--  -->
# Preprocessing - Joint Embedding & Doublet Removal with MultiVI
Adapted from Michael Sterr

2024-02-09 09:28:15 


# Setup

In [None]:
# General
import scipy as sci
import numpy as np
import pandas as pd
import logging
import time
import pickle
from itertools import chain
import session_info
import gc # Free memory #gc.collect()
import scipy.stats as stats

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rcParams
from matplotlib.pyplot import rc_context
from matplotlib import cm
import seaborn as sb

# Analysis
import muon as mu
from muon import atac as ac # Import a module with ATAC-seq-related functions
import scanpy as sc
import anndata as ad

In [None]:
# Settings

import warnings
warnings.filterwarnings("ignore")

## Directory
base_dir = '/mnt/hdd/'
data_dir = 'data/Healthy/'
nb_dir = 'Notebooks/Gut_project/'
sc.settings.figdir = base_dir + nb_dir + 'Figures'
sc.settings.cachedir = base_dir + 'Cache'

## Scanpy settings
sc.settings.verbosity = 3
sc.logging.print_versions()
session_info.show()

In [None]:
# Color maps
ch_YlRd=sb.cubehelix_palette(100, start=.7, rot=.25, gamma=0.6, hue=2, light=1, dark=0.05, as_cmap=True)

In [None]:
# Plot settings
%matplotlib inline

## Plotting parameters
rcParams['figure.figsize']=(6,6) #rescale figures
#sc.set_figure_params(scanpy=True, frameon=False, vector_friendly=False, color_map='tab10' ,transparent=True, dpi=150, dpi_save=300)
sc.set_figure_params(scanpy=True, frameon=False, vector_friendly=False ,transparent=True, dpi=150, dpi_save=300)

## Grid & Ticks
rcParams['grid.alpha'] = 0
rcParams['xtick.bottom'] = True
rcParams['ytick.left'] = True

## Embed font
plt.rc('pdf', fonttype=42)

## Define new default settings
plt.rcParamsDefault = plt.rcParams

# Setup R

In [None]:
%run utils.ipynb

In [None]:
#R
import rpy2
import rpy2.robjects as ro
import rpy2.rinterface_lib.callbacks
from rpy2.robjects import pandas2ri
import anndata2ri
setup_R('/home/scanalysis/mnt/envs/scUV/lib/R')

In [None]:
%%R

.libPaths()

In [None]:
%%R
# Parallelization
library(BiocParallel)
register(MulticoreParam(20, progressbar = TRUE))

library(future)
plan("multicore", workers = 20)
options(future.globals.maxSize = 64 * 1024^2)
plan()

library(doParallel)
registerDoParallel(20)

sessionInfo()

# Load Data

In [None]:
file_base_name = 'scMultiome_Mouse_Crypts_FVF'
file_path = '/mnt/md0/Projects/scMultiome_Mouse_Crypts_FVF_P23033_Final_Notebooks/Files'

## MuData

In [None]:
multiome_samples = ['597_NVF_Crypts_Rep1', '598_FVF_Crypts_Rep1','599_FVF_Crypts_Rep2','604_NVF_Crypts_Rep2', 'FVF-high','FVF-low']

In [None]:
mdata_list = []
for s in multiome_samples:
    mdata_list.append(mu.read(f'{base_dir}data/Multiome/{s}/outs/raw_feature_bc_matrix_filtered_markedDoublets.h5mu'))

In [None]:
mdata_list

### concat mdata

In [None]:
# Concatenate ATAC for meta data
atac = ad.concat([i.mod['atac'] for i in mdata_list], merge='unique', join = 'outer')

In [None]:
del mdata_list
atac

#### ATAC TF-IDF Normalization

In [None]:
def normalize_tfidf(atac, hvg=False, hvg_min_mean=0.05, hvg_max_mean=1.5, hvg_min_disp=0.5, remove_1st_lsi=True):
    
    from muon import atac as ac
    print('Normalization with SCT:')
    
    print('\tSave raw counts to .layers[\'atac_raw_counts\']...')
    # Save original counts
    if 'atac_raw_counts' not in list(atac.layers):
        print('\tSave AnnData.X to AnnData.layers[\'atac_raw_counts\']...')
        atac.layers['atac_raw_counts'] = atac.X
    
    # TF-IDF normalization
    print('\tTF-IDF normalization...')
    ac.pp.tfidf(atac, scale_factor=1e4, log_tf=False, log_idf=False, log_tfidf=True)
    
    if hvg:
        # Feature selection
        sc.pp.highly_variable_genes(atac, min_mean=hvg_min_mean, max_mean=hvg_max_mean, min_disp=hvg_min_disp)
        sc.pl.highly_variable_genes(atac)
        print('\t\tNumber of variable features: ', np.sum(atac.var.highly_variable))
    
    # Save to .raw
    print('\tSave to .raw...')
    atac.raw = atac
    
    # LSI
    print('\tLSI...')
    ac.tl.lsi(atac)
    
    if remove_1st_lsi:
        # 1st dimension is often associated with number peaks/counts and should be removed
        
        # plot 1st lsi against counts/peaks
        lims_x = []
        lims_y = []
        lims_line = []

        fig, axs = plt.subplots(1, 2, constrained_layout=True, figsize=(8, 4))
        # Plots
        axs[0].scatter(atac.obsm['X_lsi'][:,0], y=atac.obs['log_counts_ATAC'], s=2, alpha=0.2, c=atac.obs['n_peaks_ATAC'], cmap='rocket')
        axs[1].scatter(atac.obsm['X_lsi'][:,0], y=atac.obs['log_peaks_ATAC'], s=2, alpha=0.2, c=atac.obs['n_counts_ATAC'], cmap='rocket')

        # Aesthetics
        for i,ax in enumerate(axs):
            lims_x.append(ax.get_xlim())
            lims_y.append(ax.get_ylim())

        axs[0].set_xlabel('LSI Dim. 1')
        axs[0].set_ylabel('Counts')
        axs[0].set_xlim(lims_x[0])
        axs[0].set_ylim(lims_y[0])

        axs[1].set_xlabel('LSI Dim. 1')
        axs[1].set_ylabel('Peaks')
        axs[1].set_xlim(lims_x[1])
        axs[1].set_ylim(lims_y[1])
        
        # remove 1st component
        atac.obsm['X_lsi'] = atac.obsm['X_lsi'][:,1:]
        atac.varm["LSI"] = atac.varm["LSI"][:,1:]
        atac.uns["lsi"]["stdev"] = atac.uns["lsi"]["stdev"][1:]

In [None]:
atac.layers['atac_raw_counts'] = atac.X.copy()

In [None]:
normalize_tfidf(atac, hvg=True,remove_1st_lsi=False)

In [None]:
#atac.X = sci.sparse.csr_matrix(atac.X)  
import torch
torch.cuda.empty_cache()
gc.collect()
atac.layers['tf_idf'] = atac.X.copy()

In [None]:
atac.X = atac.layers['atac_raw_counts'].copy()

In [None]:
atac.write('/mnt/hdd/data/Healthy/atac_normalized_concatednated.h5ad')

In [None]:
atac = sc.read_h5ad('/mnt/hdd/data/Healthy/atac_normalized_concatednated.h5ad')

#### load adata

In [None]:
adata = sc.read_h5ad('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno.h5ad')

#### extract multiome normailzed gex data

In [None]:
adata_gex = adata[adata.obs['sample'].isin(multiome_samples)]

In [None]:
# Get obs data from gex
atac.obs = adata_gex.obs.copy() # for concatenation

In [None]:
adata_gex.obs['sample'].value_counts()

In [None]:
adata_gex

#### clean up gex data from muliome samples

In [None]:
# Set raw count as X
adata_gex.X = adata_gex.layers['raw_counts'].copy()

# Set feature type
adata_gex.var['feature_types'] = 'Gene Expression'
adata_gex.var['modality'] = 'Gene Expression'

adata_gex.var['genome'] = adata_gex.var['genome-0'].copy()

# Remove all unneccessary var data
adata_gex.var = adata_gex.var.loc[:,['feature_types','genome']].copy()

In [None]:
# delete all uns/obsm/varm/layers/obsp/raw
del adata_gex.uns
del adata_gex.obsm
del adata_gex.varm
del adata_gex.layers
del adata_gex.obsp
del adata_gex.raw
gc.collect()

In [None]:
adata_gex

In [None]:
gc.collect()

#### concat normalized gex and atac

In [None]:
atac

In [None]:
atac.var

In [None]:
atac.var['feature_types'] = 'Peaks'

In [None]:
# Filter variable peaks
atac = atac[:,atac.var['highly_variable'] == True]

In [None]:
# Joint data
adata_multi = adata_gex.copy().transpose().concatenate(atac.copy().transpose()).transpose()

# Add modality to .obs
adata_multi.obs['modality'] = 'Multiome'

# Fix var_names
adata_multi.var_names = list(adata_gex.var_names) + list(atac.var_names)

# Clean up .obs
adata_multi.obs = adata_multi.obs.loc[:,['sample', 'doublet_calls', 'final_doublets', 'final_doublets_cat', 'phase', 'proliferation', 'initial_cell_type']].copy()

In [None]:
adata_multi.var = adata_multi.var.iloc[:,0:2].copy()

In [None]:
adata_multi.var.feature_types.value_counts()

In [None]:
adata_multi.var

# Prepare for MultiVI

In [None]:
# extract non-multiome gex data
adata_rna = adata[~adata.obs['sample'].isin(multiome_samples)]
adata_rna.obs['modality'] = 'expression'

In [None]:
adata_rna

In [None]:
adata_rna.obs['sample'].value_counts()

In [None]:
# Set raw count as X
adata_rna.X = adata_rna.layers['raw_counts'].copy()

# Set feature type
adata_rna.var['feature_types'] = 'Gene Expression'
adata_rna.var['modality'] = 'Gene Expression'

adata_rna.var['genome'] = adata_rna.var['genome-6'].copy()

# Remove all unneccessary var data
adata_rna.var = adata_rna.var.loc[:,['feature_types','genome']].copy()

In [None]:
# delete all uns/obsm/varm/layers/obsp/raw
del adata_rna.uns
del adata_rna.obsm
del adata_rna.varm
del adata_rna.layers
del adata_rna.obsp
del adata_rna.raw
gc.collect()

In [None]:
# Clean up .obs
adata_rna.obs = adata_rna.obs.loc[:,['sample', 'doublet_calls', 'final_doublets', 'final_doublets_cat', 'phase', 'proliferation', 'initial_cell_type']].copy()

In [None]:
del adata
gc.collect()

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
import scvi
gc.collect()

In [None]:
del adata_gex
del atac
gc.collect()

In [None]:
gc.collect()

In [None]:
adata_multi

In [None]:
adata_multi.X.shape

In [None]:
adata_rna

In [None]:
adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi, adata_rna)

In [None]:
adata_mvi.write('/mnt/hdd/data/Healthy/multiVIorganized_object.h5ad')

In [None]:
# Order features, such that genes appear before genomic regions
adata_mvi = adata_mvi[:, adata_mvi.var["feature_types"].argsort()].copy()
adata_mvi.var

In [None]:
# filter features present in less than 1% of cells
print(adata_mvi.shape)
sc.pp.filter_genes(adata_mvi, min_cells=int(adata_mvi.shape[0] * 0.01))
print(adata_mvi.shape)

# Run MultiVI

n_hidden: 1024, layers: 4, inject: True

This model does a very good job in clustering doublets together.

## Setup Model

In [None]:
n_hidden=1024
n_latent=50
n_layers=4

batch_key = 'modality'
# labels_key = None

categorical_covariate_keys = ['sample']
continuous_covariate_keys = None

deeply_inject_covariates = True

modality_weights = 'cell'
model_depth = True

In [None]:
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key=batch_key, categorical_covariate_keys=categorical_covariate_keys, continuous_covariate_keys=continuous_covariate_keys)

In [None]:
model_mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['feature_types']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['feature_types']=='Peaks').sum(),
    n_hidden=n_hidden,
    n_latent=n_latent, 
    n_layers_encoder=n_layers,
    n_layers_decoder=n_layers,
    deeply_inject_covariates=deeply_inject_covariates,
    #modality_weights=modality_weights,
    #model_depth=model_depth,
    #gene_dispersion='gene-batch'
)
model_mvi.view_anndata_setup()

## Train

In [None]:
import torch

In [None]:
torch.cuda.empty_cache()

In [None]:
model_mvi.train()

In [None]:
# plot reconstruction loss
plt.plot(model_mvi.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
plt.plot(model_mvi.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

## Save

## Results

In [None]:
adata_mvi.obsm['X_MultiVI'] = model_mvi.get_latent_representation(adata_mvi)

In [None]:
adata_mvi.write('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_MutliVI.h5ad')

In [None]:
sc.pp.neighbors(adata_mvi, use_rep='X_MultiVI')

In [None]:
sc.tl.umap(adata_mvi, min_dist=0.3)

In [None]:
sc.tl.leiden(adata_mvi, resolution=1.5)

In [None]:
sc.pl.umap(adata_mvi, color=['modality','sample','leiden','initial_cell_type','doublet_calls','final_doublets_cat'], size=7, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=2, wspace=0.65, color_map=ch_YlRd)

# Remove Doublets

In [None]:
adata_mvi_wd = adata_mvi.copy()

In [None]:
adata_mvi.obs['doublet_calls_cat'] = [str(x) for x in adata_mvi.obs['doublet_calls']]

In [None]:
adata_mvi.uns['doublet_calls_cat_colors'] = np.array([mpl.colors.to_hex(color, keep_alpha=True) for color in ch_YlRd(np.linspace(0,1,8))])

In [None]:
sc.pl.umap(adata_mvi, color=['doublet_calls_cat'], size=15, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, color_map=ch_YlRd)

In [None]:
plot_composition(adata_mvi, x_key='initial_cell_type', y_key='doublet_calls_cat', x_rotation=90, figsize=(8,4))

In [None]:
sc.tl.leiden(adata_mvi, resolution=2.5)

In [None]:
sc.pl.umap(adata_mvi, color=['leiden','doublet_calls','final_doublets_cat'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4, wspace=0.65,color_map=ch_YlRd)

In [None]:
pct_doublets = plot_composition(adata_mvi, x_key='leiden', y_key='doublet_calls_cat', x_rotation=90, figsize=(8,4))
pct_doublets

In [None]:
sc.tl.leiden(adata_mvi, restrict_to=('leiden', ['26']), resolution=0.25, key_added='leiden_sub1')

In [None]:
sc.pl.umap(adata_mvi, color=['leiden_sub1','final_doublets_cat','doublet_calls'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4, wspace=0.65,color_map=ch_YlRd)

In [None]:
sc.tl.leiden(adata_mvi, restrict_to=('leiden_sub1', ['20']), resolution=0.25, key_added='leiden_sub2')

In [None]:
sc.pl.umap(adata_mvi, color=['leiden_sub2','final_doublets_cat','doublet_calls'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4, wspace=0.65,color_map=ch_YlRd)

In [None]:
plot_composition(adata_mvi, x_key='leiden_sub2', y_key='doublet_calls_cat', x_rotation=90, figsize=(8,4))

In [None]:
# remove doublet clusters
adata_mvi = adata_mvi[~adata_mvi.obs.leiden_sub2.isin(list(pct_doublets['x_labels'][pct_doublets['0'] < 4]))].copy()

In [None]:
sc.pp.neighbors(adata_mvi, use_rep="X_MultiVI", metric='correlation', n_pcs=50, n_neighbors=20)
sc.tl.umap(adata_mvi, min_dist=0.3)

In [None]:
sc.tl.leiden(adata_mvi, resolution=3)

In [None]:
sc.pl.umap(adata_mvi, color=['leiden','doublet_calls','final_doublets_cat'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4,wspace=0.55, color_map=ch_YlRd)

In [None]:
pct_doublets = plot_composition(adata_mvi, x_key='leiden', y_key='doublet_calls_cat', x_rotation=90, figsize=(8,4))
pct_doublets

In [None]:
# remove doublet clusters
adata_mvi = adata_mvi[~adata_mvi.obs.leiden.isin(list(pct_doublets['x_labels'][pct_doublets['0'] < 5]))].copy()

In [None]:
# remove cells with > 3 doublet calls
adata_mvi = adata_mvi[adata_mvi.obs.doublet_calls < 4 ].copy()

In [None]:
# remove cells initial cell type == 'non-epithelial
adata_mvi = adata_mvi[adata_mvi.obs.initial_cell_type != 'Non-Epithelial' ].copy()

In [None]:
sc.pp.neighbors(adata_mvi, use_rep="X_MultiVI", n_pcs=50, n_neighbors=20)
sc.tl.umap(adata_mvi, min_dist=0.2, spread=0.8, negative_sample_rate=1, gamma=0.5)

In [None]:
sc.tl.leiden(adata_mvi, resolution=1)

In [None]:
sc.pl.umap(adata_mvi, color=['sample','leiden','initial_cell_type','doublet_calls','final_doublets_cat'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4, wspace=0.85, color_map=ch_YlRd)

In [None]:
sc.tl.paga(adata_mvi, groups='initial_cell_type')
sc.pl.paga(adata_mvi)

In [None]:
sc.tl.umap(adata_mvi, min_dist=0.2, spread=0.8, negative_sample_rate=1, gamma=0.5, init_pos='paga')

In [None]:
sc.pl.umap(adata_mvi, color=['sample','leiden','initial_cell_type','doublet_calls','final_doublets_cat'], size=7, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4, wspace=0.75, color_map=ch_YlRd)

# Update Original Files

In [None]:
adata_mvi.write('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_noimmune_multivi.h5ad')

In [None]:
adata_mvi = sc.read_h5ad('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_noimmune_multivi.h5ad')

In [None]:
barcodes_multiome = [name.split('_')[0] for name in adata_mvi[adata_mvi.obs.modality.isin(['paired'])].obs_names]
barcodes_gex = [name.split('_')[0] for name in adata_mvi[~adata_mvi.obs.modality.isin(['paired'])].obs_names]
barcodes_all = [name.split('_')[0] for name in adata_mvi.obs_names]

In [None]:
adata_multi = adata_multi[barcodes_multiome]
adata_rna = adata_rna[barcodes_gex]

In [None]:
adata = adata[barcodes_all]

In [None]:
adata.obsm['X_MultiVI'] = adata_mvi.obsm['X_MultiVI']

In [None]:
adata.write('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_noimmune_multivi_orig_wodblts.h5ad')

In [None]:
del adata
del adata_mvi
gc.collect()

# Run MultiVI without Doublets

n_hidden: 1024, layers: 4, inject: True


In [None]:
del model_mvi

In [None]:
torch.cuda.empty_cache()

## Setup Data

In [None]:
import torch
import scvi

In [None]:
adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi, adata_rna)

In [None]:
# Order features, such that genes appear before genomic regions
adata_mvi = adata_mvi[:, adata_mvi.var["feature_types"].argsort()].copy()
adata_mvi.var

In [None]:
# filter features present in less than 1% of cells
print(adata_mvi.shape)
sc.pp.filter_genes(adata_mvi, min_cells=int(adata_mvi.shape[0] * 0.01))
print(adata_mvi.shape)

## Setup Model

In [None]:
n_hidden=1024
n_latent=50
n_layers=4

batch_key = 'modality'
# labels_key = None

categorical_covariate_keys = ['sample']
continuous_covariate_keys = None

deeply_inject_covariates = True

modality_weights = 'cell'
model_depth = True

In [None]:
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key=batch_key, categorical_covariate_keys=categorical_covariate_keys, continuous_covariate_keys=continuous_covariate_keys)

In [None]:
model_mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['feature_types']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['feature_types']=='Peaks').sum(),
    n_hidden=n_hidden,
    n_latent=n_latent, 
    n_layers_encoder=n_layers,
    n_layers_decoder=n_layers,
    deeply_inject_covariates=deeply_inject_covariates,
    #modality_weights=modality_weights,
    #model_depth=model_depth,
    #gene_dispersion='gene-batch'
)
model_mvi.view_anndata_setup()

## Train

In [None]:
gc.collect()

In [None]:
torch.cuda.empty_cache()

In [None]:
torch.cuda.memory_allocated()

In [None]:
model_mvi.train()

In [None]:
# plot reconstruction loss
plt.plot(model_mvi.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
plt.plot(model_mvi.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

## Save

## Results

In [None]:
adata_mvi.obsm['X_MultiVI_rmDoublets'] = model_mvi.get_latent_representation(adata_mvi)

In [None]:
sc.pp.neighbors(adata_mvi, use_rep="X_MultiVI_rmDoublets", n_pcs=50, n_neighbors=20)
sc.tl.umap(adata_mvi, min_dist=0.2, spread=0.8, negative_sample_rate=1, gamma=0.5)

In [None]:
sc.tl.leiden(adata_mvi, resolution=1)

In [None]:
sc.pl.umap(adata_mvi, color=['sample','leiden','initial_cell_type','doublet_calls','final_doublets_cat'], size=7, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=2, wspace =0.65, color_map=ch_YlRd)

In [None]:
sc.tl.paga(adata_mvi, groups='initial_cell_type')
sc.pl.paga(adata_mvi)

In [None]:
sc.tl.umap(adata_mvi, min_dist=0.2, spread=0.8, negative_sample_rate=1, gamma=0.5, init_pos='paga')

In [None]:
sc.pl.umap(adata_mvi, color=['sample','leiden','initial_cell_type','doublet_calls','final_doublets_cat'], size=7, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=3,wspace=0.99, color_map=ch_YlRd)

# Update Original Files & Save

## Multiome

In [None]:
barcodes_all = [name.split('_')[0] for name in adata_mvi.obs_names]

In [None]:
adata = adata[barcodes_all]

In [None]:
adata.obsm['X_MultiVI_rmDoublets'] = adata_mvi.obsm['X_MultiVI_rmDoublets']

In [None]:
adata.write('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_noimmune_multivi_orig_wodblts_2.h5ad')