In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

import anndata as ad
import scanpy as sc

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from umap import UMAP


from scroutines import basicu
from scroutines import powerplots
from scroutines.miscu import is_in_polygon

import utils_merfish
from merfish_datasets import merfish_datasets
from merfish_genesets import get_all_genesets

import importlib
importlib.reload(powerplots)
importlib.reload(utils_merfish)

In [None]:
np.random.seed(0)

In [None]:
def binning_pipe(adata, n=20, layer='lnorm', bin_type='depth_bin'):
    """
    """
    assert bin_type in ['depth_bin', 'width_bin']
    # bin it 
    depth_bins, depth_binned = utils_merfish.binning(adata.obs['depth'].values, n)
    width_bins, width_binned = utils_merfish.binning(adata.obs['width'].values, n)

    norm_ = pd.DataFrame(adata.layers[layer], columns=adata.var.index)
    norm_['depth_bin'] = depth_binned
    norm_['width_bin'] = width_binned
    
    norm_mean = norm_.groupby(bin_type).mean(numeric_only=True)
    norm_sem  = norm_.groupby(bin_type).sem(numeric_only=True)
    norm_std  = norm_.groupby(bin_type).std(numeric_only=True)
    norm_n    = norm_[bin_type].value_counts(sort=False)

    return norm_mean, norm_sem, norm_std, norm_n, depth_binned, width_binned, depth_bins, width_bins

In [None]:
genesets = get_all_genesets()
genesets

In [None]:
directories = merfish_datasets
print(merfish_datasets)

ddir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/organized" 
fout = os.path.join(ddir, 'P28NRDR_v1l23_merged_240506.h5ad')
!ls $ddir/*240411.h5ad 

In [None]:
%%time
names = [
    'P28NR_ant', 
    'P28NR_pos',
    
    'P28NRb_ant', 
    'P28NRb_pos',
    
    'P28DR_ant', 
    'P28DR_pos',
    
    'P28DRb_ant', 
    'P28DRb_pos',
]

alldata = {}
for name in names:
    if 'b' not in name:
        alldata[name] = ad.read(os.path.join(ddir, f'{name}_ctxglut_240411.h5ad'))
    else:
        alldata[name] = ad.read(os.path.join(ddir, f'{name}_ctxglut_240429.h5ad'))
    print(name, len(alldata[name]))
    


In [None]:
for name, adata in alldata.items():
    print(name)
    gns = ['Scnn1a', 'Rorb', 'Igfbp4', 'Fos', 'Sorcs3']
    x = adata.obs['x']
    y = adata.obs['y']
    n = len(gns)

    fig, axs = plt.subplots(1,n,figsize=(n*6,1*5))
    for ax, gn in zip(axs, gns):
        g = adata[:,gn].layers['norm']
        utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=1, title=gn)
    plt.show()

# viz

In [None]:
v1l23_data = {}
for name in names:
    adata = alldata[name]
    # bound_l, bound_r, bound_d = adata.uns['bound_lrd']
    adatasub = adata[adata.obs['inside_v1l23']]
    v1l23_data[name] = adatasub

In [None]:
gns = ['Scnn1a', 'Rorb', 'Igfbp4', 'Whrn', 'Fos', 'Cdh13', 'Sorcs3', 'Chrm2',]# 'Astn2', 'Fos']

name = 'P28DRb_pos' 
adata = alldata[name]
bound_l, bound_r, bound_d = adata.uns['bound_lrd']
x =  adata.obs['width']
y = -adata.obs['depth']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*10,n*1.5))
fig.suptitle(name)
for ax, gn in zip(axs, gns):
    # gn = 'Scnn1a'
    g = np.log2(1+adata[:,gn].layers['norm']).reshape(-1,)
    
    vmax = np.percentile(g, 99)
    vmin = np.percentile(g,  5)
    cond = g > 1e-5
    
    utils_merfish.st_scatter_ax(fig, ax, x[~cond], y[~cond], gexp=g[~cond], s=2, title='', vmin=vmin, vmax=vmax, cmap='rocket_r') #, axis_off=False)
    utils_merfish.st_scatter_ax(fig, ax,  x[cond],  y[cond],  gexp=g[cond], s=2, title='', vmin=vmin, vmax=vmax, cmap='rocket_r') #, axis_off=False)
    
    ax.set_title(gn, loc='left', ha='right', y=0.5)
    ax.axhline(-bound_d, linestyle='--', linewidth=1, color='gray', zorder=2)
    ax.axvline( bound_l, linestyle='--', linewidth=1, color='gray', zorder=2)
    ax.axvline( bound_r, linestyle='--', linewidth=1, color='gray', zorder=2)
    # ax.grid(True)
    # ax.axis('on')
    
plt.show()

In [None]:
gns = ['Fos', 'Cdh13', 'Sorcs3', 'Chrm2']
densities = []
for name in names:
    adata = v1l23_data[name]
    x = adata.obs['x']
    y = adata.obs['y']
    n = len(gns)
    width_range = adata.obs['width'].max() - adata.obs['width'].min()
    print(name, len(adata), width_range, len(adata)/width_range)
    densities.append(len(adata)/width_range)

    fig, axs = plt.subplots(1,n,figsize=(n*6,1*5))
    for ax, gn in zip(axs, gns):
        # gn = 'Scnn1a'
        g = np.log2(1+adata[:,gn].layers['norm'])
        utils_merfish.st_scatter_ax(fig, ax, x, y, gexp=g, s=5, title=gn)

    plt.show()

In [None]:
gns = ['Cdh13', 'Sorcs3', 'Chrm2', 'Fos']
for name in names: 
    adatasub = v1l23_data[name]
    x =  adatasub.obs['width']
    y = -adatasub.obs['depth']
    n = len(gns)

    fig, axs = plt.subplots(n,1,figsize=(1*10,n*1.5))
    fig.suptitle(name)
    for ax, gn in zip(axs, gns):
        g = np.log2(1+adatasub[:,gn].layers['norm']).reshape(-1,)

        vmax = np.percentile(g, 99)
        vmin = np.percentile(g,  5)
        cond = g > 1e-5

        utils_merfish.st_scatter_ax(fig, ax, x[~cond], y[~cond], gexp=g[~cond], s=4, title='', vmin=vmin, vmax=vmax, cmap='rocket_r') #, axis_off=False)
        utils_merfish.st_scatter_ax(fig, ax,  x[cond],  y[cond],  gexp=g[cond], s=4, title='', vmin=vmin, vmax=vmax, cmap='rocket_r') #, axis_off=False)

        ax.set_title(gn, loc='left', ha='right', y=0.5)
        # ax.grid(True)
        # ax.axis('on')

    plt.show()

# do a clustering together
- abc genes

In [None]:
agenes = genesets['a']
bgenes = genesets['b']
cgenes = genesets['c']
iegs   = genesets['i']
up_agenes = genesets['a_up']
abcgenes = np.hstack([agenes, bgenes, cgenes])
len(abcgenes)

In [None]:
agenes_idx = basicu.get_index_from_array(adatasub.var.index.values, agenes)
bgenes_idx = basicu.get_index_from_array(adatasub.var.index.values, bgenes)
cgenes_idx = basicu.get_index_from_array(adatasub.var.index.values, cgenes)
igenes_idx = basicu.get_index_from_array(adatasub.var.index.values, iegs)

In [None]:
for name in names:
    adatasub = v1l23_data[name]
    vol = adatasub.obs['volume']
    sns.histplot(vol, bins=np.linspace(0,1000,50), 
                 # cumulative=True, 
                 element='step', fill=False, stat='percent')
    

In [None]:
for name in names:
    adatasub = v1l23_data[name]
    metric = adatasub.obs['transcript_count']
    sns.histplot(metric, bins=np.linspace(0,500,20), 
                 # cumulative=True, 
                 element='step', fill=False, stat='percent')
plt.show()


mean_total_rna_target = 250
for name in names:
    adatasub = v1l23_data[name]
    b = adatasub.layers['norm'].sum(axis=1)
    m = np.mean(b)
    sns.histplot(b, bins=np.linspace(0,500,20), 
                 # cumulative=True, 
                 element='step', fill=False, stat='percent', 
                 label=f'{name} {m:.1f}'
                )
plt.legend(loc='upper left', bbox_to_anchor=(1,1))
plt.show()
    

In [None]:
mean_total_rna_target = 250
for name in names:
    adatasub = v1l23_data[name]
    b = adatasub.layers['norm'].sum(axis=1)
    m = np.mean(b)
    adatasub.layers['jnorm'] = adatasub.layers['norm']*(mean_total_rna_target/m)
    b2 = adatasub.layers['jnorm'].sum(axis=1)
    m2 = np.mean(b2)
    sns.histplot(b2, bins=np.linspace(0,500,50), 
                 # cumulative=True, 
                 element='step', fill=False, stat='percent', 
                 label=f'{name} {m:.1f}'
                )
plt.legend(loc='upper left', bbox_to_anchor=(1,1))
plt.show()
    

In [None]:
# lnorm
mean_total_rna_target = 250
for name in names:
    adatasub = v1l23_data[name]
    mean_per_batch = np.mean(adatasub.layers['norm'].sum(axis=1))
    adatasub.layers['jnorm'] = adatasub.layers['norm']*(mean_total_rna_target/mean_per_batch)
    adatasub.layers['ljnorm'] = np.log2(1+adatasub.layers['jnorm']) # .sum(axis=1)


In [None]:
# cluster V1 L2/3 only
mean_total_rna_target = 250

adata_merged = []
for i, name in enumerate(names):
    adatasub = v1l23_data[name].copy()
    if i == 0:
        genes = adatasub.var.index.values
    else:
        np.all(genes == adatasub.var.index.values)
        
    
    # norm
    mean_per_batch = np.mean(adatasub.layers['norm'].sum(axis=1))
    adatasub.layers['jnorm'] = adatasub.layers['norm']*(mean_total_rna_target/mean_per_batch)
    
    # filter
    n0 = len(adatasub)
    # adatasub = adatasub[adatasub.obs['transcript_count'] > 30]
    adatasub = adatasub[adatasub.layers['jnorm'].sum(axis=1) > 70]
    n1 = len(adatasub)
    print(name, f'{n1/n0*100:.1f}')
    
    adatasub.layers['ljnorm'] = np.log2(1+adatasub.layers['jnorm']) # .sum(axis=1)
        
        
    adatasub.obs.index = adatasub.obs.index + '_' + name
    adatasub.obs['sample'] = name
    adatasub.obs['depth_show'] = -adatasub.obs['depth'].values - i*500 # name
    adatasub.obs['width_show'] =  adatasub.obs['width'].values - np.min(adatasub.obs['width'].values)   # name
    adata_merged.append(adatasub)
    print(adatasub.shape)
    # break
    
adata_merged = ad.concat(adata_merged)
adata_merged_abcgenes = adata_merged[:,abcgenes].copy()

adata_merged, adata_merged_abcgenes

In [None]:
adata = adata_merged

# PCA
pca = PCA(n_components=10)
pcs = pca.fit_transform(adata.layers['ljnorm'])
ucs = UMAP(n_components=2, n_neighbors=30, random_state=0).fit_transform(pcs)

adata.obsm['pca'] = pcs
adata.obsm['umap'] = ucs
sc.pp.neighbors(adata, n_neighbors=30, use_rep='pca', random_state=0)



In [None]:
# clustering
r = 0.4
sc.tl.leiden(adata, resolution=r, key_added=f'leiden_r{r}', random_state=0, n_iterations=1)

In [None]:
# plot
gn = 'Slc17a7'
# gn = 'Fos'
# gn = 'Gad1'
g = np.log2(1+adata[:,gn].layers['norm'].reshape(-1,))

# add some quality metrics

fig, (ax1, ax2) = plt.subplots(1,2,figsize=(2*5,1*4))
utils_merfish.st_scatter_ax(fig, ax1, pcs[:,0], pcs[:,1], gexp=g)
utils_merfish.st_scatter_ax(fig, ax2, ucs[:,0], ucs[:,1], gexp=g)
plt.show()

# plot
gn = 'Slc17a7'
# gn = 'Fos'
# gn = 'Gad1'
g = adata[:,gn].layers['norm'].reshape(-1,)

# add some quality metrics

fig, (ax1, ax2) = plt.subplots(1,2,figsize=(2*5,1*4))
utils_merfish.st_scatter_ax(fig, ax1, pcs[:,0], pcs[:,1], gexp=g)
utils_merfish.st_scatter_ax(fig, ax2, ucs[:,0], ucs[:,1], gexp=g)
plt.show()

In [None]:

clsts = adata.obs[f'leiden_r{r}'].astype(int)
xr =  adata.obs['width_show']
yr =  adata.obs['depth_show']
ux    = adata.obsm['umap'][:,0]
uy    = adata.obsm['umap'][:,1]
utils_merfish.plot_cluster(clsts, xr, yr, ux, uy, s=2)

samples, uniq_labels = pd.factorize(adata.obs['sample']) # .astype(int)
utils_merfish.plot_cluster(samples, xr, yr, ux, uy, s=2)

In [None]:
clsts = adata.obs[f'leiden_r{r}'].astype(int)
uniq_clsts = np.unique(clsts)


for clst in uniq_clsts:
    show = (clsts == clst)
    xr =  adata.obs['width_show']
    yr =  adata.obs['depth_show']
    ux    = adata.obsm['umap'][:,0]
    uy    = adata.obsm['umap'][:,1]
    utils_merfish.plot_cluster(show, xr, yr, ux, uy, s=2, cmap=plt.cm.copper_r, suptitle=clst)

# figure out major cell population by marker genes; groups of genes; and quality metrics

In [None]:
# plot
marker_genes = [
       'Ptprn', 'Slc17a7', 'Gad1', 'Fos', 
       
       'Gfap', 'Slc6a13', 'Slc47a1',
       'Grin2c', 'Aqp4', 'Rfx4', 'Sox21', 'Slc1a3',
       
       'Sox10', 'Pdgfra', 'Mog',
       
       'Pecam1', 'Cd34' , 'Tnfrsf12a', 'Sema3c', 
       'Zfhx3', 'Pag1', 'Slco2b1', 'Cx3cr1',
      ] 
gns = marker_genes
n = len(gns)
nx = 4
ny = int((n+nx-1)/nx)
# add some quality metrics
fig, axs = plt.subplots(ny,nx,figsize=(nx*5,ny*4))
for gn, ax in zip(gns, axs.flat):
    g = np.log2(1+adata[:,gn].layers['norm'].reshape(-1,))
    utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g)
    ax.set_title(gn)
plt.show()



In [None]:
# add some quality metrics
fig, ax = plt.subplots()
g = (adata.layers['jnorm'].sum(axis=1) < 80).astype(int)
# g = (adata.obs['transcript_count'].values < 50).astype(int)
# g = (adata.obs['volume'].values < 60).astype(int)
p = utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g, s=3)
fig.colorbar(p)
ax.set_title('')
plt.show()

In [None]:
metrics = [
    'volume', 'anisotropy', 'perimeter_area_ratio', 'solidity', 
    'PolyT_raw', 'PolyT_high_pass', 'DAPI_raw', 'DAPI_high_pass', 
    'transcript_count', 'gncov', 'gnnum', 'fpcov', 
    'depth', 'width', 'sample' 
       ]
n = len(metrics)
nx = 5
ny = int((n+nx-1)/nx)
# add some quality metrics
fig, axs = plt.subplots(ny,nx,figsize=(nx*5,ny*4))
for metric, ax in zip(metrics, axs.flat):
    g = adata.obs[metric].values
    if metric == 'sample':
        g, uniq_lbls = pd.factorize(g)
    utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g, s=3)
    ax.set_title(metric)
plt.show()


In [None]:
n = len(metrics)
nx = 5
ny = int((n+nx-1)/nx)
# add some quality metrics
fig, axs = plt.subplots(ny,nx,figsize=(nx*5,ny*4))
for metric, ax in zip(metrics, axs.flat):
    if metric == 'sample':
        g, uniq_lbls = pd.factorize(g)
    else:
        g = np.log10(1+adata.obs[metric].values)
    utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g, s=3)
    ax.set_title(metric)
plt.show()

# broad annotation and save

In [None]:
clst_annots = [
    "Glu0",
    "Glu1",
    "Glu2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    # "9",
    # 'Glu',
    # 'Glu',
    # 'low',
    # 'Glu',
    # 'micro', 
    # 'GABA',
    # 'astro',
    # 'VLMC',
    # 'Olig',
    
     # 'Glu_nr_low', 
     # 'olig_and_low_qual',
     # 'Glu_dr_low', 
     # 'Glu_nr_high', 
     # 'Glu_dr_high', 
     # 'GABA',
]
adata.uns['clst_annots'] = clst_annots

In [None]:
adata.write(fout)

In [None]:
adata.obs