diff --git a/Project.toml b/Project.toml index dae7bab..9e0a324 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" +version = "0.3.1" authors = ["ITensor developers and contributors"] -version = "0.3.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -35,7 +35,7 @@ FillArrays = "1.13" GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.10" -MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5" +MatrixAlgebraKit = "0.6" TensorAlgebra = "0.3.10, 0.4" TensorProducts = "0.1.7" TypeParameterAccessors = "0.4.2" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 76962ab..19e7e77 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -42,7 +42,7 @@ kroneckerfactortypes(T::Type) = throw(MethodError(kroneckerfactortypes, (T,))) Construct an object that represents the Kronecker product of the provided `args`. """ (⊗) -function ⊗(a, b) end +function ⊗ end const otimes = ⊗ # non-unicode alternative # Includes diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index d52c4e7..24b9702 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -415,6 +415,29 @@ function Base.reshape( return reshape(a, kroneckerfactors.(ax, 1)) ⊗ reshape(b, kroneckerfactors.(ax, 2)) end +function Base.fill!(ab::AbstractKroneckerArray, v) + a, b = kroneckerfactors(ab) + fill!(a, √v) + fill!(b, √v) + return ab +end +function Base.fill!(ab::AbstractKroneckerMatrix, v) + a, b = kroneckerfactors(ab) + (!isactive(a) && isone(a)) && (fill!(b, v); return ab) + (!isactive(b) && isone(b)) && (fill!(a, v); return ab) + fill!(a, √v) + fill!(b, √v) + return ab +end +function Base.fill!(ab::AbstractKroneckerVector, v) + a, b = kroneckerfactors(ab) + (!isactive(a) && all(isone, a)) && (fill!(b, v); return ab) + (!isactive(b) && all(isone, b)) && (fill!(a, v); return ab) + fill!(a, √v) + fill!(b, √v) + return ab +end + using Base.Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted struct KroneckerStyle{N, A, B} <: BC.AbstractArrayStyle{N} end diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index 12df9c0..0a70767 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -80,7 +80,11 @@ end for f in (:eig_vals, :eigh_vals, :svd_vals) f! = Symbol(f, :!) - @eval MAK.initialize_output(::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm) = nothing + @eval function MAK.initialize_output( + ::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm + ) + return nothing + end @eval function MAK.$f!(ab::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm) a, b = kroneckerfactors(ab) algA, algB = kroneckerfactors(alg) @@ -88,26 +92,97 @@ for f in (:eig_vals, :eigh_vals, :svd_vals) end end -for f in (:left_orth, :right_orth) - f! = Symbol(f, :!) - @eval MAK.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) = nothing - @eval function MAK.$f!(ab::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...) - a, b = kroneckerfactors(ab) - Fa = MAK.$f(a; kwargs..., kwargs1...) - Fb = MAK.$f(b; kwargs..., kwargs2...) - return Fa .⊗ Fb +# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104 +# is merged. +for kind in ("polar", "qr", "svd") + @eval begin + function MAK.initialize_output( + ::typeof(left_orth!), a::AbstractKroneckerMatrix, + alg::MAK.LeftOrthAlgorithm{Symbol($kind)}, + ) + return nothing + end + function MAK.left_orth!( + ab::AbstractKroneckerMatrix, F, alg::MAK.LeftOrthAlgorithm{Symbol($kind)}; + kwargs1 = (;), kwargs2 = (;), kwargs..., + ) + a, b = kroneckerfactors(ab) + Fa = MAK.left_orth!(a; kwargs..., kwargs1...) + Fb = MAK.left_orth!(b; kwargs..., kwargs2...) + return Fa .⊗ Fb + end end end -for f in [:left_null, :right_null] - f! = Symbol(f, :!) - @eval MAK.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) = - nothing - @eval function MAK.$f!(ab::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...) - a, b = kroneckerfactors(ab) - Na = MAK.$f(a; kwargs..., kwargs1...) - Nb = MAK.$f(b; kwargs..., kwargs2...) - return Na ⊗ Nb +# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104 +# is merged. +for kind in ("lq", "polar", "svd") + @eval begin + function MAK.initialize_output( + ::typeof(right_orth!), a::AbstractKroneckerMatrix, + alg::MAK.RightOrthAlgorithm{Symbol($kind)}, + ) + return nothing + end + function MAK.right_orth!( + ab::AbstractKroneckerMatrix, F, alg::MAK.RightOrthAlgorithm{Symbol($kind)}; + kwargs1 = (;), kwargs2 = (;), kwargs..., + ) + a, b = kroneckerfactors(ab) + Fa = MAK.right_orth!(a; kwargs..., kwargs1...) + Fb = MAK.right_orth!(b; kwargs..., kwargs2...) + return Fa .⊗ Fb + end + end +end + +# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104 +# is merged. +for Alg in ( + :(MAK.LeftNullViaQR), + :(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm}), + :(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}), + ) + @eval begin + function MAK.initialize_output( + ::typeof(left_null!), a::AbstractKroneckerMatrix, alg::$Alg + ) + return nothing + end + function MAK.left_null!( + ab::AbstractKroneckerMatrix, F, alg::$Alg; + kwargs1 = (;), kwargs2 = (;), kwargs..., + ) + a, b = kroneckerfactors(ab) + Na = MAK.left_null!(a; kwargs..., kwargs1...) + Nb = MAK.left_null!(b; kwargs..., kwargs2...) + return Na ⊗ Nb + end + end +end + +# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104 +# is merged. +for Alg in ( + :(MAK.RightNullViaLQ), + :(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm}), + :(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}), + ) + @eval begin + function MAK.initialize_output( + ::typeof(right_null!), a::AbstractKroneckerMatrix, alg::$Alg + ) + return nothing + end + function MAK.right_null!( + ab::AbstractKroneckerMatrix, F, alg::$Alg; + kwargs1 = (;), kwargs2 = (;), kwargs..., + ) + a, b = kroneckerfactors(ab) + Na = MAK.right_null!(a; kwargs..., kwargs1...) + Nb = MAK.right_null!(b; kwargs..., kwargs2...) + return Na ⊗ Nb + end end end diff --git a/test/Project.toml b/test/Project.toml index b9d9be9..c344ab9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,7 +31,7 @@ GPUArraysCore = "0.2" JLArrays = "0.2, 0.3" KroneckerArrays = "0.3" LinearAlgebra = "1.10" -MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5" +MatrixAlgebraKit = "0.6" SafeTestsets = "0.1" StableRNGs = "1.0" Suppressor = "0.2"