# Deconvolution with Cell2location

In this notebook, I perform Cell2location tutorial with the human lymph node dataset.

You can found this tutorial here : https://cell2location.readthedocs.io/en/latest/notebooks/cell2location_tutorial.html# 

For this analysis, choose **more than 1 GPU**

# 0. Installation

conda create -y -n cell2loc_env python=3.9

conda activate cell2loc_env

pip install cell2location[tutorials]

pip install ipykernel

python -m ipykernel install --user --name=cell2loc_env --display-name=Cell2location

pip install torch # for using gpu

In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import cell2location
import scanpy as sc
import torch

In [None]:
## Result folder 
output_path = f"/sbgenomics/output-files/data/Cell2Location"
os.makedirs(output_path, exist_ok=True)

# 1. Load and preprocess ST data

In [None]:
adata = sc.read_h5ad('./data/ST.h5ad')

In [None]:
adata.uns

In [None]:
list(adata.uns['spatial'].keys())[0]

In [None]:
adata.obs['sample'] = list(adata.uns['spatial'].keys())[0]

In [None]:
adata.var_names

In [None]:
adata.var

In [None]:
adata.var['SYMBOL']

In [None]:
adata.var.index

In [None]:
sc.pl.spatial(adata, color='PTPRC', gene_symbols='SYMBOL')

# 2. Load and preprocess scRNA-seq data

In [None]:
adata_ref = sc.read('./data/scRNA.h5ad')

In [None]:
adata.var

In [None]:
adata_ref.var.index

In [None]:
from cell2location.utils.filtering import filter_genes

In [None]:
selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)

In [None]:
adata_ref = adata_ref[:, selected].copy()

In [None]:
adata.var

# 3. Estimation of reference cell type signatures (Negative Binomial regression)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
cell2location.models.RegressionModel.setup_anndata(adata=adata_ref,
                        # 10X reaction / sample / batch
                        batch_key='Sample',
                        # cell type, covariate used for constructing signatures
                        labels_key='Subset',
                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
                        categorical_covariate_keys=['Method']
                       )

In [None]:
# create the regression model
from cell2location.models import RegressionModel
mod = RegressionModel(adata_ref)

# view anndata_setup as a sanity check
mod.view_anndata_setup()

In [None]:
# met 1h à tourner avec 1 GPU

In [None]:
mod.train(max_epochs=250, use_gpu=True)

In [None]:
mod.plot_history(20)

In [None]:
adata_ref = mod.export_posterior(
    adata_ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}
)

In [None]:
mod.plot_QC()

In [None]:
# export estimated expression in each cluster
if 'means_per_cluster_mu_fg' in adata_ref.varm.keys():
    inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
else:
    inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
inf_aver.columns = adata_ref.uns['mod']['factor_names']
inf_aver.iloc[0:5, 0:5]

# 4. Spatial mapping

In [None]:
# find shared genes and subset both anndata and reference signatures
intersect = np.intersect1d(adata.var_names, inf_aver.index)

In [None]:
adata = adata[:, intersect].copy()
inf_aver = inf_aver.loc[intersect, :].copy()

In [None]:
# prepare anndata for cell2location model
cell2location.models.Cell2location.setup_anndata(adata=adata, batch_key="sample")

In [None]:
# create and train the model
mod = cell2location.models.Cell2location(
    adata, cell_state_df=inf_aver,
    # the expected average cell abundance: tissue-dependent
    # hyper-prior which can be estimated from paired histology:
    N_cells_per_location=30,
    # hyperparameter controlling normalisation of
    # within-experiment variation in RNA detection:
    detection_alpha=20
)
mod.view_anndata_setup()

In [None]:
mod.train(max_epochs=3000,
          # train using full data (batch_size=None)
          batch_size=None,
          # use all data points in training because
          # we need to estimate cell abundance at all locations
          train_size=1,
          use_gpu=True,
         )

# plot ELBO loss history during training, removing first 100 epochs from the plot
mod.plot_history(1000)
plt.legend(labels=['full data training']);

Exporting estimated posterior distributions of cell abundance and saving results:

In [None]:
adata = mod.export_posterior(
    adata, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'use_gpu': True}
)

In [None]:
# Save model
mod.save(f"{output_path}", overwrite=True)

In [None]:
# Save anndata object with results
adata_file = f"{output_path}/adata_after_train.h5ad"
adata.write(adata_file)
adata_file

In [None]:
mod.plot_QC()

# 5. Visualising cell abundance in spatial coordinates

In [None]:
# add 5% quantile, representing confident cell abundance, 'at least this amount is present',
# to adata.obs with nice names for plotting
adata.obs[adata.uns['mod']['factor_names']] = adata.obsm['q05_cell_abundance_w_sf']

In [None]:
# plot in spatial coordinates
with mpl.rc_context({'axes.facecolor':  'black',
                     'figure.figsize': [4.5, 5]}):

    sc.pl.spatial(adata, cmap='magma',
                  # show first 8 cell types
                  color=['B_Cycling', 'B_GC_LZ', 'T_CD4+_TfH_GC'],
                  ncols=4, size=1.3,
                  img_key='hires',
                  # limit color scale at 99.2% quantile of cell abundance
                  vmin=0, vmax='p99.2'
                 )

In [None]:
# Now we use cell2location plotter that allows showing multiple cell types in one panel
from cell2location.plt import plot_spatial

In [None]:
# select up to 6 clusters
clust_labels = ['T_CD4+_naive', 'B_naive', 'FDC']
clust_col = ['' + str(i) for i in clust_labels] # in case column names differ from labels

In [None]:
with mpl.rc_context({'figure.figsize': (15, 15)}):
    fig = plot_spatial(
        adata=adata,
        # labels to show on a plot
        color=clust_col, labels=clust_labels,
        show_img=True,
        # 'fast' (white background) or 'dark_background'
        style='fast',
        # limit color scale at 99.2% quantile of cell abundance
        max_color_quantile=0.992,
        # size of locations (adjust depending on figure size)
        circle_diameter=6,
        colorbar_position='right'
    )

# 6. Downstream analysis

# 6.1. Identification of spatial domains

In [None]:
# compute KNN using the cell2location output stored in adata.obsm
sc.pp.neighbors(adata, use_rep='q05_cell_abundance_w_sf',
                n_neighbors = 15)

In [None]:
# Cluster spots into regions using scanpy
sc.tl.leiden(adata, resolution=1.1)

In [None]:
# add region as categorical variable
adata.obs["region_cluster"] = adata.obs["leiden"].astype("category")

In [None]:
# compute UMAP using KNN graph based on the cell2location output
sc.tl.umap(adata, min_dist = 0.3, spread = 1)

# show regions in UMAP coordinates
with mpl.rc_context({'axes.facecolor':  'white',
                     'figure.figsize': [8, 8]}):
    sc.pl.umap(adata, color=['region_cluster'], size=30,
               color_map = 'RdPu', ncols = 2, legend_loc='on data',
               legend_fontsize=20)
    sc.pl.umap(adata, color=['sample'], size=30,
               color_map = 'RdPu', ncols = 2,
               legend_fontsize=20)

# plot in spatial coordinates
with mpl.rc_context({'axes.facecolor':  'black',
                     'figure.figsize': [4.5, 5]}):
    sc.pl.spatial(adata, color=['region_cluster'],
                  size=1.3, img_key='hires', alpha=0.5)

# 6.2. Identification of spatial co-occurrence cell types

In [None]:
from cell2location import run_colocation
res_dict, adata = run_colocation(
    adata,
    model_name='CoLocatedGroupsSklearnNMF',
    train_args={
      'n_fact': np.arange(2, 30), # IMPORTANT: use a wider range of the number of factors (5-30)
      'sample_name_col': 'sample', # columns in adata.obs that identifies sample
      'n_restarts': 3 # number of training restarts
    },
    # the hyperparameters of NMF can be also adjusted:
    model_kwargs={'alpha': 0.01, 'init': 'random', "nmf_kwd_args": {"tol": 0.000001}},
    export_args={'path': f'{output_path}'}
)

In [None]:
# Here we plot the NMF weights (Same as saved to `cell_type_fractions_heatmap`)
res_dict['n_fact3']['mod'].plot_cell_type_loadings()

# 6.3. Estimate cell-type specific expression of every gene in the spatial data

In [None]:
# Compute expected expression per cell type
expected_dict = mod.module.model.compute_expected_per_cell_type(
    mod.samples["post_sample_q05"], mod.adata_manager
)

In [None]:
# Add to anndata layers
for i, n in enumerate(mod.factor_names_):
    adata.layers[n] = expected_dict['mu'][i]

In [None]:
# Save anndata object with results
adata_file = f"{output_path}/cell_type_specific_expression_ST_cell2loc.h5ad"
adata.write(adata_file)
adata_file

In [None]:
def plot_genes_per_cell_type(slide, genes, ctypes):
    n_genes = len(genes)
    n_ctypes = len(ctypes)
    fig, axs = plt.subplots(
        nrows=n_genes, ncols=n_ctypes + 1, figsize=(4.5 * (n_ctypes + 1) + 2, 5 * n_genes + 1), squeeze=False
    )
    # axs = axs.reshape((n_genes, n_ctypes+1))

    # plots of every gene
    for j in range(n_genes):
        # limit color scale at 99.2% quantile of gene expression (computed across cell types)
        quantile_across_ct = np.array(
            [
                np.quantile(slide.layers[n][:, slide.var["SYMBOL"] == genes[j]].toarray(), 0.992)
                for n in slide.uns["mod"]["factor_names"]
            ]
        )
        quantile_across_ct = np.partition(quantile_across_ct.flatten(), -2)[-2]
        sc.pl.spatial(
            slide,
            cmap="magma",
            color=genes[j],
            # layer=ctypes[i],
            gene_symbols="SYMBOL",
            ncols=4,
            size=1.3,
            img_key="hires",
            # limit color scale at 99.2% quantile of gene expression
            vmin=0,
            vmax="p99.2",
            ax=axs[j, 0],
            show=False,
        )

        # plots of every cell type
        for i in range(n_ctypes):
            sc.pl.spatial(
                slide,
                cmap="magma",
                color=genes[j],
                layer=ctypes[i],
                gene_symbols="SYMBOL",
                ncols=4,
                size=1.3,
                img_key="hires",
                # limit color scale at 99.2% quantile of gene expression
                vmin=0,
                vmax=quantile_across_ct,
                ax=axs[j, i + 1],
                show=False,
            )
            axs[j, i + 1].set_title(f"Gene:{genes[j]}|Cell type:{ctypes[i]}")

    return fig, axs

Here we highlight CD3D, pan T-cell marker expressed by 2 subtypes of T cells in distinct locations but not expressed by co-located B cells, that instead express CR2 gene

In [None]:
# list cell types and genes for plotting
ctypes = ['T_CD4+_TfH_GC', 'T_CD4+_naive', 'B_GC_LZ']
genes = ['CD3D', 'CR2']

with mpl.rc_context({'axes.facecolor':  'black'}):
    plot_genes_per_cell_type(adata, genes, ctypes);