In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
from scipy import stats
import os

from scipy import spatial
from scipy import sparse
from scipy.interpolate import CubicSpline
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import networkx as nx
from umap import UMAP
from scipy.stats import ttest_ind, mannwhitneyu
from scipy.stats import pearsonr, spearmanr, zscore
from statsmodels.stats.multitest import multipletests

import json

In [None]:
import importlib
import scroutines
importlib.reload(scroutines)
from scroutines import powerplots
from scroutines.miscu import is_in_polygon

import utils_merfish
importlib.reload(utils_merfish)
from utils_merfish import rot2d, st_scatter, st_scatter_ax, plot_cluster, binning
from utils_merfish import RefLineSegs

import merfish_datasets
import merfish_genesets
importlib.reload(merfish_datasets)
importlib.reload(merfish_genesets)
from merfish_datasets import merfish_datasets
from merfish_datasets import merfish_datasets_params

from scroutines import basicu

In [None]:
def get_qc_metrics(df):
    """
    return metrics
     - key
      - (name, val, medval, bins)
    """
    metrics = {}
    cols  = ['volume', 'gncov', 'gnnum']
    names = ['cell volume', 'num transcripts', 'num genes']
    
    for col, name in zip(cols, names):
        val = df[col].values
        medval = np.median(val)
        bins = np.linspace(0, 10*medval, 50)
        
        metrics[col] = (name, val, medval, bins)
    return metrics

def get_norm_counts(adata, scaling=500):
    """norm - equalize the volume to be 500 for all cells
    """
    cnts = adata.X
    vol = adata.obs['volume'].values
    normcnts = cnts/vol.reshape(-1,1)*scaling
    adata.layers['norm'] = normcnts
    
    return normcnts

In [None]:
def preprocessing(adata):
    # filter genes
    cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
    adata_sub = adata[:,cond]

    # counts
    x = adata_sub.X
    cov = adata_sub.obs['n_counts'].values

    # CP10k
    xn = (sparse.diags(1/cov).dot(x))*1e4

    # log10(CP10k+1)
    xln = xn.copy()
    xln.data = np.log10(xln.data+1)

    adata_sub.layers['norm'] = xn
    adata_sub.layers['lognorm'] = xln
    
    return adata_sub

In [None]:
def get_hvgs(adata, layer, nbin=20, qth=0.3):
    """
    """
    xn = adata.layers[layer]
    
    # min
    gm = np.ravel(xn.mean(axis=0))

    # var
    tmp = xn.copy()
    tmp.data = np.power(tmp.data, 2)
    gv = np.ravel(tmp.mean(axis=0))-gm**2

    # cut 
    lbl = pd.qcut(gm, nbin, labels=np.arange(nbin))
    gres = pd.DataFrame()
    gres['lbl'] = lbl
    gres['mean'] = gm
    gres['var'] = gv
    gres['ratio']= gv/gm

    # select
    gres_sel = gres.groupby('lbl')['ratio'].nlargest(int(qth*(len(gm)/nbin))) #.reset_index()
    gsel_idx = np.sort(gres_sel.index.get_level_values(1).values)

    assert np.all(gsel_idx != -1)
    
    return adata.var.index.values[gsel_idx]

In [None]:
def binning_pipe2(adata, col_to_bin, layer, bins=None, n=20):
    """
    """
    if bins is None:
        # bin it 
        bins, binned = utils_merfish.binning(adata.obs[col_to_bin].values, n)
    else:
        binned = pd.cut(adata.obs[col_to_bin].values, bins=bins)

    norm_ = pd.DataFrame(adata.layers[layer], columns=adata.var.index)
    norm_['thebin'] = binned
    
    norm_mean = norm_.groupby('thebin').mean(numeric_only=True)
    norm_sem  = norm_.groupby('thebin').sem(numeric_only=True)
    norm_std  = norm_.groupby('thebin').std(numeric_only=True)
    norm_n    = norm_['thebin'].value_counts(sort=False)

    return norm_mean, norm_sem, norm_std, norm_n, binned, bins 

In [None]:
def neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(ref_emb)
    dists, idx = neigh.kneighbors(qry_emb, k, return_distance=True)
    
    raw_pred = ref_lbl[idx]

    # p
    pabc = np.empty((qry_n, n_unq_lbls))
    for i, lbl in enumerate(unq_lbls):
        p = np.sum(raw_pred==lbl, axis=1)/k
        pabc[:,i] = p

    # max
    max_pred = unq_lbls[np.argmax(pabc, axis=1)]

    # 
    gated_pred = max_pred.copy()
    cond1 = np.max( pabc, axis=1) > p_cutoff
    gated_pred[~cond1] = 'NA' 
    if dist_cutoff is not None:
        cond2 = np.max(dists, axis=1) < dist_cutoff
        gated_pred[~cond2] = 'NA' 
    
    return max_pred, gated_pred, np.max(dists, axis=1)


def neighbor_self_nonself(k, ref_emb, qry_emb):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    lbls = np.array([0]*ref_n+[1]*qry_n)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(np.vstack([ref_emb, qry_emb]))
    idx = neigh.kneighbors(qry_emb, k, return_distance=False)
    
    isself = lbls[idx]

    p = np.sum(isself, axis=1)/k

    
    return p # max_pred, gated_pred, np.max(dists, axis=1)

In [None]:
from py_pcha import PCHA
def get_aa(X):
    """
    """
    np.random.seed(0)
    XC, S, C, SSE, varexpl = PCHA(X, noc=3, delta=0)
    XC = np.array(XC)
    XC = XC[:,np.argsort(XC[0])].copy() # order this
    return XC

In [None]:
def add_triangle(XC, ax, zorder=0, vertices=False, label='', linecolor='gray', linewidth=1, **kwargs):
    # add the triangle
    ax.plot(XC[0].tolist()+[XC[0,0]], XC[1].tolist()+[XC[1,0]], '--', 
            color=linecolor, label=label, zorder=zorder, linewidth=linewidth, markersize=3)
    
    # add vertices
    if vertices:
        ax.scatter(XC[0,0], XC[1,0], color='C0', zorder=zorder, **kwargs)
        ax.scatter(XC[0,1], XC[1,1], color='C1', zorder=zorder, **kwargs)
        ax.scatter(XC[0,2], XC[1,2], color='C2', zorder=zorder, **kwargs)

In [None]:
def p_mark(p):
    """
    """
    
    if p > 0.05:
        mark = 'ns'
    elif p < 0.05 and p > 0.001:
        mark = '*'
    elif p < 0.001:
        mark = '***'
        
    return mark

In [None]:
def wrap_label_transfer(ref, qry, key_emb, key_lbl, k=30):
    """write results into 
    qry.obs[key_lbl]
    """
    
    # # label transfer from RNA data
    ref_emb = ref.obsm[key_emb][:,:2]
    qry_emb = qry.obsm[key_emb][:,:2]
    ref_lbl = ref.obs[key_lbl].values.astype(str)

    qry_lbl, _, _ = neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None)
    qry.obs[key_lbl] = qry_lbl
    
    return 

In [None]:
def get_abc_scores(adata, agenes, bgenes, cgenes):
    """
    """
    
    # get ABC scores
    g0_a = zscore(adata[:,agenes].layers['ljnorm'], axis=0).mean(axis=1)
    g0_b = zscore(adata[:,bgenes].layers['ljnorm'], axis=0).mean(axis=1)
    g0_c = zscore(adata[:,cgenes].layers['ljnorm'], axis=0).mean(axis=1)

    # make ABC scores comparable and norm to [0,1] [40% to 95%]
    vmin_p, vmax_p = 40, 95
    vmin_a = np.percentile(g0_a, vmin_p)
    vmax_a = np.percentile(g0_a, vmax_p)

    vmin_b = np.percentile(g0_b, vmin_p)
    vmax_b = np.percentile(g0_b, vmax_p)

    vmin_c = np.percentile(g0_c, vmin_p)
    vmax_c = np.percentile(g0_c, vmax_p)

    g0_a = np.clip((g0_a-vmin_a)/(vmax_a-vmin_a), 0, 1)
    g0_b = np.clip((g0_b-vmin_b)/(vmax_b-vmin_b), 0, 1)
    g0_c = np.clip((g0_c-vmin_c)/(vmax_c-vmin_c), 0, 1)

    # separate them into scale and frequency (mag 0~3 vs direction 0 or 1)
    g0_sum  = (g0_a+g0_b+g0_c)
    freq0_a = g0_a/(g0_sum+1e-5)
    freq0_b = g0_b/(g0_sum+1e-5)
    freq0_c = g0_c/(g0_sum+1e-5)

    # record
    adata.obsm['size_freq_abc'] = np.vstack([freq0_a, freq0_b, freq0_c, g0_sum]).T
    
    return

def get_abc_stats(adata, samples):
    """per sample
    
    assigned to the best score
    some na due to 0,0,0
    """
    # score based (ABC) assignment
    res = []
    for sample in samples:
        if 'ant' in sample:
            cond = 'ant'
        elif 'pos' in sample:
            cond = 'pos'

        adatasub = adata[adata.obs['sample']==sample]
        freq_a = adatasub.obsm['size_freq_abc'][:,0]
        freq_b = adatasub.obsm['size_freq_abc'][:,1]
        freq_c = adatasub.obsm['size_freq_abc'][:,2]

        n = len(adatasub)
        cond_na = (freq_a+freq_b+freq_c)==0
        tn = np.sum(cond_na)

        rank = np.argsort(np.vstack([freq_a,freq_b,freq_c]).T[~cond_na], axis=1)[:,-1]
        ta = np.sum(rank==0)
        tb = np.sum(rank==1)
        tc = np.sum(rank==2)

        assert np.abs(n-(ta+tb+tc)-tn) < 1
        res.append([sample, cond, ta/n*100, tb/n*100, tc/n*100, tn/n*100])

    res = pd.DataFrame(res, columns=['sample', 'cond', 'L2/3_A', 'L2/3_B', 'L2/3_C', 'NA']).set_index('sample')
    return res

def get_abc_stats_typeready(adata, samples, sample_col, type_col):
    """
    """
    num_types = adata.obs.groupby([sample_col, type_col]).size().unstack().reindex(samples)
    frq_types = num_types.divide(num_types.sum(axis=1), axis=0)*100
    frq_types['cond'] = np.where(frq_types.index.str.contains('pos'), 'pos', 'ant')
    
    return frq_types

# load data

In [None]:
outfigdir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures/250414"
!mkdir -p $outfigdir
fig_manager = powerplots.FigManager(outfigdir)

In [None]:
np.random.seed(0)

### MERFISH genes

In [None]:
f = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/merfish_genes.txt" 
genes = np.loadtxt(f, dtype='str')

genesets, typegenes_df = merfish_genesets.get_all_genesets()
for key, item in genesets.items():
    print(key, len(item))
    
genes = genesets['allmerfish']

agenes = genesets['a']
bgenes = genesets['b']
cgenes = genesets['c']
igenes = genesets['i']
iegs = igenes
genes_noniegs = np.array([g for g in genes if g not in iegs])

agenes_idx = basicu.get_index_from_array(genes, agenes)
bgenes_idx = basicu.get_index_from_array(genes, bgenes)
cgenes_idx = basicu.get_index_from_array(genes, cgenes)
igenes_idx = basicu.get_index_from_array(genes, igenes)

### new ABC genes

In [None]:
f = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/DEG_l23abc_gene_list_250409.csv"
df_genes_newabc = pd.read_csv(f)
df_genes_newabc_p8 = df_genes_newabc[df_genes_newabc['cond']=='P8NR']
df_genes_newabc_p14 = df_genes_newabc[df_genes_newabc['cond']=='P14NR']
df_genes_newabc_p21 = df_genes_newabc[df_genes_newabc['cond']=='P21NR']

abcgenes = df_genes_newabc['gene'].unique()

p8_agenes = df_genes_newabc_p8.loc[df_genes_newabc_p8['archetype']=='A', 'gene'].unique()
p8_bgenes = df_genes_newabc_p8.loc[df_genes_newabc_p8['archetype']=='B', 'gene'].unique()
p8_cgenes = df_genes_newabc_p8.loc[df_genes_newabc_p8['archetype']=='C', 'gene'].unique()

p14_agenes = df_genes_newabc_p14.loc[df_genes_newabc_p14['archetype']=='A', 'gene'].unique()
p14_bgenes = df_genes_newabc_p14.loc[df_genes_newabc_p14['archetype']=='B', 'gene'].unique()
p14_cgenes = df_genes_newabc_p14.loc[df_genes_newabc_p14['archetype']=='C', 'gene'].unique()

p21_agenes = df_genes_newabc_p21.loc[df_genes_newabc_p21['archetype']=='A', 'gene'].unique()
p21_bgenes = df_genes_newabc_p21.loc[df_genes_newabc_p21['archetype']=='B', 'gene'].unique()
p21_cgenes = df_genes_newabc_p21.loc[df_genes_newabc_p21['archetype']=='C', 'gene'].unique()

# agenes = np.union1d(genesets['a'], df_genes_newabc.loc[df_genes_newabc['archetype']=='A', 'gene'].unique())
# bgenes = np.union1d(genesets['b'], df_genes_newabc.loc[df_genes_newabc['archetype']=='B', 'gene'].unique())
# cgenes = np.union1d(genesets['c'], df_genes_newabc.loc[df_genes_newabc['archetype']=='C', 'gene'].unique())

# print(len(abcgenes), 
#       len(agenes)+len(bgenes)+len(cgenes),
#       len(agenes),len(bgenes),len(cgenes), 
#      )

# overlap with MERFISH
p8_agenes = np.intersect1d(genes, p8_agenes)
p8_bgenes = np.intersect1d(genes, p8_bgenes)
p8_cgenes = np.intersect1d(genes, p8_cgenes)

p14_agenes = np.intersect1d(genes, p14_agenes)
p14_bgenes = np.intersect1d(genes, p14_bgenes)
p14_cgenes = np.intersect1d(genes, p14_cgenes)

p21_agenes = np.intersect1d(genes, p21_agenes)
p21_bgenes = np.intersect1d(genes, p21_bgenes)
p21_cgenes = np.intersect1d(genes, p21_cgenes)

# print(len(abcgenes), 
#       len(agenes)+len(bgenes)+len(cgenes),
#       len(agenes),len(bgenes),len(cgenes), 
#      )

p8_agenes_idx = basicu.get_index_from_array(genes, p8_agenes)
p8_bgenes_idx = basicu.get_index_from_array(genes, p8_bgenes)
p8_cgenes_idx = basicu.get_index_from_array(genes, p8_cgenes)

p14_agenes_idx = basicu.get_index_from_array(genes, p14_agenes)
p14_bgenes_idx = basicu.get_index_from_array(genes, p14_bgenes)
p14_cgenes_idx = basicu.get_index_from_array(genes, p14_cgenes)

p21_agenes_idx = basicu.get_index_from_array(genes, p21_agenes)
p21_bgenes_idx = basicu.get_index_from_array(genes, p21_bgenes)
p21_cgenes_idx = basicu.get_index_from_array(genes, p21_cgenes)

### MERFISH cells (integrated L2/3)

In [None]:
ddir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/organized" 

# get MERFISH data (cells) from integrated V1L23Glut 
fin = os.path.join(ddir, 'P21NRDR_v1l23glut_rna_merfish_250411.h5ad')
adata_ = ad.read(fin, backed='r')
adata_ = adata_[adata_.obs['modality']=='merfish'] 
print(adata_.obs['gated_pred_subclass'].unique()) # should be L2/3 only

l23_cells = adata_.obs.index.values
l23_cells.shape

### MERFISH raw - all genes

In [None]:
# %%time
names = [
    'P21NRb_ant',
    'P21NRb_pos',
    
    'P21NRc_ant',
    'P21NRc_ant2',
    'P21NRc_pos2',
    
    'P21DRa_ant', 
    'P21DRa_pos',
    
    'P21DRb_ant', 
    'P21DRb_pos',
]


mean_total_rna_target = 500
adata_merged = []

for i, name in enumerate(names):
    j = i // 4
    i = i % 4
    
    adatasub = ad.read(os.path.join(ddir, f'{name}_l2_v1_250410.h5ad')) 
    adatasub = adatasub[adatasub.obs['transcript_count'] > 50].copy()
    print(name, len(adatasub))
    
    adatasub.obs.index = np.char.add(f'{name}', adatasub.obs.index.values)
    adatasub.obs['sample'] = name
    
    norm_cnts = adatasub.layers['norm']
    # mean_per_batch = np.mean(norm_cnts.sum(axis=1))
    mean_per_batch_noniegs = np.mean(adatasub[:,genes_noniegs].layers['norm'].sum(axis=1))
    
    adatasub.layers['jnorm']  = norm_cnts*(mean_total_rna_target/mean_per_batch_noniegs)
    adatasub.layers['ljnorm'] = np.log2(1+adatasub.layers['jnorm'])
    
    adatasub.obs['norm_transcript_count']  = adatasub.layers['norm'].sum(axis=1)
    adatasub.obs['jnorm_transcript_count'] = adatasub.layers['jnorm'].sum(axis=1)
    
    adatasub.obs['depth_show'] = -adatasub.obs['depth'].values - i*400 # name
    adatasub.obs['width_show'] =  adatasub.obs['width'].values - np.min(adatasub.obs['width'].values) + j*2500   # name
    
    adata_merged.append(adatasub)
    
adata_merged = ad.concat(adata_merged)

In [None]:
clsts_palette2 = {
    'L2/3_A': 'C0',    
    'L2/3_B': 'C1',    
    'L2/3_C': 'C2',    
    'NA': 'gray',
}

from matplotlib.colors import LinearSegmentedColormap

# ABC map
colors_a = [(0.0, 'black'), (1.0, 'C0')]      
colors_b = [(0.0, 'black'), (1.0, 'C1')]      
colors_c = [(0.0, 'black'), (1.0, 'C2')]      
cmap_a = LinearSegmentedColormap.from_list('cmap_a', colors_a)
cmap_b = LinearSegmentedColormap.from_list('cmap_b', colors_b)
cmap_c = LinearSegmentedColormap.from_list('cmap_c', colors_c)

### MERFISH raw - get high-qual L2/3
- using anatomical features to filter out cells
- using transcript counts 
- check cell density - insignificant

In [None]:
# filter by L2/3 label
adata_l23 = adata_merged[l23_cells].copy()

# filter by depth and by counts
conds = np.logical_and(
    adata_l23.obs['depth']    < 450,
    adata_l23.obs['depth']    < 450,
    # adata_l23.obs['transcript_count'] > 50,
)
adata_mer = adata_l23[conds].copy()

# add depth_norm
adata_mer.obs['depth_norm'] = 0.5
for name in names:
    adatasub = adata_mer[adata_mer.obs['sample']==name]
    depths = adatasub.obs['depth'].values
    depth_min = np.percentile(depths,  1)
    depth_max = np.percentile(depths, 99)
    depth_norm = (depths-depth_min)/(depth_max-depth_min)
    adata_mer.obs.loc[adatasub.obs.index, 'depth_norm'] = depth_norm
    
# filter by depth norm
adata_mer = adata_mer[adata_mer.obs['depth_norm'] < 1].copy()

In [None]:
colors = ['C1', 'C1', 'k', 'C2', 'C2']
fig, axs = plt.subplots(3, 1, figsize=(1*8,4*3))
for i, name in enumerate(names):
    j = i // 4
    color = colors[j]
    adatasub0 = adata_l23[adata_l23.obs['sample']==name]
    adatasub = adata_mer[adata_mer.obs['sample']==name]
    
    sns.histplot(adatasub0.obs['depth'].values, element='step', fill=False, stat='density', color=color, ax=axs[0])
    sns.histplot(adatasub.obs['depth'].values, element='step', fill=False, stat='density', color=color, ax=axs[1])
    sns.histplot(adatasub.obs['depth_norm'].values, element='step', fill=False, stat='density', color=color, ax=axs[2])

### summarize things needed

In [None]:
# add cols
adata_mer.obs['modality'] = 'merfish'
lognorm_mer = adata_mer.layers['ljnorm']  
adata_mer.layers['zscore'] = zscore(lognorm_mer, axis=0)
adata_mer

# Downstream calc

In [None]:
selected_samples = ['P21NRb_ant', 'P21NRc_ant', 'P21DRa_ant', 'P21DRb_ant']
window_size = 1000
gap = 50

adata_mersel = adata_mer[adata_mer.obs['sample'].isin(selected_samples)].copy()

# select by width - mid +/- half window size
width_min = adata_mersel.obs.groupby('sample')['width'].min().reindex(selected_samples)
width_max = adata_mersel.obs.groupby('sample')['width'].max().reindex(selected_samples)
width_rng = width_max - width_min 
width_mid = width_min+1/2*(width_rng)

width_cum = pd.Series(np.arange(0, (window_size+gap)*len(selected_samples), window_size+gap), index=selected_samples)
adata_mersel.obs['width_mid0']  =  adata_mersel.obs['width']    - width_mid.reindex(adata_mersel.obs['sample']).values
adata_mersel = adata_mersel[np.abs(adata_mersel.obs['width_mid0']) < window_size/2].copy()
adata_mersel.obs['width_show2'] =  adata_mersel.obs['width_mid0'] + width_cum.reindex(adata_mersel.obs['sample']).values

# select by depth
depth_min = adata_mersel.obs.groupby('sample')['depth'].min().reindex(selected_samples)
adata_mersel.obs['depth_show2'] = -(adata_mersel.obs['depth']-depth_min.reindex(adata_mersel.obs['sample']).values)


In [None]:
# gns = ['Meis2', 'Epha6', 'Pcdh19', 'Astn2', 'Kcnh5', 'Chrm2', 'Sorcs3', 'Nptx2', 'Phf21b'] 
# gns = ['Syt10', 'Bdnf', 'Kcnn3'] 
# gns = ['Sox5', 'Npas4', 'Cacna2d3', 'Scg3', 'Phf21b'] 
# gns = ['Tox', 'Pcdh7', 'Grik4', 'Adamts2'] 
# gns = ['Nfib', 'Shisa9', 'Tenm1', 'Magi1', 'Met', ] 
# gns = ['Sema6d', 'Scn1a']
# gns = ['Robo3', 'Robo1', 'Cntn4', 'Sdk1', 'Chrm3', 'Chrm2']# 'Epha10', 'Foxp1']
# gns = ['Col19a1', 'Kcnq5', 'Ntn4', 'Nxph2']#, 'Robo1', 'Cntn4', 'Sdk1', 'Chrm3', 'Chrm2']# 'Epha10', 'Foxp1']

gns = ['Meis2', 'Epha6', 'Astn2', 'Kcnh5', 'Nptx2', 'Phf21b', 'Fos', 'Junb', 'Nr4a2', 'Egr4'] 
# gns = ['Fos', 'Nr4a2', ]

print([g for g in gns if g not in genes])
x =  adata_mersel.obs['width_show2']
y =  adata_mersel.obs['depth_show2']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*20,n*1))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 100, lbl, fontsize=12, ha='center')
    
    g = adata_mersel[:,gn].layers['ljnorm'].reshape(-1,)
    vmax = np.percentile(g, 99)
    vmin = np.percentile(g,  0)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    # p = utils_merfish.st_scatter_ax(fig, ax,  x,  y,  gexp=g, s=5, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_title(gn, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
# fig_manager.savefig(fig)
plt.show()
    

In [None]:
p21_agenes, igenes

In [None]:
gn_groups = [
    p21_agenes,
    p21_bgenes,
    p21_cgenes,
    igenes,
]
titles = ["A genes", "B genes", "C genes", "IEGs"]

x =  adata_mersel.obs['width_show2']
y =  adata_mersel.obs['depth_show2']
n = len(gn_groups)
print(n)

fig, axs = plt.subplots(n,1,figsize=(1*20,n*1))
for i, (ax, gn_group) in enumerate(zip(axs, gn_groups)):
    title = titles[i]
    
    if i == 0:
        for lbl, coord in width_cum.items():
            ax.text(coord, 100, lbl, fontsize=12, ha='center')
    
    gmat = adata_mersel[:,gn_group].layers['ljnorm'] #.reshape(-1,)
    gmat_max = np.percentile(gmat, 98, axis=0)
    g = np.mean(gmat/gmat_max, axis=1)
                
    vmax = np.percentile(g, 90)
    vmin = np.percentile(g, 10)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x,  y,  gexp=g, s=8, title='', vmin=vmin, vmax=vmax, cmap='coolwarm')
    ax.set_title(title, loc='left', va='center', ha='right', y=0.5, pad=None)
    # fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
# fig_manager.savefig(fig)
plt.show()
    

In [None]:
gmat.shape

# FISH stats

In [None]:
# stats = {}
# bins = np.linspace(0, 1, 4*3+1)
# midpoints = np.mean(np.vstack([bins[:-1], bins[1:]]), axis=0)

# for name in names:
#     adatasub = adata_mer[adata_mer.obs['sample']==name].copy() # v1l23_data[name]
#     lnorm_mean, lnorm_sem, lnorm_std, n, d, db = binning_pipe2(adatasub, 'depth_norm', 'ljnorm', bins=bins)
#     stats[name] = (lnorm_mean, lnorm_sem, lnorm_std, n, d, db)

In [None]:
stats2 = {}
bins = np.linspace(0, 1, 10+1)
midpoints = np.mean(np.vstack([bins[:-1], bins[1:]]), axis=0)
print(names)

for name in names:
    adatasub = adata_mer[adata_mer.obs['sample']==name].copy() # v1l23_data[name]
    norm_mean, norm_sem, norm_std, n, d, db = binning_pipe2(adatasub, 'depth_norm', 'jnorm', bins=bins)
    stats2[name] = (norm_mean, norm_sem, norm_std, n, d, db)
    
# each gene to its max (across samples)
norm_mean_max = np.zeros(500)
for name in names:
    (norm_mean, norm_sem, norm_std, n, d, db) = stats2[name]
    norm_mean_max = np.maximum(norm_mean_max, norm_mean.max().values)

# norm to fraction of max 
for name in names:
    (norm_mean, norm_sem, norm_std, n, d, db) = stats2[name]
    norm_mean = norm_mean/norm_mean_max
    norm_sem  = norm_sem /norm_mean_max
    norm_std  = norm_std /norm_mean_max
    stats2[name] = (norm_mean, norm_sem, norm_std, n, d, db)

# select P21 genes

In [None]:
def get_gene_group_matrix(stats, names, gene_group_idx_list):
    """
    stats[name] contains depth by gene matrix for every gene
    
    reduce it to gene_group by depth
    """

    means = {}
    sems = {}
    
    for name in names:
        (norm_mean, norm_sem, norm_std, n, d, db) = stats[name]
        
        means[name] = []
        sems[name] = []
        for gene_group_idx in gene_group_idx_list:
            i_mean = np.mean(norm_mean.iloc[:,gene_group_idx], axis=1) # a bin vector; mean over genes
            i_sem = np.mean(norm_sem.iloc[:,gene_group_idx], axis=1) # a bin vector; mean over genes
            means[name].append(i_mean)
            sems[name].append(i_sem)

    mean_mat = np.array([np.array(means[name]) for name in names]) 
    sem_mat = np.array([np.array(sems[name]) for name in names]) 
    print(mean_mat.shape) # sample, gene group, depth
    
    return mean_mat, sem_mat #, sems

In [None]:
# mean expression level across V1 L2/3 in NR
mean_mat_p21, sem_mat_p21 = get_gene_group_matrix(stats2, names, 
                                                   [p21_agenes_idx,
                                                    p21_bgenes_idx,
                                                    p21_cgenes_idx,
                                                    igenes_idx,
                                                   ])

data_mean = [
    np.mean(mean_mat_p21[ :5], axis=0),
    np.mean(mean_mat_p21[5: ], axis=0),
    ]
    
data_sem  = [
    np.std(mean_mat_p21[:5], axis=0)/np.sqrt(5),
    np.std(mean_mat_p21[5:], axis=0)/np.sqrt(4),
]

In [None]:
titles = ['P21NR', 'P21DR']
gnames = ['A genes', 'B genes', 'C genes', ]
colors = ['C0', 'C1', 'C2', ]

fig, axs = plt.subplots(1, 2, figsize=(8.5,4*1), sharex=True, sharey=True)

linestyle = '-'
x = midpoints

for j in range(2):
    ax = axs[j]
    title = titles[j]
    ymean = data_mean[j] 
    yerr  = data_sem[j]

    for k, (gname, color) in enumerate(zip(gnames, colors)):
        ax.plot(x, ymean[k], label=gname, color=color, linestyle=linestyle)
        ax.fill_between(x, ymean[k]-yerr[k], ymean[k]+yerr[k], color=color, alpha=0.1, edgecolor='none')


    sns.despine(ax=ax)
    ax.grid(False)
    ax.set_title(title)
    ax.set_xlabel('top->bottom L2/3')
        
axs[1].legend(bbox_to_anchor=(1,1))
axs[0].set_xticks([0, 0.5, 1])
axs[0].set_ylabel('frac. max expr.')
fig.tight_layout()
fig_manager.savefig(fig)
plt.show()

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(5*5,4*2), sharex=True, sharey=True)
linestyle = '-'
for i, (ax, name) in enumerate(zip(axs.flat, names)):
    
    persamp_mat_mean = mean_mat_p21[i]
    persamp_mat_sem  = sem_mat_p21[i]
    x = midpoints
    
    for row1, row2 in zip(persamp_mat_mean[:3], persamp_mat_sem[:3]):
        ax.plot(x, row1, linestyle=linestyle)
        ax.fill_between(x, row1-row2, row1+row2, alpha=0.1, edgecolor='none')
        
    sns.despine(ax=ax)
    # ax.set_xticks([0, 100, 200, 300, 400])
    # ax.set_xlim(left=50, right=400)
    # ax.set_ylim([-0.4, 0.4])
    ax.grid(False)
    ax.set_title(name)
    
axs.flat[0].set_ylabel('frac of max expr (+/- sem)')
fig.subplots_adjust(wspace=0.1)

# fig_manager.savefig(fig)
plt.show()

In [None]:
data_mean

In [None]:
titles = ['P21NR', 'P21DR']
gnames = ['B genes', 'IEGs', ]
colors = ['C1', 'C3', ]
channels = [1,3]

fig, axs = plt.subplots(1, 2, figsize=(4*2,4*1), sharex=True, sharey=True)

linestyle = '-'
x = midpoints

for j in range(2):
    ax = axs[j]
    title = titles[j]
    ymean = data_mean[j] 
    yerr  = data_sem[j]

    for k, (gname, color) in enumerate(zip(gnames, colors)):
        
        channel = channels[k]
        ax.plot(x, ymean[channel], label=gname, color=color, linestyle=linestyle)
        ax.fill_between(x, ymean[channel]-yerr[channel], ymean[channel]+yerr[channel], color=color, alpha=0.1, edgecolor='none')


    sns.despine(ax=ax)
    ax.grid(False)
    ax.set_title(title)
    ax.set_xlabel('top->bottom L2/3')
        
axs[0].set_xticks([0, 0.5, 1])
axs[0].set_ylabel('fraction of max expression')
fig.tight_layout()
fig_manager.savefig(fig)
plt.show()

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(3*5,3*2), sharex=True, sharey=True)
linestyle = '-'
for i, (ax, name) in enumerate(zip(axs.flat, names)):
    
    persamp_mat_mean = mean_mat_p21[i]
    persamp_mat_sem  = sem_mat_p21[i]
    x = midpoints
    
    row1 = persamp_mat_mean[-1]
    row2 = persamp_mat_sem[-1]
    ax.plot(x, row1, linestyle=linestyle, color='C3')
    ax.fill_between(x, row1-row2, row1+row2, alpha=0.1, edgecolor='none', color='C3')
        
    sns.despine(ax=ax)
    ax.grid(False)
    ax.set_title(name)
    
axs.flat[0].set_xticks([0, 0.5, 1])
axs.flat[0].set_ylabel('frac. max expr.')
axs.flat[0].set_xlabel('top->bottom L2/3')
# fig.subplots_adjust(wspace=0.1)
fig.tight_layout()

# fig_manager.savefig(fig)
plt.show()

In [None]:
def per_channel_plot(channel):
    fig, axs = plt.subplots(1, 2, figsize=(8.5,4*1), sharex=True, sharey=True)
    linestyle = '-'
    x = midpoints
    for i, name in enumerate(names):

        persamp_mat_mean = mean_mat_p21[i]
        persamp_mat_sem  = sem_mat_p21[i]

        if i < 5:
            ax = axs[0]
        else:
            ax = axs[1]

        x = midpoints

        row1 = persamp_mat_mean[channel]
        row2 = persamp_mat_sem[channel]
        
        if i == 6:
            label = gnames[channel]
        else:
            label = ''
        ax.plot(x, row1, linestyle=linestyle, color=colors[channel], label=label)
        ax.fill_between(x, row1-row2, row1+row2, alpha=0.1, edgecolor='none', color=colors[channel])

        sns.despine(ax=ax)
        ax.grid(False)
        ax.set_title(name)
        ax.set_xlabel('top->bottom L2/3')

    axs[0].set_xticks([0, 0.5, 1])
    axs[1].legend(bbox_to_anchor=(1,1))
    axs[0].set_ylabel('frac. max expr.')
    axs[0].set_title('P21NR')
    axs[1].set_title('P21DR')
    fig.tight_layout()
    # fig_manager.savefig(fig)
    plt.show()

In [None]:
gnames = ['A genes', 'B genes', 'C genes', 'IEGs']
colors = ['C0', 'C1', 'C2', 'C3']
for channel in range(4):
    per_channel_plot(channel)