In [None]:
import os
import numpy as np
import pandas as pd
from skimage import io, measure
import tifffile
from kneed import KneeLocator

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
ddir1 = "/u/home/f/f7xiesnm/project-zipursky/easifish/results/viz_all_projections_jan29"
PLUS_ONE = False # False if table indices start from 1 ; True if from 0

# s3 resolution - to downsamp and rounding
f_msk = '/u/home/f/f7xiesnm/project-zipursky/easifish/cdf03_c1-2_bino/outputs/r1/segmentation/r1-c3.tif'
msk = io.imread(f_msk)
msk = np.array(msk)

msk = msk[::2, ::2, ::2]
msk = msk[:,:-1,:-1]
print(msk.shape)


# # images 
dataset_name = 'cdf03_c1-2_bino_r1_autos1_flatfused'
# channels = ['c0', 'c1', 'c2']

# all_props = []

# for ch in channels:
#     print(ch)
    
#     # image 
#     f_img = os.path.join(ddir1, f'{dataset_name}_{ch}_s4.tiff')

#     # view 1 image
#     img = io.imread(f_img)
#     img = np.array(img)

#     print(img.shape)
    
#     # measure intensity
#     props = measure.regionprops_table(msk.astype(int), 
#                                       intensity_image=img,
#                                       properties=['centroid', 'area', 'mean_intensity']) # , 'max_intensity', 'min_intensity'
#     props = pd.DataFrame(props)
#     print(props.shape)

#     all_props.append(props)


# table
f_tbl1 = '/u/home/f/f7xiesnm/project-zipursky/easifish/cdf03_c1-2_bino/proc/r1-v1/spotcount.csv'
f_tbl2 = '/u/home/f/f7xiesnm/project-zipursky/easifish/cdf03_c1-2_bino/proc/r1-v1/roi.csv'
props = pd.read_csv(f_tbl1, index_col=0)
props2 = pd.read_csv(f_tbl2, index_col=0)
props = props.join(props2, how='left')
props


In [None]:
np.unique(msk[msk!=0])

# check threshold (channel by channel)

In [None]:
n = 2
channels = ['r1_c0', 'r1_c1', 'r1_c2']
log2ths = []


fig, axs = plt.subplots(1,n,figsize=(n*5,1*4))
for i, ch in enumerate(channels):
    ax = axs[0]
    # x y 
    intn = np.log2(1+props[ch].values)
    x  = np.sort(intn)
    y  = np.arange(len(intn))/len(intn)


    cond = y > 0.8
    kl = KneeLocator(x[cond], y[cond], curve="concave")

    x_elbow = kl.elbow
    y_elbow = kl.elbow_y
    log2ths.append(x_elbow)

    ax.plot(x, y, color=f'C{i}')
    ax.plot(x_elbow, y_elbow, 'o', color=f'C{i}')
    print(x_elbow, y_elbow)


    ax = axs[1]
    ax.plot(x[cond], y[cond], color=f'C{i}')
    ax.plot(x_elbow, y_elbow, 'o', color=f'C{i}')

log2ths = np.array(log2ths)
    
plt.show()

# visualize the masks

In [None]:
def masks_to_labeled_masks(msk, labeled_cells):
    unq, inv = np.unique(msk.reshape(-1,), return_inverse=True)

    for i in unq:
        if i not in labeled_cells:
            unq[i] = 0

    labeled_msk = unq[inv].reshape(msk.shape)
    
    return labeled_msk

In [None]:

print(log2ths)
thresholds = 2**log2ths-1 # np.array([5, 7, 5])
print(thresholds)

In [None]:
# keep the masks with high intensity

for ch, th in zip(channels, thresholds):

    if PLUS_ONE:
        labeled_cells = 1+props[props[ch]>th].index.values #, props['area']<max_cellsize)].index.values
    else:
        labeled_cells = props[props[ch]>th].index.values #, props['area']<max_cellsize)].index.values
        
    labeled_masks = masks_to_labeled_masks(msk, labeled_cells)
    print(np.unique(labeled_masks).shape)

    # # # save as tiff
    # output = os.path.join(ddir1, f'{dataset_name}_{ch}_s4_labeled_masks_countbased.tiff')
    # print(output)
    # tifffile.imwrite(output, labeled_masks)
    
    # break


In [None]:
# # download and viz

# rsync -av f7xiesnm@dtn.hoffman2.idre.ucla.edu:/u/home/f/f7xiesnm/project-zipursky/v1-bb/ms_reanalysis/240910 ~/Downloads/

# Visualize the dots 

In [None]:
import plotly.io as pio
pio.renderers.default = 'jupyterlab'
import plotly.graph_objs as go

In [None]:
# plotly 3D

colors = ['red', 'green', 'blue']
traces = []

allprops = props # .index.values # props['area']<max_cellsize)].index.values
allx = props['x'].values
ally = props['y'].values
allz = props['z'].values

trace = go.Scatter3d(x=allx, y=ally, z=allz, mode='markers',
                     marker=dict(size=2, color='lightgray', opacity=0.5))
traces.append(trace)

for i, (ch, th) in enumerate(zip(channels, thresholds)):
    labeled_props = props[props[ch]>th] # .index.values # props['area']<max_cellsize)].index.values
    x = labeled_props['x'].values
    y = labeled_props['y'].values
    z = labeled_props['z'].values
    color = colors[i]
    
    trace = go.Scatter3d(x=x, y=y, z=z, mode='markers',
                         marker=dict(size=3, color=color, opacity=0.5))
    traces.append(trace)
    
layout = go.Layout(title='',
                   scene=dict(
                       xaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       yaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       zaxis=dict(showgrid=False, backgroundcolor='rgba(0,0,0,0)'),
                       xaxis_title='x', yaxis_title='y', zaxis_title='z',
                   ), 
                   scene_dragmode='orbit',
                  )

fig = go.Figure(data=traces, layout=layout)
fig.write_html("figure.html")
fig.show()

# quants and stats

In [None]:
allprops['LM']    = (allprops['r1_c0'] > thresholds[0]).astype(int)
allprops['AM/PM'] = (allprops['r1_c1'] > thresholds[1]).astype(int)
allprops['RL']    = (allprops['r1_c2'] > thresholds[2]).astype(int)

allprops

In [None]:
allprops[['x', 'y', 'z']].describe()

# along Z

In [None]:
zbins = np.linspace(0,  460,  int(460/20)+1)
xbins = np.linspace(0, 1000, int(1000/40)+1)
ybins = np.linspace(0,  860,  int(860/40)+1)

zbins, xbins, ybins

In [None]:
proj_targets = ['LM', 'AM/PM', 'RL']

fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
ax = axs[0]
sns.histplot(data=allprops['z'], bins=zbins, color='lightgray', 
             element='step', fill=False,
             label='cytoDAPI',
             ax=ax,
            )
sns.histplot(data=allprops.loc[np.sum(allprops[proj_targets], axis=1)>0, 'z'], bins=zbins, color='black',
             element='step', fill=False,
             label='projection labeled',
             ax=ax,
            )
ax.legend()
sns.despine(ax=ax)

ax = axs[1]
for i, target in enumerate(proj_targets):
    sns.histplot(data=allprops.loc[allprops[target]>0, 'z'], bins=zbins, color=colors[i], 
                 element='step', fill=False,
                 label=target,
                 ax=ax,
                )
ax.legend()
sns.despine(ax=ax)
plt.show()


# along XY

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(2*5,1*4))
ax = axs[0]
label = 'cytoDAPI'
sns.histplot(data=allprops, 
             x='x', y='y', bins=[xbins, ybins], 
             element='step', fill=False,
             cmap='gray_r',
             cbar=True, cbar_kws={'shrink': 0.5},
             ax=ax,
            )
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.set_title(label)

ax = axs[1]
label = 'projection labeled'
sns.histplot(data=allprops.loc[np.sum(allprops[proj_targets], axis=1)>0], 
             x='x', y='y', bins=[xbins, ybins], 
             element='step', fill=False,
             cmap='gray_r',
             cbar=True, cbar_kws={'shrink': 0.5},
             ax=ax,
            )
ax.set_aspect('equal')
sns.despine(ax=ax)
ax.set_title(label)
plt.show()

fig, axs = plt.subplots(1, 3, figsize=(3*5,1*4))
for i, target in enumerate(proj_targets):
    ax = axs[i]
    color = colors[i]
    sns.histplot(data=allprops.loc[allprops[target]>0], 
                 x='x', y='y', bins=[xbins, ybins], 
                 element='step', fill=False,
                 cmap=f'{color[0].upper()}{color[1:]}s', 
                 cbar=True, cbar_kws={'shrink': 0.5},
                 ax=ax,
                )
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.set_title(target)
plt.show()