Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, eps(T))
end

function Base.isfinite(x::TracedRNumber{<:Complex})
return isfinite(real(x)) & isfinite(imag(x))
end
function Base.isfinite(x::TracedRNumber{T}) where {T<:AbstractFloat}
return Reactant.Ops.is_finite(x)
end

function Base.isnan(x::TracedRNumber{T}) where {T<:AbstractFloat}
return !isfinite(x) & (x != typemax(T)) & (x != typemin(T))
end
function Base.isnan(x::TracedRNumber{<:Complex})
return isnan(real(x)) | isnan(imag(x))
end

function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}}
return print(io, "TracedRNumber{", T, "}(", X.paths, ")")
end
Expand Down
28 changes: 22 additions & 6 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,26 +371,26 @@ end

@testset "Number and RArray" for a in [1.0f0, 1.0e0]
typeof_a = typeof(a)
_b = [2.0, 3.0, 4.0] .|> typeof_a
_c = [2.0 3.0 4.0] .|> typeof_a
_b = typeof_a.([2.0, 3.0, 4.0])
_c = typeof_a.([2.0 3.0 4.0])
b = Reactant.to_rarray(_b)
c = Reactant.to_rarray(_c)

# vcat test
y = @jit vcat(a, b)
@test y == vcat(a, _b)
@test y isa ConcreteRArray{typeof_a,1}

## vcat test - adjoint
y1 = @jit vcat(a, c')
@test y1 == vcat(a, _c')
@test y1 isa ConcreteRArray{typeof_a,2}

# hcat test
z = @jit hcat(a, c)
@test z == hcat(a, _c)
@test z isa ConcreteRArray{typeof_a,2}

## hcat test - adjoint
z1 = @jit hcat(a, b')
@test z1 == hcat(a, _b')
Expand Down Expand Up @@ -1028,3 +1028,19 @@ end
@test res[2] isa ConcreteRNumber{Float32}
end
end

@testset "isfinite" begin
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
@test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false]

x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
@test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false]
end

@testset "isnan" begin
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
@test Reactant.@jit(isnan.(x)) == [false, true, false, false, true]

x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
@test Reactant.@jit(isnan.(x)) == [false, true, false, false, true]
end
Loading