In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
from scipy import sparse 
import itertools

import matplotlib.pyplot as plt
import collections
from sklearn.cluster import KMeans
from sklearn import metrics

from sklearn.decomposition import PCA
from umap import UMAP

from py_pcha import PCHA

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu

In [None]:
outdir    = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/results"

In [None]:
# outfigdir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/results/241021"
# !mkdir -p $outfigdir
# fig_manager = powerplots.FigManager(outfigdir)

In [None]:
import glob

wkdir = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_rfx3_oe/P21_PCA_cellembeddings' 
files = glob.glob(wkdir + '/*.csv')

data_dict = {}
adata_dict = {}
for f in files:
    sample = os.path.basename(f)
    sample = sample[:-len('.csv')]
    df = pd.read_csv(f)
    df['P21_2'] = -df['P21_2']
    
    data_dict[sample] = df.values
    adata_dict[sample] = pd.DataFrame(df.values, columns=['PC1', 'PC2'])
    print(sample, len(df))

In [None]:
data_dict['ctrl'] = np.vstack([data_dict['controlA'], data_dict['controlB']])
data_dict['oe'] = np.vstack([data_dict['Rfx3OE_rep1'], 
                             data_dict['Rfx3OE_rep2_lane1'],
                             data_dict['Rfx3OE_rep2_lane2'],
                            ])

adata_dict['ctrl'] = pd.DataFrame(data_dict['ctrl'], columns=['PC1', 'PC2'])
adata_dict['oe'] = pd.DataFrame(data_dict['oe'], columns=['PC1', 'PC2'])

# PCA - AA

In [None]:
from py_pcha import PCHA

In [None]:
def pca_pipe(adata):
    
    np.random.seed(0)
    pca = PCA(n_components=4)
    
    zlognorm = zscore(np.array(adata.layers['lognorm'].todense()), axis=0)
    # zlognorm = np.nan_to_num(zlognorm, 0)
    
    pcs = pca.fit_transform(zlognorm) # auto centering
    
    return zlognorm-np.mean(zlognorm, axis=0), pca, pcs # manual centering

In [None]:
def get_dists_to_specialists(prj, XC):
    """
    """
    diffs = np.array([prj-XC[:,0], prj-XC[:,1], prj-XC[:,2]]) # specialist by cell by dim
    dists = np.sqrt(np.sum(np.power(diffs, 2), axis=2)) # specialist by cell
    # print(diffs.shape, dists.shape)
    return dists

In [None]:
def aa_inference(X):
    """
    """
    XC, _, _, _, _ = PCHA(X, noc=3, delta=0)
    XC = np.array(XC)
    XC = XC[:,np.argsort(XC[0])] # assign an order according to x-axis 
    return XC

In [None]:
def add_triangle(XC, ax, zorder=0, vertices=False, label='', linecolor='gray', linewidth=1, alpha=1, **kwargs):
    # add the triangle
    ax.plot(XC[0].tolist()+[XC[0,0]], XC[1].tolist()+[XC[1,0]], '--', 
            color=linecolor, label=label, zorder=zorder, linewidth=linewidth, markersize=3, alpha=alpha)
    
    # add vertices
    if vertices:
        ax.scatter(XC[0,0], XC[1,0], color='C0', zorder=zorder, alpha=alpha, **kwargs)
        ax.scatter(XC[0,1], XC[1,1], color='C1', zorder=zorder, alpha=alpha, **kwargs)
        ax.scatter(XC[0,2], XC[1,2], color='C2', zorder=zorder, alpha=alpha, **kwargs)

In [None]:
def neighbor_label_transfer(k, ref_emb, qry_emb, ref_lbl, p_cutoff=0.5, dist_cutoff=None):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(ref_emb)
    dists, idx = neigh.kneighbors(qry_emb, k, return_distance=True)
    
    raw_pred = ref_lbl[idx]

    # p
    pabc = np.empty((qry_n, n_unq_lbls))
    for i, lbl in enumerate(unq_lbls):
        p = np.sum(raw_pred==lbl, axis=1)/k
        pabc[:,i] = p

    # max
    max_pred = unq_lbls[np.argmax(pabc, axis=1)]

    # 
    gated_pred = max_pred.copy()
    cond1 = np.max( pabc, axis=1) > p_cutoff
    gated_pred[~cond1] = 'NA' 
    if dist_cutoff is not None:
        cond2 = np.max(dists, axis=1) < dist_cutoff
        gated_pred[~cond2] = 'NA' 
    
    return max_pred, gated_pred, np.max(dists, axis=1)


def neighbor_self_nonself(k, ref_emb, qry_emb):
    """ref vs qry neighbors
    """
    unq_lbls = np.unique(ref_lbl).astype(str) # array(['L2/3_A', 'L2/3_B', 'L2/3_C'])
    n_unq_lbls = len(unq_lbls)
    ref_n = len(ref_emb)
    qry_n = len(qry_emb)
    lbls = np.array([0]*ref_n+[1]*qry_n)
    
    neigh = NearestNeighbors(n_neighbors=k) # , radius=0.4)
    neigh.fit(np.vstack([ref_emb, qry_emb]))
    idx = neigh.kneighbors(qry_emb, k, return_distance=False)
    
    isself = lbls[idx]

    p = np.sum(isself, axis=1)/k

    
    return p # max_pred, gated_pred, np.max(dists, axis=1)

# Archetype location

In [None]:
np.random.seed(0)
labels = list(data_dict.keys())[:5]
print(labels)

XC_dict = {}
for lbl in data_dict.keys(): 
    X = data_dict[lbl].T
    XC = aa_inference(X)
    XC_dict[lbl] = XC
    

In [None]:
XC_dict

In [None]:
sns.color_palette('tab20c')

In [None]:
color_dict = {
    'controlA': 'C1', # sns.color_palette('tab20c')[4],
    'controlB': 'C1', # sns.color_palette('tab20c')[6],
    
    'Rfx3OE_rep1': 'k',
    'Rfx3OE_rep2_lane1': 'k',
    'Rfx3OE_rep2_lane2': 'k',
}

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5*1,1*4)) # , sharex=True, sharey=True)
for lbl in labels:
    XC = XC_dict[lbl]
    color = color_dict[lbl]
    
    ax.plot(XC[0].tolist()+[XC[0,0]], XC[1].tolist()+[XC[1,0]], '--o', fillstyle='none', label=lbl, color=color, linewidth=1)
    ax.legend(bbox_to_anchor=(0,-0.25), loc='upper left')
    ax.set_aspect('equal')
    ax.grid(False)
    sns.despine(ax=ax)
    # ax.set_xlim([-18,18])
    # ax.set_ylim([-15,12])

plt.show()

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(5*5,1*4), sharex=True, sharey=True)
for ax, lbl in zip(axs, labels):
    XC = XC_dict[lbl]
    df = adata_dict[lbl]
    sns.scatterplot(data=df, x='PC1', y='PC2', s=3, edgecolor='none',#, palette=palette_type, hue_order=list(palette_type), 
                    rasterized=True,
                    ax=ax)
    add_triangle(XC, ax, vertices=True, linewidth=1, linecolor='k', zorder=2)
    ax.set_title(lbl)
    ax.set_aspect('equal')
    ax.grid(False)
    sns.despine(ax=ax)

plt.show()

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(5*5,1*4), sharex=True, sharey=True)
for ax, lbl in zip(axs, labels):
    XC = XC_dict[lbl]
    df = adata_dict[lbl]
    g = sns.histplot(ax=ax, data=df, x='PC1', y='PC2', 
                 stat='percent', binwidth=1, vmin=0, #vmax=1.2,
                 cmap='gray_r', cbar=True, cbar_kws=dict(shrink=0.5))
    
    add_triangle(XC, ax, vertices=True, linewidth=1, linecolor='k', zorder=2)
    ax.grid(False)
    sns.despine(ax=ax)
    
    ax.set_title(lbl)
    ax.set_aspect('equal')
plt.show()

In [None]:

XC = XC_dict['ctrl']
fig, axs = plt.subplots(1, 2, figsize=(5*2,1*4), sharex=True, sharey=True)
for ax, lbl in zip(axs, ['ctrl', 'oe']):
    df = adata_dict[lbl]
    g = sns.histplot(ax=ax, data=df, x='PC1', y='PC2', 
                 stat='percent', binwidth=1, vmin=0, #vmax=1.2,
                 cmap='gray_r', cbar=True, cbar_kws=dict(shrink=0.5))
    
    add_triangle(XC, ax, vertices=True, linewidth=1, linecolor='k', zorder=2)
    ax.grid(False)
    sns.despine(ax=ax)
    
    ax.set_title(lbl)
    ax.set_aspect('equal')
plt.show()

In [None]:
XC = XC_dict['ctrl']
fig, ax = plt.subplots(1, 1, figsize=(5*1,1*4)) 
    
add_triangle(XC, ax, vertices=False, linewidth=1, linecolor='k', zorder=2)
ax.grid(False)
sns.despine(ax=ax)
ax.set_title('OE - ctrl')
ax.set_aspect('equal')

xbins = np.arange(-20, 20, 1)
ybins = np.arange(-15, 10, 1)
xmin = np.min(xbins)
xmax = np.max(xbins)
ymin = np.min(ybins)
ymax = np.max(ybins)

X = data_dict['ctrl']
hist1, _, _ = np.histogram2d(X[:,0], X[:,1], bins=[xbins, ybins], normed=True)
Y = data_dict['oe']
hist2, _, _ = np.histogram2d(Y[:,0], Y[:,1], bins=[xbins, ybins], normed=True)

vlim = 0.5*np.percentile(hist1, 95)

g = ax.imshow((hist2-hist1).T, cmap='coolwarm', vmin=-vlim, vmax=vlim, origin='lower', extent=(xmin, xmax, ymin, ymax))
fig.colorbar(g, shrink=0.4)

plt.show()

# Optimal Transport

In [None]:
import ot
import ot.plot
from matplotlib import collections as mc

In [None]:
def OT_pipe(xs, xt, numbins=12):
    """
    xs = X_nr28.dot(V)[:,:2]
    xt = X_dr28.dot(V)[:,:2]
    """
    ns = len(xs)
    nt = len(xt)
    a = np.ones((ns,))/ns
    b = np.ones((nt,))/nt
    print(xs.shape, xt.shape)
    
    # ~5 sec for 4k cells vs 4k cells
    M = ot.dist(xs, xt)
    G0 = ot.emd(a, b, M)
    G0n = G0/np.array(G0.sum(axis=1)+1e-10).reshape(-1,1) #.shape
    
    # organize results (ns, nt)
    G0ns = sparse.coo_matrix(G0n)
    
    # per source cell vector
    alli, allj, allw = G0ns.row, G0ns.col, G0ns.data
    tmp = pd.DataFrame((xt[allj] - xs[alli])*allw.reshape(-1,1))
    tmp[2] = alli
    arrows = tmp.groupby(2).sum().reindex(np.arange(ns)).values
    
    # organize per cell vector 
    arrows = pd.DataFrame(arrows, columns=['dx', 'dy'])
    arrows['x'] = xs[:,0]
    arrows['y'] = xs[:,1]
    arrows['xbin'] = pd.cut(xs[:,0], numbins, labels=False)
    arrows['ybin'] = pd.cut(xs[:,1], numbins, labels=False)
    
    # local mean field
    n_arrows = arrows.groupby(['xbin', 'ybin']).size()
    n_arrows = n_arrows[n_arrows!=0]
    mean_arrows = arrows.groupby(['xbin', 'ybin']).mean().reindex(n_arrows.index).fillna(0) #.dropna()
    
    mags = np.sqrt(arrows['dx']**2+arrows['dy']**2)  # n_arrows # [[2,3]] #*n_arrows
    mean_mags = np.mean(mags)
    
    return G0ns, arrows, n_arrows, mean_arrows, mean_mags
    

In [None]:
def OT_plot(n_arrows, mean_arrows, XC, output=None):
    """
    """
    maxsize = np.max(n_arrows)
    minsize = np.min(n_arrows)
    print(minsize, maxsize)

    fig, ax = plt.subplots(figsize=(7,5))
    for size, arrow in zip(n_arrows.values, mean_arrows[['x', 'y', 'dx', 'dy']].values):
        ax.arrow(*arrow, linewidth=1, width=0.2, alpha=size/maxsize, edgecolor='none', facecolor='k')

    add_triangle(XC, ax, vertices=True, linewidth=1, linecolor='k', zorder=2)
    ax.grid(False)
    ax.set_aspect('equal')
    if output:
        powerplots.savefig_autodate(fig, output)
    plt.show()
    
def OT_plot_ax(ax, n_arrows, mean_arrows):
    """
    """
    maxsize = np.max(n_arrows)
    minsize = np.min(n_arrows)
    print(minsize, maxsize)
    for size, arrow in zip(n_arrows.values, mean_arrows[['x', 'y', 'dx', 'dy']].values):
        ax.arrow(*arrow, linewidth=1, width=0.2, alpha=size/maxsize, edgecolor='none', facecolor='k')

In [None]:
ot_res = dict()

XC = XC_dict['ctrl']
xs = data_dict['ctrl']
xt = data_dict['oe']
_, _, n_arrows, mean_arrows, mean_mags = OT_pipe(xs, xt, numbins=15)
ot_res['ctrl vs oe'] = (n_arrows, mean_arrows, mean_mags)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(6*3,5*1), sharex=True, sharey=True)
for ax, lbl in zip(axs.flat, ['ctrl', 'oe']):
    df = adata_dict[lbl]
    n = len(df)
    g = sns.histplot(ax=ax, data=df, x='PC1', y='PC2', 
                 stat='percent', binwidth=1.2, vmin=0, vmax=3,
                 cmap='gray_r', cbar=False, cbar_kws=dict(shrink=0.5))
    
    add_triangle(XC, ax, vertices=True, linewidth=1, linecolor='k', zorder=2)
    ax.grid(False)
    sns.despine(ax=ax)
    
    ax.set_title(lbl) # +f'\n{n} cells')
    # ax.set_title(lbl+f'\n{n} cells')
    ax.set_aspect('equal')
    ax.set_xlim([-18,18])
    ax.set_ylim([-15,12])
    
ax = axs[2]
OT_plot_ax(ax, n_arrows, mean_arrows)
add_triangle(XC, ax, vertices=True, linewidth=1, linecolor='k', zorder=2)
ax.grid(False)
ax.set_aspect('equal')
ax.set_title('ctrl -> oe')
sns.despine(ax=ax)
    
plt.show()

In [None]:
ot_res = dict()

XC = XC_dict['ctrl']
xs = data_dict['controlA']
xt = data_dict['controlB']
_, _, n_arrows, mean_arrows, mean_mags = OT_pipe(xs, xt, numbins=15)
ot_res['ctrl A vs B'] = (n_arrows, mean_arrows, mean_mags)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(6*3,5*1), sharex=True, sharey=True)
for ax, lbl in zip(axs.flat, ['controlA', 'controlB']):
    df = adata_dict[lbl]
    n = len(df)
    g = sns.histplot(ax=ax, data=df, x='PC1', y='PC2', 
                 stat='percent', binwidth=1.2, vmin=0, vmax=3,
                 cmap='gray_r', cbar=False, cbar_kws=dict(shrink=0.5))
    
    add_triangle(XC, ax, vertices=True, linewidth=1, linecolor='k', zorder=2)
    ax.grid(False)
    sns.despine(ax=ax)
    
    ax.set_title(lbl) # +f'\n{n} cells')
    # ax.set_title(lbl+f'\n{n} cells')
    ax.set_aspect('equal')
    ax.set_xlim([-18,18])
    ax.set_ylim([-15,12])
    
ax = axs[2]
OT_plot_ax(ax, n_arrows, mean_arrows)
add_triangle(XC, ax, vertices=True, linewidth=1, linecolor='k', zorder=2)
ax.grid(False)
ax.set_aspect('equal')
sns.despine(ax=ax)
    
plt.show()

# DR-NR use the same as MERFISH

In [None]:
from matplotlib.colors import LinearSegmentedColormap

colors_a = [(0.0, 'black'), (1.0, 'C0')]      
colors_b = [(0.0, 'black'), (1.0, 'C1')]      
colors_c = [(0.0, 'black'), (1.0, 'C2')]      

colors_nrdr = [(0.0, 'C1'), (0.5, 'white'), (1.0, 'black')]
colors_nr = [(0.0, 'white'), (1.0, 'C1'),]
colors_dr = [(0.0, 'white'), (1.0, 'black'),]

# Create a custom colormap using LinearSegmentedColormap
cmap_a = LinearSegmentedColormap.from_list('cmap_a', colors_a)
cmap_b = LinearSegmentedColormap.from_list('cmap_b', colors_b)
cmap_c = LinearSegmentedColormap.from_list('cmap_c', colors_c)
cmap_nrdr = LinearSegmentedColormap.from_list('cmap_nrdr', colors_nrdr)
cmap_nr = LinearSegmentedColormap.from_list('cmap_nr', colors_nr)
cmap_dr = LinearSegmentedColormap.from_list('cmap_dr', colors_dr)

In [None]:
xmin, xmax = -18, 18
ymin, ymax = -15, 12

bins_x = np.linspace(xmin, xmax, int((xmax-xmin)/1.2+1))
bins_y = np.linspace(ymin, ymax, int((ymax-ymin)/1.2+1))
print(bins_x)
print(bins_y)

hists = []
fig, axs = plt.subplots(1,3,figsize=(3*6,1*5), sharex=True, sharey=True)
for ax, lbl in zip(axs.flat, ['ctrl', 'oe']):
    df = adata_dict[lbl]
    x =  df['PC1']
    y =  df['PC2']
    sns.histplot(x=x, y=y, ax=ax, bins=(bins_x, bins_y), 
                 cmap='gray_r', 
                 stat='percent', vmin=0, vmax=3, 
                 cbar=True, cbar_kws=dict(shrink=0.4, ))
    
    hist, _, _= np.histogram2d(x, y, bins=(bins_x, bins_y))
    hist = hist/len(x)*100
    hists.append(hist)
    print(hist.shape)
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.grid(False)
    
    # add the triangle
    add_triangle(XC, ax, zorder=2)
    
ax = axs[2] 
ax.set_title('DR-NR')
g = ax.imshow(
    pd.DataFrame(hists[1]-hists[0],  
                index  =bins_x[1:]-1.2/2, 
                columns=bins_y[1:]-1.2/2).T, 
            origin='lower',
            extent=(xmin, xmax, ymin, ymax),
            # cmap='coolwarm', 
            cmap=cmap_nrdr, 
            vmax=0.5, vmin=-0.5)
# ax.invert_yaxis()
ax.set_aspect('equal')
ax.grid(False)
fig.colorbar(g, shrink=0.4)# ticks=[-1,0,1])
sns.despine(ax=ax)

# add the triangle
add_triangle(XC, ax, zorder=2)
# add_triangle(XC, ax, vertices=True, linewidth=1, linecolor='k', zorder=2)

plt.show()