In [None]:
import numcodecs
import numpy as np
import zarr

from zarr.storage import LocalStore

Setup logging

In [None]:
import logging, os, sys
from logging.config import fileConfig

log_config_file = ''

if log_config_file and os.path.exists(log_config_file):
    print(f'Initialize logging from {log_config_file}')
    fileConfig(log_config_file)
else:
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s - %(threadName)s:%(name)s - %(levelname)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        handlers=[
                            logging.StreamHandler(stream=sys.stdout)
                        ])


In [None]:
def open_zarr(img_dir, img_subpath):
    data_store = LocalStore(img_dir)
    img_container = zarr.open(store=data_store, mode='r')
    if img_subpath:
        return img_container[img_subpath]
    else:
        return img_container

Open test data and create a temporary work directory

In [None]:
test_input_img_zarr = open_zarr('/Users/goinac/Work/HHMI/stitching/datasets/tiny/stitched.ome.zarr', 't2/2')
test_input_img_shape = test_input_img_zarr.shape
test_working_dir = '/Users/goinac/Work/HHMI/segmentation/cellpose/tmp'

test_output_zarr_format = 2
test_output_zarr = zarr.open_group(
    store=LocalStore(root=f'{test_working_dir}/labels.zarr'),
    mode='w',
    zarr_format=test_output_zarr_format,
)
print(test_output_zarr)
chunk_key_separator = {'name': 'v2', 'separator': '/'} if test_output_zarr_format == 2 else None

test_output_labels_zarr = test_output_zarr.require_array(
    name='labels',
    shape=test_input_img_shape[-3:],
    chunks=(128,128,128),
    dtype=np.uint32,
    chunk_key_encoding=chunk_key_separator,
    compressors=numcodecs.get_codec({'id': 'zstd', 'level': 5})
)

In [None]:
!PYTORCH_ENABLE_MPS_FALLBACK=1

Invoke distributed eval method

In [None]:
from cellpose.contrib.distributed_segmentation import distributed_eval
from cellpose.contrib.dask_utils import myLocalCluster, ConfigureWorkerPlugin

localCluster = myLocalCluster(
    1,
    n_workers=8,
    threads_per_worker=1,
    processes=True,
    host="localhost",
)
localCluster.client.register_plugin(
    ConfigureWorkerPlugin(
        '', 
        '',
        True,
        1
    )
)

labels, boxes = distributed_eval(
    test_input_img_zarr,
    0, # input_timeindex
    1, # input_channels
    (128,128,128),
    test_output_labels_zarr,
    cellpose_model_args={
        'use_gpu': True,
        'gpu_device': 'mps',
        'pretrained_model': 'cpsam',
    },
    normalize_args={
        'normalize': True,
        'lowhigh': (1,99),
    },
    cellpose_eval_args={
        'do_3D': True,
        'min_size': 15,
        'max_size_fraction': 0.4,
        'cellprob_threshold': -8,
        'flow3D_smooth': 1,
        'batch_size': 8,
    },
    cluster=localCluster,
    cluster_kwargs={},
    temp_dir=test_working_dir,
)