In [None]:
import os
import numpy as np
import pandas as pd
from skimage import io, measure
import tifffile
from kneed import KneeLocator

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import plotly.io as pio
pio.renderers.default = 'jupyterlab'
import plotly.graph_objs as go

In [None]:
import sys
sys.path.insert(0, '/u/home/f/f7xiesnm/project-zipursky/code/easifish-proc/bydatasets/')
from metadata_gene_chan import get_proj_metadata

meta = get_proj_metadata()

In [None]:
import os, sys
sys.path.append('../../')

from easi_fish import roi_prop, spot, intensity
# import warnings
# warnings.filterwarnings('ignore')

import importlib
importlib.reload(spot)
importlib.reload(roi_prop)
importlib.reload(intensity)

In [None]:
def masks_to_labeled_masks(msk, labeled_cells):
    unq, inv = np.unique(msk.reshape(-1,), return_inverse=True)

    for i in unq:
        if i not in labeled_cells:
            unq[i] = 0

    labeled_msk = unq[inv].reshape(msk.shape)
    
    return labeled_msk

In [None]:
ddir0 = "/u/home/f/f7xiesnm/project-zipursky/easifish/lt186"
ddir1 = "/u/home/f/f7xiesnm/project-zipursky/easifish/results/viz_all_projections_jan29"
PLUS_ONE = False # False if table indices start from 1 ; True if from 0

# s3 resolution - to downsamp and rounding
f_msk  = ddir0 + '/outputs/r1v3/segmentation/r1v3-c3.tif'
f_tbl1 = ddir0 + '/proc/r1v3/spotcount.csv'
f_tbl2 = ddir0 + '/proc/r1v3/roi.csv'
f_img0 = ddir1 + '/lt186_r1_autos1_flatfused_c0_s4.tiff'

# spot dir for every gene
fx_spots = {
    'r1v3_c0': ddir0 + f'/outputs/r1v3/spots/spots_c0.txt',
    'r1v3_c2': ddir0 + f'/outputs/r1v3/spots/spots_c2.txt',
}
intn_threshs = {
    'r1v3_c0': 60, 
    'r1v3_c2': 60, 
}

# meta
metakey = 'lt186'
channels = meta[metakey]['channels']
proj_targets = meta[metakey]['proj_targets']
colors = meta[metakey]['colors']

# S4 image shape
img_shape = io.imread(f_img0).shape
print(img_shape)

# S3 mask
lb = io.imread(f_msk)
lb = np.array(lb)

# downsample and trim to the same shape
msk = lb[::2, ::2, ::2]
msk = msk[:-1,:,:-1]
print(lb.shape)
print(msk.shape)
assert np.all(msk.shape == img_shape)

# table - check mask numbers == number of cells
props  = pd.read_csv(f_tbl1, index_col=0)
props2 = pd.read_csv(f_tbl2, index_col=0)
props  = props.join(props2, how='left')
assert np.all(np.unique(msk[msk!=0]) == props.index.values)

# filter out cells that are too large or too small
cond_filter = np.logical_and(props['area']<5000, props['area']>500)
print(cond_filter.sum()/len(props))
props = props[cond_filter]

# # normalization -  to mean area
# mean_area = np.mean(props['area'])
# norm_factor = props['area']/mean_area
# for i, ch in enumerate(channels):
#     props[ch] = props[ch]/norm_factor

# props

In [None]:
sns.histplot(props['area'].values)

# check - c0

In [None]:
def plot_frac(data, shff, bins=np.arange(0,11,1)):
    """
    """
    fig, ax = plt.subplots()
    
    cnts_data, bins = np.histogram(data, bins)
    cnts_shff, bins = np.histogram(shff, bins)
    
    ax.plot(bins[1:], cnts_shff/cnts_data, '-o')
    ax.set_yscale('log')
    ax.set_yticks([1,0.1,0.05,0.01])
    ax.set_yticklabels([1,0.1,0.05, 0.01])
    
    return fig

def plot_reverse_cumsum(counts, bins=np.arange(0,11,1), ymax=None):
    """
    """
    fig, ax = plt.subplots(figsize=(8,6))
    ax2 = ax.twinx()
    ax.set_xlabel('num spots')
    ax.set_ylabel('num cells (cumulative)')
    ax2.set_ylabel('fraction of cells')
    
    n = len(counts)
    cnts, _ = np.histogram(counts, bins)
    rev_cumsum = n-np.cumsum(cnts)
        
    ax.plot(bins[1:], rev_cumsum, '-o', )
    ax2.plot(bins[1:], rev_cumsum/n, '-o', )
    
    if ymax:
        ax.set_ylim(ymin=0, ymax=ymax)
        ax2.set_ylim(ymin=0, ymax=ymax/n)
        
    ax.grid(False)
    ax2.grid(False)
    
    return fig

def plot_reverse_cumsum_complex(counts_list, label_list=None, color_list=None, bins=np.arange(0,11,1), ymax=None):
    """
    """
    fig, axs = plt.subplots(1,2,figsize=(2*8,6))
    ax = axs[0]
    ax2 = ax.twinx()
    ax.set_xlabel('num spots')
    ax.set_ylabel('num cells (cumulative)')
    ax2.set_ylabel('fraction of cells')
    
    if label_list is None:
        label_list = np.arange(len(counts_list))
    if color_list is None:
        color_list = sns.color_palette(n_colors=3)
        
    cumsum_list = []
    n = len(counts_list[0])
    for counts in counts_list:
        assert n == len(counts) # assumes len(counts) is the same
        cnts, _ = np.histogram(counts, bins)
        rev_cumsum = n-np.cumsum(cnts)
        cumsum_list.append(rev_cumsum)
        
    for revcnts, lb, color in zip(cumsum_list, label_list, color_list):
        ax.plot(bins[1:], revcnts, '-o', label=lb, color=color)
        ax2.plot(bins[1:], revcnts/n, '-o', label=lb, color=color)
    
    if ymax:
        ax.set_ylim(ymin=0, ymax=ymax)
        ax2.set_ylim(ymin=0, ymax=ymax/n)
        
    ax.grid(False)
    ax2.grid(False)
    
    ax = axs[1]
    lb = label_list[1]
    color = color_list[1]
    ax.plot(bins[1:], cumsum_list[1]/cumsum_list[0], '-o', label=lb, color=color)
    
    lb = label_list[2]
    color = color_list[2]
    ax.plot(bins[1:], cumsum_list[2]/cumsum_list[0], '-o', label=lb, color=color)
    ax.set_yscale('log')
    ax.set_yticks([1,0.1,0.05,0.01])
    ax.set_yticklabels([1,0.1,0.05, 0.01])
    ax.set_ylabel('eFDR (shuff/data)')
    
    fig.subplots_adjust(wspace=0.4)
    
    return fig, axs

def plot_frac(data, shff, bins=np.arange(0,11,1)):
    """
    """
    fig, ax = plt.subplots()
    
    cnts_data, bins = np.histogram(data, bins)
    cnts_shff, bins = np.histogram(shff, bins)
    
    ax.plot(bins[1:], cnts_shff/cnts_data, '-o')
    ax.set_yscale('log')
    ax.set_yticks([1,0.1,0.05,0.01])
    ax.set_yticklabels([1,0.1,0.05, 0.01])
    
    return fig

In [None]:
ex = 2

In [None]:
chs = channels
lb_res = [1.84,1.84,1.68]
lb_id = np.unique(lb[lb!=0]) # exclude 0
lb_id = np.hstack([[0], lb_id]) # include 0 - noncell
lb_id_selected = props.index.values

In [None]:
i = 0
c = chs[i]
f_spots = fx_spots[c]
intn_th = intn_threshs[c]
print(c, f_spots) #, f_intns)

# spots
spots_rc = np.loadtxt(f_spots, delimiter=',')
print(len(spots_rc))

# filter
filter_cond = spots_rc[:,3] > intn_th
spots_rc = spots_rc[filter_cond]
print(len(spots_rc))

In [None]:
cond = spots_rc[:,3] > 0

In [None]:
plt.scatter(spots_rc[:,0][cond]/ex, spots_rc[:,1][cond]/ex, s=1, edgecolor='none')
plt.gca().set_aspect('equal')

In [None]:
plt.scatter(spots_rc[:,0][cond]/ex, spots_rc[:,2][cond]/ex, s=1, edgecolor='none')
plt.gca().set_aspect('equal')

In [None]:
plt.scatter(spots_rc[:,1][cond]/ex, spots_rc[:,2][cond]/ex, s=1, edgecolor='none')
plt.gca().set_aspect('equal')

In [None]:
sns.histplot(spots_rc[:,3], bins=np.arange(400,1000,5))

In [None]:
spots = np.round(spots_rc[:,:3]/lb_res).astype(int)-1

xlim, ylim, zlim =lb.shape
# remove outside range
spots = spots[~np.any(spots<0, axis=1)]
spots = spots[~(spots[:,0]>=xlim)]
spots = spots[~(spots[:,1]>=ylim)]
spots = spots[~(spots[:,2]>=zlim)]
print(len(spots))

spots_lb = lb[spots[:,2], spots[:,1], spots[:,0]] # z, y, x
spots_lb_outside = (spots_lb == 0)
print(f"fraction of spots outside of cells: {spots_lb_outside.sum()/len(spots):.2f}")
print(f"fraction of space outside of cells: {np.sum(lb==0)/lb.size:.2f}")

In [None]:
np.max(spots[:,2])

In [None]:
%%time
# shuffle everywhere
lb_shuff = lb.copy()
np.random.shuffle(np.ravel(lb_shuff))
lb_shuff = lb_shuff.reshape(lb.shape)

In [None]:
%%time
# shuffle in cells
i, j, k = lb.nonzero()
v = lb[i,j,k]
np.random.shuffle(v)
lb_shuff2 = np.zeros(lb.shape)
lb_shuff2[i,j,k] = v

In [None]:
res        = spot.spot_counts_worker(lb, spots_rc, lb_res,
                              lb_id=lb_id, 
                              remove_noncell=True, 
                              selected_roi_list=lb_id_selected,
                              )
res_shuff  = spot.spot_counts_worker(lb_shuff, spots_rc, lb_res,
                              lb_id=lb_id, 
                              remove_noncell=True, 
                              selected_roi_list=lb_id_selected,
                              )
res_shuff2 = spot.spot_counts_worker(lb_shuff2, spots_rc, lb_res,
                              lb_id=lb_id, 
                              remove_noncell=True, 
                              selected_roi_list=lb_id_selected,
                              )

In [None]:
counts = res.values
counts_shuff  = res_shuff.values
counts_shuff2 = res_shuff2.values
bins = np.arange(40)
# bins = np.arange(400, 600, 10)

fig = plot_reverse_cumsum(counts, bins=bins)
plt.show()

fig, axs = plot_reverse_cumsum_complex([counts, counts_shuff, counts_shuff2], 
                                label_list=['data', 'shuffled', 'shuffled in cells'],
                                color_list=['C0', 'black', 'gray'],
                                ymax=15000,
                                bins=bins,
                               )
axs[0].legend()
plt.show()

In [None]:
table = props # roi_meta.join(spotcount) #_merged)
table

In [None]:

for i in [5,10,12]:
    cond = table['r1v3_c0']>=i
    
    fig, axs = plt.subplots(1, 3, figsize=(3*8,1*6))
    fig.suptitle(f">= {i} spots (n={cond.sum()})")
    ax = axs[0]
    g = ax.scatter(table['x'], table['y'], c='lightgray', s=1, edgecolor='none')
    g = ax.scatter(table['x'][cond], table['y'][cond], c='red', s=3, edgecolor='none')
    # fig.colorbar(g, shrink=0.3)
    ax.set_aspect('equal')

    ax = axs[1]
    g = ax.scatter(table['x'], table['z'], c='lightgray', s=1, edgecolor='none')
    g = ax.scatter(table['x'][cond], table['z'][cond], c='red', s=3, edgecolor='none')
    # fig.colorbar(g, shrink=0.3)
    ax.set_aspect('equal')

    ax = axs[2]
    g = ax.scatter(table['y'], table['z'], c='lightgray', s=1, edgecolor='none')
    g = ax.scatter(table['y'][cond], table['z'][cond], c='red', s=3, edgecolor='none')
    # fig.colorbar(g, shrink=0.3)
    ax.set_aspect('equal')
    plt.show()


# combine

In [None]:
chs = ['r1v3_c0', 'r1v3_c2',]
ths = [8,7]
clrs = ['blue', 'red']

fig, axs = plt.subplots(3, 1, figsize=(1*8,3*6))
ax = axs[0]
g = ax.scatter(table['x'], table['y'], c='lightgray', s=1, edgecolor='none')
ax.grid(False)

ax = axs[1]
g = ax.scatter(table['x'], table['z'], c='lightgray', s=1, edgecolor='none')
ax.invert_yaxis()
ax.grid(False)

ax = axs[2]
g = ax.scatter(table['y'], table['z'], c='lightgray', s=1, edgecolor='none')
ax.invert_yaxis()
ax.grid(False)

# fig.suptitle(f">= {i} spots (n={cond.sum()})")
for ch, i, cl in zip(chs, ths, clrs):
    cond = table[ch]>=i
    ax = axs[0]
    g = ax.scatter(table['x'][cond], table['y'][cond], c=cl, s=5, edgecolor='none')
    # fig.colorbar(g, shrink=0.3)
    ax.set_aspect('equal')

    ax = axs[1]
    g = ax.scatter(table['x'][cond], table['z'][cond], c=cl, s=5, edgecolor='none')
    # fig.colorbar(g, shrink=0.3)
    ax.set_aspect('equal')

    ax = axs[2]
    g = ax.scatter(table['y'][cond], table['z'][cond], c=cl, s=5, edgecolor='none')
    # fig.colorbar(g, shrink=0.3)
    ax.set_aspect('equal')
plt.show()

# visualize the masks

In [None]:
# print(log2ths)
# thresholds = 2**log2ths-1 # np.array([5, 7, 5])
# print(thresholds)
thresholds = ths

In [None]:
# keep the masks with high intensity

for ch, th in zip(channels, thresholds):

    if PLUS_ONE:
        labeled_cells = 1+props[props[ch]>th].index.values #, props['area']<max_cellsize)].index.values
    else:
        labeled_cells = props[props[ch]>th].index.values #, props['area']<max_cellsize)].index.values
        
    labeled_masks = masks_to_labeled_masks(msk, labeled_cells)
    print(np.unique(labeled_masks).shape)

    # # save as tiff
    output = f_img0.replace('_c0_', f'_{ch}_').replace('.tiff', '_labeled_masks_countbased.tiff') 
    print(output)
    tifffile.imwrite(output, labeled_masks)
    


In [None]:
# # download and viz

# rsync -av f7xiesnm@dtn.hoffman2.idre.ucla.edu:/u/home/f/f7xiesnm/project-zipursky/v1-bb/ms_reanalysis/240910 ~/Downloads/

# Visualize the dots 

In [None]:
# plotly 3D

traces = []

allprops = props # .index.values # props['area']<max_cellsize)].index.values
allx = props['x'].values
ally = props['y'].values
allz = props['z'].values

trace = go.Scatter3d(x=allx, y=ally, z=allz, mode='markers',
                     marker=dict(size=2, color='lightgray', opacity=0.5))
traces.append(trace)

for i, (ch, th) in enumerate(zip(channels, thresholds)):
    labeled_props = props[props[ch]>th] # .index.values # props['area']<max_cellsize)].index.values
    x = labeled_props['x'].values
    y = labeled_props['y'].values
    z = labeled_props['z'].values
    color = colors[i]
    
    trace = go.Scatter3d(x=x, y=y, z=z, mode='markers',
                         marker=dict(size=3, color=color, opacity=0.5))
    traces.append(trace)
    
layout = go.Layout(title='',
                   scene=dict(
                       xaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       yaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       zaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       xaxis_title='x', yaxis_title='y', zaxis_title='z',
                   ), 
                   scene_dragmode='orbit',
                  )

fig = go.Figure(data=traces, layout=layout)
fig.write_html("figure.html")
fig.show()

# quants and stats

In [None]:
# bin spatial
stepsize_xy = 40
stepsize_z  = 20

xmax = allprops['x'].max()
ymax = allprops['y'].max()
zmax = allprops['z'].max()

xbins = np.arange(0, xmax+stepsize_xy, stepsize_xy)
ybins = np.arange(0, ymax+stepsize_xy, stepsize_xy)
zbins = np.arange(0, zmax+stepsize_z , stepsize_z)

# bin intensity (for binary label) 
for i, (ch, proj_target) in enumerate(zip(channels, proj_targets)): 
    allprops[proj_target] = (allprops[ch] > thresholds[i]).astype(int)

# along Z

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
ax = axs[0]
sns.histplot(data=allprops['z'], bins=zbins, color='lightgray', 
             element='step', fill=False,
             label='cytoDAPI',
             ax=ax,
            )
sns.histplot(data=allprops.loc[np.sum(allprops[proj_targets], axis=1)>0, 'z'], bins=zbins, color='black',
             element='step', fill=False,
             label='projection labeled',
             ax=ax,
            )
ax.legend()
sns.despine(ax=ax)

ax = axs[1]
for i, target in enumerate(proj_targets):
    sns.histplot(data=allprops.loc[allprops[target]>0, 'z'], bins=zbins, color=colors[i], 
                 element='step', fill=False,
                 label=target,
                 ax=ax,
                )
ax.legend()
sns.despine(ax=ax)
plt.show()


# along XY

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
ax = axs[0]
label = 'cytoDAPI labeled'
sns.histplot(data=allprops, 
             x='x', y='y', bins=[xbins, ybins], 
             element='step', fill=False,
             cmap='gray_r',
             cbar=True, cbar_kws={'shrink': 0.5},
             ax=ax,
            )
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.set_title(label)

ax = axs[1]
label = 'projection labeled'
sns.histplot(data=allprops.loc[np.sum(allprops[proj_targets], axis=1)>0], 
             x='x', y='y', bins=[xbins, ybins], 
             element='step', fill=False,
             cmap='gray_r',
             cbar=True, cbar_kws={'shrink': 0.5},
             ax=ax,
            )
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.set_title(label)
plt.show()

fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
for i, target in enumerate(proj_targets):
    ax = axs[i]
    color = colors[i]
    sns.histplot(data=allprops.loc[allprops[target]>0], 
                 x='x', y='y', bins=[xbins, ybins], 
                 element='step', fill=False,
                 cmap=f'{color[0].upper()}{color[1:]}s', 
                 cbar=True, cbar_kws={'shrink': 0.5},
                 ax=ax,
                )
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.set_title(target)
plt.show()