From 8f7535b23e4718ae4a59557beffeccdc45c6389a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 27 Aug 2025 20:34:08 -0400 Subject: [PATCH] fix: logsumexp --- src/TracedRArray.jl | 2 ++ test/nn/nnlib.jl | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index fc89f51b74..b99b0dcaf5 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -573,7 +573,9 @@ for (jlop, hloop, hlocomp, merge) in end __default_init(::Type{T}, ::typeof(Base.min)) where {T} = typemax(T) +__default_init(::Type{T}, ::typeof(Base.FastMath.min_fast)) where {T} = typemax(T) __default_init(::Type{T}, ::typeof(Base.max)) where {T} = typemin(T) +__default_init(::Type{T}, ::typeof(Base.FastMath.max_fast)) where {T} = typemin(T) function __default_init(::Type{T}, op::F) where {T,F} return Base.reduce_empty(Base.BottomRF(op), T) end diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 8dfc42f54e..c92cc300b1 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -730,3 +730,12 @@ end @test @jit(NNlib.softmax(x_ra)) ≈ NNlib.softmax(x) @test @jit(NNlib.logsoftmax(x_ra)) ≈ NNlib.logsoftmax(x) end + +@testset "logsumexp #1593" begin + x = collect(Float32, 1:16) + x_ra = Reactant.to_rarray(x) + + y = logsumexp(x) + y_ra = @jit(logsumexp(x_ra)) + @test Float32(y_ra) ≈ y +end