Skip to content

Scalar rules with multiple outputs are type-unstable #265

@sethaxen

Description

@sethaxen

One of the type-instabilities in ChainRules identified by JuliaDiff/ChainRulesTestUtils.jl#78 is that @scalar_rule currently produces type-unstable rrules for functions that return multiple arguments, such as sincos and SpecialFunctions.logabsgamma:

julia> using ChainRules, SpecialFunctions, Test

julia> x, Δsinx, Δcosx = randn(3);

julia> Ω, back = @inferred rrule(sincos, x);

julia> ΔΩ = Composite{typeof(Ω)}(Δsinx, Δcosx);

julia> @inferred back(ΔΩ)
ERROR: return type Tuple{Zero,Float64} does not match inferred return type Tuple{Zero,Any}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope at REPL[5]:1

julia> x, Δy, Δs = randn(3);

julia> Ω, back = @inferred rrule(SpecialFunctions.logabsgamma, x);

julia> ΔΩ = Composite{typeof(Ω)}(Δy, Δs);

julia> @inferred back(ΔΩ)
ERROR: return type Tuple{Zero,Float64} does not match inferred return type Tuple{Zero,Any}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope at REPL[9]:1

The type-instability enters when broadcasting over Δ1 and Δ2:

%14 = Base.broadcasted(ChainRulesCore.conj, %13)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(conj),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(digamma),Tuple{Float64}}}}
│   %15 = Base.broadcasted(ChainRulesCore.:*, %14, ChainRulesCore.Δ1)::Any%16 = Base.broadcasted(muladd, %10, ChainRulesCore.Δ2, %15)::Any%17 = Base.materialize(%16)::Any%18 = Core.tuple(ChainRulesCore.NO_FIELDS, %17)::Core.Compiler.PartialStruct(Tuple{Zero,Any}, Any[Core.Compiler.Const(Zero(), false), Any])```

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions