diff --git a/Project.toml b/Project.toml index de46232..a5193ad 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.2.7" +version = "0.2.8" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl index 78362c2..eb6587a 100644 --- a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl +++ b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl @@ -1,6 +1,6 @@ module KroneckerArraysTensorAlgebraExt -using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2 +using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, arg1, arg2 using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize @@ -10,7 +10,7 @@ struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle end KroneckerArrays.arg1(style::KroneckerFusion) = style.a KroneckerArrays.arg2(style::KroneckerFusion) = style.b -function TensorAlgebra.FusionStyle(a::KroneckerArray) +function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a))) end function matricize_kronecker( diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 28328a7..6a1e454 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -1,3 +1,35 @@ +""" + abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end + +Abstract supertype for arrays that have a kronecker product structure, +i.e. that can be written as `AB = A ⊗ B`. +""" +abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end + +const AbstractKroneckerVector{T} = AbstractKroneckerArray{T, 1} +const AbstractKroneckerMatrix{T} = AbstractKroneckerArray{T, 2} + +@doc """ + arg1(AB::AbstractKroneckerArray{T, N}) + +Extract the first factor (`A`) of the Kronecker array `AB = A ⊗ B`. +""" arg1 + +@doc """ + arg2(AB::AbstractKroneckerArray{T, N}) + +Extract the second factor (`B`) of the Kronecker array `AB = A ⊗ B`. +""" arg2 + +arg1type(x::AbstractKroneckerArray) = arg1type(typeof(x)) +arg1type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` subtypes have to implement `arg1type`.") +arg2type(x::AbstractKroneckerArray) = arg2type(typeof(x)) +arg2type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` subtypes have to implement `arg2type`.") + +arguments(a::AbstractKroneckerArray) = (arg1(a), arg2(a)) +arguments(a::AbstractKroneckerArray, n::Int) = arguments(a)[n] +argument_types(a::AbstractKroneckerArray) = argument_types(typeof(a)) + function unwrap_array(a::AbstractArray) p = parent(a) p ≡ a && return a @@ -26,7 +58,7 @@ function _convert(A::Type{<:Diagonal}, a::AbstractMatrix) end struct KroneckerArray{T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}} <: - AbstractArray{T, N} + AbstractKroneckerArray{T, N} arg1::A1 arg2::A2 end @@ -48,6 +80,10 @@ const KroneckerVector{T, A1 <: AbstractVector{T}, A2 <: AbstractVector{T}} = Kro @inline arg1(a::KroneckerArray) = getfield(a, :arg1) @inline arg2(a::KroneckerArray) = getfield(a, :arg2) +arg1type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A1 +arg2type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A2 + +argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2) function mutate_active_args!(f!, f, dest, src) (isactive(arg1(dest)) || isactive(arg2(dest))) || @@ -66,7 +102,7 @@ function mutate_active_args!(f!, f, dest, src) end using Adapt: Adapt, adapt -function Adapt.adapt_structure(to, a::KroneckerArray) +function Adapt.adapt_structure(to, a::AbstractKroneckerArray) # TODO: Is this a good definition? It is similar to # the definition of `similar`. return if isactive(arg1(a)) == isactive(arg2(a)) @@ -78,18 +114,22 @@ function Adapt.adapt_structure(to, a::KroneckerArray) end end -function Base.copy(a::KroneckerArray) - return copy(arg1(a)) ⊗ copy(arg2(a)) +Base.copy(a::AbstractKroneckerArray) = copy(arg1(a)) ⊗ copy(arg2(a)) +function Base.copy!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray) + return mutate_active_args!(copy!, copy, dest, src) end +# TODO: copyto! is typically reserved for contiguous copies (i.e. also for copying from a +# vector into an array), it might be better to not define that here. function Base.copyto!(dest::KroneckerArray{<:Any, N}, src::KroneckerArray{<:Any, N}) where {N} return mutate_active_args!(copyto!, copy, dest, src) end function Base.convert( - ::Type{KroneckerArray{T, N, A1, A2}}, a::KroneckerArray - ) where {T, N, A1, A2} - return _convert(A1, arg1(a)) ⊗ _convert(A2, arg2(a)) + ::Type{KroneckerArray{T, N, A1, A2}}, a::AbstractKroneckerArray + )::KroneckerArray{T, N, A1, A2} where {T, N, A1, A2} + typeof(a) === KroneckerArray{T, N, A1, A2} && return a + return KroneckerArray(_convert(A1, arg1(a)), _convert(A2, arg2(a))) end # Promote the element type if needed. @@ -98,7 +138,7 @@ end maybe_promot_eltype(a, elt) = eltype(a) <: elt ? a : elt.(a) function Base.similar( - a::KroneckerArray, + a::AbstractKroneckerArray, elt::Type, axs::Tuple{ CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, @@ -115,7 +155,7 @@ function Base.similar( maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt, arg2.(axs)) end end -function Base.similar(a::KroneckerArray, elt::Type) +function Base.similar(a::AbstractKroneckerArray, elt::Type) # TODO: Is this a good definition? return if isactive(arg1(a)) == isactive(arg2(a)) similar(arg1(a), elt) ⊗ similar(arg2(a), elt) @@ -125,7 +165,7 @@ function Base.similar(a::KroneckerArray, elt::Type) maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt) end end -function Base.similar(a::KroneckerArray) +function Base.similar(a::AbstractKroneckerArray) # TODO: Is this a good definition? return if isactive(arg1(a)) == isactive(arg2(a)) similar(arg1(a)) ⊗ similar(arg2(a)) @@ -147,16 +187,18 @@ function Base.similar( end function Base.similar( - arrayt::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}, + ::Type{ArrayT}, axs::Tuple{ CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}}, }, - ) where {A1, A2} + ) where {ArrayT <: AbstractKroneckerArray} + A1, A2 = arg1type(ArrayT), arg2type(ArrayT) return similar(A1, map(arg1, axs)) ⊗ similar(A2, map(arg2, axs)) end function Base.similar( - ::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}, sz::Tuple{Int, Vararg{Int}} - ) where {A1, A2} + ::Type{ArrayT}, sz::Tuple{Int, Vararg{Int}} + ) where {ArrayT <: AbstractKroneckerArray} + A1, A2 = arg1type(ArrayT), arg2type(ArrayT) return similar(promote_type(A1, A2), sz) end @@ -169,15 +211,15 @@ function Base.similar( return similar(arrayt, map(arg1, axs)) ⊗ similar(arrayt, map(arg2, axs)) end -function Base.permutedims(a::KroneckerArray, perm) +function Base.permutedims(a::AbstractKroneckerArray, perm) return permutedims(arg1(a), perm) ⊗ permutedims(arg2(a), perm) end using DerivableInterfaces: DerivableInterfaces, permuteddims -function DerivableInterfaces.permuteddims(a::KroneckerArray, perm) +function DerivableInterfaces.permuteddims(a::AbstractKroneckerArray, perm) return permuteddims(arg1(a), perm) ⊗ permuteddims(arg2(a), perm) end -function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm) +function Base.permutedims!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray, perm) return mutate_active_args!( (dest, src) -> permutedims!(dest, src, perm), Base.Fix2(permutedims, perm), dest, src ) @@ -208,9 +250,10 @@ kron_nd(a1::AbstractMatrix, a2::AbstractMatrix) = kron(a1, a2) kron_nd(a1::AbstractVector, a2::AbstractVector) = kron(a1, a2) # Eagerly collect arguments to make more general on GPU. -Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a))) +Base.collect(a::AbstractKroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a))) +Base.collect(T::Type, a::AbstractKroneckerArray) = kron_nd(collect(T, arg1(a)), collect(T, arg2(a))) -function Base.zero(a::KroneckerArray) +function Base.zero(a::AbstractKroneckerArray) return if isactive(arg1(a)) == isactive(arg2(a)) # TODO: Maybe this should zero both arguments? # This is how `a * false` would behave. @@ -223,7 +266,7 @@ function Base.zero(a::KroneckerArray) end using DerivableInterfaces: DerivableInterfaces, zero! -function DerivableInterfaces.zero!(a::KroneckerArray) +function DerivableInterfaces.zero!(a::AbstractKroneckerArray) (isactive(arg1(a)) || isactive(arg2(a))) || error("Can't mutate immutable KroneckerArray.") isactive(arg1(a)) && zero!(arg1(a)) @@ -231,15 +274,13 @@ function DerivableInterfaces.zero!(a::KroneckerArray) return a end -function Base.Array{T, N}(a::KroneckerArray{S, N}) where {T, S, N} - return convert(Array{T, N}, collect(a)) +function Base.Array{T, N}(a::AbstractKroneckerArray{S, N}) where {T, S, N} + return convert(Array{T, N}, collect(T, a)) end -function Base.size(a::KroneckerArray) - return ntuple(dim -> size(arg1(a), dim) * size(arg2(a), dim), ndims(a)) -end +Base.size(a::AbstractKroneckerArray) = size(arg1(a)) .* size(arg2(a)) -function Base.axes(a::KroneckerArray) +function Base.axes(a::AbstractKroneckerArray) return ntuple(ndims(a)) do dim return CartesianProductUnitRange( axes(arg1(a), dim) × axes(arg2(a), dim), Base.OneTo(size(a, dim)) @@ -247,11 +288,6 @@ function Base.axes(a::KroneckerArray) end end -arguments(a::KroneckerArray) = (arg1(a), arg2(a)) -arguments(a::KroneckerArray, n::Int) = arguments(a)[n] -argument_types(a::KroneckerArray) = argument_types(typeof(a)) -argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2) - function Base.print_array(io::IO, a::KroneckerArray) Base.print_array(io, arg1(a)) println(io, "\n ⊗") @@ -285,7 +321,7 @@ end # Indexing logic. function Base.to_indices( - a::KroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg} + a::AbstractKroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg} ) I1 = to_indices(arg1(a), arg1.(inds), arg1.(I)) I2 = to_indices(arg2(a), arg2.(inds), arg2.(I)) @@ -293,37 +329,40 @@ function Base.to_indices( end function Base.getindex( - a::KroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N} + a::AbstractKroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N} ) where {N} I′ = to_indices(a, I) return arg1(a)[arg1.(I′)...] ⊗ arg2(a)[arg2.(I′)...] end # Fix ambigiuity error. -Base.getindex(a::KroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[] +Base.getindex(a::AbstractKroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[] arg1(::Colon) = (:) arg2(::Colon) = (:) arg1(::Base.Slice) = (:) arg2(::Base.Slice) = (:) function Base.view( - a::KroneckerArray{<:Any, N}, + a::AbstractKroneckerArray{<:Any, N}, I::Vararg{Union{CartesianProduct, CartesianProductUnitRange, Base.Slice, Colon}, N}, ) where {N} return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...) end -function Base.view(a::KroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N} +function Base.view(a::AbstractKroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N} return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...) end # Fix ambigiuity error. -Base.view(a::KroneckerArray{<:Any, 0}) = view(arg1(a)) ⊗ view(arg2(a)) +Base.view(a::AbstractKroneckerArray{<:Any, 0}) = view(arg1(a)) ⊗ view(arg2(a)) -function Base.:(==)(a::KroneckerArray, b::KroneckerArray) +function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray) return arg1(a) == arg1(b) && arg2(a) == arg2(b) end -function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...) + +# TODO: this definition doesn't fully retain the original meaning: +# ‖a - b‖ < atol could be true even if the following check isn't +function Base.isapprox(a::AbstractKroneckerArray, b::AbstractKroneckerArray; kwargs...) return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...) end -function Base.iszero(a::KroneckerArray) +function Base.iszero(a::AbstractKroneckerArray) return iszero(arg1(a)) || iszero(arg2(a)) end function Base.isreal(a::KroneckerArray) @@ -335,8 +374,8 @@ function DiagonalArrays.diagonal(a::KroneckerArray) return diagonal(arg1(a)) ⊗ diagonal(arg2(a)) end -Base.real(a::KroneckerArray{<:Real}) = a -function Base.real(a::KroneckerArray) +Base.real(a::AbstractKroneckerArray{<:Real}) = a +function Base.real(a::AbstractKroneckerArray) if iszero(imag(arg1(a))) || iszero(imag(arg2(a))) return real(arg1(a)) ⊗ real(arg2(a)) elseif iszero(real(arg1(a))) || iszero(real(arg2(a))) @@ -344,8 +383,8 @@ function Base.real(a::KroneckerArray) end return real(arg1(a)) ⊗ real(arg2(a)) - imag(arg1(a)) ⊗ imag(arg2(a)) end -Base.imag(a::KroneckerArray{<:Real}) = zero(a) -function Base.imag(a::KroneckerArray) +Base.imag(a::AbstractKroneckerArray{<:Real}) = zero(a) +function Base.imag(a::AbstractKroneckerArray) if iszero(imag(arg1(a))) || iszero(real(arg2(a))) return real(arg1(a)) ⊗ imag(arg2(a)) elseif iszero(real(arg1(a))) || iszero(imag(arg2(a))) @@ -356,14 +395,14 @@ end for f in [:transpose, :adjoint, :inv] @eval begin - function Base.$f(a::KroneckerArray) + function Base.$f(a::AbstractKroneckerArray) return $f(arg1(a)) ⊗ $f(arg2(a)) end end end function Base.reshape( - a::KroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}} + a::AbstractKroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}} ) return reshape(arg1(a), map(arg1, ax)) ⊗ reshape(arg2(a), map(arg2, ax)) end @@ -383,8 +422,8 @@ end function KroneckerStyle{N, A1, A2}(v::Val{M}) where {N, A1, A2, M} return KroneckerStyle{M, typeof(A1)(v), typeof(A2)(v)}() end -function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any, N, A1, A2}}) where {N, A1, A2} - return KroneckerStyle{N}(BroadcastStyle(A1), BroadcastStyle(A2)) +function Base.BroadcastStyle(::Type{T}) where {T <: AbstractKroneckerArray} + return KroneckerStyle{ndims(T)}(BroadcastStyle(arg1type(T)), BroadcastStyle(arg2type(T))) end function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N} style_a = BroadcastStyle(arg1(style1), arg1(style2)) @@ -403,10 +442,10 @@ function Base.similar( return a ⊗ b end -function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...) +function Base.map(f, a1::AbstractKroneckerArray, a_rest::AbstractKroneckerArray...) return Broadcast.broadcast_preserving_zero_d(f, a1, a_rest...) end -function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...) +function Base.map!(f, dest::AbstractKroneckerArray, a1::AbstractKroneckerArray, a_rest::AbstractKroneckerArray...) dest .= f.(a1, a_rest...) return dest end @@ -438,7 +477,7 @@ end function Base.copy(a::Summed{<:KroneckerStyle}) return copy(KroneckerBroadcast(a)) end -function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle}) +function Base.copyto!(dest::AbstractKroneckerArray, a::Summed{<:KroneckerStyle}) return copyto!(dest, KroneckerBroadcast(a)) end diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index de67466..c0f08d5 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -30,22 +30,20 @@ function LinearAlgebra.pinv(a::KroneckerArray; kwargs...) return pinv(arg1(a); kwargs...) ⊗ pinv(arg2(a); kwargs...) end -function LinearAlgebra.diag(a::KroneckerArray) +function LinearAlgebra.diag(a::AbstractKroneckerArray) return copy(DiagonalArrays.diagview(a)) end -function Base.:*(a::KroneckerArray, b::KroneckerArray) +function Base.:*(a::AbstractKroneckerArray, b::AbstractKroneckerArray) return (arg1(a) * arg1(b)) ⊗ (arg2(a) * arg2(b)) end function LinearAlgebra.mul!( - c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number + c::AbstractKroneckerArray, a::AbstractKroneckerArray, b::AbstractKroneckerArray, α::Number, β::Number ) - iszero(β) || - iszero(c) || - throw( + iszero(β) || iszero(c) || throw( ArgumentError( - "Can't multiple KroneckerArrays with nonzero β and nonzero destination." + "Can't multiply KroneckerArrays with nonzero β and nonzero destination." ), ) # TODO: Only perform in-place operation on the non-active argument(s). @@ -55,12 +53,12 @@ function LinearAlgebra.mul!( end using LinearAlgebra: tr -function LinearAlgebra.tr(a::KroneckerArray) +function LinearAlgebra.tr(a::AbstractKroneckerArray) return tr(arg1(a)) * tr(arg2(a)) end using LinearAlgebra: norm -function LinearAlgebra.norm(a::KroneckerArray, p::Int = 2) +function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Int = 2) return norm(arg1(a), p) * norm(arg2(a), p) end @@ -99,7 +97,7 @@ const MATRIX_FUNCTIONS = [ for f in MATRIX_FUNCTIONS @eval begin - function Base.$f(a::KroneckerArray) + function Base.$f(a::AbstractKroneckerArray) return if isone(arg1(a)) arg1(a) ⊗ $f(arg2(a)) elseif isone(arg2(a)) @@ -115,30 +113,30 @@ end # than `LinearAlgebra.checksquare`, for example it compares axes and can check # that the codomain and domain are dual of each other. using DiagonalArrays: DiagonalArrays, checksquare, issquare -function DiagonalArrays.issquare(a::KroneckerArray) +function DiagonalArrays.issquare(a::AbstractKroneckerArray) return issquare(arg1(a)) && issquare(arg2(a)) end using LinearAlgebra: det -function LinearAlgebra.det(a::KroneckerArray) +function LinearAlgebra.det(a::AbstractKroneckerArray) checksquare(a) return det(arg1(a))^size(arg2(a), 1) * det(arg2(a))^size(arg1(a), 1) end -function LinearAlgebra.svd(a::KroneckerArray) +function LinearAlgebra.svd(a::AbstractKroneckerArray) F1 = svd(arg1(a)) F2 = svd(arg2(a)) return SVD(F1.U ⊗ F2.U, F1.S ⊗ F2.S, F1.Vt ⊗ F2.Vt) end -function LinearAlgebra.svdvals(a::KroneckerArray) +function LinearAlgebra.svdvals(a::AbstractKroneckerArray) return svdvals(arg1(a)) ⊗ svdvals(arg2(a)) end -function LinearAlgebra.eigen(a::KroneckerArray) +function LinearAlgebra.eigen(a::AbstractKroneckerArray) F1 = eigen(arg1(a)) F2 = eigen(arg2(a)) return Eigen(F1.values ⊗ F2.values, F1.vectors ⊗ F2.vectors) end -function LinearAlgebra.eigvals(a::KroneckerArray) +function LinearAlgebra.eigvals(a::AbstractKroneckerArray) return eigvals(arg1(a)) ⊗ eigvals(arg2(a)) end @@ -151,10 +149,10 @@ end function Base.:*(a::KroneckerQ, b::KroneckerQ) return (arg1(a) * arg1(b)) ⊗ (arg2(a) * arg2(b)) end -function Base.:*(a1::KroneckerQ, a2::KroneckerArray) +function Base.:*(a1::KroneckerQ, a2::AbstractKroneckerArray) return (arg1(a1) * arg1(a2)) ⊗ (arg2(a1) * arg2(a2)) end -function Base.:*(a1::KroneckerArray, a2::KroneckerQ) +function Base.:*(a1::AbstractKroneckerArray, a2::KroneckerQ) return (arg1(a1) * arg1(a2)) ⊗ (arg2(a1) * arg2(a2)) end function Base.adjoint(a::KroneckerQ) @@ -171,7 +169,7 @@ Base.iterate(F::KroneckerQR, ::Val{:done}) = nothing function ⊗(a1::LinearAlgebra.QRCompactWYQ, a2::LinearAlgebra.QRCompactWYQ) return KroneckerQ(a1, a2) end -function LinearAlgebra.qr(a::KroneckerArray) +function LinearAlgebra.qr(a::AbstractKroneckerArray) Fa = qr(arg1(a)) Fb = qr(arg2(a)) return KroneckerQR(Fa.Q ⊗ Fb.Q, Fa.R ⊗ Fb.R) @@ -187,7 +185,7 @@ Base.iterate(F::KroneckerLQ, ::Val{:done}) = nothing function ⊗(a1::LinearAlgebra.LQPackedQ, a2::LinearAlgebra.LQPackedQ) return KroneckerQ(a1, a2) end -function LinearAlgebra.lq(a::KroneckerArray) +function LinearAlgebra.lq(a::AbstractKroneckerArray) Fa = lq(arg1(a)) Fb = lq(arg2(a)) return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q) diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index f2061e4..5383130 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -1,61 +1,23 @@ using MatrixAlgebraKit: MatrixAlgebraKit, - AbstractAlgorithm, - TruncationStrategy, - default_eig_algorithm, - default_eigh_algorithm, - default_lq_algorithm, - default_polar_algorithm, - default_qr_algorithm, - default_svd_algorithm, - eig_full!, - eig_full, - eig_trunc!, - eig_trunc, - eig_vals!, - eig_vals, - eigh_full!, - eigh_full, - eigh_trunc!, - eigh_trunc, - eigh_vals!, - eigh_vals, + AbstractAlgorithm, TruncationStrategy, + default_eig_algorithm, default_eigh_algorithm, default_lq_algorithm, + default_polar_algorithm, default_qr_algorithm, default_svd_algorithm, + eig_full!, eig_full, eig_trunc!, eig_trunc, eig_vals!, eig_vals, + eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eigh_vals!, eigh_vals, initialize_output, - left_null!, - left_null, - left_orth!, - left_orth, - left_polar!, - left_polar, - lq_compact!, - lq_compact, - lq_full!, - lq_full, - qr_compact!, - qr_compact, - qr_full!, - qr_full, - right_null!, - right_null, - right_orth!, - right_orth, - right_polar!, - right_polar, - svd_compact!, - svd_compact, - svd_full!, - svd_full, - svd_trunc!, - svd_trunc, - svd_vals!, - svd_vals, + left_null!, left_null, left_orth!, left_orth, left_polar!, left_polar, + lq_compact!, lq_compact, lq_full!, lq_full, + qr_compact!, qr_compact, qr_full!, qr_full, + right_null!, right_null, right_orth!, right_orth, right_polar!, right_polar, + svd_compact!, svd_compact, svd_full!, svd_full, svd_trunc!, svd_trunc, svd_vals!, svd_vals, truncate using DiagonalArrays: DiagonalArrays, diagview -function DiagonalArrays.diagview(a::KroneckerMatrix) +function DiagonalArrays.diagview(a::AbstractKroneckerMatrix) return diagview(arg1(a)) ⊗ diagview(arg2(a)) end -MatrixAlgebraKit.diagview(a::KroneckerMatrix) = diagview(a) +MatrixAlgebraKit.diagview(a::AbstractKroneckerMatrix) = diagview(a) struct KroneckerAlgorithm{A1, A2} <: AbstractAlgorithm arg1::A1 @@ -66,53 +28,35 @@ end using MatrixAlgebraKit: copy_input, - eig_full, - eig_vals, - eigh_full, - eigh_vals, - qr_compact, - qr_full, - left_null, - left_orth, - left_polar, - lq_compact, - lq_full, - right_null, - right_orth, - right_polar, - svd_compact, - svd_full + eig_full, eig_vals, eigh_full, eigh_vals, + qr_compact, qr_full, + left_null, left_orth, left_polar, + lq_compact, lq_full, + right_null, right_orth, right_polar, + svd_compact, svd_full for f in [ - :eig_full, - :eigh_full, - :qr_compact, - :qr_full, - :left_polar, - :lq_compact, - :lq_full, - :right_polar, - :svd_compact, - :svd_full, + :eig_full, :eigh_full, + :qr_compact, :qr_full, + :lq_compact, :lq_full, + :left_polar, :right_polar, + :svd_compact, :svd_full, ] @eval begin - function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix) + function MatrixAlgebraKit.copy_input(::typeof($f), a::AbstractKroneckerMatrix) return copy_input($f, arg1(a)) ⊗ copy_input($f, arg2(a)) end end end for f in [ - :default_eig_algorithm, - :default_eigh_algorithm, - :default_lq_algorithm, - :default_qr_algorithm, - :default_polar_algorithm, - :default_svd_algorithm, + :default_eig_algorithm, :default_eigh_algorithm, + :default_lq_algorithm, :default_qr_algorithm, + :default_polar_algorithm, :default_svd_algorithm, ] @eval begin function MatrixAlgebraKit.$f( - A::Type{<:KroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs... + A::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs... ) A1, A2 = argument_types(A) return KroneckerAlgorithm( @@ -123,16 +67,11 @@ for f in [ end for f in [ - :eig_full, - :eigh_full, - :left_polar, - :lq_compact, - :lq_full, - :qr_compact, - :qr_full, - :right_polar, - :svd_compact, - :svd_full, + :eig_full, :eigh_full, + :left_polar, :right_polar, + :lq_compact, :lq_full, + :qr_compact, :qr_full, + :svd_compact, :svd_full, ] f! = Symbol(f, :!) @eval begin @@ -142,10 +81,10 @@ for f in [ return nothing end function MatrixAlgebraKit.$f!( - a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs... + a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm ) - a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...) - a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...) + a1 = $f(arg1(a), arg1(alg)) + a2 = $f(arg2(a), arg2(alg)) return a1 .⊗ a2 end end @@ -160,10 +99,10 @@ for f in [:eig_vals, :eigh_vals, :svd_vals] return nothing end function MatrixAlgebraKit.$f!( - a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs... + a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm ) - a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...) - a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...) + a1 = $f(arg1(a), arg1(alg)) + a2 = $f(arg2(a), arg2(alg)) return a1 ⊗ a2 end end @@ -172,11 +111,11 @@ end for f in [:left_orth, :right_orth] f! = Symbol(f, :!) @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f!), a::KroneckerMatrix) + function MatrixAlgebraKit.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) return nothing end function MatrixAlgebraKit.$f!( - a::KroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs... + a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs... ) a1 = $f(arg1(a); kwargs..., kwargs1...) a2 = $f(arg2(a); kwargs..., kwargs2...) @@ -188,11 +127,11 @@ end for f in [:left_null, :right_null] f! = Symbol(f, :!) @eval begin - function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) + function MatrixAlgebraKit.initialize_output(::typeof($f), a::AbstractKroneckerMatrix) return nothing end function MatrixAlgebraKit.$f!( - a::KroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs... + a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs... ) a1 = $f(arg1(a); kwargs..., kwargs1...) a2 = $f(arg2(a); kwargs..., kwargs2...) @@ -248,7 +187,7 @@ function to_truncated_indices(values::KroneckerVector, I) end function MatrixAlgebraKit.findtruncated( - values::KroneckerVector, strategy::KroneckerTruncationStrategy + values::AbstractKroneckerVector, strategy::KroneckerTruncationStrategy ) I = findtruncated(Vector(values), strategy.strategy) return to_truncated_indices(values, I) @@ -257,12 +196,12 @@ end for f in [:eig_trunc!, :eigh_trunc!] @eval begin function MatrixAlgebraKit.truncate( - ::typeof($f), DV::NTuple{2, KroneckerMatrix}, strategy::TruncationStrategy + ::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix}, strategy::TruncationStrategy ) return truncate($f, DV, KroneckerTruncationStrategy(strategy)) end function MatrixAlgebraKit.truncate( - ::typeof($f), (D, V)::NTuple{2, KroneckerMatrix}, strategy::KroneckerTruncationStrategy + ::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy ) I = findtruncated(diagview(D), strategy) return (D[I, I], V[(:) × (:), I]), I @@ -271,13 +210,13 @@ for f in [:eig_trunc!, :eigh_trunc!] end function MatrixAlgebraKit.truncate( - f::typeof(svd_trunc!), USVᴴ::NTuple{3, KroneckerMatrix}, strategy::TruncationStrategy + f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix}, strategy::TruncationStrategy ) return truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy)) end function MatrixAlgebraKit.truncate( ::typeof(svd_trunc!), - (U, S, Vᴴ)::NTuple{3, KroneckerMatrix}, + (U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy, ) I = findtruncated(diagview(S), strategy)