Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: return full gradient of all arguments in gradient #535

Closed
Roger-luo opened this issue Mar 6, 2020 · 13 comments
Closed

RFC: return full gradient of all arguments in gradient #535

Roger-luo opened this issue Mar 6, 2020 · 13 comments

Comments

@Roger-luo
Copy link
Contributor

Roger-luo commented Mar 6, 2020

I feel it should return the gradient of each argument when FluxML/Flux.jl#1073 is merged since this will allow one to use optimizers directly on structures (maybe also related to FluxML/Flux.jl#637 ) it would be more convenient to just return the gradient of all arguments since we could have

m = Chain(Dense(100, 100), sum)
x = rand(Float32, 100)
Δm, Δx = gradient(m, x)
opt = ADAM()
# train step etc.
nepochs = 1000
for k in 1:nepochs
    update!(opt, m, Δm)
end

(which is actually my case, the output of the model is a probability and I need to use a similar but more complicated code to do policy gradient), currently one has to workaround this by gradient((m, x)->m(x), m, x) which I find is less convenient... The only thing needs to change is the following line tho

function pullback(f, args...)

which also simplifies the logic of pullback in Zygote side a bit I think and there is no need to have _pullback and pullback

@CarloLucibello
Copy link
Member

What if I do not what the gradient respect to m?
Also, gradient((m, x) -> m(x), m, x) or, more typically, gradient(m -> m(x), m), don't look that bad

@oxinabox
Copy link
Member

What if I do not what the gradient respect to m?

What if I don't want the argument with respect to x ?
I use first or last or indexing, or destructuring assignment _, dx = gradient(...)

I am gently pro this.
People should understand the complexities of the world,
and that the function itself is a one of the arguments to the operation the function does.
and that that only matters if the function is a functor/closure.
but that it is not itself special in anyway, it is just one of many arguements.

@Roger-luo
Copy link
Contributor Author

I think my main point here is to be consistent and transparent with internals, this will make things much simple, I have a few reasons for this change:

Firstly, returning the full gradient and using something like _, gs = gradient(m, x) won't add much inconvenience,

Secondly, this will simplify the logic inside the internals, and make it easier to understand things.

Lastly, the semantic of this gradient API I think was mainly because of the f'(x) API, and now it's removed. Thus there is no need to keep it by adding a bunch of redundant things inside the implementation.

@CarloLucibello
Copy link
Member

my point is I don't want to waste computation in computing gradients I don't need. Sometimes I need the gradient of m, sometimes the gradient of x, and it seems to me that with this interface I always compute both. Is there something I'm not understanding?

@Roger-luo
Copy link
Contributor Author

@CarloLucibello No, Zygote will always compute all the gradient no matter you need it or not, since what Zygote does here is just simply compose different pullbacks together, e.g

The pullback of matrix multiplication is defined as

https://github.com/FluxML/Zygote.jl/blob/master/src/lib/array.jl#L291

@adjoint function(A::AbstractMatrix * B::AbstractMatrix)
  return A * B, Δ::AbstractMatrix->* B', A' * Δ)
end

when you call the backward pass of this pullback, both value will be calculated, and I don't think Julia compiler will always drop the unused value, since the way Zygote drop one of it is to simply use Base.tail

https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/interface.jl#L31

If you benchmark it, this is actually can be a bit slower and more allocation due to the use of tail (more allocation)

julia> A = rand(100, 100); B = rand(100, 100);

julia> _, back = Zygote._pullback((A, B)->sum(A * B), A, B)
(249662.01236290898, (#37))

julia> @benchmark back(1.0)
BenchmarkTools.Trial:
  memory estimate:  234.98 KiB
  allocs estimate:  14
  --------------
  minimum time:     879.487 μs (0.00% GC)
  median time:      988.103 μs (0.00% GC)
  mean time:        1.037 ms (3.23% GC)
  maximum time:     8.125 ms (87.60% GC)
  --------------
  samples:          4805
  evals/sample:     1

julia> @benchmark Base.tail(back(1.0))
BenchmarkTools.Trial:
  memory estimate:  235.02 KiB
  allocs estimate:  15
  --------------
  minimum time:     879.982 μs (0.00% GC)
  median time:      988.425 μs (0.00% GC)
  mean time:        1.044 ms (3.26% GC)
  maximum time:     8.084 ms (87.52% GC)
  --------------
  samples:          4777
  evals/sample:     1

The best solution to avoid such calculation so far, I think is to make use of Thunk in ChainRules, but we need to wait @oxinabox 's PR to get merged.

@CarloLucibello
Copy link
Member

thanks for the explanation, that wasn't clear to me. So, when we'll have thunks, how would you compute gradients only for selected parameters according to your proposal?

@oxinabox
Copy link
Member

oxinabox commented Mar 12, 2020

With thunks, anything not unthunked is never computed.
So one will have say:
https://github.com/JuliaDiff/ChainRules.jl/blob/59f06bce6ebb21ed2d5f71769887d9b6523c9d61/src/rulesets/LinearAlgebra/dense.jl#L95-L100
a pullback returning grads = (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ)))
and then if you do dx = unthunk(grads[2])
then only Ȳ * B' will be computed, not A' * Ȳ

@CarloLucibello
Copy link
Member

still, I think gradient should do the unthunking for you and keep the current behavior (otherwise I don't see how to specify what to unthunk)

@oxinabox
Copy link
Member

oxinabox commented Mar 13, 2020

still, I think gradient should do the unthunking for you and keep the current behavior (otherwise I don't see how to specify what to unthunk)

you specify via unthunk.
But in anycase that stuff is not in Zygote at all right now, and in the initial PR everything is unthunked immediately (preserving current behavour).


Back on topic: point of mentioning thunking at all what that right now gradient does compute all derivatives, even though it doesn't return them all.
So there is no performance change with @Roger-luo 's proposal

@Roger-luo
Copy link
Contributor Author

I found another case related, current API prevents us from defining the gradient for the callable objects since the first argument will be ignored, e.g

struct Linear
    W
    b
end

@adjoint function (::Linear)(x)
    function pullback(y)
         grad_x
    end
   return Linear(x), pullback
end

but since this would add a nothing for this adjoins, we are not able to define the gradient of Linear ourselves, unless overloading Zygote._pullback

@MikeInnes
Copy link
Member

@Roger-luo that syntax already does what you want. It's also unrelated to the gradient interface.

Thunking is relevant, though, since once we have that ability, we'll have to provide some way to communicate what not to calculate. Currently that's done by closing over variables rather than passing them (even though it doesn't actually improve performance as yet, it could). With this change we'd have to expose Thunk objects to the user to get the same advantage. That would make gradient increasingly complicated and hard to understand, and increasingly unrelated to the mathematical operation most users are interested in.

The main advantage of the current API is that you can write things like gradient(sin, 1) and it behaves intuitively. In almost all common cases the gradient of the function is going to be nothing, and I don't really buy that users need to understand that the function is an argument to itself (which is something of a Julia-specific quirk).

The original motiviating example is not very convincing, since in general you actually write something like gradient((m, x, y) -> loss(m(x), y), m, x), in which case returning the extra argument saves you nothing.

@oxinabox
Copy link
Member

oxinabox commented Mar 16, 2020

Thunking is relevant, though, since once we have that ability, we'll have to provide some way to communicate what not to calculate

One thing we might think about is something like returning a Gradient object,
Which is iterable and indexable like a tuple (or even like a named tuple?)
that calls unthunk when you iterate or index it.

There is a related issue JuliaDiff/ChainRulesCore.jl#121
about adding something like this behavour to Composite.

@Roger-luo
Copy link
Contributor Author

The main advantage of the current API is that you can write things like gradient(sin, 1) and it behaves intuitively. In almost all common cases the gradient of the function is going to be nothing, and I don't really buy that users need to understand that the function is an argument to itself (which is something of a Julia-specific quirk).

Yeah in the context of "function"s, it kinda makes sense, I'll close this issue then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants