Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,31 @@ Base.Math._log(x::TracedRNumber, base, ::Symbol) = log(x) / log(Reactant._unwrap
Base.isreal(::TracedRNumber) = false
Base.isreal(::TracedRNumber{<:Real}) = true

Base.isinteger(x::TracedRNumber{<:Integer}) = true
Base.isinteger(x::TracedRNumber{<:AbstractFloat}) = x - trunc(x) == zero(x)

Base.isodd(x::TracedRNumber) = isodd(real(x))
function Base.isodd(x::TracedRNumber{<:Real})
return (
isinteger(x) &
!iszero(
rem(
Reactant.promote_to(TracedRNumber{Int}, x),
Reactant.promote_to(TracedRNumber{Int}, 2),
),
)
)
end

Base.iseven(x::TracedRNumber) = iseven(real(x))
function Base.iseven(x::TracedRNumber{<:Real})
return iszero(
rem(
Reactant.promote_to(TracedRNumber{Int}, x),
Reactant.promote_to(TracedRNumber{Int}, 2),
),
return (
isinteger(x) & iszero(
rem(
Reactant.promote_to(TracedRNumber{Int}, x),
Reactant.promote_to(TracedRNumber{Int}, 2),
),
)
)
end

Expand Down
39 changes: 37 additions & 2 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ using ..Reactant: Reactant, Ops
using ..Reactant:
TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector
using ..Reactant: call_with_reactant
using ReactantCore: ReactantCore
using ReactantCore: materialize_traced_array
using ReactantCore: ReactantCore, materialize_traced_array
using Reactant_jll: Reactant_jll

using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
Expand Down Expand Up @@ -657,6 +656,8 @@ struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,
info::I
end

Base.size(lu::GeneralizedLU) = size(lu.factors)

Base.ndims(lu::GeneralizedLU) = ndims(lu.factors)

function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I}
Expand Down Expand Up @@ -729,6 +730,26 @@ function LinearAlgebra.ldiv!(
return B
end

function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T}
n = LinearAlgebra.checksquare(lu)
# TODO: check for non-singular matrices

P = prod(LinearAlgebra.diag(lu.factors))
return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P
end

function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T}
n = LinearAlgebra.checksquare(lu)
Treal = real(T)
# TODO: check for non-singular matrices

d = LinearAlgebra.diag(lu.factors)
absdet = sum(log ∘ abs, d)
P = prod(sign, d)
s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P
return absdet, s
end

for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization),
aType in (:AbstractVecOrMat, :AbstractArray)

Expand Down Expand Up @@ -795,4 +816,18 @@ for AT in (
end
end

LinearAlgebra._istriu(A::AnyTracedRMatrix, k) = all(iszero, overloaded_tril(A, k - 1))
LinearAlgebra._istril(A::AnyTracedRMatrix, k) = all(iszero, overloaded_triu(A, k + 1))

# Only needed because we lack automatic if tracing
function LinearAlgebra.det(A::AnyTracedRMatrix)
# FIXME: using @trace here produces the cryptic UndefVarError
return LinearAlgebra.det(LinearAlgebra.lu(A; check=false))
end

function LinearAlgebra.logabsdet(A::AnyTracedRMatrix)
# FIXME: using @trace here produces the cryptic UndefVarError
return LinearAlgebra.logabsdet(LinearAlgebra.lu(A; check=false))
end

end
32 changes: 32 additions & 0 deletions test/integration/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,35 @@ end
1e-2
end
end

@testset "istriu" begin
x = Reactant.TestUtils.construct_test_array(Float32, 8, 8)
x_triu = triu(x, 4)
x_triu_ra = Reactant.to_rarray(x_triu)
@test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 4)))
@test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 3)))
@test !Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 5)))
end

@testset "istril" begin
x = Reactant.TestUtils.construct_test_array(Float32, 8, 8)
x_tril = tril(x, -4)
x_tril_ra = Reactant.to_rarray(x_tril)
@test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -4)))
@test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -3)))
@test !Bool(@jit(LinearAlgebra.istril(x_tril_ra, -5)))
end

@testset "det" begin
x = Reactant.TestUtils.construct_test_array(Float32, 8, 8)
x_ra = Reactant.to_rarray(x)

res_ra = @jit LinearAlgebra.logabsdet(x_ra)
res = LinearAlgebra.logabsdet(x)
@test res_ra[1] ≈ res[1]
@test res_ra[2] ≈ res[2]

res_ra = @jit LinearAlgebra.det(x_ra)
res = LinearAlgebra.det(x)
@test res_ra ≈ res
end
Loading