Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.2.8"
version = "0.2.9"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
50 changes: 46 additions & 4 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,53 @@ function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
end

# TODO: this definition doesn't fully retain the original meaning:
# ‖a - b‖ < atol could be true even if the following check isn't
function Base.isapprox(a::AbstractKroneckerArray, b::AbstractKroneckerArray; kwargs...)
return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...)
# norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2)
# = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2))
function dist_kronecker(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
a1, a2 = arg1(a), arg2(a)
b1, b2 = arg1(b), arg2(b)
diff1 = a1 - b1
diff2 = a2 - b2
# x = (a1 - b1) ⊗ a2
# y = b1 ⊗ (a2 - b2)
# z = (a1 - b1) ⊗ (a2 - b2)
xx = norm(diff1)^2 * norm(a2)^2
yy = norm(b1)^2 * norm(diff2)^2
zz = norm(diff1)^2 * norm(diff2)^2
xy = real(dot(diff1, b1) * dot(a2, diff2))
xz = real(dot(diff1, diff1) * dot(a2, diff2))
yz = real(dot(b1, diff1) * dot(diff2, diff2))
# `abs` is used in case there are negative values due to floating point roundoff errors.
return sqrt(abs(xx + yy + zz + 2 * (xy + xz + yz)))
end

using LinearAlgebra: dot, promote_leaf_eltypes
function Base.isapprox(
a::AbstractKroneckerArray, b::AbstractKroneckerArray; atol::Real = 0,
rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol),
)
a1, a2 = arg1(a), arg2(a)
b1, b2 = arg1(b), arg2(b)
if a1 == b1
return isapprox(a2, b2; atol = atol / norm(a1), rtol)
elseif a2 == b2
return isapprox(a1, b1; atol = atol / norm(a2), rtol)
else
# This could be defined as:
# ```julia
# d = KroneckerArrays.dist_kronecker(a, b)
# iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b)))
# ```
# but that might have numerical precision issues so for now we just error.
throw(
ArgumentError(
"`isapprox` not implemented for KroneckerArrays where both arguments differ. " *
"In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`."
)
)
end
end

function Base.iszero(a::AbstractKroneckerArray)
return iszero(arg1(a)) || iszero(arg2(a))
end
Expand Down
2 changes: 1 addition & 1 deletion src/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function LinearAlgebra.tr(a::AbstractKroneckerArray)
end

using LinearAlgebra: norm
function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Int = 2)
function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Real = 2)
return norm(arg1(a), p) * norm(arg2(a), p)
end

Expand Down
66 changes: 50 additions & 16 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,9 @@ using DerivableInterfaces: zero!
using DiagonalArrays: diagonal
using GPUArraysCore: @allowscalar
using JLArrays: JLArray
using KroneckerArrays:
KroneckerArrays,
KroneckerArray,
KroneckerStyle,
CartesianProductUnitRange,
CartesianProductVector,
⊗,
×,
arg1,
arg2,
cartesianproduct,
cartesianrange,
kron_nd,
unproduct
using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerStyle,
CartesianProductUnitRange, CartesianProductVector, ⊗, ×, arg1, arg2, cartesianproduct,
cartesianrange, kron_nd, unproduct
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr
using StableRNGs: StableRNG
using Test: @test, @test_broken, @test_throws, @testset
Expand Down Expand Up @@ -219,10 +208,11 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c = a.arg1 ⊗ b.arg2
c = arg1(a) ⊗ arg2(b)
U, S, V = svd(a)
@test collect(U * diagonal(S) * V') ≈ collect(a)
@test svdvals(a) ≈ S
@test arg1(svdvals(a)) ≈ arg1(S)
@test arg2(svdvals(a)) ≈ arg2(S)
@test sort(collect(S); rev = true) ≈ svdvals(collect(a))
@test collect(U'U) ≈ I
@test collect(V * V') ≈ I
Expand All @@ -246,4 +236,48 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test_throws ArgumentError $f($a)
end
end

# isapprox

rng = StableRNG(123)
a1 = randn(rng, elt, (2, 2))
a = a1 ⊗ randn(rng, elt, (3, 3))
b = a1 ⊗ randn(rng, elt, (3, 3))
@test isapprox(a, b; atol = norm(a - b) * (1 + 2eps(real(elt))))
@test !isapprox(a, b; atol = norm(a - b) * (1 - 2eps(real(elt))))
@test isapprox(
a, b;
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt)))
)
@test !isapprox(
a, b;
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt)))
)
@test isapprox(
a, b; atol = norm(a - b) * (1 + 2eps(real(elt))),
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt)))
)
@test isapprox(
a, b; atol = norm(a - b) * (1 + 2eps(real(elt))),
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt)))
)
@test isapprox(
a, b; atol = norm(a - b) * (1 - 2eps(real(elt))),
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt)))
)
@test !isapprox(
a, b; atol = norm(a - b) * (1 - 2eps(real(elt))),
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt)))
)

a = randn(elt, (2, 2)) ⊗ randn(elt, (3, 3))
b = randn(elt, (2, 2)) ⊗ randn(elt, (3, 3))
@test_throws ArgumentError isapprox(a, b)

# KroneckerArrays.dist_kronecker
rng = StableRNG(123)
a = randn(rng, (100, 100)) ⊗ randn(rng, (100, 100))
b = (arg1(a) + randn(rng, size(arg1(a))) / 10) ⊗
(arg2(a) + randn(rng, size(arg2(a))) / 10)
@test KroneckerArrays.dist_kronecker(a, b) ≈ norm(collect(a) - collect(b)) rtol = 1.0e-2
end
83 changes: 48 additions & 35 deletions test/test_matrixalgebrakit.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,8 @@
using KroneckerArrays: ⊗, arguments
using KroneckerArrays: ⊗, arg1, arg2
using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm
using MatrixAlgebraKit:
eig_full,
eig_trunc,
eig_vals,
eigh_full,
eigh_trunc,
eigh_vals,
left_null,
left_orth,
left_polar,
lq_compact,
lq_full,
qr_compact,
qr_full,
right_null,
right_orth,
right_polar,
svd_compact,
svd_full,
svd_trunc,
using MatrixAlgebraKit: eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc,
eigh_vals, left_null, left_orth, left_polar, lq_compact, lq_full, qr_compact,
qr_full, right_null, right_orth, right_polar, svd_compact, svd_full, svd_trunc,
svd_vals
using Test: @test, @test_throws, @testset
using TestExtras: @constinferred
Expand All @@ -31,18 +14,26 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
d, v = eig_full(a)
@test a * v ≈ v * d
av = a * v
vd = v * d
@test arg1(av) ≈ arg1(vd)
@test arg2(av) ≈ arg2(vd)

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
@test_throws ArgumentError eig_trunc(a)

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
d = eig_vals(a)
@test d ≈ diag(eig_full(a)[1])
d′ = diag(eig_full(a)[1])
@test arg1(d) ≈ arg1(d′)
@test arg2(d) ≈ arg2(d′)

a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3))
d, v = eigh_full(a)
@test a * v ≈ v * d
av = a * v
vd = v * d
@test arg1(av) ≈ arg1(vd)
@test arg2(av) ≈ arg2(vd)
@test eltype(d) === real(elt)
@test eltype(v) === elt

Expand All @@ -56,22 +47,30 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, c = qr_compact(a)
@test u * c ≈ a
uc = u * c
@test arg1(uc) ≈ arg1(a)
@test arg2(uc) ≈ arg2(a)
@test collect(u'u) ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, c = qr_full(a)
@test u * c ≈ a
uc = u * c
@test arg1(uc) ≈ arg1(a)
@test arg2(uc) ≈ arg2(a)
@test collect(u'u) ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c, u = lq_compact(a)
@test c * u ≈ a
cu = c * u
@test arg1(cu) ≈ arg1(a)
@test arg2(cu) ≈ arg2(a)
@test collect(u * u') ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c, u = lq_full(a)
@test c * u ≈ a
cu = c * u
@test arg1(cu) ≈ arg1(a)
@test arg2(cu) ≈ arg2(a)
@test collect(u * u') ≈ I

a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3)
Expand All @@ -84,27 +83,37 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, c = left_orth(a)
@test u * c ≈ a
uc = u * c
@test arg1(uc) ≈ arg1(a)
@test arg2(uc) ≈ arg2(a)
@test collect(u'u) ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c, u = right_orth(a)
@test c * u ≈ a
cu = c * u
@test arg1(cu) ≈ arg1(a)
@test arg2(cu) ≈ arg2(a)
@test collect(u * u') ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, c = left_polar(a)
@test u * c ≈ a
uc = u * c
@test arg1(uc) ≈ arg1(a)
@test arg2(uc) ≈ arg2(a)
@test collect(u'u) ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c, u = right_polar(a)
@test c * u ≈ a
cu = c * u
@test arg1(cu) ≈ arg1(a)
@test arg2(cu) ≈ arg2(a)
@test collect(u * u') ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, s, v = svd_compact(a)
@test u * s * v ≈ a
usv = u * s * v
@test arg1(usv) ≈ arg1(a)
@test arg2(usv) ≈ arg2(a)
@test eltype(u) === elt
@test eltype(s) === real(elt)
@test eltype(v) === elt
Expand All @@ -113,7 +122,9 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, s, v = svd_full(a)
@test u * s * v ≈ a
usv = u * s * v
@test arg1(usv) ≈ arg1(a)
@test arg2(usv) ≈ arg2(a)
@test eltype(u) === elt
@test eltype(s) === real(elt)
@test eltype(v) === elt
Expand All @@ -125,5 +136,7 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
s = svd_vals(a)
@test s ≈ diag(svd_compact(a)[2])
s′ = diag(svd_compact(a)[2])
@test arg1(s) ≈ arg1(s′)
@test arg2(s) ≈ arg2(s′)
end
Loading