From d5d6218e9452cf433c7919e4cbf696a2cf2167e2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Nov 2025 11:15:39 -0600 Subject: [PATCH 01/11] feat: add cheaper det/logabsdet if triangular --- src/stdlibs/LinearAlgebra.jl | 29 +++++++++++++++++++++++------ test/integration/linear_algebra.jl | 23 ++++++++++++++--------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index b92a5d1177..6fe3157e72 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -5,7 +5,7 @@ using ..Reactant: Reactant, Ops using ..Reactant: TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector using ..Reactant: call_with_reactant -using ReactantCore: ReactantCore, materialize_traced_array +using ReactantCore: ReactantCore, materialize_traced_array, @trace using Reactant_jll: Reactant_jll using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data! @@ -15,7 +15,7 @@ using LinearAlgebra: LinearAlgebra, BLAS using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular -using LinearAlgebra: diag, diagm, ldiv! +using LinearAlgebra: diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril using Libdl: Libdl function __init__() @@ -821,13 +821,30 @@ LinearAlgebra._istril(A::AnyTracedRMatrix, k) = all(iszero, overloaded_triu(A, k # Only needed because we lack automatic if tracing function LinearAlgebra.det(A::AnyTracedRMatrix) - # FIXME: using @trace here produces the cryptic UndefVarError - return LinearAlgebra.det(LinearAlgebra.lu(A; check=false)) + @trace if istriu(A) || istril(A) + _det = det(UpperTriangular(A)) + else + _det = det(lu(A; check=false)) + end + return _det end function LinearAlgebra.logabsdet(A::AnyTracedRMatrix) - # FIXME: using @trace here produces the cryptic UndefVarError - return LinearAlgebra.logabsdet(LinearAlgebra.lu(A; check=false)) + @trace if istriu(A) || istril(A) + _logabsdet = logabsdet(UpperTriangular(A)) + else + _logabsdet = logabsdet(lu(A; check=false)) + end + return _logabsdet +end + +function LinearAlgebra.logabsdet( + A::Union{UpperTriangular{T,<:AnyTracedRMatrix},LowerTriangular{T,<:AnyTracedRMatrix}} +) where {T} + d = LinearAlgebra.diag(A) + sgn = prod(sign, d) + abs_det = sum(log ∘ abs, d) + return abs_det, sgn end end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index ea05cfbf0f..5de03f9ae7 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -452,15 +452,20 @@ end end @testset "det" begin - x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) - x_ra = Reactant.to_rarray(x) + x_lowtri = Float32[1 0; 2 2] + x_reg = Float32[1 -1; 2 2] + x_uptri = Float32[1 2; 0 2] - res_ra = @jit LinearAlgebra.logabsdet(x_ra) - res = LinearAlgebra.logabsdet(x) - @test res_ra[1] ≈ res[1] - @test res_ra[2] ≈ res[2] + for x in (x_lowtri, x_reg, x_uptri) + x_ra = Reactant.to_rarray(x) - res_ra = @jit LinearAlgebra.det(x_ra) - res = LinearAlgebra.det(x) - @test res_ra ≈ res + res_ra = @jit LinearAlgebra.logabsdet(x_ra) + res = LinearAlgebra.logabsdet(x) + @test res_ra[1] ≈ res[1] + @test res_ra[2] ≈ res[2] + + res_ra = @jit LinearAlgebra.det(x_ra) + res = LinearAlgebra.det(x) + @test res_ra ≈ res + end end From 4bba667b6451600f922f4d961a5806222aa195fb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Nov 2025 11:37:30 -0600 Subject: [PATCH 02/11] feat: lowering of inv --- src/stdlibs/LinearAlgebra.jl | 52 ++++++++++++++++++++++++++++-- test/integration/linear_algebra.jl | 14 ++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 6fe3157e72..23b21eefae 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -15,7 +15,8 @@ using LinearAlgebra: LinearAlgebra, BLAS using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular -using LinearAlgebra: diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril +using LinearAlgebra: + diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril, triu!, tril!, inv, inv! using Libdl: Libdl function __init__() @@ -657,8 +658,11 @@ struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray, end Base.size(lu::GeneralizedLU) = size(lu.factors) - +Base.size(lu::GeneralizedLU, i) = size(lu.factors, i) Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) +function Base.copy(lu::GeneralizedLU) + return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) +end function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} @assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1 @@ -760,6 +764,18 @@ for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFac end end +# currently we lower inverse to lu decomposition + triangular solve. we should +# instead emit getri and lower that to a fallback if the backend doesn't support +# it. +function LinearAlgebra.inv!(lu::GeneralizedLU) + @assert ndims(lu) == 2 "Only implemented for 2D tensors" + rhs = Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1)) + ) + ldiv!(lu, rhs) + return rhs +end + function _lu_solve_core(factors::AbstractMatrix, B::AbstractMatrix, perm::AbstractVector) permuted_B = B[Int64.(perm), :] return UpperTriangular(factors) \ (UnitLowerTriangular(factors) \ permuted_B) @@ -847,4 +863,36 @@ function LinearAlgebra.logabsdet( return abs_det, sgn end +function Base.inv(A::TracedRArray{T,2}) where {T} # don't overload Any* here + LinearAlgebra.checksquare(A) + @trace if istriu(A) + Ai = triu!(parent(inv(UpperTriangular(A)))) + elseif istril(A) + Ai = tril!(parent(inv(LowerTriangular(A)))) + else + Ai = inv!(lu(A; check=false)) + end + return Ai +end + +for (wT, lower, ud) in ( + (:UpperTriangular, false, false), + (:LowerTriangular, true, false), + (:UnitUpperTriangular, false, true), + (:UnitLowerTriangular, true, true), +) + @eval function Base.inv(A::LinearAlgebra.$(wT){T,<:AnyTracedRMatrix}) where {T} + S = typeof(inv(oneunit(Reactant.unwrapped_eltype(T)))) + rhs = Reactant.promote_to(TracedRArray{S,2}, LinearAlgebra.I(size(A, 1))) + return @opcall triangular_solve( + parent(A), + rhs; + left_side=false, + lower=$(lower), + transpose_a='N', + unit_diagonal=$(ud), + ) + end +end + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 5de03f9ae7..f92ce7758c 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -469,3 +469,17 @@ end @test res_ra ≈ res end end + +@testset "inv" begin + x_lowtri = Float32[1 0; 2 2] + x_reg = Float32[1 -1; 2 2] + x_uptri = Float32[1 2; 0 2] + + for x in (x_lowtri, x_reg, x_uptri) + x_ra = Reactant.to_rarray(x) + + res_ra = @jit inv(x_ra) + res = inv(x) + @test res_ra ≈ res + end +end From 718d97d1a3b2b677c111c6f0592a0f118890d513 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Nov 2025 11:42:58 -0600 Subject: [PATCH 03/11] fix: accidental promotion in norm --- src/stdlibs/LinearAlgebra.jl | 2 +- test/integration/linear_algebra.jl | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 23b21eefae..e3a78237d5 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -328,7 +328,7 @@ end # LinearAlgebra defines norm with some conditionals which cannot be traced directly function LinearAlgebra.norm(x::TracedRArray{T,N}, p::Real=2) where {T,N} isinf(p) && return maximum(abs, x) - return mapreduce(Base.Fix2(^, p), +, x)^(1 / p) + return mapreduce(Base.Fix2(^, p), +, x)^(T(1 / p)) end function LinearAlgebra._diagm(shape, kv::Pair{<:Integer,<:AnyTracedRVector}...) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index f92ce7758c..55334c30b8 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -483,3 +483,8 @@ end @test res_ra ≈ res end end + +@testset "norm accidental promotion" begin + x_ra = Reactant.to_rarray(rand(Float32, 4, 4)) + @test @jit(norm(x_ra)) isa ConcreteRNumber{Float32} +end From c7c90a99990afaefdafa419d8df0f8ef5112a291 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Nov 2025 11:46:54 -0600 Subject: [PATCH 04/11] feat: support cross --- src/TracedRArray.jl | 4 ++++ src/TracedRNumber.jl | 12 ++++++++++++ src/stdlibs/LinearAlgebra.jl | 17 +++++++++++++++++ test/integration/linear_algebra.jl | 11 +++++++++++ 4 files changed, 44 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 73dacadfb2..2038f7d731 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -29,6 +29,10 @@ Base.elsize(::Type{TracedRArray{T,N}}) where {T,N} = sizeof(T) # we use it Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T, x) +# Base.first is very common usecase for getting first element to get the type +# inside LinearAlgebra.jl +Base.first(x::TracedRArray{T,N}) where {T,N} = @allowscalar(x[1]) + # Base.complex Base.complex(x::TracedRArray{<:Real}) = complex.(x) Base.complex(x::TracedRArray{<:Complex}) = x diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 575c53c7eb..8b53c5237e 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -26,13 +26,25 @@ Base.copy(x::TracedRNumber{T}) where {T} = TracedRNumber{T}((), x.mlir_data) function Base.eps(::Type{TracedRNumber{T}}) where {T} return Reactant.promote_to(TracedRNumber{T}, eps(T)) end +Base.eps(x::TracedRNumber{T}) where {T} = eps(typeof(x)) function Base.typemin(::Type{TracedRNumber{T}}) where {T} return Reactant.promote_to(TracedRNumber{T}, typemin(T)) end +Base.typemin(x::TracedRNumber{T}) where {T} = typemin(typeof(x)) + function Base.typemax(::Type{TracedRNumber{T}}) where {T} return Reactant.promote_to(TracedRNumber{T}, typemax(T)) end +Base.typemax(x::TracedRNumber{T}) where {T} = typemax(typeof(x)) + +function Base.nextfloat(x::TracedRNumber{T}) where {T<:AbstractFloat} + return @opcall next_after(x, typemax(x)) +end + +function Base.prevfloat(x::TracedRNumber{T}) where {T<:AbstractFloat} + return @opcall next_after(x, typemin(x)) +end function Base.rtoldefault(T::Type{<:TracedRNumber}) return T(Base.rtoldefault(unwrapped_eltype(T))) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index e3a78237d5..422f99c864 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -18,6 +18,7 @@ using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular using LinearAlgebra: diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril, triu!, tril!, inv, inv! using Libdl: Libdl +using GPUArraysCore: @allowscalar function __init__() if Reactant_jll.is_available() @@ -895,4 +896,20 @@ for (wT, lower, ud) in ( end end +function LinearAlgebra.cross(x::AnyTracedRVector, y::AbstractVector) + return LinearAlgebra.cross(x, Reactant.promote_to(TracedRArray{eltype(y),1}, y)) +end + +function LinearAlgebra.cross(x::AbstractVector, y::AnyTracedRVector) + return LinearAlgebra.cross(Reactant.promote_to(TracedRArray{eltype(x),1}, x), y) +end + +function LinearAlgebra.cross(x::AnyTracedRVector, y::AnyTracedRVector) + x_ = materialize_traced_array(x) + y_ = materialize_traced_array(y) + @allowscalar a1, a2, a3 = x_ + @allowscalar b1, b2, b3 = y_ + return Reactant.aos_to_soa([a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1]) +end + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 55334c30b8..2794c19a99 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -488,3 +488,14 @@ end x_ra = Reactant.to_rarray(rand(Float32, 4, 4)) @test @jit(norm(x_ra)) isa ConcreteRNumber{Float32} end + +@testset "cross" begin + x = Float32[0; 1; 0] + x_ra = Reactant.to_rarray(x) + y = Float32[0; 0; 1] + y_ra = Reactant.to_rarray(y) + + @test @jit(LinearAlgebra.cross(x_ra, y_ra)) ≈ LinearAlgebra.cross(x, y) + @test @jit(LinearAlgebra.cross(x_ra, y)) ≈ LinearAlgebra.cross(x, y) + @test @jit(LinearAlgebra.cross(x, y_ra)) ≈ LinearAlgebra.cross(x, y) +end \ No newline at end of file From 6e5774e544610e41c42502a50c450c257dcd4e81 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Nov 2025 11:58:34 -0600 Subject: [PATCH 05/11] feat: symmetric/hermitian/banded check --- src/stdlibs/LinearAlgebra.jl | 14 +++++++ test/integration/linear_algebra.jl | 66 ++++++++++++++++++++++-------- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 422f99c864..29508719b1 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -912,4 +912,18 @@ function LinearAlgebra.cross(x::AnyTracedRVector, y::AnyTracedRVector) return Reactant.aos_to_soa([a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1]) end +function LinearAlgebra.issymmetric(A::AnyTracedRMatrix) + axes(A, 1) == axes(A, 2) || return false + return all(A .== transpose(A)) +end + +function LinearAlgebra.ishermitian(A::AnyTracedRMatrix) + axes(A, 1) == axes(A, 2) || return false + return all(A .== adjoint(A)) +end + +function LinearAlgebra.isbanded(A::AnyTracedRMatrix, kl::Integer, ku::Integer) + return istriu(A, kl) & istril(A, ku) +end + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 2794c19a99..3fb64331f2 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -433,22 +433,56 @@ end end end -@testset "istriu" begin - x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) - x_triu = triu(x, 4) - x_triu_ra = Reactant.to_rarray(x_triu) - @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 4))) - @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 3))) - @test !Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 5))) -end +@testset "structure check" begin + @testset "istriu" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_triu = triu(x, 4) + x_triu_ra = Reactant.to_rarray(x_triu) + @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 4))) + @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 3))) + @test !Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 5))) + end + + @testset "istril" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_tril = tril(x, -4) + x_tril_ra = Reactant.to_rarray(x_tril) + @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -4))) + @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -3))) + @test !Bool(@jit(LinearAlgebra.istril(x_tril_ra, -5))) + end + + @testset "banded" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x = tril(triu(x, -3), 4) + x_ra = Reactant.to_rarray(x) -@testset "istril" begin - x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) - x_tril = tril(x, -4) - x_tril_ra = Reactant.to_rarray(x_tril) - @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -4))) - @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -3))) - @test !Bool(@jit(LinearAlgebra.istril(x_tril_ra, -5))) + @test Bool(@jit(LinearAlgebra.isbanded(x_ra, -3, 4))) + @test Bool(@jit(LinearAlgebra.isbanded(x_ra, -3, 5))) + @test Bool(@jit(LinearAlgebra.isbanded(x_ra, -4, 4))) + @test !Bool(@jit(LinearAlgebra.isbanded(x_ra, -2, 4))) + @test !Bool(@jit(LinearAlgebra.isbanded(x_ra, -3, 3))) + end + + @testset "issymmetric/ishermitian" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x2 = x .+ x' + x_ra = Reactant.to_rarray(x) + x2_ra = Reactant.to_rarray(x2) + @test Bool(@jit(LinearAlgebra.issymmetric(x2_ra))) + @test Bool(@jit(LinearAlgebra.ishermitian(x2_ra))) + @test !Bool(@jit(LinearAlgebra.issymmetric(x_ra))) + @test !Bool(@jit(LinearAlgebra.ishermitian(x_ra))) + + x = Reactant.TestUtils.construct_test_array(ComplexF32, 8, 8) + x2 = x .+ x' + x_ra = Reactant.to_rarray(x) + x2_ra = Reactant.to_rarray(x2) + @test !Bool(@jit(LinearAlgebra.issymmetric(x2_ra))) + @test Bool(@jit(LinearAlgebra.ishermitian(x2_ra))) + @test !Bool(@jit(LinearAlgebra.issymmetric(x_ra))) + @test !Bool(@jit(LinearAlgebra.ishermitian(x_ra))) + end end @testset "det" begin @@ -498,4 +532,4 @@ end @test @jit(LinearAlgebra.cross(x_ra, y_ra)) ≈ LinearAlgebra.cross(x, y) @test @jit(LinearAlgebra.cross(x_ra, y)) ≈ LinearAlgebra.cross(x, y) @test @jit(LinearAlgebra.cross(x, y_ra)) ≈ LinearAlgebra.cross(x, y) -end \ No newline at end of file +end From 0500f19b33ab255ae33bed7177be28db447e4471 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Nov 2025 12:23:53 -0600 Subject: [PATCH 06/11] feat: normalize/normalize\! --- src/stdlibs/LinearAlgebra.jl | 38 +++++++++++++++++++++++++++++- test/integration/linear_algebra.jl | 11 +++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 29508719b1..a76d337ad8 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -16,7 +16,7 @@ using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular using LinearAlgebra: - diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril, triu!, tril!, inv, inv! + diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril, triu!, tril!, inv!, rmul! using Libdl: Libdl using GPUArraysCore: @allowscalar @@ -926,4 +926,40 @@ function LinearAlgebra.isbanded(A::AnyTracedRMatrix, kl::Integer, ku::Integer) return istriu(A, kl) & istril(A, ku) end +@static if isdefined(LinearAlgebra, :__normalize!) + function LinearAlgebra.__normalize!(a::AnyTracedRArray, nrm) + # The largest positive floating point number whose inverse is less than infinity + δ = inv(prevfloat(typemax(nrm))) + @trace if nrm ≥ δ # Safe to multiply with inverse + invnrm = inv(nrm) + rmul!(a, invnrm) + else # scale elements to avoid overflow + εδ = eps(one(nrm)) / δ + rmul!(a, εδ) + rmul!(a, inv(nrm * εδ)) + end + return a + end +end + +function LinearAlgebra.rmul!(A::AnyTracedRArray, b::Number) + @. A *= b + return A +end + +function LinearAlgebra.lmul!(b::Number, A::AnyTracedRArray) + @. A = b * A + return A +end + +function LinearAlgebra.rdiv!(A::AnyTracedRArray, b::Number) + @. A /= b + return A +end + +function LinearAlgebra.ldiv!(b::Number, A::AnyTracedRArray) + @. A = b \ A + return A +end + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 3fb64331f2..7fa390b3a3 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -533,3 +533,14 @@ end @test @jit(LinearAlgebra.cross(x_ra, y)) ≈ LinearAlgebra.cross(x, y) @test @jit(LinearAlgebra.cross(x, y_ra)) ≈ LinearAlgebra.cross(x, y) end + +@testset "normalize/normalize!" begin + x = Reactant.TestUtils.construct_test_array(Float32, 4, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(LinearAlgebra.normalize(x_ra)) ≈ LinearAlgebra.normalize(x) + + LinearAlgebra.normalize!(x) + @jit LinearAlgebra.normalize!(x_ra) + @test x_ra ≈ x +end From ba3251a4e540a71414de77066c309918d9938927 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 10:07:20 -0500 Subject: [PATCH 07/11] Apply suggestion from @avik-pal --- src/TracedRArray.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 2038f7d731..73dacadfb2 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -29,10 +29,6 @@ Base.elsize(::Type{TracedRArray{T,N}}) where {T,N} = sizeof(T) # we use it Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T, x) -# Base.first is very common usecase for getting first element to get the type -# inside LinearAlgebra.jl -Base.first(x::TracedRArray{T,N}) where {T,N} = @allowscalar(x[1]) - # Base.complex Base.complex(x::TracedRArray{<:Real}) = complex.(x) Base.complex(x::TracedRArray{<:Complex}) = x From b9648db9ff9b85fe92245a061872b580d25e2a66 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 10:09:06 -0500 Subject: [PATCH 08/11] feat: cholesky decomposition (#1884) * refactor: move LU into a separate file * feat: lower cholesky * feat: lowering cholesky ldiv * test: cholesky * fix: revert change to is symm/hermitian --- src/Overlay.jl | 64 +++---- src/Reactant.jl | 2 +- src/TracedRArray.jl | 17 +- src/stdlibs/LinearAlgebra.jl | 184 +-------------------- src/stdlibs/factorization/Cholesky.jl | 96 +++++++++++ src/stdlibs/factorization/Factorization.jl | 49 ++++++ src/stdlibs/factorization/LU.jl | 131 +++++++++++++++ test/integration/linear_algebra.jl | 64 ++++++- 8 files changed, 387 insertions(+), 220 deletions(-) create mode 100644 src/stdlibs/factorization/Cholesky.jl create mode 100644 src/stdlibs/factorization/Factorization.jl create mode 100644 src/stdlibs/factorization/LU.jl diff --git a/src/Overlay.jl b/src/Overlay.jl index e837966cb7..f539ef8199 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -230,36 +230,42 @@ end end # LinearAlgebra -@reactant_overlay @noinline function LinearAlgebra.lu(x::AbstractArray; kwargs...) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu)(x; kwargs...) - end -end -@reactant_overlay @noinline function LinearAlgebra.lu( - x::AbstractArray, pivot::RowMaximum; kwargs... -) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu)(x, pivot; kwargs...) - end -end -@reactant_overlay @noinline function LinearAlgebra.lu!(x::AbstractArray; kwargs...) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu!)(x; kwargs...) - end -end -@reactant_overlay @noinline function LinearAlgebra.lu!( - x::AbstractArray, pivot::RowMaximum; kwargs... +## Various factorizations +## TODO: specialize for `cholesky!` --> cholcopy +factorization_copy(f::F, x, pivot) where {F} = x +factorization_copy(f::F, x) where {F} = x + +for (jlop, rop, default_pivot) in ( + (:lu, :overloaded_lu, RowMaximum), + (:lu!, :overloaded_lu, RowMaximum), + (:cholesky, :overloaded_cholesky, NoPivot), + (:cholesky!, :overloaded_cholesky, NoPivot), ) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu!)(x, pivot; kwargs...) + @eval begin + @reactant_overlay @noinline function LinearAlgebra.$(jlop)( + x::AbstractArray; kwargs... + ) + if use_overlayed_version(x) + pivot = $(default_pivot)() + return TracedLinearAlgebra.$(rop)( + factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs... + ) + else + return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...) + end + end + + @reactant_overlay @noinline function LinearAlgebra.$(jlop)( + x::AbstractArray, pivot::$(default_pivot); kwargs... + ) + if use_overlayed_version(x) + return TracedLinearAlgebra.$(rop)( + factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs... + ) + else + return Base.inferencebarrier(LinearAlgebra.$(jlop))(x, pivot; kwargs...) + end + end end end diff --git a/src/Reactant.jl b/src/Reactant.jl index 69cee2b3b0..e164886fb9 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -3,7 +3,7 @@ module Reactant using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue, materialize_traced_array -using LinearAlgebra: LinearAlgebra, RowMaximum +using LinearAlgebra: LinearAlgebra, RowMaximum, NoPivot using Random: Random, AbstractRNG using EnumX: @enumx using Functors: Functors, @leaf diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 73dacadfb2..f01c57231b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -730,11 +730,20 @@ end # stack function overloaded_stack(dims::Union{Integer,Colon}, xs) - @assert allequal([ndims(x) for x in xs]) "All arrays must have the same number of \ - dimensions..." - dims = dims isa Colon ? ndims(first(xs)) + 1 : dims + dims = dims isa Colon ? nothing : dims res = [] - for x in xs + prev_dims = nothing + for x in unwrapped_broadcast(identity, xs) + cur_dims = ndims(x) + if prev_dims === nothing + prev_dims = cur_dims + else + @assert prev_dims == cur_dims "All arrays must have the same number of \ + dimensions..." + end + + dims === nothing && (dims = cur_dims + 1) + new_shape = ntuple( i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1 ) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index a76d337ad8..ac23ae964f 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -12,7 +12,7 @@ using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data! using ..Ops: @opcall using LinearAlgebra: LinearAlgebra, BLAS -using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum +using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular using LinearAlgebra: @@ -40,6 +40,8 @@ function __init__() return nothing end +include("factorization/Factorization.jl") + # Various Wrapper Arrays defined in LinearAlgebra function ReactantCore.materialize_traced_array( x::Transpose{TracedRNumber{T},<:AnyTracedRArray} @@ -633,186 +635,6 @@ LinearAlgebra.transpose!(B::AnyTracedRMatrix, A::AnyTracedRMatrix) = copy!(B, tr LinearAlgebra.adjoint!(B::AnyTracedRMatrix, A::AnyTracedRMatrix) = copy!(B, adjoint(A)) -# Supports batched factorization -abstract type GeneralizedFactorization{T} <: Factorization{T} end - -function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization) - return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f) -end - -function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization) - return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f) -end - -const GeneralizedTransposeFactorization{T} = - LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T} -const GeneralizedAdjointFactorization{T} = - LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T} - -# LU Factorization -struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: - GeneralizedFactorization{T} - factors::S - ipiv::P - perm::P - info::I -end - -Base.size(lu::GeneralizedLU) = size(lu.factors) -Base.size(lu::GeneralizedLU, i) = size(lu.factors, i) -Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) -function Base.copy(lu::GeneralizedLU) - return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) -end - -function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} - @assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1 - @assert ndims(info) == ndims(factors) - 2 - return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) -end - -function overloaded_lu(x::AbstractArray, args...; kwargs...) - return overloaded_lu(Reactant.promote_to(TracedRArray, x), args...; kwargs...) -end - -function overloaded_lu( - A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false -) where {T,N} - # TODO: don't ignore the check and allowsingular flags - # Batching here is in the last dimensions. `Ops.lu` expects the last dimensions - permdims = vcat(collect(Int64, 3:N), 1, 2) - A = @opcall transpose(materialize_traced_array(A), permdims) - factors, ipiv, perm, info = @opcall lu(A) - - # Permute back to the original dimensions - perm_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) - factors = @opcall transpose(factors, invperm(permdims)) - ipiv = @opcall transpose(ipiv, perm_perm) - perm = @opcall transpose(perm, perm_perm) - return GeneralizedLU(factors, ipiv, perm, info) -end - -function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} -) where {T,P,I,N,M} - @assert N == M + 1 - ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...)) - return B -end - -function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} -) where {T,P,I} - B .= _lu_solve_core(lu.factors, B, lu.perm) - return B -end - -function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} -) where {T,P,I,N} - batch_shape = size(lu.factors)[3:end] - @assert batch_shape == size(B)[3:end] - - permutation = vcat(collect(Int64, 3:N), 1, 2) - - factors = @opcall transpose(materialize_traced_array(lu.factors), permutation) - B_permuted = @opcall transpose(materialize_traced_array(B), permutation) - perm = @opcall transpose( - materialize_traced_array(lu.perm), vcat(collect(Int64, 2:(N - 1)), 1) - ) - - res = @opcall transpose( - only( - @opcall( - batch( - _lu_solve_core, [factors, B_permuted, perm], collect(Int64, batch_shape) - ) - ), - ), - invperm(permutation), - ) - B .= res - return B -end - -function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} - n = LinearAlgebra.checksquare(lu) - # TODO: check for non-singular matrices - - P = prod(LinearAlgebra.diag(lu.factors)) - return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P -end - -function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} - n = LinearAlgebra.checksquare(lu) - Treal = real(T) - # TODO: check for non-singular matrices - - d = LinearAlgebra.diag(lu.factors) - absdet = sum(log ∘ abs, d) - P = prod(sign, d) - s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P - return absdet, s -end - -for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), - aType in (:AbstractVecOrMat, :AbstractArray) - - @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType) - # TODO: implement this - error("`$(f_wrapper)` is not supported yet for LU.") - return nothing - end -end - -# currently we lower inverse to lu decomposition + triangular solve. we should -# instead emit getri and lower that to a fallback if the backend doesn't support -# it. -function LinearAlgebra.inv!(lu::GeneralizedLU) - @assert ndims(lu) == 2 "Only implemented for 2D tensors" - rhs = Reactant.promote_to( - TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1)) - ) - ldiv!(lu, rhs) - return rhs -end - -function _lu_solve_core(factors::AbstractMatrix, B::AbstractMatrix, perm::AbstractVector) - permuted_B = B[Int64.(perm), :] - return UpperTriangular(factors) \ (UnitLowerTriangular(factors) \ permuted_B) -end - -# Overload \ to support batched factorization -for FT in ( - :GeneralizedFactorization, - :GeneralizedTransposeFactorization, - :GeneralizedAdjointFactorization, -) - for aType in (:AbstractVecOrMat, :AbstractArray) - @eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B) - end - - @eval Base.:(\)( - F::$FT{T}, B::Union{Array{Complex{T},1},Array{Complex{T},2}} - ) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B) -end - -function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) -end - -function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray) - return conj!(adjoint(F.parent) \ conj.(B)) -end - -function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) -end - # indexing into specific wrapepd array types # TODO: specialize these ones. We don't need to make the arrays dense (though our passes # should be able to optimize them out) diff --git a/src/stdlibs/factorization/Cholesky.jl b/src/stdlibs/factorization/Cholesky.jl new file mode 100644 index 0000000000..e0f30ad586 --- /dev/null +++ b/src/stdlibs/factorization/Cholesky.jl @@ -0,0 +1,96 @@ +struct GeneralizedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <: + GeneralizedFactorization{T} + factors::S + uplo::Char + info::I +end + +function GeneralizedCholesky(factors::S, uplo::Char, info::I) where {S,I} + @assert ndims(info) == ndims(factors) - 2 + return GeneralizedCholesky{eltype(factors),S,I}(factors, uplo, info) +end + +Base.size(c::GeneralizedCholesky) = size(c.factors) +Base.ndims(c::GeneralizedCholesky) = ndims(c.factors) + +function overloaded_cholesky(A::AbstractArray, ::NoPivot; check::Bool=false) + return overloaded_cholesky(Reactant.promote_to(TracedRArray, A), NoPivot(); check) +end + +function overloaded_cholesky( + A::AnyTracedRArray{T,N}, ::NoPivot; check::Bool=false +) where {T,N} + # TODO: dont ignore check + # move the batching dims to the front + permdims = vcat(collect(Int64, 3:N), 1, 2) + A = @opcall transpose(materialize_traced_array(A), permdims) + + factors = @opcall cholesky(A; lower=false) + factors = @opcall transpose(factors, invperm(permdims)) + + # stablehlo doesn't return the info + info = materialize_traced_array( + dropdims( + Reactant.CallWithReactant(mapreduce)( + isfinite, &, mapslices(LinearAlgebra.triu, factors; dims=(1, 2)); dims=1:2 + ); + dims=(1, 2), + ), + ) + if N == 2 + info = TracedRNumber{Bool}((), info.mlir_data) + end + + return GeneralizedCholesky(factors, 'U', info) +end + +function LinearAlgebra.ldiv!( + F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M} +) where {T,N,M} + @assert N == M + 1 + ldiv!(F, reshape(B, size(B, 1), 1, size(B)[2:end]...)) + return B +end + +function LinearAlgebra.ldiv!( + F::GeneralizedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2} +) where {T} + B .= _cholesky_solve_core(F.factors, B, F.uplo) + return B +end + +function LinearAlgebra.ldiv!( + F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N} +) where {T,N} + batch_shape = size(F.factors)[3:end] + @assert batch_shape == size(B)[3:end] + + base_fn = F.uplo == 'U' ? _cholesky_solve_core_upper : _cholesky_solve_core_lower + + permutation = vcat(collect(Int64, 3:N), 1, 2) + + factors = @opcall transpose(materialize_traced_array(F.factors), permutation) + B_permuted = @opcall transpose(materialize_traced_array(B), permutation) + + res = @opcall transpose( + only(@opcall(batch(base_fn, [factors, B_permuted], collect(Int64, batch_shape)))), + invperm(permutation), + ) + B .= res + return B +end + +function _cholesky_solve_core(factors::AbstractMatrix, B::AbstractMatrix, uplo::Char) + if uplo == 'U' + return _cholesky_solve_core_upper(factors, B) + else + return _cholesky_solve_core_lower(factors, B) + end +end + +function _cholesky_solve_core_lower(factors::AbstractMatrix, B::AbstractMatrix) + return adjoint(LowerTriangular(factors)) \ (LowerTriangular(factors) \ B) +end +function _cholesky_solve_core_upper(factors::AbstractMatrix, B::AbstractMatrix) + return UpperTriangular(factors) \ (adjoint(UpperTriangular(factors)) \ B) +end diff --git a/src/stdlibs/factorization/Factorization.jl b/src/stdlibs/factorization/Factorization.jl new file mode 100644 index 0000000000..5114a47d7b --- /dev/null +++ b/src/stdlibs/factorization/Factorization.jl @@ -0,0 +1,49 @@ +# Supports batched factorization +abstract type GeneralizedFactorization{T} <: Factorization{T} end + +function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization) + return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f) +end + +function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization) + return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f) +end + +const GeneralizedTransposeFactorization{T} = + LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T} +const GeneralizedAdjointFactorization{T} = + LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T} + +include("Cholesky.jl") +include("LU.jl") + +# Overload \ to support batched factorization +for FT in ( + :GeneralizedFactorization, + :GeneralizedTransposeFactorization, + :GeneralizedAdjointFactorization, +) + for aType in (:AbstractVecOrMat, :AbstractArray) + @eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B) + end + + @eval Base.:(\)( + F::$FT{T}, B::Union{Array{Complex{T},1},Array{Complex{T},2}} + ) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B) +end + +function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray) + return ldiv!( + F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) + ) +end + +function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray) + return conj!(adjoint(F.parent) \ conj.(B)) +end + +function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray) + return ldiv!( + F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) + ) +end diff --git a/src/stdlibs/factorization/LU.jl b/src/stdlibs/factorization/LU.jl new file mode 100644 index 0000000000..144e7b9dcd --- /dev/null +++ b/src/stdlibs/factorization/LU.jl @@ -0,0 +1,131 @@ +struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: + GeneralizedFactorization{T} + factors::S + ipiv::P + perm::P + info::I +end + +Base.size(lu::GeneralizedLU) = size(lu.factors) +Base.size(lu::GeneralizedLU, i) = size(lu.factors, i) +Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) +function Base.copy(lu::GeneralizedLU) + return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) +end + +function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} + @assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1 + @assert ndims(info) == ndims(factors) - 2 + return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) +end + +function overloaded_lu(x::AbstractArray, args...; kwargs...) + return overloaded_lu(Reactant.promote_to(TracedRArray, x), args...; kwargs...) +end + +function overloaded_lu( + A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false +) where {T,N} + # TODO: don't ignore the check and allowsingular flags + # Batching here is in the last dimensions. `Ops.lu` expects the last dimensions + permdims = vcat(collect(Int64, 3:N), 1, 2) + A = @opcall transpose(materialize_traced_array(A), permdims) + factors, ipiv, perm, info = @opcall lu(A) + + # Permute back to the original dimensions + perm_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) + factors = @opcall transpose(factors, invperm(permdims)) + ipiv = @opcall transpose(ipiv, perm_perm) + perm = @opcall transpose(perm, perm_perm) + return GeneralizedLU(factors, ipiv, perm, info) +end + +function LinearAlgebra.ldiv!( + lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} +) where {T,P,I,N,M} + @assert N == M + 1 + ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...)) + return B +end + +function LinearAlgebra.ldiv!( + lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} +) where {T,P,I} + B .= _lu_solve_core(lu.factors, B, lu.perm) + return B +end + +function LinearAlgebra.ldiv!( + lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} +) where {T,P,I,N} + batch_shape = size(lu.factors)[3:end] + @assert batch_shape == size(B)[3:end] + + permutation = vcat(collect(Int64, 3:N), 1, 2) + + factors = @opcall transpose(materialize_traced_array(lu.factors), permutation) + B_permuted = @opcall transpose(materialize_traced_array(B), permutation) + perm = @opcall transpose( + materialize_traced_array(lu.perm), vcat(collect(Int64, 2:(N - 1)), 1) + ) + + res = @opcall transpose( + only( + @opcall( + batch( + _lu_solve_core, [factors, B_permuted, perm], collect(Int64, batch_shape) + ) + ), + ), + invperm(permutation), + ) + B .= res + return B +end + +function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) + # TODO: check for non-singular matrices + + P = prod(LinearAlgebra.diag(lu.factors)) + return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P +end + +function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) + Treal = real(T) + # TODO: check for non-singular matrices + + d = LinearAlgebra.diag(lu.factors) + absdet = sum(log ∘ abs, d) + P = prod(sign, d) + s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P + return absdet, s +end + +for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), + aType in (:AbstractVecOrMat, :AbstractArray) + + @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType) + # TODO: implement this + error("`$(f_wrapper)` is not supported yet for LU.") + return nothing + end +end + +# currently we lower inverse to lu decomposition + triangular solve. we should +# instead emit getri and lower that to a fallback if the backend doesn't support +# it. +function LinearAlgebra.inv!(lu::GeneralizedLU) + @assert ndims(lu) == 2 "Only implemented for 2D tensors" + rhs = Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1)) + ) + ldiv!(lu, rhs) + return rhs +end + +function _lu_solve_core(factors::AbstractMatrix, B::AbstractMatrix, perm::AbstractVector) + permuted_B = B[Int64.(perm), :] + return UpperTriangular(factors) \ (UnitLowerTriangular(factors) \ permuted_B) +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 7fa390b3a3..08199af5a4 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -351,24 +351,31 @@ end end end -solve_with_lu(A, b) = lu(A) \ b -function solve_with_lu_batched(A::AbstractArray{T,N}, B::AbstractArray{T,N}) where {T,N} +solve_with_fact(f::F, A, b) where {F} = f(A) \ b +function solve_with_fact_batched( + f::F, A::AbstractArray{T,N}, B::AbstractArray{T,N} +) where {F,T,N} A2 = reshape(A, size(A, 1), size(A, 2), prod(size(A)[3:end])) B2 = reshape(B, size(B, 1), size(B, 2), prod(size(B)[3:end])) @assert size(A2, 3) == size(B2, 3) return reshape( - stack(lu(view(A2, :, :, i)) \ view(B2, :, :, i) for i in axes(A2, 3)), + stack(f(view(A2, :, :, i)) \ view(B2, :, :, i) for i in axes(A2, 3)), size(A2, 1), size(B2, 2), size(A)[3:end]..., ) end -function solve_with_lu_batched(A::AbstractArray{T,N}, b::AbstractArray{T,M}) where {T,N,M} +function solve_with_fact_batched( + f::F, A::AbstractArray{T,N}, b::AbstractArray{T,M} +) where {F,T,N,M} @assert N == M + 1 B = reshape(b, size(b, 1), 1, size(b)[2:end]...) - return dropdims(solve_with_lu_batched(A, B); dims=2) + return dropdims(solve_with_fact_batched(f, A, B); dims=2) end +solve_with_lu(A, b) = solve_with_fact(lu, A, b) +solve_with_lu_batched(A, b) = solve_with_fact_batched(lu, A, b) + @testset "LU Factorization" begin @testset "Un-batched" begin @testset for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -433,6 +440,53 @@ end end end +solve_with_cholesky(A, b) = solve_with_fact(cholesky, A, b) +solve_with_cholesky_batched(A, b) = solve_with_fact_batched(cholesky, A, b) + +@testset "Cholesky Factorization" begin + @testset "Un-batched" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + (T == ComplexF64 || T == Float64) && RunningOnTPU && continue + + A = rand(T, 4, 4) + A = A * A' + A_ra = Reactant.to_rarray(A) + + b = rand(T, 4) + b_ra = Reactant.to_rarray(b) + + B = rand(T, 4, 3) + B_ra = Reactant.to_rarray(B) + + @test @jit(solve_with_cholesky(A_ra, b_ra)) ≈ solve_with_cholesky(A, b) atol = + 1e-4 rtol = 1e-2 + @test @jit(solve_with_cholesky(A_ra, B_ra)) ≈ solve_with_cholesky(A, B) atol = + 1e-4 rtol = 1e-2 + end + end + + @testset "Batched" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + (T == ComplexF64 || T == Float64) && RunningOnTPU && continue + + A = rand(T, 4, 4, 6) + A = reshape(stack((r * r' for r in eachslice(A; dims=3))), 4, 4, 3, 2) + A_ra = Reactant.to_rarray(A) + + b = rand(T, 4, 3, 2) + b_ra = Reactant.to_rarray(b) + + B = rand(T, 4, 5, 3, 2) + B_ra = Reactant.to_rarray(B) + + @test @jit(solve_with_cholesky(A_ra, b_ra)) ≈ solve_with_cholesky_batched(A, b) atol = + 1e-4 rtol = 1e-2 + @test @jit(solve_with_cholesky(A_ra, B_ra)) ≈ solve_with_cholesky_batched(A, B) atol = + 1e-4 rtol = 1e-2 + end + end +end + @testset "structure check" begin @testset "istriu" begin x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) From 9b183f47928e0f5152c6ff24aa06c76242781b31 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 10:41:07 -0500 Subject: [PATCH 09/11] fix: overload normalize --- src/stdlibs/LinearAlgebra.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index ac23ae964f..ff67859c5e 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -748,6 +748,16 @@ function LinearAlgebra.isbanded(A::AnyTracedRMatrix, kl::Integer, ku::Integer) return istriu(A, kl) & istril(A, ku) end +function LinearAlgebra.normalize(a::AnyTracedRArray{T}, p::Real=2) where {T} + nrm = LinearAlgebra.norm(a, p) + if !isempty(a) + aa = LinearAlgebra.copymutable_oftype(a, typeof(zero(T) / nrm)) + return LinearAlgebra.__normalize!(aa, nrm) + else + return typeof(zero(T) / nrm)[] + end +end + @static if isdefined(LinearAlgebra, :__normalize!) function LinearAlgebra.__normalize!(a::AnyTracedRArray, nrm) # The largest positive floating point number whose inverse is less than infinity From e3eb94ed844a1c21977dc48b0a721cde5f300819 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 14:10:16 -0500 Subject: [PATCH 10/11] test: use low cond number matrix --- test/integration/linear_algebra.jl | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 08199af5a4..609e342803 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -443,12 +443,36 @@ end solve_with_cholesky(A, b) = solve_with_fact(cholesky, A, b) solve_with_cholesky_batched(A, b) = solve_with_fact_batched(cholesky, A, b) +function random_matrix_with_cond( + ::Type{T}, rows::Int, cols::Int, cond_number::Float64 +) where {T} + # Generate random orthogonal matrices U and V + U = ( + LinearAlgebra.qr(randn(rows, rows)).Q * + Diagonal(sign.(diag(LinearAlgebra.qr(randn(rows, rows)).R))) + ) + V = ( + LinearAlgebra.qr(randn(cols, cols)).Q * + Diagonal(sign.(diag(LinearAlgebra.qr(randn(cols, cols)).R))) + ) + + min_dim = min(rows, cols) + singular_values = exp.(range(log(1.0), log(1.0 / cond_number); length=min_dim)) + + S = zeros(Float64, rows, cols) + @inbounds for i in 1:min_dim + S[i, i] = singular_values[i] + end + + return T.(U * S * V') +end + @testset "Cholesky Factorization" begin @testset "Un-batched" begin @testset for T in (Float32, Float64, ComplexF32, ComplexF64) (T == ComplexF64 || T == Float64) && RunningOnTPU && continue - A = rand(T, 4, 4) + A = random_matrix_with_cond(T, 4, 4, 1.001) # avoid ill conditioned A = A * A' A_ra = Reactant.to_rarray(A) @@ -469,7 +493,9 @@ solve_with_cholesky_batched(A, b) = solve_with_fact_batched(cholesky, A, b) @testset for T in (Float32, Float64, ComplexF32, ComplexF64) (T == ComplexF64 || T == Float64) && RunningOnTPU && continue - A = rand(T, 4, 4, 6) + A = stack( + random_matrix_with_cond(T, 4, 4, 1.001) for _ in 1:6 + ) A = reshape(stack((r * r' for r in eachslice(A; dims=3))), 4, 4, 3, 2) A_ra = Reactant.to_rarray(A) From 4fb1151cfa9d0902197e0f17c355ea09fa214340 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 14:15:07 -0500 Subject: [PATCH 11/11] Update test/integration/linear_algebra.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/integration/linear_algebra.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 609e342803..cbf2cb02af 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -493,9 +493,7 @@ end @testset for T in (Float32, Float64, ComplexF32, ComplexF64) (T == ComplexF64 || T == Float64) && RunningOnTPU && continue - A = stack( - random_matrix_with_cond(T, 4, 4, 1.001) for _ in 1:6 - ) + A = stack(random_matrix_with_cond(T, 4, 4, 1.001) for _ in 1:6) A = reshape(stack((r * r' for r in eachslice(A; dims=3))), 4, 4, 3, 2) A_ra = Reactant.to_rarray(A)