Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.2.7"
version = "0.2.8"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module KroneckerArraysTensorAlgebraExt

using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2
using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, arg1, arg2
using TensorAlgebra:
TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize

Expand All @@ -10,7 +10,7 @@ struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle
end
KroneckerArrays.arg1(style::KroneckerFusion) = style.a
KroneckerArrays.arg2(style::KroneckerFusion) = style.b
function TensorAlgebra.FusionStyle(a::KroneckerArray)
function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray)
return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a)))
end
function matricize_kronecker(
Expand Down
141 changes: 90 additions & 51 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,35 @@
"""
abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end

Abstract supertype for arrays that have a kronecker product structure,
i.e. that can be written as `AB = A ⊗ B`.
"""
abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end

const AbstractKroneckerVector{T} = AbstractKroneckerArray{T, 1}
const AbstractKroneckerMatrix{T} = AbstractKroneckerArray{T, 2}

@doc """
arg1(AB::AbstractKroneckerArray{T, N})

Extract the first factor (`A`) of the Kronecker array `AB = A ⊗ B`.
""" arg1

@doc """
arg2(AB::AbstractKroneckerArray{T, N})

Extract the second factor (`B`) of the Kronecker array `AB = A ⊗ B`.
""" arg2
Comment on lines +12 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was realizing that if we are using these functions outside of KroneckerArrays.jl, these names are a bit too general. We can of course call them as KroneckerArrays.arg1(a), etc. explicitly, or rename them to something more explicit like kron_arg1(a), etc. Any opinion on that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess arg1 and arg2 are also used for non-Kronecker data structures such as CartesianProduct so maybe kron_arg* isn't the best name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

factors(x, [i])?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But presumably better to change that in a separate PR, this is currently non-breaking

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that kind of change should be done in a separate PR. factors(x, [i]) is probably a better interface.


arg1type(x::AbstractKroneckerArray) = arg1type(typeof(x))
arg1type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` subtypes have to implement `arg1type`.")
arg2type(x::AbstractKroneckerArray) = arg2type(typeof(x))
arg2type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` subtypes have to implement `arg2type`.")

arguments(a::AbstractKroneckerArray) = (arg1(a), arg2(a))
arguments(a::AbstractKroneckerArray, n::Int) = arguments(a)[n]
argument_types(a::AbstractKroneckerArray) = argument_types(typeof(a))
Comment on lines +29 to +31
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in practice these aren't used so maybe they can be removed for now, I think I was undecided whether I would use arguments vs. arg1/arg2 but in practice I found arg1/arg2 were easier to use. That could be a separate PR of course, just bringing it up here since this PR reminded me of this issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your suggestion of factors(x, [i]) would address this issue.


function unwrap_array(a::AbstractArray)
p = parent(a)
p ≡ a && return a
Expand Down Expand Up @@ -26,7 +58,7 @@ function _convert(A::Type{<:Diagonal}, a::AbstractMatrix)
end

struct KroneckerArray{T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}} <:
AbstractArray{T, N}
AbstractKroneckerArray{T, N}
arg1::A1
arg2::A2
end
Expand All @@ -48,6 +80,10 @@ const KroneckerVector{T, A1 <: AbstractVector{T}, A2 <: AbstractVector{T}} = Kro

@inline arg1(a::KroneckerArray) = getfield(a, :arg1)
@inline arg2(a::KroneckerArray) = getfield(a, :arg2)
arg1type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A1
arg2type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A2

argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2)

function mutate_active_args!(f!, f, dest, src)
(isactive(arg1(dest)) || isactive(arg2(dest))) ||
Expand All @@ -66,7 +102,7 @@ function mutate_active_args!(f!, f, dest, src)
end

using Adapt: Adapt, adapt
function Adapt.adapt_structure(to, a::KroneckerArray)
function Adapt.adapt_structure(to, a::AbstractKroneckerArray)
# TODO: Is this a good definition? It is similar to
# the definition of `similar`.
return if isactive(arg1(a)) == isactive(arg2(a))
Expand All @@ -78,18 +114,22 @@ function Adapt.adapt_structure(to, a::KroneckerArray)
end
end

function Base.copy(a::KroneckerArray)
return copy(arg1(a)) ⊗ copy(arg2(a))
Base.copy(a::AbstractKroneckerArray) = copy(arg1(a)) ⊗ copy(arg2(a))
function Base.copy!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray)
return mutate_active_args!(copy!, copy, dest, src)
end

# TODO: copyto! is typically reserved for contiguous copies (i.e. also for copying from a
# vector into an array), it might be better to not define that here.
function Base.copyto!(dest::KroneckerArray{<:Any, N}, src::KroneckerArray{<:Any, N}) where {N}
return mutate_active_args!(copyto!, copy, dest, src)
end

function Base.convert(
::Type{KroneckerArray{T, N, A1, A2}}, a::KroneckerArray
) where {T, N, A1, A2}
return _convert(A1, arg1(a)) ⊗ _convert(A2, arg2(a))
::Type{KroneckerArray{T, N, A1, A2}}, a::AbstractKroneckerArray
)::KroneckerArray{T, N, A1, A2} where {T, N, A1, A2}
typeof(a) === KroneckerArray{T, N, A1, A2} && return a
return KroneckerArray(_convert(A1, arg1(a)), _convert(A2, arg2(a)))
end

# Promote the element type if needed.
Expand All @@ -98,7 +138,7 @@ end
maybe_promot_eltype(a, elt) = eltype(a) <: elt ? a : elt.(a)

function Base.similar(
a::KroneckerArray,
a::AbstractKroneckerArray,
elt::Type,
axs::Tuple{
CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}},
Expand All @@ -115,7 +155,7 @@ function Base.similar(
maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt, arg2.(axs))
end
end
function Base.similar(a::KroneckerArray, elt::Type)
function Base.similar(a::AbstractKroneckerArray, elt::Type)
# TODO: Is this a good definition?
return if isactive(arg1(a)) == isactive(arg2(a))
similar(arg1(a), elt) ⊗ similar(arg2(a), elt)
Expand All @@ -125,7 +165,7 @@ function Base.similar(a::KroneckerArray, elt::Type)
maybe_promot_eltype(arg1(a), elt) ⊗ similar(arg2(a), elt)
end
end
function Base.similar(a::KroneckerArray)
function Base.similar(a::AbstractKroneckerArray)
# TODO: Is this a good definition?
return if isactive(arg1(a)) == isactive(arg2(a))
similar(arg1(a)) ⊗ similar(arg2(a))
Expand All @@ -147,16 +187,18 @@ function Base.similar(
end

function Base.similar(
arrayt::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}},
::Type{ArrayT},
axs::Tuple{
CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}},
},
) where {A1, A2}
) where {ArrayT <: AbstractKroneckerArray}
A1, A2 = arg1type(ArrayT), arg2type(ArrayT)
return similar(A1, map(arg1, axs)) ⊗ similar(A2, map(arg2, axs))
end
function Base.similar(
::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}, sz::Tuple{Int, Vararg{Int}}
) where {A1, A2}
::Type{ArrayT}, sz::Tuple{Int, Vararg{Int}}
) where {ArrayT <: AbstractKroneckerArray}
A1, A2 = arg1type(ArrayT), arg2type(ArrayT)
return similar(promote_type(A1, A2), sz)
end

Expand All @@ -169,15 +211,15 @@ function Base.similar(
return similar(arrayt, map(arg1, axs)) ⊗ similar(arrayt, map(arg2, axs))
end

function Base.permutedims(a::KroneckerArray, perm)
function Base.permutedims(a::AbstractKroneckerArray, perm)
return permutedims(arg1(a), perm) ⊗ permutedims(arg2(a), perm)
end
using DerivableInterfaces: DerivableInterfaces, permuteddims
function DerivableInterfaces.permuteddims(a::KroneckerArray, perm)
function DerivableInterfaces.permuteddims(a::AbstractKroneckerArray, perm)
return permuteddims(arg1(a), perm) ⊗ permuteddims(arg2(a), perm)
end

function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm)
function Base.permutedims!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray, perm)
return mutate_active_args!(
(dest, src) -> permutedims!(dest, src, perm), Base.Fix2(permutedims, perm), dest, src
)
Expand Down Expand Up @@ -208,9 +250,10 @@ kron_nd(a1::AbstractMatrix, a2::AbstractMatrix) = kron(a1, a2)
kron_nd(a1::AbstractVector, a2::AbstractVector) = kron(a1, a2)

# Eagerly collect arguments to make more general on GPU.
Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
Base.collect(a::AbstractKroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
Base.collect(T::Type, a::AbstractKroneckerArray) = kron_nd(collect(T, arg1(a)), collect(T, arg2(a)))

function Base.zero(a::KroneckerArray)
function Base.zero(a::AbstractKroneckerArray)
return if isactive(arg1(a)) == isactive(arg2(a))
# TODO: Maybe this should zero both arguments?
# This is how `a * false` would behave.
Expand All @@ -223,35 +266,28 @@ function Base.zero(a::KroneckerArray)
end

using DerivableInterfaces: DerivableInterfaces, zero!
function DerivableInterfaces.zero!(a::KroneckerArray)
function DerivableInterfaces.zero!(a::AbstractKroneckerArray)
(isactive(arg1(a)) || isactive(arg2(a))) ||
error("Can't mutate immutable KroneckerArray.")
isactive(arg1(a)) && zero!(arg1(a))
isactive(arg2(a)) && zero!(arg2(a))
return a
end

function Base.Array{T, N}(a::KroneckerArray{S, N}) where {T, S, N}
return convert(Array{T, N}, collect(a))
function Base.Array{T, N}(a::AbstractKroneckerArray{S, N}) where {T, S, N}
return convert(Array{T, N}, collect(T, a))
end

function Base.size(a::KroneckerArray)
return ntuple(dim -> size(arg1(a), dim) * size(arg2(a), dim), ndims(a))
end
Base.size(a::AbstractKroneckerArray) = size(arg1(a)) .* size(arg2(a))

function Base.axes(a::KroneckerArray)
function Base.axes(a::AbstractKroneckerArray)
return ntuple(ndims(a)) do dim
return CartesianProductUnitRange(
axes(arg1(a), dim) × axes(arg2(a), dim), Base.OneTo(size(a, dim))
)
end
end

arguments(a::KroneckerArray) = (arg1(a), arg2(a))
arguments(a::KroneckerArray, n::Int) = arguments(a)[n]
argument_types(a::KroneckerArray) = argument_types(typeof(a))
argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2)

function Base.print_array(io::IO, a::KroneckerArray)
Base.print_array(io, arg1(a))
println(io, "\n ⊗")
Expand Down Expand Up @@ -285,45 +321,48 @@ end

# Indexing logic.
function Base.to_indices(
a::KroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
a::AbstractKroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
)
I1 = to_indices(arg1(a), arg1.(inds), arg1.(I))
I2 = to_indices(arg2(a), arg2.(inds), arg2.(I))
return I1 .× I2
end

function Base.getindex(
a::KroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N}
a::AbstractKroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N}
) where {N}
I′ = to_indices(a, I)
return arg1(a)[arg1.(I′)...] ⊗ arg2(a)[arg2.(I′)...]
end
# Fix ambigiuity error.
Base.getindex(a::KroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[]
Base.getindex(a::AbstractKroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[]

arg1(::Colon) = (:)
arg2(::Colon) = (:)
arg1(::Base.Slice) = (:)
arg2(::Base.Slice) = (:)
function Base.view(
a::KroneckerArray{<:Any, N},
a::AbstractKroneckerArray{<:Any, N},
I::Vararg{Union{CartesianProduct, CartesianProductUnitRange, Base.Slice, Colon}, N},
) where {N}
return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...)
end
function Base.view(a::KroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N}
function Base.view(a::AbstractKroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N}
return view(arg1(a), arg1.(I)...) ⊗ view(arg2(a), arg2.(I)...)
end
# Fix ambigiuity error.
Base.view(a::KroneckerArray{<:Any, 0}) = view(arg1(a)) ⊗ view(arg2(a))
Base.view(a::AbstractKroneckerArray{<:Any, 0}) = view(arg1(a)) ⊗ view(arg2(a))

function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
end
function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...)

# TODO: this definition doesn't fully retain the original meaning:
# ‖a - b‖ < atol could be true even if the following check isn't
Comment on lines +360 to +361
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I was a bit lazy with this definition but as you say we would need to be more careful about the tolerances. This is a case where it would be easier to special case for SectorArray.

Copy link
Member

@mtfishman mtfishman Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best I could come up with is:

using LinearAlgebra: promote_leaf_eltypes
function Base.isapprox(
        a::AbstractKroneckerArray, b::AbstractKroneckerArray;
        atol::Real = 0,
        rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol),
        norm::Function = norm
    )
    a1, a2 = arg1(a), arg2(a)
    b1, b2 = arg1(b), arg2(b)
    # Approximation of:
    # norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2)
    #             = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2))
    diff1 = norm(a1 - b1)
    diff2 = norm(a2 - b2)
    d = diff1 * norm(a2) + norm(b1) * diff2 + diff1 * diff2
    return iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b)))
end

which would work for SectorArrays since a1 == b1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I'm fully following your derivation, I don't see how you can split the norms like that?
In principle there is the following formula:

$$|| a1 \otimes a2 - b1 \otimes b2 ||^2 = || a1 \otimes a2 ||^2 + || b1 \otimes b2 ||^2 - 2 Re \langle a1 \otimes a2, b1 \otimes b2 \rangle$$

But this doesn't behave well with finite precision, since we are basically subtracting equal magnitude numbers to attempt to find something of much smaller magnitude

function Base.isapprox(a::AbstractKroneckerArray, b::AbstractKroneckerArray; kwargs...)
return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...)
end
function Base.iszero(a::KroneckerArray)
function Base.iszero(a::AbstractKroneckerArray)
return iszero(arg1(a)) || iszero(arg2(a))
end
function Base.isreal(a::KroneckerArray)
Expand All @@ -335,17 +374,17 @@ function DiagonalArrays.diagonal(a::KroneckerArray)
return diagonal(arg1(a)) ⊗ diagonal(arg2(a))
end

Base.real(a::KroneckerArray{<:Real}) = a
function Base.real(a::KroneckerArray)
Base.real(a::AbstractKroneckerArray{<:Real}) = a
function Base.real(a::AbstractKroneckerArray)
if iszero(imag(arg1(a))) || iszero(imag(arg2(a)))
return real(arg1(a)) ⊗ real(arg2(a))
elseif iszero(real(arg1(a))) || iszero(real(arg2(a)))
return -(imag(arg1(a)) ⊗ imag(arg2(a)))
end
return real(arg1(a)) ⊗ real(arg2(a)) - imag(arg1(a)) ⊗ imag(arg2(a))
end
Base.imag(a::KroneckerArray{<:Real}) = zero(a)
function Base.imag(a::KroneckerArray)
Base.imag(a::AbstractKroneckerArray{<:Real}) = zero(a)
function Base.imag(a::AbstractKroneckerArray)
if iszero(imag(arg1(a))) || iszero(real(arg2(a)))
return real(arg1(a)) ⊗ imag(arg2(a))
elseif iszero(real(arg1(a))) || iszero(imag(arg2(a)))
Expand All @@ -356,14 +395,14 @@ end

for f in [:transpose, :adjoint, :inv]
@eval begin
function Base.$f(a::KroneckerArray)
function Base.$f(a::AbstractKroneckerArray)
return $f(arg1(a)) ⊗ $f(arg2(a))
end
end
end

function Base.reshape(
a::KroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
a::AbstractKroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
)
return reshape(arg1(a), map(arg1, ax)) ⊗ reshape(arg2(a), map(arg2, ax))
end
Expand All @@ -383,8 +422,8 @@ end
function KroneckerStyle{N, A1, A2}(v::Val{M}) where {N, A1, A2, M}
return KroneckerStyle{M, typeof(A1)(v), typeof(A2)(v)}()
end
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any, N, A1, A2}}) where {N, A1, A2}
return KroneckerStyle{N}(BroadcastStyle(A1), BroadcastStyle(A2))
function Base.BroadcastStyle(::Type{T}) where {T <: AbstractKroneckerArray}
return KroneckerStyle{ndims(T)}(BroadcastStyle(arg1type(T)), BroadcastStyle(arg2type(T)))
end
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
style_a = BroadcastStyle(arg1(style1), arg1(style2))
Expand All @@ -403,10 +442,10 @@ function Base.similar(
return a ⊗ b
end

function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...)
function Base.map(f, a1::AbstractKroneckerArray, a_rest::AbstractKroneckerArray...)
return Broadcast.broadcast_preserving_zero_d(f, a1, a_rest...)
end
function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...)
function Base.map!(f, dest::AbstractKroneckerArray, a1::AbstractKroneckerArray, a_rest::AbstractKroneckerArray...)
dest .= f.(a1, a_rest...)
return dest
end
Expand Down Expand Up @@ -438,7 +477,7 @@ end
function Base.copy(a::Summed{<:KroneckerStyle})
return copy(KroneckerBroadcast(a))
end
function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle})
function Base.copyto!(dest::AbstractKroneckerArray, a::Summed{<:KroneckerStyle})
return copyto!(dest, KroneckerBroadcast(a))
end

Expand Down
Loading
Loading