Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ PythonCall = "0.9.25"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.16"
Reactant_jll = "0.0.263"
Reactant_jll = "0.0.264"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
16 changes: 16 additions & 0 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,22 @@ for (jlop, rop, default_pivot) in (
end
end

for (jlop, rop) in ((:svd, :overloaded_svd),)
@eval begin
@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
x::AbstractArray; kwargs...
)
if use_overlayed_version(x)
return TracedLinearAlgebra.$(rop)(
factorization_copy(LinearAlgebra.$(jlop), x); kwargs...
)
else
return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...)
end
end
end
end

@reactant_overlay @noinline function LinearAlgebra.dot(x::AbstractArray, y::AbstractArray)
if use_overlayed_version(x) || use_overlayed_version(y)
return TracedLinearAlgebra.overloaded_dot(x, y)
Expand Down
7 changes: 4 additions & 3 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ..MLIR: MLIR
using ..Reactant: Reactant, Ops
using ..Reactant:
TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector
using ..Reactant: call_with_reactant
using ..Reactant: call_with_reactant, unwrapped_eltype, promote_to
using ReactantCore: ReactantCore, materialize_traced_array, @trace
using Reactant_jll: Reactant_jll

Expand All @@ -15,8 +15,9 @@ using LinearAlgebra: LinearAlgebra, BLAS
using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot
using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal
using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular
using LinearAlgebra:
diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril, triu!, tril!, inv!, rmul!
using LinearAlgebra: I, diag, diagm, ldiv!, det, logabsdet, istriu, istril, triu!, tril!
using LinearAlgebra: inv!, rmul!, normalize
using LinearAlgebra: svd, lu
using Libdl: Libdl
using GPUArraysCore: @allowscalar

Expand Down
21 changes: 11 additions & 10 deletions src/stdlibs/factorization/Cholesky.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
struct GeneralizedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <:
GeneralizedFactorization{T}
struct BatchedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <:
BatchedFactorization{T}
factors::S
uplo::Char
info::I
end

function GeneralizedCholesky(factors::S, uplo::Char, info::I) where {S,I}
function BatchedCholesky(factors::S, uplo::Char, info::I) where {S,I}
@assert ndims(info) == ndims(factors) - 2
return GeneralizedCholesky{eltype(factors),S,I}(factors, uplo, info)
return BatchedCholesky{eltype(factors),S,I}(factors, uplo, info)
end

Base.size(c::GeneralizedCholesky) = size(c.factors)
Base.ndims(c::GeneralizedCholesky) = ndims(c.factors)
Base.size(c::BatchedCholesky) = size(c.factors)
Base.size(c::BatchedCholesky, i::Integer) = size(c.factors, i)
Base.ndims(c::BatchedCholesky) = ndims(c.factors)

function overloaded_cholesky(A::AbstractArray, ::NoPivot; check::Bool=false)
return overloaded_cholesky(Reactant.promote_to(TracedRArray, A), NoPivot(); check)
Expand Down Expand Up @@ -41,26 +42,26 @@ function overloaded_cholesky(
info = TracedRNumber{Bool}((), info.mlir_data)
end

return GeneralizedCholesky(factors, 'U', info)
return BatchedCholesky(factors, 'U', info)
end

function LinearAlgebra.ldiv!(
F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M}
F::BatchedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M}
) where {T,N,M}
@assert N == M + 1
ldiv!(F, reshape(B, size(B, 1), 1, size(B)[2:end]...))
return B
end

function LinearAlgebra.ldiv!(
F::GeneralizedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2}
F::BatchedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2}
) where {T}
B .= _cholesky_solve_core(F.factors, B, F.uplo)
return B
end

function LinearAlgebra.ldiv!(
F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N}
F::BatchedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N}
) where {T,N}
batch_shape = size(F.factors)[3:end]
@assert batch_shape == size(B)[3:end]
Expand Down
59 changes: 37 additions & 22 deletions src/stdlibs/factorization/Factorization.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
# Supports batched factorization
abstract type GeneralizedFactorization{T} <: Factorization{T} end
abstract type BatchedFactorization{T} <: Factorization{T} end

function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization)
function LinearAlgebra.TransposeFactorization(f::BatchedFactorization)
return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f)
end

function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization)
function LinearAlgebra.AdjointFactorization(f::BatchedFactorization)
return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f)
end

const GeneralizedTransposeFactorization{T} =
LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T}
const GeneralizedAdjointFactorization{T} =
LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T}
const BatchedTransposeFactorization{T} =
LinearAlgebra.TransposeFactorization{T,<:BatchedFactorization{T}} where {T}
const BatchedAdjointFactorization{T} =
LinearAlgebra.AdjointFactorization{T,<:BatchedFactorization{T}} where {T}

include("Cholesky.jl")
include("LU.jl")
include("SVD.jl")

# Overload \ to support batched factorization
for FT in (
:GeneralizedFactorization,
:GeneralizedTransposeFactorization,
:GeneralizedAdjointFactorization,
)
for FT in
(:BatchedFactorization, :BatchedTransposeFactorization, :BatchedAdjointFactorization)
for aType in (:AbstractVecOrMat, :AbstractArray)
@eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B)
end
Expand All @@ -32,18 +29,36 @@ for FT in (
) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B)
end

function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray)
return ldiv!(
F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B))))
)
function __get_B(F::Factorization, B::AbstractArray)
m, n = size(F, 1), size(F, 2)
if m != size(B, 1)
throw(DimensionMismatch("arguments must have the same number of rows"))
end

TFB = typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))

BB = similar(B, TFB, max(size(B, 1), n), size(B)[2:end]...)
if n > size(B, 1)
BB[1:m, ntuple(Returns(Colon()), ndims(B) - 1)...] = B
else
copyto!(BB, B)
end

return BB
end

function _overloaded_backslash(F::BatchedFactorization, B::AbstractArray)
BB = __get_B(F, B)
ldiv!(F, BB)
return BB[1:size(F, 2), ntuple(Returns(Colon()), ndims(B) - 1)...]
end

function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray)
function _overloaded_backslash(F::BatchedTransposeFactorization, B::AbstractArray)
return conj!(adjoint(F.parent) \ conj.(B))
end

function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray)
return ldiv!(
F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B))))
)
function _overloaded_backslash(F::BatchedAdjointFactorization, B::AbstractArray)
BB = __get_B(F, B)
ldiv!(F, BB)
return BB[1:size(F)[2], ntuple(Returns(Colon()), ndims(B) - 1)...]
end
34 changes: 17 additions & 17 deletions src/stdlibs/factorization/LU.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <:
GeneralizedFactorization{T}
struct BatchedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <:
BatchedFactorization{T}
factors::S
ipiv::P
perm::P
info::I
end

Base.size(lu::GeneralizedLU) = size(lu.factors)
Base.size(lu::GeneralizedLU, i) = size(lu.factors, i)
Base.ndims(lu::GeneralizedLU) = ndims(lu.factors)
function Base.copy(lu::GeneralizedLU)
return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info))
Base.size(lu::BatchedLU) = size(lu.factors)
Base.size(lu::BatchedLU, i::Integer) = size(lu.factors, i)
Base.ndims(lu::BatchedLU) = ndims(lu.factors)
function Base.copy(lu::BatchedLU)
return BatchedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info))
end

function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I}
function BatchedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I}
@assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1
@assert ndims(info) == ndims(factors) - 2
return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info)
return BatchedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info)
end

function overloaded_lu(x::AbstractArray, args...; kwargs...)
Expand All @@ -37,26 +37,26 @@ function overloaded_lu(
factors = @opcall transpose(factors, invperm(permdims))
ipiv = @opcall transpose(ipiv, perm_perm)
perm = @opcall transpose(perm, perm_perm)
return GeneralizedLU(factors, ipiv, perm, info)
return BatchedLU(factors, ipiv, perm, info)
end

function LinearAlgebra.ldiv!(
lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M}
lu::BatchedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M}
) where {T,P,I,N,M}
@assert N == M + 1
ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...))
return B
end

function LinearAlgebra.ldiv!(
lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2}
lu::BatchedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2}
) where {T,P,I}
B .= _lu_solve_core(lu.factors, B, lu.perm)
return B
end

function LinearAlgebra.ldiv!(
lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N}
lu::BatchedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N}
) where {T,P,I,N}
batch_shape = size(lu.factors)[3:end]
@assert batch_shape == size(B)[3:end]
Expand All @@ -83,15 +83,15 @@ function LinearAlgebra.ldiv!(
return B
end

function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T}
function LinearAlgebra.det(lu::BatchedLU{T,<:AbstractMatrix}) where {T}
n = LinearAlgebra.checksquare(lu)
# TODO: check for non-singular matrices

P = prod(LinearAlgebra.diag(lu.factors))
return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P
end

function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T}
function LinearAlgebra.logabsdet(lu::BatchedLU{T,<:AbstractMatrix}) where {T}
n = LinearAlgebra.checksquare(lu)
Treal = real(T)
# TODO: check for non-singular matrices
Expand All @@ -106,7 +106,7 @@ end
for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization),
aType in (:AbstractVecOrMat, :AbstractArray)

@eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType)
@eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:BatchedLU}, B::$aType)
# TODO: implement this
error("`$(f_wrapper)` is not supported yet for LU.")
return nothing
Expand All @@ -116,7 +116,7 @@ end
# currently we lower inverse to lu decomposition + triangular solve. we should
# instead emit getri and lower that to a fallback if the backend doesn't support
# it.
function LinearAlgebra.inv!(lu::GeneralizedLU)
function LinearAlgebra.inv!(lu::BatchedLU)
@assert ndims(lu) == 2 "Only implemented for 2D tensors"
rhs = Reactant.promote_to(
TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1))
Expand Down
Loading
Loading