In [1]:
import os
import numpy as np
import zarr
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from ome_zarr.writer import write_image
from dask import delayed
import dask.array as da
from skimage.io import imsave, imread

from cellpose import models

import napari
from napari.settings import get_settings
get_settings().application.ipy_interactive = True

from tqdm.notebook import tqdm

from PIL import Image
Image.MAX_IMAGE_PIXELS = None

In [3]:
# input 
input_path_zarr = r'R:\Kasia\tracking\TrackGardener\B4_C1_small.zarr'

#output
output_path_zarr = r'R:\Kasia\tracking\TrackGardener\C4_masks.zarr'
save_png_dir = r'R:\Kasia\tracking\TrackGardener\C4_segmentation'

# create the png directory if it doesn't exist
os.makedirs(save_png_dir, exist_ok=True)

# chunking for the output zarr
size_t = 1
size_xy = 2048

In [4]:
# read in the data
reader = Reader(parse_url(input_path_zarr))

image_node = list(reader())[0]
dask_data = image_node.data[0]
dask_data

  compressor, fill_value = _kwargs_compat(compressor, fill_value, kwargs)


Unnamed: 0,Array,Chunk
Bytes,31.66 GiB,5.00 MiB
Shape,"(241, 8400, 8396)","(10, 512, 512)"
Dask graph,7225 chunks in 2 graph layers,7225 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 31.66 GiB 5.00 MiB Shape (241, 8400, 8396) (10, 512, 512) Dask graph 7225 chunks in 2 graph layers Data type uint16 numpy.ndarray",8396  8400  241,

Unnamed: 0,Array,Chunk
Bytes,31.66 GiB,5.00 MiB
Shape,"(241, 8400, 8396)","(10, 512, 512)"
Dask graph,7225 chunks in 2 graph layers,7225 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray


In [None]:
# specify if all frame is to be segmented (start 0 and stop from dask_data.shape  for the full movie)
params = {
    'frame_start': 0,
    'row_start': 0,
    'col_start': 0,
    'frame_stop': dask_data.shape[0],
    'row_stop': dask_data.shape[1],
    'col_stop': dask_data.shape[2],
}

# check that the values are acceptable

# modify the values if they are negative
data_shapes = [dask_data.shape[0], dask_data.shape[1], dask_data.shape[2]]
params = {key: value + data_shapes[i % 3] if value < 0 else value for i, (key, value) in enumerate(params.items())}
frame_start, row_start, col_start, frame_stop, row_stop, col_stop = [params[key] for key in params]


assert frame_start < frame_stop, f'frame_start {frame_start} should be smaller than frame_stop {frame_stop}'
assert row_start < row_stop, f'row_start {row_start} should be smaller than row_stop {row_stop}'
assert col_start < col_stop, f'col_start {col_start} should be smaller than col_stop {col_stop}'

print(f'Selected data of shape: {frame_stop - frame_start, row_stop - row_start, col_stop - col_start}')

Selected data of shape: (2, 8400, 8396)


In [12]:
# read cellpose model
model = models.CellposeModel(gpu=True,model_type='cyto')

  state_dict = torch.load(filename, map_location=device)


### Segment to png files

In [None]:
# set parameters for segmentation
# see https://cellpose.readthedocs.io/en/latest/api.html for details

normalize_dict = {
    "percentile": [1, 99], # defaults are 1 and 99
    "normalize": True    
}

resample = False # default is True


for i in tqdm(range(frame_start,frame_stop)):

    im_frame = dask_data[i,row_start:row_stop,col_start:col_stop]
    mask,_,_ = model.eval(im_frame,diameter = 60, normalize=normalize_dict, z_axis = 0,resample = resample)

    save_path = os.path.join(save_png_dir,f'mask_{(str(i).zfill(3))}.png')
    #imsave(save_path,mask)

  0%|          | 0/2 [00:00<?, ?it/s]

### Save masks as zarr

In [19]:
lazy_arrays = [delayed(imread)(os.path.join(save_png_dir,f'mask_{(str(i).zfill(3))}.png')) for i in range(frame_start,frame_stop)]
dask_arrays = [da.from_delayed(delayed_reader, shape=mask.shape, dtype='uint16') for delayed_reader in lazy_arrays]
stack = da.stack(dask_arrays, axis=0)
stack

In [21]:
# save zarr file

store = parse_url(output_path_zarr, mode="w").store
root = zarr.group(store=store)

# it will fail if the store already contains arrays
write_image(image=stack, group=root, axes="tyx", storage_options=dict(chunks=(size_t,size_xy, size_xy)))

[]