From 91dac0668a99667f2ca54b6eeef129127be04e6d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 7 Oct 2025 00:13:58 -0400 Subject: [PATCH 1/3] Change Prod to Mul --- Project.toml | 2 +- src/lazynameddimsarrays.jl | 162 +++++++++++++++++-------------- test/test_lazynameddimsarrays.jl | 25 +++-- 3 files changed, 106 insertions(+), 83 deletions(-) diff --git a/Project.toml b/Project.toml index b527a53..b77134e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.6" +version = "0.1.7" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl index 04eca26..3561cb6 100644 --- a/src/lazynameddimsarrays.jl +++ b/src/lazynameddimsarrays.jl @@ -7,42 +7,108 @@ using NamedDimsArrays: AbstractNamedDimsArrayStyle, dename, inds +using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments -struct Prod{A} - factors::Vector{A} -end +struct Mul{A} + arguments::Vector{A} +end +TermInterface.arguments(m::Mul) = getfield(m, :arguments) +TermInterface.children(m::Mul) = arguments(m) +TermInterface.head(m::Mul) = operation(m) +TermInterface.iscall(m::Mul) = true +TermInterface.isexpr(m::Mul) = iscall(m) +TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args) +TermInterface.operation(m::Mul) = * +TermInterface.sorted_arguments(m::Mul) = arguments(m) +TermInterface.sorted_children(m::Mul) = sorted_arguments(a) @wrapped struct LazyNamedDimsArray{ T, A <: AbstractNamedDimsArray{T}, } <: AbstractNamedDimsArray{T, Any} - union::Union{A, Prod{LazyNamedDimsArray{T, A}}} + union::Union{A, Mul{LazyNamedDimsArray{T, A}}} end function NamedDimsArrays.inds(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return inds(unwrap(a)) - elseif unwrap(a) isa Prod - return mapreduce(inds, symdiff, unwrap(a).factors) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return inds(u) + elseif u isa Mul + return mapreduce(inds, symdiff, arguments(u)) else return error("Variant not supported.") end end function NamedDimsArrays.dename(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return dename(unwrap(a)) - elseif unwrap(a) isa Prod + u = unwrap(a) + if u isa AbstractNamedDimsArray + return dename(u) + elseif u isa Mul return dename(materialize(a), inds(a)) else return error("Variant not supported.") end end +function TermInterface.arguments(a::LazyNamedDimsArray) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return error("No arguments.") + elseif u isa Mul + return arguments(u) + else + return error("Variant not supported.") + end +end +function TermInterface.children(a::LazyNamedDimsArray) + return arguments(a) +end +function TermInterface.head(a::LazyNamedDimsArray) + return operation(a) +end +function TermInterface.iscall(a::LazyNamedDimsArray) + return iscall(unwrap(a)) +end +function TermInterface.isexpr(a::LazyNamedDimsArray) + return iscall(a) +end +function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata) + if head ≡ * + return LazyNamedDimsArray(maketerm(Mul, head, args, metadata)) + else + return error("Only product terms supported right now.") + end +end +function TermInterface.operation(a::LazyNamedDimsArray) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return error("No operation.") + elseif u isa Mul + return operation(u) + else + return error("Variant not supported.") + end +end +function TermInterface.sorted_arguments(a::LazyNamedDimsArray) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return error("No arguments.") + elseif u isa Mul + return sorted_arguments(u) + else + return error("Variant not supported.") + end +end +function TermInterface.sorted_children(a::LazyNamedDimsArray) + return sorted_arguments(a) +end + using Base.Broadcast: materialize function Base.Broadcast.materialize(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return unwrap(a) - elseif unwrap(a) isa Prod - return prod(materialize, unwrap(a).factors) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return u + elseif u isa Mul + return mapfoldl(materialize, operation(u), arguments(u)) else return error("Variant not supported.") end @@ -50,9 +116,10 @@ end Base.copy(a::LazyNamedDimsArray) = materialize(a) function Base.:*(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return LazyNamedDimsArray(Prod([lazy(unwrap(a))])) - elseif unwrap(a) isa Prod + u = unwrap(a) + if u isa AbstractNamedDimsArray + return LazyNamedDimsArray(Mul([lazy(u)])) + elseif u isa Mul return a else return error("Variant not supported.") @@ -61,7 +128,7 @@ end function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) # Nested by default. - return LazyNamedDimsArray(Prod([a1, a2])) + return LazyNamedDimsArray(Mul([a1, a2])) end function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) return error("Not implemented.") @@ -85,7 +152,7 @@ end function LazyNamedDimsArray(a::AbstractNamedDimsArray) return LazyNamedDimsArray{eltype(a), typeof(a)}(a) end -function LazyNamedDimsArray(a::Prod{LazyNamedDimsArray{T, A}}) where {T, A} +function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A} return LazyNamedDimsArray{T, A}(a) end function lazy(a::AbstractNamedDimsArray) @@ -124,59 +191,4 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) return -a end -using TermInterface: TermInterface -# arguments, arity, children, head, iscall, operation -function TermInterface.arguments(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return error("No arguments.") - elseif unwrap(a) isa Prod - unwrap(a).factors - else - return error("Variant not supported.") - end -end -function TermInterface.children(a::LazyNamedDimsArray) - return TermInterface.arguments(a) -end -function TermInterface.head(a::LazyNamedDimsArray) - return TermInterface.operation(a) -end -function TermInterface.iscall(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return false - elseif unwrap(a) isa Prod - return true - else - return false - end -end -function TermInterface.isexpr(a::LazyNamedDimsArray) - return TermInterface.iscall(a) -end -function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata) - if head ≡ prod - return LazyNamedDimsArray(Prod(args)) - else - return error("Only product terms supported right now.") - end -end -function TermInterface.operation(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return error("No operation.") - elseif unwrap(a) isa Prod - prod - else - return error("Variant not supported.") - end -end -function TermInterface.sorted_arguments(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return error("No arguments.") - elseif unwrap(a) isa Prod - return TermInterface.arguments(a) - else - return error("Variant not supported.") - end -end - end diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index 958c191..e6f3106 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,8 +1,17 @@ using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Prod, lazy +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy using NamedDimsArrays: NamedDimsArray, inds, nameddims using TermInterface: - arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments + arguments, + arity, + children, + head, + iscall, + isexpr, + maketerm, + operation, + sorted_arguments, + sorted_children using Test: @test, @test_throws, @testset using WrappedUnions: unwrap @@ -23,8 +32,8 @@ using WrappedUnions: unwrap @test copy(l) ≈ a1 * a2 * a3 @test materialize(l) ≈ a1 * a2 * a3 @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) - @test unwrap(l) isa Prod - @test unwrap(l).factors == [l1 * l2, l3] + @test unwrap(l) isa Mul + @test unwrap(l).arguments == [l1 * l2, l3] end @testset "TermInterface" begin @@ -41,16 +50,18 @@ using WrappedUnions: unwrap @test !isexpr(l1) @test_throws ErrorException operation(l1) @test_throws ErrorException sorted_arguments(l1) + @test_throws ErrorException sorted_children(l1) l = l1 * l2 * l3 @test arguments(l) == [l1 * l2, l3] @test arity(l) == 2 @test children(l) == [l1 * l2, l3] - @test head(l) ≡ prod + @test head(l) ≡ * @test iscall(l) @test isexpr(l) - @test l == maketerm(LazyNamedDimsArray, prod, [l1 * l2, l3], nothing) - @test operation(l) ≡ prod + @test l == maketerm(LazyNamedDimsArray, *, [l1 * l2, l3], nothing) + @test operation(l) ≡ * @test sorted_arguments(l) == [l1 * l2, l3] + @test sorted_children(l) == [l1 * l2, l3] end end From 35e534c15aa2698b742cca474f4d332f7158e941 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 7 Oct 2025 00:21:36 -0400 Subject: [PATCH 2/3] Update tests --- test/test_lazynameddimsarrays.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index e6f3106..f95a3ed 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -33,7 +33,8 @@ using WrappedUnions: unwrap @test materialize(l) ≈ a1 * a2 * a3 @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) @test unwrap(l) isa Mul - @test unwrap(l).arguments == [l1 * l2, l3] + @test operation(unwrap(l)) ≡ * + @test arguments(unwrap(l)) == [l1 * l2, l3] end @testset "TermInterface" begin From 5f35ee233e45d17e9b63529144e9716179cb37d6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 7 Oct 2025 00:22:21 -0400 Subject: [PATCH 3/3] More tests --- test/test_lazynameddimsarrays.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index f95a3ed..4c38c5e 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -33,6 +33,8 @@ using WrappedUnions: unwrap @test materialize(l) ≈ a1 * a2 * a3 @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) @test unwrap(l) isa Mul + @test unwrap(l).arguments == [l1 * l2, l3] + # TermInterface.jl @test operation(unwrap(l)) ≡ * @test arguments(unwrap(l)) == [l1 * l2, l3] end