Skip to content

Commit

Permalink
adding docs for restore
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Feb 23, 2024
1 parent e453d1b commit d4857f5
Show file tree
Hide file tree
Showing 19 changed files with 1,210 additions and 1,070 deletions.
2 changes: 1 addition & 1 deletion cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def main():
invert=args.invert, batch_size=args.batch_size,
interp=(not args.no_interp), normalize=(not args.no_norm),
channel_axis=args.channel_axis, z_axis=args.z_axis,
anisotropy=args.anisotropy)
anisotropy=args.anisotropy, niter=args.niter)
masks, flows = out[:2]
if len(out) > 3:
diams = out[-1]
Expand Down
3 changes: 3 additions & 0 deletions cellpose/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def get_arg_parser():
algorithm_args.add_argument(
"--cellprob_threshold", default=0, type=float,
help="cellprob threshold, default is 0, decrease to find more and larger masks")
algorithm_args.add_argument(
"--niter", default=0, type=int,
help="niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs")

algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
help="anisotropy of volume in 3D")
Expand Down
2 changes: 0 additions & 2 deletions cellpose/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,10 @@ def load_images_labels(tdir, mask_filter="_masks", image_filter=None,
k = 0
for n in range(nimg):
if os.path.isfile(label_names[n]) or os.path.isfile(flow_names[0]):
print(image_names[n])
image = imread(image_names[n])
if label_names is not None:
label = imread(label_names[n])
if flow_names is not None:
print(flow_names[n])
flow = imread(flow_names[n])
if flow.shape[0] < 4:
label = np.concatenate((label[np.newaxis, :, :], flow), axis=0)
Expand Down
80 changes: 32 additions & 48 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class Cellpose():
device (torch device, optional): Device used for model running / training. Overrides gpu input. Recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")). Defaults to None.
Attributes:
torch (bool): Flag indicating if torch is available.
device (torch device): Device used for model running / training.
gpu (bool): Flag indicating if GPU is used.
diam_mean (float): Mean diameter for cytoplasm model.
Expand Down Expand Up @@ -202,7 +201,6 @@ class CellposeModel():
Class representing a Cellpose model.
Attributes:
torch (bool): Whether or not the torch library is available.
diam_mean (float): Mean "diameter" value for the model.
builtin (bool): Whether the model is a built-in model or not.
device (torch device): Device used for model running / training.
Expand Down Expand Up @@ -357,9 +355,10 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
Returns:
masks (list, np.ndarray): labelled image(s), where 0=no masks; 1,2,...=mask labels
flows (list): list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration
styles (list, np.ndarray): style vector summarizing each image of size 256.
A tuple containing:
- masks (list, np.ndarray): labelled image(s), where 0=no masks; 1,2,...=mask labels
- flows (list): list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration
- styles (list, np.ndarray): style vector summarizing each image of size 256.
"""
if isinstance(x, list) or x.squeeze().ndim == 5:
Expand Down Expand Up @@ -501,7 +500,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
if compute_masks:
tic = time.time()
niter0 = 200 if (do_3D and not resample) else (1 / rescale * 200)
niter = niter0 if niter is None else niter
niter = niter0 if niter is None or niter==0 else niter
if do_3D:
masks, p = dynamics.resize_and_compute_masks(
dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
Expand Down Expand Up @@ -552,17 +551,6 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
masks, p = np.zeros(0), np.zeros(0) #pass back zeros if not compute_masks
return masks, styles, dP, cellprob, p

def loss_fn(self, lbl, y):
""" loss function between true labels lbl and prediction y """
veci = 5. * self._to_device(lbl[:, 1:])
lbl = self._to_device(lbl[:, 0] > .5).float()
loss = self.criterion(y[:, :2], veci)
loss /= 2.
loss2 = self.criterion2(y[:, 2], lbl)
loss = loss + loss2
return loss


class SizeModel():
"""
Linear regression model for determining the size of objects in image
Expand Down Expand Up @@ -612,50 +600,46 @@ def __init__(self, cp_model, device=None, pretrained_size=None, **kwargs):
raise ValueError(error_message)

def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False,
augment=False, tile=True, batch_size=8, progress=None, interp=True):
augment=False, tile=True, batch_size=8, progress=None):
"""Use images x to produce style or use style input to predict size of objects in image.
Object size estimation is done in two steps:
1. Use a linear regression model to predict size from style in image.
2. Resize image to predicted size and run CellposeModel to get output masks.
Take the median object size of the predicted masks as the final predicted size.
Parameters:
x: list or array of images
Can be a list of 2D/3D images or an array of 2D/3D images.
channels: list (optional, default None)
List of channels, either of length 2 or of length number of images by 2.
The first element of the list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
The second element of the list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
Args:
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
For instance, to segment grayscale images, input [0,0]. To segment images with cells
in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
image with cells in green and nuclei in blue, input [[0,0], [2,3]].
Defaults to None.
channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
if None, channels dimension is attempted to be automatically determined. Defaults to None.
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
can also pass dictionary of parameters (all keys are optional, default values shown):
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
- "normalize"=True ; run normalization (if False, all following parameters ignored)
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
Defaults to True.
invert (bool, optional): Invert image pixel intensity before running network (if True, image is also normalized). Defaults to False.
augment (bool, optional): tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
tile (bool, optional): tiles image to ensure GPU/CPU memory usage limited (recommended). Defaults to True.
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
channel_axis: int (optional, default None)
If None, the channels dimension is attempted to be automatically determined.
normalize: bool (default, True)
Normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel.
invert: bool (optional, default False)
Invert image pixel intensity before running the network.
augment: bool (optional, default False)
Tile the image with overlapping tiles and flips overlapped regions to augment.
tile: bool (optional, default True)
Tile the image to ensure GPU/CPU memory usage is limited (recommended).
progress: pyqt progress bar (optional, default None)
Return progress bar status to GUI.
Returns:
diam: array, float
Final estimated diameters from images x or styles style after running both steps.
diam_style: array, float
Estimated diameters from style alone.
A tuple containing:
- diam (np.ndarray): Final estimated diameters from images x or styles style after running both steps.
- diam_style (np.ndarray): Estimated diameters from style alone.
"""

if isinstance(x, list):
Expand Down
2 changes: 2 additions & 0 deletions cellpose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
(train_data, train_labels, train_files, train_labels_files, train_probs, diam_train,
test_data, test_labels, test_files, test_labels_files, test_probs, diam_test) = out

net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)

nimg = len(train_data) if train_data is not None else len(train_files)
nimg_test = len(test_data) if test_data is not None else None
nimg_test = len(test_files) if test_files is not None else nimg_test
Expand Down
6 changes: 3 additions & 3 deletions cellpose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,11 @@ def outlines_list(masks, multiprocessing_threshold=1000, multiprocessing=None):
- This function is a wrapper for outlines_list_single and outlines_list_multi.
- Multiprocessing is disabled for Windows.
"""
# default to use multiprocessing if few_masks, but allow user to override
# default to use multiprocessing if not few_masks, but allow user to override
if multiprocessing is None:
few_masks = np.max(masks) < multiprocessing_threshold
multiprocessing = few_masks

multiprocessing = not few_masks
# disable multiprocessing for Windows
if os.name == "nt":
if multiprocessing:
Expand Down
20 changes: 20 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,32 @@ CellposeModel
.. autoclass:: cellpose.models.CellposeModel
:members:

CellposeDenoiseModel
~~~~~~~~~~~~~~~~~

.. autoclass:: cellpose.denoise.CellposeDenoiseModel
:members:


DenoiseModel
~~~~~~~~~~~~~~~~~

.. autoclass:: cellpose.denoise.DenoiseModel
:members:

SizeModel
~~~~~~~~~~~~~~~~~

.. autoclass:: cellpose.models.SizeModel
:members:

Training
~~~~~~~~~~~~~~~~~~

.. automodule:: cellpose.train
:members:


Metrics
~~~~~~~~~~~~~~~~~~

Expand Down
10 changes: 8 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@ can install it as ``pip install cellpose[gui]``.
You can try it out without installing at `cellpose.org`_.
Also check out these resources:

Cellpose 2.0
Cellpose3: one-click image restoration for improved cellular segmentation

- `paper <https://www.biorxiv.org/content/10.1101/2024.02.10.579780v1>`_ on biorxiv
- `thread <https://neuromatch.social/@computingnature/111932247922392030>`_

Cellpose 2.0: how to train your own model

- `paper <https://www.biorxiv.org/content/10.1101/2022.04.01.486764v1>`_ on biorxiv
- `talk <https://www.youtube.com/watch?v=3ydtAhfq6H0>`_
- twitter `thread <https://twitter.com/marius10p/status/1511415409047650307?s=20&t=umTVIG1CFKIWHYMrQqFKyQ>`_
- human-in-the-loop training protocol `video <https://youtu.be/3Y1VKcxjNy4>`_

Cellpose 1.0
Cellpose: a generalist algorithm for cellular segmentation

- `paper <https://www.biorxiv.org/content/10.1101/2020.02.02.931238v1>`_ on biorxiv (see figure 1 below) and in `nature methods <https://t.co/kBMXmPp3Yn?amp=1>`_
- twitter `thread <https://twitter.com/computingnature/status/1224477812763119617>`_
Expand All @@ -46,6 +51,7 @@ Cellpose 1.0
settings
outputs
models
restore
train
openvino
faq
Expand Down
16 changes: 10 additions & 6 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,19 @@ Dependencies
~~~~~~~~~~~~~~~~~~~~~~

cellpose relies on the following excellent packages (which are
automatically installed with conda/pip if missing):
automatically installed with pip if missing):

- `pytorch`_
- `pyqtgraph`_
- `PyQt5`_
- `PyQt5`_ or pyside or PyQt6
- `numpy`_ (>=1.16.0)
- `numba`_
- `scipy`_
- `scikit-image`_
- `tifffile`_
- `natsort`_
- `matplotlib`_
- `fastremap`_
- `roifile`_
- `superqt`_

.. _Anaconda: https://www.anaconda.com/download/
.. _environment.yml: https://github.com/MouseLand/cellpose/blob/master/environment.yml?raw=true
Expand All @@ -129,6 +131,8 @@ automatically installed with conda/pip if missing):
.. _numpy: http://www.numpy.org/
.. _numba: http://numba.pydata.org/numba-doc/latest/user/5minguide.html
.. _scipy: https://www.scipy.org/
.. _scikit-image: https://scikit-image.org/
.. _tifffile: https://pypi.org/project/tifffile/
.. _natsort: https://natsort.readthedocs.io/en/master/
.. _matplotlib: https://matplotlib.org/
.. _fastremap: https://github.com/seung-lab/fastremap
.. _roifile: https://github.com/cgohlke/roifile
.. _superqt: https://github.com/pyapp-kit/superqt

0 comments on commit d4857f5

Please sign in to comment.