From 16b60ba2af34010ea9dc4c5245a53cde4b1b50c3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Nov 2025 16:16:24 -0500 Subject: [PATCH 1/9] Better definition of isapprox --- Project.toml | 2 +- src/kroneckerarray.jl | 20 ++++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index de46232..aa26928 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.2.7" +version = "0.2.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 28328a7..3e6f595 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -320,9 +320,25 @@ Base.view(a::KroneckerArray{<:Any, 0}) = view(arg1(a)) ⊗ view(arg2(a)) function Base.:(==)(a::KroneckerArray, b::KroneckerArray) return arg1(a) == arg1(b) && arg2(a) == arg2(b) end -function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...) - return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...) + +using LinearAlgebra: promote_leaf_eltypes +function Base.isapprox( + a::KroneckerArray, b::KroneckerArray; + atol::Real = 0, + rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol), + norm::Function = norm + ) + a1, a2 = arg1(a), arg2(a) + b1, b2 = arg1(b), arg2(b) + # Approximation of: + # norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2) + # = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2)) + diff1 = norm(a1 - b1) + diff2 = norm(a2 - b2) + d = diff1 * norm(a2) + norm(b1) * diff2 + diff1 * diff2 + return iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b))) end + function Base.iszero(a::KroneckerArray) return iszero(arg1(a)) || iszero(arg2(a)) end From db344d55551e0e597a2517985d4d1dee14f7d5f1 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 4 Nov 2025 16:23:37 -0500 Subject: [PATCH 2/9] Fix merge issue --- src/kroneckerarray.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index a4d5f75..1b9bbab 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -375,8 +375,6 @@ function Base.isapprox( return iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b))) end -function Base.iszero(a::KroneckerArray) - function Base.iszero(a::AbstractKroneckerArray) return iszero(arg1(a)) || iszero(arg2(a)) end From add9b069f308a0627667d25c41005343f3fd2e48 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Nov 2025 17:16:02 -0500 Subject: [PATCH 3/9] Include cross terms --- src/kroneckerarray.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 1b9bbab..a052dba 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -357,7 +357,7 @@ function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray) return arg1(a) == arg1(b) && arg2(a) == arg2(b) end -using LinearAlgebra: promote_leaf_eltypes +using LinearAlgebra: dot, promote_leaf_eltypes function Base.isapprox( a::KroneckerArray, b::KroneckerArray; atol::Real = 0, @@ -369,9 +369,9 @@ function Base.isapprox( # Approximation of: # norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2) # = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2)) - diff1 = norm(a1 - b1) - diff2 = norm(a2 - b2) - d = diff1 * norm(a2) + norm(b1) * diff2 + diff1 * diff2 + diff1 = a1 - b1 + diff2 = a2 - b2 + d = sqrt(norm(diff1)^2 * norm(a2)^2 + norm(b1)^2 * norm(diff2)^2 + 2 * real(dot(diff1, b1) * dot(b2, diff2))) return iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b))) end From 451cd32d44a6701032f06b64bb64e928ce698e85 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Nov 2025 17:18:20 -0500 Subject: [PATCH 4/9] Abstract definition --- src/kroneckerarray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index a052dba..7f9ab64 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -359,7 +359,7 @@ end using LinearAlgebra: dot, promote_leaf_eltypes function Base.isapprox( - a::KroneckerArray, b::KroneckerArray; + a::AbstractKroneckerArray, b::AbstractKroneckerArray; atol::Real = 0, rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol), norm::Function = norm From 61e24208daabc631b2774c7a50cb92c050ad20c2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Nov 2025 18:20:09 -0500 Subject: [PATCH 5/9] Stricter isapprox, update tests --- src/kroneckerarray.jl | 37 +++++++++++++--- test/test_basics.jl | 28 ++++++------- test/test_matrixalgebrakit.jl | 79 ++++++++++++++++++++--------------- 3 files changed, 88 insertions(+), 56 deletions(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 7f9ab64..04d0d72 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -357,6 +357,25 @@ function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray) return arg1(a) == arg1(b) && arg2(a) == arg2(b) end +# norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2) +# = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2)) +function dist(a::AbstractKroneckerArray, b::AbstractKroneckerArray) + a1, a2 = arg1(a), arg2(a) + b1, b2 = arg1(b), arg2(b) + diff1 = a1 - b1 + diff2 = a2 - b2 + # x = (a1 - b1) ⊗ a2 + # y = b1 ⊗ (a2 - b2) + # z = (a1 - b1) ⊗ (a2 - b2) + xx = norm(diff1)^2 * norm(a2)^2 + yy = norm(b1)^2 * norm(diff2)^2 + zz = norm(diff1)^2 * norm(diff2)^2 + xy = real(dot(diff1, b1) * dot(a2, diff2)) + xz = real(dot(diff1, diff1) * dot(a2, diff2)) + yz = real(dot(b1, diff1) * dot(diff2, diff2)) + return sqrt(abs(xx + yy + zz + 2 * (xy + xz + yz))) +end + using LinearAlgebra: dot, promote_leaf_eltypes function Base.isapprox( a::AbstractKroneckerArray, b::AbstractKroneckerArray; @@ -366,12 +385,18 @@ function Base.isapprox( ) a1, a2 = arg1(a), arg2(a) b1, b2 = arg1(b), arg2(b) - # Approximation of: - # norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2) - # = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2)) - diff1 = a1 - b1 - diff2 = a2 - b2 - d = sqrt(norm(diff1)^2 * norm(a2)^2 + norm(b1)^2 * norm(diff2)^2 + 2 * real(dot(diff1, b1) * dot(b2, diff2))) + d = if a1 == b1 + norm(b1) * norm(a2 - b2) + elseif a2 == b2 + norm(a1 - b1) * norm(b2) + else + # This could be defined as `KroneckerArrays.dist(a, b)`, but that might have + # numerical precision issues so for now we just error. + error( + "`isapprox` not implemented for KroneckerArrays where both arguments differ. " * + "In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`." + ) + end return iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b))) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 4b8c8e8..ce0863a 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -4,20 +4,9 @@ using DerivableInterfaces: zero! using DiagonalArrays: diagonal using GPUArraysCore: @allowscalar using JLArrays: JLArray -using KroneckerArrays: - KroneckerArrays, - KroneckerArray, - KroneckerStyle, - CartesianProductUnitRange, - CartesianProductVector, - ⊗, - ×, - arg1, - arg2, - cartesianproduct, - cartesianrange, - kron_nd, - unproduct +using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerStyle, + CartesianProductUnitRange, CartesianProductVector, ⊗, ×, arg1, arg2, cartesianproduct, + cartesianrange, kron_nd, unproduct using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr using StableRNGs: StableRNG using Test: @test, @test_broken, @test_throws, @testset @@ -219,10 +208,11 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - c = a.arg1 ⊗ b.arg2 + c = arg1(a) ⊗ arg2(b) U, S, V = svd(a) @test collect(U * diagonal(S) * V') ≈ collect(a) - @test svdvals(a) ≈ S + @test arg1(svdvals(a)) ≈ arg1(S) + @test arg2(svdvals(a)) ≈ arg2(S) @test sort(collect(S); rev = true) ≈ svdvals(collect(a)) @test collect(U'U) ≈ I @test collect(V * V') ≈ I @@ -246,4 +236,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test_throws ArgumentError $f($a) end end + + # KroneckerArrays.dist + rng = StableRNG(123) + a = randn(rng, 100, 100) ⊗ randn(rng, 100, 100) + b = (arg1(a) + 1.0e-1 * randn(rng, size(arg1(a)))) ⊗ (arg2(a) + 1.0e-1 * randn(rng, size(arg2(a)))) + @test KroneckerArrays.dist(a, b) ≈ norm(collect(a) - collect(b)) rtol = 1.0e-2 end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index 0767542..4af64ca 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -1,25 +1,8 @@ -using KroneckerArrays: ⊗, arguments +using KroneckerArrays: ⊗, arg1, arg2 using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm -using MatrixAlgebraKit: - eig_full, - eig_trunc, - eig_vals, - eigh_full, - eigh_trunc, - eigh_vals, - left_null, - left_orth, - left_polar, - lq_compact, - lq_full, - qr_compact, - qr_full, - right_null, - right_orth, - right_polar, - svd_compact, - svd_full, - svd_trunc, +using MatrixAlgebraKit: eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc, + eigh_vals, left_null, left_orth, left_polar, lq_compact, lq_full, qr_compact, + qr_full, right_null, right_orth, right_polar, svd_compact, svd_full, svd_trunc, svd_vals using Test: @test, @test_throws, @testset using TestExtras: @constinferred @@ -31,18 +14,26 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) d, v = eig_full(a) - @test a * v ≈ v * d + av = a * v + vd = v * d + @test arg1(av) ≈ arg1(vd) + @test arg2(av) ≈ arg2(vd) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) @test_throws ArgumentError eig_trunc(a) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) d = eig_vals(a) - @test d ≈ diag(eig_full(a)[1]) + d′ = diag(eig_full(a)[1]) + @test arg1(d) ≈ arg1(d′) + @test arg2(d) ≈ arg2(d′) a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) d, v = eigh_full(a) - @test a * v ≈ v * d + av = a * v + vd = v * d + @test arg1(av) ≈ arg1(vd) + @test arg2(av) ≈ arg2(vd) @test eltype(d) === real(elt) @test eltype(v) === elt @@ -56,22 +47,30 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = qr_compact(a) - @test u * c ≈ a + uc = u * c + @test arg1(uc) ≈ arg1(a) + @test arg2(uc) ≈ arg2(a) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = qr_full(a) - @test u * c ≈ a + uc = u * c + @test arg1(uc) ≈ arg1(a) + @test arg2(uc) ≈ arg2(a) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = lq_compact(a) - @test c * u ≈ a + cu = c * u + @test arg1(cu) ≈ arg1(a) + @test arg2(cu) ≈ arg2(a) @test collect(u * u') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = lq_full(a) - @test c * u ≈ a + cu = c * u + @test arg1(cu) ≈ arg1(a) + @test arg2(cu) ≈ arg2(a) @test collect(u * u') ≈ I a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3) @@ -84,27 +83,37 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = left_orth(a) - @test u * c ≈ a + uc = u * c + @test arg1(uc) ≈ arg1(a) + @test arg2(uc) ≈ arg2(a) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = right_orth(a) - @test c * u ≈ a + cu = c * u + @test arg1(cu) ≈ arg1(a) + @test arg2(cu) ≈ arg2(a) @test collect(u * u') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = left_polar(a) - @test u * c ≈ a + uc = u * c + @test arg1(uc) ≈ arg1(a) + @test arg2(uc) ≈ arg2(a) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = right_polar(a) - @test c * u ≈ a + cu = c * u + @test arg1(cu) ≈ arg1(a) + @test arg2(cu) ≈ arg2(a) @test collect(u * u') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, s, v = svd_compact(a) - @test u * s * v ≈ a + usv = u * s * v + @test arg1(usv) ≈ arg1(a) + @test arg2(usv) ≈ arg2(a) @test eltype(u) === elt @test eltype(s) === real(elt) @test eltype(v) === elt @@ -113,7 +122,9 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, s, v = svd_full(a) - @test u * s * v ≈ a + usv = u * s * v + @test arg1(usv) ≈ arg1(a) + @test arg2(usv) ≈ arg2(a) @test eltype(u) === elt @test eltype(s) === real(elt) @test eltype(v) === elt From 4464fdedab43722dc628f05318c54dadaa31ec53 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Nov 2025 18:29:25 -0500 Subject: [PATCH 6/9] Improve style --- src/kroneckerarray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 04d0d72..5d54929 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -386,7 +386,7 @@ function Base.isapprox( a1, a2 = arg1(a), arg2(a) b1, b2 = arg1(b), arg2(b) d = if a1 == b1 - norm(b1) * norm(a2 - b2) + norm(a1) * norm(a2 - b2) elseif a2 == b2 norm(a1 - b1) * norm(b2) else From 7559a548b9abdd343b03a76d23811771be53a1e2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Nov 2025 18:41:06 -0500 Subject: [PATCH 7/9] Fix tests --- src/kroneckerarray.jl | 4 ++-- test/test_basics.jl | 2 +- test/test_matrixalgebrakit.jl | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 5d54929..4a2639d 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -359,7 +359,7 @@ end # norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2) # = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2)) -function dist(a::AbstractKroneckerArray, b::AbstractKroneckerArray) +function dist_kronecker(a::AbstractKroneckerArray, b::AbstractKroneckerArray) a1, a2 = arg1(a), arg2(a) b1, b2 = arg1(b), arg2(b) diff1 = a1 - b1 @@ -390,7 +390,7 @@ function Base.isapprox( elseif a2 == b2 norm(a1 - b1) * norm(b2) else - # This could be defined as `KroneckerArrays.dist(a, b)`, but that might have + # This could be defined as `KroneckerArrays.dist_kronecker(a, b)`, but that might have # numerical precision issues so for now we just error. error( "`isapprox` not implemented for KroneckerArrays where both arguments differ. " * diff --git a/test/test_basics.jl b/test/test_basics.jl index ce0863a..fc12c3b 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -241,5 +241,5 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) a = randn(rng, 100, 100) ⊗ randn(rng, 100, 100) b = (arg1(a) + 1.0e-1 * randn(rng, size(arg1(a)))) ⊗ (arg2(a) + 1.0e-1 * randn(rng, size(arg2(a)))) - @test KroneckerArrays.dist(a, b) ≈ norm(collect(a) - collect(b)) rtol = 1.0e-2 + @test KroneckerArrays.dist_kronecker(a, b) ≈ norm(collect(a) - collect(b)) rtol = 1.0e-2 end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index 4af64ca..f62c492 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -136,5 +136,7 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) s = svd_vals(a) - @test s ≈ diag(svd_compact(a)[2]) + s′ = diag(svd_compact(a)[2]) + @test arg1(s) ≈ arg1(s′) + @test arg2(s) ≈ arg2(s′) end From 5bb5d0eb18beeedff6eec1179dc442c5bcaedc7d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 5 Nov 2025 10:04:24 -0500 Subject: [PATCH 8/9] Write isapprox in terms of isapprox of the factors --- src/kroneckerarray.jl | 27 ++++++++++++++------------ src/linearalgebra.jl | 2 +- test/test_basics.jl | 44 ++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 57 insertions(+), 16 deletions(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 4a2639d..9b1a293 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -378,26 +378,29 @@ end using LinearAlgebra: dot, promote_leaf_eltypes function Base.isapprox( - a::AbstractKroneckerArray, b::AbstractKroneckerArray; - atol::Real = 0, + a::AbstractKroneckerArray, b::AbstractKroneckerArray; atol::Real = 0, rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol), - norm::Function = norm ) a1, a2 = arg1(a), arg2(a) b1, b2 = arg1(b), arg2(b) - d = if a1 == b1 - norm(a1) * norm(a2 - b2) + if a1 == b1 + return isapprox(a2, b2; atol = atol / norm(a1), rtol) elseif a2 == b2 - norm(a1 - b1) * norm(b2) + return isapprox(a1, b1; atol = atol / norm(a2), rtol) else - # This could be defined as `KroneckerArrays.dist_kronecker(a, b)`, but that might have - # numerical precision issues so for now we just error. - error( - "`isapprox` not implemented for KroneckerArrays where both arguments differ. " * - "In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`." + # This could be defined as: + # ```julia + # d = KroneckerArrays.dist_kronecker(a, b) + # iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b))) + # ``` + # but that might have numerical precision issues so for now we just error. + throw( + ArgumentError( + "`isapprox` not implemented for KroneckerArrays where both arguments differ. " * + "In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`." + ) ) end - return iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b))) end function Base.iszero(a::AbstractKroneckerArray) diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index c0f08d5..c4dd865 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -58,7 +58,7 @@ function LinearAlgebra.tr(a::AbstractKroneckerArray) end using LinearAlgebra: norm -function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Int = 2) +function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Real = 2) return norm(arg1(a), p) * norm(arg2(a), p) end diff --git a/test/test_basics.jl b/test/test_basics.jl index fc12c3b..6aa1068 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -237,9 +237,47 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) end end - # KroneckerArrays.dist + # isapprox + + rng = StableRNG(123) + a1 = randn(rng, elt, (2, 2)) + a = a1 ⊗ randn(rng, elt, (3, 3)) + b = a1 ⊗ randn(rng, elt, (3, 3)) + @test isapprox(a, b; atol = norm(a - b) * (1 + 2eps(real(elt)))) + @test !isapprox(a, b; atol = norm(a - b) * (1 - 2eps(real(elt)))) + @test isapprox( + a, b; + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt))) + ) + @test !isapprox( + a, b; + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt))) + ) + @test isapprox( + a, b; atol = norm(a - b) * (1 + 2eps(real(elt))), + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt))) + ) + @test isapprox( + a, b; atol = norm(a - b) * (1 + 2eps(real(elt))), + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt))) + ) + @test isapprox( + a, b; atol = norm(a - b) * (1 - 2eps(real(elt))), + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt))) + ) + @test !isapprox( + a, b; atol = norm(a - b) * (1 - 2eps(real(elt))), + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt))) + ) + + a = randn(elt, (2, 2)) ⊗ randn(elt, (3, 3)) + b = randn(elt, (2, 2)) ⊗ randn(elt, (3, 3)) + @test_throws ArgumentError isapprox(a, b) + + # KroneckerArrays.dist_kronecker rng = StableRNG(123) - a = randn(rng, 100, 100) ⊗ randn(rng, 100, 100) - b = (arg1(a) + 1.0e-1 * randn(rng, size(arg1(a)))) ⊗ (arg2(a) + 1.0e-1 * randn(rng, size(arg2(a)))) + a = randn(rng, (100, 100)) ⊗ randn(rng, (100, 100)) + b = (arg1(a) + randn(rng, size(arg1(a))) / 10) ⊗ + (arg2(a) + randn(rng, size(arg2(a))) / 10) @test KroneckerArrays.dist_kronecker(a, b) ≈ norm(collect(a) - collect(b)) rtol = 1.0e-2 end From 3af6e96ff7676e1b91c1c62027b4025de3905ac6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 5 Nov 2025 10:14:29 -0500 Subject: [PATCH 9/9] Comment about abs --- src/kroneckerarray.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 9b1a293..9b93fb0 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -373,6 +373,7 @@ function dist_kronecker(a::AbstractKroneckerArray, b::AbstractKroneckerArray) xy = real(dot(diff1, b1) * dot(a2, diff2)) xz = real(dot(diff1, diff1) * dot(a2, diff2)) yz = real(dot(b1, diff1) * dot(diff2, diff2)) + # `abs` is used in case there are negative values due to floating point roundoff errors. return sqrt(abs(xx + yy + zz + 2 * (xy + xz + yz))) end