Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preallocation Option for vjp calculation #671

Merged
merged 10 commits into from
Feb 5, 2022

Conversation

ba2tro
Copy link
Contributor

@ba2tro ba2tro commented Jan 6, 2022

Added precaching feature to avoid heap allocations on vjp calculation in reverse pass

Added precaching feature to avoid heap allocations on vjp calculation in reverse pass
@@ -40,11 +40,25 @@ struct FastDense{F,F2} <: FastLayer
σ::F
initial_params::F2
bias::Bool
precache::Bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no need to store this. If precache=false, then just store nothing and check for that in the function.

@@ -40,11 +40,25 @@ struct FastDense{F,F2} <: FastLayer
σ::F
initial_params::F2
bias::Bool
precache::Bool
cs :: NamedTuple
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use informative variable names. cache instead of cs.

Also, this is not type stable, instead parameterize it ::C.

@@ -78,28 +92,53 @@ ZygoteRules.@adjoint function (f::FastDense)(x,p)
y = f.σ.(r)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you missed this allocation and some before.

Comment on lines 96 to 106
if typeof(f.σ) <: typeof(tanh)
f.cs.zbar = ȳ .* (1 .- y.^2)
elseif typeof(f.σ) <: typeof(identity)
f.cs.zbar = ȳ
else
f.cs.zbar = ȳ .* ForwardDiff.derivative.(f.σ,r)
end
f.cs.Wbar = f.cs.zbar * x'
f.cs.bbar = f.cs.zbar
f.cs.xbar = W' * f.cs.zbar
f.cs.pbar = if f.bias == true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines all still allocate. They need .=, mul!, etc.

tmp = typeof(f.cs.bbar) <: AbstractVector ? #how to find if bbar is AbstractVector and allocate its shape and size
vec(vcat(vec(f.cs.Wbar),f.cs.bbar)) :
vec(vcat(vec(f.cs.Wbar),sum(f.cs.bbar,dims=2)))
ifgpufree(f.cs.bbar)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's in the cache, then don't free it.

f.cs.xbar = W' * f.cs.zbar
f.cs.pbar = if f.bias == true
tmp = typeof(f.cs.bbar) <: AbstractVector ? #how to find if bbar is AbstractVector and allocate its shape and size
vec(vcat(vec(f.cs.Wbar),f.cs.bbar)) :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are allocating statements.

@ChrisRackauckas
Copy link
Member

It's a start but still has a long way to go. Need to test gradient accuracy in https://github.com/SciML/DiffEqFlux.jl/blob/master/test/fast_layers.jl (and inaccuracy in second calls), and should add a test for its usage in neural ODEs in https://github.com/SciML/DiffEqFlux.jl/blob/master/test/neural_de.jl and https://github.com/SciML/DiffEqFlux.jl/blob/master/test/neural_de_gpu.jl

@ba2tro
Copy link
Contributor Author

ba2tro commented Jan 6, 2022

Thanks for the feedback : ), I'll fix these issues

@ChrisRackauckas
Copy link
Member

What's the status here?

Parameter numcols can be specified with precache true to provide the max number of columns in the input(s), which otherwise defaults to 1.
@ba2tro
Copy link
Contributor Author

ba2tro commented Jan 24, 2022

Just pushed the required updates for allowing matrix inputs. It takes views when number of columns in input is less than pre specified number, otherwise everything is done with the preallocated buffers with full size(1 by default).

Comment on lines 96 to 98
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test that these gradients match the non-caching one.

@@ -38,6 +44,12 @@ fsgrad = Flux.Zygote.gradient((x,p)->sum(fs(x,p)),x,pd)
@test fdgrad[1] ≈ fsgrad[1]
@test fdgrad[2] ≈ fsgrad[2] rtol=1e-5

fdcgrad = Flux.Zygote.gradient((x,p)->sum(fdc(x,p)),x,pd)
@test fdgrad[1] ≈ fdcgrad[1]
@test fdgrad[2] ≈ fdcgrad[2] rtol=1e-5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why so high of a tolerance? Seems like an issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would 1e-9 be ok, any specific value?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-9 is fine. For something that's just changing to no allocs, I would expect it to pass at like 1e-12 at least.

@test fdgrad[1] ≈ fdcgrad[1]
@test fdgrad[2] ≈ fdcgrad[2] rtol=1e-5
@allocated fdc(x, pd);
@test @allocated fdc(x, pd) == 1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these allocations from?

zbar = ȳ .* (1 .- y.^2)
elseif typeof(f.σ) <: typeof(identity)
zbar = ȳ
cols = length(size(x)) == 1 ? 1 : size(x)[2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cols = length(size(x)) == 1 ? 1 : size(x)[2]
cols = size(x,2)

Copy link
Contributor Author

@ba2tro ba2tro Jan 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would cause error when x::AbstractVector but numcols>1, although with separated dispatches as you suggested below it won't occur

zbar = ȳ
cols = length(size(x)) == 1 ? 1 : size(x)[2]
if !isgpu(p)
f.cache.W .= @view p[reshape(1:(f.out*f.in),f.out,f.in)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this cached?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this just to avoid the pointer allocation from taking @view. Would you prefer it uncached?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if uncached it causes further allocations when its transpose is taken here

Copy link
Member

@ChrisRackauckas ChrisRackauckas Jan 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, transpose of a view allocates?

cache = nothing
end
new{typeof(σ), typeof(initial_params), typeof(cache)}(out,in,σ,initial_params,cache,bias,numcols)
# new{typeof(σ),typeof(initial_params)}(out,in,σ,initial_params,bias)
end
end

# (f::FastDense)(x,p) = f.σ.(reshape(uview(p,1:(f.out*f.in)),f.out,f.in)*x .+ uview(p,(f.out*f.in+1):lastindex(p)))
(f::FastDense)(x,p) = ((f.bias == true) ? (f.σ.(reshape(p[1:(f.out*f.in)],f.out,f.in)*x .+ p[(f.out*f.in+1):end])) : (f.σ.(reshape(p[1:(f.out*f.in)],f.out,f.in)*x)))

ZygoteRules.@adjoint function (f::FastDense)(x,p)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this two separate dispatches, one for x::AbstractVector and one for x::AbstractMatrix. That will make it a lot simpler.

@ChrisRackauckas
Copy link
Member

ba2tro and others added 2 commits January 31, 2022 18:34
Added separated dispatches for vector and matrix inputs, fixed tests and some allocations.
@ba2tro
Copy link
Contributor Author

ba2tro commented Jan 31, 2022

This return is causing the allocations here. Is there any workaround for this?

@ChrisRackauckas
Copy link
Member

You can have a preallocated vector for the return that you write into. Indeed it seems to instantiate the view.

Fixed some statements causing runtime allocations from adjoint calculation.
@ba2tro
Copy link
Contributor Author

ba2tro commented Feb 1, 2022

Making y a preallocated vector reduced some allocations but the Fastdense_adjoint that is also being returned here seems to be the main problem.

Just for experimenting I returned nothing, y , i.e., nothing for y and y instead of FastDense_adjoint and with the following @allocated call it returned just 176, which suggests FastDense_adjoint allocating

y is just a placeholder here because if nothing is returned in place of FastDense_adjoint then Zygote.pullback will error out.

WhatsApp Image 2022-02-01 at 2 08 30 PM

if typeof(f.cache) <: Nothing
y,FastDense_adjoint
else
@view(f.cache.y[:,1:f.cache.cols[1]]),FastDense_adjoint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have an out array that you write into and then return, instead of returning a view which will allocate the mutable struct of the view itself. That's probably the 176.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do anything for this view without the confirmation that numcols will equal cols?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can return the whole y if numcols equals cols and if cols is lesser, the view

else
r = W*x
f.cache.yvec,FastDense_adjoint
Copy link
Contributor Author

@ba2tro ba2tro Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am talking about this one. After replacing the view with a preallocated array yvec the allocations decrease by ~400 bytes but it still allocates ~1400, which disappears when FastDense_adjoint is removed from above, although we can't compute if its removed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes that closure will need to allocate. Don't worry about that for now. If that's all that's left, then we're good. Let's try to get this finished, merged, and then talk about what to do here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems good to run tests on

@@ -38,6 +44,12 @@ fsgrad = Flux.Zygote.gradient((x,p)->sum(fs(x,p)),x,pd)
@test fdgrad[1] ≈ fsgrad[1]
@test fdgrad[2] ≈ fsgrad[2] rtol=1e-5

fdcgrad = Flux.Zygote.gradient((x,p)->sum(fdc(x,p)),x,pd)
@test fdgrad[1] ≈ fdcgrad[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lower tolerance

Comment on lines +63 to +65
gradsnc = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(gradsnc[x])
@test ! iszero(gradsnc[node.p])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check that this matches the one without caching to a low tolerance. Set the ODE solver tolerance low for this test.

gradsc = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(gradsc[x])
@test ! iszero(gradsc[node.p])
@test gradsnc[x] ≈ gradsc[x]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tolerance on here?

Comment on lines 101 to 126
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
@test ! iszero(grads[xs])
@test ! iszero(grads[node.p])

node = NeuralODE(fastcdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint())
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
@test ! iszero(grads[xs])
@test ! iszero(grads[node.p])

goodgrad = grads[node.p]
p = node.p

node = NeuralODE(fastcdudt,tspan,Tsit5,save_everystep=false,save_start=false, sensealg=BacksolveAdjoint(),p=p)
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
@test !iszero(grads[xs])
@test ! iszero(grads[node.p])
goodgrad2 = grads[node.p]
@test goodgrad ≈ goodgrad2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

etc

Comment on lines 193 to 218
node = NeuralODE(fastcdudt,tspan,Tsit5(),save_everystep=false,save_start=false)
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
@test_broken ! iszero(grads[xs])
@test_broken ! iszero(grads[node.p])

node = NeuralODE(fastcdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0)
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
@test_broken ! iszero(grads[xs])
@test_broken ! iszero(grads[node.p])

node = NeuralODE(fastcdudt,tspan,Tsit5(),saveat=0.1)
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

@test_broken grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) isa Tuple
@test_broken ! iszero(grads[xs])
@test_broken ! iszero(grads[node.p])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be easier to read if the cached test is paired right next to the uncached test

@@ -244,8 +334,21 @@ grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode))
@test ! iszero(grads[sode.p])
@test ! iszero(grads[sode.p][end])

sode = NeuralDSDE(fastcdudt,fastcdudt2,(0.0f0,.1f0),SOSRI(),saveat=0.0:0.01:0.1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you set RNG seeds, do these gradients match the uncached versions? That's a good test.

Comment on lines 413 to 424

fastcddudt = FastChain(FastDense(6,50,tanh,numcols=size(xs)[2],precache=true),FastDense(50,2,numcols=size(xs)[2],precache=true))
NeuralCDDE(fastcddudt,(0.0f0,2.0f0),(p,t)->zero(x),(1f-1,2f-1),MethodOfSteps(Tsit5()),saveat=0.1)(x)
dode = NeuralCDDE(fastcddudt,(0.0f0,2.0f0),(p,t)->zero(x),(1f-1,2f-1),MethodOfSteps(Tsit5()),saveat=0.0:0.1:2.0)

grads = Zygote.gradient(()->sum(dode(x)),Flux.params(x,dode))
@test ! iszero(grads[x])
@test ! iszero(grads[dode.p])

@test_broken grads = Zygote.gradient(()->sum(dode(xs)),Flux.params(xs,dode)) isa Tuple
@test_broken ! iszero(grads[xs])
@test ! iszero(grads[dode.p])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check against the uncached.

Made corrections to existing tests and added a couple of new tests
@ChrisRackauckas
Copy link
Member

Test failure

@ba2tro
Copy link
Contributor Author

ba2tro commented Feb 5, 2022

On a simple adjoint calculation we get ~2x speedup with precaching
WhatsApp Image 2022-02-05 at 11 06 01 PM

@ChrisRackauckas ChrisRackauckas merged commit 30274ae into SciML:master Feb 5, 2022
@ChrisRackauckas ChrisRackauckas mentioned this pull request Feb 5, 2022
@ba2tro ba2tro deleted the precache branch February 7, 2022 05:52
@ba2tro ba2tro restored the precache branch February 7, 2022 05:53
@ba2tro ba2tro deleted the precache branch February 11, 2022 13:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants