Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory usage #278

Closed
yaysummeriscoming opened this issue Nov 8, 2018 · 33 comments
Closed

Memory usage #278

yaysummeriscoming opened this issue Nov 8, 2018 · 33 comments
Labels
bug Something isn't working

Comments

@yaysummeriscoming
Copy link

yaysummeriscoming commented Nov 8, 2018

I’ve been able to get some great speeds out of DALI with pytorch - far beyond what torchvision can do. Problem is that I seem to be getting a memory leak that causes training to slow down & eventually crash on an OOM error. Below I’ve plotted average images/sec over the entire epoch, along with memory usage. This is training resnet 18 on a 1000 class 100k subset of imagenet, using dali-CPU mode together with Apex on the example script here: https://github.com/NVIDIA/DALI/blob/master/docs/examples/pytorch/main.py

screen shot 2018-11-08 at 10 50 41 pm

screen shot 2018-11-08 at 10 50 34 pm

The problem is most prominent using dali-cpu. Dali-gpu takes 90 epochs to reach the same memory usage. Note that as I'm just using 100k training examples, this is ~8 epochs on the full imagenet dataset.

Additionally the GPU version seems to use a massive amount of GPU memory. I’ve found I always need to halve the batch size compared to dali-cpu or torchvision. I see that this was touched on in issues 21 and 51, but is this much normal? For imagenet, I calculate that one 256 example float32 data batch should take 256 * 3 * 224 * 224 * 4 ~= 154Mb of memory, so it seems to me that the memory usage shouldn’t be that high?

I’m using Ubuntu 16.04, CUDA 9.2, cuDNN 7.3.0 together with a pre-1.0 version of pytorch built from the dev head. I’m running on google cloud on a machine with 12 vCPUs (6 real cores), 32GB of ram and a V100.

Edit: I'm using dali 0.4

@JanuszL JanuszL added bug Something isn't working question Further information is requested labels Nov 9, 2018
@JanuszL
Copy link
Contributor

JanuszL commented Nov 9, 2018

Hi,
It takes some time to reach the final memory consumption level in DALI, but it should happen after a dozen epochs. If you see that memory consumption is growing constantly it means that something is not working right.
And yes, DALI needs more memory for GPU pipeline that CPU one. One reason if nvJpeg which need a lot of memory to decompress jpeg images - the more threads you are using (--workers) the more instances of nvJpeg is spawned and this adds up. Also, all intermediate buffers that are allocated to pass data between operators take some memory.
So total memory that DALI consumes is something like:

  • memory for nvJpeg - for raw ImageNet images it could be as much as ~200MB
  • memory for each operator is batch_size * size_of_the_image_at_the_operators_output * 3 (channel number)

So it could take some memory.
Never the less we will look into this memory grown you are observing.
Tracked as DALI-354.
Br,
Janusz

@yaysummeriscoming
Copy link
Author

Regarding RAM usage, I did a longer test of 70 epochs with dali-gpu mode:

screen shot 2018-11-09 at 12 48 57 pm

screen shot 2018-11-09 at 12 48 51 pm

Using dali-cpu the effect is much stronger.

Thanks for clearing up regarding GPU memory usage. I had thought that workers was just a CPU option, so I had set it to 12. By reducing it to 8 I was able to keep the GPU batch size the same size as dali-cpu and torchvision. Would it make sense to have seperate worker options for CPU and GPU? It seems to me that the data augmentation operators are quite lightweight (for a GPU) so shouldn't require as many threads?

@JanuszL
Copy link
Contributor

JanuszL commented Nov 9, 2018

Hi,
Regarding workers it is for CPU and mixed operators (partially done on CPU, partially on CPU as nvJPEG is). For GPU itself there are no workers. I thinking setting workers to 3 or 4 per GPU should do. But you may test what works best for you.

@JanuszL
Copy link
Contributor

JanuszL commented Nov 16, 2018

@yaysummeriscoming - could you provide what exact cmd line command have you used to run this PyTorch example?
What data set did you use - raw ImageNet?

@yaysummeriscoming
Copy link
Author

/mnt/disks/ssd/ImageNet100 --print-freq 100

I'm using the first 100 classes of imagenet - I've modified resnet to suit.

@JanuszL
Copy link
Contributor

JanuszL commented Nov 19, 2018

I started fixing things reported by Valgrind - you can check #308.
@yaysummeriscoming - could you also share exact script you are using. Running the same stuff as you do will help me to narrow trace the memleak you are experiencing.

@yaysummeriscoming
Copy link
Author

Great, I'm running the script here:
https://github.com/NVIDIA/DALI/blob/master/docs/examples/pytorch/main.py

The only change I've made is on line 197, to change the model to have 100 classes:
model = models.dictargs.arch

I just retested it, the leak is present with dali-gpu but is worst using dali-cpu mode:
/mnt/disks/ssd/ImageNet100 --dali_cpu --print-freq 100 --workers 12

I'm running the first 100 imagenet classes off a google cloud VM together with a local SSD.

@JanuszL
Copy link
Contributor

JanuszL commented Nov 21, 2018

Hi,
I checked DALI with Valgrind, fixed what I have found (nothing that significant). With the CPU pipeline, your memory consumption will be definitely higher as all buffers are kept at the CPU side. It is expected to grow because how we handle the memory now:

  • GPU batch is a list of tensors allocated ahead before each operator is executed, its size could be as big as the sum of sizes of the biggest images in the batch
  • for CPU we have a vector of tensors, each tensor could be as big as the biggest image in the data-set.
  • once tensor list or tensor is allocated it is not freed, only it can get bigger. So at the end of the training, we can end up with the vector of tensors with the biggest possible size for given data set.
    I think we cannot do much with it now, we have bigger architecture rework in mind. We were hoping to resolve this with Memory refactor  #120 but there are other problems with it and we won't merge it.
    Can you try to perform training with the smaller batch and see if memory consumption saturates?

@yaysummeriscoming
Copy link
Author

Today I rebuilt pytorch from the dev head, together with cuda 10 & cudnn 7.4.1. I retested using a smaller batch size, unfortunately the problem is still present. I'm using dali 0.4.1 as issue #308 doesn't seem to have been merged in yet?

@JanuszL
Copy link
Contributor

JanuszL commented Nov 21, 2018

@yaysummeriscoming - no, #308 is still under review, but leaks it fixes are not that significant. So how many epochs you are able to train and what is the memory occupation after reducing the batch size?

@yaysummeriscoming
Copy link
Author

yaysummeriscoming commented Nov 21, 2018

Ok, here's a plot of memory usage:

screen shot 2018-11-21 at 4 38 04 pm

It looks like it's exactly following the memory usage trajectory I plotted earlier

Edit: batch size 128 = 1000 batches/epoch

@yaysummeriscoming
Copy link
Author

@JanuszL any progress on this one? This sounds related to #328

@JanuszL
Copy link
Contributor

JanuszL commented Nov 30, 2018

@yaysummeriscoming - regarding #328 we are able to test with batch 128 per GPU with 16GB memory.
Regarding total memory utilization we don't have any good intermediate solution. We just started redesigning some aspects of DALI architecture and we hope to address this issue there too.

@yaysummeriscoming
Copy link
Author

Ok thanks, I'm looking forwards to it. I'm very eager to deploy DALI into my training setup given how much faster it is!

@JanuszL
Copy link
Contributor

JanuszL commented Dec 14, 2018

Tracked as DALI-452

@yaysummeriscoming
Copy link
Author

So I retested this issue with DALI 0.6, cuda10 & cudnn 7.4.1, but unfortunately still no luck. I have however been able to develop a workaround by recreating all DALI objects & re-importing DALI at the end of every epoch.

I inserted the following lines of code at line 282 of https://github.com/NVIDIA/DALI/blob/master/docs/examples/pytorch/resnet50/main.py (just after the DALI iterators are reset):

    def print_ram_usage():
        import psutil
        pid = os.getpid()
        py = psutil.Process(pid)
        memoryUse = py.memory_info()[0] / 2. ** 30  # memory use in GB...I think
        print('memory use: %.2f GB' % memoryUse)

    print_ram_usage()

    # To workaround memory leak: delete all dali objects, reimport dali & then recreate all dali objects
    del train_loader, val_loader, pipe
    import importlib
    from nvidia import dali
    importlib.reload(dali)
    pipe = HybridTrainPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank,
                           data_dir=traindir, crop=crop_size, dali_cpu=args.dali_cpu)
    pipe.build()
    train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    pipe = HybridValPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank,
                         data_dir=valdir, crop=crop_size, size=val_size)
    pipe.build()
    val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

@yaysummeriscoming
Copy link
Author

Together with the above workaround I notice that sometimes on the beginning of an epoch, the first training batch is all 0. I could get around this with the following code at line 309, where the first training batch is generated:

if input.sum() == 0.:
    print('Null input encountered!!!  Loading next input')
    input, label = train_loader.next()

@JanuszL
Copy link
Contributor

JanuszL commented Dec 20, 2018

@yaysummeriscoming - it happen only when you are using this workaround or it happens without it as well?

@JanuszL
Copy link
Contributor

JanuszL commented Dec 21, 2018

I retested with your code:
1
This discontinuity in the plot is the validation phase (I have put prints only inside training part of the script).
After ~3 epochs (3000 iterations) memory consumption saturates at 20GB. The command I used to run it:

python -m torch.distributed.launch --nproc_per_node=1 main.py -a resnet50 --dali_cpu --fp16 --b 128 --static-loss-scale 128.0 --workers 4 --lr=0.4

Another interesting observation is that ImageNet file size could be as big as 15Mb, what in the end will resize each (or most) of prefetch buffer elements to this size. Taking into account it has 1024 elements DALI could consume 15GB for the file reading itself.
So it is not a leak, but we definitely want to improve memory consumption on the CPU.
We have few ideas how to do that but we need to finish basic of architecture rework before we can think about implementing these ideas.

@yaysummeriscoming
Copy link
Author

@JanuszL yes, the null input problem only happened with my workaround. I was also getting reduced accuracy with this workaround so I've removed it again - I imagine this might have been something to do with shuffling.

I retested again today, using the DALI example script with a smaller batch size & DALI CPU mode (I found that the problem seemed to be worse in this mode). Unfortunately I'm not seeing an upper bound on memory usage:

screen shot 2018-12-27 at 3 34 34 pm

The options I used were:
-a resnet18 --fp16 --print-freq 100 -b 128 --workers 12 --dali_cpu

I'm seeing some weird behaviour regarding processing speed. Processing speed seems to drop at much lower levels of RAM utilization. On a machine with 32GB RAM, processing speed drops after about 12k iterations. 40GB RAM gets me to 70k iterations. I'm no expert on RAM usage, but I see a lot of cached memory. Could my dataset (~15GB) be cached to RAM?

@blueardour
Copy link

same issue by pip install package~

@JanuszL
Copy link
Contributor

JanuszL commented Jun 24, 2019

Currently, we are reworking how CPU memory is used. One of the enablers is moving from per sample to per batch processing on CPU - #936. In the future, it will allow better host memory utilization as we will be able to allocate whole batch memory at the time, not the per sample (we only enlarge buffers, not freeing them to avoid expensive reallocations).

@JanuszL JanuszL removed the question Further information is requested label Jan 21, 2020
@mzient
Copy link
Contributor

mzient commented Feb 4, 2020

We're pleased to say that we've changed the allocation strategy for (non-pinned) CPU buffers. It reduces the memory consumption in RN50 training in PyTorch by almost 50%. Please check the latest master (or next successul nightly) and see if your issue is resolved.
The memory is now freed when a requested tensor is smaller than a given percentage of actual allocation. You can tweak it by setting the environment variable DALI_HOST_BUFFER_SHRINK_THRESHOLD=0.xx. The default value is 0.9. You can also set it in python using nvidia.dali.backend.SetHostBufferShrinkThreshold(threshold).

@JanuszL JanuszL added this to the Release_0.19.0 milestone Feb 4, 2020
@yaysummeriscoming
Copy link
Author

@mzient great thanks, hopefully I'll get some time next week to retest. Nice that theres an environment variable to control behaviour. If I follow correctly, setting the threshold to 0 will retain old behaviour?

@JanuszL
Copy link
Contributor

JanuszL commented Feb 6, 2020

setting the threshold to 0 will retain old behaviour?
@yaysummeriscoming - you are correct.

@JanuszL
Copy link
Contributor

JanuszL commented Mar 2, 2020

0.19 is out and should address this. Please reopen if it doesn't work.

@JanuszL JanuszL closed this as completed Mar 2, 2020
@forjiuzhou
Copy link

forjiuzhou commented Mar 9, 2020

0.19 is out and should address this. Please reopen if it doesn't work.

I tried the latest version, but the condition seemed the same for me. I run imagenet experiment on 1080 gpu, which has only 8gb memory. So for this to work, I need make sure the gpu memory usage stay the same for the entire training process, or the OOM error comes in.
Isn't there some way to make sure memory stay the same usage? Speed can slow down, memory is my concern here.
Thank you.

@JanuszL
Copy link
Contributor

JanuszL commented Mar 9, 2020

@forjiuzhou - 0.19 addresses CPU memory usage. GPU memory doesn't grow that much, and after a couple of epochs should stabilize. If you have 8GB of GPU memory, how about using a CPU pipeline?

@forjiuzhou
Copy link

@JanuszL A CPU pipeline does have smaller gpu memory usage, but it still grows. After dozens of epochs, OOM error still could happen. In my experience, a validation phrase could increase hundreds MB memory, it's wierd because training phrase actually grows very slow.

@JanuszL
Copy link
Contributor

JanuszL commented Mar 10, 2020

@forjiuzhou - the training pipeline memory consumption grows very slow (if at all) after a few epochs as it traversed many samples and hits the watermark pretty fast. However, in the case of the validation pipeline, the data is significantly slower so it takes more iterations to stabilize the memory consumption. But if you consider how many samples you need to process to get a stable memory consumption it would be the same in both cases I assume.

@yaysummeriscoming
Copy link
Author

Finally got a chance to test DALI 0.19. Memory usage is no longer rising, thanks for implementing this fix!

I couldn't see any difference in speed. If there is, its <1%.

@JanuszL
Copy link
Contributor

JanuszL commented Mar 20, 2020

Finally got a chance to test DALI 0.19. Memory usage is no longer rising, thanks for implementing this fix!

I couldn't see any difference in speed. If there is, its <1%.

Great to hear that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants
@blueardour @forjiuzhou @yaysummeriscoming @mzient @JanuszL and others