Skip to content

Commit

Permalink
implementing fixes for #543, #554, #555, #557 and #558
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Sep 9, 2022
1 parent 91dd7ab commit 02b83a5
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 62 deletions.
40 changes: 27 additions & 13 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# settings re-grouped a bit
def main():
parser = argparse.ArgumentParser(description='cellpose parameters')

# settings for CPU vs GPU
hardware_args = parser.add_argument_group("hardware arguments")
hardware_args.add_argument('--use_gpu', action='store_true', help='use gpu if torch with cuda installed')
Expand All @@ -35,6 +35,8 @@ def main():
input_img_args = parser.add_argument_group("input image arguments")
input_img_args.add_argument('--dir',
default=[], type=str, help='folder containing data to run or train on.')
input_img_args.add_argument('--image_path',
default=[], type=str, help='if given and --dir not given, run on single image instead of folder (cannot train with this option)')
input_img_args.add_argument('--look_one_level_down', action='store_true', help='run processing on all subdirectories of current folder')
input_img_args.add_argument('--img_filter',
default=[], type=str, help='end string for images to run on')
Expand All @@ -52,6 +54,7 @@ def main():
# model settings
model_args = parser.add_argument_group("model arguments")
model_args.add_argument('--pretrained_model', required=False, default='cyto', type=str, help='model to use for running or starting training')
model_args.add_argument('--add_model', required=False, default=None, type=str, help='model path to copy model to hidden .cellpose folder for using in GUI/CLI')
model_args.add_argument('--unet', action='store_true', help='run standard unet instead of cellpose flow output')
model_args.add_argument('--nclasses',default=3, type=int, help='if running unet, choose 2 or 3; cellpose always uses 3')

Expand Down Expand Up @@ -128,14 +131,17 @@ def main():
else:
mkl_enabled = True

if len(args.dir)==0:
if not GUI_ENABLED:
print('GUI ERROR: %s'%GUI_ERROR)
if GUI_IMPORT:
print('GUI FAILED: GUI dependencies may not be installed, to install, run')
print(' pip install cellpose[gui]')
if len(args.dir)==0 and len(args.image_path)==0:
if args.add_model:
io.add_model(args.add_model)
else:
gui.run()
if not GUI_ENABLED:
print('GUI ERROR: %s'%GUI_ERROR)
if GUI_IMPORT:
print('GUI FAILED: GUI dependencies may not be installed, to install, run')
print(' pip install cellpose[gui]')
else:
gui.run()

else:
if args.verbose:
Expand Down Expand Up @@ -183,13 +189,21 @@ def main():
szmean = 30.
builtin_size = model_type == 'cyto' or model_type == 'cyto2' or model_type == 'nuclei'

if len(args.image_path) > 0 and (args.train or args.train_size):
raise ValueError('ERROR: cannot train model with single image input')

if not args.train and not args.train_size:
tic = time.time()

image_names = io.get_image_files(args.dir,
args.mask_filter,
imf=imf,
look_one_level_down=args.look_one_level_down)
if len(args.dir) > 0:
image_names = io.get_image_files(args.dir,
args.mask_filter,
imf=imf,
look_one_level_down=args.look_one_level_down)
else:
if os.path.exists(args.image_path):
image_names = [args.image_path]
else:
raise ValueError(f'ERROR: no file found at {args.image_path}')
nimg = len(image_names)

cstr0 = ['GRAY', 'RED', 'GREEN', 'BLUE']
Expand Down
40 changes: 9 additions & 31 deletions cellpose/gui/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import fastremap

from .. import utils, plot, transforms, models
from ..io import imread, imsave, outlines_to_text
from ..io import imread, imsave, outlines_to_text, add_model, remove_model
from ..transforms import normalize99

try:
Expand All @@ -28,31 +28,17 @@

def _init_model_list(parent):
models.MODEL_DIR.mkdir(parents=True, exist_ok=True)
parent.model_list_path = os.fspath(models.MODEL_DIR.joinpath('gui_models.txt'))
parent.model_strings = []
if not os.path.exists(parent.model_list_path):
textfile = open(parent.model_list_path, 'w')
textfile.close()
else:
with open(parent.model_list_path, 'r') as textfile:
lines = [line.rstrip() for line in textfile]
if len(lines) > 0:
parent.model_strings.extend(lines)
parent.model_list_path = models.MODEL_LIST_PATH
parent.model_strings = models.get_user_models()

def _add_model(parent, filename=None, load_model=True):
if filename is None:
name = QFileDialog.getOpenFileName(
parent, "Add model to GUI"
)
filename = name[0]
add_model(filename)
fname = os.path.split(filename)[-1]
try:
shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname)))
except shutil.SameFileError:
pass
print(f'GUI_INFO: {filename} copied to models folder {os.fspath(models.MODEL_DIR)}')
with open(parent.model_list_path, 'a') as textfile:
textfile.write(fname + '\n')
parent.ModelChoose.addItems([fname])
parent.model_strings.append(fname)
if len(parent.model_strings) > 0:
Expand All @@ -63,7 +49,7 @@ def _add_model(parent, filename=None, load_model=True):
if model_string == fname:
_remove_model(parent, ind=ind+1, verbose=False)

parent.ModelChoose.setCurrentIndex(len(parent.model_strings))
parent.ModelChoose.setCurrentIndex(len(parent.model_strings))
if load_model:
parent.model_choose(len(parent.model_strings))

Expand All @@ -72,28 +58,20 @@ def _remove_model(parent, ind=None, verbose=True):
ind = parent.ModelChoose.currentIndex()
if ind > 0:
ind -= 1
if verbose:
print(f'GUI_INFO: deleting {parent.model_strings[ind]} from GUI')
parent.ModelChoose.removeItem(ind+1)
del parent.model_strings[ind]
custom_strings = parent.model_strings
if len(custom_strings) > 0:
with open(parent.model_list_path, 'w') as textfile:
for fname in custom_strings:
textfile.write(fname + '\n')
# remove model from txt path
modelstr = parent.ModelChoose.currentText()
remove_model(modelstr)
if len(parent.model_strings) > 0:
parent.ModelChoose.setCurrentIndex(len(parent.model_strings))
else:
# write empty file
textfile = open(parent.model_list_path, 'w')
textfile.close()
parent.ModelChoose.setCurrentIndex(0)
parent.ModelButton.setStyleSheet(parent.styleInactive)
parent.ModelButton.setEnabled(False)
else:
print('ERROR: no model selected to delete')



def _get_train_set(image_names):
""" get training data and labels for images in current folder image_names"""
train_data, train_labels, train_files = [], [], []
Expand Down
66 changes: 54 additions & 12 deletions cellpose/io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, datetime, gc, warnings, glob
import os, datetime, gc, warnings, glob, shutil
from natsort import natsorted
import numpy as np
import cv2
Expand All @@ -7,6 +7,7 @@
from tqdm import tqdm
from pathlib import Path


try:
from PyQt5 import QtGui, QtCore, Qt, QtWidgets
from PyQt5.QtWidgets import QMessageBox
Expand Down Expand Up @@ -66,7 +67,9 @@ def outlines_to_text(base, outlines):
f.write('\n')

def imread(filename):
ext = os.path.splitext(filename)[-1]
""" read in image with tif or image file type supported by cv2 """
# ensure that extension check is not case sensitive
ext = os.path.splitext(filename)[-1].lower()
if ext== '.tif' or ext=='.tiff':
with tifffile.TiffFile(filename) as tif:
ltif = len(tif.pages)
Expand Down Expand Up @@ -108,9 +111,38 @@ def imread(filename):
io_logger.critical('ERROR: could not read masks from file, %s'%e)
return None

def remove_model(filename, delete=False):
""" remove model from .cellpose custom model list """
filename = os.path.split(filename)[-1]
from . import models
model_strings = models.get_user_models()
if len(model_strings) > 0:
with open(models.MODEL_LIST_PATH, 'w') as textfile:
for fname in model_strings:
textfile.write(fname + '\n')
else:
# write empty file
textfile = open(models.MODEL_LIST_PATH, 'w')
textfile.close()
print(f'{filename} removed from custom model list')
if delete:
os.remove(os.fspath(models.MODEL_DIR.joinpath(fname)))
print('model deleted')

def add_model(filename):
""" add model to .cellpose models folder to use with GUI or CLI """
from . import models
fname = os.path.split(filename)[-1]
try:
shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname)))
except shutil.SameFileError:
pass
print(f'{filename} copied to models folder {os.fspath(models.MODEL_DIR)}')
with open(models.MODEL_LIST_PATH, 'a') as textfile:
textfile.write(fname + '\n')

def imsave(filename, arr):
ext = os.path.splitext(filename)[-1]
ext = os.path.splitext(filename)[-1].lower()
if ext== '.tif' or ext=='.tiff':
tifffile.imsave(filename, arr)
else:
Expand All @@ -130,13 +162,23 @@ def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):
if look_one_level_down:
folders = natsorted(glob.glob(os.path.join(folder, "*/")))
folders.append(folder)

exts = ['.png', '.jpg', '.jpeg', '.tif', '.tiff']
l0 = 0
al = 0
for folder in folders:
image_names.extend(glob.glob(folder + '/*%s.png'%imf))
image_names.extend(glob.glob(folder + '/*%s.jpg'%imf))
image_names.extend(glob.glob(folder + '/*%s.jpeg'%imf))
image_names.extend(glob.glob(folder + '/*%s.tif'%imf))
image_names.extend(glob.glob(folder + '/*%s.tiff'%imf))
all_files = glob.glob(folder + '/*')
al += len(all_files)
for ext in exts:
image_names.extend(glob.glob(folder + f'/*{imf}{ext}'))
image_names.extend(glob.glob(folder + f'/*{imf}{ext.upper()}'))
l0 += len(image_names)

# return error if no files found
if al==0:
raise ValueError('ERROR: no files in --dir folder ')
elif l0==0:
raise ValueError("ERROR: no images in --dir folder with extensions '.png', '.jpg', '.jpeg', '.tif', '.tiff'")

image_names = natsorted(image_names)
imn = []
for im in image_names:
Expand All @@ -148,10 +190,10 @@ def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):
if igood:
imn.append(im)
image_names = imn

if len(image_names)==0:
raise ValueError('ERROR: no images in --dir folder')

if len(image_names)==0:
raise ValueError('ERROR: no images in --dir folder without _masks or _flows ending')

return image_names

def get_label_files(image_names, mask_filter, imf=None):
Expand Down
7 changes: 4 additions & 3 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
MODEL_NAMES = ['cyto','nuclei','tissuenet','livecell', 'cyto2', 'general',
'CP', 'CPx', 'TN1', 'TN2', 'TN3', 'LC1', 'LC2', 'LC3', 'LC4']

MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath('gui_models.txt'))

def model_path(model_type, model_index, use_torch=True):
torch_str = 'torch'
if model_type=='cyto' or model_type=='cyto2' or model_type=='nuclei':
Expand All @@ -42,10 +44,9 @@ def cache_model_path(basename):
return cached_file

def get_user_models():
model_list_path = os.fspath(MODEL_DIR.joinpath('gui_models.txt'))
model_strings = []
if os.path.exists(model_list_path):
with open(model_list_path, 'r') as textfile:
if os.path.exists(MODEL_LIST_PATH):
with open(MODEL_LIST_PATH, 'r') as textfile:
lines = [line.rstrip() for line in textfile]
if len(lines) > 0:
model_strings.extend(lines)
Expand Down
9 changes: 8 additions & 1 deletion cellpose/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import warnings
import cv2
import torch

import logging
transforms_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -240,7 +241,13 @@ def convert_image(x, channels, channel_axis=None, z_axis=None,
do_3D=False, normalize=True, invert=False,
nchan=2):
""" return image with z first, channels last and normalized intensities """


# check if image is a torch array instead of numpy array
# converts torch to numpy
if torch.is_tensor(x):
transforms_logger.warning('torch array used as input, converting to numpy')
x = x.cpu().numpy()

# squeeze image, and if channel_axis or z_axis given, transpose image
if x.ndim > 3:
to_squeeze = np.array([int(isq) for isq,s in enumerate(x.shape) if s==1])
Expand Down
32 changes: 32 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,36 @@ Plot functions
.. automodule:: cellpose.plot
:members:

I/O functions
~~~~~~~~~~~~~~~~~~

.. automodule:: cellpose.io
:members:

Utils functions
~~~~~~~~~~~~~~~~~~

.. automodule:: cellpose.utils
:members:

Network classes
~~~~~~~~~~~~~~~~~~~~~

.. automodule:: cellpose.resnet_torch
:members:


Core functions
~~~~~~~~~~~~~~~~~~~~~

.. automodule:: cellpose.core
:members:

All models functions
~~~~~~~~~~~~~~~~~~~~~

.. automodule:: cellpose.models
:members:



8 changes: 6 additions & 2 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ with the model, or will be used if the diameter is 0

These models can be loaded and used in the notebook with e.g.
``models.CellposeModel(model_type='name_in_gui')`` or with the full path
``models.CellposeModel(pretrained_model='/full/path/to/model')`` .
``models.CellposeModel(pretrained_model='/full/path/to/model')`` . If you trained in the
GUI, you can automatically use the ``model_type`` argument. If you trained in the
command line, you need to first add the model to the cellpose path either in the GUI
in the Models menu, or using the command line:
``python -m cellpose --add_model /full/path/to/model``.

Or they can be used in the command line with ``python -m cellpose --pretrained_model name_in_gui``
Or these models can be used in the command line with ``python -m cellpose --pretrained_model name_in_gui``
or ``python -m cellpose --pretrained_model /full/path/to/model`` .
2 changes: 2 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def test_class_train(data_dir):
cpmodel_path = model.train(images, labels, train_files=image_names,
test_data=test_images, test_labels=test_labels, test_files=image_names_test,
channels=[2,1], save_path=train_dir, n_epochs=3)
io.add_model(cpmodel_path)
io.remove_model(cpmodel_path, delete=True)
print('>>>> model trained and saved to %s'%cpmodel_path)

def test_cli_train(data_dir):
Expand Down

0 comments on commit 02b83a5

Please sign in to comment.