From 3883a655eb6d328836ce6468367f6270abd068b4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 20 Nov 2025 22:49:55 -0500 Subject: [PATCH] Improve overloading of orthnull functions --- Project.toml | 2 +- src/matrixalgebrakit.jl | 144 ++++++++++++++++------------------------ 2 files changed, 59 insertions(+), 87 deletions(-) diff --git a/Project.toml b/Project.toml index 9e0a324..24cb025 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" -version = "0.3.1" +version = "0.3.2" authors = ["ITensor developers and contributors"] [deps] diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index 0a70767..e4369db 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -20,7 +20,9 @@ using MatrixAlgebraKit: using MatrixAlgebraKit: TruncationStrategy, findtruncated, truncate import MatrixAlgebraKit as MAK -DiagonalArrays.diagview(a::AbstractKroneckerMatrix) = ⊗(DiagonalArrays.diagview.(kroneckerfactors(a))...) +function DiagonalArrays.diagview(a::AbstractKroneckerMatrix) + return ⊗(DiagonalArrays.diagview.(kroneckerfactors(a))...) +end MatrixAlgebraKit.diagview(a::AbstractKroneckerMatrix) = DiagonalArrays.diagview(a) struct KroneckerAlgorithm{A, B} <: AbstractAlgorithm @@ -51,8 +53,10 @@ for f in ( :default_lq_algorithm, :default_qr_algorithm, :default_polar_algorithm, :default_svd_algorithm, ) - @eval function MAK.$f(A::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs...) - A, B = kroneckerfactortypes(A) + @eval function MAK.$f( + AB::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs... + ) + A, B = kroneckerfactortypes(AB) return KroneckerAlgorithm( MAK.$f(A; kwargs..., kwargs1...), MAK.$f(B; kwargs..., kwargs2...) @@ -68,7 +72,11 @@ for f in ( :svd_compact, :svd_full, ) f! = Symbol(f, :!) - @eval MAK.initialize_output(::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm) = nothing + @eval function MAK.initialize_output( + ::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm + ) + return nothing + end @eval function MAK.$f!(ab::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm) a, b = kroneckerfactors(ab) algA, algB = kroneckerfactors(alg) @@ -92,111 +100,66 @@ for f in (:eig_vals, :eigh_vals, :svd_vals) end end -# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104 -# is merged. -for kind in ("polar", "qr", "svd") +for f! in (:left_orth!, :right_orth!, :left_null!, :right_null!) @eval begin - function MAK.initialize_output( - ::typeof(left_orth!), a::AbstractKroneckerMatrix, - alg::MAK.LeftOrthAlgorithm{Symbol($kind)}, + function MAK.default_algorithm( + ::typeof($f!), AB::Type{<:AbstractKroneckerMatrix}; kwargs... ) - return nothing + A, B = kroneckerfactortypes(AB) + algA = MAK.default_algorithm($f!, A; kwargs...) + algB = MAK.default_algorithm($f!, B; kwargs...) + return KroneckerAlgorithm(algA, algB) end - function MAK.left_orth!( - ab::AbstractKroneckerMatrix, F, alg::MAK.LeftOrthAlgorithm{Symbol($kind)}; - kwargs1 = (;), kwargs2 = (;), kwargs..., + function MAK.select_algorithm( + ::typeof($f!), ab::AbstractKroneckerMatrix, alg::Symbol; kwargs... ) a, b = kroneckerfactors(ab) - Fa = MAK.left_orth!(a; kwargs..., kwargs1...) - Fb = MAK.left_orth!(b; kwargs..., kwargs2...) - return Fa .⊗ Fb + algA = MAK.select_algorithm($f!, a, alg; kwargs...) + algB = MAK.select_algorithm($f!, b, alg; kwargs...) + return KroneckerAlgorithm(algA, algB) end + MAK.initialize_output(::typeof($f!), A, alg::KroneckerAlgorithm) = nothing end end - -# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104 -# is merged. -for kind in ("lq", "polar", "svd") +for f! in (:left_orth!, :right_orth!) @eval begin - function MAK.initialize_output( - ::typeof(right_orth!), a::AbstractKroneckerMatrix, - alg::MAK.RightOrthAlgorithm{Symbol($kind)}, - ) - return nothing - end - function MAK.right_orth!( - ab::AbstractKroneckerMatrix, F, alg::MAK.RightOrthAlgorithm{Symbol($kind)}; - kwargs1 = (;), kwargs2 = (;), kwargs..., + function MAK.$f!( + ab, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs..., ) a, b = kroneckerfactors(ab) - Fa = MAK.right_orth!(a; kwargs..., kwargs1...) - Fb = MAK.right_orth!(b; kwargs..., kwargs2...) + Fa = MAK.$f!(a; kwargs..., kwargs1...) + Fb = MAK.$f!(b; kwargs..., kwargs2...) return Fa .⊗ Fb end end end - -# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104 -# is merged. -for Alg in ( - :(MAK.LeftNullViaQR), - :(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm}), - :(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}), - ) - @eval begin - function MAK.initialize_output( - ::typeof(left_null!), a::AbstractKroneckerMatrix, alg::$Alg - ) - return nothing - end - function MAK.left_null!( - ab::AbstractKroneckerMatrix, F, alg::$Alg; - kwargs1 = (;), kwargs2 = (;), kwargs..., - ) - a, b = kroneckerfactors(ab) - Na = MAK.left_null!(a; kwargs..., kwargs1...) - Nb = MAK.left_null!(b; kwargs..., kwargs2...) - return Na ⊗ Nb - end - end -end - -# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104 -# is merged. -for Alg in ( - :(MAK.RightNullViaLQ), - :(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm}), - :(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}), - ) +for f! in (:left_null!, :right_null!) @eval begin - function MAK.initialize_output( - ::typeof(right_null!), a::AbstractKroneckerMatrix, alg::$Alg - ) - return nothing - end - function MAK.right_null!( - ab::AbstractKroneckerMatrix, F, alg::$Alg; + function MAK.$f!( + ab, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs..., ) a, b = kroneckerfactors(ab) - Na = MAK.right_null!(a; kwargs..., kwargs1...) - Nb = MAK.right_null!(b; kwargs..., kwargs2...) - return Na ⊗ Nb + Fa = MAK.$f!(a; kwargs..., kwargs1...) + Fb = MAK.$f!(b; kwargs..., kwargs2...) + return Fa ⊗ Fb end end end # Truncation - struct KroneckerTruncationStrategy{T <: TruncationStrategy} <: TruncationStrategy strategy::T end using FillArrays: OnesVector -const OnesKroneckerVector{T, A <: OnesVector{T}, B <: AbstractVector{T}} = KroneckerVector{T, A, B} -const KroneckerOnesVector{T, A <: AbstractVector{T}, B <: OnesVector{T}} = KroneckerVector{T, A, B} -const OnesVectorOnesVector{T, A <: OnesVector{T}, B <: OnesVector{T}} = KroneckerVector{T, A, B} +const OnesKroneckerVector{T, A <: OnesVector{T}, B <: AbstractVector{T}} = + KroneckerVector{T, A, B} +const KroneckerOnesVector{T, A <: AbstractVector{T}, B <: OnesVector{T}} = + KroneckerVector{T, A, B} +const OnesVectorOnesVector{T, A <: OnesVector{T}, B <: OnesVector{T}} = + KroneckerVector{T, A, B} axis(a) = only(axes(a)) @@ -208,7 +171,8 @@ function to_truncated_indices(values::OnesKroneckerVector, I) I_data = unique(kroneckerfactors.(prods, 2)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> kroneckerfactors(x, 2) == i, prods) == length(kroneckerfactors(values, 2)) + return count(x -> kroneckerfactors(x, 2) == i, prods) == + length(kroneckerfactors(values, 2)) end return I_id × I_data end @@ -218,7 +182,8 @@ function to_truncated_indices(values::KroneckerOnesVector, I) I_data = unique(kroneckerfactors.(prods, 1)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> kroneckerfactors(x, 1) == i, prods) == length(kroneckerfactors(values, 2)) + return count(x -> kroneckerfactors(x, 1) == i, prods) == + length(kroneckerfactors(values, 2)) end I_id = only(to_indices(kroneckerfactors(values, 2), (:,))) return I_data × I_id @@ -240,22 +205,29 @@ end for f in (:eig_trunc!, :eigh_trunc!) @eval function MAK.truncate( - ::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix}, strategy::TruncationStrategy + ::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix}, + strategy::TruncationStrategy, ) return MAK.truncate($f, DV, KroneckerTruncationStrategy(strategy)) end @eval function MAK.truncate( - ::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy + ::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix}, + strategy::KroneckerTruncationStrategy, ) I = MAK.findtruncated(MAK.diagview(D), strategy) return (D[I, I], V[(:) × (:), I]), I end end -MAK.truncate(f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix}, strategy::TruncationStrategy) = - MAK.truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy)) function MAK.truncate( - ::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy, + f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix}, + strategy::TruncationStrategy, + ) + return MAK.truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy)) +end +function MAK.truncate( + ::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix}, + strategy::KroneckerTruncationStrategy, ) I = MAK.findtruncated(MAK.diagview(S), strategy) return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]), I