Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory (CPU and GPU) leaks during the 1st epoch #1510

Closed
alexeykarnachev opened this issue Apr 16, 2020 · 20 comments 路 Fixed by #1528
Closed

Memory (CPU and GPU) leaks during the 1st epoch #1510

alexeykarnachev opened this issue Apr 16, 2020 · 20 comments 路 Fixed by #1528
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@alexeykarnachev
Copy link
Contributor

alexeykarnachev commented Apr 16, 2020

馃悰 Bug

Hello.
This memory leak occurs during the first epoch. If one has a large epoch time (I had > 10 days), the OOM error will come. It's interesting, that in precision=16 mode, it leaks out on the GPU and the CPU both. If we switch amp optimization off (precision=32), the leak goes only on the CPU.
Also, I checked the number of tensors, which are tracked by the garbage collector. And it appeared to be linearly increasing during the first epoch, and then (on the 2nd epoch starts), it falls to the initial value and begins increasing again.
Let me provide the plots:


Experiment 1: amp_level='O2', precision=16

image
The number of tensors, tracked by garbage collector

image
GPU (the 2nd in my case) usage, tracked by pytorch-lightning

image
CPU memory usage by the process (bytes)


Experiment 2: amp_level=None, precision=None

image
The number of tensors, tracked by garbage collector

image
GPU (the 2nd in my case) usage, tracked by pytorch-lightning

image
CPU memory usage by the process (bytes)


As you can see, both cases have a CPU leak. The "amp"-case also has a GPU leak.
Also, it's clear, that such leaky behavior stops when the 2nd epoch starts.
On these plots, the 2nd epoch starts on the 2nd "saw claw" of the "Num-of-tensors" plot.
Also, there is another observation: the speed of tensors number increasing is 1001. And this is my forward pass method:

    def training_step(self, batch, batch_idx):
        losses = self.forward(batch)
        num_of_tensors = get_num_of_tensors()
        log = {'Num-of-tensors': num_of_tensors, 'Cpu-mem-usg': get_cpu_mem()}

        for i, loss in enumerate(losses):
            log[f'loss{i}'] = loss

        print(num_of_tensors)
        return {'loss': losses[0], 'log': log}

Here I return exactly 1001 tensor: one for loss and 1000 for log.
In my real experiments I had only 3 tensors. It took ~2-3 days to get OOM. But in the current example (see To Reproduce) it will crash much faster.

To Reproduce

Steps to reproduce the behavior:

  1. Execute Code sample (this script has no arguments, so change needed values manually in script).
  2. Go to the tensorboard to check plots.

Code sample

https://gist.github.com/alexeykarnachev/47de06b93a717ab0664eded42ed2826a

Expected behavior

The number of tensors, GPU and CPU memory does not increase during the training.

Environment

PyTorch version: 1.4.0
OS: Ubuntu 16.04.6 LTS
Python version: 3.7

Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] pytorch-lightning==0.7.3
[pip] torch==1.4.0
[pip] torchvision==0.5.0

Additional context

Sorry for so messy flow of the information, but I don't know, how to structure it more clearly.

@alexeykarnachev alexeykarnachev added bug Something isn't working help wanted Open to be worked on labels Apr 16, 2020
@williamFalcon
Copy link
Contributor

by leak you mean tensors build up during epoch 1? but after that the memory stays constant? ie: there is no more "leak" for epochs >= 2?

@alexeykarnachev
Copy link
Contributor Author

alexeykarnachev commented Apr 16, 2020

Yes, the memory stays constant after 1st epoch ends (although, the number of tensors begins increasing again)

@BartekRoszak
Copy link

BartekRoszak commented Apr 16, 2020

The whole output of a training step is stored.
In your code with every training step, there are new tensors created.
With z log[f'loss{i}'] = loss.item()there is no leak.

I think there is a mistake in optimizer_closure() in the training loop which, returns whole batch output dict. It should be enough to return only callback_metrics instead of the whole batch output.

@alexeykarnachev
Copy link
Contributor Author

Yes, I agreed, that with .item() there is no leak because all tensors "disappear in place" (I did not check it, but I believe that it so). But, I suppose, that .item() will slow my code.
On the other hand, .item() is performed anyway by the Trainer itself (before logging), so maybe it's not a big deal to call .item() beforehand. At least as a hotfix solution

@alexeykarnachev
Copy link
Contributor Author

alexeykarnachev commented Apr 16, 2020

Oh, no sorry, just checked: it will be a leak even if we perform log[f'loss{i}'] = loss.item()
Because we still have 'loss': losses[0] part (the actual loss tensor, which needs to be minimized).
So, it will be a leak with speed 1 tensor per step. It's very slow, but the OOM will occur anyway in 6-9 days

@williamFalcon
Copy link
Contributor

can you submit a PR? i thought we took care of all the metrics.
we should also use detach instead of item no? to not slow code down

@BartekRoszak
Copy link

BartekRoszak commented Apr 17, 2020

We take care of it in process_output() but then in optimizer_closure() we return original output_dict again.
We pass then a list of original outputs to the training_epoch_end().
I think w should not do that bc loss, log and progress_bar is handling by us in a proper way so we should return to training_epoch_end only other keys from output_dict and let a user manage it.

@alexeykarnachev
Copy link
Contributor Author

What about fp32-mode? There is no leak on the GPU in such a case. What could be the reason?

@alexeykarnachev
Copy link
Contributor Author

@AratorField , do you mean this?
https://github.com/PyTorchLightning/pytorch-lightning/blob/9b31272cf0f3079a244944096b4a81eec20fe555/pytorch_lightning/trainer/training_loop.py#L427-L428

Here is a list that stores all train step outputs during the epoch.

@williamFalcon
Copy link
Contributor

@alexeykarnachev
Copy link
Contributor Author

alexeykarnachev commented Apr 17, 2020

Yes, but they (tensors) are still on the GPU after detach. So, in case of long epochs or huge outputs from the training step, the GPU memory will blow after some time.

@BartekRoszak
Copy link

We can create something like _recursive_item() or remove keys loss, log, progress_bar from batch_output before appending to outputs.

@alexeykarnachev
Copy link
Contributor Author

Is it in general a good practice to store values during the epoch? The size of such a bookkeeping list is undetermined in the general case. I mean, that one could have almost an infinite epoch and sooner or later he'll be faced with OOM (GPU or CPU, it does not matter).

@williamFalcon
Copy link
Contributor

the thing is that .item() slows things down.
so we want to detach but not .item().

The tradeoff is that we plug the memory leak but slow things down.

@BartekRoszak
Copy link

There is no reason to store loss, log and progress_bar for the whole epoch.
Any other key in output_dict could be valuable and has to be stored i.e. for metrics calculating.

@alexeykarnachev
Copy link
Contributor Author

Maybe it's possible to introduce a flag, which shows, should we store tensors in this list during an epoch or not.
Or, maybe you can advise me some hot-fix, that I can apply locally. Because now, I can not train even 1 epoch :)

@alexeykarnachev
Copy link
Contributor Author

I even have no training_epoch_end method. Maybe, we can check if this method is not determined by the user, we can skip batch results bookkeeping?

@BartekRoszak
Copy link

@alexeykarnachev
Copy link
Contributor Author

Thank you, I'll patch it locally for now.

@johngrabner
Copy link

I am using pytorch-lightning 1.5.10
and deleted all logs for training (ie comment out in code), only logs are for val and made "limit_val_batches" a small value of 3000. During training, with "limit_train_batches"= 15000 (about 1 hr), I can bearly fit in 256G of CPU memory.
A batch of 30000 (about 2 hrs) freezes because my computer is out of CPU memory, zero chance to include a full epoch.

Is there a patch for 1.5.10?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
4 participants