-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
Description
From discourse https://discourse.julialang.org/t/zygote-gradient-accumulation/55654
I have a densenet inspired architecture implemented in pytorch and ported it to julia. Sadly I get out of memory errors now.
Here is a MWE, where Julia memory consumption is more then 5x 2x compared to pytorch:
On my laptop (GeForce RTX 2060 5.9GB) Julia throws out of memory error, while pytorch does not. Observe that the pytorch tensor is 5x 2x bigger!
Also when using the same tensor size in both Zygote and pytorch, pytorch is 2.5x faster.
Zygote
using Zygote
using CUDA
function net(x1)
x2 = x1
x3 = x1 + x2
x4 = x1 + x2 + x3
x5 = x1 + x2 + x3 + x4
x6 = x1 + x2 + x3 + x4 + x5
x7 = x1 + x2 + x3 + x4 + x5 + x6
x8 = x1 + x2 + x3 + x4 + x5 + x6 + x7
x9 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
x10 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
x11 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10
x12 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11
#x13 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 + x12
return x12
end
function loss(x)
sum(abs2, net(x))
end
x = CUDA.randn(128,128,128,20)
Zygote.gradient(loss, x) # OOM errorx = CUDA.randn(128,128,128,10)
Zygote.gradient(loss, x) #warmup
CUDA.@time for _ in 1:100
Zygote.gradient(loss, x)
end
# 26.188100 seconds (449.09 k CPU allocations: 15.913 MiB, 46.80% gc time) (11.30 k GPU allocations: 867.188 GiB, 59.73% gc time of which 18.51% spent allocating)pytorch
import torch
def net(x1):
x2 = x1
x3 = x1 + x2
x4 = x1 + x2 + x3
x5 = x1 + x2 + x3 + x4
x6 = x1 + x2 + x3 + x4 + x5
x7 = x1 + x2 + x3 + x4 + x5 + x6
x8 = x1 + x2 + x3 + x4 + x5 + x6 + x7
x9 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
x10 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
x11 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10
x12 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11
#x13 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 + x12
return x12
def loss(x):
return torch.sum(x**2)
x = torch.randn(128,128,128,40).to("cuda")
x.requires_grad = True
y = loss(net(x))
y.backward()
x.gradx = torch.randn(128,128,128,10).to("cuda")
x.requires_grad = True
y = loss(net(x))
y.backward()
import time
start = time.time()
#with torch.autograd.profiler.profile(use_cuda=True) as prof:
for _ in range(100):
x.grad.zero_()
y = loss(net(x))
y.backward()
stop = time.time()
stop - start
# 10.764797449111938CarloLucibello, ToucheSir, DoktorMike and oxinabox