diff --git a/Project.toml b/Project.toml index d4bcbd5d8..185d3f722 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.24" +version = "0.9.25" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index b68717693..b56d4415b 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -142,7 +142,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials) # Δs is the input to the propagator rule # because this is push-forward there is one per input to the function - Δs = [esc(Symbol(:Δ, i)) for i in 1:n_inputs] + Δs = _propagator_inputs(n_inputs) pushforward_returns = map(1:n_outputs) do output_i ∂s = partials[output_i].args propagation_expr(Δs, ∂s) @@ -173,7 +173,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) # Δs is the input to the propagator rule # because this is a pull-back there is one per output of function - Δs = [Symbol(:Δ, i) for i in 1:n_outputs] + Δs = _propagator_inputs(n_outputs) # 1 partial derivative per input pullback_returns = map(1:n_inputs) do input_i @@ -198,6 +198,11 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) end end +# For context on why this is important, see +# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 +"Declares properly hygenic inputs for propagation expressions" +_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] + """ propagation_expr(Δs, ∂s, _conj = false) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index c4a0df8d3..35a2ddd88 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -207,5 +207,24 @@ end # make sure type is exactly as expected: @test ẏ isa Composite{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}} end + + @testset "Regression test against #276" begin + # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 + # Symptom of this problem is creation of global variables and type instablily + + num_globals_before = length(names(ChainRulesCore; all=true)) + + simo2(x) = (x, 2x) + @scalar_rule(simo2(x), 1.0, 2.0) + _, simo2_pb = rrule(simo2, 43.0) + # make sure it infers: inferability implies type stability + @inferred simo2_pb(Composite{Tuple{Float64, Float64}}(3.0, 6.0)) + + # Test no new globals were created + @test length(names(ChainRulesCore; all=true)) == num_globals_before + end end + + + end