Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

frule API #128

Closed
willtebbutt opened this issue Feb 12, 2020 · 12 comments
Closed

frule API #128

willtebbutt opened this issue Feb 12, 2020 · 12 comments
Labels
design Requires some desgin before changes are made forward-mode Related to use of ChainRules for ForwardMode AD

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Feb 12, 2020

Current Situation

The current frule API is as follows:

function frule(f, x, y, df, dx, dy)
    z = some_computation
    dz = some_other_computation
    return z, dz
end

The problems start when we encounter varargs. Consider

function f(x...)
    return some_computation
end

To write an frule for this, the following (up to a probable off-by-one error) is required:

function frule(f, args...)
    n = length(args) + 1
    x = args[1:div(n, 2) - 1]
    df = args[div(n, 2)]
    dx = args[div(n, 2)+1:end]
    z = some_computation
    dz = some_other_computation
    return z, dz
end

There might be some slightly cleaner syntax to achieve the same result, but the long of the short of it is that we find ourselves in the unpleasant situation in which we have to do arithmetic to parse the arguments to our functions 🙁 .

To be clear, no one wanted this behaviour, it just slipped through the cracks when we were making some API changes because we intentionally don't have a huge number of frules implemented at the minute (we should perhaps add some more frules to ensure that we can avoid future slip ups of this kind...)

Proposal

I can see a couple of choices here:

Option 1: Per the API used in ChainRulesTestUtils.jl

(and one of the original proposals in #74 )

Binary function

f(x, y) = some_computation

function frule((f, df)::Tuple{typeof(f), Any}, (x, dx), (y, dy))
    z = some_computation
    dz = some_computation
    return z, dz
end

Varargs function:

f(x...) = some_computation

function frule((f, df)::Tuple{typeof(f), Any}, args...)
    x = first.(args)
    dx = last.(args)
    z = some_computation
    dz = some_other_computation
    return z, dz
end

Type Constraints

f(x::Real) = some_computation

function frule((f, df)::Tuple{typeof(f), Any}, (x, dx)::Tuple{Real, SomeOtherType})
    z = some_computation
    dz = some_other_computation
    return z, dz
end

Option 2: Differentials before other args

This is directly ripped off from Mike's ForwardDiff-in-Zygote PR.

Binary function

f(x, y) = some_computation

function frule((df, dx, dy), f::typeof(f), x, y)
    z = some_computation
    dz = some_computation
    return z, dz
end

Varargs function:

f(x...) = some_computation

# Possibly we need a really careful implementation for this?
tail(x::Tuple) = x[2:end]

function frule(dargs, f::typeof(f), x...)
    x = first.(args)
    dx = tail(dargs)
    z = some_computation
    dz = some_other_computation
    return z, dz
end

Type Constraints

f(x::Real) = some_computation

function frule((df, dx)::Tuple{Any, Real}, f::typeof(f), x::Real)
    z = some_computation
    dz = some_other_computation
    return z, dz
end

My thoughts

While both options resolve the varargs issue, I'm leaning towards option 2, simply because imposing type constraints in option 1 is a bit annoying (lots of ::Tuple{T1, T2}-like things).

Note that type constraints are actually most straightforward under the current API. I don't believe that this comparatively small win (relative to option 2) is worth it, given the ugliness it introduces for varargs.

If anyone has

  • any other bright ideas for interfaces, or
  • any other considerations that we should make (reasons to like / dislike the current / proposed APIs)

please make them known :) (@oxinabox @nickrobinson251 @YingboMa @shashi @MikeInnes )

@MikeInnes
Copy link

One other case worth considering is the frule for Core._apply. I suspect that'd be significantly harder to get right with option 1 (but haven't thought it through deeply).

Another nice thing about option 2 is that you can split out the pushforward itself very easily, e.g.

function frule(ẋ, ::typeof(f), x...)
  push = (ẋ...) -> # something
  f(x...), push(tail(ẋ)...)
end

Not all frules can be expressed this way, but most can, and it makes it really convenient to deal with varargs, types, dispatch on the perturbations etc. Zygote's version of frule abstracts over this so that you can write:

@frule +(x...) = +(x...), +
@frule tail(x) = tail(x), tail

Example with dispatch (possibly very useful once we have Zeros etc. involved):

@frule function f(x...)
  push(ẋ::Real...) = sum(ẋ)
  push(ẋ::Int...) = sum(ẋ)
  return f(x...), push
end

There should be no runtime overhead from this (indeed we could force-inline in the macro to avoid that in most cases, if needed).

@shashi
Copy link
Collaborator

shashi commented Feb 14, 2020

I ended up using this because it was getting painful to write rules.

@inline @generated function split_args(args...)
    n = (length(args)-1) ÷ 2 # one extra for the function!
    primals  = ntuple(i->:(args[$i]), n)
    partials = ntuple(i->:(args[$(n+i+1)]), n)
    :(($(primals...),), args[$(n+1)], ($(partials...),))
end

@nickrobinson251 nickrobinson251 added the forward-mode Related to use of ChainRules for ForwardMode AD label Feb 14, 2020
@nickrobinson251
Copy link
Contributor

Is there a viable alternative to 2 that means we can avoid so many tail calls (without some other cost in its place)? e.g.

frule(df, (dx, dys...), f::typeof(f), x, ys...)

@willtebbutt
Copy link
Member Author

I think your proposal is a reasonable one @nickrobinson251 . Will try it out on #129 - imagine it will make things a little cleaner.

@oxinabox
Copy link
Member

oxinabox commented Feb 16, 2020

Is there a viable alternative to 2 that means we can avoid so many tail calls (without some other cost in its place)? e.g.

frule(df, (dx, dys...), f::typeof(f), x, ys...)

I definately agree with the goal.
However, it is one of the important conceptual pieces of ChainRules that the deriviative of the function object (dself) is not fundermentally any different to any of the other arguments.
It thus should stick with them.

The actual syntax:
frule(df, (dx, dys...), f::typeof(f), x, ys...) does not currently work (see JuliaLang/julia#34776)
however, if it did then the optimal syntax from my perspective would be:

frule((df, dx, dys...), f::typeof(f), x, ys...)

Perhaps we can write a macro that makes that work?

@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Feb 16, 2020

For the syntax I meant frule(df, dargs, f, args...)

@oxinabox
Copy link
Member

For the syntax I meant frule(df, dargs, f, args...)

Yes that is a problem,.
consider foo(x,y)
then that would be frule(dself, (dx, dy), foo, x, y)
which is making dself marked as some how special.
Which it isn't.

@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Feb 16, 2020

Yeah I agree conceptually it's not special.

I'm just raising the question of whether avoiding so many tail calls means in practice it makes implementing rulles much nicer, and is thereby worth being its own separate input.

Don't feel strongly either way. Just another option worth considering.

@MikeInnes
Copy link

I really think we should just leave the nicer syntax to the macro. If you're doing things like generating rules, then the frule overload makes sense, but then it's easier to deal with this stuff programmatically if there's no special treatment built-in.

The idea of making it as easy as possible to just overload frule without a macro is a noble one, but I just don't think it's viable once we start dealing with things like kwargs functions. If people think it's actually viable it'd be great to understand how that's going to work.

@willtebbutt
Copy link
Member Author

@MikeInnes do you have any thoughts about what frules for functions with kwargs should look like? Given our new frule design, I was thinking of something along the lines of

f(args...; a=5.0, b=4.0) = ...
frule((df, dargs..., (a=da, b=db)), f, args...; a=5.0, b=4.0)

ie. just require that the final element of the argument tangent vectors an appropriate NamedTuple.

@MikeInnes
Copy link

The internal frule API for kwargs is effectively fixed by Julia's internal representation of them, since it eliminates them early on in lowering:

julia> f() = foo(a, b; c = 1)
f (generic function with 1 method)

julia> @code_lowered f()
CodeInfo(
1%1 = (:c,)
│   %2 = Core.apply_type(Core.NamedTuple, %1)
│   %3 = Core.tuple(1)
│   %4 = (%2)(%3)
│   %5 = Core.kwfunc(Main.foo)
│   %6 = (%5)(%4, Main.foo, Main.a, Main.b)
└──      return %6

We actually call the non-kwargs function Core.kwfunc(foo)((c=1,), foo, a, b), so to overload that you'd have to define frule((dkwfoo, (c=dc,), dfoo, da, db), ::typeof(Core.kwfunc(foo)), kwargs, foo, a, b).

We unfortunately can't change that syntax, but we can do what Zygote does here: have @frule define both the foo and kwfunc(foo) overloads and automatically forward keyword arguments to the user-defined pushforward in a nicer way.

@nickrobinson251 nickrobinson251 added the design Requires some desgin before changes are made label Aug 4, 2020
@oxinabox
Copy link
Member

oxinabox commented Nov 9, 2020

I think we can close this now?
Julia is getting splatting in argument destructuring, so current API works great

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Requires some desgin before changes are made forward-mode Related to use of ChainRules for ForwardMode AD
Projects
None yet
Development

No branches or pull requests

5 participants