-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
frule
API
#128
Comments
One other case worth considering is the frule for Another nice thing about option 2 is that you can split out the pushforward itself very easily, e.g. function frule(ẋ, ::typeof(f), x...)
push = (ẋ...) -> # something
f(x...), push(tail(ẋ)...)
end Not all frules can be expressed this way, but most can, and it makes it really convenient to deal with varargs, types, dispatch on the perturbations etc. Zygote's version of @frule +(x...) = +(x...), +
@frule tail(x) = tail(x), tail Example with dispatch (possibly very useful once we have @frule function f(x...)
push(ẋ::Real...) = sum(ẋ)
push(ẋ::Int...) = sum(ẋ)
return f(x...), push
end There should be no runtime overhead from this (indeed we could force-inline in the macro to avoid that in most cases, if needed). |
I ended up using this because it was getting painful to write rules. @inline @generated function split_args(args...)
n = (length(args)-1) ÷ 2 # one extra for the function!
primals = ntuple(i->:(args[$i]), n)
partials = ntuple(i->:(args[$(n+i+1)]), n)
:(($(primals...),), args[$(n+1)], ($(partials...),))
end |
Is there a viable alternative to 2 that means we can avoid so many frule(df, (dx, dys...), f::typeof(f), x, ys...) |
I think your proposal is a reasonable one @nickrobinson251 . Will try it out on #129 - imagine it will make things a little cleaner. |
I definately agree with the goal. The actual syntax:
Perhaps we can write a macro that makes that work? |
For the syntax I meant |
Yes that is a problem,. |
Yeah I agree conceptually it's not special. I'm just raising the question of whether avoiding so many Don't feel strongly either way. Just another option worth considering. |
I really think we should just leave the nicer syntax to the macro. If you're doing things like generating rules, then the The idea of making it as easy as possible to just overload |
@MikeInnes do you have any thoughts about what f(args...; a=5.0, b=4.0) = ...
frule((df, dargs..., (a=da, b=db)), f, args...; a=5.0, b=4.0) ie. just require that the final element of the argument tangent vectors an appropriate |
The internal julia> f() = foo(a, b; c = 1)
f (generic function with 1 method)
julia> @code_lowered f()
CodeInfo(
1 ─ %1 = (:c,)
│ %2 = Core.apply_type(Core.NamedTuple, %1)
│ %3 = Core.tuple(1)
│ %4 = (%2)(%3)
│ %5 = Core.kwfunc(Main.foo)
│ %6 = (%5)(%4, Main.foo, Main.a, Main.b)
└── return %6 We actually call the non-kwargs function We unfortunately can't change that syntax, but we can do what Zygote does here: have |
I think we can close this now? |
Current Situation
The current
frule
API is as follows:The problems start when we encounter varargs. Consider
To write an
frule
for this, the following (up to a probable off-by-one error) is required:There might be some slightly cleaner syntax to achieve the same result, but the long of the short of it is that we find ourselves in the unpleasant situation in which we have to do arithmetic to parse the arguments to our functions 🙁 .
To be clear, no one wanted this behaviour, it just slipped through the cracks when we were making some API changes because we intentionally don't have a huge number of
frule
s implemented at the minute (we should perhaps add some more frules to ensure that we can avoid future slip ups of this kind...)Proposal
I can see a couple of choices here:
Option 1: Per the API used in ChainRulesTestUtils.jl
(and one of the original proposals in #74 )
Binary function
Varargs function:
Type Constraints
Option 2: Differentials before other args
This is directly ripped off from Mike's ForwardDiff-in-Zygote PR.
Binary function
Varargs function:
Type Constraints
My thoughts
While both options resolve the varargs issue, I'm leaning towards option 2, simply because imposing type constraints in option 1 is a bit annoying (lots of
::Tuple{T1, T2}
-like things).Note that type constraints are actually most straightforward under the current API. I don't believe that this comparatively small win (relative to option 2) is worth it, given the ugliness it introduces for varargs.
If anyone has
please make them known :) (@oxinabox @nickrobinson251 @YingboMa @shashi @MikeInnes )
The text was updated successfully, but these errors were encountered: