Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Sep 22, 2020
2 parents 003aa5f + 22baf1d commit 0b9d2a7
Show file tree
Hide file tree
Showing 14 changed files with 2,790 additions and 721 deletions.
55 changes: 17 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,44 +51,11 @@ pip install cellpose --upgrade

If you have an older `cellpose` environment you can remove it with `conda env remove -n cellpose` before creating a new one.

Note you will always have to run **conda activate cellpose** before you run cellpose. If you want to run jupyter notebooks in this environment, then also `conda install jupyter`
and `pip install matplotlib`.
Note you will always have to run **conda activate cellpose** before you run cellpose. If you want to run jupyter notebooks in this environment, then also `conda install jupyter` and `pip install matplotlib`.

### Common issues
If you have **issues** with installation, see the [docs](https://cellpose.readthedocs.io/en/latest/installation.html) for more details, and then if the suggestions fail, open an issue.

If you receive the error: `Illegal instruction (core dumped)`, then likely mxnet does not recognize your MKL version. Please uninstall and reinstall mxnet without mkl:
~~~~
pip uninstall mxnet-mkl
pip uninstall mxnet
pip install mxnet==1.4.0
~~~~

**MAC OS ISSUE**: You may have an issue on Mac with the latest *opencv-python* library (package name *cv2*). Downgrade it with the command
~~~~
pip install opencv-python==3.4.5.20
~~~~

If you receive the error: `No module named PyQt5.sip`, then try uninstalling and reinstalling pyqt5
~~~~
pip uninstall pyqt5 pyqt5-tools
pip install pyqt5 pyqt5-tools pyqt5.sip
~~~~

If you receive an error associated with **matplotlib**, try upgrading it:
~~~~
pip install matplotlib --upgrade
~~~~


If you receive the error: `ImportError: _arpack DLL load failed`, then try uninstalling and reinstalling scipy
~~~~
pip uninstall scipy
pip install scipy
~~~~

If you are having issues with the graphical interface, make sure you have **python 3.7** and not python 3.8 installed.

**CUDA version**
### CUDA version

If you plan on running many images, you may want to install a GPU version of *mxnet*. I recommend using CUDA 10.0 or greater. Follow the instructions [here](https://mxnet.apache.org/get_started?).

Expand Down Expand Up @@ -157,8 +124,7 @@ The GUI serves two main functions:

There is a help window in the GUI that provides more instructions and
a page in the documentation [here](http://cellpose.readthedocs.io/en/latest/gui.html).
Also, if you hover over certain words in the GUI, their definitions
are revealed as tooltips.
Also, if you hover over certain words in the GUI, their definitions are revealed as tooltips.

### In a notebook

Expand All @@ -178,6 +144,19 @@ python -m cellpose --dir ~/images_nuclei/test/ --pretrained_model nuclei --diame

See the [docs](http://cellpose.readthedocs.io/en/latest/command.html) for more info.

### Timing

You can check if cellpose is running the MKL version (if you are using the CPU not the GPU) by adding the flag `--check_mkl`. If you are not using MKL cellpose will be much slower. Here are Cellpose run times divided into the time it takes to run the deep neural network (DNN) and the time for postprocessing (gradient tracking, segmentation, quality control etc.). The DNN runtime is shown using either a GPU (Nvidia GTX 1080Ti) or a CPU (Intel 10-core 7900X), with or without network ensembling (4net vs 1net). The postprocessing runtime is similar regardless of ensembling or CPU/GPU version. Runtime is shown for different image sizes, all with a cell diameter of 30 pixels (the average from our training set).

| | 256 pix | 512 pix | 1024 pix |
|----|-------|------|----------|
| DNN (1net, GPU) | 0.054 s | 0.12 s | 0.31 s |
| DNN (1net, CPU) | 0.30 s | 0.65 s | 2.4 s |
| DNN (4net, GPU) | 0.23 s | 0.41 s | 1.3 s |
| DNN (4net, CPU) | 1.3 s | 2.5 s | 9.1 s |
| | | | |
| Postprocessing (CPU) | 0.32 s | 1.2 s | 6.1 s |

## Outputs

See the [docs](http://cellpose.readthedocs.io/en/latest/outputs.html) for info.
Expand Down
146 changes: 111 additions & 35 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
raise

def get_image_files(folder, mask_filter, imf=None):
mask_filters = ['_cp_masks', '_cp_output', mask_filter]
mask_filters = ['_cp_masks', '_cp_output', '_flows', mask_filter]
image_names = []
if imf is None:
imf = ''
Expand All @@ -35,6 +35,8 @@ def get_image_files(folder, mask_filter, imf=None):
imfile = os.path.splitext(im)[0]
igood = all([(len(imfile) > len(mask_filter) and imfile[-len(mask_filter):] != mask_filter) or len(imfile) < len(mask_filter)
for mask_filter in mask_filters])
if len(imf)>0:
igood &= imfile[-len(imf):]==imf
if igood:
imn.append(im)
image_names = imn
Expand All @@ -43,23 +45,39 @@ def get_image_files(folder, mask_filter, imf=None):

def get_label_files(image_names, mask_filter, imf=None):
nimg = len(image_names)
label_names = [os.path.splitext(image_names[n])[0] for n in range(nimg)]
label_names0 = [os.path.splitext(image_names[n])[0] for n in range(nimg)]

if imf is not None and len(imf) > 0:
label_names = [label_names[n][:-len(imf)] for n in range(nimg)]
label_names = [label_names0[n][:-len(imf)] for n in range(nimg)]
else:
label_names = label_names0

# check for flows
if os.path.exists(label_names0[0] + '_flows.tif'):
flow_names = [label_names0[n] + '_flows.tif' for n in range(nimg)]
else:
flow_names = [label_names[n] + '_flows.tif' for n in range(nimg)]
if not all([os.path.exists(flow) for flow in flow_names]):
flow_names = None

# check for masks
if os.path.exists(label_names[0] + mask_filter + '.tif'):
label_names = [label_names[n] + mask_filter + '.tif' for n in range(nimg)]
elif os.path.exists(label_names[0] + mask_filter + '.png'):
label_names = [label_names[n] + mask_filter + '.png' for n in range(nimg)]
else:
raise ValueError('labels not provided with correct --mask_filter')
return label_names
if not all([os.path.exists(label) for label in label_names]):
raise ValueError('labels not provided for all images in train and/or test set')

return label_names, flow_names


def main():
parser = argparse.ArgumentParser(description='cellpose parameters')
parser.add_argument('--check_mkl', action='store_true', help='check if mkl working')
parser.add_argument('--mkldnn', action='store_true', help='force MXNET_SUBGRAPH_BACKEND = "MKLDNN"')
parser.add_argument('--train', action='store_true', help='train network using images in dir (not yet implemented)')
parser.add_argument('--train', action='store_true', help='train network using images in dir')
parser.add_argument('--dir', required=False,
default=[], type=str, help='folder containing data to run or train on')
parser.add_argument('--img_filter', required=False,
Expand All @@ -70,8 +88,10 @@ def main():
# settings for running cellpose
parser.add_argument('--pretrained_model', required=False,
default='cyto', type=str, help='model to use')
#parser.add_argument('--unet', required=False,
# default=0, type=int, help='run standard unet instead of cellpose flow output')
parser.add_argument('--unet', required=False,
default=0, type=int, help='run standard unet instead of cellpose flow output')
parser.add_argument('--nclasses', required=False,
default=3, type=int, help='if running unet, choose 2 or 3, otherwise not used')
parser.add_argument('--chan', required=False,
default=0, type=int, help='channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE')
parser.add_argument('--chan2', required=False,
Expand All @@ -85,10 +105,11 @@ def main():
default=0.0, type=float, help='cell probability threshold, centered at 0.0')
parser.add_argument('--save_png', action='store_true', help='save masks as png and outlines as text file for ImageJ')
parser.add_argument('--save_tif', action='store_true', help='save masks as tif and outlines as text file for ImageJ')
parser.add_argument('--fast_mode', action='store_true', help="make code run faster by turning off augmentations and 4 network averaging")
parser.add_argument('--fast_mode', action='store_true', help="make code run faster by turning off 4 network averaging")
parser.add_argument('--no_npy', action='store_true', help='suppress saving of npy')

# settings for training
parser.add_argument('--train_size', action='store_true', help='train size network at end of training')
parser.add_argument('--mask_filter', required=False,
default='_masks', type=str, help='end string for masks to run on')
parser.add_argument('--test_dir', required=False,
Expand All @@ -99,21 +120,30 @@ def main():
default=500, type=int, help='number of epochs')
parser.add_argument('--batch_size', required=False,
default=8, type=int, help='batch size')
parser.add_argument('--residual_on', required=False,
default=1, type=int, help='use residual connections')
parser.add_argument('--style_on', required=False,
default=1, type=int, help='use style vector')
parser.add_argument('--concatenation', required=False,
default=0, type=int, help='concatenate downsampled layers with upsampled layers (off by default which means they are added)')


args = parser.parse_args()

if args.check_mkl:
tic=time.time()
print('Running test snippet to check if MKL running (https://mxnet.apache.org/versions/1.6/api/python/docs/tutorials/performance/backend/mkldnn/mkldnn_readme.html#4)')
process = subprocess.Popen('python test_mkl.py', stdout=subprocess.PIPE,
stderr=subprocess.PIPE, cwd=os.path.dirname(os.path.realpath(__file__)))
process = subprocess.Popen(['python', 'test_mkl.py'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
cwd=os.path.dirname(os.path.abspath(__file__)))
stdout, stderr = process.communicate()
if len(stdout)>0:
print('MKL version working - CPU version is hardware-accelerated.')
mkl_enabled = True
else:
print('WARNING: MKL version not working/installed - CPU version will be SLOW!')
mkl_enabled = False
print(time.time()-tic)
else:
mkl_enabled = True

Expand Down Expand Up @@ -157,7 +187,7 @@ def main():
print('>>>> using %s'%(['CPU', 'GPU'][use_gpu]))
model_dir = pathlib.Path.home().joinpath('.cellpose', 'models')

if not args.train:
if not args.train and not args.train_size:
tic = time.time()
if not (args.pretrained_model=='cyto' or args.pretrained_model=='nuclei'):
cpmodel_path = args.pretrained_model
Expand All @@ -166,8 +196,7 @@ def main():
args.pretrained_model = 'cyto'

if args.pretrained_model=='cyto' or args.pretrained_model=='nuclei':
model = models.Cellpose(device=device, model_type=args.pretrained_model,
batch_size=args.batch_size)
model = models.Cellpose(device=device, model_type=args.pretrained_model)

if args.diameter==0:
diameter = None
Expand All @@ -183,9 +212,10 @@ def main():

masks, flows, _, diams = model.eval(images, channels=channels, diameter=diameter,
do_3D=args.do_3D, net_avg=(not args.fast_mode),
augment=(not args.fast_mode),
augment=False,
flow_threshold=args.flow_threshold,
cellprob_threshold=args.cellprob_threshold)
cellprob_threshold=args.cellprob_threshold,
batch_size=args.batch_size)

else:
if args.all_channels:
Expand All @@ -200,7 +230,7 @@ def main():
rescale = model.diam_mean / diameter
masks, flows, _ = model.eval(images, channels=channels, rescale=rescale,
do_3D=args.do_3D,
augment=(not args.fast_mode),
augment=False,
flow_threshold=args.flow_threshold,
cellprob_threshold=args.cellprob_threshold)
diams = diameter * np.ones(len(images))
Expand All @@ -225,10 +255,30 @@ def main():
if args.all_channels:
channels = None

label_names = get_label_files(image_names, args.mask_filter, imf=imf)
# training data
label_names, flow_names = get_label_files(image_names, args.mask_filter, imf=imf)
nimg = len(image_names)
labels = [io.imread(label_names[n]) for n in range(nimg)]
if flow_names is not None and not args.unet:
labels = [np.concatenate((labels[n][np.newaxis,:,:], io.imread(flow_names[n])), axis=0)
for n in range(nimg)]

# testing data
test_images, test_labels, image_names_test = None, None, None
if len(args.test_dir) > 0:
image_names_test = get_image_files(args.test_dir, args.mask_filter, imf=imf)
label_names_test, flow_names_test = get_label_files(image_names_test, args.mask_filter, imf=imf)
nimg = len(image_names_test)
test_images = [io.imread(image_names_test[n]) for n in range(nimg)]
test_labels = [io.imread(label_names_test[n]) for n in range(nimg)]
if flow_names_test is not None and not args.unet:
test_labels = [np.concatenate((test_labels[n][np.newaxis,:,:], io.imread(flow_names_test[n])), axis=0)
for n in range(nimg)]

# model path
if not os.path.exists(cpmodel_path):
if not args.train:
raise ValueError('ERROR: model path missing or incorrect - cannot train size model')
cpmodel_path = False
print('>>>> training from scratch')
if args.diameter==0:
Expand All @@ -240,27 +290,53 @@ def main():
else:
rescale = True
args.diameter = szmean
print('>>>> training starting with pretrained_model %s'%cpmodel_path)
if rescale:
print('>>>> rescaling diameter for each training image to %0.1f'%args.diameter)
print('>>>> pretrained model %s is being used'%cpmodel_path)
args.residual_on = 1
args.style_on = 1
args.concatenation = 0
if rescale and args.train:
print('>>>> during training rescaling images to fixed diameter of %0.1f pixels'%args.diameter)

# initialize model
if args.unet:
model = models.UnetModel(device=device,
pretrained_model=cpmodel_path,
diam_mean=szmean,
residual_on=args.residual_on,
style_on=args.style_on,
concatenation=args.concatenation,
nclasses=args.nclasses)
else:
model = models.CellposeModel(device=device,
pretrained_model=cpmodel_path,
diam_mean=szmean,
residual_on=args.residual_on,
style_on=args.style_on,
concatenation=args.concatenation)

# train segmentation model
if args.train:
cpmodel_path = model.train(images, labels, train_files=image_names,
test_data=test_images, test_labels=test_labels, test_files=image_names_test,
learning_rate=args.learning_rate, channels=channels,
save_path=os.path.realpath(args.dir), rescale=rescale, n_epochs=args.n_epochs)
print('>>>> model trained and saved to %s'%cpmodel_path)

test_images, test_labels = None, None
if len(args.test_dir) > 0:
image_names_test = get_image_files(args.test_dir, args.mask_filter, imf=imf)
label_names_test = get_label_files(image_names_test, args.mask_filter, imf=imf)
nimg = len(image_names_test)
test_images = [io.imread(image_names_test[n]) for n in range(nimg)]
test_labels = [io.imread(label_names_test[n]) for n in range(nimg)]
#print('>>>> %s model'%(['cellpose', 'unet'][args.unet]))
model = models.CellposeModel(device=device,
pretrained_model=cpmodel_path,
diam_mean=szmean,
batch_size=args.batch_size)
n_epochs=args.n_epochs
model.train(images, labels, test_images, test_labels, learning_rate=args.learning_rate,
channels=channels, save_path=os.path.realpath(args.dir), rescale=rescale, n_epochs=n_epochs)
# train size model
if args.train_size:
sz_model = models.SizeModel(model, device=device)
sz_model.train(images, labels, test_images, test_labels, channels=channels)
if test_images is not None:
predicted_diams, diams_style = sz_model.eval(test_images, channels=channels)
if test_labels[0].ndim>2:
tlabels = [lbl[0] for lbl in test_labels]
else:
tlabels = test_labels
ccs = np.corrcoef(diams_style, np.array([utils.diameters(lbl)[0] for lbl in tlabels]))[0,1]
cc = np.corrcoef(predicted_diams, np.array([utils.diameters(lbl)[0] for lbl in tlabels]))[0,1]
print('style test correlation: %0.4f; final test correlation: %0.4f'%(ccs,cc))
np.save(os.path.join(args.test_dir, '%s_predicted_diams.npy'%os.path.split(cpmodel_path)[1]),
{'predicted_diams': predicted_diams, 'diams_style': diams_style})

if __name__ == '__main__':
main()

0 comments on commit 0b9d2a7

Please sign in to comment.