Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.24"
version = "0.9.25"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
9 changes: 7 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)

# Δs is the input to the propagator rule
# because this is push-forward there is one per input to the function
Δs = [esc(Symbol(:Δ, i)) for i in 1:n_inputs]
Δs = _propagator_inputs(n_inputs)
pushforward_returns = map(1:n_outputs) do output_i
∂s = partials[output_i].args
propagation_expr(Δs, ∂s)
Expand Down Expand Up @@ -173,7 +173,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)

# Δs is the input to the propagator rule
# because this is a pull-back there is one per output of function
Δs = [Symbol(:Δ, i) for i in 1:n_outputs]
Δs = _propagator_inputs(n_outputs)

# 1 partial derivative per input
pullback_returns = map(1:n_inputs) do input_i
Expand All @@ -198,6 +198,11 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
end
end

# For context on why this is important, see
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276
"Declares properly hygenic inputs for propagation expressions"
_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the gensym?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gensym gives it a name that is certain not to clash with anything.
it is a name like ##Δ2#991

So just incase the person had named a variable Δ2 and written something like

@scalar_rule foobar(x) (Δ2=10.0; Δ2*2),  x

using gensym makes sure we don't clash into that variable.

Copy link
Member

@mzgubic mzgubic Jan 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes thanks esc doesn't help us with that


"""
propagation_expr(Δs, ∂s, _conj = false)

Expand Down
19 changes: 19 additions & 0 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,24 @@ end
# make sure type is exactly as expected:
@test ẏ isa Composite{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}}
end

@testset "Regression test against #276" begin
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276
# Symptom of this problem is creation of global variables and type instablily

num_globals_before = length(names(ChainRulesCore; all=true))

simo2(x) = (x, 2x)
@scalar_rule(simo2(x), 1.0, 2.0)
_, simo2_pb = rrule(simo2, 43.0)
# make sure it infers: inferability implies type stability
@inferred simo2_pb(Composite{Tuple{Float64, Float64}}(3.0, 6.0))

# Test no new globals were created
@test length(names(ChainRulesCore; all=true)) == num_globals_before
end
end



end