From 3c94853a344cdcb3e20057092381696865cade67 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 00:23:30 -0500 Subject: [PATCH 1/5] feat: linalg det and logabsdet --- src/TracedRNumber.jl | 27 +++++++++++++--- src/stdlibs/LinearAlgebra.jl | 51 ++++++++++++++++++++++++++++-- test/integration/linear_algebra.jl | 18 +++++++++++ 3 files changed, 89 insertions(+), 7 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 8a5bb307ea..82138547aa 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -567,13 +567,30 @@ Base.Math._log(x::TracedRNumber, base, ::Symbol) = log(x) / log(Reactant._unwrap Base.isreal(::TracedRNumber) = false Base.isreal(::TracedRNumber{<:Real}) = true +Base.isinteger(x::TracedRNumber{<:Integer}) = true +Base.isinteger(x::TracedRNumber{<:AbstractFloat}) = x - trunc(x) == zero(x) + +Base.isodd(x::TracedRNumber) = isodd(real(x)) +function Base.isodd(x::TracedRNumber{<:Real}) + return ( + isinteger(x) & !iszero( + rem( + Reactant.promote_to(TracedRNumber{Int}, x), + Reactant.promote_to(TracedRNumber{Int}, 2), + ), + ) + ) +end + Base.iseven(x::TracedRNumber) = iseven(real(x)) function Base.iseven(x::TracedRNumber{<:Real}) - return iszero( - rem( - Reactant.promote_to(TracedRNumber{Int}, x), - Reactant.promote_to(TracedRNumber{Int}, 2), - ), + return ( + isinteger(x) & iszero( + rem( + Reactant.promote_to(TracedRNumber{Int}, x), + Reactant.promote_to(TracedRNumber{Int}, 2), + ), + ) ) end diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index fa8285e630..d1ab528ef5 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -5,8 +5,7 @@ using ..Reactant: Reactant, Ops using ..Reactant: TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector using ..Reactant: call_with_reactant -using ReactantCore: ReactantCore -using ReactantCore: materialize_traced_array +using ReactantCore: ReactantCore, materialize_traced_array, @trace using Reactant_jll: Reactant_jll using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data! @@ -657,6 +656,8 @@ struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray, info::I end +Base.size(lu::GeneralizedLU) = size(lu.factors) + Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} @@ -729,6 +730,25 @@ function LinearAlgebra.ldiv!( return B end +function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) + # TODO: check for non-singular matrices + + P = prod(LinearAlgebra.diag(lu.factors)) + return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P +end + +function LinearAlgebra.logabsdet(F::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(F) + Treal = real(T) + # TODO: check for non-singular matrices + + absdet = sum(log ∘ abs, LinearAlgebra.diag(lu.factors)) + P = prod(sign, d) + s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P + return absdet, s +end + for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), aType in (:AbstractVecOrMat, :AbstractArray) @@ -795,4 +815,31 @@ for AT in ( end end +LinearAlgebra._istriu(A::AnyTracedRMatrix, k) = all(iszero, overloaded_tril(A, k - 1)) +LinearAlgebra._istril(A::AnyTracedRMatrix, k) = all(iszero, overloaded_triu(A, k + 1)) + +# Only needed because we lack automatic if tracing +function LinearAlgebra.det(A::AnyTracedRMatrix) + T = Reactant.unwrapped_eltype(A) + S = promote_type(T, typeof((one(T) * zero(T) + zero(T)) / one(T))) + istriangular = LinearAlgebra.istriu(A) | LinearAlgebra.istril(A) + # return LinearAlgebra.det(LinearAlgebra.lu(A; check=false)) + return ReactantCore.traced_if( + istriangular, + # x -> convert(TracedRNumber{S}, LinearAlgebra.det(LinearAlgebra.UpperTriangular(x))), + x -> LinearAlgebra.det(LinearAlgebra.lu(x; check=false)), + x -> LinearAlgebra.det(LinearAlgebra.lu(x; check=false)), + (A,), + ) +end + +function LinearAlgebra.logabsdet(A::AnyTracedRMatrix) + @trace if LinearAlgebra.istriu(A) || LinearAlgebra.istril(A) + r = LinearAlgebra.logabsdet(LinearAlgebra.UpperTriangular(A)) + else + r = LinearAlgebra.logabsdet(LinearAlgebra.lu(A; check=false)) + end + return r +end + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 2bab0b7366..65bc968cd3 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -432,3 +432,21 @@ end 1e-2 end end + +@testset "istriu" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_triu = triu(x, 4) + x_triu_ra = Reactant.to_rarray(x_triu) + @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 4))) + @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 3))) + @test !Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 5))) +end + +@testset "istril" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_tril = tril(x, -4) + x_tril_ra = Reactant.to_rarray(x_tril) + @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -4))) + @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -3))) + @test !Bool(@jit(LinearAlgebra.istril(x_tril_ra, -5))) +end From 132f9d3faa0a79a2ce47a7bc8fdd7052b9b37bb0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 11:04:00 -0500 Subject: [PATCH 2/5] test: det and logabsdet --- src/stdlibs/LinearAlgebra.jl | 30 +++++++++--------------------- test/integration/linear_algebra.jl | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index d1ab528ef5..b92a5d1177 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -5,7 +5,7 @@ using ..Reactant: Reactant, Ops using ..Reactant: TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector using ..Reactant: call_with_reactant -using ReactantCore: ReactantCore, materialize_traced_array, @trace +using ReactantCore: ReactantCore, materialize_traced_array using Reactant_jll: Reactant_jll using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data! @@ -738,12 +738,13 @@ function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P end -function LinearAlgebra.logabsdet(F::GeneralizedLU{T,<:AbstractMatrix}) where {T} - n = LinearAlgebra.checksquare(F) +function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) Treal = real(T) # TODO: check for non-singular matrices - absdet = sum(log ∘ abs, LinearAlgebra.diag(lu.factors)) + d = LinearAlgebra.diag(lu.factors) + absdet = sum(log ∘ abs, d) P = prod(sign, d) s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P return absdet, s @@ -820,26 +821,13 @@ LinearAlgebra._istril(A::AnyTracedRMatrix, k) = all(iszero, overloaded_triu(A, k # Only needed because we lack automatic if tracing function LinearAlgebra.det(A::AnyTracedRMatrix) - T = Reactant.unwrapped_eltype(A) - S = promote_type(T, typeof((one(T) * zero(T) + zero(T)) / one(T))) - istriangular = LinearAlgebra.istriu(A) | LinearAlgebra.istril(A) - # return LinearAlgebra.det(LinearAlgebra.lu(A; check=false)) - return ReactantCore.traced_if( - istriangular, - # x -> convert(TracedRNumber{S}, LinearAlgebra.det(LinearAlgebra.UpperTriangular(x))), - x -> LinearAlgebra.det(LinearAlgebra.lu(x; check=false)), - x -> LinearAlgebra.det(LinearAlgebra.lu(x; check=false)), - (A,), - ) + # FIXME: using @trace here produces the cryptic UndefVarError + return LinearAlgebra.det(LinearAlgebra.lu(A; check=false)) end function LinearAlgebra.logabsdet(A::AnyTracedRMatrix) - @trace if LinearAlgebra.istriu(A) || LinearAlgebra.istril(A) - r = LinearAlgebra.logabsdet(LinearAlgebra.UpperTriangular(A)) - else - r = LinearAlgebra.logabsdet(LinearAlgebra.lu(A; check=false)) - end - return r + # FIXME: using @trace here produces the cryptic UndefVarError + return LinearAlgebra.logabsdet(LinearAlgebra.lu(A; check=false)) end end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 65bc968cd3..828e9d07ca 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -450,3 +450,24 @@ end @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -3))) @test !Bool(@jit(LinearAlgebra.istril(x_tril_ra, -5))) end + +@testset "logabsdet" begin + x = Reactant.TestUtils.construct_test_array(Float64, 8, 8) + x_ra = Reactant.to_rarray(x) + @test @jit(LinearAlgebra.logabsdet(x_ra)) ≈ LinearAlgebra.logabsdet(x) +end + +@testset "det" begin + x = Reactant.TestUtils.construct_test_array(Float64, 8, 8) + x_ra = Reactant.to_rarray(x) + + res_ra = @jit LinearAlgebra.logabsdet(x_ra) + res = LinearAlgebra.logabsdet(x) + @test res_ra[1] ≈ res[1] + @test res_ra[2] ≈ res[2] + + res_ra = @jit LinearAlgebra.det(x_ra) + res = LinearAlgebra.det(x) + @test res_ra ≈ res +end + From 1803eb79d12932de700da75675d57009441c5586 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 11:05:39 -0500 Subject: [PATCH 3/5] Update test/integration/linear_algebra.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/integration/linear_algebra.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 828e9d07ca..fa96d209fd 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -451,12 +451,6 @@ end @test !Bool(@jit(LinearAlgebra.istril(x_tril_ra, -5))) end -@testset "logabsdet" begin - x = Reactant.TestUtils.construct_test_array(Float64, 8, 8) - x_ra = Reactant.to_rarray(x) - @test @jit(LinearAlgebra.logabsdet(x_ra)) ≈ LinearAlgebra.logabsdet(x) -end - @testset "det" begin x = Reactant.TestUtils.construct_test_array(Float64, 8, 8) x_ra = Reactant.to_rarray(x) @@ -470,4 +464,3 @@ end res = LinearAlgebra.det(x) @test res_ra ≈ res end - From 63da3cb60a20a214be9feaa4f965e44969a90eac Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 12:41:23 -0500 Subject: [PATCH 4/5] test: use Float32 --- test/integration/linear_algebra.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index fa96d209fd..ea05cfbf0f 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -452,7 +452,7 @@ end end @testset "det" begin - x = Reactant.TestUtils.construct_test_array(Float64, 8, 8) + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) x_ra = Reactant.to_rarray(x) res_ra = @jit LinearAlgebra.logabsdet(x_ra) From d900d7c87908dc26e1aa96cb10003c9f1537f101 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 18:36:30 -0500 Subject: [PATCH 5/5] Update src/TracedRNumber.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/TracedRNumber.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 82138547aa..575c53c7eb 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -573,7 +573,8 @@ Base.isinteger(x::TracedRNumber{<:AbstractFloat}) = x - trunc(x) == zero(x) Base.isodd(x::TracedRNumber) = isodd(real(x)) function Base.isodd(x::TracedRNumber{<:Real}) return ( - isinteger(x) & !iszero( + isinteger(x) & + !iszero( rem( Reactant.promote_to(TracedRNumber{Int}, x), Reactant.promote_to(TracedRNumber{Int}, 2),