In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
from scipy import stats
import os

from scipy import spatial
from scipy import sparse
from scipy.interpolate import CubicSpline
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import networkx as nx
from umap import UMAP

import json

In [None]:
import importlib

from scroutines import powerplots
from scroutines.miscu import is_in_polygon

import utils_merfish
importlib.reload(utils_merfish)
from utils_merfish import rot2d, st_scatter, st_scatter_ax, plot_cluster, binning
from utils_merfish import RefLineSegs

import merfish_datasets
import merfish_genesets
importlib.reload(merfish_datasets)
from merfish_datasets import merfish_datasets
from merfish_datasets import merfish_datasets_params

from scroutines import basicu

In [None]:
def get_qc_metrics(df):
    """
    return metrics
     - key
      - (name, val, medval, bins)
    """
    metrics = {}
    cols  = ['volume', 'gncov', 'gnnum']
    names = ['cell volume', 'num transcripts', 'num genes']
    
    for col, name in zip(cols, names):
        val = df[col].values
        medval = np.median(val)
        bins = np.linspace(0, 10*medval, 50)
        
        metrics[col] = (name, val, medval, bins)
    return metrics

def get_norm_counts(adata, scaling=500):
    """norm - equalize the volume to be 500 for all cells
    """
    cnts = adata.X
    vol = adata.obs['volume'].values
    normcnts = cnts/vol.reshape(-1,1)*scaling
    adata.layers['norm'] = normcnts
    
    return normcnts

In [None]:
def get_largest_spatial_components(adata, k=100, dist_th=80):
    """
    k - number of neighbors
    dist_th - distance to call connected components
    
    returns
        - indices of the largest components
    """
    XY = adata.obs[['x', 'y']].values
    nc = len(XY)

    # kNN
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(XY)
    distances, indices = nbrs.kneighbors(XY)

    # filtered by distance th
    val = distances[:,1:].reshape(-1,)
    i = np.repeat(indices[:,0],k-1)
    j = indices[:,1:].reshape(-1,)

    indices_filtered = np.vstack([i[val < dist_th], j[val < dist_th]]).T

    G = nx.Graph()
    G.add_nodes_from(np.arange(nc))
    G.add_edges_from(indices_filtered)
    components = nx.connected_components(G)
    largest_component = max(components, key=len)
    indices_selected = np.array(list(largest_component))

    print(f"fraction of cells included: {len(largest_component)/nc: .2f}" )
    
    return indices_selected, XY

# load data and construct adata 

In [None]:
np.random.seed(0)

In [None]:
outdir     = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/results_merfish/plots_240718"
outdatadir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/organized"
!mkdir -p $outdir
!mkdir -p $outdatadir

In [None]:
genesets, df = merfish_genesets.get_all_genesets()
for key, item in genesets.items():
    print(key, len(item))

In [None]:
ddir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/organized" 
fout = os.path.join(ddir, 'P28NRDR_v1l23_merged_240508.h5ad')
!ls $ddir/*l2*240723.h5ad 

In [None]:
%%time
names = [
    'P28NRa_ant', 
    'P28NRa_pos',
    
    'P28NRb_ant', 
    'P28NRb_pos',
    
    'P28DRa_ant', 
    'P28DRa_pos',
    
    'P28DRb_ant', 
    'P28DRb_pos',
]

alldata = {}
for name in names:
    adatasub = ad.read(os.path.join(ddir, f'{name}_l2_v1_240723.h5ad')) 
    adatasub.obs.index = np.char.add(f'{name}', adatasub.obs.index.values)
    alldata[name] = adatasub 
    print(name, len(alldata[name]))
    
genes = adatasub.var.index.values
genes.shape

In [None]:
agenes = genesets['a']
bgenes = genesets['b']
cgenes = genesets['c']
iegs   = genesets['i']
up_agenes = genesets['a_up']
abcgenes = np.hstack([agenes, bgenes, cgenes])
len(abcgenes)

In [None]:
agenes_idx = basicu.get_index_from_array(adatasub.var.index.values, agenes)
bgenes_idx = basicu.get_index_from_array(adatasub.var.index.values, bgenes)
cgenes_idx = basicu.get_index_from_array(adatasub.var.index.values, cgenes)
igenes_idx = basicu.get_index_from_array(adatasub.var.index.values, iegs)

In [None]:
mean_total_rna_target = 250
adata_premerge = []
for i, name in enumerate(names):
    adatasub = alldata[name].copy()
    
    adatasub.obs['sample'] = name
    
    norm_cnts = adatasub.layers['norm']
    mean_per_batch = np.mean(norm_cnts.sum(axis=1))
    adatasub.layers['jnorm']  = norm_cnts*(mean_total_rna_target/mean_per_batch)
    adatasub.layers['ljnorm'] = np.log2(1+adatasub.layers['jnorm'])
    
    adatasub.obs['norm_transcript_count']  = adatasub.layers['norm'].sum(axis=1)
    adatasub.obs['jnorm_transcript_count'] = adatasub.layers['jnorm'].sum(axis=1)
    
    adata_premerge.append(adatasub)
    
adata_premerge = ad.concat(adata_premerge)

In [None]:
adata_merged = []
for i, name in enumerate(names):
    j = i // 4
    i = i % 4
    
    adatasub = adata_premerge[adata_premerge.obs['sample']==name].copy()
        
    adatasub.obs['sample'] = name
    adatasub.obs['depth_show'] = -adatasub.obs['depth'].values - i*1300 # name
    adatasub.obs['width_show'] =  adatasub.obs['width'].values - np.min(adatasub.obs['width'].values) + j*2500   # name
    adata_merged.append(adatasub)
    print(adatasub.shape)
    # break
    
adata_merged = ad.concat(adata_merged)

In [None]:

sns.histplot(adata_premerge.obs['jnorm_transcript_count'], bins=np.linspace(0, 500, 50))

In [None]:
sns.histplot(adata_premerge.obs['norm_transcript_count'], bins=np.linspace(0, 500, 50))

In [None]:
from scipy.stats import zscore

In [None]:
marker_genes = [
       'Ptprn', 'Slc17a7', 'Gad1', 'Fos', 
       
       'Gfap', 'Slc6a13', 'Slc47a1',
       'Grin2c', 'Aqp4', 'Rfx4', 'Sox21', 'Slc1a3',
       
       'Sox10', 'Pdgfra', 'Mog',
       
       'Pecam1', 'Cd34' , 'Tnfrsf12a', 'Sema3c', 
       'Zfhx3', 'Pag1', 'Slco2b1', 'Cx3cr1',
      ] 

In [None]:
genes_noniegs = np.array([g for g in genes if g not in iegs])
genes_noniegs.shape

In [None]:
adata = adata_merged # [:,marker_genes].copy()


# PCA
pca = PCA(n_components=50)
pcs = pca.fit_transform(zscore(adata[:,genes_noniegs].layers['ljnorm'], axis=1))
# pcs = pca.fit_transform(zscore(adata.layers['ljnorm'], axis=1))
ucs = UMAP(n_components=2, n_neighbors=30, random_state=0).fit_transform(pcs)

adata.obsm['pca'] = pcs
adata.obsm['umap'] = ucs
sc.pp.neighbors(adata, n_neighbors=30, use_rep='pca', random_state=0)

In [None]:
# clustering
r = 0.3
sc.tl.leiden(adata, resolution=r, key_added=f'leiden_r{r}', random_state=0, n_iterations=10)

In [None]:
# plot
gn = 'Slc17a7'
# gn = 'Fos'
# gn = 'Gad1'
g = np.log2(1+adata[:,gn].layers['jnorm'].reshape(-1,))

# add some quality metrics

fig, (ax1, ax2) = plt.subplots(1,2,figsize=(2*5,1*4))
utils_merfish.st_scatter_ax(fig, ax1, pcs[:,0], pcs[:,1], gexp=g)
utils_merfish.st_scatter_ax(fig, ax2, ucs[:,0], ucs[:,1], gexp=g)
plt.show()

# plot
gn = 'Slc17a7'
# gn = 'Fos'
# gn = 'Gad1'
g = adata[:,gn].layers['jnorm'].reshape(-1,)

# add some quality metrics

fig, (ax1, ax2) = plt.subplots(1,2,figsize=(2*5,1*4))
utils_merfish.st_scatter_ax(fig, ax1, pcs[:,0], pcs[:,1], gexp=g)
utils_merfish.st_scatter_ax(fig, ax2, ucs[:,0], ucs[:,1], gexp=g)
plt.show()

In [None]:

clsts = adata.obs[f'leiden_r{r}'].astype(int)
xr =  adata.obs['width_show']
yr =  adata.obs['depth_show']
ux = adata.obsm['umap'][:,0]
uy = adata.obsm['umap'][:,1]
utils_merfish.plot_cluster(clsts, xr, yr, ux, uy, s=2)

samples, uniq_labels = pd.factorize(adata.obs['sample']) # .astype(int)
utils_merfish.plot_cluster(samples, xr, yr, ux, uy, s=2)

In [None]:
np.unique(clsts, return_counts=True)

In [None]:
clsts = adata.obs[f'leiden_r{r}'].astype(int)
uniq_clsts = np.unique(clsts)


for clst in uniq_clsts:
    show = (clsts == clst)
    xr =  adata.obs['width_show']
    yr =  adata.obs['depth_show']
    ux    = adata.obsm['umap'][:,0]
    uy    = adata.obsm['umap'][:,1]
    utils_merfish.plot_cluster(show, xr, yr, ux, uy, s=2, cmap=plt.cm.copper_r, suptitle=clst)

# figure out major cell population by marker genes; groups of genes; and quality metrics

In [None]:
# plot
marker_genes = [
       'Ptprn', 'Slc17a7', 'Gad1', 'Fos', 
       
       'Gfap', 'Slc6a13', 'Slc47a1',
       'Grin2c', 'Aqp4', 'Rfx4', 'Sox21', 'Slc1a3',
       
       'Sox10', 'Pdgfra', 'Mog',
       
       'Pecam1', 'Cd34' , 'Tnfrsf12a', 'Sema3c', 
       'Zfhx3', 'Pag1', 'Slco2b1', 'Cx3cr1',
      ] 
gns = marker_genes
n = len(gns)
nx = 4
ny = int((n+nx-1)/nx)
# add some quality metrics
fig, axs = plt.subplots(ny,nx,figsize=(nx*5,ny*4))
for gn, ax in zip(gns, axs.flat):
    g = adata[:,gn].layers['jnorm'].reshape(-1,)
    utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g)
    ax.set_title(gn)
plt.show()


In [None]:
# plot
gns = marker_genes
n = len(gns)
nx = 4
ny = int((n+nx-1)/nx)
# add some quality metrics
fig, axs = plt.subplots(ny,nx,figsize=(nx*5,ny*4))
for gn, ax in zip(gns, axs.flat):
    g = np.log2(1+adata[:,gn].layers['jnorm'].reshape(-1,))
    utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g)
    ax.set_title(gn)
plt.show()



In [None]:
# # add some quality metrics
# fig, ax = plt.subplots()
# # g = (adata.layers['jnorm'].sum(axis=1) < 100).astype(int)
# g = (adata.obs['jnorm_transcript_count'].values < 100).astype(int)
# # g = (adata.obs['volume'].values < 60).astype(int)
# p = utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g, s=3)
# fig.colorbar(p)
# ax.set_title('')
# plt.show()

# fig, ax = plt.subplots()
# # g = (adata.layers['volume'].sum(axis=1)) #  < 100).astype(int)
# # g = (adata.obs['transcript_count'].values < 50).astype(int)
# g = (adata.obs['volume'].values < 150).astype(int)
# p = utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g, s=3)
# fig.colorbar(p)
# ax.set_title('')
# plt.show()

# fig, ax = plt.subplots()
# # g = (adata.layers['volume'].sum(axis=1)) #  < 100).astype(int)
# g = (adata.obs['transcript_count'].values < 100).astype(int)
# # g = (adata.obs['volume'].values < 150).astype(int)
# p = utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g, s=3)
# fig.colorbar(p)
# ax.set_title('')
# plt.show()

fig, ax = plt.subplots()
# g = (adata.layers['volume'].sum(axis=1)) #  < 100).astype(int)
g = np.logical_and(
    (adata.obs['transcript_count'].values < 50).astype(int),
    (adata.obs['jnorm_transcript_count'].values < 150).astype(int),
).astype(int)
# g = (adata.obs['volume'].values < 150).astype(int)
p = utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g, s=3)
fig.colorbar(p)
ax.set_title('')
plt.show()

In [None]:
metrics = [
    'volume', 'anisotropy', 'perimeter_area_ratio', 'solidity', 
    'PolyT_raw', 'PolyT_high_pass', 'DAPI_raw', 'DAPI_high_pass', 
    'transcript_count', 'jnorm_transcript_count', 'gnnum', 'fpcov', 
    'depth', 'width', 'sample' 
       ]
n = len(metrics)
nx = 5
ny = int((n+nx-1)/nx)
# add some quality metrics
fig, axs = plt.subplots(ny,nx,figsize=(nx*5,ny*4))
for metric, ax in zip(metrics, axs.flat):
    g = adata.obs[metric].values
    if metric == 'sample':
        g, uniq_lbls = pd.factorize(g)
    p = utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g, s=3)
    fig.colorbar(p, shrink=0.4)
    ax.set_title(metric)
plt.show()


In [None]:
n = len(metrics)
nx = 5
ny = int((n+nx-1)/nx)
# add some quality metrics
fig, axs = plt.subplots(ny,nx,figsize=(nx*5,ny*4))
for metric, ax in zip(metrics, axs.flat):
    if metric == 'sample':
        g, uniq_lbls = pd.factorize(g)
    else:
        g = np.log10(1+adata.obs[metric].values)
    utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g, s=3)
    ax.set_title(metric)
plt.show()

In [None]:
adata

# astrocytes only

In [None]:
adata_astro = adata[adata.obs[f'leiden_r{r}']=='0']
adata_astro

In [None]:
adata_astro

In [None]:
adata = adata_astro # [:,marker_genes].copy()


# PCA
pca = PCA(n_components=10)
pcs = pca.fit_transform(zscore(adata[:,genes_noniegs].layers['ljnorm'], axis=1))
# pcs = pca.fit_transform(zscore(adata.layers['ljnorm'], axis=1))
ucs = UMAP(n_components=2, n_neighbors=30, random_state=0).fit_transform(pcs)

adata.obsm['pca'] = pcs
adata.obsm['umap'] = ucs
sc.pp.neighbors(adata, n_neighbors=30, use_rep='pca', random_state=0)

In [None]:
# clustering
r = 0.3
sc.tl.leiden(adata, resolution=r, key_added=f'leiden_r{r}', random_state=0, n_iterations=10)

In [None]:
# plot
gn = 'Slc17a7'
# gn = 'Fos'
# gn = 'Gad1'
g = np.log2(1+adata[:,gn].layers['jnorm'].reshape(-1,))

# add some quality metrics

fig, (ax1, ax2) = plt.subplots(1,2,figsize=(2*5,1*4))
utils_merfish.st_scatter_ax(fig, ax1, pcs[:,0], pcs[:,1], gexp=g)
utils_merfish.st_scatter_ax(fig, ax2, ucs[:,0], ucs[:,1], gexp=g)
plt.show()

# plot
gn = 'Slc17a7'
# gn = 'Fos'
# gn = 'Gad1'
g = adata[:,gn].layers['jnorm'].reshape(-1,)

# add some quality metrics

fig, (ax1, ax2) = plt.subplots(1,2,figsize=(2*5,1*4))
utils_merfish.st_scatter_ax(fig, ax1, pcs[:,0], pcs[:,1], gexp=g)
utils_merfish.st_scatter_ax(fig, ax2, ucs[:,0], ucs[:,1], gexp=g)
plt.show()

In [None]:

clsts = adata.obs[f'leiden_r{r}'].astype(int)
xr =  adata.obs['width_show']
yr =  adata.obs['depth_show']
ux    = adata.obsm['umap'][:,0]
uy    = adata.obsm['umap'][:,1]
utils_merfish.plot_cluster(clsts, xr, yr, ux, uy, s=2)

samples, uniq_labels = pd.factorize(adata.obs['sample']) # .astype(int)
utils_merfish.plot_cluster(samples, xr, yr, ux, uy, s=2)

In [None]:
np.unique(clsts, return_counts=True)

In [None]:
clsts = adata.obs[f'leiden_r{r}'].astype(int)
uniq_clsts = np.unique(clsts)


for clst in uniq_clsts:
    show = (clsts == clst)
    xr =  adata.obs['width_show']
    yr =  adata.obs['depth_show']
    ux    = adata.obsm['umap'][:,0]
    uy    = adata.obsm['umap'][:,1]
    utils_merfish.plot_cluster(show, xr, yr, ux, uy, s=2, cmap=plt.cm.copper_r, suptitle=clst)

In [None]:
# plot
gns = ['Gfap', 'Slc6a13', 'Slc17a7', 'Grin2c', 'Aqp4', 'Rfx4']

for gn in gns:
    g = adata[:,gn].layers['jnorm'].reshape(-1,)

    x  =  adata.obs['width_show']
    y  =  adata.obs['depth_show']
    ux = adata.obsm['umap'][:,0]
    uy = adata.obsm['umap'][:,1]
    # add some quality metrics

    fig, (ax1, ax2) = plt.subplots(1,2,figsize=(2*5,1*4))
    utils_merfish.st_scatter_ax(fig, ax1, x, y, gexp=g)
    utils_merfish.st_scatter_ax(fig, ax2, ux, uy, gexp=g)
    ax2.set_title(gn)
    plt.show()
    

In [None]:
# 

# rfx4 -> target genes
astro_genes = [
    "Rfx4",
    "Grin2c",
    "Aqp4",
    "Gfap",
    "Nr1d1",
    "Junb",
    "Mertk",
    "Slc1a3",
    "Nrxn1",
    "Sox21",
    "Fosl2",
    "Id3",
    "Stat2",
    "Klf3",
    "Rora",
    "Sdc4",
]

astro_genes

In [None]:

for gn in astro_genes:
    g = adata[:,gn].layers['jnorm'].reshape(-1,)

    x  =  adata.obs['width_show']
    y  =  adata.obs['depth_show']
    ux = adata.obsm['umap'][:,0]
    uy = adata.obsm['umap'][:,1]
    # add some quality metrics

    fig, (ax1, ax2) = plt.subplots(1,2,figsize=(2*5,1*4))
    utils_merfish.st_scatter_ax(fig, ax1, x, y, gexp=g)
    utils_merfish.st_scatter_ax(fig, ax2, ux, uy, gexp=g)
    ax2.set_title(gn)
    plt.show()
    