In [None]:
import sys
import os
import scanpy as sc
import anndata
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import cell2location
import scvi

from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text for PDFs

In [None]:
root_path = os.getcwd()

In [None]:
results_folder = os.path.join(root_path, 'deconvolution')
ref_run_name =  os.path.join(results_folder, 'reference_signatures') 
run_name = os.path.join(results_folder, 'cell2location_map')  

In [None]:
## check if folder exists and create it otherwise
if not os.path.exists(ref_run_name):
    os.makedirs(ref_run_name)
    print(f"Folder '{ref_run_name}' created.")
else:
    print(f"Folder '{ref_run_name}' already exists.")

In [None]:
## check if folder exists and create it otherwise
if not os.path.exists(run_name):
    os.makedirs(run_name)
    print(f"Folder '{run_name}' created.")
else:
    print(f"Folder '{run_name}' already exists.")

## Loading Single Cell Reference

In [None]:
adata_ref_obs = sc.read_h5ad('/sw_besca2_cellbender.annotated.h5ad')

In [None]:
adata_ref_obs.obs['celltype_merged'].unique().tolist()

In [None]:
print(adata_ref_obs.raw.X)

For the deconvolution purposes, we are going to exclude B T cell doublets and mixed cells. 

In [None]:
exclude_celltypes = ['B T cell doublet', 'mixed']
mask = ~adata_ref_obs.obs['celltype_merged'].isin(exclude_celltypes)
filtered_adata_ref_obs =  adata_ref_obs[mask]

In [None]:
filtered_adata_ref_obs.obs['celltype_merged'].unique().tolist()

We do not have the raw counts here, so we need to read the .mtx matrix and then match with the information provided in this object about Dblabel4

In [None]:
# Paths to the files
path_raw = '/raw/' #replace with path to raw snRNA-seq data
adata_raw = sc.read_mtx(path_raw + 'matrix.mtx').T
adata_raw.obs_names = pd.read_csv(path_raw + 'barcodes.tsv', header=None, sep='\t')[0]
adata_raw.var_names = pd.read_csv(path_raw + 'genes.tsv', header=None, sep='\t')[1]

In [None]:
adata_raw.obs_names_make_unique()
adata_raw.var_names_make_unique()

In [None]:
combined_obs = adata_raw.obs.join(filtered_adata_ref_obs.obs, how='inner')
indices_to_keep = combined_obs.index

In [None]:
adata_ref = adata_raw[indices_to_keep]

In [None]:
adata_ref.obs = combined_obs

In [None]:
adata_ref.obs['celltype_merged'].unique().tolist()

In [None]:
# We clean the variables that we are not longer using 
del adata_raw
del combined_obs
del filtered_adata_ref_obs
# del adata_ref_obs

In [None]:
from cell2location.utils.filtering import filter_genes
selected = filter_genes(adata_ref, cell_count_cutoff=25, cell_percentage_cutoff2=0.1, nonz_mean_cutoff=1.25)

# filter the object
adata_ref = adata_ref[:, selected].copy()

In [None]:
# prepare anndata for the regression model
cell2location.models.RegressionModel.setup_anndata(adata=adata_ref,
                        # 10X reaction / sample / batch
                        # batch_key='cell_processing_protocol',
                        # cell type, covariate used for constructing signatures
                        labels_key='celltype_merged',
                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
                        categorical_covariate_keys=['sex']
                       )

In [None]:
# create the regression model
from cell2location.models import RegressionModel
mod = RegressionModel(adata_ref)

# view anndata_setup as a sanity check
mod.view_anndata_setup()

In [None]:
mod.train(max_epochs=400, accelerator='cuda')

In [None]:
mod.plot_history(20)

In [None]:
# In this section, we export the estimated cell abundance (summary of the posterior distribution).
adata_ref = mod.export_posterior(
    adata_ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}
)

# Save model
mod.save(f"{ref_run_name}", overwrite=True)

# Save anndata object with results
adata_file = f"{ref_run_name}/sc.h5ad"
adata_ref.write(adata_file)
adata_file

In [None]:
adata_ref = mod.export_posterior(
    adata_ref, use_quantiles=False,
    sample_kwargs={'batch_size': 2500, 'use_gpu': True}
)

In [None]:
mod.plot_QC()

In [None]:
# adata_file = f"{ref_run_name}/sc.h5ad"
# adata_ref = sc.read_h5ad(adata_file)
# d = cell2location.models.RegressionModel.load(f"{ref_run_name}", adata_ref)

In [None]:
# export estimated expression in each cluster
if 'means_per_cluster_mu_fg' in adata_ref.varm.keys():
    inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
else:
    inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
inf_aver.columns = adata_ref.uns['mod']['factor_names']
inf_aver.iloc[0:5, 0:5]

## Loading 10x VISIUM data

In [None]:
root_path = os.getcwd()
adata_folder = os.path.join(root_path)

In [None]:
file_names = [f for f in os.listdir(os.path.join(adata_folder, 'analyzed')) if os.path.isfile(os.path.join(adata_folder,'analyzed',f))]

adata_list = [anndata.read_h5ad(os.path.join(adata_folder, 'analyzed', file)) for file in file_names if file.endswith('.h5ad')]

In [None]:
# Combine anndata objects together
adata_concat = sc.concat(
    adata_list,
    label="library_id",
    uns_merge="unique",
    keys=[
        k
        for d in [adata.uns["spatial"] for adata in adata_list]
        for k, v in d.items()
    ],
    index_unique="-",
    join='outer',
)

We remove mito genes

In [None]:
adata_concat.var['SYMBOL'] = adata_concat.var.index

In [None]:
adata_concat.var

In [None]:
# find mitochondria-encoded (MT) genes
adata_concat.var['MT_gene'] = [gene.startswith('mt-') for gene in adata_concat.var['SYMBOL']]

# remove MT genes for spatial mapping (keeping their counts in the object)
adata_concat.obsm['MT'] = adata_concat[:, adata_concat.var['MT_gene'].values].X.toarray()
adata_concat = adata_concat[:, ~adata_concat.var['MT_gene'].values]

## Cell2location: spatial mapping

In [None]:
# find shared genes and subset both anndata and reference signatures
intersect = np.intersect1d(adata_concat.var_names, inf_aver.index)
adata_concat = adata_concat[:, intersect].copy()
inf_aver = inf_aver.loc[intersect, :].copy()

# prepare anndata for cell2location model
cell2location.models.Cell2location.setup_anndata(adata=adata_concat, batch_key="Batch_ID", labels_key  = 'Gender')

In [None]:
# create and train the model
mod = cell2location.models.Cell2location(
    adata_concat, cell_state_df=inf_aver,
    # the expected average cell abundance: tissue-dependent
    # hyper-prior which can be estimated from paired histology:
    N_cells_per_location=10,
    # hyperparameter controlling normalisation of
    # within-experiment variation in RNA detection:
    detection_alpha=100)

mod.view_anndata_setup()

In [None]:
mod.train(max_epochs=50000,
          # train using full data (batch_size=None)
          batch_size=None,
          # use all data points in training because
          # we need to estimate cell abundance at all locations
          train_size=1,
          use_gpu=True,
         )

# plot ELBO loss history during training, removing first 100 epochs from the plot
mod.plot_history(1000)
plt.legend(labels=['full data training']);

In [None]:
# In this section, we export the estimated cell abundance (summary of the posterior distribution).
adata_concat = mod.export_posterior(
    adata_concat, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'use_gpu': True}
)

# Save model
mod.save(f"{run_name}", overwrite=True)

# mod = cell2location.models.Cell2location.load(f"{run_name}", adata_vis)

# Save anndata object with results
adata_file = f"{run_name}/sp.h5ad"
adata_concat.write(adata_file)
adata_file

In [None]:
# adata_file = f"{run_name}/sp.h5ad"
# adata_vis = sc.read_h5ad(adata_file)
# mod = cell2location.models.Cell2location.load(f"{run_name}", adata_vis)

In [None]:
mod.plot_QC()

In [None]:
# fig = mod.plot_spatial_QC_across_batches()

In [None]:
adata_concat.obs

In [None]:
# add 5% quantile, representing confident cell abundance, 'at least this amount is present',
# to adata.obs with nice names for plotting
adata_concat.obs[adata_concat.uns['mod']['factor_names']] = adata_concat.obsm['q05_cell_abundance_w_sf']

# plot in spatial coordinates
for i, library in enumerate(
   adata_concat.obs["Sample_ID"].unique().tolist()
):
    with mpl.rc_context({'axes.facecolor':  'black',
                     'figure.figsize': [4.5, 5]}):
    
        ad = adata_concat[adata_concat.obs.library_id == library, :].copy()
        print(library)
        sc.pl.spatial(ad, cmap='magma',
                  # show first 8 cell types
                  color=adata_concat.uns['mod']['factor_names'],
                  ncols=4, size=1.3,
                  img_key='hires',
                  # limit color scale at 99.2% quantile of cell abundance
                  vmin=0, vmax='p99.2',
                  library_id=library
                 )

In [None]:
# Now we use cell2location plotter that allows showing multiple cell types in one panel
from cell2location.plt import plot_spatial
from cell2location.utils import select_slide

# select up to 6 clusters
clust_labels = ['periportal hepatocyte', 'pericentral hepatocyte']
clust_col = ['' + str(i) for i in clust_labels] # in case column names differ from labels

# plot in spatial coordinates
for i, library in enumerate(
   adata_concat.obs["Sample_ID"].unique().tolist()
):

    slide = select_slide(adata_concat, library, batch_key='Sample_ID')

    with mpl.rc_context({'figure.figsize': (15, 15)}):
        fig = plot_spatial(
            adata=slide,
            # labels to show on a plot
            color=clust_col, labels = clust_labels,
            show_img=True,
            # 'fast' (white background) or 'dark_background'
            style='fast',
            # size of locations (adjust depending on figure size)
            circle_diameter=6,
            colorbar_position='bottom', max_color_quantile=0.97)

In [None]:
# select up to 6 clusters
clust_labels = ['periportal LSEC', 'pericentral LSEC', 'midzonal LSEC']
clust_col = ['' + str(i) for i in clust_labels] # in case column names differ from labels

# plot in spatial coordinates
for i, library in enumerate(
   adata_concat.obs["Sample_ID"].unique().tolist()
):

    slide = select_slide(adata_concat, library, batch_key='Sample_ID')

    with mpl.rc_context({'figure.figsize': (15, 15)}):
        fig = plot_spatial(
            adata=slide,
            # labels to show on a plot
            color=clust_col, labels = clust_labels,
            show_img=True,
            # 'fast' (white background) or 'dark_background'
            style='fast',
            # size of locations (adjust depending on figure size)
            circle_diameter=6,
            colorbar_position='right', max_color_quantile=0.97)

In [None]:
! jupyter nbconvert --to html 21_Deconvolution_C2L.ipynb