From a40d16effd14cf50be07484e6a862a3143e3c2c9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 9 Apr 2025 10:17:26 -0400 Subject: [PATCH 1/5] [WIP] Start using TensorAlgebra.factorize --- Project.toml | 2 ++ src/quirks.jl | 64 +++++++++++++++++---------------------------- test/Project.toml | 2 ++ test/test_basics.jl | 7 +++++ 4 files changed, 35 insertions(+), 40 deletions(-) diff --git a/Project.toml b/Project.toml index 90dc4b7..bec70b9 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" UnallocatedArrays = "43c9e47c-e622-40fb-bf18-a09fc8c466b6" UnspecifiedTypes = "42b3faec-625b-4613-8ddc-352bf9672b8d" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" @@ -34,6 +35,7 @@ LinearAlgebra = "1.10" MapBroadcast = "0.1.5" NamedDimsArrays = "0.6" SparseArraysBase = "0.5" +TensorAlgebra = "0.2.10" UnallocatedArrays = "0.1.1" UnspecifiedTypes = "0.1.1" VectorInterface = "0.5" diff --git a/src/quirks.jl b/src/quirks.jl index c69606f..907a25e 100644 --- a/src/quirks.jl +++ b/src/quirks.jl @@ -20,46 +20,30 @@ hasqns(a::AbstractITensor) = all(hasqns, inds(a)) # TODO: Investigate this and see if we can get rid of it. Base.Broadcast.extrude(a::AbstractITensor) = a -# TODO: This is just a stand-in for truncated SVD -# that only makes use of `maxdim`, just to get some -# functionality running in `ITensorMPS.jl`. -# Define a proper truncated SVD in -# `MatrixAlgebra.jl`/`TensorAlgebra.jl`. -function svd_truncated(a::AbstractITensor, codomain_inds; maxdim) - U, S, V = svd(a, codomain_inds) - r = Base.OneTo(min(maxdim, minimum(Int.(size(S))))) - u = commonind(U, S) - v = commonind(V, S) - us = uniqueinds(U, S) - vs = uniqueinds(V, S) - U′ = U[(us .=> :)..., u => r] - S′ = S[u => r, v => r] - V′ = V[v => r, (vs .=> :)...] - return U′, S′, V′ +function translate_factorize_kwargs(; + # ITensors.jl kwargs. + ortho=nothing, + cutoff=nothing, + maxdim=nothing, + # MatrixAlgebraKit.jl/TensorAlgebra.jl kwargs. + orth=nothing, + trunc=nothing, + kwargs..., +) + @show ortho, cutoff, maxdim + @show orth, trunc + @show kwargs + return error() end -using LinearAlgebra: qr, svd -# TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`. -function factorize( - a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, ortho="left", kwargs... -) - # TODO: Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`? - # See https://github.com/ITensor/NamedDimsArrays.jl/issues/22. - codomain_inds′ = if ortho == "left" - intersect(inds(a), codomain_inds) - elseif ortho == "right" - setdiff(inds(a), codomain_inds) - else - error("Bad `ortho` input.") - end - F1, F2 = if isnothing(maxdim) && isnothing(cutoff) - qr(a, codomain_inds′) - else - U, S, V = svd_truncated(a, codomain_inds′; maxdim) - U, S * V - end - if ortho == "right" - F2, F1 = F1, F2 - end - return F1, F2, (; truncerr=zero(Bool),) +using TensorAlgebra: TensorAlgebra, factorize +function TensorAlgebra.factorize(a::AbstractITensor, codomain_inds, domain_inds; kwargs...) + return invoke( + factorize, + Tuple{AbstractNamedDimsArray,Any,Any}, + a, + codomain_inds, + domain_inds; + translate_factorize_kwargs(; kwargs...)..., + ) end diff --git a/test/Project.toml b/test/Project.toml index bac88fa..c557bb3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,11 +6,13 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2" ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/test_basics.jl b/test/test_basics.jl index 059d1e6..c81d6ae 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -19,6 +19,7 @@ using ITensorBase: using NamedDimsArrays: dename, name, named using SparseArraysBase: oneelement using SymmetrySectors: U1 +using LinearAlgebra: factorize using Test: @test, @test_broken, @test_throws, @testset @testset "ITensorBase" begin @@ -164,4 +165,10 @@ using Test: @test, @test_broken, @test_throws, @testset @test hasqns(j) @test hasqns(a) end + @testset "factorize" begin + i = Index(2) + j = Index(2) + a = randn(i, j) + x, y = factorize(a, (i,)) + end end From 9c8880459c1fc54e41e7b4f709ba9817097ad403 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 11 Apr 2025 18:04:24 -0400 Subject: [PATCH 2/5] Fix forwarding kwargs, add tests --- Project.toml | 2 +- src/quirks.jl | 16 +++++++++------- test/Project.toml | 1 + test/test_basics.jl | 24 ++++++++++++++++++++++-- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index eb78372..59db1ae 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ LinearAlgebra = "1.10" MapBroadcast = "0.1.5" NamedDimsArrays = "0.7" SparseArraysBase = "0.5" -TensorAlgebra = "0.2.10" +TensorAlgebra = "0.3" UnallocatedArrays = "0.1.1" UnspecifiedTypes = "0.1.1" VectorInterface = "0.5" diff --git a/src/quirks.jl b/src/quirks.jl index 907a25e..f502921 100644 --- a/src/quirks.jl +++ b/src/quirks.jl @@ -21,19 +21,21 @@ hasqns(a::AbstractITensor) = all(hasqns, inds(a)) Base.Broadcast.extrude(a::AbstractITensor) = a function translate_factorize_kwargs(; + # MatrixAlgebraKit.jl/TensorAlgebra.jl kwargs. + orth=nothing, + rtol=nothing, + maxrank=nothing, # ITensors.jl kwargs. ortho=nothing, cutoff=nothing, maxdim=nothing, - # MatrixAlgebraKit.jl/TensorAlgebra.jl kwargs. - orth=nothing, - trunc=nothing, kwargs..., ) - @show ortho, cutoff, maxdim - @show orth, trunc - @show kwargs - return error() + orth::Symbol = @something orth ortho :left + rtol = @something rtol cutoff Some(nothing) + maxrank = @something maxrank maxdim Some(nothing) + !isnothing(maxrank) && error("`maxrank` not supported yet.") + return filter(!isnothing, (; orth, rtol, maxrank, kwargs...)) end using TensorAlgebra: TensorAlgebra, factorize diff --git a/test/Project.toml b/test/Project.toml index b28b6fb..22d63e1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -27,4 +27,5 @@ SafeTestsets = "0.1" SparseArraysBase = "0.5" Suppressor = "0.2" SymmetrySectors = "0.1" +TensorAlgebra = "0.3" Test = "1.10" diff --git a/test/test_basics.jl b/test/test_basics.jl index c81d6ae..e69b5bd 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -22,6 +22,7 @@ using SymmetrySectors: U1 using LinearAlgebra: factorize using Test: @test, @test_broken, @test_throws, @testset +const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "ITensorBase" begin @testset "Basics" begin elt = Float64 @@ -165,10 +166,29 @@ using Test: @test, @test_broken, @test_throws, @testset @test hasqns(j) @test hasqns(a) end - @testset "factorize" begin + @testset "factorize" for elt in elts i = Index(2) j = Index(2) - a = randn(i, j) + a = randn(elt, i, j) x, y = factorize(a, (i,)) + @test a ≈ x * y + @test x isa ITensor + @test y isa ITensor + @test i ∈ inds(x) + @test j ∈ inds(y) + @test eltype(x) === elt + @test eltype(y) === elt + @test Int.(Tuple(size(x))) == (2, 2) + @test Int.(Tuple(size(y))) == (2, 2) + + i = Index(2) + j = Index(2) + a = randn(elt, i) * randn(elt, j) + for kwargs in ((; rtol=1e-2), (; cutoff=1e-2)) + x, y = factorize(a, (i,); kwargs...) + @test a ≈ x * y + @test Int.(Tuple(size(x))) == (2, 1) + @test Int.(Tuple(size(y))) == (1, 2) + end end end From 55bd2391716165c45738b3a20f03d93e519be5a8 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 11 Apr 2025 18:05:06 -0400 Subject: [PATCH 3/5] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 59db1ae..1334477 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorBase" uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" authors = ["ITensor developers and contributors"] -version = "0.2.2" +version = "0.2.3" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" From 3652db4525555ae1522af6a23e602c46a3bef06f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 11 Apr 2025 18:06:00 -0400 Subject: [PATCH 4/5] Delete stale test dep --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 22d63e1..99007e3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,7 +6,6 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2" ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" From f03d8276655b27322eae98f934d0cfb410b8c149 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 11 Apr 2025 18:23:01 -0400 Subject: [PATCH 5/5] Fix for Julia 1.10 --- src/quirks.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/quirks.jl b/src/quirks.jl index f502921..7b0d44f 100644 --- a/src/quirks.jl +++ b/src/quirks.jl @@ -20,6 +20,11 @@ hasqns(a::AbstractITensor) = all(hasqns, inds(a)) # TODO: Investigate this and see if we can get rid of it. Base.Broadcast.extrude(a::AbstractITensor) = a +# See: https://github.com/JuliaLang/julia/blob/v1.11.4/base/namedtuple.jl#L269 +# `filter(f, ::NamedTuple)` is available in Julia v1.11, delete once +# we drop support for Julia v1.10. +filter_namedtuple(f, xs::NamedTuple) = xs[filter(k -> f(xs[k]), keys(xs))] + function translate_factorize_kwargs(; # MatrixAlgebraKit.jl/TensorAlgebra.jl kwargs. orth=nothing, @@ -35,7 +40,7 @@ function translate_factorize_kwargs(; rtol = @something rtol cutoff Some(nothing) maxrank = @something maxrank maxdim Some(nothing) !isnothing(maxrank) && error("`maxrank` not supported yet.") - return filter(!isnothing, (; orth, rtol, maxrank, kwargs...)) + return filter_namedtuple(!isnothing, (; orth, rtol, maxrank, kwargs...)) end using TensorAlgebra: TensorAlgebra, factorize