Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworksNext"
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.1.5"
version = "0.1.6"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -15,9 +15,11 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"

[compat]
Adapt = "4.3.0"
Adapt = "4.3"
BackendSelection = "0.1.6"
DataGraphs = "0.2.7"
Dictionaries = "0.4.5"
Expand All @@ -28,4 +30,6 @@ NamedDimsArrays = "0.8"
NamedGraphs = "0.6.9, 0.7"
SimpleTraits = "0.9.5"
SplitApplyCombine = "1.2.3"
TermInterface = "2"
WrappedUnions = "0.3"
julia = "1.10"
1 change: 1 addition & 0 deletions src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module ITensorNetworksNext

include("lazynameddimsarrays.jl")
include("abstracttensornetwork.jl")
include("tensornetwork.jl")

Expand Down
182 changes: 182 additions & 0 deletions src/lazynameddimsarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
module LazyNamedDimsArrays

using WrappedUnions: @wrapped, unwrap
using NamedDimsArrays:
NamedDimsArrays,
AbstractNamedDimsArray,
AbstractNamedDimsArrayStyle,
dename,
inds

struct Prod{A}
factors::Vector{A}
end

@wrapped struct LazyNamedDimsArray{
T, A <: AbstractNamedDimsArray{T},
} <: AbstractNamedDimsArray{T, Any}
union::Union{A, Prod{LazyNamedDimsArray{T, A}}}
end

function NamedDimsArrays.inds(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return inds(unwrap(a))
elseif unwrap(a) isa Prod
return mapreduce(inds, symdiff, unwrap(a).factors)
else
return error("Variant not supported.")
end
end
function NamedDimsArrays.dename(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return dename(unwrap(a))
elseif unwrap(a) isa Prod
return dename(materialize(a), inds(a))
else
return error("Variant not supported.")
end
end

using Base.Broadcast: materialize
function Base.Broadcast.materialize(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return unwrap(a)
elseif unwrap(a) isa Prod
return prod(materialize, unwrap(a).factors)
else
return error("Variant not supported.")
end
end
Base.copy(a::LazyNamedDimsArray) = materialize(a)

function Base.:*(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return LazyNamedDimsArray(Prod([lazy(unwrap(a))]))
elseif unwrap(a) isa Prod
return a
else
return error("Variant not supported.")
end
end

function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
# Nested by default.
return LazyNamedDimsArray(Prod([a1, a2]))
end
function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
return error("Not implemented.")
end
function Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
return error("Not implemented.")
end
function Base.:*(c::Number, a::LazyNamedDimsArray)
return error("Not implemented.")
end
function Base.:*(a::LazyNamedDimsArray, c::Number)
return error("Not implemented.")
end
function Base.:/(a::LazyNamedDimsArray, c::Number)
return error("Not implemented.")
end
function Base.:-(a::LazyNamedDimsArray)
return error("Not implemented.")
end

function LazyNamedDimsArray(a::AbstractNamedDimsArray)
return LazyNamedDimsArray{eltype(a), typeof(a)}(a)
end
function LazyNamedDimsArray(a::Prod{LazyNamedDimsArray{T, A}}) where {T, A}
return LazyNamedDimsArray{T, A}(a)
end
function lazy(a::AbstractNamedDimsArray)
return LazyNamedDimsArray(a)
end

# Broadcasting
struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end
function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray})
return LazyNamedDimsArrayStyle()
end
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...)
return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.")
end
# Linear operations.
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2)
return a1 + a2
end
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2)
return a1 - a2
end
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a)
return c * a
end
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number)
return a * c
end
# Fix ambiguity error.
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number)
return a * b
end
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number)
return a / c
end
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
return -a
end

using TermInterface: TermInterface
# arguments, arity, children, head, iscall, operation
function TermInterface.arguments(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return error("No arguments.")
elseif unwrap(a) isa Prod
unwrap(a).factors
else
return error("Variant not supported.")
end
end
function TermInterface.children(a::LazyNamedDimsArray)
return TermInterface.arguments(a)
end
function TermInterface.head(a::LazyNamedDimsArray)
return TermInterface.operation(a)
end
function TermInterface.iscall(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return false
elseif unwrap(a) isa Prod
return true
else
return false
end
end
function TermInterface.isexpr(a::LazyNamedDimsArray)
return TermInterface.iscall(a)
end
function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata)
if head ≡ prod
return LazyNamedDimsArray(Prod(args))
else
return error("Only product terms supported right now.")
end
end
function TermInterface.operation(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return error("No operation.")
elseif unwrap(a) isa Prod
prod
else
return error("Variant not supported.")
end
end
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
if unwrap(a) isa AbstractNamedDimsArray
return error("No arguments.")
elseif unwrap(a) isa Prod
return TermInterface.arguments(a)
else
return error("Variant not supported.")
end
end

end
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"

[compat]
Aqua = "0.8.14"
Expand All @@ -20,4 +22,6 @@ NamedDimsArrays = "0.8"
NamedGraphs = "0.6.8, 0.7"
SafeTestsets = "0.1"
Suppressor = "0.2.8"
TermInterface = "2"
Test = "1.10"
WrappedUnions = "0.3"
56 changes: 56 additions & 0 deletions test/test_lazynameddimsarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using Base.Broadcast: materialize
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Prod, lazy
using NamedDimsArrays: NamedDimsArray, inds, nameddims
using TermInterface:
arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments
using Test: @test, @test_throws, @testset
using WrappedUnions: unwrap

@testset "LazyNamedDimsArrays" begin
@testset "Basics" begin
a1 = nameddims(randn(2, 2), (:i, :j))
a2 = nameddims(randn(2, 2), (:j, :k))
a3 = nameddims(randn(2, 2), (:k, :l))
l1, l2, l3 = lazy.((a1, a2, a3))
for li in (l1, l2, l3)
@test li isa LazyNamedDimsArray
@test unwrap(li) isa NamedDimsArray
@test inds(li) == inds(unwrap(li))
@test copy(li) == unwrap(li)
@test materialize(li) == unwrap(li)
end
l = l1 * l2 * l3
@test copy(l) ≈ a1 * a2 * a3
@test materialize(l) ≈ a1 * a2 * a3
@test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...))
@test unwrap(l) isa Prod
@test unwrap(l).factors == [l1 * l2, l3]
end

@testset "TermInterface" begin
a1 = nameddims(randn(2, 2), (:i, :j))
a2 = nameddims(randn(2, 2), (:j, :k))
a3 = nameddims(randn(2, 2), (:k, :l))
l1, l2, l3 = lazy.((a1, a2, a3))

@test_throws ErrorException arguments(l1)
@test_throws ErrorException arity(l1)
@test_throws ErrorException children(l1)
@test_throws ErrorException head(l1)
@test !iscall(l1)
@test !isexpr(l1)
@test_throws ErrorException operation(l1)
@test_throws ErrorException sorted_arguments(l1)

l = l1 * l2 * l3
@test arguments(l) == [l1 * l2, l3]
@test arity(l) == 2
@test children(l) == [l1 * l2, l3]
@test head(l) ≡ prod
@test iscall(l)
@test isexpr(l)
@test l == maketerm(LazyNamedDimsArray, prod, [l1 * l2, l3], nothing)
@test operation(l) ≡ prod
@test sorted_arguments(l) == [l1 * l2, l3]
end
end
Loading