Skip to content

Commit

Permalink
fixing 3D mask throw-out bug + adding anisotropy + adding pyinstaller…
Browse files Browse the repository at this point in the history
… files
  • Loading branch information
carsen-stringer committed Aug 16, 2020
1 parent 183acde commit 007a7c3
Show file tree
Hide file tree
Showing 19 changed files with 883 additions and 356 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ You can quickly try out Cellpose on the [website](http://www.cellpose.org) first

You can also run Cellpose in google colab with a GPU -> [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MouseLand/cellpose/blob/master/notebooks/run_cellpose_GPU.ipynb). This is recommended if you have issues with MKL or run speed on your local computer (and are running 3D volumes). Colab does not allow you to run the GUI, but you can save `*_seg.npy` files in colab that you can download and open in the GUI.

### Detailed documentation at [www.cellpose.org/docs](http://www.cellpose.org/static/docs/index.html).
### Detailed documentation at [www.cellpose.org/docs](http://www.cellpose.org/docs).

## System requirements

Expand Down Expand Up @@ -51,7 +51,8 @@ pip install cellpose --upgrade

If you have an older `cellpose` environment you can remove it with `conda env remove -n cellpose` before creating a new one.

Note you will always have to run **conda activate cellpose** before you run cellpose. If you want to run jupyter notebooks in this environment, then also `conda install jupyter`.
Note you will always have to run **conda activate cellpose** before you run cellpose. If you want to run jupyter notebooks in this environment, then also `conda install jupyter`
and `pip install matplotlib`.

### Common issues

Expand Down Expand Up @@ -155,7 +156,7 @@ The GUI serves two main functions:
2. Manually labelling data.

There is a help window in the GUI that provides more instructions and
a page in the documentation [here](https://cellpose.readthedocs.io/en/latest/gui.html).
a page in the documentation [here](http://cellpose.readthedocs.io/en/latest/gui.html).
Also, if you hover over certain words in the GUI, their definitions
are revealed as tooltips.

Expand All @@ -175,11 +176,11 @@ You can specify the diameter for all the images or set to 0 if you want the algo
python -m cellpose --dir ~/images_nuclei/test/ --pretrained_model nuclei --diameter 0. --save_png
~~~

See the [docs](http://www.cellpose.org/static/docs/command.html) for more info.
See the [docs](http://cellpose.readthedocs.io/en/latest/command.html) for more info.

## Outputs

See the [docs](http://www.cellpose.org/static/docs/outputs.html) for info.
See the [docs](http://cellpose.readthedocs.io/en/latest/outputs.html) for info.

## Dependencies
cellpose relies on the following excellent packages (which are automatically installed with conda/pip if missing):
Expand All @@ -189,6 +190,4 @@ cellpose relies on the following excellent packages (which are automatically ins
- [numpy](http://www.numpy.org/) (>=1.16.0)
- [numba](http://numba.pydata.org/numba-doc/latest/user/5minguide.html)
- [scipy](https://www.scipy.org/)
- [scikit-image](https://scikit-image.org/)
- [natsort](https://natsort.readthedocs.io/en/master/)
- [matplotlib](https://matplotlib.org/)
96 changes: 65 additions & 31 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os, argparse, glob, pathlib, time
import subprocess
import numpy as np
import mxnet as mx
import os, argparse, glob, pathlib
import skimage
from natsort import natsorted

from . import utils, models, io
Expand All @@ -20,8 +19,11 @@
GUI_IMPORT = False
raise

def get_image_files(folder, mask_filter):
def get_image_files(folder, mask_filter, imf=None):
mask_filters = ['_cp_masks', '_cp_output', mask_filter]
image_names = []
if imf is None:
imf = ''
image_names.extend(glob.glob(folder + '/*%s.png'%imf))
image_names.extend(glob.glob(folder + '/*%s.jpg'%imf))
image_names.extend(glob.glob(folder + '/*%s.jpeg'%imf))
Expand All @@ -31,19 +33,18 @@ def get_image_files(folder, mask_filter):
imn = []
for im in image_names:
imfile = os.path.splitext(im)[0]
if len(imfile) > len(mask_filter):
if imfile[-len(mask_filter):] != mask_filter:
imn.append(im)
else:
igood = all([(len(imfile) > len(mask_filter) and imfile[-len(mask_filter):] != mask_filter) or len(imfile) < len(mask_filter)
for mask_filter in mask_filters])
if igood:
imn.append(im)
image_names = imn

return image_names

def get_label_files(image_names, imf, mask_filter):
def get_label_files(image_names, mask_filter, imf=None):
nimg = len(image_names)
label_names = [os.path.splitext(image_names[n])[0] for n in range(nimg)]
if len(imf) > 0:
if imf is not None and len(imf) > 0:
label_names = [label_names[n][:-len(imf)] for n in range(nimg)]
if os.path.exists(label_names[0] + mask_filter + '.tif'):
label_names = [label_names[n] + mask_filter + '.tif' for n in range(nimg)]
Expand All @@ -54,8 +55,10 @@ def get_label_files(image_names, imf, mask_filter):
return label_names


if __name__ == '__main__':
def main():
parser = argparse.ArgumentParser(description='cellpose parameters')
parser.add_argument('--check_mkl', action='store_true', help='check if mkl working')
parser.add_argument('--mkldnn', action='store_true', help='force MXNET_SUBGRAPH_BACKEND = "MKLDNN"')
parser.add_argument('--train', action='store_true', help='train network using images in dir (not yet implemented)')
parser.add_argument('--dir', required=False,
default=[], type=str, help='folder containing data to run or train on')
Expand All @@ -80,7 +83,9 @@ def get_label_files(image_names, imf, mask_filter):
default=0.4, type=float, help='flow error threshold, 0 turns off this optional QC step')
parser.add_argument('--cellprob_threshold', required=False,
default=0.0, type=float, help='cell probability threshold, centered at 0.0')
parser.add_argument('--save_png', action='store_true', help='save masks as png')
parser.add_argument('--save_png', action='store_true', help='save masks as png and outlines as text file for ImageJ')
parser.add_argument('--save_tif', action='store_true', help='save masks as tif and outlines as text file for ImageJ')
parser.add_argument('--fast_mode', action='store_true', help="make code run faster by turning off augmentations and 4 network averaging")
parser.add_argument('--no_npy', action='store_true', help='suppress saving of npy')

# settings for training
Expand All @@ -95,8 +100,30 @@ def get_label_files(image_names, imf, mask_filter):
parser.add_argument('--batch_size', required=False,
default=8, type=int, help='batch size')


args = parser.parse_args()

if args.check_mkl:
print('Running test snippet to check if MKL running (https://mxnet.apache.org/versions/1.6/api/python/docs/tutorials/performance/backend/mkldnn/mkldnn_readme.html#4)')
process = subprocess.Popen('python test_mkl.py', stdout=subprocess.PIPE,
stderr=subprocess.PIPE, cwd=os.path.dirname(os.path.realpath(__file__)))
stdout, stderr = process.communicate()
if len(stdout)>0:
print('MKL version working - CPU version is hardware-accelerated.')
mkl_enabled = True
else:
print('WARNING: MKL version not working/installed - CPU version will be SLOW!')
mkl_enabled = False
else:
mkl_enabled = True

if not args.train and (mkl_enabled and args.mkldnn):
os.environ["MXNET_SUBGRAPH_BACKEND"]="MKLDNN"
else:
os.environ["MXNET_SUBGRAPH_BACKEND"]=""

import mxnet as mx

if len(args.dir)==0:
if not GUI_ENABLED:
print('ERROR: %s'%GUI_ERROR)
Expand All @@ -114,12 +141,13 @@ def get_label_files(image_names, imf, mask_filter):
if len(args.img_filter)>0:
imf = args.img_filter
else:
imf = ''
imf = None

image_names = get_image_files(args.dir, args.mask_filter)
image_names = get_image_files(args.dir, args.mask_filter, imf=imf)
nimg = len(image_names)
images = [skimage.io.imread(image_names[n]) for n in range(nimg)]
images = [io.imread(image_names[n]) for n in range(nimg)]


if args.use_gpu:
use_gpu = utils.use_gpu()
if use_gpu:
Expand All @@ -130,6 +158,7 @@ def get_label_files(image_names, imf, mask_filter):
model_dir = pathlib.Path.home().joinpath('.cellpose', 'models')

if not args.train:
tic = time.time()
if not (args.pretrained_model=='cyto' or args.pretrained_model=='nuclei'):
cpmodel_path = args.pretrained_model
if not os.path.exists(cpmodel_path):
Expand All @@ -153,7 +182,8 @@ def get_label_files(image_names, imf, mask_filter):
(nimg, cstr0[channels[0]], cstr1[channels[1]]))

masks, flows, _, diams = model.eval(images, channels=channels, diameter=diameter,
do_3D=args.do_3D,
do_3D=args.do_3D, net_avg=(not args.fast_mode),
augment=(not args.fast_mode),
flow_threshold=args.flow_threshold,
cellprob_threshold=args.cellprob_threshold)

Expand All @@ -170,33 +200,34 @@ def get_label_files(image_names, imf, mask_filter):
rescale = model.diam_mean / diameter
masks, flows, _ = model.eval(images, channels=channels, rescale=rescale,
do_3D=args.do_3D,
augment=(not args.fast_mode),
flow_threshold=args.flow_threshold,
cellprob_threshold=args.cellprob_threshold)
diams = diameter * np.ones(len(images))

print('>>>> saving results')
if not args.no_npy:
io.masks_flows_to_seg(images, masks, flows, diams, image_names, channels)
if args.save_png:
io.save_to_png(images, masks, flows, image_names)
if args.save_png or args.save_tif:
io.save_masks(images, masks, flows, image_names, png=args.save_png, tif=args.save_tif)
print('>>>> completed in %0.3f sec'%(time.time()-tic))
else:
if args.pretrained_model=='cyto' or args.pretrained_model=='nuclei':
cpmodel_path = os.fspath(model_dir.joinpath('%s_0'%(args.pretrained_model)))
if args.pretrained_model=='cyto':
szmean = 27.
szmean = 30.
else:
szmean = 15.
szmean = 17.
else:
cpmodel_path = os.fspath(args.pretrained_model)
szmean = 27.
szmean = 30.

if args.all_channels:
channels = None

label_names = get_label_files(image_names, imf, args.mask_filter)
label_names = get_label_files(image_names, args.mask_filter, imf=imf)
nimg = len(image_names)
labels = [skimage.io.imread(label_names[n]) for n in range(nimg)]
labels = [io.imread(label_names[n]) for n in range(nimg)]
if not os.path.exists(cpmodel_path):
cpmodel_path = False
print('>>>> training from scratch')
Expand All @@ -205,10 +236,10 @@ def get_label_files(image_names, imf, mask_filter):
print('>>>> median diameter set to 0 => no rescaling during training')
else:
rescale = True
szmean = args.diameter * (np.pi**0.5/2)
szmean = args.diameter
else:
rescale = True
args.diameter = szmean / (np.pi**0.5/2)
args.diameter = szmean
print('>>>> training starting with pretrained_model %s'%cpmodel_path)
if rescale:
print('>>>> rescaling diameter for each training image to %0.1f'%args.diameter)
Expand All @@ -217,16 +248,19 @@ def get_label_files(image_names, imf, mask_filter):

test_images, test_labels = None, None
if len(args.test_dir) > 0:
image_names_test = get_image_files(args.test_dir, args.mask_filter)
label_names_test = get_label_files(image_names_test, imf, args.mask_filter)
image_names_test = get_image_files(args.test_dir, args.mask_filter, imf=imf)
label_names_test = get_label_files(image_names_test, args.mask_filter, imf=imf)
nimg = len(image_names_test)
test_images = [skimage.io.imread(image_names_test[n]) for n in range(nimg)]
test_labels = [skimage.io.imread(label_names_test[n]) for n in range(nimg)]
test_images = [io.imread(image_names_test[n]) for n in range(nimg)]
test_labels = [io.imread(label_names_test[n]) for n in range(nimg)]
#print('>>>> %s model'%(['cellpose', 'unet'][args.unet]))
model = models.CellposeModel(device=device,
pretrained_model=cpmodel_path,
diam_mean=szmean,
batch_size=args.batch_size)
n_epochs=args.n_epochs
model.train(images, labels, test_images, test_labels, learning_rate=args.learning_rate,
channels=channels, save_path=os.path.realpath(args.dir), rescale=rescale, n_epochs=n_epochs)
channels=channels, save_path=os.path.realpath(args.dir), rescale=rescale, n_epochs=n_epochs)

if __name__ == '__main__':
main()
72 changes: 57 additions & 15 deletions cellpose/dynamics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from scipy.ndimage.filters import maximum_filter1d
import scipy.ndimage
import skimage.morphology
import numpy as np
from tqdm import trange
import time
Expand Down Expand Up @@ -276,6 +275,7 @@ def follow_flows(dP, niter=200):
np.arange(shape[2]), indexing='ij')
p = np.array(p).astype(np.float32)
# run dynamics on subset of pixels
#inds = np.array(np.nonzero(dP[0]!=0)).astype(np.int32).T
inds = np.array(np.nonzero(np.abs(dP[0])>1e-3)).astype(np.int32).T
p = steps3D(p, dP, inds, niter)
else:
Expand Down Expand Up @@ -428,10 +428,10 @@ def get_masks(p, iscell=None, rpad=20, flows=None, threshold=0.4):
for i in range(dims):
pflows[i] = pflows[i] + rpad
M0 = M[tuple(pflows)]
_,counts = np.unique(M0, return_counts=True)

# remove big masks
big = shape0[0] * shape0[1] * 0.35
_,counts = np.unique(M0, return_counts=True)
big = np.prod(shape0) * 0.4
for i in np.nonzero(counts > big)[0]:
M0[M0==i] = 0
_,M0 = np.unique(M0, return_inverse=True)
Expand All @@ -445,37 +445,79 @@ def get_masks(p, iscell=None, rpad=20, flows=None, threshold=0.4):
return M0

def fill_holes(masks, min_size=15):
""" fill holes in masks (2D) and discard masks smaller than min_size
""" fill holes in masks (2D/3D) and discard masks smaller than min_size (2D)
fill holes in each mask using scipy.ndimage.morphology.binary_fill_holes
Parameters
----------------
masks: int, 2D array
masks: int, 2D or 3D array
labelled masks, 0=NO masks; 1,2,...=mask labels,
size [Ly x Lx]
size [Ly x Lx] or [Lz x Ly x Lx]
min_size: int (optional, default 15)
minimum number of pixels per mask
Returns
---------------
masks: int, 2D array
masks: int, 2D or 3D array
masks with holes filled and masks smaller than min_size removed,
0=NO masks; 1,2,...=mask labels,
size [Ly x Lx]
size [Ly x Lx] or [Lz x Ly x Lx]
"""

if masks.ndim > 3 or masks.ndim < 2:
raise ValueError('masks_to_outlines takes 2D or 3D array, not %dD array'%masks.ndim)

slices = scipy.ndimage.find_objects(masks)
i = 0
for sr, sc in slices:
msk = masks[sr, sc] == (i+1)
msk = scipy.ndimage.morphology.binary_fill_holes(msk)
sm = np.logical_and(msk, ~skimage.morphology.remove_small_objects(msk, min_size=min_size, connectivity=1))
masks[sr, sc][msk] = (i+1)
masks[sr, sc][sm] = 0
for slc in slices:
if slc is not None:
msk = masks[slc] == (i+1)
if msk.ndim==3:
small_objects = np.zeros(msk.shape, np.bool)
for k in range(msk.shape[0]):
msk[k] = scipy.ndimage.morphology.binary_fill_holes(msk[k])
#small_objects[k] = ~remove_small_objects(msk[k], min_size=min_size)
else:
msk = scipy.ndimage.morphology.binary_fill_holes(msk)
small_objects = ~remove_small_objects(msk, min_size=min_size)
sm = np.logical_and(msk, small_objects)
#~skimage.morphology.remove_small_objects(msk, min_size=min_size, connectivity=1))
masks[slc][msk] = (i+1)
masks[slc][sm] = 0
i+=1
return masks

def remove_small_objects(ar, min_size=64, connectivity=1):
""" copied from skimage.morphology.remove_small_objects (required to be separate for pyinstaller) """
out = ar.copy()

if min_size == 0: # shortcut for efficiency
return out

if out.dtype == bool:
selem = scipy.ndimage.generate_binary_structure(ar.ndim, connectivity)
ccs = np.zeros_like(ar, dtype=np.int32)
scipy.ndimage.label(ar, selem, output=ccs)
else:
ccs = out

try:
component_sizes = np.bincount(ccs.ravel())
except ValueError:
raise ValueError("Negative value labels are not supported. Try "
"relabeling the input with `scipy.ndimage.label` or "
"`skimage.morphology.label`.")

if len(component_sizes) == 2 and out.dtype != bool:
warn("Only one label was provided to `remove_small_objects`. "
"Did you mean to use a boolean array?")

too_small = component_sizes < min_size
too_small_mask = too_small[ccs]
out[too_small_mask] = 0

return out

0 comments on commit 007a7c3

Please sign in to comment.