From b1794f762037ea24afcc092a5556a4e7e4c6da51 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 19 Feb 2025 13:48:41 -0500 Subject: [PATCH] Check shapes in ITensor constructor --- Project.toml | 4 ++-- src/abstractitensor.jl | 5 +++++ test/test_basics.jl | 8 ++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ec930c4..1ee259a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorBase" uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" authors = ["ITensor developers and contributors"] -version = "0.1.15" +version = "0.1.16" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -32,7 +32,7 @@ FillArrays = "1.13.0" GradedUnitRanges = "0.1.4" LinearAlgebra = "1.10" MapBroadcast = "0.1.5" -NamedDimsArrays = "0.4" +NamedDimsArrays = "0.4.5" SparseArraysBase = "0.2.11" UnallocatedArrays = "0.1.1" UnspecifiedTypes = "0.1.1" diff --git a/src/abstractitensor.jl b/src/abstractitensor.jl index e6a825f..8a7060f 100644 --- a/src/abstractitensor.jl +++ b/src/abstractitensor.jl @@ -80,6 +80,11 @@ end mutable struct ITensor <: AbstractITensor parent::AbstractArray nameddimsindices + function ITensor(parent::AbstractArray, dims) + # This checks the shapes of the inputs. + nameddimsindices = NamedDimsArrays.to_nameddimsindices(parent, dims) + return new(parent, nameddimsindices) + end end Base.parent(a::ITensor) = a.parent NamedDimsArrays.nameddimsindices(a::ITensor) = a.nameddimsindices diff --git a/test/test_basics.jl b/test/test_basics.jl index 376d8c9..f23b031 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -18,12 +18,13 @@ using ITensorBase: using NamedDimsArrays: dename, name, named using SparseArraysBase: oneelement using SymmetrySectors: U1 -using Test: @test, @test_broken, @testset +using Test: @test, @test_broken, @test_throws, @testset @testset "ITensorBase" begin @testset "Basics" begin + elt = Float64 i, j = Index.((2, 2)) - x = randn(2, 2) + x = randn(elt, 2, 2) for a in (ITensor(x, i, j), ITensor(x, (i, j))) @test dename(a) == x @test plev(i) == 0 @@ -34,6 +35,9 @@ using Test: @test, @test_broken, @testset @test issetequal(inds(a′), (prime(i), prime(j))) end + @test_throws ErrorException ITensor(randn(elt, 2, 2), Index.((2, 3))) + @test_throws ErrorException ITensor(randn(elt, 4), Index.((2, 2))) + i = Index(2) i = settag(i, "X", "x") @test hastag(i, "X")