-
Notifications
You must be signed in to change notification settings - Fork 65
Closed
Description
One of the type-instabilities in ChainRules identified by JuliaDiff/ChainRulesTestUtils.jl#78 is that @scalar_rule
currently produces type-unstable rrule
s for functions that return multiple arguments, such as sincos
and SpecialFunctions.logabsgamma
:
julia> using ChainRules, SpecialFunctions, Test
julia> x, Δsinx, Δcosx = randn(3);
julia> Ω, back = @inferred rrule(sincos, x);
julia> ΔΩ = Composite{typeof(Ω)}(Δsinx, Δcosx);
julia> @inferred back(ΔΩ)
ERROR: return type Tuple{Zero,Float64} does not match inferred return type Tuple{Zero,Any}
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] top-level scope at REPL[5]:1
julia> x, Δy, Δs = randn(3);
julia> Ω, back = @inferred rrule(SpecialFunctions.logabsgamma, x);
julia> ΔΩ = Composite{typeof(Ω)}(Δy, Δs);
julia> @inferred back(ΔΩ)
ERROR: return type Tuple{Zero,Float64} does not match inferred return type Tuple{Zero,Any}
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] top-level scope at REPL[9]:1
The type-instability enters when broadcasting over Δ1
and Δ2
:
│ %14 = Base.broadcasted(ChainRulesCore.conj, %13)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(conj),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(digamma),Tuple{Float64}}}}
│ %15 = Base.broadcasted(ChainRulesCore.:*, %14, ChainRulesCore.Δ1)::Any
│ %16 = Base.broadcasted(muladd, %10, ChainRulesCore.Δ2, %15)::Any
│ %17 = Base.materialize(%16)::Any
│ %18 = Core.tuple(ChainRulesCore.NO_FIELDS, %17)::Core.Compiler.PartialStruct(Tuple{Zero,Any}, Any[Core.Compiler.Const(Zero(), false), Any])```
Metadata
Metadata
Assignees
Labels
No labels