In [None]:
import os
import numpy as np
import pandas as pd
from skimage import io, measure
import tifffile
from kneed import KneeLocator
from itertools import combinations

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import plotly.io as pio
pio.renderers.default = 'jupyterlab'
import plotly.graph_objs as go

In [None]:
import sys
import importlib

sys.path.insert(0, '/u/home/f/f7xiesnm/project-zipursky/code/easifish-proc/bydatasets/')
sys.path.append('../../')

from scroutines import powerplots

import metadata_gene_chan
importlib.reload(metadata_gene_chan)
from metadata_gene_chan import get_proj_metadata

meta = get_proj_metadata()

from easi_fish import roi_prop, spot, intensity
# importlib.reload(spot)
# importlib.reload(roi_prop)
# importlib.reload(intensity)

import shared_functions
importlib.reload(shared_functions)

from shared_functions import spot_to_voxel_coords, filter_raw_spots, spots_incells_metrics, scramble_cell_masks, masks_to_labeled_masks
from shared_functions import plot_frac, plot_reverse_cumsum, plot_reverse_cumsum_complex

In [None]:
outdir = "/u/home/f/f7xiesnm/project-zipursky/easifish/results/ms_multiome/lt186"
!mkdir -p $outdir

In [None]:
ddir0 = "/u/home/f/f7xiesnm/project-zipursky/easifish/lt186"
ddir1 = "/u/home/f/f7xiesnm/project-zipursky/easifish/results/viz_all_projections_jan29"

PLUS_ONE = False # False if table indices start from 1 ; True if from 0

# s3 resolution - to downsamp and rounding
f_msk  = ddir0 + '/outputs/r1v3/segmentation/r1v3-c3.tif'
f_tbl  = ddir0 + '/proc/r1v3/roi.csv'
f_img0 = ddir1 + '/lt186_r1_autos1_flatfused_c0_s4.tiff' # use c0

# spot dir for every gene
fx_spots = {
    'r1v3_c0': ddir0 + f'/outputs/r1v3/spots/spots_c0.txt',
    'r1v3_c2': ddir0 + f'/outputs/r1v3/spots/spots_c2.txt',
}
# intn thresholds
intn_threshs = {
    'r1v3_c0': 70, 
    'r1v3_c2': 60, 
}

ex = 2
lb_res = [0.23*8, 0.23*8, 0.42*4] # s3 resolution

# meta
metakey = 'lt186'
channels     = meta[metakey]['channels']
proj_targets = meta[metakey]['proj_targets']
colors       = meta[metakey]['colors']
polarities   = meta[metakey]['axis_polarity']

np.random.seed(0)

In [None]:
# S4 image shape
img_shape = io.imread(f_img0).shape
print(img_shape)

# S3 mask
lb = io.imread(f_msk)
lb = np.array(lb)

# some set up
lb_id = np.unique(lb[lb!=0]) # exclude 0
lb_id = np.hstack([[0], lb_id]) # include 0 - noncell

# downsample and trim to the same shape
mx, my, mz = lb.shape
ix, iy, iz = img_shape
rx, ry, rz = mx-ix*2, my-iy*2, mz-iz*2
msk = lb[0:mx-rx:2, 
         0:my-ry:2, 
         0:mz-rz:2,]
print(lb.shape)
print(msk.shape)
assert np.all(msk.shape == img_shape)

# table - check mask numbers == number of cells
# props  = pd.read_csv(f_tbl1, index_col=0)
# props2 = pd.read_csv(f_tbl2, index_col=0)
# props  = props.join(props2, how='left')
props  = pd.read_csv(f_tbl, index_col=0)
assert np.all(np.unique(msk[msk!=0]) == props.index.values)

# filter out cells that are too large or too small
cond_filter = np.logical_and(props['area']<4000, props['area']>1000)
print(cond_filter.sum()/len(props))
props = props[cond_filter]

# lb_id_selected
lb_id_selected = props.index.values

# normalization -  to mean area
mean_area = np.mean(props['area'])
norm_factor = props['area']/mean_area


In [None]:
sns.histplot(props['area'].values)

# check - c0

physical (x ex) -> image coords <- voxel coords (x voxel size)

In [None]:
# shuffled masks
lb_shuff = scramble_cell_masks(lb)

In [None]:

alpha = 0.05
thresholds = []
bins = np.linspace(0, 30, 50)
x0x1s = [(3, 20), (3,15)] 

from scipy.interpolate import interp1d
from scipy.optimize import fsolve, root_scalar
    
# channel by channel
for ch, x0x1, proj_target in zip(channels, x0x1s, proj_targets):
    f_spots = fx_spots[ch]
    intn_th = intn_threshs[ch]
    print(ch, f_spots) #, f_intns)
    x0, x1 = x0x1

    # spots
    spots_rc = np.loadtxt(f_spots, delimiter=',')
    spots_rc = filter_raw_spots(spots_rc, intn_th, lb_res, lb)
    
    spots_incells_metrics(spots_rc, lb, lb_res)

    # plot filtered spots
    fig, axs = plt.subplots(1,3,figsize=(3*8,1*6))
    indices = [0,1,2]
    indices_label = ['x', 'y', 'z']
    for idx, (i, j) in enumerate(combinations(indices, 2)):
        print(idx, i, j)
        ax = axs[idx]
        pi = polarities[i]
        pj = polarities[j]
        if pi == -1:
            ax.invert_xaxis()
        if pj == -1:
            ax.invert_yaxis()
        
        ax.scatter(spots_rc[:,i]/ex, spots_rc[:,j]/ex, s=1, edgecolor='none', color='gray')
        ax.set_xlabel(indices_label[i])
        ax.set_ylabel(indices_label[j])
        ax.set_aspect('equal')
        
    # count spots
    counts      = spot.spot_counts_worker(lb, spots_rc, lb_res,
                                  lb_id=lb_id, 
                                  remove_noncell=True, 
                                  selected_roi_list=lb_id_selected,
                                  )
    counts_shuff = spot.spot_counts_worker(lb_shuff, spots_rc, lb_res,
                                  lb_id=lb_id, 
                                  remove_noncell=True, 
                                  selected_roi_list=lb_id_selected,
                                  )
    norm_counts        = counts.values       /norm_factor.values
    norm_counts_shuff  = counts_shuff.values /norm_factor.values

    # get threshold
    fig, axs, fdrs = plot_reverse_cumsum_complex(
                                    bins,
                                    [norm_counts, norm_counts_shuff], 
                                    label_list=['data', 'scrambled cells'],
                                    color_list=['C1', 'black'],
                                   )
    func = interp1d(bins[1:], 
                    -np.log10(fdrs)-(-np.log10(alpha)), 
                    kind='linear')
    solution = root_scalar(func, x0=x0, x1=x1) # , x1=10)
    if solution.converged:
        x_th = solution.root
        print(x_th)
    else:
        x_th = None
    
    axs[0].legend()
    
    axs[1].axvline(x_th,             linestyle='--', color='gray')
    axs[1].axhline(func(x_th)+(-np.log10(alpha)), linestyle='--', color='gray')
    axs[1].set_title(f'threshold = {x_th:.2f}')
    
    output = os.path.join(outdir, f'1_{ch}.pdf')
    powerplots.savefig_autodate(fig, output)
    plt.show()

    # save
    # threshold
    thresholds.append(x_th)
    # norm counts
    props[f'new_{ch}'] = norm_counts
    # bin intensity (for binary label) 
    props[f'proj_{ch}'] = (norm_counts > x_th).astype(int)
    props[proj_target] = props[f'proj_{ch}'].copy()
    

In [None]:
polarities

In [None]:
axis_labels = ['x','y','z']
axis_indices = [0,1,2]
fig, axs = plt.subplots(1, 3, figsize=(3*8,1*6))
for idx, (i, j) in enumerate(combinations(axis_indices, 2)):
    ax = axs[idx]
    ax.set_aspect('equal')
    
    li = axis_labels[i]
    lj = axis_labels[j]
    ax.set_xlabel(li)
    ax.set_ylabel(lj)
    
    # ax.invert_yaxis() # a must to move (0,0) to the upper left
    # pi = polarities[i]
    # pj = polarities[j]
    # if pi == -1:
    #     ax.invert_xaxis()
    # if pj == -1:
    #     ax.invert_yaxis()
    
    g = ax.scatter(props[li], props[lj], c='lightgray', s=1, edgecolor='none')
    for ch, th, cl in zip(channels, thresholds, colors):
        cond = props[f'proj_{ch}'] > 0
        num = cond.sum()
        g = ax.scatter(props[li][cond], props[lj][cond], c=cl, s=5, edgecolor='none', label=f'{ch} (n={num:,})', rasterized=True)
        
    cond = np.sum(props[[f'proj_{ch}' for ch in channels]], axis=1) == 2
    num = cond.sum()
    g = ax.scatter(props[li][cond], props[lj][cond], c='purple', s=5, edgecolor='none', label=f'both (n={num:,})', rasterized=True)

axs[0].legend(bbox_to_anchor=(0,0))


# flip Z
axs[1].invert_yaxis()
axs[2].invert_yaxis()

# rotate 180 degree (-x, -y)
axs[0].invert_xaxis()
axs[0].invert_yaxis()
axs[1].invert_xaxis()
axs[2].invert_xaxis()

output = os.path.join(outdir, '2.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

# visualize the masks

In [None]:
# keep projection labeled masks

for ch in channels:

    if PLUS_ONE:
        labeled_cells = 1+props[props[f'proj_{ch}']>0].index.values #, props['area']<max_cellsize)].index.values
    else:
        labeled_cells = props[props[f'proj_{ch}']>0].index.values #, props['area']<max_cellsize)].index.values
        
    labeled_masks = masks_to_labeled_masks(msk, labeled_cells)
    print(np.unique(labeled_masks).shape)

    # # save as tiff
    output = f_img0.replace('_c0_', f'_{ch}_').replace('.tiff', '_labeled_masks_countbased.tiff') 
    print(output)
    tifffile.imwrite(output, labeled_masks)
    

# Visualize the dots 

In [None]:
# plotly 3D

traces = []

# allprops = props # .index.values # props['area']<max_cellsize)].index.values
allx = props['x'].values
ally = props['y'].values
allz = props['z'].values

trace = go.Scatter3d(x=allx, y=ally, z=allz, mode='markers',
                     marker=dict(size=2, color='lightgray', opacity=0.5))
traces.append(trace)

for i, (ch, th) in enumerate(zip(channels, thresholds)):
    labeled_props = props[props[f'new_{ch}']>th] # .index.values # props['area']<max_cellsize)].index.values
    x = labeled_props['x'].values
    y = labeled_props['y'].values
    z = labeled_props['z'].values
    color = colors[i]
    
    trace = go.Scatter3d(x=x, y=y, z=z, mode='markers',
                         marker=dict(size=3, color=color, opacity=0.5))
    traces.append(trace)
    
layout = go.Layout(title='',
                   scene=dict(
                       xaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       yaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       zaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       xaxis_title='x', yaxis_title='y', zaxis_title='z',
                   ), 
                   scene_dragmode='orbit',
                  )

fig = go.Figure(data=traces, layout=layout)
fig.write_html("figure.html")
fig.show()

# quants and stats

In [None]:
# bin spatial
stepsize_xy = 100 # 40
stepsize_z  = 20

xmax = props['x'].max()
ymax = props['y'].max()
zmax = props['z'].max()

xbins = np.arange(0, xmax+stepsize_xy, stepsize_xy)
ybins = np.arange(0, ymax+stepsize_xy, stepsize_xy)
zbins = np.arange(0, zmax+stepsize_z , stepsize_z)
zbins_2x = np.arange(0, zmax+stepsize_z*2 , stepsize_z*2)

props['xbin'] = pd.cut(props['x'], bins=xbins, labels=False) #.astype(int)
props['ybin'] = pd.cut(props['y'], bins=ybins, labels=False)
props['zbin'] = pd.cut(props['z'], bins=zbins, labels=False)

# along Z

In [None]:
props_labeled = props[np.sum(props[proj_targets], axis=1)>0]
props_lm = props[props['LM']>0]
props_rl = props[props['RL']>0]



histz_allcell, _ = np.histogram(props['z'], bins=zbins)
histz_labeled, _ = np.histogram(props_labeled['z'], bins=zbins)

In [None]:
l23_start, l23_end = 100, 400

fig, axs = plt.subplots(1, 2, figsize=(2*7,1*4))
ax = axs[0]
ax.plot(zbins[1:], histz_allcell, '-', 
        color='lightgray', label='cytoDAPI')
ax.plot(zbins[1:], histz_labeled, '-', 
        color='black', label='projection labeled')

ax.grid(False)
ax.axvline(l23_start, color='gray', linestyle='--')
ax.axvline(l23_end, color='gray', linestyle='--')
ax.set_xlabel('cortical depth (um)')
ax.set_ylabel('number of cells')

ax2 = ax.twinx()
ax2.plot(zbins[1:], histz_labeled/histz_allcell, '-', 
        color='red', label='frac projection labeled')
ax2.set_ylabel('fraction of cells')

ax2.grid(False)
sns.despine(ax=ax)


ax = axs[1]
for i, target in enumerate(proj_targets):
    datasub = props.loc[props[target]>0, 'z']
    histz_sub, _ = np.histogram(datasub, bins=zbins)
    ax.plot(zbins[1:], histz_sub, '-', 
            color=colors[i], 
            label=f'{target} (n={len(datasub):,})',
            )
                             
datasub = props.loc[np.sum(props[proj_targets], axis=1)==2, 'z'] 
histz_sub, _ = np.histogram(datasub, bins=zbins)
ax.plot(zbins[1:], histz_sub, '-', 
        color='purple',
        label=f'both (n={len(datasub):,})',
        )
ax.legend(bbox_to_anchor=(1,1))
sns.despine(ax=ax)

ax.axvline(l23_start, color='gray', linestyle='--')
ax.axvline(l23_end, color='gray', linestyle='--')
ax.grid(False)
ax.set_xlabel('cortical depth (um)')
ax.set_ylabel('number of cells')

fig.tight_layout()
output = os.path.join(outdir, 'proj_bias_z.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()


In [None]:
l23_start, l23_end = 100, 400

fig, ax = plt.subplots(1, 1, figsize=(1*7,1*4))
for i, target in enumerate(proj_targets):
    datasub = props.loc[props[target]>0, 'z']
    histz_sub, _ = np.histogram(datasub, bins=zbins)
    ax.plot(zbins[1:], histz_sub, '-', 
            color=colors[i], 
            label=f'{target} (n={len(datasub):,})',
            )
                             
datasub = props.loc[np.sum(props[proj_targets], axis=1)==2, 'z'] 
histz_sub, _ = np.histogram(datasub, bins=zbins)
ax.plot(zbins[1:], histz_sub, '-', 
        color='purple',
        label=f'both (n={len(datasub):,})',
        )
ax.legend(bbox_to_anchor=(1,1))
sns.despine(ax=ax)

ax.axvline(l23_start, color='gray', linestyle='--')
ax.axvline(l23_end, color='gray', linestyle='--')
ax.grid(False)
ax.set_xlabel('cortical depth (um)')
ax.set_ylabel('number of cells')

fig.tight_layout()
output = os.path.join(outdir, 'proj_bias_z.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
fout = '/u/home/f/f7xiesnm/v1_multiome/easifish_sample1_res.csv'
props.to_csv(fout)
# !head $fout

# along XY

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
ax = axs[0]
label = 'cytoDAPI labeled'
sns.histplot(data=props, 
             x='x', y='y', bins=[xbins, ybins], 
             element='step', fill=False,
             cmap='gray_r',
             cbar=True, cbar_kws={'shrink': 0.5},
             ax=ax,
            )
ax.grid(False)
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.set_title(label)

ax = axs[1]
label = 'projection labeled'
sns.histplot(data=props.loc[np.sum(props[proj_targets], axis=1)>0], 
             x='x', y='y', bins=[xbins, ybins], 
             element='step', fill=False,
             cmap='gray_r',
             cbar=True, cbar_kws={'shrink': 0.5},
             ax=ax,
            )
ax.grid(False)
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.set_title(label)
output = os.path.join(outdir, '4-1.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
for i, target in enumerate(proj_targets):
    ax = axs[i]
    color = colors[i]
    sns.histplot(data=props.loc[props[target]>0], 
                 x='x', y='y', bins=[xbins, ybins], 
                 element='step', fill=False,
                 cmap=f'{color[0].upper()}{color[1:]}s', 
                 cbar=True, cbar_kws={'shrink': 0.5, },
                 ax=ax,
                )
    ax.grid(False)
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.set_title(target)
    
output = os.path.join(outdir, '4-2.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
histxy_lm, _, _ = np.histogram2d(props_lm['x'], props_lm['y'], bins=[xbins, ybins])
histxy_rl, _, _ = np.histogram2d(props_rl['x'], props_rl['y'], bins=[xbins, ybins])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1*5,1*4))
sns.heatmap(np.log2((histxy_rl)/(histxy_lm)).T, # x then y
            cmap='coolwarm_r', 
            center=0,
            vmax=3, vmin=-3,
            cbar_kws=dict(shrink=0.5, 
                          label='log2([#RL]/[#LM])', 
                          ticks=[-3,0,3],
                         ))
# ax.axis('off')
# ax.invert_yaxis()
ax.grid(False)
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.set_title('projection bias')
    
output = os.path.join(outdir, 'proj_bias_xy.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

# subpanels

In [None]:

rectangles = [
    [150, 300, 100, 250], # xmin, xmax, ymin, ymax
    [150, 300, 550, 700], # xmin, xmax, ymin, ymax
    [450, 600, 300, 450], # xmin, xmax, ymin, ymax
    [700, 850, 100, 250], # xmin, xmax, ymin, ymax
    [700, 850, 550, 700], # xmin, xmax, ymin, ymax
]


axis_labels = ['x','y','z']
axis_indices = [0,1,2]
fig, axs = plt.subplots(1, 3, figsize=(3*8,1*6))
for idx, (i, j) in enumerate(combinations(axis_indices, 2)):
    ax = axs[idx]
    ax.set_aspect('equal')
    
    li = axis_labels[i]
    lj = axis_labels[j]
    ax.set_xlabel(li)
    ax.set_ylabel(lj)
    
    pi = polarities[i]
    pj = polarities[j]
    if pi == -1:
        ax.invert_xaxis()
    if pj == -1:
        ax.invert_yaxis()
    
    g = ax.scatter(props[li], props[lj], c='lightgray', s=1, edgecolor='none', rasterized=True)
    for ch, target, th, cl in zip(channels, proj_targets, thresholds, colors):
        cond = props[f'proj_{ch}'] > 0
        num = cond.sum()
        g = ax.scatter(props[li][cond], props[lj][cond], c=cl, s=5, edgecolor='none', label=f'{target} (n={num:,})', rasterized=True)
        
    cond = np.sum(props[[f'proj_{ch}' for ch in channels]], axis=1) == 2
    num = cond.sum()
    g = ax.scatter(props[li][cond], props[lj][cond], c='purple', s=5, edgecolor='none', label=f'both (n={num:,})', rasterized=True)

axs[0].legend(bbox_to_anchor=(0,0))

for rect_idx, [xmin, xmax, ymin, ymax] in enumerate(rectangles):
    axs[0].plot([xmin, xmax, xmax, xmin, xmin], 
                [ymin, ymin, ymax, ymax, ymin],
                '-', color='k'
               )
    axs[0].text(xmin+50, ymin-50, f"#{rect_idx+1}", bbox=dict(facecolor='lightgray', linewidth=0))

output = os.path.join(outdir, 'new1.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
n = len(rectangles)

fig, axs = plt.subplots(1, n, figsize=(n*5,1*4), sharex=True, sharey=True)
for idx, rect in enumerate(rectangles):
    ax = axs[idx]
    ax.set_title(f'#{idx+1}')
    xmin, xmax, ymin, ymax = rect
    
    cond = np.all([props['x']>xmin, 
                   props['x']<xmax,
                   props['y']>ymin,
                   props['y']<ymax,
                  ], axis=0)
    
    props_sub = props[cond]
    
    for i, target in enumerate(proj_targets):
        datasub = props_sub.loc[props_sub[target]>0, 'z']
        sns.histplot(data=datasub, bins=zbins, color=colors[i], 
                     element='step', fill=False,
                     label=f'{target} (n={len(datasub):,})',
                     ax=ax,
                    )
                      
    datasub = props_sub.loc[np.sum(props_sub[proj_targets], axis=1)==2, 'z'] 
    sns.histplot(data=datasub, bins=zbins, color='purple',
                 element='step', fill=False,
                 label=f'both (n={len(datasub):,})',
                 ax=ax,
                )
    ax.grid(False)
    ax.legend(bbox_to_anchor=(0,-0.2), loc='upper left')
    sns.despine(ax=ax)

output = os.path.join(outdir, 'new2.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
n = len(rectangles)

fig, axs = plt.subplots(1, n, figsize=(n*5,1*4), sharex=True, sharey=True)
for idx, rect in enumerate(rectangles):
    ax = axs[idx]
    ax.set_title(f'#{idx+1}')
    xmin, xmax, ymin, ymax = rect
    
    cond = np.all([props['x']>xmin, 
                   props['x']<xmax,
                   props['y']>ymin,
                   props['y']<ymax,
                  ], axis=0)
    
    props_sub = props[cond]
    
    for i, target in enumerate(proj_targets):
        datasub = props_sub.loc[props_sub[target]>0, 'z']
        sns.histplot(data=datasub, bins=zbins_2x, color=colors[i], 
                     stat='proportion',
                     element='step', fill=False,
                     label=f'{target} (n={len(datasub):,})',
                     ax=ax,
                    )
                      
    # datasub = props_sub.loc[np.sum(props_sub[proj_targets], axis=1)==2, 'z'] 
    # sns.histplot(data=datasub, bins=zbins_2x, color='purple',
    #              stat='percent',
    #              element='step', fill=False,
    #              label=f'both (n={len(datasub):,})',
    #              ax=ax,
    #             )
    ax.grid(False)
    ax.legend(bbox_to_anchor=(0,-0.2), loc='upper left')
    sns.despine(ax=ax)

output = os.path.join(outdir, 'new3.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

# density map

In [None]:
def plot_density_maps(dfsize_1, dfsize_2, dfsize_3):
    fig, axs = plt.subplots(1,4,figsize=(4*6,1*5), sharex=True,sharey=True)
    ax = axs[0]
    sns.heatmap(dfsize_1.T, ax=ax, cmap='rocket_r', cbar_kws=dict(shrink=0.5))
    ax.set_aspect('equal')
    ax.set_title('LM')

    ax = axs[1]
    sns.heatmap(dfsize_2.T, ax=ax, cmap='rocket_r', cbar_kws=dict(shrink=0.5))
    ax.set_aspect('equal')
    ax.set_title('RL')

    ax = axs[2]
    sns.heatmap(dfsize_3.T, ax=ax, cmap='rocket_r', cbar_kws=dict(shrink=0.5))
    ax.set_aspect('equal')
    ax.set_title('both')

    ax = axs[3]
    sns.heatmap(np.log2((dfsize_1.T+1)/(dfsize_2.T+1)), 
                cmap='coolwarm', 
                center=0,
                cbar_kws=dict(shrink=0.5)
               )
    ax.set_aspect('equal')
    ax.set_title('log2(LM/RL)')
    plt.show()

In [None]:
# plot_density_maps(dfsize_xz1, dfsize_xz2, dfsize_xz3)
# plot_density_maps(dfsize_yz1, dfsize_yz2, dfsize_yz3)
# plot_density_maps(dfsize_xy1, dfsize_xy2, dfsize_xy3)



# at every x and y, compare z distribution

In [None]:
# dfsub = df[df['zrbin'] == 8]
# plot_easifish_proj(dfsub, 'xr', 'yr', invert_yaxis=True)
# plot_easifish_proj(dfsub, 'xr', 'zr', invert_yaxis=True)
# plot_easifish_proj(dfsub, 'yr', 'zr', invert_yaxis=True)

In [None]:
props

In [None]:
# xy
dfmean_rl = props[props['RL']>0].groupby(['xbin', 'ybin'])['z'].mean().unstack() # .unstack() #.fillna(0).astype(int) #mean().unstack()
dfmean_lm = props[props['LM']>0].groupby(['xbin', 'ybin'])['z'].mean().unstack() # .unstack() #.fillna(0).astype(int) #mean().unstack()
dfmean_del = dfmean_rl - dfmean_lm

In [None]:
dfmean_del

In [None]:
dfmean_del_shuffs = []

n_rep = 1000
for i in range(n_rep): 
    props_shuff = props.copy()
    props_shuff['RL'] = np.random.choice(props['RL'].values, size=len(props), replace=False)
    props_shuff['LM'] = np.random.choice(props['LM'].values, size=len(props), replace=False)
    
    dfmean_rl_shuff = props_shuff[props_shuff['RL']>0].groupby(['xbin', 'ybin'])['z'].mean().unstack()
    dfmean_lm_shuff = props_shuff[props_shuff['LM']>0].groupby(['xbin', 'ybin'])['z'].mean().unstack()
    
    dfmean_rl_shuff = dfmean_rl_shuff.reindex(dfmean_del.index)[dfmean_del.columns] # use the same dimensinos
    dfmean_lm_shuff = dfmean_lm_shuff.reindex(dfmean_del.index)[dfmean_del.columns] # use the same dimensinos
    
    dfmean_del_shuff = dfmean_rl_shuff - dfmean_lm_shuff
    dfmean_del_shuffs.append(dfmean_del_shuff.values)
    
dfmean_del_shuffs = np.array(dfmean_del_shuffs)
dfmean_del_shuffmean = np.nanmean(dfmean_del_shuffs, axis=0) # .shape
pvals = (1+np.sum(np.abs(dfmean_del_shuffs) > np.abs(dfmean_del.values[np.newaxis,:,:]), axis=0))/n_rep
pvals

In [None]:
from matplotlib.colors import LinearSegmentedColormap

# Define colors (start and end)
colors = ["black", 'white', "C1"]

# Create colormap
custom_cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)


In [None]:



fig, axs = plt.subplots(1, 3, figsize=(6*3,5))
for i, (dfmean, title) in enumerate(zip(
    [dfmean_lm, dfmean_rl, dfmean_del], 
    ['LM', 'RL', 'RL-LM'], 
    )):
    ax = axs[i]
    if i == 2:
        cmap = custom_cmap
        center = 0 
        vmin = -100
        vmax = +100
    else:
        cmap = 'gray_r'
        center = None
        vmin = 100
        vmax = 400
        
    sns.heatmap(dfmean.T, 
                cmap=cmap, center=center, vmin=vmin, vmax=vmax, 
                cbar_kws=dict(shrink=.5, label='mean depth (um)'), 
                ax=ax,
               )
    ax.set_aspect('equal')
    ax.grid(False)
    
fig.tight_layout()
output = os.path.join(outdir, 'dist_diff_heatmap.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()


In [None]:

fig, axs = plt.subplots(1, 3, figsize=(6*3,5))
for i, dfmean in enumerate([dfmean_lm_shuff, dfmean_rl_shuff, dfmean_del_shuff]):
    ax = axs[i]
    if i == 2:
        cmap = 'coolwarm'
        center = 0 
        vmin = -100
        vmax = +100
    else:
        cmap = 'rocket_r'
        center = None
        vmin = 100
        vmax = 400
        
    sns.heatmap(dfmean.T, 
                cmap=cmap, center=center, vmin=vmin, vmax=vmax, 
                cbar_kws=dict(shrink=.5, label='mean depth (um)'), 
                ax=ax,
               )
    ax.set_aspect('equal')
    ax.grid(False)
    
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(1*6,1*5))
ax.scatter(dfmean_del.T, -np.log10(pvals).T, s=5, color='k')
ax.set_xlabel('mean RL-LM in zr (um)')
ax.set_ylabel('-log10(p-value)')
ax.set_xlim([-150, 150])
sns.despine(ax=ax)
# ax.grid(False)
ax.axhline(-np.log10(0.05), color='k', linestyle='--')
# ax.axvline(0, color='k', linestyle='--')
plt.show()

# ax.grid(False)

In [None]:
del_bins = np.linspace(-150, 150, 50)

fig, ax = plt.subplots(figsize=(6,4))
sns.histplot(dfmean_del_shuffs.reshape(-1,), stat='density', 
             bins=del_bins, element='step', 
             ax=ax, color='gray', label='shuffled projection labels')
sns.histplot(dfmean_del.values.reshape(-1,), stat='density', 
             bins=del_bins, element='step',
             ax=ax, color='C1', label='data')

ax.legend(bbox_to_anchor=(1,1), loc='upper left')
sns.despine(ax=ax)
ax.grid(False)
ax.set_xlabel('mean depth diff (um; RL-LM)')

output = os.path.join(outdir, 'dist_shuff_data.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()