In [None]:
import os
import numpy as np
import pandas as pd
from skimage import io, measure
import tifffile
from kneed import KneeLocator

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
ddir1 = "/u/home/f/f7xiesnm/project-zipursky/easifish/results/viz_all_projections_jan29"

# s3 resolution - to downsamp and rounding
f_msk = '/u/home/f/f7xiesnm/project-zipursky/easifish/cdf03_c1-2_bino/outputs/r1/segmentation/r1-c3.tif'
msk = io.imread(f_msk)
msk = np.array(msk)

msk = msk[::2, ::2, ::2]
msk = msk[:,:-1,:-1]
print(msk.shape)


# images 
dataset_name = 'cdf03_c1-2_bino_r1_autos1_flatfused'
channels = ['c0', 'c1', 'c2']

all_props = []

for ch in channels:
    print(ch)
    
    # image 
    f_img = os.path.join(ddir1, f'{dataset_name}_{ch}_s4.tiff')

    # view 1 image
    img = io.imread(f_img)
    img = np.array(img)

    print(img.shape)
    
    # measure intensity
    props = measure.regionprops_table(msk.astype(int), 
                                      intensity_image=img,
                                      properties=['centroid', 'area', 'mean_intensity']) # , 'max_intensity', 'min_intensity'
    props = pd.DataFrame(props)
    print(props.shape)


    all_props.append(props)

# check threshold (channel by channel)

In [None]:
n = len(all_props)+2

log2ths = []

fig, axs = plt.subplots(1,n,figsize=(n*5,1*4))
for i in range(n-2):
    ax = axs[i]
    props = all_props[i]
    
    sns.scatterplot(data=props, x='mean_intensity', y='area', s=4, ax=ax)


for i, props in enumerate(all_props):
    # x y 
    intn = np.log2(1+props['mean_intensity'].values)
    x  = np.sort(intn)
    y  = np.arange(len(intn))/len(intn)
    
    
    cond = y > 0.8
    kl = KneeLocator(x[cond], y[cond], curve="concave")
    
    x_elbow = kl.elbow
    y_elbow = kl.elbow_y

    ax = axs[n-2]
    ax.plot(x, y, color=f'C{i}')
    ax.plot(x_elbow, y_elbow, 'o', color=f'C{i}')
    print(x_elbow, y_elbow)
    
    ax = axs[n-1]
    ax.plot(x[cond], y[cond], color=f'C{i}')
    ax.plot(x_elbow, y_elbow, 'o', color=f'C{i}')
    
    log2ths.append(x_elbow)

log2ths = np.array(log2ths)
    
plt.show()


# visualize the masks

In [None]:
def masks_to_labeled_masks(msk, labeled_cells):
    unq, inv = np.unique(msk.reshape(-1,), return_inverse=True)

    for i in unq:
        if i not in labeled_cells:
            unq[i] = 0

    labeled_msk = unq[inv].reshape(msk.shape)
    
    return labeled_msk

In [None]:

print(log2ths)
thresholds = 2**log2ths-1 # np.array([5, 7, 5])
print(thresholds)

In [None]:
# keep the masks with high intensity

max_cellsize = 1500

for ch, props, th in zip(channels, all_props, thresholds):

    labeled_cells = 1+props[np.logical_and(props['mean_intensity']>th, props['area']<max_cellsize)].index.values
    labeled_masks = masks_to_labeled_masks(msk, labeled_cells)
    print(np.unique(labeled_masks).shape)

    # # save as tiff
    output = os.path.join(ddir1, f'{dataset_name}_{ch}_s4_labeled_masks.tiff')
    print(output)
    tifffile.imwrite(output, labeled_masks)
    
    break


In [None]:
areas = props.loc[labeled_cells-1, 'area']
np.min(areas)

In [None]:
sns.histplot(areas) 

In [None]:
max_cellsize = 1500
min_cellsize = 100

for ch, props, th in zip(channels, all_props, thresholds):

    labeled_cells = 1+props[np.logical_and(props['mean_intensity']>th, props['area']<min_cellsize)].index.values
    labeled_masks = masks_to_labeled_masks(msk, labeled_cells)
    print(np.unique(labeled_masks).shape)

    # # save as tiff
    output = os.path.join(ddir1, f'{dataset_name}_{ch}_s4_labeled_masks_small.tiff')
    print(output)
    tifffile.imwrite(output, labeled_masks)
    
    break

In [None]:
labeled_cells

In [None]:
# # download and viz

# rsync -av f7xiesnm@dtn.hoffman2.idre.ucla.edu:/u/home/f/f7xiesnm/project-zipursky/v1-bb/ms_reanalysis/240910 ~/Downloads/

# use S3 later - simpler code - more accurate - 8x slower

# thresholding can be better