diff --git a/Project.toml b/Project.toml index a5193ad..aa26928 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.2.8" +version = "0.2.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 6a1e454..9b93fb0 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -357,11 +357,53 @@ function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray) return arg1(a) == arg1(b) && arg2(a) == arg2(b) end -# TODO: this definition doesn't fully retain the original meaning: -# ‖a - b‖ < atol could be true even if the following check isn't -function Base.isapprox(a::AbstractKroneckerArray, b::AbstractKroneckerArray; kwargs...) - return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...) +# norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2) +# = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2)) +function dist_kronecker(a::AbstractKroneckerArray, b::AbstractKroneckerArray) + a1, a2 = arg1(a), arg2(a) + b1, b2 = arg1(b), arg2(b) + diff1 = a1 - b1 + diff2 = a2 - b2 + # x = (a1 - b1) ⊗ a2 + # y = b1 ⊗ (a2 - b2) + # z = (a1 - b1) ⊗ (a2 - b2) + xx = norm(diff1)^2 * norm(a2)^2 + yy = norm(b1)^2 * norm(diff2)^2 + zz = norm(diff1)^2 * norm(diff2)^2 + xy = real(dot(diff1, b1) * dot(a2, diff2)) + xz = real(dot(diff1, diff1) * dot(a2, diff2)) + yz = real(dot(b1, diff1) * dot(diff2, diff2)) + # `abs` is used in case there are negative values due to floating point roundoff errors. + return sqrt(abs(xx + yy + zz + 2 * (xy + xz + yz))) +end + +using LinearAlgebra: dot, promote_leaf_eltypes +function Base.isapprox( + a::AbstractKroneckerArray, b::AbstractKroneckerArray; atol::Real = 0, + rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol), + ) + a1, a2 = arg1(a), arg2(a) + b1, b2 = arg1(b), arg2(b) + if a1 == b1 + return isapprox(a2, b2; atol = atol / norm(a1), rtol) + elseif a2 == b2 + return isapprox(a1, b1; atol = atol / norm(a2), rtol) + else + # This could be defined as: + # ```julia + # d = KroneckerArrays.dist_kronecker(a, b) + # iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b))) + # ``` + # but that might have numerical precision issues so for now we just error. + throw( + ArgumentError( + "`isapprox` not implemented for KroneckerArrays where both arguments differ. " * + "In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`." + ) + ) + end end + function Base.iszero(a::AbstractKroneckerArray) return iszero(arg1(a)) || iszero(arg2(a)) end diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index c0f08d5..c4dd865 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -58,7 +58,7 @@ function LinearAlgebra.tr(a::AbstractKroneckerArray) end using LinearAlgebra: norm -function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Int = 2) +function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Real = 2) return norm(arg1(a), p) * norm(arg2(a), p) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 4b8c8e8..6aa1068 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -4,20 +4,9 @@ using DerivableInterfaces: zero! using DiagonalArrays: diagonal using GPUArraysCore: @allowscalar using JLArrays: JLArray -using KroneckerArrays: - KroneckerArrays, - KroneckerArray, - KroneckerStyle, - CartesianProductUnitRange, - CartesianProductVector, - ⊗, - ×, - arg1, - arg2, - cartesianproduct, - cartesianrange, - kron_nd, - unproduct +using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerStyle, + CartesianProductUnitRange, CartesianProductVector, ⊗, ×, arg1, arg2, cartesianproduct, + cartesianrange, kron_nd, unproduct using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr using StableRNGs: StableRNG using Test: @test, @test_broken, @test_throws, @testset @@ -219,10 +208,11 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - c = a.arg1 ⊗ b.arg2 + c = arg1(a) ⊗ arg2(b) U, S, V = svd(a) @test collect(U * diagonal(S) * V') ≈ collect(a) - @test svdvals(a) ≈ S + @test arg1(svdvals(a)) ≈ arg1(S) + @test arg2(svdvals(a)) ≈ arg2(S) @test sort(collect(S); rev = true) ≈ svdvals(collect(a)) @test collect(U'U) ≈ I @test collect(V * V') ≈ I @@ -246,4 +236,48 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test_throws ArgumentError $f($a) end end + + # isapprox + + rng = StableRNG(123) + a1 = randn(rng, elt, (2, 2)) + a = a1 ⊗ randn(rng, elt, (3, 3)) + b = a1 ⊗ randn(rng, elt, (3, 3)) + @test isapprox(a, b; atol = norm(a - b) * (1 + 2eps(real(elt)))) + @test !isapprox(a, b; atol = norm(a - b) * (1 - 2eps(real(elt)))) + @test isapprox( + a, b; + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt))) + ) + @test !isapprox( + a, b; + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt))) + ) + @test isapprox( + a, b; atol = norm(a - b) * (1 + 2eps(real(elt))), + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt))) + ) + @test isapprox( + a, b; atol = norm(a - b) * (1 + 2eps(real(elt))), + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt))) + ) + @test isapprox( + a, b; atol = norm(a - b) * (1 - 2eps(real(elt))), + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt))) + ) + @test !isapprox( + a, b; atol = norm(a - b) * (1 - 2eps(real(elt))), + rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt))) + ) + + a = randn(elt, (2, 2)) ⊗ randn(elt, (3, 3)) + b = randn(elt, (2, 2)) ⊗ randn(elt, (3, 3)) + @test_throws ArgumentError isapprox(a, b) + + # KroneckerArrays.dist_kronecker + rng = StableRNG(123) + a = randn(rng, (100, 100)) ⊗ randn(rng, (100, 100)) + b = (arg1(a) + randn(rng, size(arg1(a))) / 10) ⊗ + (arg2(a) + randn(rng, size(arg2(a))) / 10) + @test KroneckerArrays.dist_kronecker(a, b) ≈ norm(collect(a) - collect(b)) rtol = 1.0e-2 end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index 0767542..f62c492 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -1,25 +1,8 @@ -using KroneckerArrays: ⊗, arguments +using KroneckerArrays: ⊗, arg1, arg2 using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm -using MatrixAlgebraKit: - eig_full, - eig_trunc, - eig_vals, - eigh_full, - eigh_trunc, - eigh_vals, - left_null, - left_orth, - left_polar, - lq_compact, - lq_full, - qr_compact, - qr_full, - right_null, - right_orth, - right_polar, - svd_compact, - svd_full, - svd_trunc, +using MatrixAlgebraKit: eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc, + eigh_vals, left_null, left_orth, left_polar, lq_compact, lq_full, qr_compact, + qr_full, right_null, right_orth, right_polar, svd_compact, svd_full, svd_trunc, svd_vals using Test: @test, @test_throws, @testset using TestExtras: @constinferred @@ -31,18 +14,26 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) d, v = eig_full(a) - @test a * v ≈ v * d + av = a * v + vd = v * d + @test arg1(av) ≈ arg1(vd) + @test arg2(av) ≈ arg2(vd) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) @test_throws ArgumentError eig_trunc(a) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) d = eig_vals(a) - @test d ≈ diag(eig_full(a)[1]) + d′ = diag(eig_full(a)[1]) + @test arg1(d) ≈ arg1(d′) + @test arg2(d) ≈ arg2(d′) a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) d, v = eigh_full(a) - @test a * v ≈ v * d + av = a * v + vd = v * d + @test arg1(av) ≈ arg1(vd) + @test arg2(av) ≈ arg2(vd) @test eltype(d) === real(elt) @test eltype(v) === elt @@ -56,22 +47,30 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = qr_compact(a) - @test u * c ≈ a + uc = u * c + @test arg1(uc) ≈ arg1(a) + @test arg2(uc) ≈ arg2(a) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = qr_full(a) - @test u * c ≈ a + uc = u * c + @test arg1(uc) ≈ arg1(a) + @test arg2(uc) ≈ arg2(a) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = lq_compact(a) - @test c * u ≈ a + cu = c * u + @test arg1(cu) ≈ arg1(a) + @test arg2(cu) ≈ arg2(a) @test collect(u * u') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = lq_full(a) - @test c * u ≈ a + cu = c * u + @test arg1(cu) ≈ arg1(a) + @test arg2(cu) ≈ arg2(a) @test collect(u * u') ≈ I a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3) @@ -84,27 +83,37 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = left_orth(a) - @test u * c ≈ a + uc = u * c + @test arg1(uc) ≈ arg1(a) + @test arg2(uc) ≈ arg2(a) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = right_orth(a) - @test c * u ≈ a + cu = c * u + @test arg1(cu) ≈ arg1(a) + @test arg2(cu) ≈ arg2(a) @test collect(u * u') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = left_polar(a) - @test u * c ≈ a + uc = u * c + @test arg1(uc) ≈ arg1(a) + @test arg2(uc) ≈ arg2(a) @test collect(u'u) ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c, u = right_polar(a) - @test c * u ≈ a + cu = c * u + @test arg1(cu) ≈ arg1(a) + @test arg2(cu) ≈ arg2(a) @test collect(u * u') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, s, v = svd_compact(a) - @test u * s * v ≈ a + usv = u * s * v + @test arg1(usv) ≈ arg1(a) + @test arg2(usv) ≈ arg2(a) @test eltype(u) === elt @test eltype(s) === real(elt) @test eltype(v) === elt @@ -113,7 +122,9 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, s, v = svd_full(a) - @test u * s * v ≈ a + usv = u * s * v + @test arg1(usv) ≈ arg1(a) + @test arg2(usv) ≈ arg2(a) @test eltype(u) === elt @test eltype(s) === real(elt) @test eltype(v) === elt @@ -125,5 +136,7 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) s = svd_vals(a) - @test s ≈ diag(svd_compact(a)[2]) + s′ = diag(svd_compact(a)[2]) + @test arg1(s) ≈ arg1(s′) + @test arg2(s) ≈ arg2(s′) end