In [1]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt

from pathlib import Path
from scipy.stats import ttest_ind
from statsmodels.stats.multitest import fdrcorrection

In [2]:
data_dir = '/mnt/home/adaly/ceph/datasets/U54_BA9/'

# AnnData object containing lambdas for all spots
adata = sc.read_h5ad(os.path.join(data_dir, 'splotch_anndata', 'adata_U54_BA9_splotch_lambdas.h5ad'))

# Remove any NaN values before summation
keep_genes = np.logical_not(np.isnan(adata.X.sum(axis=0)))
adata = adata[:, keep_genes]

# Clip expression of each gene to 99th percentile across all spots to remove effects of outliers
cap_val = np.percentile(adata.X, 99, axis=0)
adata.X = np.clip(adata.X, np.zeros_like(cap_val)[None,:], cap_val[None,:])

In [3]:
# Remove effects of gene expression scale while maintaining variation among individuals
sc.pp.scale(adata) 

  view_to_actual(adata)


In [4]:
# File containing gene program annotations
label_file = os.path.join(data_dir, 'coexpression_analysis', 
                          'U54_BA9_splotch_lambdas_corr_hclust_labels_d1700.csv')
gene_labels = pd.read_csv(label_file, index_col=0, header=0)

In [5]:
# Label individual arrays for significance testing
adata.obs['array_name'] = [Path(x).name.split('.')[0] for x in adata.obs.index]

In [6]:
def pval_to_stars(pval):
    if pval < 1e-4:
        return '****'
    elif pval < 1e-3:
        return '***'
    elif pval < 1e-2:
        return '**'
    elif pval < 0.05:
        return '*'
    else:
        return ''

def plot_module_heatmap(adata, module_genes, rows='aars', cols='Level 1', row_order=None, col_order=None,
                        col_comparisons=None, group_by=None, vmin=None, vmax=None, ax=None):
    '''
    Parameters:
    ----------
    adata: AnnData
        expression data to be visualized
    module_genes: iterable of str
        names of genes in coexpression module of interest (must match naming in adata.var.index)
    rows: str
        annotations in adata.obs to separate by on rows of heatmap
    cols: str
        annotations in adata.obs to separate by on cols of heatmap
    row_order: iterable of str or None
        list of annotation categories, in order, to appear on heatmap rows (can be used to subset data)
    col_order: iterable of str or None
        list of annotation categories, in order, to appear on heatmpa cols (can be used to subset data)
    col_comparisons: iterable of tuple or None
        list of (col1, col2) comparisons across which to perform BH-adjusted (Welch's) t-test; denote on heatmap with *'s
    group_by: str
        annotations in adata.obs denoting independent groups of samples (e.g., Visium array names); defaults to all observations being independent
    vmin, vmax: float
        min/max values of color bar
    ax: Axes or None
        axes on which to plot heatmap, or None to instantiate new
        
    Returns:
    -------
    ax: Axes
    '''
    adata_sub = adata[:, module_genes]
    
    if row_order is None:
        row_order = adata.obs[rows].unique()
    if col_order is None:
        col_order = adata.obs[cols].unique()
        
    dat = dict([(c,[]) for c in col_order])
            
    # Calculate mean expression of module in each cell
    for c in col_order:
        adata_sub_c = adata_sub[adata_sub.obs[cols] == c]
        for r in row_order:
            adata_sub_r = adata_sub_c[adata_sub_c.obs[rows] == r]
            mean_expr = adata_sub_r.X.mean()
            dat[c].append(mean_expr)
    dat = pd.DataFrame(dat, index=row_order, dtype=np.float32)
    
    # Calculate significance of mean change between indicated column pairs (separately per row)
    sig = pd.DataFrame(data=np.ones_like(dat), index=dat.index, columns=dat.columns)
    for c1, c2 in col_comparisons:
        adata_sub_c1 = adata_sub[adata_sub.obs[cols] == c1]
        adata_sub_c2 = adata_sub[adata_sub.obs[cols] == c2]
        for r in row_order:
            adata_sub_r1 = adata_sub_c1[adata_sub_c1.obs[rows] == r]
            adata_sub_r2 = adata_sub_c2[adata_sub_c2.obs[rows] == r]
            
            means_1 = adata_sub_r1.X.mean(axis=1)
            means_2 = adata_sub_r2.X.mean(axis=1)
            # group averages (e.g., Visium arrays) treated as independent samples
            if group_by is not None:
                means_1 = [means_1[adata_sub_r1.obs[group_by]==g].mean() for g in adata_sub_r1.obs[group_by].unique()]
                means_2 = [means_2[adata_sub_r2.obs[group_by]==g].mean() for g in adata_sub_r2.obs[group_by].unique()]

            t, pval = ttest_ind(means_1, means_2, equal_var=False)
            sig.loc[r, c2] = pval
    
    # Perform BH FDR correction
    pvals_all = sig.values.flatten()
    pvals_all[pvals_all != 1.0] = fdrcorrection(pvals_all[pvals_all != 1.0], is_sorted=False)[1]
    p_adj = pd.DataFrame(pvals_all.reshape(sig.shape), index=sig.index, columns=sig.columns)
    p_stars = p_adj.applymap(pval_to_stars)
        
    if ax is None:
        fig, ax = plt.subplots(1)
    sns.heatmap(dat, ax=ax, cmap='bwr', vmin=vmin, vmax=vmax, center=0,
                annot=p_stars, fmt="",
                cbar_kws={'label':r'Mean scaled expression ($\overline{\lambda}$)'})
    
    # Add brackets over compared columns
    bracket_count = {}
    for (c1, c2) in col_comparisons:
        for c in (c1, c2):
            if c not in bracket_count.keys():
                bracket_count[c] = 1
            else:
                bracket_count[c] += 1
        bh = -0.1 * np.maximum(bracket_count[c1], bracket_count[c2])
        x1 = list(col_order).index(c1) + 0.5
        x2 = list(col_order).index(c2) + 0.5
        ymin, ymax = ax.get_ylim()
        ax.plot([x1, x1, x2, x2], [0, bh, bh, 0], c='k')
        ax.set_ylim(ymin, ymax+bh)
    
    return ax

In [7]:
clust_lbl = 'd1700'
aar_list = ['Layer_%d' % i for i in range(1, 7)] + ['White_matter']
age_list = ['Young', 'Middle', 'Old']

for k in np.unique(gene_labels[clust_lbl]):
    genes_in = gene_labels.index[gene_labels[clust_lbl]==k]
    
    if len(genes_in) >= 10:
        plot_module_heatmap(adata, genes_in, rows='region', cols='Level 1',
                            row_order=aar_list, col_order=age_list,
                            col_comparisons=[('Young', 'Middle'), ('Young', 'Old')],
                            group_by='array_name', vmin=-1, vmax=1)
        plt.subplots_adjust(left=0.25)
        plt.savefig(os.path.join('lambda_plots_heatmap', 'Module%d_l1.png' % k), dpi=300)
        plt.close()