-
-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Preallocation Option for vjp calculation #671
Conversation
Added precaching feature to avoid heap allocations on vjp calculation in reverse pass
src/fast_layers.jl
Outdated
@@ -40,11 +40,25 @@ struct FastDense{F,F2} <: FastLayer | |||
σ::F | |||
initial_params::F2 | |||
bias::Bool | |||
precache::Bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no need to store this. If precache=false
, then just store nothing
and check for that in the function.
src/fast_layers.jl
Outdated
@@ -40,11 +40,25 @@ struct FastDense{F,F2} <: FastLayer | |||
σ::F | |||
initial_params::F2 | |||
bias::Bool | |||
precache::Bool | |||
cs :: NamedTuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use informative variable names. cache
instead of cs
.
Also, this is not type stable, instead parameterize it ::C
.
src/fast_layers.jl
Outdated
@@ -78,28 +92,53 @@ ZygoteRules.@adjoint function (f::FastDense)(x,p) | |||
y = f.σ.(r) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you missed this allocation and some before.
src/fast_layers.jl
Outdated
if typeof(f.σ) <: typeof(tanh) | ||
f.cs.zbar = ȳ .* (1 .- y.^2) | ||
elseif typeof(f.σ) <: typeof(identity) | ||
f.cs.zbar = ȳ | ||
else | ||
f.cs.zbar = ȳ .* ForwardDiff.derivative.(f.σ,r) | ||
end | ||
f.cs.Wbar = f.cs.zbar * x' | ||
f.cs.bbar = f.cs.zbar | ||
f.cs.xbar = W' * f.cs.zbar | ||
f.cs.pbar = if f.bias == true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines all still allocate. They need .=
, mul!
, etc.
src/fast_layers.jl
Outdated
tmp = typeof(f.cs.bbar) <: AbstractVector ? #how to find if bbar is AbstractVector and allocate its shape and size | ||
vec(vcat(vec(f.cs.Wbar),f.cs.bbar)) : | ||
vec(vcat(vec(f.cs.Wbar),sum(f.cs.bbar,dims=2))) | ||
ifgpufree(f.cs.bbar) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's in the cache, then don't free it.
src/fast_layers.jl
Outdated
f.cs.xbar = W' * f.cs.zbar | ||
f.cs.pbar = if f.bias == true | ||
tmp = typeof(f.cs.bbar) <: AbstractVector ? #how to find if bbar is AbstractVector and allocate its shape and size | ||
vec(vcat(vec(f.cs.Wbar),f.cs.bbar)) : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are allocating statements.
It's a start but still has a long way to go. Need to test gradient accuracy in https://github.com/SciML/DiffEqFlux.jl/blob/master/test/fast_layers.jl (and inaccuracy in second calls), and should add a test for its usage in neural ODEs in https://github.com/SciML/DiffEqFlux.jl/blob/master/test/neural_de.jl and https://github.com/SciML/DiffEqFlux.jl/blob/master/test/neural_de_gpu.jl |
Thanks for the feedback : ), I'll fix these issues |
What's the status here? |
Parameter numcols can be specified with precache true to provide the max number of columns in the input(s), which otherwise defaults to 1.
Just pushed the required updates for allowing matrix inputs. It takes views when number of columns in input is less than pre specified number, otherwise everything is done with the preallocated buffers with full size(1 by default). |
test/neural_de.jl
Outdated
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should test that these gradients match the non-caching one.
test/fast_layers.jl
Outdated
@@ -38,6 +44,12 @@ fsgrad = Flux.Zygote.gradient((x,p)->sum(fs(x,p)),x,pd) | |||
@test fdgrad[1] ≈ fsgrad[1] | |||
@test fdgrad[2] ≈ fsgrad[2] rtol=1e-5 | |||
|
|||
fdcgrad = Flux.Zygote.gradient((x,p)->sum(fdc(x,p)),x,pd) | |||
@test fdgrad[1] ≈ fdcgrad[1] | |||
@test fdgrad[2] ≈ fdcgrad[2] rtol=1e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why so high of a tolerance? Seems like an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would 1e-9 be ok, any specific value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1e-9 is fine. For something that's just changing to no allocs, I would expect it to pass at like 1e-12
at least.
test/fast_layers.jl
Outdated
@test fdgrad[1] ≈ fdcgrad[1] | ||
@test fdgrad[2] ≈ fdcgrad[2] rtol=1e-5 | ||
@allocated fdc(x, pd); | ||
@test @allocated fdc(x, pd) == 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are these allocations from?
src/fast_layers.jl
Outdated
zbar = ȳ .* (1 .- y.^2) | ||
elseif typeof(f.σ) <: typeof(identity) | ||
zbar = ȳ | ||
cols = length(size(x)) == 1 ? 1 : size(x)[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cols = length(size(x)) == 1 ? 1 : size(x)[2] | |
cols = size(x,2) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would cause error when x::AbstractVector but numcols>1, although with separated dispatches as you suggested below it won't occur
src/fast_layers.jl
Outdated
zbar = ȳ | ||
cols = length(size(x)) == 1 ? 1 : size(x)[2] | ||
if !isgpu(p) | ||
f.cache.W .= @view p[reshape(1:(f.out*f.in),f.out,f.in)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this cached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did this just to avoid the pointer allocation from taking @view. Would you prefer it uncached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if uncached it causes further allocations when its transpose is taken here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, transpose of a view allocates?
src/fast_layers.jl
Outdated
cache = nothing | ||
end | ||
new{typeof(σ), typeof(initial_params), typeof(cache)}(out,in,σ,initial_params,cache,bias,numcols) | ||
# new{typeof(σ),typeof(initial_params)}(out,in,σ,initial_params,bias) | ||
end | ||
end | ||
|
||
# (f::FastDense)(x,p) = f.σ.(reshape(uview(p,1:(f.out*f.in)),f.out,f.in)*x .+ uview(p,(f.out*f.in+1):lastindex(p))) | ||
(f::FastDense)(x,p) = ((f.bias == true) ? (f.σ.(reshape(p[1:(f.out*f.in)],f.out,f.in)*x .+ p[(f.out*f.in+1):end])) : (f.σ.(reshape(p[1:(f.out*f.in)],f.out,f.in)*x))) | ||
|
||
ZygoteRules.@adjoint function (f::FastDense)(x,p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this two separate dispatches, one for x::AbstractVector
and one for x::AbstractMatrix
. That will make it a lot simpler.
Looks like a real test failure? https://github.com/SciML/DiffEqFlux.jl/runs/4918538935?check_suite_focus=true#step:6:499 |
Added separated dispatches for vector and matrix inputs, fixed tests and some allocations.
This return is causing the allocations here. Is there any workaround for this? |
You can have a preallocated vector for the return that you write into. Indeed it seems to instantiate the view. |
Fixed some statements causing runtime allocations from adjoint calculation.
Making y a preallocated vector reduced some allocations but the Fastdense_adjoint that is also being returned here seems to be the main problem. Just for experimenting I returned y is just a placeholder here because if nothing is returned in place of FastDense_adjoint then Zygote.pullback will error out. |
if typeof(f.cache) <: Nothing | ||
y,FastDense_adjoint | ||
else | ||
@view(f.cache.y[:,1:f.cache.cols[1]]),FastDense_adjoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have an out
array that you write into and then return, instead of returning a view which will allocate the mutable struct of the view itself. That's probably the 176.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we do anything for this view without the confirmation that numcols will equal cols?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can return the whole y if numcols equals cols and if cols is lesser, the view
else | ||
r = W*x | ||
f.cache.yvec,FastDense_adjoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am talking about this one. After replacing the view with a preallocated array yvec the allocations decrease by ~400 bytes but it still allocates ~1400, which disappears when FastDense_adjoint is removed from above, although we can't compute if its removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes that closure will need to allocate. Don't worry about that for now. If that's all that's left, then we're good. Let's try to get this finished, merged, and then talk about what to do here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems good to run tests on
test/fast_layers.jl
Outdated
@@ -38,6 +44,12 @@ fsgrad = Flux.Zygote.gradient((x,p)->sum(fs(x,p)),x,pd) | |||
@test fdgrad[1] ≈ fsgrad[1] | |||
@test fdgrad[2] ≈ fsgrad[2] rtol=1e-5 | |||
|
|||
fdcgrad = Flux.Zygote.gradient((x,p)->sum(fdc(x,p)),x,pd) | |||
@test fdgrad[1] ≈ fdcgrad[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lower tolerance
gradsnc = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(gradsnc[x]) | ||
@test ! iszero(gradsnc[node.p]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check that this matches the one without caching to a low tolerance. Set the ODE solver tolerance low for this test.
test/neural_de.jl
Outdated
gradsc = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(gradsc[x]) | ||
@test ! iszero(gradsc[node.p]) | ||
@test gradsnc[x] ≈ gradsc[x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tolerance on here?
test/neural_de.jl
Outdated
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) | ||
@test ! iszero(grads[xs]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
node = NeuralODE(fastcdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) | ||
@test ! iszero(grads[xs]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
goodgrad = grads[node.p] | ||
p = node.p | ||
|
||
node = NeuralODE(fastcdudt,tspan,Tsit5,save_everystep=false,save_start=false, sensealg=BacksolveAdjoint(),p=p) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) | ||
@test !iszero(grads[xs]) | ||
@test ! iszero(grads[node.p]) | ||
goodgrad2 = grads[node.p] | ||
@test goodgrad ≈ goodgrad2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
etc
test/neural_de.jl
Outdated
node = NeuralODE(fastcdudt,tspan,Tsit5(),save_everystep=false,save_start=false) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple | ||
@test_broken ! iszero(grads[xs]) | ||
@test_broken ! iszero(grads[node.p]) | ||
|
||
node = NeuralODE(fastcdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple | ||
@test_broken ! iszero(grads[xs]) | ||
@test_broken ! iszero(grads[node.p]) | ||
|
||
node = NeuralODE(fastcdudt,tspan,Tsit5(),saveat=0.1) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple | ||
@test_broken ! iszero(grads[xs]) | ||
@test_broken ! iszero(grads[node.p]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be easier to read if the cached test is paired right next to the uncached test
test/neural_de.jl
Outdated
@@ -244,8 +334,21 @@ grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode)) | |||
@test ! iszero(grads[sode.p]) | |||
@test ! iszero(grads[sode.p][end]) | |||
|
|||
sode = NeuralDSDE(fastcdudt,fastcdudt2,(0.0f0,.1f0),SOSRI(),saveat=0.0:0.01:0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you set RNG seeds, do these gradients match the uncached versions? That's a good test.
test/neural_de.jl
Outdated
|
||
fastcddudt = FastChain(FastDense(6,50,tanh,numcols=size(xs)[2],precache=true),FastDense(50,2,numcols=size(xs)[2],precache=true)) | ||
NeuralCDDE(fastcddudt,(0.0f0,2.0f0),(p,t)->zero(x),(1f-1,2f-1),MethodOfSteps(Tsit5()),saveat=0.1)(x) | ||
dode = NeuralCDDE(fastcddudt,(0.0f0,2.0f0),(p,t)->zero(x),(1f-1,2f-1),MethodOfSteps(Tsit5()),saveat=0.0:0.1:2.0) | ||
|
||
grads = Zygote.gradient(()->sum(dode(x)),Flux.params(x,dode)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[dode.p]) | ||
|
||
@test_broken grads = Zygote.gradient(()->sum(dode(xs)),Flux.params(xs,dode)) isa Tuple | ||
@test_broken ! iszero(grads[xs]) | ||
@test ! iszero(grads[dode.p]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check against the uncached.
Made corrections to existing tests and added a couple of new tests
Test failure |
Added precaching feature to avoid heap allocations on vjp calculation in reverse pass