Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
version = "0.3.1"
version = "0.3.2"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[deps]
Expand Down
144 changes: 58 additions & 86 deletions src/matrixalgebrakit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ using MatrixAlgebraKit:
using MatrixAlgebraKit: TruncationStrategy, findtruncated, truncate
import MatrixAlgebraKit as MAK

DiagonalArrays.diagview(a::AbstractKroneckerMatrix) = ⊗(DiagonalArrays.diagview.(kroneckerfactors(a))...)
function DiagonalArrays.diagview(a::AbstractKroneckerMatrix)
return ⊗(DiagonalArrays.diagview.(kroneckerfactors(a))...)
end
MatrixAlgebraKit.diagview(a::AbstractKroneckerMatrix) = DiagonalArrays.diagview(a)

struct KroneckerAlgorithm{A, B} <: AbstractAlgorithm
Expand Down Expand Up @@ -51,8 +53,10 @@ for f in (
:default_lq_algorithm, :default_qr_algorithm,
:default_polar_algorithm, :default_svd_algorithm,
)
@eval function MAK.$f(A::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs...)
A, B = kroneckerfactortypes(A)
@eval function MAK.$f(
AB::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs...
)
A, B = kroneckerfactortypes(AB)
return KroneckerAlgorithm(
MAK.$f(A; kwargs..., kwargs1...),
MAK.$f(B; kwargs..., kwargs2...)
Expand All @@ -68,7 +72,11 @@ for f in (
:svd_compact, :svd_full,
)
f! = Symbol(f, :!)
@eval MAK.initialize_output(::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm) = nothing
@eval function MAK.initialize_output(
::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm
)
return nothing
end
@eval function MAK.$f!(ab::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm)
a, b = kroneckerfactors(ab)
algA, algB = kroneckerfactors(alg)
Expand All @@ -92,111 +100,66 @@ for f in (:eig_vals, :eigh_vals, :svd_vals)
end
end

# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
# is merged.
for kind in ("polar", "qr", "svd")
for f! in (:left_orth!, :right_orth!, :left_null!, :right_null!)
@eval begin
function MAK.initialize_output(
::typeof(left_orth!), a::AbstractKroneckerMatrix,
alg::MAK.LeftOrthAlgorithm{Symbol($kind)},
function MAK.default_algorithm(
::typeof($f!), AB::Type{<:AbstractKroneckerMatrix}; kwargs...
)
return nothing
A, B = kroneckerfactortypes(AB)
algA = MAK.default_algorithm($f!, A; kwargs...)
algB = MAK.default_algorithm($f!, B; kwargs...)
return KroneckerAlgorithm(algA, algB)
end
function MAK.left_orth!(
ab::AbstractKroneckerMatrix, F, alg::MAK.LeftOrthAlgorithm{Symbol($kind)};
kwargs1 = (;), kwargs2 = (;), kwargs...,
function MAK.select_algorithm(
::typeof($f!), ab::AbstractKroneckerMatrix, alg::Symbol; kwargs...
)
a, b = kroneckerfactors(ab)
Fa = MAK.left_orth!(a; kwargs..., kwargs1...)
Fb = MAK.left_orth!(b; kwargs..., kwargs2...)
return Fa .⊗ Fb
algA = MAK.select_algorithm($f!, a, alg; kwargs...)
algB = MAK.select_algorithm($f!, b, alg; kwargs...)
return KroneckerAlgorithm(algA, algB)
end
MAK.initialize_output(::typeof($f!), A, alg::KroneckerAlgorithm) = nothing
end
end

# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
# is merged.
for kind in ("lq", "polar", "svd")
for f! in (:left_orth!, :right_orth!)
@eval begin
function MAK.initialize_output(
::typeof(right_orth!), a::AbstractKroneckerMatrix,
alg::MAK.RightOrthAlgorithm{Symbol($kind)},
)
return nothing
end
function MAK.right_orth!(
ab::AbstractKroneckerMatrix, F, alg::MAK.RightOrthAlgorithm{Symbol($kind)};
kwargs1 = (;), kwargs2 = (;), kwargs...,
function MAK.$f!(
ab, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs...,
)
a, b = kroneckerfactors(ab)
Fa = MAK.right_orth!(a; kwargs..., kwargs1...)
Fb = MAK.right_orth!(b; kwargs..., kwargs2...)
Fa = MAK.$f!(a; kwargs..., kwargs1...)
Fb = MAK.$f!(b; kwargs..., kwargs2...)
return Fa .⊗ Fb
end
end
end

# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
# is merged.
for Alg in (
:(MAK.LeftNullViaQR),
:(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm}),
:(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}),
)
@eval begin
function MAK.initialize_output(
::typeof(left_null!), a::AbstractKroneckerMatrix, alg::$Alg
)
return nothing
end
function MAK.left_null!(
ab::AbstractKroneckerMatrix, F, alg::$Alg;
kwargs1 = (;), kwargs2 = (;), kwargs...,
)
a, b = kroneckerfactors(ab)
Na = MAK.left_null!(a; kwargs..., kwargs1...)
Nb = MAK.left_null!(b; kwargs..., kwargs2...)
return Na ⊗ Nb
end
end
end

# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
# is merged.
for Alg in (
:(MAK.RightNullViaLQ),
:(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm}),
:(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}),
)
for f! in (:left_null!, :right_null!)
@eval begin
function MAK.initialize_output(
::typeof(right_null!), a::AbstractKroneckerMatrix, alg::$Alg
)
return nothing
end
function MAK.right_null!(
ab::AbstractKroneckerMatrix, F, alg::$Alg;
function MAK.$f!(
ab, F, alg::KroneckerAlgorithm;
kwargs1 = (;), kwargs2 = (;), kwargs...,
)
a, b = kroneckerfactors(ab)
Na = MAK.right_null!(a; kwargs..., kwargs1...)
Nb = MAK.right_null!(b; kwargs..., kwargs2...)
return NaNb
Fa = MAK.$f!(a; kwargs..., kwargs1...)
Fb = MAK.$f!(b; kwargs..., kwargs2...)
return FaFb
end
end
end

# Truncation


struct KroneckerTruncationStrategy{T <: TruncationStrategy} <: TruncationStrategy
strategy::T
end

using FillArrays: OnesVector
const OnesKroneckerVector{T, A <: OnesVector{T}, B <: AbstractVector{T}} = KroneckerVector{T, A, B}
const KroneckerOnesVector{T, A <: AbstractVector{T}, B <: OnesVector{T}} = KroneckerVector{T, A, B}
const OnesVectorOnesVector{T, A <: OnesVector{T}, B <: OnesVector{T}} = KroneckerVector{T, A, B}
const OnesKroneckerVector{T, A <: OnesVector{T}, B <: AbstractVector{T}} =
KroneckerVector{T, A, B}
const KroneckerOnesVector{T, A <: AbstractVector{T}, B <: OnesVector{T}} =
KroneckerVector{T, A, B}
const OnesVectorOnesVector{T, A <: OnesVector{T}, B <: OnesVector{T}} =
KroneckerVector{T, A, B}

axis(a) = only(axes(a))

Expand All @@ -208,7 +171,8 @@ function to_truncated_indices(values::OnesKroneckerVector, I)
I_data = unique(kroneckerfactors.(prods, 2))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> kroneckerfactors(x, 2) == i, prods) == length(kroneckerfactors(values, 2))
return count(x -> kroneckerfactors(x, 2) == i, prods) ==
length(kroneckerfactors(values, 2))
end
return I_id × I_data
end
Expand All @@ -218,7 +182,8 @@ function to_truncated_indices(values::KroneckerOnesVector, I)
I_data = unique(kroneckerfactors.(prods, 1))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> kroneckerfactors(x, 1) == i, prods) == length(kroneckerfactors(values, 2))
return count(x -> kroneckerfactors(x, 1) == i, prods) ==
length(kroneckerfactors(values, 2))
end
I_id = only(to_indices(kroneckerfactors(values, 2), (:,)))
return I_data × I_id
Expand All @@ -240,22 +205,29 @@ end

for f in (:eig_trunc!, :eigh_trunc!)
@eval function MAK.truncate(
::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix}, strategy::TruncationStrategy
::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix},
strategy::TruncationStrategy,
)
return MAK.truncate($f, DV, KroneckerTruncationStrategy(strategy))
end
@eval function MAK.truncate(
::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy
::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix},
strategy::KroneckerTruncationStrategy,
)
I = MAK.findtruncated(MAK.diagview(D), strategy)
return (D[I, I], V[(:) × (:), I]), I
end
end

MAK.truncate(f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix}, strategy::TruncationStrategy) =
MAK.truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy))
function MAK.truncate(
::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy,
f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix},
strategy::TruncationStrategy,
)
return MAK.truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy))
end
function MAK.truncate(
::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix},
strategy::KroneckerTruncationStrategy,
)
I = MAK.findtruncated(MAK.diagview(S), strategy)
return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]), I
Expand Down
Loading