In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
from scipy import stats
import os

from scipy import spatial
from scipy import sparse
from scipy.interpolate import CubicSpline
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import networkx as nx
from umap import UMAP
from scipy.stats import ttest_ind, mannwhitneyu
from scipy.stats import pearsonr, spearmanr, zscore
from statsmodels.stats.multitest import multipletests

import json

In [None]:
import importlib
import scroutines
importlib.reload(scroutines)
from scroutines import powerplots
from scroutines.miscu import is_in_polygon

import utils_merfish
importlib.reload(utils_merfish)
from utils_merfish import rot2d, st_scatter, st_scatter_ax, plot_cluster, binning
from utils_merfish import RefLineSegs

import merfish_datasets
import merfish_genesets
importlib.reload(merfish_datasets)
importlib.reload(merfish_genesets)
from merfish_datasets import merfish_datasets
from merfish_datasets import merfish_datasets_params

from scroutines import basicu

In [None]:
def get_qc_metrics(df):
    """
    return metrics
     - key
      - (name, val, medval, bins)
    """
    metrics = {}
    cols  = ['volume', 'gncov', 'gnnum']
    names = ['cell volume', 'num transcripts', 'num genes']
    
    for col, name in zip(cols, names):
        val = df[col].values
        medval = np.median(val)
        bins = np.linspace(0, 10*medval, 50)
        
        metrics[col] = (name, val, medval, bins)
    return metrics

def get_norm_counts(adata, scaling=500):
    """norm - equalize the volume to be 500 for all cells
    """
    cnts = adata.X
    vol = adata.obs['volume'].values
    normcnts = cnts/vol.reshape(-1,1)*scaling
    adata.layers['norm'] = normcnts
    
    return normcnts

In [None]:
def preprocessing(adata):
    # filter genes
    cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
    adata_sub = adata[:,cond]

    # counts
    x = adata_sub.X
    cov = adata_sub.obs['n_counts'].values

    # CP10k
    xn = (sparse.diags(1/cov).dot(x))*1e4

    # log10(CP10k+1)
    xln = xn.copy()
    xln.data = np.log10(xln.data+1)

    adata_sub.layers['norm'] = xn
    adata_sub.layers['lognorm'] = xln
    
    return adata_sub

In [None]:
def get_hvgs(adata, layer, nbin=20, qth=0.3):
    """
    """
    xn = adata.layers[layer]
    
    # min
    gm = np.ravel(xn.mean(axis=0))

    # var
    tmp = xn.copy()
    tmp.data = np.power(tmp.data, 2)
    gv = np.ravel(tmp.mean(axis=0))-gm**2

    # cut 
    lbl = pd.qcut(gm, nbin, labels=np.arange(nbin))
    gres = pd.DataFrame()
    gres['lbl'] = lbl
    gres['mean'] = gm
    gres['var'] = gv
    gres['ratio']= gv/gm

    # select
    gres_sel = gres.groupby('lbl')['ratio'].nlargest(int(qth*(len(gm)/nbin))) #.reset_index()
    gsel_idx = np.sort(gres_sel.index.get_level_values(1).values)

    assert np.all(gsel_idx != -1)
    
    return adata.var.index.values[gsel_idx]

In [None]:
def binning_pipe(adata, n=20, layer='lnorm', bin_type='depth_bin'):
    """
    """
    assert bin_type in ['depth_bin', 'width_bin']
    # bin it 
    depth_bins, depth_binned = utils_merfish.binning(adata.obs['depth'].values, n)
    width_bins, width_binned = utils_merfish.binning(adata.obs['width'].values, n)

    norm_ = pd.DataFrame(adata.layers[layer], columns=adata.var.index)
    norm_['depth_bin'] = depth_binned
    norm_['width_bin'] = width_binned
    
    norm_mean = norm_.groupby(bin_type).mean(numeric_only=True)
    norm_sem  = norm_.groupby(bin_type).sem(numeric_only=True)
    norm_std  = norm_.groupby(bin_type).std(numeric_only=True)
    norm_n    = norm_[bin_type].value_counts(sort=False)

    return norm_mean, norm_sem, norm_std, norm_n, depth_binned, width_binned, depth_bins, width_bins

def binning_pipe2(adata, col_to_bin, layer, bins=None, n=20):
    """
    """
    if bins is None:
        # bin it 
        bins, binned = utils_merfish.binning(adata.obs[col_to_bin].values, n)
    else:
        binned = pd.cut(adata.obs[col_to_bin].values, bins=bins)

    norm_ = pd.DataFrame(adata.layers[layer], columns=adata.var.index)
    norm_['thebin'] = binned
    
    norm_mean = norm_.groupby('thebin').mean(numeric_only=True)
    norm_sem  = norm_.groupby('thebin').sem(numeric_only=True)
    norm_std  = norm_.groupby('thebin').std(numeric_only=True)
    norm_n    = norm_['thebin'].value_counts(sort=False)

    return norm_mean, norm_sem, norm_std, norm_n, binned, bins 

In [None]:
def neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(ref_emb)
    dists, idx = neigh.kneighbors(qry_emb, k, return_distance=True)
    
    raw_pred = ref_lbl[idx]

    # p
    pabc = np.empty((qry_n, n_unq_lbls))
    for i, lbl in enumerate(unq_lbls):
        p = np.sum(raw_pred==lbl, axis=1)/k
        pabc[:,i] = p

    # max
    max_pred = unq_lbls[np.argmax(pabc, axis=1)]

    # 
    gated_pred = max_pred.copy()
    cond1 = np.max( pabc, axis=1) > p_cutoff
    gated_pred[~cond1] = 'NA' 
    if dist_cutoff is not None:
        cond2 = np.max(dists, axis=1) < dist_cutoff
        gated_pred[~cond2] = 'NA' 
    
    return max_pred, gated_pred, np.max(dists, axis=1)


def neighbor_self_nonself(k, ref_emb, qry_emb):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    lbls = np.array([0]*ref_n+[1]*qry_n)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(np.vstack([ref_emb, qry_emb]))
    idx = neigh.kneighbors(qry_emb, k, return_distance=False)
    
    isself = lbls[idx]

    p = np.sum(isself, axis=1)/k

    
    return p # max_pred, gated_pred, np.max(dists, axis=1)

In [None]:
from py_pcha import PCHA
def get_aa(X):
    """
    """
    np.random.seed(0)
    XC, S, C, SSE, varexpl = PCHA(X, noc=3, delta=0)
    XC = np.array(XC)
    XC = XC[:,np.argsort(XC[0])].copy() # order this
    return XC

In [None]:
def add_triangle(XC, ax, zorder=0, vertices=False, label='', linecolor='gray', linewidth=1, **kwargs):
    # add the triangle
    ax.plot(XC[0].tolist()+[XC[0,0]], XC[1].tolist()+[XC[1,0]], '--', 
            color=linecolor, label=label, zorder=zorder, linewidth=linewidth, markersize=3)
    
    # add vertices
    if vertices:
        ax.scatter(XC[0,0], XC[1,0], color='C0', zorder=zorder, **kwargs)
        ax.scatter(XC[0,1], XC[1,1], color='C1', zorder=zorder, **kwargs)
        ax.scatter(XC[0,2], XC[1,2], color='C2', zorder=zorder, **kwargs)

In [None]:
def p_mark(p):
    """
    """
    
    if p > 0.05:
        mark = 'ns'
    elif p < 0.05 and p > 0.001:
        mark = '*'
    elif p < 0.001:
        mark = '***'
        
    return mark

# load data

In [None]:
outfigdir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/ms_reanalysis/250414"
!mkdir -p $outfigdir
fig_manager = powerplots.FigManager(outfigdir)

In [None]:
np.random.seed(0)

### MERFISH genes

In [None]:
f = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/merfish_genes.txt" 
genes = np.loadtxt(f, dtype='str')

genesets, typegenes_df = merfish_genesets.get_all_genesets()
for key, item in genesets.items():
    print(key, len(item))
    
genes = genesets['allmerfish']
agenes = genesets['a']
bgenes = genesets['b']
cgenes = genesets['c']
iegs   = genesets['i']

abcgenes = np.hstack([agenes, bgenes, cgenes])
genes_noniegs = np.array([g for g in genes if g not in iegs])

agenes_idx = basicu.get_index_from_array(genes, agenes)
bgenes_idx = basicu.get_index_from_array(genes, bgenes)
cgenes_idx = basicu.get_index_from_array(genes, cgenes)
igenes_idx = basicu.get_index_from_array(genes, iegs)
len(abcgenes), len(genes_noniegs)

### new ABC genes

In [None]:
# f = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/DEG_l23abc_gene_list_250409.csv"
# df_genes_newabc = pd.read_csv(f)
# # df_genes_newabc = df_genes_newabc[df_genes_newabc['cond']=='P8NR']
# abcgenes = df_genes_newabc['gene'].unique()
# agenes = df_genes_newabc.loc[df_genes_newabc['archetype']=='A', 'gene'].unique()
# bgenes = df_genes_newabc.loc[df_genes_newabc['archetype']=='B', 'gene'].unique()
# cgenes = df_genes_newabc.loc[df_genes_newabc['archetype']=='C', 'gene'].unique()

# print(len(abcgenes), 
#       len(agenes)+len(bgenes)+len(cgenes),
#       len(agenes),len(bgenes),len(cgenes), 
#      )
# abcgenes = np.intersect1d(genes, abcgenes)
# agenes = np.intersect1d(genes, agenes)
# bgenes = np.intersect1d(genes, bgenes)
# cgenes = np.intersect1d(genes, cgenes)

# print(len(abcgenes), 
#       len(agenes)+len(bgenes)+len(cgenes),
#       len(agenes),len(bgenes),len(cgenes), 
#      )

# agenes_idx = basicu.get_index_from_array(genes, agenes)
# bgenes_idx = basicu.get_index_from_array(genes, bgenes)
# cgenes_idx = basicu.get_index_from_array(genes, cgenes)


### MERFISH cells (integrated L2/3)

In [None]:
ddir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/organized" 

# get MERFISH data (cells) from integrated V1L23Glut 
fin = os.path.join(ddir, 'P8NR_v1l23glut_rna_merfish_250411.h5ad')
adata_ = ad.read(fin, backed='r')
adata_ = adata_[adata_.obs['modality']=='merfish'] 
print(adata_.obs['gated_pred_subclass'].unique()) # should be L2/3 only

l23_cells = adata_.obs.index.values
l23_cells.shape

### MERFISH raw - all genes

In [None]:
# %%time
names = [
    'P8NRa_ant2', 
    'P8NRb_ant2',
    'P8NRc_ant2', 
    'P8NRd_ant2',
    
    'P8NRa_pos2', 
    'P8NRb_pos2',
    'P8NRc_pos2', 
    'P8NRd_pos2',
]


mean_total_rna_target = 250
adata_merged = []

for i, name in enumerate(names):
    j = i // 4
    i = i % 4
    
    adatasub = ad.read(os.path.join(ddir, f'{name}_l2_v1_250410.h5ad')) 
    print(name, len(adatasub))
    
    adatasub.obs.index = np.char.add(f'{name}', adatasub.obs.index.values)
    adatasub.obs['sample'] = name
    
    norm_cnts = adatasub.layers['norm']
    mean_per_batch = np.mean(norm_cnts.sum(axis=1))
    # mean_per_batch_noniegs = np.mean(adatasub[:,genes_noniegs].layers['norm'].sum(axis=1))
    
    # adatasub.layers['jnorm']  = norm_cnts*(mean_total_rna_target/mean_per_batch_noniegs)
    adatasub.layers['jnorm']  = norm_cnts*(mean_total_rna_target/mean_per_batch)
    adatasub.layers['ljnorm'] = np.log2(1+adatasub.layers['jnorm'])
    
    adatasub.obs['norm_transcript_count']  = adatasub.layers['norm'].sum(axis=1)
    adatasub.obs['jnorm_transcript_count'] = adatasub.layers['jnorm'].sum(axis=1)
    
    adatasub.obs['depth_show'] = -adatasub.obs['depth'].values - i*1300 # name
    adatasub.obs['width_show'] =  adatasub.obs['width'].values - np.min(adatasub.obs['width'].values) + j*2500   # name
    
    adata_merged.append(adatasub)
    
adata_merged = ad.concat(adata_merged)

### MERFISH raw - get high-qual L2/3
- using anatomical features to filter out cells
- using transcript counts 
- check cell density - insignificant

In [None]:
# filter by L2/3 label
adata_l23 = adata_merged[l23_cells].copy()

# by depth and by counts
conds = np.logical_and(
    adata_l23.obs['depth']    < 400,
    adata_l23.obs['transcript_count'] > 50,
)
adata_mer = adata_l23[conds].copy()
len(adata_l23), len(adata_mer)

In [None]:
width_min = adata_mer.obs.groupby('sample')['width'].min().reindex(names)
width_max = adata_mer.obs.groupby('sample')['width'].max().reindex(names)
width_rng = width_max - width_min 
width_cum = pd.Series(np.cumsum(np.hstack([0, width_rng[:-1]+100])), index=names)

adata_mer.obs['width_n0']    =  adata_mer.obs['width']    - width_min.reindex(adata_mer.obs['sample']).values
adata_mer.obs['width_show2'] =  adata_mer.obs['width_n0'] + width_cum.reindex(adata_mer.obs['sample']).values
adata_mer.obs['depth_show2'] = -adata_mer.obs['depth']

In [None]:
colors = ['C1', 'k']
for i, name in enumerate(names):
    j = i // 4
    color = colors[j]
    
    adatasub = adata_l23[adata_l23.obs['sample']==name]
    sns.histplot(adatasub.obs['depth'].values, element='step', fill=False, color=color)

In [None]:
ns = adata_mer.obs.groupby('sample').size()
ls = adata_mer.obs.groupby('sample')['width'].max() - adata_mer.obs.groupby('sample')['width'].min()
a, b = (ns/ls)[:4], (ns/ls)[4:]
t, p = ttest_ind(a, b)

plt.bar(np.arange(8), ns/ls)
plt.title(p)

### L23 RNA raw (NR as the reference) - all genes 

In [None]:
f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/cheng21_cell_scrna/organized/P28NR.h5ad'
adata_rna = sc.read(f)
# select L2/3, rename genes, preprocessing, get ABC only as the features
adata_rna = adata_rna[adata_rna.obs['Subclass']=='L2/3'].copy()
adata_rna.var.index = merfish_genesets.rename_genes(adata_rna.var.index.values) # rename genes according to MERFISH
adata_rna = preprocessing(adata_rna)
adata_rna

### summarize things needed

In [None]:
# add cols
adata_rna.obs['modality'] = 'rna'
adata_mer.obs['modality'] = 'merfish'

# add zscores (per gene)
lognorm_rna = np.log10(1+np.array(adata_rna.layers['norm'].todense()))
adata_rna.layers['zscore'] = zscore(lognorm_rna, axis=0)

lognorm_mer = adata_mer.layers['ljnorm']  
adata_mer.layers['zscore'] = zscore(lognorm_mer, axis=0)

### spin off an independent NR instance
adata_mer_nr = adata_mer[adata_mer.obs['sample'].str.contains('NR')].copy() 
adata_rna_nr = adata_rna.copy()  # same for RNA - for future simplicity

adata_rna, adata_mer, adata_rna_nr, adata_mer_nr

# Harmonize using ABC genes - rather than overlapping HVGs at a higher level

In [None]:
adata_r    = adata_rna[:,abcgenes]
adata_m    = adata_mer[:, abcgenes]
adata_m_nr = adata_mer_nr[:, abcgenes]

# nrdr (MERFISH) & nr (RNA)
adata_merge = sc.concat([adata_r, adata_m], join='outer')
adata_merge.obsm['X_pca2'] = PCA(n_components=20).fit_transform(adata_merge.layers['zscore']) 

# nr (MERFISH) & nr (RNA)
adata_merge_nr = sc.concat([adata_r, adata_m_nr], join='outer')
adata_merge_nr.obsm['X_pca2'] = PCA(n_components=20).fit_transform(adata_merge_nr.layers['zscore']) 

In [None]:
sc.external.pp.harmony_integrate(adata_merge, 'modality', basis='X_pca2', max_iter_harmony=20)

In [None]:
sc.external.pp.harmony_integrate(adata_merge_nr, 'modality', basis='X_pca2', max_iter_harmony=20)

In [None]:
# assign results `X_pca_harmony` back to the original ones `adata_mer` and `adata_rna`

adata_rna.obsm['X_pca_harmony']    = adata_merge[adata_rna.obs.index].obsm['X_pca_harmony']
adata_mer.obsm['X_pca_harmony']    = adata_merge[adata_mer.obs.index].obsm['X_pca_harmony']

adata_rna_nr.obsm['X_pca_harmony'] = adata_merge_nr[adata_rna_nr.obs.index].obsm['X_pca_harmony']
adata_mer_nr.obsm['X_pca_harmony'] = adata_merge_nr[adata_mer_nr.obs.index].obsm['X_pca_harmony']

# MERFISH only analysis - NR

In [None]:
mat_nr = np.array(adata_mer_nr[:, abcgenes].layers['zscore'])
pcs_typegenes = PCA(n_components=5, random_state=0).fit_transform(mat_nr)

adata_mer_nr.obsm['pcs_typegenes'] = pcs_typegenes

# Downstream calc

In [None]:
def wrap_label_transfer(ref, qry, key_emb, key_lbl, k=30):
    """write results into 
    qry.obs[key_lbl]
    """
    
    # # label transfer from RNA data
    ref_emb = ref.obsm[key_emb][:,:2]
    qry_emb = qry.obsm[key_emb][:,:2]
    ref_lbl = ref.obs[key_lbl].values.astype(str)

    qry_lbl, _, _ = neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None)
    qry.obs[key_lbl] = qry_lbl
    
    return 

In [None]:
def get_abc_scores(adata, agenes, bgenes, cgenes):
    """
    """
    
    # get ABC scores
    g0_a = zscore(adata[:,agenes].layers['ljnorm'], axis=0).mean(axis=1)
    g0_b = zscore(adata[:,bgenes].layers['ljnorm'], axis=0).mean(axis=1)
    g0_c = zscore(adata[:,cgenes].layers['ljnorm'], axis=0).mean(axis=1)

    # make ABC scores comparable and norm to [0,1] [40% to 95%]
    vmin_p, vmax_p = 40, 95
    vmin_a = np.percentile(g0_a, vmin_p)
    vmax_a = np.percentile(g0_a, vmax_p)

    vmin_b = np.percentile(g0_b, vmin_p)
    vmax_b = np.percentile(g0_b, vmax_p)

    vmin_c = np.percentile(g0_c, vmin_p)
    vmax_c = np.percentile(g0_c, vmax_p)

    g0_a = np.clip((g0_a-vmin_a)/(vmax_a-vmin_a), 0, 1)
    g0_b = np.clip((g0_b-vmin_b)/(vmax_b-vmin_b), 0, 1)
    g0_c = np.clip((g0_c-vmin_c)/(vmax_c-vmin_c), 0, 1)

    # separate them into scale and frequency (mag 0~3 vs direction 0 or 1)
    g0_sum  = (g0_a+g0_b+g0_c)
    freq0_a = g0_a/(g0_sum+1e-5)
    freq0_b = g0_b/(g0_sum+1e-5)
    freq0_c = g0_c/(g0_sum+1e-5)

    # record
    adata.obsm['size_freq_abc'] = np.vstack([freq0_a, freq0_b, freq0_c, g0_sum]).T
    
    return

def get_abc_stats(adata, samples):
    """per sample
    
    assigned to the best score
    some na due to 0,0,0
    """
    # score based (ABC) assignment
    res = []
    for sample in samples:
        if 'ant' in sample:
            cond = 'ant'
        elif 'pos' in sample:
            cond = 'pos'

        adatasub = adata[adata.obs['sample']==sample]
        freq_a = adatasub.obsm['size_freq_abc'][:,0]
        freq_b = adatasub.obsm['size_freq_abc'][:,1]
        freq_c = adatasub.obsm['size_freq_abc'][:,2]

        n = len(adatasub)
        cond_na = (freq_a+freq_b+freq_c)==0
        tn = np.sum(cond_na)

        rank = np.argsort(np.vstack([freq_a,freq_b,freq_c]).T[~cond_na], axis=1)[:,-1]
        ta = np.sum(rank==0)
        tb = np.sum(rank==1)
        tc = np.sum(rank==2)

        assert np.abs(n-(ta+tb+tc)-tn) < 1
        res.append([sample, cond, ta/n*100, tb/n*100, tc/n*100, tn/n*100])

    res = pd.DataFrame(res, columns=['sample', 'cond', 'L2/3_A', 'L2/3_B', 'L2/3_C', 'NA']).set_index('sample')
    return res

def get_abc_stats_typeready(adata, samples, sample_col, type_col):
    """
    """
    num_types = adata.obs.groupby([sample_col, type_col]).size().unstack().reindex(samples)
    frq_types = num_types.divide(num_types.sum(axis=1), axis=0)*100
    frq_types['cond'] = np.where(frq_types.index.str.contains('pos'), 'pos', 'ant')
    
    return frq_types

In [None]:
# # label transfer from RNA data
wrap_label_transfer(adata_rna_nr, adata_mer_nr, 'X_pca_harmony', 'Type', k=30)

# # label transfer from RNA data
wrap_label_transfer(adata_rna, adata_mer, 'X_pca_harmony', 'Type', k=30) # note that adata_rna == adata_rna_nr

In [None]:
%%time

### 3 NR
# RNA+MERFISH NR 
X = adata_merge_nr.obsm['X_pca_harmony'][:,:2].T
XC0 = get_aa(X)

# RNA NR
X = adata_rna_nr.obsm['X_pca_harmony'][:,:2].T
XC1 = get_aa(X)

# MERFISH NR
X = adata_mer_nr.obsm['X_pca_harmony'][:,:2].T
XC2 = get_aa(X)

# MERFISH NRDR
X = adata_mer.obsm['X_pca_harmony'][:,:2].T
XC = get_aa(X)

# MERFISH NR no integration
X = adata_mer_nr.obsm['pcs_typegenes'][:,[0,2]].T
XC_alone = get_aa(X)

In [None]:
# abc scores
get_abc_scores(adata_mer, agenes, bgenes, cgenes)
get_abc_scores(adata_mer_nr, agenes, bgenes, cgenes)

# score based (ABC) assignment
typefrq_abc    = get_abc_stats(adata_mer, names)
typefrq_abc_nr = get_abc_stats(adata_mer_nr, names[:4])

In [None]:
# label-transfer-based (PC1/PC2) assignment 
typefrq_lbt = get_abc_stats_typeready(adata_mer, names, 'sample', 'Type')
typefrq_lbt_nr = get_abc_stats_typeready(adata_mer_nr, names[:4], 'sample', 'Type')

typefrq_lbt

In [None]:
# compare two results?
# what to expect?

# Viz set up

In [None]:
adata_rna.obs['hpc1']    = np.array(adata_rna.obsm['X_pca_harmony'][:,0])
adata_rna.obs['hpc2']    = np.array(adata_rna.obsm['X_pca_harmony'][:,1])

adata_mer.obs['hpc1']    = np.array(adata_mer.obsm['X_pca_harmony'][:,0])
adata_mer.obs['hpc2']    = np.array(adata_mer.obsm['X_pca_harmony'][:,1])

adata_merge.obs['hpc1']  = np.array(adata_merge.obsm['X_pca_harmony'][:,0])
adata_merge.obs['hpc2']  = np.array(adata_merge.obsm['X_pca_harmony'][:,1])

adata_rna_nr.obs['hpc1'] = np.array(adata_rna_nr.obsm['X_pca_harmony'][:,0])
adata_rna_nr.obs['hpc2'] = np.array(adata_rna_nr.obsm['X_pca_harmony'][:,1])

adata_mer_nr.obs['hpc1'] = np.array(adata_mer_nr.obsm['X_pca_harmony'][:,0])
adata_mer_nr.obs['hpc2'] = np.array(adata_mer_nr.obsm['X_pca_harmony'][:,1])

adata_merge_nr.obs['hpc1'] = np.array(adata_merge_nr.obsm['X_pca_harmony'][:,0])
adata_merge_nr.obs['hpc2'] = np.array(adata_merge_nr.obsm['X_pca_harmony'][:,1])

adata_plot0 = adata_rna
adata_plot1 = adata_mer 

adata_plot00 = adata_rna_nr
adata_plot01 = adata_mer_nr

clsts_palette2 = {
    'L2/3_A': 'C0',    
    'L2/3_B': 'C1',    
    'L2/3_C': 'C2',    
    'NA': 'gray',
}

In [None]:
from matplotlib.colors import LinearSegmentedColormap

# ABC map
colors_a = [(0.0, 'black'), (1.0, 'C0')]      
colors_b = [(0.0, 'black'), (1.0, 'C1')]      
colors_c = [(0.0, 'black'), (1.0, 'C2')]      
cmap_a = LinearSegmentedColormap.from_list('cmap_a', colors_a)
cmap_b = LinearSegmentedColormap.from_list('cmap_b', colors_b)
cmap_c = LinearSegmentedColormap.from_list('cmap_c', colors_c)

# NRDR map
colors_nr = [(0.0, 'white'), (1.0, 'C1'),]
colors_dr = [(0.0, 'white'), (1.0, 'black'),]
colors_nrdr = [(0.0, 'C1'), (0.5, 'white'), (1.0, 'black')]

cmap_nr = LinearSegmentedColormap.from_list('cmap_nr', colors_nr)
cmap_dr = LinearSegmentedColormap.from_list('cmap_dr', colors_dr)
cmap_nrdr = LinearSegmentedColormap.from_list('cmap_nrdr', colors_nrdr)

# MERFISH NR alone

In [None]:
metrics = ['gnnum', 'transcript_count', 'jnorm_transcript_count',] #  'depth',]
for j in range(3):
    
    y = adata_mer_nr.obsm['pcs_typegenes'][:,j]
    fig, axs = plt.subplots(1,3,figsize=(3*4,1*4), sharey=True)
    for i, metric in enumerate(metrics):
        ax = axs[i]
        x = adata_mer_nr.obs[metric]
        r, _ = spearmanr(x,y)
        ax.scatter(x, y, s=1, rasterized=True)
        ax.set_ylabel(f'PC{j+1}')
        ax.set_title(f'r={r:.2f}')
        ax.set_xscale('log')
        ax.set_xlabel(metric)
        sns.despine(ax=ax)
    fig.tight_layout()
    
    fig_manager.savefig(fig)
    plt.show()

In [None]:
metrics = ['gnnum', 'transcript_count', 'jnorm_transcript_count',] #  'depth',]
rss = [[],[],[]]
for i, metric in enumerate(metrics):
    rs = []
    x = adata_mer_nr.obs[metric]
    for j in range(5):
        y = adata_mer_nr.obsm['pcs_typegenes'][:,j]
        r, _ = spearmanr(x,y)
        rs.append(r)
    rss[i] = rs
        

fig, axs = plt.subplots(1,3,figsize=(3*4,1*4))
for i in range(3):
    axs[i].bar(np.arange(5), np.abs(rss[i]))

In [None]:
xi, yi = 0, 1

gns = [agenes, bgenes, cgenes]
titles = ['A genes', 'B genes', 'C genes']

fig, axs = plt.subplots(1,3,figsize=(4*5,1*3), sharex=True, sharey=True)
for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
    g = adata_mer_nr[:,gn].layers['ljnorm'].mean(axis=1)
    x = adata_mer_nr.obsm['pcs_typegenes'][:,xi]
    y = adata_mer_nr.obsm['pcs_typegenes'][:,yi]

    p = utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, cmap='coolwarm', vmin_p=5, vmax_p=95)
    colorbar = plt.colorbar(p, aspect=5, shrink=0.3)

    # Show ticks but no grid
    ax.set_aspect('equal')
    ax.axis('on')
    ax.grid(False)  # Turn off grid lines
    sns.despine(ax=ax)
    ax.tick_params(axis='both', which='both', bottom=True, left=True)
    
    ax.set_xlabel(f'PC{xi+1}')
    ax.set_ylabel(f'PC{yi+1}')

    ax.set_title(title)


fig_manager.savefig(fig)
plt.show()

In [None]:
xi, yi = 0, 2

gns = [agenes, bgenes, cgenes]
titles = ['A genes', 'B genes', 'C genes']

fig, axs = plt.subplots(1,3,figsize=(4*5,1*3), sharex=True, sharey=True)
for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
    g = adata_mer_nr[:,gn].layers['ljnorm'].mean(axis=1)
    x = adata_mer_nr.obsm['pcs_typegenes'][:,xi]
    y = adata_mer_nr.obsm['pcs_typegenes'][:,yi]

    p = utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, cmap='coolwarm', vmin_p=5, vmax_p=95)
    colorbar = plt.colorbar(p, aspect=5, shrink=0.3)

    # Show ticks but no grid
    ax.set_aspect('equal')
    ax.axis('on')
    ax.grid(False)  # Turn off grid lines
    sns.despine(ax=ax)
    ax.tick_params(axis='both', which='both', bottom=True, left=True)
    
    ax.set_xlabel(f'PC{xi+1}')
    ax.set_ylabel(f'PC{yi+1}')
    
    
    add_triangle(XC_alone, ax)

    ax.set_title(title)


fig_manager.savefig(fig)
plt.show()

# MERFISH alone - triangle stability

In [None]:
def shuff_genes(mat, seed=0):
    """
    shuffle each column independently
    
    if cell by gene as input; it will shuffle each gene independently across cells
    """
    rng = np.random.default_rng(seed=seed)
    mat_shuff = rng.permuted(mat, axis=0)
    return mat_shuff

In [None]:
def downsample_X(X, p=0.8):
    """downsample cells (assuming it is feature by cell matrix)
    """
    return X[:,np.random.rand(X.shape[1])<p]

In [None]:
def aa_inference(X):
    """
    """
    XC, _, _, _, _ = PCHA(X, noc=3, delta=0)
    XC = np.array(XC)
    XC = XC[:,np.argsort(XC[0])] # assign an order according to x-axis 
    return XC

# viz NR only

In [None]:
fig, axs = plt.subplots(1,2, figsize=(2*6,1*5))
ax = axs[0]
sns.scatterplot(data=adata_merge_nr.obs.sample(frac=1, replace=False), 
                x='hpc1', y='hpc2', hue='modality', s=5, edgecolor='none', 
                ax=ax,
               )
add_triangle(XC0, ax)
# add_triangle(XC1, ax)
# add_triangle(XC2, ax)
ax.set_aspect('equal')

ax = axs[1]
sns.scatterplot(data=adata_merge_nr.obs.sample(frac=1, replace=False), 
                x='hpc1', y='hpc2', hue='Type', s=5, edgecolor='none', 
                ax=ax,
               )
add_triangle(XC0, ax)
# add_triangle(XC1, ax)
# add_triangle(XC2, ax)
ax.set_aspect('equal')
plt.show()

In [None]:
fig, axs = plt.subplots(1,2,figsize=(2*6,1*5), sharex=True, sharey=True)
ax = axs[0]
ax.set_title('rna')
sns.scatterplot(data=adata_plot00.obs.sample(frac=1, replace=False), 
                ax=ax, x='hpc1', y='hpc2', 
                hue='Type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                legend=False,
                s=5, edgecolor='none', rasterized=True)
ax.set_aspect('equal')
add_triangle(XC0, ax, vertices=True, edgecolors='k', linewidths=1, marker='o')
sns.despine(ax=ax)
ax.grid(False)

ax = axs[1]
ax.set_title('merfish')
sns.scatterplot(data=adata_plot01.obs.sample(frac=1, replace=False), 
                ax=ax, x='hpc1', y='hpc2', 
                hue='Type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                legend=False,
                s=5, edgecolor='none', rasterized=True)
ax.set_aspect('equal')
add_triangle(XC0, ax, vertices=True, edgecolors='k', linewidths=1, marker='o')
sns.despine(ax=ax)
ax.grid(False)
fig_manager.savefig(fig)
plt.show()

In [None]:
gns = ['Cdh13', 'Sorcs3', 'Trpc6', 'Chrm2'] 
n = len(gns)
titles = gns
dfplot = adata_plot01.obs.sample(frac=1, replace=False)

x = dfplot['hpc1'].values
y = dfplot['hpc2'].values

fig, axs = plt.subplots(1,n,figsize=(n*5,1*3), sharex=True, sharey=True)
for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
    g = adata_plot01[dfplot.index, gn].layers['ljnorm'].reshape(-1,)
    vmin = np.percentile(g,  0)
    vmax = np.percentile(g, 95)
    
    # p = utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, vmin=vmin, vmax=vmax, cmap='rocket_r')
    
    sorting = np.argsort(g)
    p = utils_merfish.st_scatter_ax(fig, ax, x[sorting], y[sorting], gexp=g[sorting], s=5, vmin=vmin, vmax=vmax, cmap='rocket_r')
    colorbar = plt.colorbar(p, aspect=5, shrink=0.3)

    # Show ticks but no grid
    ax.set_aspect('equal')
    ax.axis('on')
    ax.grid(False)  # Turn off grid lines
    sns.despine(ax=ax)
    ax.tick_params(axis='both', which='both', bottom=True, left=True)

    ax.set_title(title)

    add_triangle(XC0, ax)
fig_manager.savefig(fig)
plt.show()

In [None]:
bins = np.linspace(100,400,30)

fig, ax = plt.subplots(1,1,figsize=(1*5,1*3)) 
adatasub = adata_mer_nr
depth = -adatasub.obs['depth_show2']
lbls = adatasub.obs['Type']

p = sns.histplot(depth[lbls=='L2/3_A'], bins=bins, ax=ax, element='step', fill=False, stat='percent') 
p = sns.histplot(depth[lbls=='L2/3_B'], bins=bins, ax=ax, element='step', fill=False, stat='percent') 
p = sns.histplot(depth[lbls=='L2/3_C'], bins=bins, ax=ax, element='step', fill=False, stat='percent') 

# Show ticks but no grid
ax.axis('on')
ax.grid(False)  # Turn off grid lines
sns.despine(ax=ax)
ax.tick_params(axis='both', which='both', bottom=True, left=True)
ax.set_title('type distribution (label transfer)', y=1.1)
fig_manager.savefig(fig)
plt.show()

fig, ax = plt.subplots(1,1,figsize=(1*5,1*3))
adatasub = adata_mer_nr
depth = -adatasub.obs['depth_show2']

freq_a = adatasub.obsm['size_freq_abc'][:,0]
freq_b = adatasub.obsm['size_freq_abc'][:,1]
freq_c = adatasub.obsm['size_freq_abc'][:,2]

p = sns.histplot(depth[freq_a > 0.6], bins=bins, ax=ax, element='step', fill=False, stat='percent')
p = sns.histplot(depth[freq_b > 0.6], bins=bins, ax=ax, element='step', fill=False, stat='percent') 
p = sns.histplot(depth[freq_c > 0.6], bins=bins, ax=ax, element='step', fill=False, stat='percent') 

# Show ticks but no grid
ax.axis('on')
ax.grid(False)  # Turn off grid lines
sns.despine(ax=ax)
ax.tick_params(axis='both', which='both', bottom=True, left=True)
ax.set_title('type distribution (abc)', y=1.1)
fig_manager.savefig(fig)
plt.show()

# viz - NR vs DR

In [None]:
adatas = [
    adata_mer,
    adata_mer[adata_mer.obs['sample'].str.contains('ant')],
    adata_mer[adata_mer.obs['sample'].str.contains('pos')],
]
conditions = ['combined', 'ant', 'pos']

In [None]:
fig, axs = plt.subplots(1,2,figsize=(2*6,1*5), sharex=True, sharey=True)
ax = axs[0]
ax.set_title('MERFISH - ant')
sns.scatterplot(data=adata_plot1[adata_plot1.obs['sample'].str.contains('ant')].obs.sample(frac=1, replace=False), 
                ax=ax, x='hpc1', y='hpc2', 
                hue='Type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                legend=False,
                s=5, edgecolor='none')
ax.set_aspect('equal')
add_triangle(XC, ax)

ax = axs[1]
ax.set_title('MERFISH - pos')
sns.scatterplot(data=adata_plot1[adata_plot1.obs['sample'].str.contains('pos')].obs.sample(frac=1, replace=False), 
                ax=ax, x='hpc1', y='hpc2', 
                hue='Type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                legend=False,
                s=5, edgecolor='none')
ax.set_aspect('equal')
add_triangle(XC, ax)
plt.show()

In [None]:
metrics = ['gncov', 'gnnum', 'depth', 'width_show2']

fig, axs = plt.subplots(1,4,figsize=(4*5,1*4))
for metric, ax in zip(metrics, axs):
    # g = np.log10(1+adata.obs[metric])
    g = adata_mer.obs[metric]
    x = adata_mer.obsm['X_pca_harmony'][:,0]
    y = adata_mer.obsm['X_pca_harmony'][:,1]
    utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, )
    add_triangle(XC, ax)
    ax.set_title(metric)

In [None]:
gns = [agenes, bgenes, cgenes,]
titles = ['A genes', 'B genes', 'C genes', ]

fig, axss = plt.subplots(3,4,figsize=(4*5,3*3), sharex=True, sharey=True)
for i, (axs, adatasub, condition) in enumerate(zip(axss, adatas, conditions)):
    condition = conditions[i]
    for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
        g = adatasub[:,gn].layers['ljnorm'].mean(axis=1)
        x = adatasub.obsm['X_pca_harmony'][:,0]
        y = adatasub.obsm['X_pca_harmony'][:,1]
        
        # consistent over
        g0 = adata_mer[:,gn].layers['ljnorm'].mean(axis=1)
        vmin = np.percentile(g0,  5)
        vmax = np.percentile(g0, 95)
            
        p = utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, cmap='coolwarm', vmin=vmin, vmax=vmax)
        colorbar = plt.colorbar(p, aspect=5, shrink=0.3)
        
        # Show ticks but no grid
        ax.set_aspect('equal')
        ax.axis('on')
        ax.grid(False)  # Turn off grid lines
        sns.despine(ax=ax)
        ax.tick_params(axis='both', which='both', bottom=True, left=True)
        
        if i == 0:
            ax.set_title(title)
        if j == 0:
            ax.set_ylabel(condition, rotation=0, loc='top')
            
        # add the triangle
        add_triangle(XC, ax)
fig_manager.savefig(fig)
plt.show()

In [None]:
gns = ['Cdh13', 'Trpc6', 'Sorcs3', 'Chrm2', 'Fos']
titles = gns
n = len(gns)

fig, axss = plt.subplots(3,n,figsize=(n*5,3*3), sharex=True, sharey=True)
for i, (axs, adatasub, condition) in enumerate(zip(axss, adatas, conditions)):
    condition = conditions[i]
    for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
        g = adatasub[:,gn].layers['ljnorm'].reshape(-1,)
        x = adatasub.obsm['X_pca_harmony'][:,0]
        y = adatasub.obsm['X_pca_harmony'][:,1]
        sorting = np.argsort(g)
        
        # consistent over
        g0 = adata_mer[:,gn].layers['ljnorm'].mean(axis=1)
        vmin = np.percentile(g0,  0)
        vmax = np.percentile(g0, 95)
            
        p = utils_merfish.st_scatter_ax(fig, ax, x[sorting], y[sorting], gexp=g[sorting], s=5, vmin=vmin, vmax=vmax)
        colorbar = plt.colorbar(p, aspect=5, shrink=0.3)
        
        # Show ticks but no grid
        ax.set_aspect('equal')
        ax.axis('on')
        ax.grid(False)  # Turn off grid lines
        sns.despine(ax=ax)
        ax.tick_params(axis='both', which='both', bottom=True, left=True)
        
        if i == 0:
            ax.set_title(title)
        if j == 0:
            ax.set_ylabel(condition, rotation=0, loc='top')
            
        add_triangle(XC, ax)
fig_manager.savefig(fig)
plt.show()

In [None]:
gns = [
    'Matn2', 'Otof',
    'Hkdc1', 'Lynx1', 'Stard8', 'Lamp5', 'Rgs8', 'Igfn1']
titles = gns
n = len(gns)

fig, axss = plt.subplots(3,n,figsize=(n*5,3*3), sharex=True, sharey=True)
for i, (axs, adatasub, condition) in enumerate(zip(axss, adatas, conditions)):
    condition = conditions[i]
    for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
        g = adatasub[:,gn].layers['ljnorm'].reshape(-1,)
        x = adatasub.obsm['X_pca_harmony'][:,0]
        y = adatasub.obsm['X_pca_harmony'][:,1]
        sorting = np.argsort(g)
        
        # consistent over
        g0 = adata_mer[:,gn].layers['ljnorm'].mean(axis=1)
        vmin = np.percentile(g0,  0)
        vmax = np.percentile(g0, 95)
            
        p = utils_merfish.st_scatter_ax(fig, ax, x[sorting], y[sorting], gexp=g[sorting], s=5, vmin=vmin, vmax=vmax)
        colorbar = plt.colorbar(p, aspect=5, shrink=0.3)
        
        # Show ticks but no grid
        ax.set_aspect('equal')
        ax.axis('on')
        ax.grid(False)  # Turn off grid lines
        sns.despine(ax=ax)
        ax.tick_params(axis='both', which='both', bottom=True, left=True)
        
        if i == 0:
            ax.set_title(title)
        if j == 0:
            ax.set_ylabel(condition, rotation=0, loc='top')
            
        add_triangle(XC, ax)
fig_manager.savefig(fig)
plt.show()

In [None]:
xmin, xmax = -12, 12
ymin, ymax = -7, 7 

bins_x = np.linspace(xmin, xmax, 1*(xmax-xmin)+1)
bins_y = np.linspace(ymin, ymax, 1*(ymax-ymin)+1)

hists = []
fig, axs = plt.subplots(1,4,figsize=(4*5,1*4), sharex=True, sharey=True)
for ax, adatasub, cond, _cmap in zip(axs, adatas, conditions, ['gray_r', cmap_nr, cmap_dr]):
    x = adatasub.obsm['X_pca_harmony'][:,0]
    y = adatasub.obsm['X_pca_harmony'][:,1]
    sns.histplot(x=x, y=y, ax=ax, bins=(bins_x, bins_y), 
                 cmap=_cmap, # 'gray_r', 
                 stat='percent', vmin=0, vmax=2, 
                 cbar=True, cbar_kws=dict(shrink=0.4, ticks=[0,2]))
    # sns.kdeplot(x=x, y=y, ax=ax, bins=(bins_x, bins_y))
    
    hist, _, _= np.histogram2d(x, y, bins=(bins_x, bins_y))
    hist = hist/len(x)*100
    hists.append(hist)
    print(hist.shape)
    ax.set_title(cond)
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.grid(False)
    
    # add the triangle
    add_triangle(XC, ax, zorder=2)
    
ax = axs[3] 
ax.set_title('pos-ant')
g = ax.imshow(
    pd.DataFrame(hists[2]-hists[1], index=bins_x[1:]-0.5, columns=bins_y[1:]-0.5).T, 
            origin='lower',
            extent=(xmin, xmax, ymin, ymax),
            # cmap='coolwarm', 
            cmap=cmap_nrdr, 
            vmax=2, vmin=-2)
# ax.invert_yaxis()
ax.set_aspect('equal')
ax.grid(False)
fig.colorbar(g, shrink=0.4, ticks=[-2,0,2])
sns.despine(ax=ax)

# add the triangle
add_triangle(XC, ax, zorder=2)
fig_manager.savefig(fig)

plt.show()

In [None]:
dfshow = adata_mer.obs.copy()
dfshow['nrdr'] = dfshow['sample'].str.contains('DR').astype(int)
# dfshow['dim1'] = np.array(adata.obsm['X_pca_harmony'][:,0])
# dfshow['dim2'] = np.array(adata.obsm['X_pca_harmony'][:,1])
palette = {0: 'C1', 1: 'black'}


fig, axs = plt.subplots(1,2,figsize=(2*5,1*4), sharex=True, sharey=True) 
ax = axs[0]
add_triangle(XC, ax)
sns.scatterplot(data=dfshow.sample(frac=1), x='hpc1', y='hpc2', hue='nrdr', s=3, edgecolor='none', palette=palette, ax=ax)
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.grid(False)
ax.legend(bbox_to_anchor=(1,1))

ax = axs[1]
add_triangle(XC, ax)
sns.kdeplot(data=dfshow, x='hpc1', y='hpc2', hue='nrdr', palette=palette, legend=False, ax=ax,)
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.grid(False)


fig.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(2,4,figsize=(4*4,2*3), sharex=True, sharey=True) 
for sample, ax in zip(names, axs.flat):
    add_triangle(XC, ax)
    sns.scatterplot(data=dfshow[dfshow['sample']==sample], 
                    x='hpc1', y='hpc2', hue='nrdr', s=5, edgecolor='none', palette=palette, ax=ax, legend=False)
    sns.kdeplot(data=dfshow[dfshow['sample']==sample],
                x='hpc1', y='hpc2', hue='nrdr', palette=palette, legend=False, ax=ax,)
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.set_title(sample)
    ax.grid(False)
    # ax.legend(bbox_to_anchor=(1,1))

fig.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(2,4,figsize=(4*4,2*3), sharex=True, sharey=True) 
for sample, ax in zip(names, axs.flat):
    add_triangle(XC, ax)
    adatasub = adata_mer[adata_mer.obs['sample']==sample]
    x = adatasub.obsm['X_pca_harmony'][:,0]
    y = adatasub.obsm['X_pca_harmony'][:,1]
    sns.histplot(x=x, y=y, ax=ax, bins=(bins_x, bins_y), 
                 cmap=_cmap, # 'gray_r', 
                 stat='percent', vmin=0, vmax=2, 
                 cbar=True, cbar_kws=dict(shrink=0.4, ticks=[0,2]))
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.set_title(sample)
    ax.grid(False)
    # ax.legend(bbox_to_anchor=(1,1))

fig.tight_layout()
plt.show()

# ABC scores - expression level distributions

In [None]:
fig, axs = plt.subplots(1,3, figsize=(4*3,4), sharex=False, sharey=True)
for ax, genegroup, title in zip(axs, 
                                [agenes, bgenes, cgenes, ], 
                                ['A genes', 'B genes', 'C genes', ],
                               ):
    for i, sample in enumerate(names):
        scores_ = adata_mer[adata_mer.obs['sample']==sample][:,genegroup].layers['ljnorm'].mean(axis=1)
        if 'ant' in sample:
            color = 'C1'
        elif 'pos' in sample:
            color = 'black'

        sns.ecdfplot(scores_, ax=ax, color=color)#, complementary=True) # , linewidth=2)
        
        # _x = np.percentile(scores_, 50)
        # _y = 0.5
        # ax.text(_x, _y, i, fontsize=10, color='red')
    ax.set_ylabel('Cumulative proportion\nof cells')
    ax.set_xlabel('Mean log norm expr.')
    ax.set_title(title)
    sns.despine(ax=ax)
    ax.grid(False)
    

fig.tight_layout()
fig_manager.savefig(fig)
plt.show()

In [None]:
fig, axs = plt.subplots(1,3,figsize=(3*5,1*3), sharex=True, sharey=True)
for i, (ax, adatasub, condition) in enumerate(zip(axs, adatas, conditions)):
    condition = conditions[i]
    x = adatasub.obsm['X_pca_harmony'][:,0]
    y = adatasub.obsm['X_pca_harmony'][:,1]
    
    freq_a = adatasub.obsm['size_freq_abc'][:,0]
    freq_b = adatasub.obsm['size_freq_abc'][:,1]
    freq_c = adatasub.obsm['size_freq_abc'][:,2]
    
    # visualize ABC scores using additive blending
    additive = (cmap_a(freq_a)+cmap_b(freq_b)+cmap_c(freq_c))[:,:3]
    p = ax.scatter(x, y, c=additive, s=5, edgecolor='none') 
        
    # Show ticks but no grid
    ax.set_aspect('equal')
    ax.axis('on')
    ax.grid(False)  # Turn off grid lines
    sns.despine(ax=ax)
    ax.tick_params(axis='both', which='both', bottom=True, left=True)
    ax.set_title(condition)

    # add the triangle
    add_triangle(XC, ax, vertices=True, edgecolors='k', linewidths=1, marker='o')
    
fig_manager.savefig(fig)
plt.show()

In [None]:

fig, axs = plt.subplots(1,3,figsize=(3*5,1*3), sharex=True, sharey=True)
for i, (ax, adatasub, condition) in enumerate(zip(axs, adatas, conditions)):
    condition = conditions[i]
    x = adatasub.obsm['X_pca_harmony'][:,0]
    y = adatasub.obsm['X_pca_harmony'][:,1]
    
    freq_a = adatasub.obsm['size_freq_abc'][:,0]
    freq_b = adatasub.obsm['size_freq_abc'][:,1]
    freq_c = adatasub.obsm['size_freq_abc'][:,2]
    
    # visualize ABC scores using additive blending
    additive = (cmap_a(freq_a)+cmap_b(freq_b)+cmap_c(freq_c))[:,:3]
    
    cond = np.max([freq_a, freq_b, freq_c], axis=0) > 0.6
    # p = ax.scatter(x, y, c=additive, s=5, edgecolor='none') 
    p = ax.scatter(x[cond], y[cond], c=additive[cond], s=5, edgecolor='none') 
        
    # Show ticks but no grid
    ax.set_aspect('equal')
    ax.axis('on')
    ax.grid(False)  # Turn off grid lines
    sns.despine(ax=ax)
    ax.tick_params(axis='both', which='both', bottom=True, left=True)
    ax.set_title(condition)

    # add the triangle
    add_triangle(XC, ax, vertices=True, edgecolors='k', linewidths=1, marker='o')
    
fig_manager.savefig(fig)
plt.show()

In [None]:
bins = np.linspace(100,400,30)

fig, axs = plt.subplots(1,3,figsize=(3*5,1*3), sharex=True) #, sharey=True)
for i, (ax, adatasub, condition) in enumerate(zip(axs, adatas, conditions)):
    condition = conditions[i]
    depth = -adatasub.obs['depth_show2']
    
    lbls = adatasub.obs['Type']
    
    p = sns.histplot(depth[lbls=='L2/3_A'], bins=bins, ax=ax, element='step', fill=False, stat='percent') 
    p = sns.histplot(depth[lbls=='L2/3_B'], bins=bins, ax=ax, element='step', fill=False, stat='percent') 
    p = sns.histplot(depth[lbls=='L2/3_C'], bins=bins, ax=ax, element='step', fill=False, stat='percent') 
        
    # Show ticks but no grid
    ax.axis('on')
    ax.grid(False)  # Turn off grid lines
    sns.despine(ax=ax)
    ax.tick_params(axis='both', which='both', bottom=True, left=True)
    ax.set_title(condition)
fig.suptitle('type distribution (label transfer)', y=1.1)
fig_manager.savefig(fig)
plt.show()

    
fig, axs = plt.subplots(1,3,figsize=(3*5,1*3), sharex=True) #, sharey=True)
for i, (ax, adatasub, condition) in enumerate(zip(axs, adatas, conditions)):
    condition = conditions[i]
    depth = -adatasub.obs['depth_show2']
    
    freq_a = adatasub.obsm['size_freq_abc'][:,0]
    freq_b = adatasub.obsm['size_freq_abc'][:,1]
    freq_c = adatasub.obsm['size_freq_abc'][:,2]
    
    p = sns.histplot(depth[freq_a > 0.6], bins=bins, ax=ax, element='step', fill=False, stat='percent')
    p = sns.histplot(depth[freq_b > 0.6], bins=bins, ax=ax, element='step', fill=False, stat='percent') 
    p = sns.histplot(depth[freq_c > 0.6], bins=bins, ax=ax, element='step', fill=False, stat='percent') 
        
    # Show ticks but no grid
    ax.axis('on')
    ax.grid(False)  # Turn off grid lines
    sns.despine(ax=ax)
    ax.tick_params(axis='both', which='both', bottom=True, left=True)
    ax.set_title(condition)
fig.suptitle('type distribution (abc)', y=1.1)
fig_manager.savefig(fig)
plt.show()

In [None]:
df_plot = adata_mer.obs.sample(frac=1, replace=False)
x =  df_plot['width_show2']
y =  df_plot['depth_show2']
c =  df_plot['Type']

fig, ax = plt.subplots(1,1,figsize=(1*25,1))
for lbl, coord in width_cum.items():
    ax.text(coord, 0, lbl, fontsize=12)
    
p = ax.scatter(x, y, c=[clsts_palette2[_c] for _c in c], s=5, edgecolor='none') # , alpha=g_size[sorting])
# cond = np.max(additive, axis=1) > 1e-5
# p = ax.scatter(x[cond], y[cond], c=normed[cond], s=1, edgecolor='none')
ax.set_aspect('equal')
ax.axis('off')

plt.show()

In [None]:
x =  adata_mer.obs['width_show2']
y =  adata_mer.obs['depth_show2']

freq_a = adata_mer.obsm['size_freq_abc'][:,0]
freq_b = adata_mer.obsm['size_freq_abc'][:,1]
freq_c = adata_mer.obsm['size_freq_abc'][:,2]

# visualize ABC scores using additive blending
additive = (cmap_a(freq_a)+cmap_b(freq_b)+cmap_c(freq_c))[:,:3]

fig, ax = plt.subplots(1,1,figsize=(1*25,1))
for lbl, coord in width_cum.items():
    ax.text(coord, 0, lbl, fontsize=12)
    
sorting = np.argsort(np.max(additive, axis=1))# [::-1]
p = ax.scatter(x[sorting], y[sorting], c=additive[sorting], s=5, edgecolor='none') 

ax.set_aspect('equal')
ax.axis('off')

plt.show()

# score based (ABC) assignment

In [None]:
def plot_typefreq(res, include_na=True):
    """
    """
    if include_na:
        unq_lbls = ['L2/3_A', 'L2/3_B', 'L2/3_C', 'NA']
        unq_colors = ['C0', 'C1', 'C2', 'gray']
        n = 4
    else:
        unq_lbls = ['L2/3_A', 'L2/3_B', 'L2/3_C']
        unq_colors = ['C0', 'C1', 'C2']
        n = 3
        
    fig, axs = plt.subplots(1, n, figsize=(2*n,4))
    allps = []
    for ax, col, color in zip(axs, unq_lbls, unq_colors):
        sns.barplot(data=res, x='cond', y=col, ax=ax, color=color, capsize=0.3, errwidth=1)
        sns.swarmplot(data=res, x='cond', y=col, color='k', ax=ax, )
        ax.set_title(f'type {col}', y=1.1)
        sns.despine(ax=ax)
        ax.grid(False)
        ax.set_ylabel('')

        a = res[res['cond']=='ant'][col]
        b = res[res['cond']=='pos'][col]
        # s, p = mannwhitneyu(a, b)
        s, p = ttest_ind(a, b)

        allps.append(p)

    rej, allqs, _, _ = multipletests(allps, method='fdr_bh')
    for ax, col, color, q in zip(axs, unq_lbls, unq_colors, allqs):
        mark = p_mark(q)
        # statistical annotation
        x1, x2 = 0, 1   
        y, h = res[col].max() + 2, 2
        ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c='k')
        ax.text((x1+x2)*.5, y+h, mark, ha='center', va='bottom', color='k')

    axs[0].set_ylabel('L2/3 cells (%)')
    fig.tight_layout()
    return fig

In [None]:
typefrq_abc

In [None]:
fig = plot_typefreq(typefrq_abc, include_na=True)
fig_manager.savefig(fig)

In [None]:
fig = plot_typefreq(typefrq_lbt, include_na=False)
fig_manager.savefig(fig)

# visualize FISH

In [None]:
adata_mer

In [None]:
gns = ['Cdh13', 'Trpc6', 'Sorcs3', 'Chrm2', 'Fos'] 
x =  adata_mer.obs['width_show2']
y =  adata_mer.obs['depth_show2']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*25,n*1))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 0, lbl, fontsize=12)
    
    g = adata_mer[:,gn].layers['ljnorm'].reshape(-1,)
    vmax = np.percentile(g, 99)
    vmin = np.percentile(g,  0)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_title(gn, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
fig_manager.savefig(fig)
plt.show()
    

In [None]:
gns = [
    'Matn2', 'Otof', 
    # 'Srf', 'Jund',
    'Hkdc1', 'Lynx1', 'Stard8', 'Lamp5', 'Rgs8', 'Igfn1']
x =  adata_mer.obs['width_show2']
y =  adata_mer.obs['depth_show2']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*25,n*1))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 0, lbl, fontsize=12)
    
    g = adata_mer[:,gn].layers['ljnorm'].reshape(-1,)
    vmax = np.percentile(g, 99)
    vmin = np.percentile(g,  0)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_title(gn, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
fig_manager.savefig(fig)
plt.show()
    

In [None]:
gns = ['Cdh13', 'Trpc6', 'Chrm2',] 
# adata_plot = adata_mer[adata_mer.obs['sample']=='P28NRa_pos']
# adata_plot = adata_mer[adata_mer.obs['sample']=='P28NRa_pos']
adata_plot = adata_mer[adata_mer.obs['sample'].str.contains('NR')] #=='P28NRa_pos']

x =  adata_plot.obs['width_show2']
y =  adata_plot.obs['depth_show2']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*10,n*1))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    g = adata_plot[:,gn].layers['ljnorm'].reshape(-1,)
    vmax = np.percentile(g, 99)
    vmin = np.percentile(g,  0)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_title(gn, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
fig_manager.savefig(fig)
plt.show()
    

In [None]:
gns = ['Cdh13', 'Trpc6', 'Chrm2',] 
bins = np.linspace(100,400,12)
adata_plot = adata_mer[adata_mer.obs['sample'].str.contains('NR')] #=='P28NRa_pos']

x = -adata_plot.obs['depth_show2']
n = len(gns)

fig, ax = plt.subplots(1,1,figsize=(1*5,4*1))
for i, gn in enumerate(gns):
    g = adata_plot[:,gn].layers['ljnorm'].reshape(-1,)
    
    
    tmp = pd.DataFrame(np.vstack([x, g]).T)
    tmp['b'] = pd.cut(tmp[0], bins=bins)
    tmp_mean = tmp.groupby('b').mean()
    
    _x = tmp_mean[0].values
    _y = tmp_mean[1].values
    
    ax.plot(_x, _y/np.max(_y), label=gn)
    # ax.plot(_x, _y, label=gn)
    # break
    
ax.legend(bbox_to_anchor=(1,1))
ax.grid(False)
sns.despine(ax=ax)
ax.set_ylabel('norm. expression')
ax.set_xlabel('cortical depth')
ax.set_xlim([100,400])
ax.set_ylim([0,1])
    
fig_manager.savefig(fig)
plt.show()
    

In [None]:
gns = [agenes, bgenes, cgenes,] 
x =  adata_mer.obs['width_show2']
y =  adata_mer.obs['depth_show2']
titles = ['A genes', 'B genes', 'C genes', ]
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*25,n*1))
for i, (ax, gn, title) in enumerate(zip(axs, gns, titles)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 0, lbl, fontsize=12)
    
    # g = adata_mer[:,gn].layers['ljnorm'].mean(axis=1)
    g = adata_mer[:,gn].layers['zscore'].mean(axis=1)
    sorting = np.argsort(g)
    
    vmin = np.percentile(g,  0)
    vmax = np.percentile(g, 95)
    p = utils_merfish.st_scatter_ax(fig, ax, x[sorting], y[sorting], gexp=g[sorting], 
                                    s=5, title='', vmin=vmin, vmax=vmax, cmap='coolwarm') #, axis_off=False)
    ax.set_title(title, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
fig_manager.savefig(fig)
plt.show()
    

In [None]:
gns = [agenes, bgenes, cgenes] 
adata_plot = adata_mer[adata_mer.obs['sample'].str.contains('NR')]

x =  adata_plot.obs['width_show2']
y =  adata_plot.obs['depth_show2']
titles = ['A genes', 'B genes', 'C genes', ]
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*10,n*1))
for i, (ax, gn, title) in enumerate(zip(axs, gns, titles)):
    # g = adata_plot[:,gn].layers['ljnorm'].mean(axis=1)
    g = adata_plot[:,gn].layers['zscore'].mean(axis=1)
    sorting = np.argsort(g)
    
    vmin = np.percentile(g,  0)
    vmax = np.percentile(g, 95)
    p = utils_merfish.st_scatter_ax(fig, ax, x[sorting], y[sorting], gexp=g[sorting], 
                                    s=5, title='', vmin=vmin, vmax=vmax, cmap='coolwarm') #, axis_off=False)
    ax.set_title(title, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
fig_manager.savefig(fig)
plt.show()
    

# FISH stats

In [None]:
stats = {}
bins = np.linspace(0, 400, 4*2+1)

for name in names:
    adatasub = adata_mer[adata_mer.obs['sample']==name]# v1l23_data[name]
    lnorm_mean, lnorm_sem, lnorm_std, n, d, db = binning_pipe2(adatasub, 'depth', 'ljnorm', bins=bins)
    stats[name] = (lnorm_mean, lnorm_sem, lnorm_std, n, d, db)
d.value_counts()

In [None]:
print(len(agenes_idx), 
      len(bgenes_idx),
      len(cgenes_idx),
     )

In [None]:
# mean expression level across V1 L2/3 in NR
base_a0 = []
base_b0 = []
base_c0 = []
for name in [
    'P8NRa_ant2', 'P8NRb_ant2', 'P8NRc_ant2', 'P8NRd_ant2',]:
    (lnorm_mean, lnorm_sem, lnorm_std, n, d, db) = stats[name]
    base_a = np.mean(lnorm_mean.iloc[:,agenes_idx], axis=0) # across depth bins for each gene
    base_b = np.mean(lnorm_mean.iloc[:,bgenes_idx], axis=0) # across depth bins for each gene
    base_c = np.mean(lnorm_mean.iloc[:,cgenes_idx], axis=0) # across depth bins for each gene
    
    base_a0.append(base_a)
    base_b0.append(base_b)
    base_c0.append(base_c)
    
base_a0 = np.mean(base_a0, axis=0)
base_b0 = np.mean(base_b0, axis=0)
base_c0 = np.mean(base_c0, axis=0)

base_a0.shape, base_b0.shape, base_c0.shape


means = {}
sems = {}
for name in names:
    (lnorm_mean, lnorm_sem, lnorm_std, n, d, db) = stats[name]
    
    amean = np.mean(lnorm_mean.iloc[:,agenes_idx]-base_a0, axis=1) # a bin vector
    bmean = np.mean(lnorm_mean.iloc[:,bgenes_idx]-base_b0, axis=1) # a bin vector
    cmean = np.mean(lnorm_mean.iloc[:,cgenes_idx]-base_c0, axis=1) # a bin vector
    
    asem = np.mean(lnorm_sem.iloc[:,agenes_idx], axis=1)
    bsem = np.mean(lnorm_sem.iloc[:,bgenes_idx], axis=1)
    csem = np.mean(lnorm_sem.iloc[:,cgenes_idx], axis=1)
    
    means[name] = [amean, bmean, cmean, ]
    sems[name] = [asem, bsem, csem,]
    

In [None]:
midpoints = np.mean(np.vstack([bins[:-1], bins[1:]]), axis=0)
midpoints

In [None]:
samp_gene_dpth_mat = np.array([np.array(means[name]) for name in names]) 
print(samp_gene_dpth_mat.shape) # sample, gene group, depth

nr_mat = samp_gene_dpth_mat[:4]
nr_mean = np.mean(nr_mat, axis=0) # gene group, depth
nr_sem  = np.std(nr_mat, axis=0)/np.sqrt(4) # gene group, depth

dr_mat = samp_gene_dpth_mat[4:]
dr_mean = np.mean(dr_mat, axis=0) # gene group, depth
dr_sem  = np.std(dr_mat, axis=0)/np.sqrt(4) # gene group, depth
nr_mean.shape, dr_mean.shape

In [None]:
# t-test between NR and DR for each gene group and each location
ts, ps = ttest_ind(nr_mat, dr_mat)
rejs, qs, _, _ = multipletests(np.nan_to_num(ps, nan=1).reshape(-1,), alpha=0.05, method='fdr_bh')
qs = qs.reshape(ps.shape)
nrdr_mean = np.stack([nr_mean, dr_mean], axis=2).mean(axis=2)

In [None]:
gnames = ['A genes (n=64)', 'B genes (n=35)', 'C genes (n=71)']

fig, axs = plt.subplots(2, 4, figsize=(5*4,4*2), sharex=True, sharey=True)

# ax.set_title('P28NR')
linestyle = '-'
for ax, name in zip(axs.flat, names):
    # (lnorm_mean, lnorm_sem, lnorm_std, n, d, db) = stats[name]
    amean, bmean, cmean,  = means[name]
    asem, bsem, csem,  = sems[name]
    
    x = midpoints
    ax.plot(x, amean, label='A genes', color='C0', linestyle=linestyle)
    ax.fill_between(x, amean-asem, amean+asem, color='C0', alpha=0.1, edgecolor='none')
    ax.plot(x, bmean, label='B genes', color='C1', linestyle=linestyle)
    ax.fill_between(x, bmean-bsem, bmean+bsem, color='C1', alpha=0.1, edgecolor='none')
    ax.plot(x, cmean, label='C genes', color='C2', linestyle=linestyle)
    ax.fill_between(x, cmean-csem, cmean+csem, color='C2', alpha=0.1, edgecolor='none')
    ax.axhline(color='lightgray', linestyle='dotted', zorder=1)

    sns.despine(ax=ax)
    ax.set_xticks([0, 100, 200, 300, 400])
    ax.set_xlim(left=50, right=400)
    ax.set_ylim([-0.4, 0.4])
    ax.grid(False)
    ax.set_title(name)
axs.flat[0].set_ylabel('mean (expr. +/- sem)')

    
fig.subplots_adjust(wspace=0.1)
fig_manager.savefig(fig)
# powerplots.savefig_autodate(fig, outdatadir+'/grant_saumya_lineq_abc_v3.pdf')

In [None]:
titles = ['NR', 'DR']
data_mean = [nr_mean, dr_mean]
data_sem = [nr_sem, dr_sem]
gnames = ['A genes', 'B genes', 'C genes']
colors = ['C0', 'C1', 'C2']

fig, axs = plt.subplots(1, 2, figsize=(5*2,4), sharex=True, sharey=True)
for ax, cond_mean, cond_sem, title in zip(axs, data_mean, data_sem, titles):
    # ax.set_title('P28NR')
    linestyle = '-'
    ax.axhline(color='lightgray', linestyle='dotted', zorder=1)
    for i, (gname, color) in enumerate(zip(gnames, colors)):
        ax.plot(midpoints, cond_mean[i], label=gname, color=color, linestyle=linestyle)
        ax.fill_between(midpoints, cond_mean[i]-cond_sem[i], cond_mean[i]+cond_sem[i], color=color, alpha=0.1, edgecolor='none')

    sns.despine(ax=ax)
    ax.set_xticks([0, 100, 200, 300, 400])
    ax.set_xlim(left=50, right=400)
    # ax.set_ylim([-0.3, 0.3])
    ax.grid(False)
    ax.set_title(title)
    ax.set_xlabel('upper->lower cortical depth')
    
axs[0].set_ylabel('mean (expr. +/- sem)')
fig.subplots_adjust(wspace=0.1)
fig_manager.savefig(fig)
plt.show()

In [None]:
linestyles = ['-', '--']
data_mean = [nr_mean, dr_mean]
data_sem = [nr_sem, dr_sem]
gnames = ['A genes', 'B genes', 'C genes']
titles = gnames
colors = ['C0', 'C1', 'C2']
labels = ['NR', 'DR']
sigs = qs
allmeans = nrdr_mean

fig, axs = plt.subplots(1, 3, figsize=(5*3,4), sharex=True, sharey=True)
for i, (ax, gname, color) in enumerate(zip(axs, gnames, colors)):
    ax.axhline(color='lightgray', linestyle='dotted', zorder=1)
    for cond_mean, cond_sem, title, linestyle in zip(data_mean, data_sem, titles, linestyles):
        ax.plot(midpoints, cond_mean[i], label=gname, color=color, linestyle=linestyle, marker='o', markersize=5)
        ax.fill_between(midpoints, cond_mean[i]-cond_sem[i], cond_mean[i]+cond_sem[i], color=color, alpha=0.1, edgecolor='none')
        
    for _x, _y, _sig in zip(midpoints, allmeans[i], sigs[i]):
        if _sig < 1e-3:
            ax.text(_x, _y, "***", ha='left', va='center', fontsize=12, rotation=90)
            ax.vlines(_x, _y-0.02, _y+0.02, color='k', linewidth=0.5)
        elif _sig < 5e-2:
            ax.text(_x, _y, "*", ha='left', va='center', fontsize=12, rotation=90)
            ax.vlines(_x, _y-0.02, _y+0.02, color='k', linewidth=0.5)

    sns.despine(ax=ax)
    ax.set_xticks([0, 100, 200, 300, 400])
    ax.set_xlim(left=50, right=400)
    # ax.set_ylim([-0.2, 0.3])
    ax.grid(False)
    ax.set_title(gname)
    ax.set_xlabel('upper->lower cortical depth')
    
axs[0].set_ylabel('mean (expr. +/- sem)')
fig.subplots_adjust(wspace=0.1)
fig_manager.savefig(fig)
plt.show()