In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
from scipy import stats
import os

from scipy import spatial
from scipy import sparse
from scipy.interpolate import CubicSpline
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import networkx as nx
from umap import UMAP

import json

In [None]:
import importlib

from scroutines import powerplots
from scroutines.miscu import is_in_polygon

import utils_merfish
importlib.reload(utils_merfish)
from utils_merfish import rot2d, st_scatter, st_scatter_ax, plot_cluster, binning
from utils_merfish import RefLineSegs

from merfish_datasets import merfish_datasets
from merfish_datasets import merfish_datasets_params

In [None]:
def get_qc_metrics(df):
    """
    return metrics
     - key
      - (name, val, medval, bins)
    """
    metrics = {}
    cols  = ['volume', 'gncov', 'gnnum']
    names = ['cell volume', 'num transcripts', 'num genes']
    
    for col, name in zip(cols, names):
        val = df[col].values
        medval = np.median(val)
        bins = np.linspace(0, 10*medval, 50)
        
        metrics[col] = (name, val, medval, bins)
    return metrics

def get_norm_counts(adata, scaling=500):
    """norm - equalize the volume to be 500 for all cells
    """
    cnts = adata.X
    vol = adata.obs['volume'].values
    normcnts = cnts/vol.reshape(-1,1)*scaling
    adata.layers['norm'] = normcnts
    
    return normcnts

In [None]:
def get_largest_spatial_components(adata, k=100, dist_th=80):
    """
    k - number of neighbors
    dist_th - distance to call connected components
    
    returns
        - indices of the largest components
    """
    XY = adata.obs[['x', 'y']].values
    nc = len(XY)

    # kNN
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(XY)
    distances, indices = nbrs.kneighbors(XY)

    # filtered by distance th
    val = distances[:,1:].reshape(-1,)
    i = np.repeat(indices[:,0],k-1)
    j = indices[:,1:].reshape(-1,)

    indices_filtered = np.vstack([i[val < dist_th], j[val < dist_th]]).T

    G = nx.Graph()
    G.add_nodes_from(np.arange(nc))
    G.add_edges_from(indices_filtered)
    components = nx.connected_components(G)
    largest_component = max(components, key=len)
    indices_selected = np.array(list(largest_component))

    print(f"fraction of cells included: {len(largest_component)/nc: .2f}" )
    
    return indices_selected, XY

# load data and construct adata 

In [None]:
np.random.seed(0)

In [None]:
outdir     = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/results_merfish/plots_250410"
outdatadir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/organized"
!mkdir -p $outdir
!mkdir -p $outdatadir

In [None]:
merfish_datasets, merfish_datasets_params

In [None]:
name = 'P14NRa_pos'
dirc = merfish_datasets[name]
params = merfish_datasets_params[name]

f1 = os.path.join('/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish', dirc, 'cell_by_gene.csv')
f2 = os.path.join('/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish', dirc, 'cell_metadata.csv')
output_l0      = outdatadir + f'/{name}_l0_hemibrain_250410.h5ad'
output_l1      = outdatadir + f'/{name}_l1_wholecortex_250410.h5ad'
output_l2      = outdatadir + f'/{name}_l2_v1_250410.h5ad'

dirc, params

In [None]:
df1 = pd.read_csv(f1)
df2 = pd.read_csv(f2)
genes = df1.iloc[:,1:-50].columns
blnks = df1.iloc[:, -50:].columns
df = df2.join(df1)

print(df1.shape, df2.shape, len(genes), len(blnks))
assert np.all(df['cell'] == df['EntityID'])

df['fpcov'] =  df[blnks].sum(axis=1)
df['gncov'] =  df[genes].sum(axis=1)
df['gnnum'] = (df[genes]>0).sum(axis=1)
metacols = np.hstack([df2.columns, 'gncov', 'gnnum', 'fpcov', ])

# metrics
metrics = get_qc_metrics(df)  # median, bin, distribution

# adata
adata = ad.AnnData(X=df[genes].values, 
                   obsm=dict(blanks=df[blnks].values),
                   obs=df[metacols], 
                   var=pd.DataFrame(index=genes), 
                  )

In [None]:
fig, axs = plt.subplots(1,3,figsize=(3*5,1*4))
for ax, (key, metric) in zip(axs, metrics.items()):
    (name, val, medval, bins) = metric
    
    sns.histplot(val, element='step', bins=bins, ax=ax, stat='percent')
    sns.despine(ax=ax)
    ax.axvline(medval, color='k',)
    ax.text(medval, 0, medval)
    ax.set_xlabel(name)
plt.show()


factors = np.array(['volume', 'PolyT_raw', 'gncov'])
pairs = [[0,1], [0,2], [1,2]]
fig, axs = plt.subplots(1,3,figsize=(3*5,1*4))
for pair, ax in zip(pairs, axs):
    f0 = factors[pair[0]]
    f1 = factors[pair[1]]
    x, y = np.log10(1+df[f0]), np.log10(1+df[f1])
    r, p = stats.pearsonr(x, y)
    ax.scatter(x, y, s=1)
    ax.set_xlabel(f0)
    ax.set_ylabel(f1)
    ax.set_title(f'r = {r:.2g}')
    
plt.show()

# rotation

In [None]:
rotation = 90 #params['rotation']
flip = False # False

# calibrate coordinates
adata.uns['rotation'] = rotation

x = adata.obs['center_x']
y = adata.obs['center_y']
xr, yr = rot2d(x, y, rotation)
if flip:
    xr = -xr
adata.obs['x'] = xr
adata.obs['y'] = yr

gn = 'Slc17a7'
g = adata[:,gn].X.reshape(-1,)
fig, axs = plt.subplots(1,2,figsize=(2*8,1*6))
ax1, ax2 = axs
st_scatter_ax(fig, ax1, x,  y,  gexp=g, title=gn, axis_off=False)
st_scatter_ax(fig, ax2, xr, yr, gexp=g, title=gn, axis_off=False)
plt.show()

In [None]:
sns.histplot(adata.obs['fpcov'], )

# filter #1 basic QC metrics

In [None]:
print(adata.shape)
# basic QC filter
cond = np.all([
    adata.obs['volume'] > 50, 
    adata.obs['volume'] < 7000, 
    adata.obs['fpcov'] < 10, # 5
    adata.obs['gncov'] > 10,
    adata.obs['y'] >  9000,
    adata.obs['x'] > -5000,
], axis=0)

adata = adata[cond]
print(adata.shape)

In [None]:
gns = [
    'Slc17a7',
    ]
for gn in gns: 
    g = adata[:,gn].X.reshape(-1,)
    xr = adata.obs['x']
    yr = adata.obs['y']
    fig, ax = plt.subplots(1,1,figsize=(1*8,1*6))
    st_scatter_ax(fig, ax, xr, yr, gexp=g, title=gn, axis_off=False)
    plt.show()

In [None]:
gns = [
    'Slc17a7',
    'Gad1',
    'Sox10', 
    'Slc6a13',
    ]
for gn in gns: 
    g = adata[:,gn].X.reshape(-1,)
    xr = adata.obs['x']
    yr = adata.obs['y']
    fig, ax = plt.subplots(1,1,figsize=(1*8,1*6))
    st_scatter_ax(fig, ax, xr, yr, gexp=g, title=gn, axis_off=False)
    ax.axis(False)
    plt.show()

# spatial domain analysis

In [None]:
importlib.reload(utils_merfish)

In [None]:
# redo metrics
metrics = get_qc_metrics(adata.obs)
# norm
nrmcnts = get_norm_counts(adata)

# scale log zscore PCs
_, cov, medcov, _ = metrics['gncov']
jnorm = adata.layers['norm']/cov.reshape(-1,1)*250
ljnorm = np.log10(1+jnorm)
zljnorm = stats.zscore(ljnorm, axis=0)
pcs = PCA(n_components=50).fit_transform(zljnorm) # all genes rather than highly variable genes 

# pca
adata.layers['jnorm'] = jnorm
adata.obsm['X_pca'] = pcs

In [None]:
%%time
# this part takes long - minutes
# spatial
adata.obsm['X_xy'] = adata.obs[['x', 'y']].values

# gene neighbors and spatial neighbors
sc.pp.neighbors(adata, use_rep='X_pca', random_state=0)
sc.pp.neighbors(adata, use_rep='X_xy', key_added='xy', random_state=0)

# blending
alpha = 0.5
joint_graph = (1 - alpha) * adata.obsp["connectivities"] + alpha * adata.obsp["xy_connectivities"]
sc.tl.leiden(adata, adjacency=joint_graph, key_added="blended_domains", random_state=0)

In [None]:
n_unq_clsts = len(adata.obs['blended_domains'].unique())
clsts_palette, clsts_cmap = utils_merfish.generate_discrete_cmap([10, n_unq_clsts-10], keys=['tab20', 'Set2'])
clsts_palette2, clsts_cmap2 = utils_merfish.generate_discrete_cmap([1, 10, n_unq_clsts-10], keys=['light:b', 'tab20', 'Set2'])
clsts_cmap

In [None]:
clsts_cmap2

In [None]:
xr =  adata.obs['x']
yr =  adata.obs['y']
clsts = adata.obs['blended_domains'].astype(int) # requires [0,1,2...,N]

with sns.axes_style('white'): 
    fig, ax = plt.subplots()
    utils_merfish.plot_cluster_simple_ax(fig, ax, clsts, xr, yr, s=0.5, cmap=clsts_cmap) # cmap=plt.cm.tab20)
    plt.show()

In [None]:
n = n_unq_clsts
nx = 5
ny = int((n+nx-1)/nx)
fig, axs = plt.subplots(ny,nx, figsize=(nx*3,ny*3))
for i in range(n):
    ax = axs.flat[i]
    clsts_this = clsts.copy()
    clsts_this[~(clsts_this==i)] = -1
    clsts_this += 1
    
    utils_merfish.plot_cluster_simple_ax(fig, ax, clsts_this, xr, yr, s=1, cmap=clsts_cmap2, cbar=False) # cmap=plt.cm.tab20)
    ax.set_title(i)
fig.tight_layout()
plt.show()

# set up the cortical grid

In [None]:
# rotate and see 
xr, yr = adata.obs['x'], adata.obs['y']
xr2, yr2 = rot2d(xr, yr, 45) 
xr2 = xr2 - np.mean(xr2)
yr2 = yr2 - np.mean(yr2)
adata.obs['x2'] = xr2
adata.obs['y2'] = yr2

gns = ['Slc17a7', 'Slc6a13']
n = len(gns)
fig, axs = plt.subplots(1,n,figsize=(n*7,1*6), sharex=True, sharey=True)
for gn, ax in zip(gns, axs):
    g = adata[:,gn].X.reshape(-1,)
    st_scatter_ax(fig, ax, xr2, yr2, gexp=g, title=gn, axis_off=False)
plt.show()



In [None]:
cond1 = utils_merfish.two_step_cut(adata.obs['x2'], adata.obs['y2'],0, 0, 0)  
cond2 = np.array(adata[:,'Slc6a13'].X).reshape(-1,)>0
adata_pia = adata[cond1 & cond2]

XY = adata_pia.obs[['x2', 'y2']].values
xr = adata.obs['x2']
yr = adata.obs['y2']

for i in range(10):
    # fit
    XY, DW, XY_fit, XY_obj, _ = utils_merfish.line_fitting(XY)
    d, w = DW[:,0], DW[:,1]
    x, y = XY[:,0], XY[:,1]
    x_fit, y_fit = XY_fit[:,0], XY_fit[:,1]
    
    # show results
    with sns.axes_style('white'): 
        fig, ax = plt.subplots()
        ax.scatter(xr, yr, s=0.5, color='lightgray', edgecolor='none')
        ax.scatter(x, y, s=10, edgecolors='k', facecolors='none', linewidth=1)
        ax.plot(x_fit, y_fit, color='k')
        ax.set_title(i+1)
        ax.axis('off')
        
        plt.show()

    # next iteration
    cond = d < 0.5*np.percentile(d, 95)
    XY = XY[cond]
    
    # break

In [None]:
with sns.axes_style('white'): 
    fig, ax = plt.subplots()
    xr =  adata.obs['x2']
    yr =  adata.obs['y2']
    clsts =  adata.obs['blended_domains'].astype(int)
    # ax.scatter(xr, yr, s=0.5, color='lightgray', edgecolor='none')
    # ax.scatter(xr, yr, s=0.5, c=adata.obs['depth'], edgecolor='none', cmap='rocket_r')
    
    utils_merfish.plot_cluster_simple_ax(fig, ax, clsts, xr, yr, s=0.5, cmap=clsts_cmap) # cmap=plt.cm.tab20)
    XY_obj.plot_grid(ax, t_interval=500, v_length=1000)
    
    plt.show()

In [None]:
# save L0 results: 
## the intial filter and trim
## the spatial domains
## the curved grid - XY_obj  <-> (poly_fit, XY_fit)

depth, width = XY_obj.dists_to_qps(adata.obs[['x2', 'y2']].values)
adata.uns['ref_line'] = XY_obj.ps.tolist()
adata.obs['depth'] = depth
adata.obs['width'] = width
adata


In [None]:
print(output_l0)
adata.write(output_l0)

# L1 - select cortex

In [None]:
cond = np.all([
    adata.obs['depth'] < 1100,
    # adata.obs['width'] > 500,
    # adata.obs['width'] < 6000,
    ], axis=0)
adatasub = adata[cond].copy()

clsts = adatasub.obs['blended_domains'].astype(int) # requires [0,1,2...,N]
XY = adatasub.obs[['x2', 'y2']].values
DW = adatasub.obs[['depth', 'width']].values
x, y = XY[:,0], XY[:,1]
d, w = DW[:,0], DW[:,1]

In [None]:
fig, axs = plt.subplots(2,1,figsize=(12*1,4*2))
ax = axs[0]
ax.scatter(w, -d, c=d, s=10, edgecolor='none', cmap='rocket_r')

ax.set_xlabel('width (um)')
ax.set_ylabel('depth (um)')
# ax.set_xlim([2000,5000])
ax.set_ylim([-200,100])
# ax.set_aspect('equal')


ax = axs[1]
utils_merfish.plot_cluster_simple_ax(fig, ax, clsts, x, y, s=0.5, cmap=clsts_cmap)
XY_obj.plot_grid(ax, t_interval=500, v_length=1000)

ax.set_aspect('equal')
ax.set_xlabel('width (um)')
ax.set_ylabel('depth (um)')
plt.show()



In [None]:
adatasub

In [None]:
# save L1 results: 
# L1 - adatasub
print(output_l1)
adatasub.write(output_l1)

In [None]:
adatasub = sc.read(output_l1)
adatasub

# Plot area and layer signatures

In [None]:
gns = [ 
    'Stard8', 'Cux2', 'Whrn', 'Syt2', 'Tle4', 'Syt6', 'Ccn2', 
    'Scnn1a', 'Igfbp4', 'Syt17', 'C1ql3', 'Rorb', 
] 
x =   adatasub.obs['width']
y =  -adatasub.obs['depth']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*6,n*1.2))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    
    g = np.log2(1+adatasub[:,gn].layers['jnorm'].reshape(-1,))
    vmax = np.percentile(g, 99)
    vmin = np.percentile(g,  0)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], s=3, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_aspect('equal')
    
    ax.set_title(gn, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
    ax.hlines(-np.arange(0,1200,200), np.min(x)-100, np.min(x), color='gray', linestyle='-', linewidth=1)
    ax.vlines(np.arange(0,np.max(x)+1,1000), 0, -1000, color='k', linestyle='--', linewidth=1) # , zorder=0)
    ax.axhline(0, color='k', linestyle='--', linewidth=1) #, zorder=0)
    
fig.subplots_adjust(hspace=0)
plt.show()

In [None]:
# bins = np.arange(0,np.max(adatasub.obs['width'])+1, 250)
# bins = np.arange(0,5000+1, 250)
bins = np.arange(0,np.max(adatasub.obs['width'])+1, 250)
midbins = bins[:-1]+250/2

df = adatasub.obs.copy()
df['g'] = np.log2(1+np.array(adatasub[:,'Whrn'].layers['jnorm'][...])) # .reshape(-1,)
df = df[((df['depth']>100)&(df['depth']<400))] # take L4

df2 = adatasub.obs.copy()
df2['g'] = np.log2(1+np.array(adatasub[:,'Rorb'].layers['jnorm'][...])) # .reshape(-1,)
df2 = df2[((df2['depth']>100)&(df2['depth']<400))] # take L234

df3 = adatasub.obs.copy()
df3['g'] = np.log2(1+np.array(adatasub[:,'Igfbp4'].layers['jnorm'][...])) # .reshape(-1,)
df3 = df3[((df3['depth']>100)&(df3['depth']<400))] # take L234

In [None]:
df['width_bin'] = pd.cut(df['width'], bins)
dfmean = df[['width_bin', 'g']].groupby('width_bin').mean()

df2['width_bin'] = pd.cut(df2['width'], bins)
dfmean2 = df2[['width_bin', 'g']].groupby('width_bin').mean()

df3['width_bin'] = pd.cut(df3['width'], bins)
dfmean3 = df3[['width_bin', 'g']].groupby('width_bin').mean()

In [None]:
v1_window = [1800, 3500]

fig, ax = plt.subplots(figsize=(5,3))
ax.plot(midbins, dfmean['g'].values, color='red')
ax2 = ax.twinx() # ax
ax2.plot(midbins, dfmean2['g'].values, color='blue')

ax3 = ax.twinx() # ax
ax3.plot(midbins, dfmean3['g'].values, color='green')

ax.set_ylabel('log expr')
ax.set_xlabel('M-L distance (um)')
sns.despine(ax=ax)
ax.grid(False) # , axis='y')
ax2.grid(False) # , axis='y')
ax3.grid(False) # , axis='y')
ax.xaxis.set_ticks(bins[::2], minor=True)
# ax.set_ylim([0,0.2])
ax.axvline(v1_window[0], color='k', linestyle='--')
ax.axvline(v1_window[1], color='k', linestyle='--')
plt.show()

In [None]:
gns = [ 
    #  #'Syt17', 'Rorb', #'Chrm2', 'Epha6',  
    'Whrn', 'Rorb',
    'Scnn1a', 'Igfbp4',
] 
x =   adatasub.obs['width']
y =  -adatasub.obs['depth']
n = len(gns)

fig, axs = plt.subplots(n,1,figsize=(1*6,n*1.2))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    
    g = np.log2(1+adatasub[:,gn].layers['jnorm'].reshape(-1,))
    vmax = np.percentile(g, 99)
    vmin = np.percentile(g,  0)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], s=3, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_aspect('equal')
    
    ax.set_title(gn, loc='left', va='center', ha='right', y=0.5, pad=None)
    fig.colorbar(p, pad=0, shrink=0.5, aspect=5, ticks=[np.round(vmin, decimals=1), np.round(vmax-0.1, decimals=1)])
    
    ax.hlines(-np.array([100,350,550,750,1000]), np.min(x)-100, np.min(x), color='gray', linestyle='-', linewidth=1)
    ax.axis('on')
    ax.xaxis.set_ticks(bins[::2], minor=True)
    # ax.set_xlim([0,5000])
    ax.axvline(v1_window[0], color='k', linestyle='--')
    ax.axvline(v1_window[1], color='k', linestyle='--')
fig.subplots_adjust(hspace=0)
plt.show()

In [None]:
gns = [ 
    'Syt17', 'Stard8', 'Whrn', 'Rorb', 'Scnn1a', 'Syt2', 'Tle4', 'Syt6', 'Ccn2', 
] 
x =   adatasub.obs['width']
y =  -adatasub.obs['depth']
midv1 = np.mean(v1_window)
cutoff = np.logical_and(x>midv1-500, x<midv1+500)
n = len(gns)

fig, axs = plt.subplots(1,n,figsize=(n*1,3*1.2))
for i, (ax, gn) in enumerate(zip(axs, gns)):
    
    g = np.log2(1+adatasub[:,gn].layers['jnorm'].reshape(-1,))
    x = x[cutoff]
    y = y[cutoff]
    g = g[cutoff]
    vmax = np.percentile(g, 99)
    vmin = np.percentile(g,  0)
    sorting = np.argsort(g)
    
    p = utils_merfish.st_scatter_ax(fig, ax,  x[sorting],  y[sorting],  gexp=g[sorting], 
                                    s=3, title='', vmin=vmin, vmax=vmax, cmap='rocket_r')
    ax.set_aspect('equal')
    
    ax.set_title(gn, fontsize=10) #, loc='left', va='center', ha='right', y=0.5, pad=None)
    ax.hlines(-np.arange(0,1200,200), np.min(x)-100, np.min(x)-50, color='gray', linestyle='-', linewidth=1)
    # ax.hlines(-np.array([100,350,550,750,1000]), np.min(x)-100, np.min(x), color='gray', linestyle='--', linewidth=1)
    
fig.subplots_adjust(hspace=0, wspace=-0.05)
plt.show()

# select V1 cells

In [None]:
n_unq_clsts = len(adatasub.obs['blended_domains'].unique())
clsts_palette, clsts_cmap = utils_merfish.generate_discrete_cmap([n_unq_clsts], keys=['tab20', ])
clsts_palette2, clsts_cmap2 = utils_merfish.generate_discrete_cmap([1, n_unq_clsts], keys=['light:b', 'tab20', ])
clsts_cmap

In [None]:
clsts_cmap2

In [None]:

with sns.axes_style('white'): 
    fig, axs = plt.subplots(1,2,figsize=(2*6,1*5))
    cond0 = np.logical_and(adatasub.obs['width']>v1_window[0], adatasub.obs['width']<v1_window[1])
    x    =  adatasub[cond0].obs['x']
    y    =  adatasub[cond0].obs['y']
    d    = -adatasub[cond0].obs['depth']
    w    =  adatasub[cond0].obs['width']
    clsts = adatasub[cond0].obs['blended_domains'].astype(int) # requires [0,1,2...,N]
    
    ax = axs[0]
    utils_merfish.plot_cluster_simple_ax(fig, ax, clsts, x, y, s=2, cmap=clsts_cmap2) # cmap=plt.cm.tab20)
    
    ax = axs[1]
    utils_merfish.plot_cluster_simple_ax(fig, ax, clsts, w, d, s=2, cmap=clsts_cmap2) # cmap=plt.cm.tab20)
    plt.show()

In [None]:
adatasub2 = adatasub[cond0].copy() # anatomical (V1 ctx) # (cell types, num transcripts)
adatasub2.uns['v1_window'] = v1_window
adatasub2

In [None]:
# save L2 results: 
# L2 - adatasub
print(output_l2)
adatasub2.write(output_l2)