Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory usage increases during subsequent evaluations of cellpose model #539

Open
tcompa opened this issue Aug 1, 2022 · 2 comments
Open
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@tcompa
Copy link

tcompa commented Aug 1, 2022

Hi there, and thanks for your support.

While working on a different project, with @mfranzon and @jluethi we noticed an unexpected increase of RAM usage during subsequent runs of cellpose segmentation with the nuclei model. I'll report here an example which is as self-contained as possible, but other pieces of information are scattered in our original issues.
The question is whether this behavior looks expected/normal, or whether we could try to mitigate it. Also we are wondering if it comes from cellpose or from torch.

Context

Our goal is to perform segmentation of 3D images with cellpose pre-trained nuclei model. We need to segment a certain number of arrays (say 20 of them), and each array may have shape like (30, 2160, 2560) and type uint16. The processing of different arrays (AKA the different cellpose calls) takes place sequentially, on a node which has 64G of memory and access to a GPU. The GPU memory is under control throughout the entire run (around 4 GiB out of 16 are used), while this issue concerns the standard RAM usage (which we monitor via mprof).

Code and results

As a minimal-working example, we load a single array of shape (30,2160,2560) and repeatedly compute the corresponding labels several times. If needed, we can find the best way to share the image folder - or use other data which are already easily available for testing.

The code looks like

import sys
import time

from skimage.io import imread
import numpy as np
from cellpose import core
from cellpose import models

def run_cellpose(img, model):
    t_start = time.perf_counter()
    print(f"START | shape: {img.shape}")
    sys.stdout.flush()
    mask, flows, styles, diams = model.eval(
        img,
        do_3D=True,
        channels=[0, 0],
        net_avg=False,
        augment=False,
        diameter=80.0,
        anisotropy=6.0,
        cellprob_threshold=0.0,
    )
    t_end = time.perf_counter()
    print(f"END  | num_labels={np.max(mask)}, elapsed_time={t_end-t_start:.3f}")
    sys.stdout.flush()
    return mask


# Read 3D stack of images (42 Z planes available)
num_z = 30
stack = np.empty((num_z, 2160, 2560), dtype=np.uint16)
for z in range(num_z):
    stack[z, :, :] = imread(f"images_v1/20200812-CardiomyocyteDifferentiation14-Cycle1_B05_T0001F002L01A01Z{z+1:02d}C01.png")

# Initialize cellpose
use_gpu = core.use_gpu()
model = models.Cellpose(gpu=use_gpu, model_type="nuclei")
print(f"End of initialization: num_z={num_z}, use_gpu={use_gpu}")

nruns = 10

for run in range(nruns):
    print(run)
    run_cellpose(stack, model)

This code runs through, and it takes approximately 320 seconds for each segmentation (finding around 3k labels). The memory trace during the first few iterations of the loop is shown below, and we notice that subsequent runs have a larger and larger memory usage - until this saturates after a few iterations. If we look at the plateau regions in the memory trace, for instance, their values (in GiB) are: 12, 13.8, 14.1, 14.1, .. Also the memory-usage peaks at the end of each cellpose calls are shifting up by a similar amount, accumulating about 2 GiB during the first 2-3 iterations.
The simplest explanation would be that cellpose or torch are caching something, but we couldn't identify what is being cached. Is this actually happening? If so, is there a way to deactivate this caching mechanism?

fig_memory

Expected behavior and why it matters

We would expect that subsequent runs on the same exact input require a very similar amount of memory - unless some caching is in-place. The relevance of this issue (for us) is that even if the memory accumulation seems mild (that's only 2 GiB more than expected), in more complex/heavy use cases (including additional parallelism) it may lead to memory errors (as we found in fractal-analytics-platform/fractal-client#109 (comment)). For this reason we'd really like to keep it under control, possibly by deactivating caching options (if any).

Environment

The python code above is submitted to a SLURM queue, and it runs on a node with a GPU available.

Relevant details on the python environment:

sys.version='3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]'
numpy.__version__='1.23.1'
torch.__version__='1.12.0+cu102'
@carsen-stringer
Copy link
Member

I have no idea, have you tried any garbage collecting? you could call cellpose as a process and then it will clean up (all those options are available on the CLI) but then you have to re-read in the saved masks

@tcompa
Copy link
Author

tcompa commented Aug 29, 2022

Thanks for your comment.

I confirm that adding gc.collect() here and there (both within run_cellpose function, and especially right after each call to this function within the loop) does not lead to any relevant change in the memory trace.

At the moment we cannot go for the CLI path, since this labeling task is part of a more complex platform to process bio-images (https://github.com/fractal-analytics-platform/fractal), where tasks need to be python functions.

For now we'll just keep this issue in mind, and apply mitigation strategies (e.g. working at a lower resolution) if/when needed.

@carsen-stringer carsen-stringer added bug Something isn't working help wanted Extra attention is needed labels Jan 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants