In [42]:
import numpy as np
import pathlib
import os
from glob import glob
import cellpose
from cellpose import io
from cellpose import utils
from cellpose import models
from tifffile import imread, imsave
from typing import List, Tuple, Callable, Optional, Union, Dict,Any
import cv2
from zipfile import ZipFile, ZIP_DEFLATED

In [None]:
# Cellpose API: https://cellpose.readthedocs.io/en/latest/api.html

In [68]:
MODEL_PATH = r'D:\MVD software\pipeline\segmentation\models\CP_CRC_TMA_DAPI_panEPI_v0'
IMG_PATH = r'D:\MVD software\pipeline\segmentation\data_test\reg_2a_A-9.tif'
OUT_DIR = r'D:\MVD software\pipeline\segmentation\data_test\output'

# Load cellpose model and image
model = load_cellpose_model(model_path = MODEL_PATH, gpu=True)
image = imread(IMG_PATH)
print(image.shape)

# Compute segmentation
# to setup channels see https://cellpose.readthedocs.io/en/latest/api.html#cellpose.models.Cellpose.eval 
mask = eval_cellpose_model(image, model, channels=[1,2])
print(mask.shape)

# Export mask
out_mask_path = os.path.join(OUT_DIR, pathlib.PurePath(IMG_PATH).name.replace('.tif', '_mask.png') )
print('out_path', out_mask_path)
imsave(out_mask_path, mask)

# Export as imageJ rois (quite long) 
ij_rois = convert_mask_labels_to_polygons(mask)
out_polygons_path = os.path.join(OUT_DIR, pathlib.PurePath(IMG_PATH).name.replace('.tif', '_rois.zip'))
export_imagej_rois(out_polygons_path , ij_rois)     

(2, 3075, 3075)
(3075, 3075)
out_path E:\test\reg_2a_B-19_mask.png


In [67]:
def load_cellpose_model(model_path : Optional[str] = None,
                        model_type : Optional[str] = 'cyto',
                        gpu = False 
                       ) -> cellpose.models:

    if model_path:
        model = models.CellposeModel(gpu=gpu, pretrained_model=model_path)
    else:
        model = models.Cellpose(gpu=gpu, model_type=model_type)
        
    return model

def eval_cellpose_model(image : np.ndarray, 
                        model : cellpose.models, 
                        channels : List[Tuple[int, int]],
                        batch_size : Optional[int] = 8,
                        diameter : Optional[int] = None,
                        flow_threshold : Optional[float] = 0.4,
                        cellprob_threshold : Optional[float] = 0.0,
                        ) -> np.ndarray:
    
    model_output = model.eval(image, 
                              channels=channels, 
                              batch_size=batch_size,
                              diameter=None, 
                              flow_threshold=flow_threshold,
                              cellprob_threshold=cellprob_threshold
                             )
    
    mask = model_output[0]

    return mask

def convert_mask_labels_to_polygons(mask : np.ndarray, 
                                    labels : Optional[List[int]] = None
                                   ) -> List[List[Tuple[float,float]]]:

    if labels == None:
        unique_labels = np.unique(mask)[1:] #Skip label 0 (=background)
    else: 
        unique_labels = labels
    
    polygons = [[]]
    for index, label in enumerate(unique_labels): 
        polygon = convert_mask_label_to_polygon(mask, label)

        if polygon:
            polygon_x = []
            polygon_y = []
            for point in polygon:
                x = point[0]
                y = point[1]
                
                if x >= 0 and y >= 0:
                    polygon_x.append(x)
                    polygon_y.append(y) 

            polygons[0].append([]) 
            polygons[0][index].append(list(np.uint16(polygon_y)))
            polygons[0][index].append(list(np.uint16(polygon_x)))

    return polygons

def convert_mask_label_to_polygon(mask : np.ndarray,
                                  label : int
                                  ) -> List[Tuple[float,float]] :
 
    gray = mask == label
    gray = gray.astype(np.uint8)*255
    cnts = cv2.findContours(gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[-2]
    
    argmax = np.argmax([len(cnt) for cnt in cnts])
    polygon = [ (cnt[0][0], cnt[0][1]) for cnt in cnts[argmax] ] 
    
    return polygon

# -- To save polygon as imageJ rois

def polyroi_bytearray(x : int,
                      y : int,
                      pos : Optional[int] = None,
                      subpixel : Optional[bool] = True):
    
    """ Byte array of polygon roi with provided x and y coordinates
        See https://github.com/imagej/imagej1/blob/master/ij/io/RoiDecoder.java
    """
    import struct
    def _int16(x):
        return int(x).to_bytes(2, byteorder='big', signed=True)
    def _uint16(x):
        return int(x).to_bytes(2, byteorder='big', signed=False)
    def _int32(x):
        return int(x).to_bytes(4, byteorder='big', signed=True)
    def _float(x):
        return struct.pack(">f", x)

    subpixel = bool(subpixel)
    # add offset since pixel center is at (0.5,0.5) in ImageJ
    x_raw = np.asarray(x).ravel() + 0.5
    y_raw = np.asarray(y).ravel() + 0.5
    x = np.round(x_raw)
    y = np.round(y_raw)
    assert len(x) == len(y)
    top, left, bottom, right = y.min(), x.min(), y.max(), x.max() # bbox

    n_coords = len(x)
    bytes_header = 64
    bytes_total = bytes_header + n_coords*2*2 + subpixel*n_coords*2*4
    B = [0] * bytes_total
    B[ 0: 4] = map(ord,'Iout')   # magic start
    B[ 4: 6] = _int16(227)       # version
    B[ 6: 8] = _int16(0)         # roi type (0 = polygon)
    B[ 8:10] = _int16(top)       # bbox top
    B[10:12] = _int16(left)      # bbox left
    B[12:14] = _int16(bottom)    # bbox bottom
    B[14:16] = _int16(right)     # bbox right
    B[16:18] = _uint16(n_coords) # number of coordinates
    if subpixel:
        B[50:52] = _int16(128)   # subpixel resolution (option flag)
    if pos is not None:
        B[56:60] = _int32(pos)   # position (C, Z, or T)

    for i,(_x,_y) in enumerate(zip(x,y)):
        xs = bytes_header + 2*i
        ys = xs + 2*n_coords
        B[xs:xs+2] = _int16(_x - left)
        B[ys:ys+2] = _int16(_y - top)

    if subpixel:
        base1 = bytes_header + n_coords*2*2
        base2 = base1 + n_coords*4
        for i,(_x,_y) in enumerate(zip(x_raw,y_raw)):
            xs = base1 + 4*i
            ys = base2 + 4*i
            B[xs:xs+4] = _float(_x)
            B[ys:ys+4] = _float(_y)

    return bytearray(B)

def export_imagej_rois(out_path : str, 
                       polygons : List[List[Tuple[float,float]]], 
                       set_position : bool = True, 
                       subpixel : bool = True, 
                       compression=ZIP_DEFLATED):
    
    """ polygons assumed to be a list of arrays with shape (id,2,c) """

    if isinstance(polygons, np.ndarray):
        polygons = (polygons,)

    fname = pathlib.Path(out_path)
    if fname.suffix == '.zip':
        fname = fname.with_suffix('')

    with ZipFile(str(fname)+'.zip', mode='w', compression=compression) as roizip:
        for pos,polygroup in enumerate(polygons, start=1):
            for i,poly in enumerate(polygroup, start=1):
                roi = polyroi_bytearray(poly[1], poly[0], pos=(pos if set_position else None), subpixel=subpixel)
                roizip.writestr('{pos:03d}_{i:03d}.roi'.format(pos=pos,i=i), roi)