Skip to content

Incremental accumulation of gradients? #905

@jw3126

Description

@jw3126

From discourse https://discourse.julialang.org/t/zygote-gradient-accumulation/55654

I have a densenet inspired architecture implemented in pytorch and ported it to julia. Sadly I get out of memory errors now.
Here is a MWE, where Julia memory consumption is more then 5x 2x compared to pytorch:
On my laptop (GeForce RTX 2060 5.9GB) Julia throws out of memory error, while pytorch does not. Observe that the pytorch tensor is 5x 2x bigger!
Also when using the same tensor size in both Zygote and pytorch, pytorch is 2.5x faster.

Zygote

using Zygote
using CUDA

function net(x1)
    x2  = x1
    x3  = x1 + x2
    x4  = x1 + x2 + x3
    x5  = x1 + x2 + x3 + x4
    x6  = x1 + x2 + x3 + x4 + x5
    x7  = x1 + x2 + x3 + x4 + x5 + x6
    x8  = x1 + x2 + x3 + x4 + x5 + x6 + x7
    x9  = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
    x10 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
    x11 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10
    x12 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11
    #x13 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 + x12
    return x12
end

function loss(x)
    sum(abs2, net(x))
end

x = CUDA.randn(128,128,128,20)
Zygote.gradient(loss, x) # OOM error
x = CUDA.randn(128,128,128,10)
Zygote.gradient(loss, x) #warmup
CUDA.@time for _ in 1:100
    Zygote.gradient(loss, x)
end
# 26.188100 seconds (449.09 k CPU allocations: 15.913 MiB, 46.80% gc time) (11.30 k GPU allocations: 867.188 GiB, 59.73% gc time of which 18.51% spent allocating)

pytorch

import torch

def net(x1):
    x2  = x1
    x3  = x1 + x2
    x4  = x1 + x2 + x3
    x5  = x1 + x2 + x3 + x4
    x6  = x1 + x2 + x3 + x4 + x5
    x7  = x1 + x2 + x3 + x4 + x5 + x6
    x8  = x1 + x2 + x3 + x4 + x5 + x6 + x7
    x9  = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
    x10 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
    x11 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10
    x12 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11
    #x13 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 + x12
    return x12

def loss(x):
    return torch.sum(x**2)

x = torch.randn(128,128,128,40).to("cuda")
x.requires_grad = True
y = loss(net(x))
y.backward()
x.grad
x = torch.randn(128,128,128,10).to("cuda")
x.requires_grad = True
y = loss(net(x))
y.backward()
import time


start = time.time()
#with torch.autograd.profiler.profile(use_cuda=True) as prof:
for _ in range(100):
    x.grad.zero_()
    y = loss(net(x))
    y.backward()
stop = time.time()
stop - start
# 10.764797449111938

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions