diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index acbd83e54..4b4c49d00 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,3 +1,5 @@ +using Base.Meta + # These are some macros (and supporting functions) to make it easier to define rules. """ @scalar_rule(f(x₁, x₂, ...), @@ -198,7 +200,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) end end -# For context on why this is important, see +# For context on why this is important, see # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 "Declares properly hygenic inputs for propagation expressions" _propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] @@ -307,11 +309,11 @@ macro non_differentiable(sig_expr) unconstrained_args = _unconstrain.(constrained_args) primal_invoke = if !has_vararg - :($(primal_name)($(unconstrained_args...); kwargs...)) + :($(primal_name)($(unconstrained_args...))) else normal_args = unconstrained_args[1:end-1] var_arg = unconstrained_args[end] - :($(primal_name)($(normal_args...), $(var_arg)...; kwargs...)) + :($(primal_name)($(normal_args...), $(var_arg)...)) end quote @@ -320,11 +322,19 @@ macro non_differentiable(sig_expr) end end +"changes `f(x,y)` into `f(x,y; kwargs....)`" +function _with_kwargs_expr(call_expr::Expr) + @assert isexpr(call_expr, :call) + return Expr( + :call, call_expr.args[1], Expr(:parameters, :(kwargs...)), call_expr.args[2:end]... + ) +end + function _nondiff_frule_expr(primal_sig_parts, primal_invoke) return esc(:( function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...) # Julia functions always only have 1 output, so return a single DoesNotExist() - return ($primal_invoke, DoesNotExist()) + return ($(_with_kwargs_expr(primal_invoke)), DoesNotExist()) end )) end @@ -349,11 +359,15 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) Expr(:call, propagator_name(primal_name, :pullback), :_), Expr(:tuple, DoesNotExist(), Expr(:(...), tup_expr)) ) - return esc(:( - function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...) + return esc(quote + # Manually defined kw version to save compiler work. See explanation in rules.jl + function (::Core.kwftype(typeof(rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...)) + return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr) + end + function ChainRulesCore.rrule($(primal_sig_parts...)) return ($primal_invoke, $pullback_expr) end - )) + end) end diff --git a/src/rules.jl b/src/rules.jl index 35967d27b..f0609894c 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -103,4 +103,16 @@ true See also: [`frule`](@ref), [`@scalar_rule`](@ref) """ -rrule(::Any, ::Vararg{Any}; kwargs...) = nothing +rrule(::Any, ::Vararg{Any}) = nothing + +# Manual fallback for keyword arguments. Usually this would be generated by +# +# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing +# +# However - the fallback method is so hot that we want to avoid any extra code +# that would be required to have the automatically generated method package up +# the keyword arguments (which the optimizer will throw away, but the compiler +# still has to manually analyze). Manually declare this method with an +# explicitly empty body to save the compiler that work. + +(::Core.kwftype(typeof(rrule)))(::Any, ::Any, ::Vararg{Any}) = nothing