Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rules for mutating functions, @adjoint! and its documentation #1228

Closed
maximilian-gelbrecht opened this issue May 17, 2022 · 11 comments
Closed

Comments

@maximilian-gelbrecht
Copy link

maximilian-gelbrecht commented May 17, 2022

(This is related to https://discourse.julialang.org/t/zygote-jl-adjoint-mutating-inplace-adjoints/78241)

Inspecting the Zygote code, I can see that aside from @adjoint there is also @adjoint! that is used to declare the adjoints of some mutating functions (like push! etc). I can’t find any doc strings or documentation when and how this can be used. I suspect, there are some limitations as Zygote generally forbids mutating. Additionally, ChainRules in its documentation says (https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/which_functions_need_rules.html#Functions-which-mutate-arrays):

Rules for functions which mutate its arguments, e.g. sort!, should not be written at the moment. While technically they are supported, they would break Zygote.jl such that it would sometimes quietly return the wrong answer. This may be resolved in the future by allowing AD systems to opt-in or opt-out of certain types of rules.

And then goes on to demonstrate how to write these rules for Zygote nonetheless, with the example for a function that adds inplace to the input array. This seems to work (and is probably translated to an @adjoint! rule?).

(just copied from the ChainRules doc)

using ChainRules, Zygote
function addone!(array)
    array .+= 1
    return sum(array)
end

function ChainRules.rrule(::typeof(addone!), a)
    y = addone!(a)
    function addone!_pullback(ȳ)
        return NoTangent(), ones(length(a))
    end
    return y, addone!_pullback
end
julia> gradient(addone!, a)
([1.0, 1.0, 1.0],)

So what are the requirements that these rules defined by ChainRules or by @adjoint! work? In the linked issue at ChainRulesCore, they can also not really exactly name these requirements. At the very least there should be some documentation on that and even better Zygote should return some kind of warning if they are prone to fail.

@ToucheSir
Copy link
Member

To elaborate on my point in the topic, this is the totality of what @adjoint! does, translated to ChainRules syntax:

function ChainRules.rrule(::typeof(addone!), a)
    y = addone!(a)

    # @adjoint! adds:
    addone!_pullback(::NoTangent) = (NoTangent(), NoTangent())
    # note that in ZygoteRules parlance this would be:
    # addone!_pullback(::Nothing) = nothing

    function addone!_pullback(ȳ)
        return NoTangent(), ones(length(a))
    end
    return y, addone!_pullback
end

This is because the return values of mutating functions are often unused, and thus AD may pass in a null gradient to the pullback:

x = [...]
__unused__ = push!(x)
return sum(x)

So you can see that @adjoint! is not required to define rules like this. It is purely a convenience provided by ZygoteRules. Hence why I've put out a call for anyone interested to help with rewriting all our existing ones as rrules: #1209.

So what are the requirements that these rules defined by ChainRules or by @adjoint! work? In the linked issue at ChainRulesCore, they can also not really exactly name these requirements.

The requirement is that the AD support mutation sufficiently well. How that looks in practice is a little fuzzier, since AFAIK we don't have a ChainRules-compatible source-to-source AD which does support it. Note that Zygote does support a limited amount of mutation already, but not of array types. setfield is tracked without too many problems, as are mutating ops on internal structures that no user can tamper with and constrained types like Zygote.Buffer. Which I think brings us to your next point...

At the very least there should be some documentation on that and even better Zygote should return some kind of warning if they are prone to fail.

In general, it is not safe and Zygote does try to catch it. This is why you get those array mutation errors when using it. However, Zygote's compiler (i.e. the AD transform) has very little information to work with when generating pullback code and thus has to punt the responsibility for checking for mutation to the runtime.

I do agree that there should be better docs for this somewhere outside of the mention in the Buffer docs.

@maximilian-gelbrecht
Copy link
Author

maximilian-gelbrecht commented May 19, 2022

Sorry, I may have to phrase the issue/question differently: My primary interest it not exactly adjoint! but the capabilities of Zygote for routines that mutate arrays inplace. You are mentioning that Zygote has a limited capacity for this, but what exactly are the limits? The example from the ChainRules documentation mutating an array does work

using ChainRulesCore, Zygote
function addone!(array)
    array .+= 1
    return sum(array)
end

function ChainRulesCore.rrule(::typeof(addone!), a)
    y = addone!(a)
    function addone!_pullback(ȳ)
        return NoTangent(), ones(length(a))
    end
    return y, addone!_pullback
end

a = [3.3,2.1,2.3]
gradient(addone!, a)

So, what kind of functions like this do work / what doesn't work? It just seems very unclear to me.

@mcabbott
Copy link
Member

mcabbott commented May 19, 2022

As written this one gives wrong answers in most cases:

julia> Zygote.gradient([1,2,3]) do x
         addone!(x)^2
       end
ȳ = 18
([1.0, 1.0, 1.0],)

julia> ForwardDiff.gradient([1,2,3]) do x
         addone!(x)^2
       end
3-element Vector{Int64}:
 18
 18
 18

That's easy to correct, fill(ȳ, size(a)), but mutating functions in general are tricky. See examples here for fill!: JuliaDiff/ChainRules.jl#521 (comment) . This addone! doesn't seem to cause problems plugged into those examples, perhaps because the change in x is just addition, so you can blindly accumulate gradients from before & after, so perhaps it's pathologically simple. It also doesn't return the array it mutates, which is the typical pattern.

Edit: here's an example where mutation causes problems:

function ChainRulesCore.rrule(::typeof(addone!), a)
    y = addone!(a)
    addone!_pullback(ȳ) = NoTangent(), fill(@show(ȳ), size(a))
    y, addone!_pullback
end

gradient([1,2,3]) do x
    y = dot(x, x)  # pullback for dot closes over x
    z = addone!(x)  # pullback for addone! is not run, unless you uncomment +z:
    x[1] + y  # + z 
end

@ToucheSir
Copy link
Member

I got intent of the question from the start :). I would recommend re-reading the second half of my comment above.

@maximilian-gelbrecht
Copy link
Author

So, do I get this right: in general it is not recommended to do these rules for mutating functions. It is possible in some cases, but it's not possible to say in which cases it is and in which it isn't a priori?

@mcabbott
Copy link
Member

mcabbott commented May 19, 2022

Yes, it's not safe in general to define such rules. I think any function which mutates an input array can give you wrong answers.

not possible to say in which cases it is and in which it isn't a priori?

It's certainly repeatable, thus is known from the code. But correctness isn't a property of the rule alone. The dangers are that (1) some other rule may depend on x still having the value it did before, (2) if the return value of the function is not used, then Zygote may not run the pullback function, and (3) accumulation of contributions to x's gradient from before and after f! may not be correct, or may fail completely.

If you know that x is always something just made with similar etc, or is just some cache to save allocations, then it can be fine to define a rule for your own use. For library use, maybe a smarter Zygote could solve (2), (3). Solving (1) would I think need lots of copies, and ideally only when the array is in fact (going to be) mutated.

@maximilian-gelbrecht
Copy link
Author

maximilian-gelbrecht commented May 19, 2022

Thanks a lot for the answers.
I would actually be interested exactly in the case you are mentioning:

or is just some cache to save allocations, then it can be fine to define a rule for your own use.

A typical case of preallocating memory for a function that saves its results in this array. Lets say e.g. a Fourier transform, so to define a rule for inplace plans.

But I also do worry about correctness already enough, that I am not sure I want to go down this route. Maybe I do some tests though.

@ToucheSir
Copy link
Member

ToucheSir commented May 19, 2022

It also doesn't return the array it mutates, which is the typical pattern.

Note that this must be done transitively, which means that if you have a nested chain of functions:

f(x, y) = g(x, y) + sum(x)
function g(x)
  push!(x, 3)
  return y
end

Both g and f would have to be augmented to return/pass through the return value of push!(x) such that the gradient info is correctly propagated:

function f(x, y) 
  newx, res = g(x, y)
  return sum(newx) + res
end
function g(x, y)
  newx = push!(x, 3)
  return newx, y
end

Note that because Zygote does not have enough information at "compile" time to know which functions transitively call a mutating function (which itself returns the mutated value) and don't thread the return value through to their return value, every function would have to be augmented this way. That could cause performance and possibly even correctness issues.

Another problem is that not all mutating functions return the mutated value! setindex! is the biggest offender here, which means we'd have to rewrite all uses of it and any functions like it to something like x = setindex!(x, i, v). Again this may be technically feasible, but it's unclear whether or not it would have any unintended, possibly breaking consequences.

Now the question becomes: given all these caveats, how do setfield! and Buffer work at all? The key is that Zygote has another mechanism for caching gradients to keep track of non-local mutations to certain data types. Appropriately, this is the cache field of Zygote.Context, which you see referenced as __context__ in @adjoint rules and is made explicit for the ChainRules RuleConfig. The main downside of this mechanism is that all mutating rules must call and rely on Zygote-specific functionality (e.g.).

Why not enable the mutable value gradient cache for all arrays and not just Buffer? That's a deeper question which I don't have a clear answer to. The docs for Buffer note:

Buffer is not an AbstractArray and can't be used for linear algebra operations like matrix multiplication. This prevents it from being captured by pullbacks.

But does not elaborate. #75 was an experimental effort to make array mutation work, but I'm not sure if it ran into any fundamental issues which would've prevented further progress. Perhaps @MikeInnes would be able to provide a historical perspective on this?

@ToucheSir
Copy link
Member

A typical case of preallocating memory for a function that saves its results in this array. Lets say e.g. a Fourier transform, so to define a rule for inplace plans.

But I also do worry about correctness already enough, that I am not sure I want to go down this route. Maybe I do some tests though.

This can be done, but the preallocation part needs to be hidden in a rule and there are some caveats around usage. See PumasAI/SimpleChains.jl#59 for a bit more discussion.

@maximilian-gelbrecht
Copy link
Author

maximilian-gelbrecht commented May 20, 2022

Thanks for all the comments. I'll close this issue now, it is clearer to me what can work and what kind of tests I could do.

@MikeInnes
Copy link
Member

The simple answer to the original question is: you're free to use @adjoint!, and it'll work, but only for functions which mutate data structures (like Dict), and not arrays. (Which is probably not what you wanted or need.)

The more technical answer is that mutation works for anything that isn't captured by value in a pullback, which in practice means AbstractArray. For example, the gradient of c = a*b is ā = c̄*b; but if b is modified by the forward pass, the backwards pass will see b′ != b and calculate a meaningless ā. Simple examples won't reveal this bug, but it's a nasty silent error in larger programs. Despite it being array-like, we are able to make Buffer an exception simply by decreeing that it won't be captured by value (and explicitly managing its reference-like gradient), which avoids this issue.


Ok, here's more detail on why this stuff is hard to fix. Buckle up.

Supporting array mutation is indeed hard, but not for the obvious reasons. Arrays introduce two kinks that are actually pretty easy to deal with:

Firstly, the above issue with values being captured. But the structure of AD gives a surprisingly easy solution: just undo all the mutations in the backwards pass. For example you'll notice that setindex! preserves the overwritten data and restores it. There are cases where this could be surprising, but as long as you ask users to provide a referentially-transparent f to gradient(f, x) it's fine.

Secondly, mutation introduces non-local data flow. If thread A modifies an array and thread B uses it, information has teleported from A to B, and the gradient needs to teleport back during the backwards pass. For this reason gradients of reference types are themselves a unique reference, stored in a global Context object. Pullbacks for mutable types like Buffer ignore the input and produce nothing as an output, essentially pretending to be non-differentiable, but store gradient information in that global side-channel, then produce normal gradients for other values that interact with the Buffer.

You'll notice that Zygote claims the gradient of a Buffer is nothing:

xs = Zygote.Buffer([1])
xs[1] = 2
Zygote.gradient(xs -> xs[1]^2, xs) # => (nothing,)

The gradient isn't really nothing, it just isn't propagated as a regular value, but hidden in the Context. We could say Buffer has a "reference-like gradient" instead of a "value-like gradient".

(Dicts actually propagate the ref-like gradient as a value, taking care not to actually accumulate them as normal gradients. But this isn't functional: it's solely to avoid giving users nothing as above. [Alternatively, gradient could look up gradients for reference inputs, but this wouldn't work well with nested data structures. Although it could handle aliased inputs better...])

This is all perfectly workable, but you'll notice that the adjoints for Buffer interact directly with the gradient cache – they are explicitly written for a reference-like gradient. But we don't want to complicate a simple adjoint like a*b, and even if we did, it wouldn't support the value-like case (eg StaticArrays), so you'd have to write it twice. Abstracting over mutability is where this gets hairy.

In #75, Zygote does a bunch of magic to make value-like adjoints do the right thing for references, eg adjoints that produce a mutable need to retrieve (and clear) the accumulated gradient, which then is passed to the user's pullback. But adjoints might receive or produce arrays wrapped in things like Transpose, or other arbitrary data structures, which Zygote can't look into and fix up; boom, silent bug, or perhaps an obscure, undebuggable error miles from the source of the problem. It's not impossible to get right, but it's a minefield for users (adjoint authors). And that's before you get to performance issues with the extra indirection.

I can imagine ways to clean this situation up, eg by formalising this fixup operation, but not without ugliness. For my part, I'm going to go watch some Rich Hickey talks instead :)

(The other downside to having reference-like gradients is that it's more type restrictive: we have to fix the type of the gradient ahead of time for performance, as we do with Buffer. This wouldn't have been a huge deal without the other stuff though.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants