From 8c26257e593767d1cdc5185a87f6e0fc151c7fb9 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 13 Jan 2025 14:46:02 +0100 Subject: [PATCH 1/3] Implement `isnan` for TracedRNumber --- src/TracedRNumber.jl | 4 ++++ test/basic.jl | 17 +++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 445fd9a783..5475e842b0 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -23,6 +23,10 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T} return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) end +function Base.isnan(x::TracedRNumber{T}) where {T<:AbstractFloat} + return !Reactant.Ops.is_finite(x) & (x != typemax(T)) & (x != typemin(T)) +end + function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} return print(io, "TracedRNumber{", T, "}(", X.paths, ")") end diff --git a/test/basic.jl b/test/basic.jl index 9ef7807e6c..a1c9355be4 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -371,26 +371,26 @@ end @testset "Number and RArray" for a in [1.0f0, 1.0e0] typeof_a = typeof(a) - _b = [2.0, 3.0, 4.0] .|> typeof_a - _c = [2.0 3.0 4.0] .|> typeof_a + _b = typeof_a.([2.0, 3.0, 4.0]) + _c = typeof_a.([2.0 3.0 4.0]) b = Reactant.to_rarray(_b) c = Reactant.to_rarray(_c) - + # vcat test y = @jit vcat(a, b) @test y == vcat(a, _b) @test y isa ConcreteRArray{typeof_a,1} - + ## vcat test - adjoint y1 = @jit vcat(a, c') @test y1 == vcat(a, _c') @test y1 isa ConcreteRArray{typeof_a,2} - + # hcat test z = @jit hcat(a, c) @test z == hcat(a, _c) @test z isa ConcreteRArray{typeof_a,2} - + ## hcat test - adjoint z1 = @jit hcat(a, b') @test z1 == hcat(a, _b') @@ -1028,3 +1028,8 @@ end @test res[2] isa ConcreteRNumber{Float32} end end + +@testset "isnan" begin + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) + @test Reactant.@jit(isnan.(x)) == [false, true, false, false, true] +end From 5b2c66a436a26d81fe82609f26ef9f039f4dabbc Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 13 Jan 2025 15:17:30 +0100 Subject: [PATCH 2/3] isfinite and complex --- src/TracedRNumber.jl | 9 ++++++++- test/basic.jl | 13 +++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 5475e842b0..b2ac6cd733 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -23,8 +23,15 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T} return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) end +function Base.isfinite(x::TracedRNumber{T}) where {T<:AbstractFloat} + return Reactant.Ops.is_finite(x) +end + function Base.isnan(x::TracedRNumber{T}) where {T<:AbstractFloat} - return !Reactant.Ops.is_finite(x) & (x != typemax(T)) & (x != typemin(T)) + return !isfinite(x) & (x != typemax(T)) & (x != typemin(T)) +end +function Base.isnan(x::TracedRNumber{<:Complex}) + return isnan(real(x)) | isnan(imag(x)) end function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} diff --git a/test/basic.jl b/test/basic.jl index a1c9355be4..b6205daa53 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1029,7 +1029,20 @@ end end end +@testset "isfinite" begin + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) + @test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false] + + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) + @test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false] +end + +broadcast_isnan(x) = isnan.(x) + @testset "isnan" begin x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) @test Reactant.@jit(isnan.(x)) == [false, true, false, false, true] + + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) + @show Reactant.@jit(broadcast_isnan(x)) end From 02fce1d3dc91be76e1b95943b953970c8e3ae8f3 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 13 Jan 2025 15:45:04 +0100 Subject: [PATCH 3/3] update --- src/TracedRNumber.jl | 3 +++ test/basic.jl | 4 +--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index b2ac6cd733..9dcc6dd39a 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -23,6 +23,9 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T} return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) end +function Base.isfinite(x::TracedRNumber{<:Complex}) + return isfinite(real(x)) & isfinite(imag(x)) +end function Base.isfinite(x::TracedRNumber{T}) where {T<:AbstractFloat} return Reactant.Ops.is_finite(x) end diff --git a/test/basic.jl b/test/basic.jl index b6205daa53..61a644580f 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1037,12 +1037,10 @@ end @test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false] end -broadcast_isnan(x) = isnan.(x) - @testset "isnan" begin x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) @test Reactant.@jit(isnan.(x)) == [false, true, false, false, true] x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) - @show Reactant.@jit(broadcast_isnan(x)) + @test Reactant.@jit(isnan.(x)) == [false, true, false, false, true] end