Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuda OOM / GPU memory leak due to transforms on the GPU (DeepEdit) #6626

Closed
matt3o opened this issue Jun 19, 2023 · 5 comments
Closed

Cuda OOM / GPU memory leak due to transforms on the GPU (DeepEdit) #6626

matt3o opened this issue Jun 19, 2023 · 5 comments

Comments

@matt3o
Copy link
Contributor

matt3o commented Jun 19, 2023

Describe the bug
I have converted the DeepEdit transforms (https://github.com/Project-MONAI/MONAI/blob/dev/monai/apps/deepedit/transforms.py) to run on the GPU instead of the CPU. I just spent over a month of debugging the code since modifying the transforms to run on the GPU did completely keep me from running the code without OOM messages. I will paste some funny images below.
This means I could no longer run the code on the 11 Gb GPU (smaller crop size) or 24 Gb GPU. Having gotten access to a big cluster I tried 50Gb and even 80 Gb GPU and the code still crashed.
Most confusing of all the things, the crashes were apparently random, always at different epochs even when the same code was run twice. The memory usage appeared to be conforming to no pattern.
After debugging my own code for weeks I realized using the Garbage collection that some references are never cleared and the GC count always increases. This insight helped my to find this issue: #3423 which described the problem pretty well.

The problematic and nondeterministic behavior is linked to the garbage collection which only cleans references if they use a lot of memory. This is true for the previous transforms since they were done in the RAM where the orphaned memory areas will be rather big and be cleaned very soon.
This is not true however for GPU pointers in torch which then are cleared at random times but apparently not often enough for the code to work. This also explains why calling torch.cuda.empy_cache() would not bring any relief - the references to the memory still existed even though they were out scope but torch does not know that it can release the GPU memory then.

The fix for this random behavior is to add a GarbageCollector(trigger_event="iteration") into the training and validation handlers.

I did not find any MONAI docs which mention this behaviour, specifically when it comes to debugging OOM or Cudnn errors. However since there already is that GarbageCollector I guess other people must have run into this issue as well which makes it even more frustrating to me.

--> Conclusion: I am not sure if there is an easy solution to this problem. Seeing there are other people running this issue and since this is hard, indeterministic bugs, it is very important to fix it imo. What I do not know is how complex a fix would be, maybe someone here knows more. Also I don't know if this behavior sometimes occurs when using pytorch code only. However if this is MONAI specific it is framework breaking.

As a temporary fix I can add: The overhead for calling the GarbageCollector in my case appears to be neglectable. Maybe this should be a default handler for SupervisedTrainer and SupervisedEvaluator, only to be turned off with a performance flag if needed.

To Reproduce
Run the DeepEdit Code and follow the speedup guide, more specifically move the transforms to the GPU.
In my experience adding ToTensord(keys=("image", "label"), device=device, track_meta=False) at the end of transform is already enough to let the GPU memory run out or at least increase it extremely and most importantly non-deterministically.
I did however rework all of the transforms and moved all of the transforms including FindDiscrepancyRegionsDeepEditd, AddRandomGuidanceDeepEditd and AddGuidanceSignalDeepEditd to the GPU. (Also see #1332 about that)

Expected behavior
No memory leak.

Screenshots
Import info before the images: Training and validation where cropped to a fixed value. So in theory the GPU memory usage should remain constant over the epochs but different between training and validation. The spikes seen in the later images are due to the validation which only ran every 10 epochs. The important hings here is that these spikes do not increase over time.

x axis: iterations, y axis: amount of GPU memory used as returned by nvmlDeviceGetMemoryInfo()
one epoch is about 400 samples for training and 100 for validation

Initial runs of the code
15_gpu_usage
16_gpu_usage

After a few weeks I got it to a point where it ran much more consistently. Interestingly some operations introduce more non-determinism in the GPU memory usage and developed a feeling which ones that were and removed / replaced them with different operations. The result is the image below. However clearly there is still something fishy going on
59_gpu_usage

For comparison how it looks after adding the GarbageCollector (iteration level cleanup)
81_gpu_usage

And using the GarbageCollector but with epoch level cleanup (does only work on the 80Gb GPU, crashes on the 50Gb one. As we can see in the image above the "actually needed memory" is 33Gb for this setting - with GarbageCollection per Iteration we need 71 Gb at least and still it might crash on some bad GC day). What can clearly be seen is a lot more jitter
66_gpu_usage

Environment

This bug exists independent of environment. I did start with the officially recommended one, tried out different CUDA version and in the end upgraded to the most recent torch version to see if maybe it would be fixed there. I will paste the output from the last environment even though I know it will not be supported. You can verify this bug on the default MONAI pip installation as well however.

================================
Printing MONAI config...
================================
MONAI version: 1.1.0
Numpy version: 1.24.3
Pytorch version: 2.0.0+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3
MONAI __file__: /homes/mhadlich/.conda/envs/monai/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.12
Nibabel version: 5.1.0
scikit-image version: 0.20.0
Pillow version: 9.5.0
Tensorboard version: 2.13.0
gdown version: 4.7.1
TorchVision version: 0.15.1+cu117
tqdm version: 4.65.0
lmdb version: 1.4.1
psutil version: 5.9.5
pandas version: 2.0.1
einops version: 0.6.1
transformers version: 4.21.3
mlflow version: 2.3.1
pynrrd version: 1.0.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Linux
Linux version: Ubuntu 22.04.2 LTS
Platform: Linux-5.15.0-73-generic-x86_64-with-glibc2.35
Processor: x86_64
Machine: x86_64
Python version: 3.10.10
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: [popenfile(path='/projects/mhadlich_segmentation/sliding-window-based-interactive-segmentation-of-volumetric-medical-images_main/tmp.txt', fd=1, position=1040, mode='w', flags=32769)]
Num physical CPUs: 48
Num logical CPUs: 48
Num usable CPUs: 1
CPU usage (%): [100.0, 100.0, 59.9, 1.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.4, 0.2, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 0.0, 0.4, 0.0, 0.0, 2.7, 0.0, 0.0, 0.0, 0.0]
CPU freq. (MHz): 1724
Load avg. in last 1, 5, 15 mins (%): [5.1, 5.0, 5.1]
Disk usage (%): 66.3
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 1007.8
Available memory (GB): 980.8
Used memory (GB): 20.0

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 11.7
cuDNN enabled: True
cuDNN version: 8500
Current device: 0
Library compiled for CUDA architectures: ['sm_37', 'sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86']
GPU 0 Name: NVIDIA RTX A6000
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 84
GPU 0 Total memory (GB): 47.5
GPU 0 CUDA capability (maj.min): 8.6

Additional context
I will publish the code in my masters thesis at the end of September so if it should be necessary, I might be able to share it beforehand.

@KumoLiu
Copy link
Contributor

KumoLiu commented Jun 21, 2023

Hi @matt3o, thanks for your detailed investigation here.
I'm not sure what's the purpose of adding GarbageCollector which I guess one effect may be that users can choose whether to clear by themselves.
Also, I think caching memory always can be a way to speed up. So maybe we should not treat clear caching as a default behavior IMO.

Hi @ericspod, could you please also share some comments here?
Thanks in advance!

@ericspod
Copy link
Member

Hi @matt3o thanks for the work, but yes this is a known issue that usually you don't see with most usage of MONAI which is for relatively straight forward training scripts. I'm pretty sure the issue is what you're describing, Python objects retain a hold of main and GPU memory but are scheduled for deletion by the garbage collector based on main memory usage only, so those taking up little main memory but large GPU chunks need to be explicitly cleaned up.

I don't see an easy fix for this. Ideally the garbage collector would be modified to take into account multiple sorts of memory usages, but this is way beyond the scope of anything we can do and what Python is intended for. I don't think this is a MONAI specific problem or even related to Pytorch, it is this property of how garbage-collected languages work which we need to work at. I have had to use the GarbageCollector class to force collection in similar circumstances, I went through the transforms and other parts of the code I was using and concluded it was just a lack of collection.

We can ameliorate the problem somewhat by being more careful about tensor handling such that we don't create large numbers of temporary tensors during calculation, reuse tensors and use inline ops when possible, don't repeatedly recreate Python objects which could retain reference to tensors and cause more things to not get cleaned up, and otherwise keep in mind that the collector isn't magical and so could use some help.

Is your solution working for you completely then? If not we need to look into the transforms and other code involved to see where references are being retained. Either way we should have a notebook in the tutorials repo on optimisation and pitfalls of this sort.

CC @wyli @Nic-Ma

@matt3o
Copy link
Contributor Author

matt3o commented Jun 21, 2023

Hey @KumoLiu and @ericspod! Thanks for your quick and extensive responses.
First of all: The solution with the GarbageCollector appears to be working perfectly, so all good in that regard. As I said, I have no idea if / how other frameworks are dealing with this as I don't have much experience with other Deep Learning tools on big 3d volumes.
What you @ericspod describe fits perfectly what I have seen. I think there are some operations which have more potential to create these memory leaks by creating GPU pointers which don't get cleared later on. As I wrote I figured out by guessing which these where. One example is EnsureTyped which in my case created a lot of randomness compared to ToTensord which appeared to work much more reliably in terms of the GPU memory consumption. (May however be complete nonsense, that was during the debugging time)
Also this is modified DeepEdit code, so the click_transforms run 10 times on each sample which is probably the corner case triggering this error. Plus as I wrote I converted the transforms from being on the RAM (which gets cleared well by the GC) to the GPU, which as we have seen does not get handled very well by the GC. Initially I tried a lot of torch.cuda.empy_caches() which did not clear a lot of GPU memory and did not make any sense to me why the memory usage always increased.

What would be good at least, is some documentation on how to handle OOMs and to consider trying out the GarbageCollector, to see, if it fixes the issue. So I really like your proposal with a notebook on those pitfalls. As of right now I did not find much information when googling the issues I ran into which was really frustrating.

@ericspod
Copy link
Member

Hi all, I've added an issue on the tutorials repo about adding a tutorial on memory management. Let's consider adding this to the features we'd like to target for 1.3.

@matt3o
Copy link
Contributor Author

matt3o commented Jul 21, 2023

Pretty sure by now the problem is linked to this bug report from Pytorch: pytorch/pytorch#50185
I ran his code sample and it still crashes. Looks pretty similar to what I am dealing with here. Also adding a gc.collect() clears sufficiently memory for the code to no longer trigger OOM.
If more people should be complaining about this, one could consider looking into his fix which modifies the Cuda Caching Allocator.

In another report the pytorch people describe one other solution: Explicitly deleting everything after usage pytorch/pytorch#20199 (comment).

Maybe that helps someone in the future..

@wyli wyli closed this as completed Jul 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants