In [None]:
import os, shutil
import numpy as np
import matplotlib.pyplot as plt
from cellpose import core, utils, io, models, metrics
from glob import glob
from tqdm import tqdm
from IPython.display import clear_output
from cellpose import plot
from cellpose import io, utils
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

use_GPU = core.use_gpu()
yn = ['NO', 'YES']
print(f'>>> GPU activated? {yn[use_GPU]}')

In [None]:
def eval(eval_dir, channels, diameter, flow_threshold, cellprob_threshold,
         save_segmentation, save_roi_counts, save_plots, count=0):
    if len(os.listdir(eval_dir)) == 0:
        return

    files = [os.path.join(eval_dir, f) for f in os.listdir(eval_dir) if f.endswith('_clip.tiff')]
    if save_roi_counts or save_plots:
        result_folder = os.path.join(*files[0].split('/')[:-1], 'results')
        if not os.path.exists(result_folder):
            os.makedirs(result_folder)
    
    for n, f in enumerate(tqdm(files, desc='Segmentation')):
        # gets image files in dir (ignoring image files ending in _masks)
        images = [io.imread(f)]


        # declare model
        model = models.CellposeModel(gpu=True, model_type='cyto3')

        # use model diameter if user diameter is 0
        diameter = model.diam_labels if diameter==0 else diameter

        # run model on test images
        masks, flows, styles = model.eval(images,
                                        channels=channels,
                                        diameter=diameter,
                                        flow_threshold=flow_threshold,
                                        cellprob_threshold=cellprob_threshold
                                        )
        if save_segmentation:
            # outlines = masks[0] * utils.masks_to_outlines(masks[0])
            # np.save(f[:-4] + '_seg.npy', {'outlines': outlines.astype(np.uint16) if outlines.max()<2**16-1 else outlines.astype(np.uint32),
            #                 'masks': masks[0].astype(np.uint16) if outlines.max()<2**16-1 else masks[0].astype(np.uint32)})
            io.masks_flows_to_seg(images, 
                      masks, 
                      flows, 
                      [f], 
                      diameter*np.ones(len(masks)), 
                      channels)
            
        if save_roi_counts:
            file_name = f.split('/')[-1]
            result_path = os.path.join(result_folder, file_name)
            save_path = os.path.splitext(result_path)[0] + '_roi_counts.txt'
            with open(save_path, 'w') as fp:
                fp.write('%d' % int(np.max(masks[0])))
        if save_plots:
            file_name = f.split('/')[-1]
            result_path = os.path.join(result_folder, file_name)
            show_segmentation(images[0], masks[0], channels=channels, file_name=result_path)
        if count != 0 and n + 1 >= count:
            return

# def run_eval(only_subfolders=False, count=0):
#     if not only_subfolders:
#         eval(dir, [chan, chan2], diameter, flow_threshold,
#             cellprob_threshold, save_segmentation, save_roi_counts, save_plots, count)
#     for o in os.listdir(dir):
#         if os.path.isdir(os.path.join(dir, o)) and o != 'results':
#             eval(os.path.join(dir, o), [chan, chan2], diameter, flow_threshold,
#                 cellprob_threshold, save_segmentation, save_roi_counts, save_plots)

def show_segmentation(img, maski, channels=[0,0], file_name=None):
    fig = plt.figure(figsize=(8,3))
    img0 = img.copy()

    if img0.shape[0] < 4:
        img0 = np.transpose(img0, (1,2,0))
    if img0.shape[-1] < 3 or img0.ndim < 3:
        img0 = plot.image_to_rgb(img0, channels=channels)
    else:
        if img0.max()<=50.0:
            img0 = np.uint8(np.clip(img0*255, 0, 1))

    overlay = plot.mask_overlay(img0, maski)

    ax = fig.add_subplot(1,1,1)
    ax.imshow(overlay)
    ax.set_title('predicted masks')
    ax.axis('off')

    if file_name is not None:
        save_path = os.path.splitext(file_name)[0]
        io.imsave(save_path + '_overlay.jpg', overlay)
    plt.close()


In [None]:
channels = [0, 0]
diameter = 0
flow_threshold = 0.5
cellprob_threshold=0
save_segmentation = True
save_roi_counts = False
save_plots = False

In [None]:
remote_path = "../data/Balint/STORM PooledPlasma - ch640/"

In [None]:
eval(remote_path, channels, diameter, flow_threshold, cellprob_threshold, save_segmentation, save_roi_counts, save_plots, count=0)