diff --git a/Project.toml b/Project.toml index f951234a4..b213cc39d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.30" +version = "0.9.31" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 4b4c49d00..0a841409e 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -361,7 +361,7 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) ) return esc(quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...)) + function (::Core.kwftype(typeof(ChainRulesCore.rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...)) return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr) end function ChainRulesCore.rrule($(primal_sig_parts...)) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 75c03c86a..c563786ea 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -251,3 +251,24 @@ end end + + +module IsolatedModuleForTestingScoping + using Test + # need to make sure macros work in something that hasn't imported all exports + # all that matters is that the following don't error, since they will resolve at + # parse time + using ChainRulesCore: ChainRulesCore + + @testset "@non_differentiable" begin + # this is + # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 + fixed(x) = :abc + ChainRulesCore.@non_differentiable fixed(x) + end + + @testset "@scalar_rule" begin + my_id(x) = x + ChainRulesCore.@scalar_rule(my_id(x), 1.0) + end +end