In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
from scipy import stats
import os

from scipy import spatial
from scipy import sparse
from scipy.interpolate import CubicSpline
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import networkx as nx
from umap import UMAP

import json

In [None]:
import importlib

from scroutines import powerplots
from scroutines.miscu import is_in_polygon

import utils_merfish
importlib.reload(utils_merfish)
from utils_merfish import rot2d, st_scatter, st_scatter_ax, plot_cluster, binning
from utils_merfish import RefLineSegs

import merfish_datasets
import merfish_genesets
importlib.reload(merfish_datasets)
from merfish_datasets import merfish_datasets
from merfish_datasets import merfish_datasets_params

from scroutines import basicu

In [None]:
def get_qc_metrics(df):
    """
    return metrics
     - key
      - (name, val, medval, bins)
    """
    metrics = {}
    cols  = ['volume', 'gncov', 'gnnum']
    names = ['cell volume', 'num transcripts', 'num genes']
    
    for col, name in zip(cols, names):
        val = df[col].values
        medval = np.median(val)
        bins = np.linspace(0, 10*medval, 50)
        
        metrics[col] = (name, val, medval, bins)
    return metrics

def get_norm_counts(adata, scaling=500):
    """norm - equalize the volume to be 500 for all cells
    """
    cnts = adata.X
    vol = adata.obs['volume'].values
    normcnts = cnts/vol.reshape(-1,1)*scaling
    adata.layers['norm'] = normcnts
    
    return normcnts

In [None]:
def get_largest_spatial_components(adata, k=100, dist_th=80):
    """
    k - number of neighbors
    dist_th - distance to call connected components
    
    returns
        - indices of the largest components
    """
    XY = adata.obs[['x', 'y']].values
    nc = len(XY)

    # kNN
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(XY)
    distances, indices = nbrs.kneighbors(XY)

    # filtered by distance th
    val = distances[:,1:].reshape(-1,)
    i = np.repeat(indices[:,0],k-1)
    j = indices[:,1:].reshape(-1,)

    indices_filtered = np.vstack([i[val < dist_th], j[val < dist_th]]).T

    G = nx.Graph()
    G.add_nodes_from(np.arange(nc))
    G.add_edges_from(indices_filtered)
    components = nx.connected_components(G)
    largest_component = max(components, key=len)
    indices_selected = np.array(list(largest_component))

    print(f"fraction of cells included: {len(largest_component)/nc: .2f}" )
    
    return indices_selected, XY

In [None]:
def preprocessing(adata):
    # filter genes
    cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
    adata_sub = adata[:,cond]

    # counts
    x = adata_sub.X
    cov = adata_sub.obs['n_counts'].values

    # CP10k
    xn = (sparse.diags(1/cov).dot(x))*1e4

    # log10(CP10k+1)
    xln = xn.copy()
    xln.data = np.log10(xln.data+1)

    adata_sub.layers['norm'] = xn
    adata_sub.layers['lognorm'] = xln
    
    return adata_sub

In [None]:
def get_hvgs(adata, layer, nbin=20, qth=0.3):
    """
    """
    xn = adata.layers[layer]
    
    # min
    gm = np.ravel(xn.mean(axis=0))

    # var
    tmp = xn.copy()
    tmp.data = np.power(tmp.data, 2)
    gv = np.ravel(tmp.mean(axis=0))-gm**2

    # cut 
    lbl = pd.qcut(gm, nbin, labels=np.arange(nbin))
    gres = pd.DataFrame()
    gres['lbl'] = lbl
    gres['mean'] = gm
    gres['var'] = gv
    gres['ratio']= gv/gm

    # select
    gres_sel = gres.groupby('lbl')['ratio'].nlargest(int(qth*(len(gm)/nbin))) #.reset_index()
    gsel_idx = np.sort(gres_sel.index.get_level_values(1).values)

    assert np.all(gsel_idx != -1)
    
    return adata.var.index.values[gsel_idx]

In [None]:
def neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(ref_emb)
    dists, idx = neigh.kneighbors(qry_emb, k, return_distance=True)
    
    raw_pred = ref_lbl[idx]

    # p
    pabc = np.empty((qry_n, n_unq_lbls))
    for i, lbl in enumerate(unq_lbls):
        p = np.sum(raw_pred==lbl, axis=1)/k
        pabc[:,i] = p

    # max
    max_pred = unq_lbls[np.argmax(pabc, axis=1)]

    # 
    gated_pred = max_pred.copy()
    cond1 = np.max( pabc, axis=1) > p_cutoff
    gated_pred[~cond1] = 'NA' 
    if dist_cutoff is not None:
        cond2 = np.max(dists, axis=1) < dist_cutoff
        gated_pred[~cond2] = 'NA' 
    
    return max_pred, gated_pred, np.max(dists, axis=1)


def neighbor_self_nonself(k, ref_emb, qry_emb):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    lbls = np.array([0]*ref_n+[1]*qry_n)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(np.vstack([ref_emb, qry_emb]))
    idx = neigh.kneighbors(qry_emb, k, return_distance=False)
    
    isself = lbls[idx]

    p = np.sum(isself, axis=1)/k

    
    return p # max_pred, gated_pred, np.max(dists, axis=1)

# load data and construct adata 

In [None]:
np.random.seed(0)

In [None]:
outdir     = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/results_merfish/plots_240718"
outdatadir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/organized"
!mkdir -p $outdir
!mkdir -p $outdatadir

In [None]:
genesets, df = merfish_genesets.get_all_genesets()
for key, item in genesets.items():
    print(key, len(item))

In [None]:
ddir  = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/merfish/organized" 
fout  = os.path.join(ddir, 'P21NR_v1_rna_merfish_250411.h5ad')
fout2 = os.path.join(ddir, 'P21NR_v1glut_rna_merfish_250411.h5ad')
fout3 = os.path.join(ddir, 'P21NR_v1l23glut_rna_merfish_250411.h5ad')
fout4 = os.path.join(ddir, 'P21NR_v1gaba_rna_merfish_250411.h5ad')
# fout5 = os.path.join(ddir, 'P8NR_v1l56itglut_rna_merfish_250411.h5ad')
!ls $ddir/*l2*250410.h5ad 

In [None]:
%%time
names = [
    'P21NRb_ant',
    'P21NRb_pos',
    
    'P21NRc_ant',
    'P21NRc_ant2',
    'P21NRc_pos2',
]

alldata = {}
for name in names:
    adatasub = ad.read(os.path.join(ddir, f'{name}_l2_v1_250410.h5ad')) 
    adatasub.obs.index = np.char.add(f'{name}', adatasub.obs.index.values)
    alldata[name] = adatasub 
    print(name, len(alldata[name]))
    
genes = adatasub.var.index.values
genes.shape

In [None]:
agenes = genesets['a']
bgenes = genesets['b']
cgenes = genesets['c']
iegs   = genesets['i']
abcgenes = np.hstack([agenes, bgenes, cgenes])
genes_noniegs = np.array([g for g in genes if g not in iegs])

marker_genes = [
       'Ptprn', 'Slc17a7', 'Gad1', 'Fos', 
       
       'Gfap', 'Slc6a13', 'Slc47a1',
       'Grin2c', 'Aqp4', 'Rfx4', 'Sox21', 'Slc1a3',
       
       'Sox10', 'Pdgfra', 'Mog',
       
       'Pecam1', 'Cd34' , 'Tnfrsf12a', 'Sema3c', 
       'Zfhx3', 'Pag1', 'Slco2b1', 'Cx3cr1',
      ] 
len(abcgenes), len(genes_noniegs)

In [None]:
agenes_idx = basicu.get_index_from_array(adatasub.var.index.values, agenes)
bgenes_idx = basicu.get_index_from_array(adatasub.var.index.values, bgenes)
cgenes_idx = basicu.get_index_from_array(adatasub.var.index.values, cgenes)
igenes_idx = basicu.get_index_from_array(adatasub.var.index.values, iegs)

In [None]:
mean_total_rna_target = 250
adata_merged = []
for i, name in enumerate(names):
    j = i // 4
    i = i % 4
    
    adatasub = alldata[name].copy()
    adatasub.obs['sample'] = name
    
    norm_cnts = adatasub.layers['norm']
    # mean_per_batch = np.mean(norm_cnts.sum(axis=1))
    mean_per_batch_noniegs = np.mean(adatasub[:,genes_noniegs].layers['norm'].sum(axis=1))
    
    adatasub.layers['jnorm']  = norm_cnts*(mean_total_rna_target/mean_per_batch_noniegs)
    adatasub.layers['ljnorm'] = np.log2(1+adatasub.layers['jnorm'])
    
    adatasub.obs['norm_transcript_count']  = adatasub.layers['norm'].sum(axis=1)
    adatasub.obs['jnorm_transcript_count'] = adatasub.layers['jnorm'].sum(axis=1)
    
    adatasub.obs['depth_show'] = -adatasub.obs['depth'].values - i*1300 # name
    adatasub.obs['width_show'] =  adatasub.obs['width'].values - np.min(adatasub.obs['width'].values) + j*2500   # name
    
    adata_merged.append(adatasub)
    
adata_merged = ad.concat(adata_merged)

In [None]:
adata = adata_merged 

# PCA
pca = PCA(n_components=50)
pcs = pca.fit_transform(stats.zscore(adata[:,genes_noniegs].layers['ljnorm'], axis=1))
# pcs = pca.fit_transform(stats.zscore(adata.layers['ljnorm'], axis=1))
ucs = UMAP(n_components=2, n_neighbors=30, random_state=0).fit_transform(pcs)

adata.obsm['pca'] = pcs
adata.obsm['umap'] = ucs
sc.pp.neighbors(adata, n_neighbors=30, use_rep='pca', random_state=0)

In [None]:
# clustering
r = 0.3
sc.tl.leiden(adata, resolution=r, key_added=f'leiden_r{r}', random_state=0, n_iterations=10)

In [None]:
clsts = adata.obs[f'leiden_r{r}'].astype(int)
xr =  adata.obs['width_show']
yr =  adata.obs['depth_show']
ux = adata.obsm['umap'][:,0]
uy = adata.obsm['umap'][:,1]
utils_merfish.plot_cluster(clsts, xr, yr, ux, uy, s=2)

samples, uniq_labels = pd.factorize(adata.obs['sample']) # .astype(int)
utils_merfish.plot_cluster(samples, xr, yr, ux, uy, s=2)

In [None]:
clsts = adata.obs[f'leiden_r{r}'].astype(int)
uniq_clsts = np.unique(clsts)


for clst in uniq_clsts:
    show = (clsts == clst)
    xr =  adata.obs['width_show']
    yr =  adata.obs['depth_show']
    ux    = adata.obsm['umap'][:,0]
    uy    = adata.obsm['umap'][:,1]
    utils_merfish.plot_cluster(show, xr, yr, ux, uy, s=2, cmap=plt.cm.copper_r, suptitle=clst)

In [None]:
# plot
gns = marker_genes
n = len(gns)
nx = 4
ny = int((n+nx-1)/nx)
# add some quality metrics
fig, axs = plt.subplots(ny,nx,figsize=(nx*5,ny*4))
for gn, ax in zip(gns, axs.flat):
    # g = np.log2(1+adata[:,gn].layers['jnorm'].reshape(-1,))
    g = adata[:,gn].layers['jnorm'].reshape(-1,)
    utils_merfish.st_scatter_ax(fig, ax, ucs[:,0], ucs[:,1], gexp=g)
    ax.set_title(gn)
plt.show()



# RNA data basic analysis 

In [None]:
#
f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/cheng21_cell_scrna/organized/P28NR.h5ad'
adata_rna_raw = sc.read(f)
adata_rna_raw

In [None]:
adata_rna_raw = preprocessing(adata_rna_raw)
hvgs = get_hvgs(adata_rna_raw, 'norm')
adata_rna = adata_rna_raw[:,hvgs]
adata_rna_raw, adata_rna

In [None]:
# merge using overlapping genes
genes_overlap = np.intersect1d(adata_rna.var.index.values, adata.var.index.values)
adata_rna = adata_rna[:,genes_overlap].copy()
adata_mer = adata[adata.obs['transcript_count']>50, genes_overlap].copy()

adata_rna.obs['modality'] = 'rna'
adata_mer.obs['modality'] = 'merfish'
adata_merge = sc.concat([adata_rna, adata_mer], join='outer')

lognorm_rna = np.log10(1+np.array(adata_rna.layers['norm'].todense()))
zlognorm_rna = stats.zscore(lognorm_rna, axis=0)

lognorm_mer = adata_mer.layers['ljnorm']  #np.log10(1+adata_mer.layers['norm'])
zlognorm_mer = stats.zscore(lognorm_mer, axis=0)

adata_merge.obsm['X_pca2'] = PCA(n_components=20).fit_transform(np.vstack([zlognorm_rna, zlognorm_mer])) 

print(len(adata_rna), len(adata), len(adata_mer))
adata_merge

In [None]:
sc.external.pp.harmony_integrate(adata_merge, 'modality', basis='X_pca2', max_iter_harmony=20)

In [None]:
sc.pp.neighbors(adata_merge, use_rep='X_pca_harmony')
sc.tl.umap(adata_merge) # , key='modality')
adata_merge.obs['umap1'] = adata_merge.obsm['X_umap'][:,0]
adata_merge.obs['umap2'] = adata_merge.obsm['X_umap'][:,1]


In [None]:
adata_merge.obs['umap1_self'] = adata_merge.obsm['umap'][:,0]
adata_merge.obs['umap2_self'] = adata_merge.obsm['umap'][:,1]

In [None]:
sns.scatterplot(data=adata_merge.obs.sample(frac=1, replace=False), 
                x='umap1', y='umap2', hue='modality', s=1, edgecolor='none')

In [None]:
# label transfer from RNA data
ref = adata_merge[adata_merge.obs['modality']=='rna'].copy()
qry = adata_merge[adata_merge.obs['modality']=='merfish'].copy()

all_emb = adata_merge.obsm['X_pca_harmony']
ref_emb = ref.obsm['X_pca_harmony']
qry_emb = qry.obsm['X_pca_harmony']

ref_lbl = ref.obs['Class_broad'].values.astype(str) # res_nr['type'].values.astype(str)
print(np.unique(ref_lbl))

k = 100
max_pred, _, dists = neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None)
ps = neighbor_self_nonself(k, ref_emb, qry_emb)

adata_merge.obs['max_pred'] = 'NA' 
adata_merge.obs['frac_self_neighbors'] = np.nan # 'NA'
adata_merge.obs['gated_pred'] = 'NA'

adata_merge.obs.loc[qry.obs.index, 'max_pred'] = max_pred
adata_merge.obs.loc[qry.obs.index, 'frac_self_neighbors'] = ps
adata_merge.obs.loc[qry.obs.index, 'gated_pred'] = np.where(ps < 0.98, max_pred, 'NA')

In [None]:
adata_snrnasq = adata_merge[adata_merge.obs['modality']=='rna']
adata_merfish = adata_merge[adata_merge.obs['modality']=='merfish']

assert np.all(adata_snrnasq.obs.index.values == adata_rna.obs.index.values)
assert np.all(adata_merfish.obs.index.values == adata_mer.obs.index.values)

adata_snrnasq, adata_merfish

In [None]:
sc.pl.umap(adata_merfish, color=['gated_pred'])
sc.pl.umap(adata_merfish, color=['frac_self_neighbors'], vmin=0.95)
sc.pl.umap(adata_merfish, color=['transcript_count'], vmin=0.95)

In [None]:
sns.color_palette('tab20')

In [None]:
x =   adata_merfish.obs['width']
y =  -adata_merfish.obs['depth']
ux =  adata_merfish.obsm['X_umap'][:,0]
uy =  adata_merfish.obsm['X_umap'][:,1]
clsts_lbl = adata_merfish.obs['gated_pred'].values
clsts, _ = pd.factorize(clsts_lbl) # .astype(int) # requires [0,1,2...,N]
clsts_palette, clsts_cmap = utils_merfish.generate_discrete_cmap([len(np.unique(clsts))], keys=['Set2',])

csel = 'tab20'
clsts_palette2 = {
    'Excitatory':       sns.color_palette(csel, 8)[0], # 'C0', 
    'Inhibitory':       sns.color_palette(csel, 8)[1], 
    
    'OPCs':             sns.color_palette(csel, 8)[2], 
    'Oligodendrocytes': sns.color_palette(csel, 8)[3], 
    
    'Microglia':        sns.color_palette(csel, 8)[4], 
    'Endothelial':      sns.color_palette(csel, 8)[5], 

    'Astrocytes':       sns.color_palette(csel, 8)[6], 
    'VLMCs':            sns.color_palette(csel, 8)[7],
    
    'NA': 'gray',
}

In [None]:
fig, ax = plt.subplots(1,1,figsize=(1*6,1*4))
ux = adata_merfish.obs['umap1_self'].values
uy = adata_merfish.obs['umap2_self'].values
c = adata_merfish.obs['transcript_count'].values
# c = np.log10(adata_merfish.obs['transcript_count'].values)

sorting = np.argsort(c) # [::-1]
# g = ax.scatter(ux[sorting], uy[sorting], c=c[sorting], s=1, cmap='rocket_r') 
g = ax.scatter(ux, uy, c=c<100, s=3, edgecolor='none', cmap='rocket_r') 
ax.set_aspect('equal')
ax.axis('off')
fig.colorbar(g)
plt.show()

In [None]:
fig, axs = plt.subplots(1,2,figsize=(8*2,6), sharex=True, sharey=True)
for ax, adata_mod, col, x, y in zip(axs, 
                              [adata_merfish, adata_snrnasq], 
                              ['gated_pred', 'Class_broad'], 
                              ['umap1', 'umap1'],
                              ['umap2', 'umap2'],
                             ):
    sns.scatterplot(data=adata_mod.obs.sample(frac=1, replace=False), 
                    x=x, y=y, 
                    hue=col, 
                    palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                    s=5, edgecolor='none', ax=ax, rasterized=True)
    ax.axis('off')
    ax.set_aspect('equal')
    ax.legend(bbox_to_anchor=(1,1), fontsize=10)
# powerplots.savefig_autodate(fig, os.path.join(outdir, 'fig1_umap.pdf'))
fig.tight_layout()
plt.show()

fig, axs = plt.subplots(1,2,figsize=(8*2,6))
for ax, adata_mod, col, x, y in zip(axs, 
                              [adata_merfish, adata_merfish, ], 
                              ['gated_pred', 'gated_pred'], 
                              ['umap1_self', 'umap1'],
                              ['umap2_self', 'umap2'],
                             ):
    sns.scatterplot(data=adata_mod.obs.sample(frac=1, replace=False), 
                    x=x, y=y, 
                    hue=col, 
                    palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                    s=5, edgecolor='none', ax=ax, rasterized=True)
    ax.axis('off')
    ax.set_aspect('equal')
    ax.legend(bbox_to_anchor=(1,1), fontsize=10)
# powerplots.savefig_autodate(fig, os.path.join(outdir, 'fig1_umap.pdf'))
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8,8))
sns.scatterplot(data=adata_merfish.obs.sample(frac=1, replace=False), 
                x='width_show', y='depth_show', 
                hue='gated_pred', 
                palette=clsts_palette2, hue_order=list(clsts_palette2.keys()),
                s=2, edgecolor='none', ax=ax, rasterized=True)
ax.legend(bbox_to_anchor=(1,1), fontsize=10)
ax.axis('off')
ax.set_aspect('equal')
# ax.invert_yaxis()
# powerplots.savefig_autodate(fig, os.path.join(outdir, 'fig1_dw.pdf'))
plt.show()

In [None]:
dfshow = adata_merfish.obs.sample(frac=1, replace=False).copy()
types = list(clsts_palette2.keys()) # dfshow['gated_pred'].unique()
print(types)

fig, axs = plt.subplots(3,3,figsize=(8*3,8*3))
for i, ax in enumerate(axs.flat):
    thistype = types[i]
    
    ax.set_title(thistype)
    sns.scatterplot(data=dfshow, 
                    x='width_show', y='depth_show', 
                    color='lightgray',
                    s=1, edgecolor='none', ax=ax, rasterized=True)
    
    sns.scatterplot(data=dfshow[dfshow['gated_pred']==thistype], 
                    x='width_show', y='depth_show', 
                    color=clsts_palette2[thistype],
                    s=10, edgecolor='none', ax=ax, rasterized=True)
    ax.axis('off')
    ax.set_aspect('equal')
    # ax.invert_yaxis()
    
    # break
    
# ax.legend(bbox_to_anchor=(1,1), fontsize=10)
# # powerplots.savefig_autodate(fig, os.path.join(outdir, 'fig1_dw.pdf'))
# plt.show()

In [None]:
m =  adata_snrnasq.obs['Class_broad'].value_counts() # /len(adata_rna)
m = m/m.sum()
n =  adata_merfish.obs['gated_pred'].value_counts() # /len(adata_mer)).loc[m.index]
print(n.loc['NA']/n.sum())
n = n.loc[m.index]
n = n/n.sum()
# n = (n/n.sum()).loc[m.index]

m, n


In [None]:
fig, ax = plt.subplots()
ax.scatter(m, n, color='k')
r, p = stats.spearmanr(m.values, n.values)
for i in range(len(m)):
    ax.text(m.iloc[i], n.iloc[i], m.index.values[i], fontsize=10)
    
ax.set_title(f"Spearman r={r:.2g}")

ax.plot([0, 0.6], [0, 0.6], '--', color='gray', zorder=0)
ax.set_aspect('equal')
# ax.set_xticks([0,0.1])
# ax.set_yticks([0,0.1])
sns.despine(ax=ax)

ax.set_xlabel('subclass freq. snRNA-seq')
ax.set_ylabel('subclass freq. MERFISH')
# powerplots.savefig_autodate(fig, os.path.join(outdir, 'fig1_scatter.pdf'))
# ax.set_xscale('log')
# ax.set_yscale('log')
                            
plt.show()

In [None]:
adata_merge.obs = adata_merge.obs.drop('Doublet', axis=1) # for save the framework

In [None]:
# save adata_merged
print(fout)
adata_merge.write(fout)

# add subclass predictions 

In [None]:
adata_merge = sc.read(fout)

In [None]:
# adata_merge_exc = adata_merge[adata_merge.
cond1 = np.logical_or(adata_merge.obs['Class_broad'] == 'Excitatory', 
                      adata_merge.obs['gated_pred']  == 'Excitatory',
                     )
cond2 = np.logical_or(adata_merge.obs['Class_broad'] == 'Inhibitory', 
                      adata_merge.obs['gated_pred']  == 'Inhibitory',
                     )
adata_merge_exc = adata_merge[cond1].copy()
adata_merge_inh = adata_merge[cond2].copy()
adata_merge_exc, adata_merge_inh

# Exc

In [None]:
# # label transfer from RNA data
ref = adata_merge_exc[adata_merge_exc.obs['modality']=='rna'].copy()
qry = adata_merge_exc[adata_merge_exc.obs['modality']=='merfish'].copy()

all_emb = adata_merge_exc.obsm['X_pca_harmony']
ref_emb = ref.obsm['X_pca_harmony']
qry_emb = qry.obsm['X_pca_harmony']

ref_lbl = ref.obs['Subclass'].values.astype(str) # res_nr['type'].values.astype(str)
print(np.unique(ref_lbl))

k = 30
max_pred, _, dists = neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None)
ps = neighbor_self_nonself(k, ref_emb, qry_emb)

adata_merge_exc.obs['max_pred_subclass'] = 'NA' 
adata_merge_exc.obs['frac_self_neighbors_subclass'] = np.nan # 'NA'
adata_merge_exc.obs['gated_pred_subclass'] = 'NA'

adata_merge_exc.obs.loc[qry.obs.index, 'max_pred_subclass'] = max_pred
adata_merge_exc.obs.loc[qry.obs.index, 'frac_self_neighbors_subclass'] = ps
adata_merge_exc.obs.loc[qry.obs.index, 'gated_pred_subclass'] = np.where(ps < 0.98, max_pred, 'NA')

In [None]:
adata_merge_exc

# Inh

In [None]:
# rename
def rename_inh_subclass(old):
    """
    """
    rename_dict = {
        'Stac': 'Vip',
        'Frem1': 'Sncg',
    }
    
    if old in rename_dict.keys():
        new = rename_dict[old]
    else:
        new = old
    return new

adata_merge_inh.obs['Subclass_new'] = adata_merge_inh.obs['Subclass'].apply(rename_inh_subclass)

In [None]:
# # label transfer from RNA data
ref = adata_merge_inh[adata_merge_inh.obs['modality']=='rna'].copy()
qry = adata_merge_inh[adata_merge_inh.obs['modality']=='merfish'].copy()

all_emb = adata_merge_inh.obsm['X_pca_harmony']
ref_emb = ref.obsm['X_pca_harmony']
qry_emb = qry.obsm['X_pca_harmony']

ref_lbl = ref.obs['Subclass_new'].values.astype(str) # res_nr['type'].values.astype(str)
print(np.unique(ref_lbl))

k = 30
max_pred, _, dists = neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None)
ps = neighbor_self_nonself(k, ref_emb, qry_emb)

adata_merge_inh.obs['max_pred_subclass'] = 'NA' 
adata_merge_inh.obs['frac_self_neighbors_subclass'] = np.nan # 'NA'
adata_merge_inh.obs['gated_pred_subclass'] = 'NA'

adata_merge_inh.obs.loc[qry.obs.index, 'max_pred_subclass'] = max_pred
adata_merge_inh.obs.loc[qry.obs.index, 'frac_self_neighbors_subclass'] = ps
adata_merge_inh.obs.loc[qry.obs.index, 'gated_pred_subclass'] = np.where(ps < 0.98, max_pred, 'NA')

In [None]:
adata_merge_inh

# save L2/3 only

In [None]:
# adata_merge_exc = adata_merge[adata_merge.
cond = np.logical_or(adata_merge_exc.obs['Subclass'] == 'L2/3', 
                     adata_merge_exc.obs['gated_pred_subclass'] == 'L2/3',
                     )
adata_merge_l23exc = adata_merge_exc[cond].copy()
adata_merge_l23exc

In [None]:
adata_merge_exc.write(fout2)
adata_merge_l23exc.write(fout3)
adata_merge_inh.write(fout4)