diff --git a/Project.toml b/Project.toml index bc1104b..1334477 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorBase" uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" authors = ["ITensor developers and contributors"] -version = "0.2.2" +version = "0.2.3" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -10,6 +10,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" UnallocatedArrays = "43c9e47c-e622-40fb-bf18-a09fc8c466b6" UnspecifiedTypes = "42b3faec-625b-4613-8ddc-352bf9672b8d" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" @@ -34,6 +35,7 @@ LinearAlgebra = "1.10" MapBroadcast = "0.1.5" NamedDimsArrays = "0.7" SparseArraysBase = "0.5" +TensorAlgebra = "0.3" UnallocatedArrays = "0.1.1" UnspecifiedTypes = "0.1.1" VectorInterface = "0.5" diff --git a/src/quirks.jl b/src/quirks.jl index c69606f..7b0d44f 100644 --- a/src/quirks.jl +++ b/src/quirks.jl @@ -20,46 +20,37 @@ hasqns(a::AbstractITensor) = all(hasqns, inds(a)) # TODO: Investigate this and see if we can get rid of it. Base.Broadcast.extrude(a::AbstractITensor) = a -# TODO: This is just a stand-in for truncated SVD -# that only makes use of `maxdim`, just to get some -# functionality running in `ITensorMPS.jl`. -# Define a proper truncated SVD in -# `MatrixAlgebra.jl`/`TensorAlgebra.jl`. -function svd_truncated(a::AbstractITensor, codomain_inds; maxdim) - U, S, V = svd(a, codomain_inds) - r = Base.OneTo(min(maxdim, minimum(Int.(size(S))))) - u = commonind(U, S) - v = commonind(V, S) - us = uniqueinds(U, S) - vs = uniqueinds(V, S) - U′ = U[(us .=> :)..., u => r] - S′ = S[u => r, v => r] - V′ = V[v => r, (vs .=> :)...] - return U′, S′, V′ -end +# See: https://github.com/JuliaLang/julia/blob/v1.11.4/base/namedtuple.jl#L269 +# `filter(f, ::NamedTuple)` is available in Julia v1.11, delete once +# we drop support for Julia v1.10. +filter_namedtuple(f, xs::NamedTuple) = xs[filter(k -> f(xs[k]), keys(xs))] -using LinearAlgebra: qr, svd -# TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`. -function factorize( - a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, ortho="left", kwargs... +function translate_factorize_kwargs(; + # MatrixAlgebraKit.jl/TensorAlgebra.jl kwargs. + orth=nothing, + rtol=nothing, + maxrank=nothing, + # ITensors.jl kwargs. + ortho=nothing, + cutoff=nothing, + maxdim=nothing, + kwargs..., ) - # TODO: Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`? - # See https://github.com/ITensor/NamedDimsArrays.jl/issues/22. - codomain_inds′ = if ortho == "left" - intersect(inds(a), codomain_inds) - elseif ortho == "right" - setdiff(inds(a), codomain_inds) - else - error("Bad `ortho` input.") - end - F1, F2 = if isnothing(maxdim) && isnothing(cutoff) - qr(a, codomain_inds′) - else - U, S, V = svd_truncated(a, codomain_inds′; maxdim) - U, S * V - end - if ortho == "right" - F2, F1 = F1, F2 - end - return F1, F2, (; truncerr=zero(Bool),) + orth::Symbol = @something orth ortho :left + rtol = @something rtol cutoff Some(nothing) + maxrank = @something maxrank maxdim Some(nothing) + !isnothing(maxrank) && error("`maxrank` not supported yet.") + return filter_namedtuple(!isnothing, (; orth, rtol, maxrank, kwargs...)) +end + +using TensorAlgebra: TensorAlgebra, factorize +function TensorAlgebra.factorize(a::AbstractITensor, codomain_inds, domain_inds; kwargs...) + return invoke( + factorize, + Tuple{AbstractNamedDimsArray,Any,Any}, + a, + codomain_inds, + domain_inds; + translate_factorize_kwargs(; kwargs...)..., + ) end diff --git a/test/Project.toml b/test/Project.toml index 2a687d6..99007e3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] @@ -25,4 +26,5 @@ SafeTestsets = "0.1" SparseArraysBase = "0.5" Suppressor = "0.2" SymmetrySectors = "0.1" +TensorAlgebra = "0.3" Test = "1.10" diff --git a/test/test_basics.jl b/test/test_basics.jl index 059d1e6..e69b5bd 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -19,8 +19,10 @@ using ITensorBase: using NamedDimsArrays: dename, name, named using SparseArraysBase: oneelement using SymmetrySectors: U1 +using LinearAlgebra: factorize using Test: @test, @test_broken, @test_throws, @testset +const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "ITensorBase" begin @testset "Basics" begin elt = Float64 @@ -164,4 +166,29 @@ using Test: @test, @test_broken, @test_throws, @testset @test hasqns(j) @test hasqns(a) end + @testset "factorize" for elt in elts + i = Index(2) + j = Index(2) + a = randn(elt, i, j) + x, y = factorize(a, (i,)) + @test a ≈ x * y + @test x isa ITensor + @test y isa ITensor + @test i ∈ inds(x) + @test j ∈ inds(y) + @test eltype(x) === elt + @test eltype(y) === elt + @test Int.(Tuple(size(x))) == (2, 2) + @test Int.(Tuple(size(y))) == (2, 2) + + i = Index(2) + j = Index(2) + a = randn(elt, i) * randn(elt, j) + for kwargs in ((; rtol=1e-2), (; cutoff=1e-2)) + x, y = factorize(a, (i,); kwargs...) + @test a ≈ x * y + @test Int.(Tuple(size(x))) == (2, 1) + @test Int.(Tuple(size(y))) == (1, 2) + end + end end