diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 8a5bb307ea..575c53c7eb 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -567,13 +567,31 @@ Base.Math._log(x::TracedRNumber, base, ::Symbol) = log(x) / log(Reactant._unwrap Base.isreal(::TracedRNumber) = false Base.isreal(::TracedRNumber{<:Real}) = true +Base.isinteger(x::TracedRNumber{<:Integer}) = true +Base.isinteger(x::TracedRNumber{<:AbstractFloat}) = x - trunc(x) == zero(x) + +Base.isodd(x::TracedRNumber) = isodd(real(x)) +function Base.isodd(x::TracedRNumber{<:Real}) + return ( + isinteger(x) & + !iszero( + rem( + Reactant.promote_to(TracedRNumber{Int}, x), + Reactant.promote_to(TracedRNumber{Int}, 2), + ), + ) + ) +end + Base.iseven(x::TracedRNumber) = iseven(real(x)) function Base.iseven(x::TracedRNumber{<:Real}) - return iszero( - rem( - Reactant.promote_to(TracedRNumber{Int}, x), - Reactant.promote_to(TracedRNumber{Int}, 2), - ), + return ( + isinteger(x) & iszero( + rem( + Reactant.promote_to(TracedRNumber{Int}, x), + Reactant.promote_to(TracedRNumber{Int}, 2), + ), + ) ) end diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index fa8285e630..b92a5d1177 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -5,8 +5,7 @@ using ..Reactant: Reactant, Ops using ..Reactant: TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector using ..Reactant: call_with_reactant -using ReactantCore: ReactantCore -using ReactantCore: materialize_traced_array +using ReactantCore: ReactantCore, materialize_traced_array using Reactant_jll: Reactant_jll using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data! @@ -657,6 +656,8 @@ struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray, info::I end +Base.size(lu::GeneralizedLU) = size(lu.factors) + Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} @@ -729,6 +730,26 @@ function LinearAlgebra.ldiv!( return B end +function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) + # TODO: check for non-singular matrices + + P = prod(LinearAlgebra.diag(lu.factors)) + return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P +end + +function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) + Treal = real(T) + # TODO: check for non-singular matrices + + d = LinearAlgebra.diag(lu.factors) + absdet = sum(log ∘ abs, d) + P = prod(sign, d) + s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P + return absdet, s +end + for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), aType in (:AbstractVecOrMat, :AbstractArray) @@ -795,4 +816,18 @@ for AT in ( end end +LinearAlgebra._istriu(A::AnyTracedRMatrix, k) = all(iszero, overloaded_tril(A, k - 1)) +LinearAlgebra._istril(A::AnyTracedRMatrix, k) = all(iszero, overloaded_triu(A, k + 1)) + +# Only needed because we lack automatic if tracing +function LinearAlgebra.det(A::AnyTracedRMatrix) + # FIXME: using @trace here produces the cryptic UndefVarError + return LinearAlgebra.det(LinearAlgebra.lu(A; check=false)) +end + +function LinearAlgebra.logabsdet(A::AnyTracedRMatrix) + # FIXME: using @trace here produces the cryptic UndefVarError + return LinearAlgebra.logabsdet(LinearAlgebra.lu(A; check=false)) +end + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 2bab0b7366..ea05cfbf0f 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -432,3 +432,35 @@ end 1e-2 end end + +@testset "istriu" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_triu = triu(x, 4) + x_triu_ra = Reactant.to_rarray(x_triu) + @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 4))) + @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 3))) + @test !Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 5))) +end + +@testset "istril" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_tril = tril(x, -4) + x_tril_ra = Reactant.to_rarray(x_tril) + @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -4))) + @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -3))) + @test !Bool(@jit(LinearAlgebra.istril(x_tril_ra, -5))) +end + +@testset "det" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_ra = Reactant.to_rarray(x) + + res_ra = @jit LinearAlgebra.logabsdet(x_ra) + res = LinearAlgebra.logabsdet(x) + @test res_ra[1] ≈ res[1] + @test res_ra[2] ≈ res[2] + + res_ra = @jit LinearAlgebra.det(x_ra) + res = LinearAlgebra.det(x) + @test res_ra ≈ res +end