diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 445fd9a783..9dcc6dd39a 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -23,6 +23,20 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T} return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) end +function Base.isfinite(x::TracedRNumber{<:Complex}) + return isfinite(real(x)) & isfinite(imag(x)) +end +function Base.isfinite(x::TracedRNumber{T}) where {T<:AbstractFloat} + return Reactant.Ops.is_finite(x) +end + +function Base.isnan(x::TracedRNumber{T}) where {T<:AbstractFloat} + return !isfinite(x) & (x != typemax(T)) & (x != typemin(T)) +end +function Base.isnan(x::TracedRNumber{<:Complex}) + return isnan(real(x)) | isnan(imag(x)) +end + function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} return print(io, "TracedRNumber{", T, "}(", X.paths, ")") end diff --git a/test/basic.jl b/test/basic.jl index 9ef7807e6c..61a644580f 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -371,26 +371,26 @@ end @testset "Number and RArray" for a in [1.0f0, 1.0e0] typeof_a = typeof(a) - _b = [2.0, 3.0, 4.0] .|> typeof_a - _c = [2.0 3.0 4.0] .|> typeof_a + _b = typeof_a.([2.0, 3.0, 4.0]) + _c = typeof_a.([2.0 3.0 4.0]) b = Reactant.to_rarray(_b) c = Reactant.to_rarray(_c) - + # vcat test y = @jit vcat(a, b) @test y == vcat(a, _b) @test y isa ConcreteRArray{typeof_a,1} - + ## vcat test - adjoint y1 = @jit vcat(a, c') @test y1 == vcat(a, _c') @test y1 isa ConcreteRArray{typeof_a,2} - + # hcat test z = @jit hcat(a, c) @test z == hcat(a, _c) @test z isa ConcreteRArray{typeof_a,2} - + ## hcat test - adjoint z1 = @jit hcat(a, b') @test z1 == hcat(a, _b') @@ -1028,3 +1028,19 @@ end @test res[2] isa ConcreteRNumber{Float32} end end + +@testset "isfinite" begin + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) + @test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false] + + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) + @test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false] +end + +@testset "isnan" begin + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) + @test Reactant.@jit(isnan.(x)) == [false, true, false, false, true] + + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) + @test Reactant.@jit(isnan.(x)) == [false, true, false, false, true] +end