-
Notifications
You must be signed in to change notification settings - Fork 64
Improve docs for structs: constructors and functors #366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,33 @@ | ||
| # On writing good `rrule` / `frule` methods | ||
|
|
||
| ## Code Style | ||
|
|
||
| Use named local functions for the `pullback` in an `rrule`. | ||
|
|
||
| ```julia | ||
| # good: | ||
| function rrule(::typeof(foo), x) | ||
| Y = foo(x) | ||
| function foo_pullback(Ȳ) | ||
| return NoTangent(), bar(Ȳ) | ||
| end | ||
| return Y, foo_pullback | ||
| end | ||
| #== output | ||
| julia> rrule(foo, 2) | ||
| (4, var"#foo_pullback#11"()) | ||
| ==# | ||
|
|
||
| # bad: | ||
| function rrule(::typeof(foo), x) | ||
| return foo(x), x̄ -> (NoTangent(), bar(x̄)) | ||
| end | ||
| #== output: | ||
| julia> rrule(foo, 2) | ||
| (4, var"##9#10"()) | ||
| ==# | ||
| ``` | ||
|
|
||
| ## Use `ZeroTangent()` as the return value | ||
|
|
||
| The `ZeroTangent()` object exists as an alternative to directly returning `0` or `zeros(n)`. | ||
|
|
@@ -35,6 +63,84 @@ Examples being: | |
| - There is only one derivative being returned, so from the fact that the user called | ||
| `frule`/`rrule` they clearly will want to use that one. | ||
|
|
||
| ## Structs: constructors and functors | ||
|
|
||
| To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`. | ||
| For example, the `rrule` signature would be like: | ||
|
|
||
| ```julia | ||
| function rrule(::typeof(foo), args...; kwargs...) | ||
| ... | ||
| return y, foo_pullback | ||
| end | ||
| ``` | ||
|
|
||
| For a struct `Bar`, | ||
| ```julia | ||
| struct Bar | ||
| a::Float64 | ||
| end | ||
|
|
||
| (bar::Bar)(x, y) = return bar.a + x + y # functor (i.e. callable object, overloading the call action) | ||
| ``` | ||
| we can define an `frule`/`rrule` for the `Bar` constructor(s), as well as any `Bar` [functors](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects). | ||
|
|
||
| ### Constructors | ||
|
|
||
| To define an `rrule` for a constructor for a _type_ `Bar` we need to be careful to dispatch only on `Type{Bar}`. | ||
| For example, the `rrule` signature for a `Bar` constructor would be like: | ||
| ```julia | ||
| function ChainRulesCore.rrule(::Type{Bar}, a) | ||
| Bar_pullback(Δbar) = NoTangent(), Δbar.a | ||
| return Bar(a), Bar_pullback | ||
| end | ||
| ``` | ||
|
|
||
| Use `Type{<:Bar}` (with the `<:`) for non-concrete types, such that the `rrule` is defined for all subtypes. | ||
| In particular, be careful not to use `typeof(Bar)` here. | ||
| Because `typeof(Bar)` is `DataType`, using this to define an `rrule`/`frule` will define an `rrule`/`frule` for all constructors. | ||
|
|
||
| You can check which to use with `Core.Typeof`: | ||
|
|
||
| ```julia | ||
| julia> function foo end | ||
| foo (generic function with 0 methods) | ||
|
|
||
| julia> typeof(foo) | ||
| typeof(foo) | ||
|
|
||
| julia> Core.Typeof(foob) | ||
| typeof(foo) | ||
|
|
||
| julia> typeof(Bar) | ||
| DataType | ||
|
|
||
| julia> Core.Typeof(Bar) | ||
| Type{Bar} | ||
|
|
||
| julia> abstract type AbstractT end | ||
|
|
||
| julia> typeof(AbstractT) | ||
| DataType | ||
|
|
||
| julia> Core.Typeof(AbstractT) | ||
| Type{AbstractT} | ||
| ``` | ||
|
|
||
| ### Functors (callable objects) | ||
|
|
||
| In contrast to defining a rule for a constructor, it is possible to define rules for calling an instance of an object. | ||
| In that case, use `bar::Bar`, i.e. | ||
|
|
||
| ```julia | ||
| function ChainRulesCore.rrule(bar::Bar, x, y) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @oxinabox just double checking this is right?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is right. Still even with them together this is a clear enhancement to the docs, so I am happy to have it merged and then to a follow up if you like |
||
| # Notice the first return is not `NoTangent()` | ||
| Bar_pullback(Δy) = Tangent{Bar}(;a=Δy), Δy, Δy | ||
| return bar(x, y), Bar_pullback | ||
| end | ||
| ``` | ||
| to define the rules. | ||
|
|
||
| ## Use `@not_implemented` appropriately | ||
|
|
||
| One can use [`@not_implemented`](@ref) to mark missing differentials. | ||
|
|
@@ -52,34 +158,6 @@ https://github.com/JuliaMath/SpecialFunctions.jl/issues/160 | |
|
|
||
| Do not use `@not_implemented` if the differential does not exist mathematically (use `NoTangent()` instead). | ||
|
|
||
| ## Code Style | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving code in the same PR as edits somehting else makes reviewing harder. |
||
|
|
||
| Use named local functions for the `pullback` in an `rrule`. | ||
|
|
||
| ```julia | ||
| # good: | ||
| function rrule(::typeof(foo), x) | ||
| Y = foo(x) | ||
| function foo_pullback(Ȳ) | ||
| return NoTangent(), bar(Ȳ) | ||
| end | ||
| return Y, foo_pullback | ||
| end | ||
| #== output | ||
| julia> rrule(foo, 2) | ||
| (4, var"#foo_pullback#11"()) | ||
| ==# | ||
|
|
||
| # bad: | ||
| function rrule(::typeof(foo), x) | ||
| return foo(x), x̄ -> (NoTangent(), bar(x̄)) | ||
| end | ||
| #== output: | ||
| julia> rrule(foo, 2) | ||
| (4, var"##9#10"()) | ||
| ==# | ||
| ``` | ||
|
|
||
| While this is more verbose, it ensures that if an error is thrown during the `pullback` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it. | ||
| This makes it a lot simpler to debug from the stacktrace. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Seth's point about
<:(for non-concrete types) is worth making too#331 (comment)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general probably should always use
<:insideType.Since it means your code doesn't have to change if it gains type-parameters, and thus is no longer concrete.
Plus saves thinking
(possibly even this should be added to BlueStyle)