In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
from scipy import stats
import os

from scipy import spatial
from scipy import sparse
from scipy.interpolate import CubicSpline
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import networkx as nx
from umap import UMAP
from scipy.stats import zscore

import json

In [None]:
import importlib

from scroutines import powerplots
from scroutines.miscu import is_in_polygon

import utils_merfish
importlib.reload(utils_merfish)
from utils_merfish import rot2d, st_scatter, st_scatter_ax, plot_cluster, binning
from utils_merfish import RefLineSegs

import merfish_datasets
import merfish_genesets
importlib.reload(merfish_datasets)
importlib.reload(merfish_genesets)
from merfish_datasets import merfish_datasets
from merfish_datasets import merfish_datasets_params

from scroutines import basicu

In [None]:
def get_qc_metrics(df):
    """
    return metrics
     - key
      - (name, val, medval, bins)
    """
    metrics = {}
    cols  = ['volume', 'gncov', 'gnnum']
    names = ['cell volume', 'num transcripts', 'num genes']
    
    for col, name in zip(cols, names):
        val = df[col].values
        medval = np.median(val)
        bins = np.linspace(0, 10*medval, 50)
        
        metrics[col] = (name, val, medval, bins)
    return metrics

def get_norm_counts(adata, scaling=500):
    """norm - equalize the volume to be 500 for all cells
    """
    cnts = adata.X
    vol = adata.obs['volume'].values
    normcnts = cnts/vol.reshape(-1,1)*scaling
    adata.layers['norm'] = normcnts
    
    return normcnts

In [None]:
def preprocessing(adata):
    # filter genes
    cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
    adata_sub = adata[:,cond]

    # counts
    x = adata_sub.X
    cov = adata_sub.obs['n_counts'].values

    # CP10k
    xn = (sparse.diags(1/cov).dot(x))*1e4

    # log10(CP10k+1)
    xln = xn.copy()
    xln.data = np.log10(xln.data+1)

    adata_sub.layers['norm'] = xn
    adata_sub.layers['lognorm'] = xln
    
    return adata_sub

In [None]:
def get_hvgs(adata, layer, nbin=20, qth=0.3):
    """
    """
    xn = adata.layers[layer]
    
    # min
    gm = np.ravel(xn.mean(axis=0))

    # var
    tmp = xn.copy()
    tmp.data = np.power(tmp.data, 2)
    gv = np.ravel(tmp.mean(axis=0))-gm**2

    # cut 
    lbl = pd.qcut(gm, nbin, labels=np.arange(nbin))
    gres = pd.DataFrame()
    gres['lbl'] = lbl
    gres['mean'] = gm
    gres['var'] = gv
    gres['ratio']= gv/gm

    # select
    gres_sel = gres.groupby('lbl')['ratio'].nlargest(int(qth*(len(gm)/nbin))) #.reset_index()
    gsel_idx = np.sort(gres_sel.index.get_level_values(1).values)

    assert np.all(gsel_idx != -1)
    
    return adata.var.index.values[gsel_idx]

In [None]:
def add_triangle(XC, ax, zorder=0, vertices=False, **kwargs):
    # add the triangle
    ax.plot(XC[0].tolist()+[XC[0,0]], XC[1].tolist()+[XC[1,0]], '--',  color='gray', label='', zorder=zorder, linewidth=1, markersize=3)
    
    # add vertices
    if vertices:
        ax.scatter(XC[0,0], XC[1,0], color='C0', zorder=zorder, **kwargs)
        ax.scatter(XC[0,1], XC[1,1], color='C1', zorder=zorder, **kwargs)
        ax.scatter(XC[0,2], XC[1,2], color='C2', zorder=zorder, **kwargs)

In [None]:
def binning_pipe(adata, n=20, layer='lnorm', bin_type='depth_bin'):
    """
    """
    assert bin_type in ['depth_bin', 'width_bin']
    # bin it 
    depth_bins, depth_binned = utils_merfish.binning(adata.obs['depth'].values, n)
    width_bins, width_binned = utils_merfish.binning(adata.obs['width'].values, n)

    norm_ = pd.DataFrame(adata.layers[layer], columns=adata.var.index)
    norm_['depth_bin'] = depth_binned
    norm_['width_bin'] = width_binned
    
    norm_mean = norm_.groupby(bin_type).mean(numeric_only=True)
    norm_sem  = norm_.groupby(bin_type).sem(numeric_only=True)
    norm_std  = norm_.groupby(bin_type).std(numeric_only=True)
    norm_n    = norm_[bin_type].value_counts(sort=False)

    return norm_mean, norm_sem, norm_std, norm_n, depth_binned, width_binned, depth_bins, width_bins

def binning_pipe2(adata, col_to_bin, layer, bins=None, n=20):
    """
    """
    if bins is None:
        # bin it 
        bins, binned = utils_merfish.binning(adata.obs[col_to_bin].values, n)
    else:
        binned = pd.cut(adata.obs[col_to_bin].values, bins=bins)

    norm_ = pd.DataFrame(adata.layers[layer], columns=adata.var.index)
    norm_['thebin'] = binned
    
    norm_mean = norm_.groupby('thebin').mean(numeric_only=True)
    norm_sem  = norm_.groupby('thebin').sem(numeric_only=True)
    norm_std  = norm_.groupby('thebin').std(numeric_only=True)
    norm_n    = norm_['thebin'].value_counts(sort=False)

    return norm_mean, norm_sem, norm_std, norm_n, binned, bins 

In [None]:
def neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(ref_emb)
    dists, idx = neigh.kneighbors(qry_emb, k, return_distance=True)
    
    raw_pred = ref_lbl[idx]

    # p
    pabc = np.empty((qry_n, n_unq_lbls))
    for i, lbl in enumerate(unq_lbls):
        p = np.sum(raw_pred==lbl, axis=1)/k
        pabc[:,i] = p

    # max
    max_pred = unq_lbls[np.argmax(pabc, axis=1)]

    # 
    gated_pred = max_pred.copy()
    cond1 = np.max( pabc, axis=1) > p_cutoff
    gated_pred[~cond1] = 'NA' 
    if dist_cutoff is not None:
        cond2 = np.max(dists, axis=1) < dist_cutoff
        gated_pred[~cond2] = 'NA' 
    
    return max_pred, gated_pred, np.max(dists, axis=1)


def neighbor_self_nonself(k, ref_emb, qry_emb):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    lbls = np.array([0]*ref_n+[1]*qry_n)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(np.vstack([ref_emb, qry_emb]))
    idx = neigh.kneighbors(qry_emb, k, return_distance=False)
    
    isself = lbls[idx]

    p = np.sum(isself, axis=1)/k

    
    return p # max_pred, gated_pred, np.max(dists, axis=1)

# load data and construct adata 

In [None]:
np.random.seed(0)

In [None]:
ddir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/organized" 
fin = os.path.join(ddir, 'P8NR_v1l23glut_rna_merfish_250411.h5ad')
!ls $ddir/*l2*240723.h5ad 

In [None]:
adata_merge = ad.read(fin)
adata_snrnasq = adata_merge[adata_merge.obs['modality']=='rna']
adata_merfish = adata_merge[adata_merge.obs['modality']=='merfish'] 

adata_merge, adata_snrnasq, adata_merfish

In [None]:
%%time
names = [
    'P8NRa_ant2', 
    'P8NRb_ant2',
    'P8NRc_ant2', 
    'P8NRd_ant2',
    
    'P8NRa_pos2', 
    'P8NRb_pos2',
    'P8NRc_pos2', 
    'P8NRd_pos2',
]

alldata = {}
for name in names:
    adatasub = ad.read(os.path.join(ddir, f'{name}_l2_v1_250410.h5ad')) 
    adatasub.obs.index = np.char.add(f'{name}', adatasub.obs.index.values)
    alldata[name] = adatasub 
    print(name, len(alldata[name]))
    
genes = adatasub.var.index.values
genes.shape

In [None]:
genesets, df = merfish_genesets.get_all_genesets()
for key, item in genesets.items():
    print(key, len(item))

In [None]:
agenes = genesets['a']
bgenes = genesets['b']
cgenes = genesets['c']
iegs   = genesets['i']
abcgenes = np.hstack([agenes, bgenes, cgenes])
genes_noniegs = np.array([g for g in genes if g not in iegs])

marker_genes = [
       'Ptprn', 'Slc17a7', 'Gad1', 'Fos', 
       
       'Gfap', 'Slc6a13', 'Slc47a1',
       'Grin2c', 'Aqp4', 'Rfx4', 'Sox21', 'Slc1a3',
       
       'Sox10', 'Pdgfra', 'Mog',
       
       'Pecam1', 'Cd34' , 'Tnfrsf12a', 'Sema3c', 
       'Zfhx3', 'Pag1', 'Slco2b1', 'Cx3cr1',
      ] 
len(abcgenes), len(genes_noniegs)

In [None]:
agenes_idx = basicu.get_index_from_array(adatasub.var.index.values, agenes)
bgenes_idx = basicu.get_index_from_array(adatasub.var.index.values, bgenes)
cgenes_idx = basicu.get_index_from_array(adatasub.var.index.values, cgenes)
igenes_idx = basicu.get_index_from_array(adatasub.var.index.values, iegs)

In [None]:
mean_total_rna_target = 250
adata_merged = []
for i, name in enumerate(names):
    j = i // 4
    i = i % 4
    
    adatasub = alldata[name].copy()
    adatasub.obs['sample'] = name
    
    norm_cnts = adatasub.layers['norm']
    # mean_per_batch = np.mean(norm_cnts.sum(axis=1))
    mean_per_batch_noniegs = np.mean(adatasub[:,genes_noniegs].layers['norm'].sum(axis=1))
    
    adatasub.layers['jnorm']  = norm_cnts*(mean_total_rna_target/mean_per_batch_noniegs)
    adatasub.layers['ljnorm'] = np.log2(1+adatasub.layers['jnorm'])
    
    adatasub.obs['norm_transcript_count']  = adatasub.layers['norm'].sum(axis=1)
    adatasub.obs['jnorm_transcript_count'] = adatasub.layers['jnorm'].sum(axis=1)
    
    adatasub.obs['depth_show'] = -adatasub.obs['depth'].values - i*1300 # name
    adatasub.obs['width_show'] =  adatasub.obs['width'].values - np.min(adatasub.obs['width'].values) + j*2500   # name
    
    adata_merged.append(adatasub)
    
adata_merged = ad.concat(adata_merged)

# get L23 RNA

In [None]:
adata_snrnasq

In [None]:
#
f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/cheng21_cell_scrna/organized/P28NR.h5ad'
adata_rna_raw = sc.read(f)
adata_rna_raw = adata_rna_raw[adata_rna_raw.obs['Subclass']=='L2/3'].copy()
adata_rna_raw

In [None]:
# abcgenes_rna = rename_genes(abcgenes)
adata_rna_raw.var.index = merfish_genesets.rename_genes(adata_rna_raw.var.index.values) 

In [None]:
adata_rna_raw = preprocessing(adata_rna_raw)
# hvgs = get_hvgs(adata_rna_raw, 'norm')
# adata_rna = adata_rna_raw[:,hvgs]
adata_rna = adata_rna_raw[:,abcgenes]
adata_rna_raw, adata_rna

# get L2/3 only

In [None]:
l23_cells = adata_merfish[adata_merfish.obs['gated_pred_subclass']=='L2/3'].obs.index.values
adata_l23 = adata_merged[l23_cells].copy()

In [None]:
colors = ['C1', 'k']
for i, name in enumerate(names):
    j = i // 4
    color = colors[j]
    
    adatasub = adata_l23[adata_l23.obs['sample']==name]
    sns.histplot(adatasub.obs['depth'].values, element='step', fill=False, color=color)

In [None]:
adata = adata_l23[adata_l23.obs['depth'] < 400].copy()

In [None]:
ns = adata.obs.groupby('sample').size()
ls = adata.obs.groupby('sample')['width'].max() - adata.obs.groupby('sample')['width'].min()

a, b = (ns/ls)[:4], (ns/ls)[4:]
t, p = stats.ttest_ind(a, b)
a, b, p

In [None]:

plt.bar(np.arange(8), ns/ls)

In [None]:
width_min = adata.obs.groupby('sample')['width'].min().reindex(names)
width_max = adata.obs.groupby('sample')['width'].max().reindex(names)
width_rng = width_max - width_min 
width_cum = pd.Series(np.cumsum(np.hstack([0, width_rng[:-1]+100])), index=names)

adata.obs['width_n0'] = adata.obs['width'] - width_min.reindex(adata.obs['sample']).values
adata.obs['width_show2'] =  adata.obs['width_n0'] + width_cum.reindex(adata.obs['sample']).values
adata.obs['depth_show2'] = -adata.obs['depth']

In [None]:
# merge using overlapping genes
adata_rna = adata_rna[:,abcgenes].copy()
adata_mer = adata[adata.obs['transcript_count']>50, abcgenes]
adata_mer = adata_mer[adata_mer.obs['sample'].str.contains('NR')].copy()

adata_rna.obs['modality'] = 'rna'
adata_mer.obs['modality'] = 'merfish'
adata_merge0 = sc.concat([adata_rna, adata_mer], join='outer')

lognorm_rna = np.log10(1+np.array(adata_rna.layers['norm'].todense()))
zlognorm_rna = stats.zscore(lognorm_rna, axis=0)

lognorm_mer = adata_mer.layers['ljnorm']  # np.log10(1+adata_mer.layers['norm'])
zlognorm_mer = stats.zscore(lognorm_mer, axis=0)

adata_merge0.obsm['X_pca2'] = PCA(n_components=20).fit_transform(np.vstack([zlognorm_rna, zlognorm_mer])) 

print(len(adata_rna), len(adata), len(adata_mer))
adata_merge0

In [None]:
sc.external.pp.harmony_integrate(adata_merge0, 'modality', basis='X_pca2', max_iter_harmony=20)

In [None]:
# merge using overlapping genes
adata_rna = adata_rna[:,abcgenes].copy()
adata_mer = adata[adata.obs['transcript_count']>50, abcgenes].copy()

adata_rna.obs['modality'] = 'rna'
adata_mer.obs['modality'] = 'merfish'
adata_merge = sc.concat([adata_rna, adata_mer], join='outer')

lognorm_rna = np.log10(1+np.array(adata_rna.layers['norm'].todense()))
zlognorm_rna = stats.zscore(lognorm_rna, axis=0)

lognorm_mer = adata_mer.layers['ljnorm']  # np.log10(1+adata_mer.layers['norm'])
zlognorm_mer = stats.zscore(lognorm_mer, axis=0)

adata_merge.obsm['X_pca2'] = PCA(n_components=20).fit_transform(np.vstack([zlognorm_rna, zlognorm_mer])) 

print(len(adata_rna), len(adata), len(adata_mer))
adata_merge

In [None]:
sc.external.pp.harmony_integrate(adata_merge, 'modality', basis='X_pca2', max_iter_harmony=20)

In [None]:
adata_merge0.obs['hpc1'] = adata_merge0.obsm['X_pca_harmony'][:,0]
adata_merge0.obs['hpc2'] = adata_merge0.obsm['X_pca_harmony'][:,1]
adata_merge0.obs['hpc3'] = adata_merge0.obsm['X_pca_harmony'][:,2]
adata_merge0.obs['hpc4'] = adata_merge0.obsm['X_pca_harmony'][:,3]

adata_merge.obs['hpc1'] = adata_merge.obsm['X_pca_harmony'][:,0]
adata_merge.obs['hpc2'] = adata_merge.obsm['X_pca_harmony'][:,1]
adata_merge.obs['hpc3'] = adata_merge.obsm['X_pca_harmony'][:,2]
adata_merge.obs['hpc4'] = adata_merge.obsm['X_pca_harmony'][:,3]


In [None]:
%%time

from py_pcha import PCHA


def get_aa(X):
    """
    """
    np.random.seed(0)
    XC, S, C, SSE, varexpl = PCHA(X, noc=3, delta=0)
    XC = np.array(XC)
    XC = XC[:,np.argsort(XC[0])].copy() # order this
    return XC

X = adata_merge0.obsm['X_pca_harmony'][:,:2].T
XC0 = get_aa(X)

X = adata_merge0[adata_merge0.obs['modality']=='rna'].obsm['X_pca_harmony'][:,:2].T
XC1 = get_aa(X)

X = adata_merge0[adata_merge0.obs['modality']=='merfish'].obsm['X_pca_harmony'][:,:2].T
XC2 = get_aa(X)

In [None]:
fig, ax = plt.subplots()
sns.scatterplot(data=adata_merge0.obs.sample(frac=1, replace=False), 
                x='hpc1', y='hpc2', hue='modality', s=5, edgecolor='none', 
                ax=ax,
               )
add_triangle(XC0, ax)
add_triangle(XC1, ax)
add_triangle(XC2, ax)
plt.show()

fig, ax = plt.subplots()
sns.scatterplot(data=adata_merge0[adata_merge0.obs['modality']=='rna'].obs.sample(frac=1, replace=False), 
                x='hpc1', y='hpc2', hue='Type', s=5, edgecolor='none', 
                ax=ax,
               )
ax.set_aspect('equal')
ax.grid(False)
add_triangle(XC0, ax)
add_triangle(XC1, ax, vertices=True, edgecolors='k', linewidths=1, marker='o')
plt.show()

fig, ax = plt.subplots()
sns.scatterplot(data=adata_merge0[adata_merge0.obs['modality']=='merfish'].obs.sample(frac=1, replace=False), 
                x='hpc1', y='hpc2', #hue='', s=5, edgecolor='none', 
                ax=ax,
               )
ax.set_aspect('equal')
ax.grid(False)
add_triangle(XC0, ax)
add_triangle(XC2, ax, vertices=True, edgecolors='k', linewidths=1, marker='o')
plt.show()

In [None]:
sns.scatterplot(data=adata_merge.obs.sample(frac=1, replace=False), 
                x='hpc1', y='hpc2', hue='modality', s=5, edgecolor='none')
plt.show()
sns.scatterplot(data=adata_merge.obs.sample(frac=1, replace=False), 
                x='hpc1', y='hpc2', hue='Type', s=5, edgecolor='none')
plt.show()

In [None]:
clsts_palette2 = {
    'L2/3_A': 'C0',    
    'L2/3_B': 'C1',    
    'L2/3_C': 'C2',    
    'NA': 'gray',
}

In [None]:
# # label transfer from RNA data
ref = adata_merge0[adata_merge0.obs['modality']=='rna'].copy()
qry = adata_merge0[adata_merge0.obs['modality']=='merfish'].copy()

all_emb = adata_merge0.obsm['X_pca_harmony'][:,:2]
ref_emb = ref.obsm['X_pca_harmony'][:,:2]
qry_emb = qry.obsm['X_pca_harmony'][:,:2]

ref_lbl = ref.obs['Type'].values.astype(str) # res_nr['type'].values.astype(str)
print(np.unique(ref_lbl))

k = 30
max_pred, _, dists = neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None)
ps = neighbor_self_nonself(k, ref_emb, qry_emb)

adata_merge0.obs['max_pred_type'] = 'NA' 
adata_merge0.obs['frac_self_neighbors_type'] = np.nan # 'NA'
adata_merge0.obs['gated_pred_type'] = 'NA'

adata_merge0.obs.loc[qry.obs.index, 'max_pred_type'] = max_pred
adata_merge0.obs.loc[qry.obs.index, 'frac_self_neighbors_type'] = ps
adata_merge0.obs.loc[qry.obs.index, 'gated_pred_type'] = np.where(ps < 0.95, max_pred, 'NA')

adata_plot00 = adata_merge0[adata_merge0.obs['modality']=='rna']
adata_plot01 = adata_merge0[adata_merge0.obs['modality']=='merfish']

In [None]:
fig, axs = plt.subplots(1,2,figsize=(2*6,1*5), sharex=True, sharey=True)
ax = axs[0]
sns.scatterplot(data=adata_plot00.obs.sample(frac=1, replace=False), 
                ax=ax, x='hpc1', y='hpc2', 
                hue='Type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                legend=False,
                s=5, edgecolor='none')
ax.set_aspect('equal')
add_triangle(XC0, ax, vertices=True, edgecolors='k', linewidths=1, marker='o')
sns.despine(ax=ax)
ax.grid(False)

ax = axs[1]
sns.scatterplot(data=adata_plot01.obs.sample(frac=1, replace=False), 
                ax=ax, x='hpc1', y='hpc2', 
                hue='max_pred_type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                legend=False,
                s=5, edgecolor='none')
ax.set_aspect('equal')
add_triangle(XC0, ax, vertices=True, edgecolors='k', linewidths=1, marker='o')
sns.despine(ax=ax)
ax.grid(False)
plt.show()

In [None]:

gns = ['Cdh13', 'Sorcs3', 'Trpc6', 'Chrm2'] 
n = len(gns)
titles = gns
x = adata_plot01.obs['hpc1'].values
y = adata_plot01.obs['hpc2'].values

fig, axs = plt.subplots(1,n,figsize=(n*5,1*3), sharex=True, sharey=True)
for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
    g = adata_plot01[:,gn].layers['ljnorm'].reshape(-1,)

    # consistent over
    g0 = adata[:,gn].layers['ljnorm'].mean(axis=1)
    vmin = np.percentile(g0,  0)
    vmax = np.percentile(g0, 95)
    
    sorting = np.argsort(g)

    p = utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, vmin=vmin, vmax=vmax, cmap='rocket_r')
    # p = utils_merfish.st_scatter_ax(fig, ax, x[sorting], y[sorting], gexp=g[sorting], s=5, vmin=vmin, vmax=vmax, cmap='rocket_r')
    colorbar = plt.colorbar(p, aspect=5, shrink=0.3)

    # Show ticks but no grid
    ax.set_aspect('equal')
    ax.axis('on')
    ax.grid(False)  # Turn off grid lines
    sns.despine(ax=ax)
    ax.tick_params(axis='both', which='both', bottom=True, left=True)

    ax.set_title(title)

    add_triangle(XC0, ax)
plt.show()

In [None]:
# # label transfer from RNA data
ref = adata_merge[adata_merge.obs['modality']=='rna'].copy()
qry = adata_merge[adata_merge.obs['modality']=='merfish'].copy()

all_emb = adata_merge.obsm['X_pca_harmony'][:,:2]
ref_emb = ref.obsm['X_pca_harmony'][:,:2]
qry_emb = qry.obsm['X_pca_harmony'][:,:2]

ref_lbl = ref.obs['Type'].values.astype(str) # res_nr['type'].values.astype(str)
print(np.unique(ref_lbl))

k = 30
max_pred, _, dists = neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None)
ps = neighbor_self_nonself(k, ref_emb, qry_emb)

adata_merge.obs['max_pred_type'] = 'NA' 
adata_merge.obs['frac_self_neighbors_type'] = np.nan # 'NA'
adata_merge.obs['gated_pred_type'] = 'NA'

adata_merge.obs.loc[qry.obs.index, 'max_pred_type'] = max_pred
adata_merge.obs.loc[qry.obs.index, 'frac_self_neighbors_type'] = ps
adata_merge.obs.loc[qry.obs.index, 'gated_pred_type'] = np.where(ps < 0.95, max_pred, 'NA')

adata_plot0 = adata_merge[adata_merge.obs['modality']=='rna'].obs.sample(frac=1, replace=False)
adata_plot1 = adata_merge[adata_merge.obs['modality']=='merfish'].obs.sample(frac=1, replace=False)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(2*6,1*5), sharex=True, sharey=True)
ax = axs[0]
sns.scatterplot(data=adata_plot0, 
                ax=ax, x='hpc1', y='hpc2', 
                hue='Type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                s=5, edgecolor='none')
ax.set_aspect('equal')

ax = axs[1]
sns.scatterplot(data=adata_plot1, 
                ax=ax, x='hpc1', y='hpc2', 
                hue='max_pred_type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                s=5, edgecolor='none')
ax.set_aspect('equal')
sns.despine(ax=ax)
plt.show()

In [None]:
%%time
xsign = 1

from py_pcha import PCHA
np.random.seed(0)

X = np.vstack([
    xsign*adata[adata.obs['sample'].str.contains('NR')].obsm['pcs_typegenes'][:,xi], 
    ysign*adata[adata.obs['sample'].str.contains('NR')].obsm['pcs_typegenes'][:,yi], 
])


XC, S, C, SSE, varexpl = PCHA(X, noc=3, delta=0)
XC = np.array(XC)
XC = XC[:,np.argsort(XC[0])].copy() # order this
print(XC.shape, S.shape, C.shape, SSE.shape, varexpl.shape, SSE, varexpl)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(2*6,1*5), sharex=True, sharey=True)
ax = axs[0]
sns.scatterplot(data=adata_plot1[adata_plot1['sample'].str.contains('NR')], 
                ax=ax, x='hpc1', y='hpc2', 
                hue='max_pred_type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                legend=False,
                s=5, edgecolor='none')
ax.set_aspect('equal')

ax = axs[1]
sns.scatterplot(data=adata_plot1[adata_plot1['sample'].str.contains('DR')], 
                ax=ax, x='hpc1', y='hpc2', 
                hue='gated_pred_type', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                legend=False,
                s=5, edgecolor='none')
ax.set_aspect('equal')
plt.show()

In [None]:
adata.obsm['pcs_typegenes'] = adata_merge[adata.obs.index].obsm['X_pca_harmony'] # [:,3]
adata

In [None]:
xi, yi = 0, 1
xsign, ysign = 1, 1

In [None]:
metrics = ['gncov', 'gnnum', 'depth', 'width_show2']

fig, axs = plt.subplots(1,4,figsize=(4*5,1*4))
for metric, ax in zip(metrics, axs):
    # g = np.log10(1+adata.obs[metric])
    g = adata.obs[metric]
    x = xsign*adata.obsm['pcs_typegenes'][:,xi]
    y = ysign*adata.obsm['pcs_typegenes'][:,yi]
    utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, )
    ax.set_title(metric)

In [None]:
xi, yi = 0, 1
xsign, ysign = 1,1

In [None]:
%%time

from py_pcha import PCHA

np.random.seed(0)

X = np.vstack([
    xsign*adata.obsm['pcs_typegenes'][:,xi], 
    ysign*adata.obsm['pcs_typegenes'][:,yi], 
])

# X = np.vstack([
#     xsign*adata[adata.obs['sample'].str.contains('NR')].obsm['pcs_typegenes'][:,xi], 
#     ysign*adata[adata.obs['sample'].str.contains('NR')].obsm['pcs_typegenes'][:,yi], 
# ])


XC, S, C, SSE, varexpl = PCHA(X, noc=3, delta=0)
XC = np.array(XC)
XC = XC[:,np.argsort(XC[0])].copy() # order this
print(XC.shape, S.shape, C.shape, SSE.shape, varexpl.shape, SSE, varexpl)




In [None]:
gns = [agenes, bgenes, cgenes, iegs]
titles = ['A genes', 'B genes', 'C genes',]
adatas = [
    adata,
]
conditions = ['combined', 'NR', ]

fig, axss = plt.subplots(2,3,figsize=(3*5,3*2), sharex=True, sharey=True)
for i, (axs, adatasub, condition) in enumerate(zip(axss, adatas, conditions)):
    condition = conditions[i]
    for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
        g = adatasub[:,gn].layers['ljnorm'].mean(axis=1)
        x = xsign*adatasub.obsm['pcs_typegenes'][:,xi]
        y = ysign*adatasub.obsm['pcs_typegenes'][:,yi]
        
        # consistent over
        g0 = adata[:,gn].layers['ljnorm'].mean(axis=1)
        vmin = np.percentile(g0,  5)
        vmax = np.percentile(g0, 95)
            
        p = utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, cmap='coolwarm', vmin=vmin, vmax=vmax)
        colorbar = plt.colorbar(p, aspect=5, shrink=0.3)
        
        # Show ticks but no grid
        ax.set_aspect('equal')
        ax.axis('on')
        ax.grid(False)  # Turn off grid lines
        sns.despine(ax=ax)
        ax.tick_params(axis='both', which='both', bottom=True, left=True)
        
        if i == 0:
            ax.set_title(title)
        if j == 0:
            ax.set_ylabel(condition, rotation=0, loc='top')
            
        # add the triangle
        add_triangle(XC, ax)
plt.show()

In [None]:
gns = ['Cdh13', 'Cdh12', 'Meis2', 'Foxp1', 'Astn2']
titles = gns
adatas = [
    adata,
]
conditions = ['combined', 'NR', 'DR']

fig, axss = plt.subplots(2,4,figsize=(4*5,3*2), sharex=True, sharey=True)
for i, (axs, adatasub, condition) in enumerate(zip(axss, adatas, conditions)):
    condition = conditions[i]
    for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
        g = adatasub[:,gn].layers['ljnorm'].reshape(-1,)
        x = xsign*adatasub.obsm['pcs_typegenes'][:,xi]
        y = ysign*adatasub.obsm['pcs_typegenes'][:,yi]
        
        # consistent over
        g0 = adata[:,gn].layers['ljnorm'].mean(axis=1)
        vmin = np.percentile(g0,  5)
        vmax = np.percentile(g0, 99)
            
        p = utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, vmin=vmin, vmax=vmax)
        colorbar = plt.colorbar(p, aspect=5, shrink=0.3)
        
        # Show ticks but no grid
        ax.set_aspect('equal')
        ax.axis('on')
        ax.grid(False)  # Turn off grid lines
        sns.despine(ax=ax)
        ax.tick_params(axis='both', which='both', bottom=True, left=True)
        
        if i == 0:
            ax.set_title(title)
        if j == 0:
            ax.set_ylabel(condition, rotation=0, loc='top')
            
        add_triangle(XC, ax)
plt.show()

In [None]:
gns = ['Cdh13', 'Cdh12', 'Astn2', 'Meis2', 'Foxp1', 'Pcdh19']
titles = gns
adatas = [
    adata,
]
conditions = ['combined', 'NR', 'DR']

n = len(gns)

fig, axss = plt.subplots(2,n,figsize=(5*n,3*2), sharex=True, sharey=True)
for i, (axs, adatasub, condition) in enumerate(zip(axss, adatas, conditions)):
    condition = conditions[i]
    for j, (ax, gn, title,) in enumerate(zip(axs, gns, titles)):
        g = adatasub[:,gn].layers['ljnorm'].reshape(-1,)
        x = xsign*adatasub.obsm['pcs_typegenes'][:,xi]
        y = ysign*adatasub.obsm['pcs_typegenes'][:,yi]
        
        # consistent over
        g0 = adata[:,gn].layers['ljnorm'].mean(axis=1)
        vmin = np.percentile(g0,  5)
        vmax = np.percentile(g0, 99)
            
        p = utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, vmin=vmin, vmax=vmax)
        colorbar = plt.colorbar(p, aspect=5, shrink=0.3)
        
        # Show ticks but no grid
        ax.set_aspect('equal')
        ax.axis('on')
        ax.grid(False)  # Turn off grid lines
        sns.despine(ax=ax)
        ax.tick_params(axis='both', which='both', bottom=True, left=True)
        
        if i == 0:
            ax.set_title(title)
        if j == 0:
            ax.set_ylabel(condition, rotation=0, loc='top')
            
        add_triangle(XC, ax)
plt.show()

# customized colormap 

In [None]:
from matplotlib.colors import LinearSegmentedColormap

colors_a = [(0.0, 'black'), (1.0, 'C0')]      
colors_b = [(0.0, 'black'), (1.0, 'C1')]      
colors_c = [(0.0, 'black'), (1.0, 'C2')]      
colors_nrdr = [(0.0, 'C1'), (0.5, 'white'), (1.0, 'black')]
colors_nr = [(0.0, 'white'), (1.0, 'C1'),]
colors_dr = [(0.0, 'white'), (1.0, 'black'),]

# Create a custom colormap using LinearSegmentedColormap
cmap_a = LinearSegmentedColormap.from_list('cmap_a', colors_a)
cmap_b = LinearSegmentedColormap.from_list('cmap_b', colors_b)
cmap_c = LinearSegmentedColormap.from_list('cmap_c', colors_c)
cmap_nrdr = LinearSegmentedColormap.from_list('cmap_nrdr', colors_nrdr)
cmap_nr = LinearSegmentedColormap.from_list('cmap_nr', colors_nr)
cmap_dr = LinearSegmentedColormap.from_list('cmap_dr', colors_dr)

In [None]:
xmin, xmax = -12, 12
ymin, ymax = -7, 7 

bins_x = np.linspace(xmin, xmax, 1*(xmax-xmin)+1)
bins_y = np.linspace(ymin, ymax, 1*(ymax-ymin)+1)
print(bins_x)
print(bins_y)

hists = []
fig, axs = plt.subplots(1,4,figsize=(4*5,1*4), sharex=True, sharey=True)
for ax, adatasub, cond, _cmap in zip(axs, [
    adata,
    ], 
    ['Combined', 'NR', ], 
    ['gray_r', cmap_nr, cmap_dr]):
    x =  xsign*adatasub.obsm['pcs_typegenes'][:,xi]
    y =  ysign*adatasub.obsm['pcs_typegenes'][:,yi]
    sns.histplot(x=x, y=y, ax=ax, bins=(bins_x, bins_y), 
                 cmap=_cmap, # 'gray_r', 
                 stat='percent', vmin=0, vmax=2, 
                 cbar=True, cbar_kws=dict(shrink=0.4, ticks=[0,2]))
    # sns.kdeplot(x=x, y=y, ax=ax, bins=(bins_x, bins_y))
    
    hist, _, _= np.histogram2d(x, y, bins=(bins_x, bins_y))
    hist = hist/len(x)*100
    hists.append(hist)
    print(hist.shape)
    ax.set_title(cond)
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.grid(False)
    
    # g = ax.imshow(pd.DataFrame(np.log2(1e-3+hist), 
    #                          index=bins_x[1:]-0.5, 
    #                          columns=bins_y[1:]-0.5).T, 
    #             origin='lower',
    #             extent=(xmin, xmax, ymin, ymax),
    #             cmap='gray_r') # , vmax=1, vmin=-1)
    
    # add the triangle
    add_triangle(XC, ax, zorder=2)
    

# add the triangle
add_triangle(XC, ax, zorder=2)

plt.show()

In [None]:
dfshow = adata.obs.copy()
dfshow['nrdr'] = dfshow['sample'].str.contains('DR').astype(int)
dfshow['dim1'] = xsign*adata.obsm['pcs_typegenes'][:,xi]
dfshow['dim2'] = ysign*adata.obsm['pcs_typegenes'][:,yi]
palette = {0: 'C1', 1: 'black'}


fig, axs = plt.subplots(1,2,figsize=(2*5,1*4), sharex=True, sharey=True) 
ax = axs[0]
add_triangle(XC, ax)
sns.scatterplot(data=dfshow.sample(frac=1), x='dim1', y='dim2', hue='nrdr', s=3, edgecolor='none', palette=palette, ax=ax)
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.grid(False)
ax.legend(bbox_to_anchor=(1,1))

ax = axs[1]
add_triangle(XC, ax)
sns.kdeplot(data=dfshow, x='dim1', y='dim2', hue='nrdr', palette=palette, legend=False, ax=ax,)
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.grid(False)


fig.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(2,4,figsize=(4*4,2*3), sharex=True, sharey=True) 
for sample, ax in zip(names, axs.flat):
    add_triangle(XC, ax)
    sns.scatterplot(data=dfshow[dfshow['sample']==sample].sample(frac=1), 
                    x='dim1', y='dim2', hue='nrdr', s=5, edgecolor='none', palette=palette, ax=ax, legend=False)
    sns.kdeplot(data=dfshow[dfshow['sample']==sample],
                x='dim1', y='dim2', hue='nrdr', palette=palette, legend=False, ax=ax,)
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.set_title(sample)
    ax.grid(False)
    # ax.legend(bbox_to_anchor=(1,1))

fig.tight_layout()
plt.show()

# customized colormap 

In [None]:
from matplotlib.colors import LinearSegmentedColormap

colors_a = [(0.0, 'black'), (1.0, 'C0')]      
colors_b = [(0.0, 'black'), (1.0, 'C1')]      
colors_c = [(0.0, 'black'), (1.0, 'C2')]      

# Create a custom colormap using LinearSegmentedColormap
cmap_a = LinearSegmentedColormap.from_list('cmap_a', colors_a)
cmap_b = LinearSegmentedColormap.from_list('cmap_b', colors_b)
cmap_c = LinearSegmentedColormap.from_list('cmap_c', colors_c)


In [None]:
# get ABC scores
g0_a = stats.zscore(adata[:,agenes].layers['ljnorm'], axis=0).mean(axis=1)
g0_b = stats.zscore(adata[:,bgenes].layers['ljnorm'], axis=0).mean(axis=1)
g0_c = stats.zscore(adata[:,cgenes].layers['ljnorm'], axis=0).mean(axis=1)

# make ABC scores comparable and norm to [0,1] [30% to 95%]
vmin_p, vmax_p = 40, 95
vmin_a = np.percentile(g0_a, vmin_p)
vmax_a = np.percentile(g0_a, vmax_p)

vmin_b = np.percentile(g0_b, vmin_p)
vmax_b = np.percentile(g0_b, vmax_p)

vmin_c = np.percentile(g0_c, vmin_p)
vmax_c = np.percentile(g0_c, vmax_p)

g0_a = np.clip((g0_a-vmin_a)/(vmax_a-vmin_a), 0, 1)
g0_b = np.clip((g0_b-vmin_b)/(vmax_b-vmin_b), 0, 1)
g0_c = np.clip((g0_c-vmin_c)/(vmax_c-vmin_c), 0, 1)

# separate them into scale and frequency
g0_sum  = (g0_a+g0_b+g0_c)
freq0_a = g0_a/(g0_sum+1e-5)
freq0_b = g0_b/(g0_sum+1e-5)
freq0_c = g0_c/(g0_sum+1e-5)

adata.obsm['size_freq_abc'] = np.vstack([freq0_a, freq0_b, freq0_c, g0_sum]).T

In [None]:
plt.plot(np.sort(g0_sum))
plt.plot(np.sort(freq0_a+freq0_b+freq0_c))

In [None]:
np.sum(freq0_a+freq0_b+freq0_c < 0.5), np.sum(freq0_a+freq0_b+freq0_c == 0)

In [None]:
np.sum(freq0_a+freq0_b+freq0_c > 0.5), np.sum(freq0_a+freq0_b+freq0_c >1-1e-1)

# ABC scores - expression level distributions

In [None]:
fig, axs = plt.subplots(1,4, figsize=(4*4,4), sharex=False, sharey=True)
for ax, genegroup, title in zip(axs, 
                                [agenes, bgenes, cgenes, iegs], 
                                ['A genes', 'B genes', 'C genes', 'IEGs'],
                               ):
    for i, sample in enumerate(names):
        scores_ = adata[adata.obs['sample']==sample][:,genegroup].layers['ljnorm'].mean(axis=1)
        if 'NR' in sample:
            color = 'C1'
        elif 'DR' in sample:
            color = 'black'

        sns.ecdfplot(scores_, ax=ax, color=color)#, complementary=True) # , linewidth=2)
        
        # _x = np.percentile(scores_, 50)
        # _y = 0.5
        # ax.text(_x, _y, i, fontsize=10, color='red')
    ax.set_ylabel('Cumulative proportion\nof cells')
    ax.set_xlabel('Mean log norm expr.')
    ax.set_title(title)
    sns.despine(ax=ax)
    ax.grid(False)
    

fig.tight_layout()
plt.show()

In [None]:
adatas = [
    adata,
]
conditions = ['combined', 'NR', 'DR']

fig, axs = plt.subplots(1,2,figsize=(3*5,1*2), sharex=True, sharey=True)
for i, (ax, adatasub, condition) in enumerate(zip(axs, adatas, conditions)):
    condition = conditions[i]
    x = xsign*adatasub.obsm['pcs_typegenes'][:,xi]
    y = ysign*adatasub.obsm['pcs_typegenes'][:,yi]
    
    freq_a = adatasub.obsm['size_freq_abc'][:,0]
    freq_b = adatasub.obsm['size_freq_abc'][:,1]
    freq_c = adatasub.obsm['size_freq_abc'][:,2]
    
    # visualize ABC scores using additive blending
    additive = (cmap_a(freq_a)+cmap_b(freq_b)+cmap_c(freq_c))[:,:3]
    p = ax.scatter(x, y, c=additive, s=5, edgecolor='none') 
        
    # Show ticks but no grid
    ax.set_aspect('equal')
    ax.axis('on')
    ax.grid(False)  # Turn off grid lines
    sns.despine(ax=ax)
    ax.tick_params(axis='both', which='both', bottom=True, left=True)
    ax.set_title(condition)

    # add the triangle
    add_triangle(XC, ax, vertices=True, edgecolors='k', linewidths=1, marker='o')
    # break
# plt.show()

# score based (ABC) soft-assignment

In [None]:
from scipy import stats
from statsmodels.stats.multitest import multipletests

In [None]:
def p_mark(p):
    """
    """
    
    if p > 0.05:
        mark = 'ns'
    elif p < 0.05 and p > 0.001:
        mark = '*'
    elif p < 0.001:
        mark = '***'
        
    return mark

In [None]:
res = []
for sample in names:
    if 'NR' in sample:
        cond = 'NR'
    elif 'DR' in sample:
        cond = 'DR'
        
    adatasub = adata[adata.obs['sample']==sample]
    freq_a = adatasub.obsm['size_freq_abc'][:,0]
    freq_b = adatasub.obsm['size_freq_abc'][:,1]
    freq_c = adatasub.obsm['size_freq_abc'][:,2]
    
    n = len(adatasub)
    cond_na = (freq_a+freq_b+freq_c)==0
    tn = np.sum(cond_na)
    
    rank = np.argsort(np.vstack([freq_a,freq_b,freq_c]).T[~cond_na], axis=1)[:,-1]
    ta = np.sum(rank==0)
    tb = np.sum(rank==1)
    tc = np.sum(rank==2)
    
    assert np.abs(n-(ta+tb+tc)-tn) < 1
    res.append([sample, cond, ta/n*100, tb/n*100, tc/n*100, tn/n*100])
    
res = pd.DataFrame(res, columns=['sample', 'cond', 'A', 'B', 'C', 'N']).set_index('sample')
res

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(2*4,4))
unq_lbls = ['A', 'B', 'C', 'N']
unq_colors = ['C0', 'C1', 'C2', 'gray']

allps = []
for ax, col, color in zip(axs, unq_lbls, unq_colors):
    sns.barplot(data=res, x='cond', y=col, ax=ax, color=color, capsize=0.3, errwidth=1)
    sns.swarmplot(data=res, x='cond', y=col, color='k', ax=ax, )
    ax.set_title(f'type {col}', y=1.1)
    sns.despine(ax=ax)
    ax.grid(False)
    ax.set_ylabel('')
    
    a = res[res['cond']=='NR'][col]
    # s, p = stats.mannwhitneyu(a, b)
    
    allps.append(p)
    
# rej, allqs, _, _ = multipletests(allps, method='fdr_bh')
# for ax, col, color, q in zip(axs, unq_lbls, unq_colors, allqs):
#     mark = p_mark(q)
#     # statistical annotation
#     x1, x2 = 0, 1   
#     y, h = res[col].max() + 2, 2
#     ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c='k')
#     ax.text((x1+x2)*.5, y+h, mark, ha='center', va='bottom', color='k')
    
axs[0].set_ylabel('L2/3 cells (%)')
fig.tight_layout()
plt.show()

In [None]:
x =  adata.obs['width_show2']
y =  adata.obs['depth_show2']

freq_a = adata.obsm['size_freq_abc'][:,0]
freq_b = adata.obsm['size_freq_abc'][:,1]
freq_c = adata.obsm['size_freq_abc'][:,2]

# visualize ABC scores using additive blending
additive = (cmap_a(freq_a)+cmap_b(freq_b)+cmap_c(freq_c))[:,:3]

fig, ax = plt.subplots(1,1,figsize=(1*25,1))
for lbl, coord in width_cum.items():
    ax.text(coord, 0, lbl, fontsize=12)
    
sorting = np.argsort(np.max(additive, axis=1))# [::-1]
p = ax.scatter(x[sorting], y[sorting], c=additive[sorting], s=5, edgecolor='none') 

ax.set_aspect('equal')
ax.axis('off')

plt.show()

# label-transfer-based (PC1/PC2) hard-assignment 

In [None]:
num_types = adata_merge[adata_merge.obs['modality']=='merfish'].obs.groupby(['sample', 'gated_pred_type']).size().unstack().reindex(names)
frq_types = num_types.divide(num_types.sum(axis=1), axis=0)*100
frq_types['cond'] = np.where(frq_types.index.str.contains('DR'), 'DR', 'NR')
frq_types


In [None]:
frq_types.plot.bar(stacked=True)

In [None]:
res = frq_types

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(2*3,4))
unq_lbls, unq_colors = ['L2/3_A', 'L2/3_B', 'L2/3_C', ], ['C0', 'C1', 'C2', ]
allps = []
for ax, col, color in zip(axs, unq_lbls, unq_colors):
    sns.barplot(data=res, x='cond', y=col, ax=ax, color=color, capsize=0.3, errwidth=1)
    sns.swarmplot(data=res, x='cond', y=col, color='k', ax=ax, )
    ax.set_title(f'type {col}', y=1.1)
    sns.despine(ax=ax)
    ax.grid(False)
    ax.set_ylabel('')
    
    a = res[res['cond']=='NR'][col] # /100
    # b = res[res['cond']=='DR'][col] # /100
    
    # t, p = stats.ttest_ind(a, b)
    # s, p = stats.mannwhitneyu(a, b)
    allps.append(p)
    
# rej, allqs, _, _ = multipletests(allps, method='fdr_bh')

# for ax, col, color, q in zip(axs, unq_lbls, unq_colors, allqs):
#     mark = p_mark(q)
#     # statistical annotation
#     x1, x2 = 0, 1   
#     y, h = res[col].max() + 2, 2
#     ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c='k')
#     ax.text((x1+x2)*.5, y+h, mark, ha='center', va='bottom', color='k')
    
    
axs[0].set_ylabel('L2/3 cells (%)')
fig.tight_layout()
plt.show()

In [None]:
df_plot = adata_merge[adata_merge.obs['modality']=='merfish'].obs.sample(frac=1, replace=False)
x =  df_plot['width_show2']
y =  df_plot['depth_show2']
c =  df_plot['gated_pred_type']


fig, ax = plt.subplots(1,1,figsize=(1*25,1))
for lbl, coord in width_cum.items():
    ax.text(coord, 0, lbl, fontsize=12)
    
p = ax.scatter(x, y, c=[clsts_palette2[_c] for _c in c], s=5, edgecolor='none') # , alpha=g_size[sorting])
# cond = np.max(additive, axis=1) > 1e-5
# p = ax.scatter(x[cond], y[cond], c=normed[cond], s=1, edgecolor='none')
ax.set_aspect('equal')
ax.axis('off')

plt.show()

# visualize FISH

In [None]:
gns = ['Cdh13', 'Sorcs3', 'Chrm2', 'Fos'] 
x =  adata.obs['width_show2']
y =  adata.obs['depth_show2']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*25,n*1))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 0, lbl, fontsize=12)
    
    g = np.log2(1+adata[:,gn].layers['jnorm'].reshape(-1,))
    vmax = np.percentile(g, 99)
    vmin = np.percentile(g,  5)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_title(gn, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
plt.show()
    

In [None]:
agenes

In [None]:
# gns = ['Meis2', 'Cdh12', 'Cdh4', ]#, 'Nr4a2', 'Per1', 'Egr1', 'Arc'] 
gns = ['Cdh13', 'Cdh12', 'Pcdh19', 'Astn2', 'Meis2', 'Foxp1']
x =  adata.obs['width_show2']
y =  adata.obs['depth_show2']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*25,n*1))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 0, lbl, fontsize=12)
    
    g = np.log2(1+adata[:,gn].layers['jnorm'].reshape(-1,))
    vmax = np.percentile(g, 95)
    vmin = np.percentile(g,  5)
    sorting = np.argsort(g)
    
    # p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    p = utils_merfish.st_scatter_ax(fig, ax,  x,  y,  gexp=g, s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_title(gn, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
plt.show()
    

In [None]:
gns = [
    'Cdh13', 'Trpc6', 'Chrm2',
] 
x =  adata.obs['width_show2']
y =  adata.obs['depth_show2']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*25,n*1))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 0, lbl, fontsize=12)
    
    g = np.log2(1+adata[:,gn].layers['jnorm'].reshape(-1,))
    vmax = np.percentile(g, 95)
    vmin = np.percentile(g,  5)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    # p = utils_merfish.st_scatter_ax(fig, ax,  x,  y,  gexp=g, s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_title(gn, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
plt.show()
    

In [None]:
x =  adata.obs['width_show2']
y =  adata.obs['depth_show2']
gns = [agenes, bgenes, cgenes] 
titles = ['A genes', 'B genes', 'C genes',]
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*25,n*1))
for i, (ax, gn, title) in enumerate(zip(axs, gns, titles)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 0, lbl, fontsize=12)
    
    g = adata[:,gn].layers['ljnorm'].mean(axis=1)
    sorting = np.argsort(g)
    
    vmin = np.percentile(g,  5)
    vmax = np.percentile(g, 95)
    p = utils_merfish.st_scatter_ax(fig, ax, x[sorting], y[sorting], gexp=g[sorting], 
                                    s=5, title='', vmin=vmin, vmax=vmax, cmap='coolwarm') #, axis_off=False)
    ax.set_title(title, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
plt.show()
    

In [None]:
x =  adata.obs['width_show2']
y =  adata.obs['depth_show2']
gns = [agenes, bgenes, cgenes] 
titles = ['A genes', 'B genes', 'C genes',]
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*10,n*1))
for i, (ax, gn, title) in enumerate(zip(axs, gns, titles)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 0, lbl, fontsize=12)
    
    g = adata[:,gn].layers['ljnorm'].mean(axis=1)
    sorting = np.argsort(g)
    
    vmin = np.percentile(g,  5)
    vmax = np.percentile(g, 95)
    p = utils_merfish.st_scatter_ax(fig, ax, x[sorting], y[sorting], gexp=g[sorting], 
                                    s=5, title='', vmin=vmin, vmax=vmax, cmap='coolwarm') #, axis_off=False)
    ax.set_aspect('auto')
    ax.set_title(title, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
plt.show()
    

In [None]:
x =  adata.obs['width_show2']
y =  adata.obs['depth_show2']
gns = [agenes, bgenes, cgenes] 
titles = ['A genes', 'B genes', 'C genes',]
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*10,n*1))
for i, (ax, gn, title) in enumerate(zip(axs, gns, titles)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 0, lbl, fontsize=12)
    
    g = adata[:,gn].layers['ljnorm'].mean(axis=1)
    sorting = np.argsort(g)
    
    vmin = 0.3 # np.percentile(g,  5)
    vmax = 0.6 # np.percentile(g, 95)
    p = utils_merfish.st_scatter_ax(fig, ax, x[sorting], y[sorting], gexp=g[sorting], 
                                    s=5, title='', vmin=vmin, vmax=vmax, cmap='coolwarm') #, axis_off=False)
    ax.set_aspect('auto')
    ax.set_title(title, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
plt.show()
    

# stats

In [None]:
stats = {}
bins = np.linspace(0, 400, 4*2+1)

for name in names:
    adatasub = adata[adata.obs['sample']==name]# v1l23_data[name]
    lnorm_mean, lnorm_sem, lnorm_std, n, d, db = binning_pipe2(adatasub, 'depth', 'ljnorm', bins=bins)
    stats[name] = (lnorm_mean, lnorm_sem, lnorm_std, n, d, db)
d.value_counts()

In [None]:
# mean expression level across V1 L2/3 in NR
base_a0 = []
base_b0 = []
base_c0 = []
base_i0 = []
for name in [
    'P8NRa_ant2', 'P8NRb_ant2', 'P8NRc_ant2', 'P8NRd_ant2',
    'P8NRa_pos2', 'P8NRb_pos2', 'P8NRc_pos2', 'P8NRd_pos2',
    ]:
    (lnorm_mean, lnorm_sem, lnorm_std, n, d, db) = stats[name]
    base_a = np.mean(lnorm_mean.iloc[:,agenes_idx], axis=0) # across depth bins for each gene
    base_b = np.mean(lnorm_mean.iloc[:,bgenes_idx], axis=0) # across depth bins for each gene
    base_c = np.mean(lnorm_mean.iloc[:,cgenes_idx], axis=0) # across depth bins for each gene
    base_i = np.mean(lnorm_mean.iloc[:,igenes_idx], axis=0) # across depth bins for each gene
    
    base_a0.append(base_a)
    base_b0.append(base_b)
    base_c0.append(base_c)
    base_i0.append(base_i)
    
base_a0 = np.mean(base_a0, axis=0)
base_b0 = np.mean(base_b0, axis=0)
base_c0 = np.mean(base_c0, axis=0)
base_i0 = np.mean(base_i0, axis=0)

base_a0.shape, base_b0.shape, base_c0.shape, base_i0.shape

In [None]:
means = {}
sems = {}
for name in names:
    (lnorm_mean, lnorm_sem, lnorm_std, n, d, db) = stats[name]
    
    amean = np.mean(lnorm_mean.iloc[:,agenes_idx]-base_a0, axis=1) # a bin vector
    bmean = np.mean(lnorm_mean.iloc[:,bgenes_idx]-base_b0, axis=1) # a bin vector
    cmean = np.mean(lnorm_mean.iloc[:,cgenes_idx]-base_c0, axis=1) # a bin vector
    imean = np.mean(lnorm_mean.iloc[:,igenes_idx]-base_i0, axis=1) # a bin vector
    
    asem = np.mean(lnorm_sem.iloc[:,agenes_idx], axis=1)
    bsem = np.mean(lnorm_sem.iloc[:,bgenes_idx], axis=1)
    csem = np.mean(lnorm_sem.iloc[:,cgenes_idx], axis=1)
    isem = np.mean(lnorm_sem.iloc[:,igenes_idx], axis=1)
    
    means[name] = [amean, bmean, cmean, imean]
    sems[name] = [asem, bsem, csem, isem]
    

In [None]:
midpoints = np.mean(np.vstack([bins[:-1], bins[1:]]), axis=0)
midpoints

In [None]:
names

In [None]:
gnames = ['A genes (n=64)', 'B genes (n=35)', 'C genes (n=71)']

fig, axs = plt.subplots(2, 4, figsize=(5*4,4*2), sharex=True, sharey=True)

# ax.set_title('P28NR')
linestyle = '-'
for ax, name in zip(axs.flat, names):
    # (lnorm_mean, lnorm_sem, lnorm_std, n, d, db) = stats[name]
    amean, bmean, cmean, imean = means[name]
    asem, bsem, csem, isem = sems[name]
    
    x = midpoints
    ax.plot(x, amean, label='A genes', color='C0', linestyle=linestyle)
    ax.fill_between(x, amean-asem, amean+asem, color='C0', alpha=0.1, edgecolor='none')
    ax.plot(x, bmean, label='B genes', color='C1', linestyle=linestyle)
    ax.fill_between(x, bmean-bsem, bmean+bsem, color='C1', alpha=0.1, edgecolor='none')
    ax.plot(x, cmean, label='C genes', color='C2', linestyle=linestyle)
    ax.fill_between(x, cmean-csem, cmean+csem, color='C2', alpha=0.1, edgecolor='none')
    ax.axhline(color='lightgray', linestyle='dotted', zorder=1)

    sns.despine(ax=ax)
    ax.set_xticks([0, 100, 200, 300])
    # ax.set_xlim(left=100, right=350)
    ax.set_xlim(left=0, right=400)
    # ax.set_ylim([-0.4, 0.4])
    ax.grid(False)
    ax.set_title(name)
axs.flat[0].set_ylabel('mean (expr. +/- sem)')

    
fig.subplots_adjust(wspace=0.1)
# powerplots.savefig_autodate(fig, outdatadir+'/grant_saumya_lineq_abc_v3.pdf')

In [None]:
samp_gene_dpth_mat = np.array([np.array(means[name]) for name in names]) 
print(samp_gene_dpth_mat.shape) # sample, gene group, depth

nr_mat = samp_gene_dpth_mat[:4]
nr_mean = np.mean(nr_mat, axis=0) # gene group, depth
nr_sem  = np.std(nr_mat, axis=0)/np.sqrt(4) # gene group, depth

dr_mat = samp_gene_dpth_mat[4:]
dr_mean = np.mean(dr_mat, axis=0) # gene group, depth
dr_sem  = np.std(dr_mat, axis=0)/np.sqrt(4) # gene group, depth
nr_mean.shape, dr_mean.shape

In [None]:
# t-test between NR and DR for each gene group and each location
from scipy import stats
from statsmodels.stats.multitest import multipletests

ts, ps = stats.ttest_ind(nr_mat, dr_mat)
rejs, qs, _, _ = multipletests(np.nan_to_num(ps, nan=1).reshape(-1,), alpha=0.05, method='fdr_bh')
qs = qs.reshape(ps.shape)
nrdr_mean = np.stack([nr_mean, dr_mean], axis=2).mean(axis=2)

In [None]:
linestyles = ['-', '--']
data_mean = [nr_mean, dr_mean]
data_sem = [nr_sem, dr_sem]
gnames = ['A genes', 'B genes', 'C genes']
titles = gnames
colors = ['C0', 'C1', 'C2']
labels = ['NR', 'DR']
sigs = qs
allmeans = nrdr_mean

fig, axs = plt.subplots(1, 3, figsize=(5*3,4), sharex=True, sharey=True)
for i, (ax, gname, color) in enumerate(zip(axs, gnames, colors)):
    ax.axhline(color='lightgray', linestyle='dotted', zorder=1)
    for cond_mean, cond_sem, title, linestyle in zip(data_mean, data_sem, titles, linestyles):
        ax.plot(midpoints, cond_mean[i], label=gname, color=color, linestyle=linestyle, marker='o', markersize=5)
        ax.fill_between(midpoints, cond_mean[i]-cond_sem[i], cond_mean[i]+cond_sem[i], color=color, alpha=0.1, edgecolor='none')
        
    for _x, _y, _sig in zip(midpoints, allmeans[i], sigs[i]):
        if _sig < 1e-3:
            ax.text(_x, _y, "***", ha='left', va='center', fontsize=12, rotation=90)
            ax.vlines(_x, _y-0.02, _y+0.02, color='k', linewidth=0.5)
        elif _sig < 5e-2:
            ax.text(_x, _y, "*", ha='left', va='center', fontsize=12, rotation=90)
            ax.vlines(_x, _y-0.02, _y+0.02, color='k', linewidth=0.5)

    sns.despine(ax=ax)
    ax.set_xticks([0, 100, 200, 300, 400])
    ax.set_xlim(left=50, right=400)
    # ax.set_ylim([-0.2, 0.3])
    ax.grid(False)
    ax.set_title(gname)
    ax.set_xlabel('upper->lower cortical depth')
    
axs[0].set_ylabel('mean (expr. +/- sem)')
fig.subplots_adjust(wspace=0.1)
# powerplots.savefig_autodate(fig, outdatadir+'/grant_saumya_lineq_abc_v3.pdf')
plt.show()