From 19b7cb71f1696d58c96d649a683f585e31819588 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 8 Feb 2025 11:53:49 -0500 Subject: [PATCH] feat: overload ifelse for more types --- src/TracedRNumber.jl | 8 ++++++++ test/basic.jl | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 4edacda958..88766cbb8a 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -165,6 +165,14 @@ for (jlop, hloop, hlocomp) in ( end end +function Base.ifelse(@nospecialize(pred::TracedRNumber{Bool}), x::Number, y::Number) + return ifelse( + pred, + TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(x)}, x), + TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(y)}, y), + ) +end + function Base.ifelse( @nospecialize(pred::TracedRNumber{Bool}), @nospecialize(x::TracedRNumber{T1}), diff --git a/test/basic.jl b/test/basic.jl index e70afc1d98..d8f4424b41 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -636,6 +636,13 @@ end @test @jit( ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0f0), ConcreteRNumber(0.0f0)) ) isa ConcreteRNumber{Float32} + + cond = ConcreteRNumber(true) + x = ConcreteRNumber(1.0) + @test @jit(ifelse(cond, x, 0.0)) == ConcreteRNumber(1.0) + @test @jit(ifelse(cond, 0.0, x)) == ConcreteRNumber(0.0) + @test @jit(ifelse(cond, 1.0, 0.0)) == ConcreteRNumber(1.0) + @test @jit(ifelse(cond, 0.0, 1.0)) == ConcreteRNumber(0.0) end @testset "fill! and zero on ConcreteRArray" begin