diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 23895f5de..e8010566c 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -67,6 +67,40 @@ julia> rrule(foo, 2) While this is more verbose, it ensures that if an error is thrown during the `pullback` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it. This makes it a lot simpler to debug from the stacktrace. +## Use rule definition tools + +Rule definition tools can help you write more `frule`s and the `rrule`s with less lines of code. + +### [`@non_differentiable`](@ref) + +For non-differentiable functions the [`@non_differentiable`](@ref) macro can be used. +For example, instead of manually defining the `frule` and the `rrule` for string concatenation `*(String..)`, the macro call +``` +@non_differentiable *(String...) +``` +defines the following `frule` and `rrule` automatically +``` +function ChainRulesCore.frule(var"##_#1600", ::Core.Typeof(*), String::Any...; kwargs...) + return (*(String...; kwargs...), DoesNotExist()) +end +function ChainRulesCore.rrule(::Core.Typeof(*), String::Any...; kwargs...) + return (*(String...; kwargs...), function var"*_pullback"(_) + (Zero(), ntuple((_->DoesNotExist()), 0 + length(String))...) + end) +end +``` +Note that the types of arguments are propagated to the `frule` and `rrule` definitions. +This is needed in case the function differentiable for some but not for other types of arguments. +For example `*(1, 2, 3)` is differentiable, and is not defined with the macro call above. + +### [`@scalar_rule`](@ref) + +For functions involving only scalars, i.e. subtypes of `Number` (no `struct`s, `String`s...), both the `frule` and the `rrule` can be defined using a single [`@scalar_rule`](@ref) macro call. + +Note that the function does not have to be $\mathbb{R} \rightarrow \mathbb{R}$. +In fact, any number of scalar arguments is supported, as is returning a tuple of scalars. + +See docstrings for the comprehensive usage instructions. ## Write tests In [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl)