In [None]:
import os
import numpy as np
import pandas as pd
from skimage import io, measure
import tifffile
from kneed import KneeLocator
from itertools import combinations

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import plotly.io as pio
pio.renderers.default = 'jupyterlab'
import plotly.graph_objs as go

In [None]:
import sys
import importlib

sys.path.insert(0, '/u/home/f/f7xiesnm/project-zipursky/code/easifish-proc/bydatasets/')
sys.path.append('../../')

import metadata_gene_chan
importlib.reload(metadata_gene_chan)
from metadata_gene_chan import get_proj_metadata

meta = get_proj_metadata()

from easi_fish import roi_prop, spot, intensity
# importlib.reload(spot)
# importlib.reload(roi_prop)
# importlib.reload(intensity)

import shared_functions
importlib.reload(shared_functions)

from shared_functions import spot_to_voxel_coords, filter_raw_spots, spots_incells_metrics, scramble_cell_masks, masks_to_labeled_masks
from shared_functions import plot_frac, plot_reverse_cumsum, plot_reverse_cumsum_complex

In [None]:
ddir0 = "/u/home/f/f7xiesnm/project-zipursky/easifish/lt186"
ddir1 = "/u/home/f/f7xiesnm/project-zipursky/easifish/results/viz_all_projections_jan29"
PLUS_ONE = False # False if table indices start from 1 ; True if from 0

# s3 resolution - to downsamp and rounding
f_msk  = ddir0 + '/outputs/r1v3/segmentation/r1v3-c3.tif'
f_tbl1 = ddir0 + '/proc/r1v3/spotcount.csv'
f_tbl2 = ddir0 + '/proc/r1v3/roi.csv'
f_img0 = ddir1 + '/lt186_r1_autos1_flatfused_c0_s4.tiff'

# spot dir for every gene
fx_spots = {
    'r1v3_c0': ddir0 + f'/outputs/r1v3/spots/spots_c0.txt',
    'r1v3_c2': ddir0 + f'/outputs/r1v3/spots/spots_c2.txt',
}
intn_threshs = {
    'r1v3_c0': 60, 
    'r1v3_c2': 60, 
}

ex = 2
lb_res = [0.23*8, 0.23*8, 0.42*4] # s3 resolution

# meta
metakey = 'lt186'
channels     = meta[metakey]['channels']
proj_targets = meta[metakey]['proj_targets']
colors       = meta[metakey]['colors']
polarities   = meta[metakey]['axis_polarity']

# S4 image shape
img_shape = io.imread(f_img0).shape
print(img_shape)

# S3 mask
lb = io.imread(f_msk)
lb = np.array(lb)

# some set up
lb_id = np.unique(lb[lb!=0]) # exclude 0
lb_id = np.hstack([[0], lb_id]) # include 0 - noncell


# downsample and trim to the same shape
msk = lb[::2, ::2, ::2]
msk = msk[:-1,:,:-1]
print(lb.shape)
print(msk.shape)
assert np.all(msk.shape == img_shape)

# table - check mask numbers == number of cells
props  = pd.read_csv(f_tbl1, index_col=0)
props2 = pd.read_csv(f_tbl2, index_col=0)
props  = props.join(props2, how='left')
assert np.all(np.unique(msk[msk!=0]) == props.index.values)

# filter out cells that are too large or too small
cond_filter = np.logical_and(props['area']<5000, props['area']>500)
print(cond_filter.sum()/len(props))
props = props[cond_filter]

# lb_id_selected
lb_id_selected = props.index.values

# normalization -  to mean area
mean_area = np.mean(props['area'])
norm_factor = props['area']/mean_area
for i, ch in enumerate(channels):
    props[ch] = props[ch]/norm_factor

props

In [None]:
# sns.histplot(props['area'].values)

# check - c0

physical (x ex) -> image coords <- voxel coords (x voxel size)

In [None]:
thresholds = []

# shuffled masks
lb_shuff = scramble_cell_masks(lb)
    
# channel by channel
for ch in channels:
    f_spots = fx_spots[ch]
    intn_th = intn_threshs[ch]
    print(ch, f_spots) #, f_intns)

    # spots
    spots_rc = np.loadtxt(f_spots, delimiter=',')
    spots_rc = filter_raw_spots(spots_rc, intn_th, lb_res, lb)
    
    spots_incells_metrics(spots_rc, lb, lb_res)

    # plot filtered spots
    fig, axs = plt.subplots(1,3,figsize=(3*8,1*6))
    indices = [0,1,2]
    indices_label = ['x', 'y', 'z']
    for idx, (i, j) in enumerate(combinations(indices, 2)):
        print(idx, i, j)
        ax = axs[idx]
        pi = polarities[i]
        pj = polarities[j]
        if pi == -1:
            ax.invert_xaxis()
        if pj == -1:
            ax.invert_yaxis()
        
        ax.scatter(spots_rc[:,i]/ex, spots_rc[:,j]/ex, s=1, edgecolor='none', color='gray')
        ax.set_xlabel(indices_label[i])
        ax.set_ylabel(indices_label[j])
        ax.set_aspect('equal')
        
    # count spots
    counts      = spot.spot_counts_worker(lb, spots_rc, lb_res,
                                  lb_id=lb_id, 
                                  remove_noncell=True, 
                                  selected_roi_list=lb_id_selected,
                                  )
    counts_shuff = spot.spot_counts_worker(lb_shuff, spots_rc, lb_res,
                                  lb_id=lb_id, 
                                  remove_noncell=True, 
                                  selected_roi_list=lb_id_selected,
                                  )
    norm_counts        = counts.values       /norm_factor.values
    norm_counts_shuff  = counts_shuff.values /norm_factor.values

    # get threshold
    bins = np.linspace(0, 20, 50)
    fig, axs, x_th = plot_reverse_cumsum_complex(
                                    [norm_counts, norm_counts_shuff], 
                                    label_list=['data', 'scrambled cells'],
                                    color_list=['C1', 'black'],
                                    # ymax=15000,
                                    bins=bins,
                                   )
    axs[0].legend()
    axs[1].set_title(f'threshold: {x_th:.2f}')
    plt.show()

    thresholds.append(x_th)
    props[f'new_{ch}'] = norm_counts
        
    # break
    

In [None]:
axis_labels = ['x','y','z']
axis_indices = [0,1,2]
fig, axs = plt.subplots(1, 3, figsize=(3*8,1*6))
for idx, (i, j) in enumerate(combinations(axis_indices, 2)):
    ax = axs[idx]
    ax.set_aspect('equal')
    
    li = axis_labels[i]
    lj = axis_labels[j]
    ax.set_xlabel(li)
    ax.set_ylabel(lj)
    
    pi = polarities[i]
    pj = polarities[j]
    if pi == -1:
        ax.invert_xaxis()
    if pj == -1:
        ax.invert_yaxis()
    
    g = ax.scatter(props[li], props[lj], c='lightgray', s=1, edgecolor='none')
    for ch, th, cl in zip(channels, thresholds, colors):
        cond = props[f'new_{ch}']>=th
        num = cond.sum()
        g = ax.scatter(props[li][cond], props[lj][cond], c=cl, s=5, edgecolor='none', label=f'{ch} (n={num:,})')

axs[0].legend()
plt.show()

# visualize the masks

In [None]:
# keep the masks with high intensity

for ch, th in zip(channels, thresholds):

    if PLUS_ONE:
        labeled_cells = 1+props[props[f'new_{ch}']>th].index.values #, props['area']<max_cellsize)].index.values
    else:
        labeled_cells = props[props[f'new_{ch}']>th].index.values #, props['area']<max_cellsize)].index.values
        
    labeled_masks = masks_to_labeled_masks(msk, labeled_cells)
    print(np.unique(labeled_masks).shape)

    # # save as tiff
    output = f_img0.replace('_c0_', f'_{ch}_').replace('.tiff', '_labeled_masks_countbased.tiff') 
    print(output)
    tifffile.imwrite(output, labeled_masks)
    

# Visualize the dots 

In [None]:
# plotly 3D

traces = []

allprops = props # .index.values # props['area']<max_cellsize)].index.values
allx = props['x'].values
ally = props['y'].values
allz = props['z'].values

trace = go.Scatter3d(x=allx, y=ally, z=allz, mode='markers',
                     marker=dict(size=2, color='lightgray', opacity=0.5))
traces.append(trace)

for i, (ch, th) in enumerate(zip(channels, thresholds)):
    labeled_props = props[props[f'new_{ch}']>th] # .index.values # props['area']<max_cellsize)].index.values
    x = labeled_props['x'].values
    y = labeled_props['y'].values
    z = labeled_props['z'].values
    color = colors[i]
    
    trace = go.Scatter3d(x=x, y=y, z=z, mode='markers',
                         marker=dict(size=3, color=color, opacity=0.5))
    traces.append(trace)
    
layout = go.Layout(title='',
                   scene=dict(
                       xaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       yaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       zaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       xaxis_title='x', yaxis_title='y', zaxis_title='z',
                   ), 
                   scene_dragmode='orbit',
                  )

fig = go.Figure(data=traces, layout=layout)
fig.write_html("figure.html")
fig.show()

# quants and stats

In [None]:
# bin spatial
stepsize_xy = 40
stepsize_z  = 20

xmax = allprops['x'].max()
ymax = allprops['y'].max()
zmax = allprops['z'].max()

xbins = np.arange(0, xmax+stepsize_xy, stepsize_xy)
ybins = np.arange(0, ymax+stepsize_xy, stepsize_xy)
zbins = np.arange(0, zmax+stepsize_z , stepsize_z)

# bin intensity (for binary label) 
for i, (ch, proj_target) in enumerate(zip(channels, proj_targets)): 
    allprops[proj_target] = (allprops[ch] > thresholds[i]).astype(int)

# along Z

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
ax = axs[0]
sns.histplot(data=allprops['z'], bins=zbins, color='lightgray', 
             element='step', fill=False,
             label='cytoDAPI',
             ax=ax,
            )
sns.histplot(data=allprops.loc[np.sum(allprops[proj_targets], axis=1)>0, 'z'], bins=zbins, color='black',
             element='step', fill=False,
             label='projection labeled',
             ax=ax,
            )
ax.legend()
sns.despine(ax=ax)

ax = axs[1]
for i, target in enumerate(proj_targets):
    sns.histplot(data=allprops.loc[allprops[target]>0, 'z'], bins=zbins, color=colors[i], 
                 element='step', fill=False,
                 label=target,
                 ax=ax,
                )
ax.legend()
sns.despine(ax=ax)
plt.show()


# along XY

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
ax = axs[0]
label = 'cytoDAPI labeled'
sns.histplot(data=allprops, 
             x='x', y='y', bins=[xbins, ybins], 
             element='step', fill=False,
             cmap='gray_r',
             cbar=True, cbar_kws={'shrink': 0.5},
             ax=ax,
            )
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.set_title(label)

ax = axs[1]
label = 'projection labeled'
sns.histplot(data=allprops.loc[np.sum(allprops[proj_targets], axis=1)>0], 
             x='x', y='y', bins=[xbins, ybins], 
             element='step', fill=False,
             cmap='gray_r',
             cbar=True, cbar_kws={'shrink': 0.5},
             ax=ax,
            )
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.set_title(label)
plt.show()

fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
for i, target in enumerate(proj_targets):
    ax = axs[i]
    color = colors[i]
    sns.histplot(data=allprops.loc[allprops[target]>0], 
                 x='x', y='y', bins=[xbins, ybins], 
                 element='step', fill=False,
                 cmap=f'{color[0].upper()}{color[1:]}s', 
                 cbar=True, cbar_kws={'shrink': 0.5},
                 ax=ax,
                )
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.set_title(target)
plt.show()