In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype'] = 'none'

from tqdm import tqdm
from scipy.stats import pearsonr

In [2]:
### REPLACE with path to AnnData downloaded from GCP ###
data_dir = '/Users/adaly/Documents/mouse_colon/csplotch_anndata'
adata_betas = sc.read_h5ad(os.path.join(data_dir, 'adata_csplotch_celltype_betas.h5ad'))
adata_lambdas = sc.read_h5ad(os.path.join(data_dir, 'adata_csplotch_lambdas.h5ad'))

### REPLACE with path to AnnData downloaded from GCP ###
data_dir = '/Users/adaly/Dropbox (Simons Foundation)/cell_segmentation_colon/snrna_anndata/'
adata_snrna = sc.read_h5ad(os.path.join(data_dir, 'adata_larger_relabeling_after_tsne_stemfiltered_renamed.h5ad'))

In [3]:
# Compute mean expression within a cell group
def grouped_obs_mean(adata, group_key, layer=None, gene_symbols=None):
    if layer is not None:
        getX = lambda x: x.layers[layer]
    else:
        getX = lambda x: x.X
    if gene_symbols is not None:
        new_idx = adata.var[gene_symbols].values
    else:
        new_idx = adata.var_names

    grouped = adata.obs.groupby(group_key)
    out = pd.DataFrame(
        np.zeros((adata.shape[1], len(grouped)), dtype=np.float64),
        columns=grouped.groups.keys(),
        index=new_idx
    )

    for group, idx in grouped.indices.items():
        X = getX(adata[idx])
        out[group] = np.ravel(X.mean(axis=0, dtype=np.float64))
    return out

Pre-process snRNA-seq data and cSplotch betas

In [4]:
# Identify common genes and subset both data modalities
common_genes = np.intersect1d(adata_betas.var.index, adata_snrna.var.index)
print(len(common_genes), 'genes in common')

adata_betas = adata_betas[:, adata_betas.var.index.isin(common_genes)]
adata_snrna = adata_snrna[:, adata_snrna.var.index.isin(common_genes)].copy()

# Normalize, log-transform, and scale snRNA-seq data
sc.pp.normalize_total(adata_snrna, 3959)  # median spot depth across ST data
sc.pp.log1p(adata_snrna)

# (find marker genes for each cell type prior to scaling)
sc.tl.rank_genes_groups(adata_snrna, 'pheno_major_cell_types')
df_snrna_markers = sc.get.rank_genes_groups_df(adata_snrna, group=None, pval_cutoff=0.05, log2fc_min=0.5)

sc.pp.scale(adata_snrna)

# Calculate mean log-scaled expression within each snRNA cell type
sn_count_means = grouped_obs_mean(adata_snrna, 'pheno_major_cell_types')

11735 genes in common


In [None]:
# Find top marker genes for each cell type
n_markers_per = 50
snrna_markers = []

for ct in sn_count_means.columns:
    markers_ct = df_snrna_markers[df_snrna_markers['group'] == ct]
    snrna_markers.append(markers_ct['names'][:n_markers_per].values)
                        
snrna_markers = np.unique(np.concatenate(snrna_markers))
print(len(snrna_markers), 'marker genes across %d cell types' % len(sn_count_means.columns))

In [None]:
adata_betas.obs.index = ['%s %s %s' % (b, a, c) for b, a, c in zip(adata_betas.obs.condition, 
                                         adata_betas.obs.annotation, 
                                         adata_betas.obs.celltype)]

In [None]:
# Create a pseudo-single cell dataset from annotated ST (N_CELLS pseudo-cells per spot in proportions determined by
# SPOTlight deconvolution).

snrna_types = sn_count_means.columns
age_range = ['4w', '6w', '8w', '12w']

PCELLS_PER_SPOT = 50
pc_betas, pc_aar, pc_region, pc_age, pc_ctype = [],[],[],[],[]

obs_ag = adata_lambdas.obs[adata_lambdas.obs.Age.isin(age_range)]

for spot in tqdm(obs_ag.index):
    region = obs_ag.loc[spot, 'Region']
    age = obs_ag.loc[spot, 'Age']
    aar = obs_ag.loc[spot, 'annotation']
    
    for ct in snrna_types:
        if ct == 'Myocyte':
            ct = 'SMC'
        
        idx = '%s BL6WT.%s %s %s' % (age, region, aar, ct)
        beta_mean = np.array(adata_betas[idx].X).squeeze()

        for p in range(int(np.rint(obs_ag.loc[spot, ct] * PCELLS_PER_SPOT))):
            pc_betas.append(beta_mean)
            pc_aar.append(aar)
            pc_region.append(region)
            pc_age.append(age)

adata_pcells = ad.AnnData(X = np.array(pc_betas),
    obs = pd.DataFrame({'aar':pc_aar, 'age':pc_age, 'region':pc_region, 'celltype':pc_ctype}),
    var = adata_betas.var)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28136/28136 [13:31<00:00, 34.68it/s]


In [None]:
# Weighted Betas: z-scale
print('Scaling mean betas across all pseudocells...')
sc.pp.scale(adata_pcells)

In [None]:
def calc_profile_correlation(df_means_st, df_means_sn):
    assert df_means_st.index.equals(df_means_sn.index), 'indices do not match!'
    st_classes = df_means_st.columns
    sn_classes = df_means_sn.columns
    
    # st_celltypes x sn_celltypes matrix for storing correlations
    corr_mat = np.zeros((len(st_classes), len(sn_classes))) 
    
    for i, aar in enumerate(st_classes):
        for j, ct in enumerate(sn_classes):
            r, p = pearsonr(df_means_st[aar].values, df_means_sn[ct].values)
            corr_mat[i,j] = r
    
    return corr_mat

In [None]:
def plot_scores_dotplot(df, df_size):
    # fill with zeros
    df = df.fillna(0).T
    df_size = df_size.fillna(0).T
    
    df_size_cut=df_size
    
    df_size_cut.columns = df.columns
    df_size_cut.index = df.index
    
    # define order
    cat_order = df_size_cut.index

    obs = pd.DataFrame(df.index, index = df.index, columns = ['Cell type']).astype("category")
    mod_anndata = sc.AnnData(df, obs, dtype=np.float32)

    # plots
    vmin = -0.5
    vmax = 0.5
    cmap = 'RdBu_r'

    plt.rcParams['font.size'] = 10
    size_title = 'Mean Spot %'
    ax_dict = sc.pl.dotplot(mod_anndata, show=False, var_names = mod_anndata.var_names, dot_size_df = df_size_cut,
                            dot_color_df = df, categories_order = cat_order, size_title = size_title, colorbar_title = 'Pearson r',
                            groupby = 'Cell type', vmin = vmin, vmax = vmax, cmap=cmap, 
                            figsize = (len(mod_anndata.obs.index)+2, len(mod_anndata.obs.index)),
                            swap_axes=True)

    ax_dict['mainplot_ax'].set_yticklabels([i for i in mod_anndata.var_names]) 
    ax_dict['mainplot_ax'].tick_params(axis='y', labelleft=True, left=True, labelsize = 10, labelrotation = 0, pad = 0)   
    ax_dict['mainplot_ax'].tick_params(axis='x', labelbottom=True, bottom=True,labelsize = 10, labelrotation = 90, pad = 0) 
    ax_dict['mainplot_ax'].grid(visible=True, which='major', axis='both')
    ax_dict['size_legend_ax'].set_facecolor('white')
    ax_dict['size_legend_ax'].set_aspect(0.2)
    ax_dict['color_legend_ax'].set_aspect(0.2)
    fig = ax_dict['mainplot_ax'].get_figure()
    
    plt.xlabel('snRNA-seq')
    plt.ylabel('cSplotch')
    
    return fig, ax_dict

In [6]:
def plot_scores_heatmap(df_corr, df_size):
    fig, ax = plt.subplots(2, 2, figsize=(5.5,5), constrained_layout=True,
                           gridspec_kw={'width_ratios': [3, 1], 'height_ratios': [1, 10]})
    
    sns.heatmap(df_corr, ax=ax[1,0], vmin=-0.5, vmax=0.5, center=0, cmap='RdBu_r', cbar_ax=ax[0,0],
               cbar_kws={'label':"Pearson's r", 'orientation':'horizontal'})
    ax[1,0].set_ylabel('cSplotch')
    ax[1,0].set_xlabel('snRNA-seq')
    
    ax[1,1].barh(np.arange(len(df_size)), df_size.iloc[:,0], align='edge', height=0.95, color='gray')
    ax[1,1].set_yticks([])
    ax[1,1].invert_yaxis()
    ax[1,1].set_ylim(0, len(df_size))
    ax[1,1].set_title('Mean cell fraction', fontsize=9, loc='left')
    ax[1,1].spines[['right', 'top']].set_visible(False)
    
    ax[0,1].axis('off')
    
    return fig, ax

In [None]:
# Make correlation heatmap for abundant cell types in each AAR
abundant_thresh = 0.05
region = 'Proximal'
age_range = ['4w', '6w', '8w', '12w']

adata_pcells_sub = adata_pcells[np.logical_and(adata_pcells.obs.age.isin(age_range), adata_pcells.obs.region == region)]
age_str = age_range[0] + '-' + age_range[-1]

# Mean expression profile for each cell type in the niche
niche_beta_means = grouped_obs_mean(adata_pcells_sub, 'celltype')

for aar in adata_lambdas.obs.annotation.unique():
    obs = adata_lambdas.obs[np.logical_and(adata_lambdas.obs.annotation == aar, 
                                           adata_lambdas.obs.Age.isin(age_range))]
    obs = obs[obs.Region == region]
    
    abundant_celltypes = [ct for ct in snrna_types if obs[ct].mean() > abundant_thresh]
    
    # Make a celltypes x celltypes DataFrame containing frequency of each cell type in ST data (constant rows)
    celltype_freq = obs[abundant_celltypes].mean(axis=0)
    df_freq = pd.DataFrame(dict([(ct, celltype_freq) for ct in abundant_celltypes]))
        
    # Find top marker genes for each cell type that is abundant in current AAR
    n_markers_per = 50
    snrna_markers_aar = []
    for ct in abundant_celltypes:
        markers_ct = df_snrna_markers[df_snrna_markers['group'] == ct]
        snrna_markers_aar.append(markers_ct['names'][:n_markers_per].values)
    snrna_markers_aar = np.unique(np.concatenate(snrna_markers_aar))
        
    # Calculate mean expression profile for each cell type in current niche
    adata_pcells_niche = adata_pcells_sub[adata_pcells_sub.obs.aar == aar]
    adata_pcells_niche = adata_pcells_niche[adata_pcells_niche.obs.age.isin(age_range)]
    
    niche_beta_means = grouped_obs_mean(adata_pcells_niche, 'celltype')
    niche_beta_means_cells = niche_beta_means[abundant_celltypes]
    
    sn_count_means_cells = sn_count_means.loc[niche_beta_means_cells.index, abundant_celltypes]

    sn_count_means_cells = sn_count_means_cells.loc[snrna_markers_aar]
    niche_beta_means_cells = niche_beta_means_cells.loc[snrna_markers_aar]
        
    corr_mat = calc_profile_correlation(niche_beta_means_cells, sn_count_means_cells)
    df_corr = pd.DataFrame(corr_mat, index=abundant_celltypes, columns=abundant_celltypes)
        
    #fig, ax_dict = plot_scores_dotplot(df_corr, df_freq)
    #ax_dict['mainplot_ax'].set_title(aar)
    fig, ax = plot_scores_heatmap(df_corr, df_freq)
    plt.show()