Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworks"
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>, Joseph Tindall <jtindall@flatironinstitute.org> and contributors"]
version = "0.15.2"
version = "0.15.3"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -56,7 +56,7 @@ Adapt = "4"
Combinatorics = "1"
Compat = "3, 4"
ConstructionBase = "1.6.0"
DataGraphs = "0.2.3"
DataGraphs = "0.2.13"
DataStructures = "0.18, 0.19"
Dictionaries = "0.4"
Distributions = "0.25.86"
Expand All @@ -70,7 +70,7 @@ IterTools = "1.4.0"
KrylovKit = "0.6, 0.7, 0.8, 0.9, 0.10"
MacroTools = "0.5"
NDTensors = "0.3, 0.4"
NamedGraphs = "0.7.1"
NamedGraphs = "0.8.2"
OMEinsumContractionOrders = "0.8.3, 0.9, 1"
Observers = "0.2.4"
SerializedElementArrays = "0.1"
Expand Down
129 changes: 78 additions & 51 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ using ITensors: dir
using NamedGraphs.PartitionedGraphs:
PartitionedGraphs,
PartitionedGraph,
PartitionVertex,
boundary_partitionedges,
partitionvertices,
partitionedges,
QuotientVertex,
boundary_quotientedges,
quotientvertices,
quotientedges,
unpartitioned_graph
using SimpleTraits: SimpleTraits, Not, @traitfn
using NamedGraphs.SimilarType: SimilarType
Expand All @@ -25,7 +25,8 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
end
data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc))

#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
#TODO: Take `dot` without precontracting the messages to allow scaling to more complex
# messages
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
lhs, rhs = contract(message_a), contract(message_b)
f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs)))
Expand All @@ -52,7 +53,7 @@ end
partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented()
messages(bpc::AbstractBeliefPropagationCache) = not_implemented()
function default_message(
bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...
bpc::AbstractBeliefPropagationCache, edge::QuotientEdge; kwargs...
)
return not_implemented()
end
Expand All @@ -66,14 +67,17 @@ end
function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...)
return not_implemented()
end
function region_scalar(bpc::AbstractBeliefPropagationCache, pv::PartitionVertex; kwargs...)
function region_scalar(bpc::AbstractBeliefPropagationCache, pv::QuotientVertex; kwargs...)
return not_implemented()
end
function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...)
function region_scalar(bpc::AbstractBeliefPropagationCache, pe::QuotientEdge; kwargs...)
return not_implemented()
end
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
PartitionedGraphs.partitionedges(bpc::AbstractBeliefPropagationCache) = not_implemented()
PartitionedGraphs.quotientedges(bpc::AbstractBeliefPropagationCache) = not_implemented()
function PartitionedGraphs.partitioned_vertices(bpc::AbstractBeliefPropagationCache)
return not_implemented()
end

default_bp_edge_sequence(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_bp_maxiter(bpc::AbstractBeliefPropagationCache) = not_implemented()
Expand All @@ -87,21 +91,23 @@ function factors(bpc::AbstractBeliefPropagationCache, verts::Vector)
end

function factors(
bpc::AbstractBeliefPropagationCache, partition_verts::Vector{<:PartitionVertex}
bpc::AbstractBeliefPropagationCache, partition_verts::Vector{<:QuotientVertex}
)
return factors(bpc, vertices(bpc, partition_verts))
end

function factors(bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex)
function factors(bpc::AbstractBeliefPropagationCache, partition_vertex::QuotientVertex)
return factors(bpc, [partition_vertex])
end

function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs = partitions(bpc); kwargs...)
function vertex_scalars(
bpc::AbstractBeliefPropagationCache, pvs = partitions(bpc); kwargs...
)
return map(pv -> region_scalar(bpc, pv; kwargs...), pvs)
end

function edge_scalars(
bpc::AbstractBeliefPropagationCache, pes = partitionedges(bpc); kwargs...
bpc::AbstractBeliefPropagationCache, pes = quotientedges(bpc); kwargs...
)
return map(pe -> region_scalar(bpc, pe; kwargs...), pes)
end
Expand All @@ -112,16 +118,16 @@ end

function incoming_messages(
bpc::AbstractBeliefPropagationCache,
partition_vertices::Vector{<:PartitionVertex};
partition_vertices::Vector{<:QuotientVertex};
ignore_edges = (),
)
bpes = boundary_partitionedges(bpc, partition_vertices; dir = :in)
bpes = boundary_quotientedges(bpc, partition_vertices; dir = :in)
ms = messages(bpc, setdiff(bpes, ignore_edges))
return reduce(vcat, ms; init = ITensor[])
end

function incoming_messages(
bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex; kwargs...
bpc::AbstractBeliefPropagationCache, partition_vertex::QuotientVertex; kwargs...
)
return incoming_messages(bpc, [partition_vertex]; kwargs...)
end
Expand Down Expand Up @@ -157,21 +163,40 @@ function Adapt.adapt_structure(to, bpc::AbstractBeliefPropagationCache)
end

#Forward from partitioned graph
for f in [
:(PartitionedGraphs.partitionedge),
:(PartitionedGraphs.partitionvertices),
:(PartitionedGraphs.partitions_graph),
:(PartitionedGraphs.vertices),
:(PartitionedGraphs.boundary_partitionedges),
]
@eval begin
function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
return $f(partitioned_tensornetwork(bpc), args...; kwargs...)
end
end
using Graphs: Graphs, vertices
function Graphs.vertices(bpc::AbstractBeliefPropagationCache)
return vertices(partitioned_tensornetwork(bpc))
end
function PartitionedGraphs.quotient_graph(bpc::AbstractBeliefPropagationCache)
return PartitionedGraphs.quotient_graph(partitioned_tensornetwork(bpc))
end
function PartitionedGraphs.quotientedge(
bpc::AbstractBeliefPropagationCache, edge::AbstractEdge
)
return PartitionedGraphs.quotientedge(partitioned_tensornetwork(bpc), edge)
end
function PartitionedGraphs.quotientvertices(bpc::AbstractBeliefPropagationCache)
return PartitionedGraphs.quotientvertices(partitioned_tensornetwork(bpc))
end
function PartitionedGraphs.quotientvertices(bpc::AbstractBeliefPropagationCache, vs)
return PartitionedGraphs.quotientvertices(partitioned_tensornetwork(bpc), vs)
end
function PartitionedGraphs.boundary_quotientedges(
bpc::AbstractBeliefPropagationCache, quotientvertices; kwargs...
)
return PartitionedGraphs.boundary_quotientedges(
partitioned_tensornetwork(bpc), quotientvertices; kwargs...
)
end
function PartitionedGraphs.boundary_quotientedges(
bpc::AbstractBeliefPropagationCache, quotientvertex::QuotientVertex; kwargs...
)
return PartitionedGraphs.boundary_quotientedges(
partitioned_tensornetwork(bpc), quotientvertex; kwargs...
)
end

function linkinds(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
function linkinds(bpc::AbstractBeliefPropagationCache, pe::QuotientEdge)
return linkinds(partitioned_tensornetwork(bpc), pe)
end

Expand All @@ -195,62 +220,64 @@ function update_factor(bpc, vertex, factor)
return bpc
end

function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...)
function message(bpc::AbstractBeliefPropagationCache, edge::QuotientEdge; kwargs...)
mts = messages(bpc)
return get(() -> default_message(bpc, edge; kwargs...), mts, edge)
end
function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...)
return map(edge -> message(bpc, edge; kwargs...), edges)
end
function set_messages!(bpc::AbstractBeliefPropagationCache, partitionedges_messages)
function set_messages!(bpc::AbstractBeliefPropagationCache, quotientedges_messages)
ms = messages(bpc)
for pe in eachindex(partitionedges_messages)
for pe in eachindex(quotientedges_messages)
# TODO: Add a check that this preserves the graph structure.
set!(ms, pe, partitionedges_messages[pe])
set!(ms, pe, quotientedges_messages[pe])
end
return bpc
end
function set_message!(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message)
function set_message!(bpc::AbstractBeliefPropagationCache, pe::QuotientEdge, message)
ms = messages(bpc)
set!(ms, pe, message)
return bpc
end

function set_messages(bpc::AbstractBeliefPropagationCache, partitionedges_messages)
function set_messages(bpc::AbstractBeliefPropagationCache, quotientedges_messages)
bpc = copy(bpc)
return set_messages!(bpc, partitionedges_messages)
return set_messages!(bpc, quotientedges_messages)
end
function set_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message)
function set_message(bpc::AbstractBeliefPropagationCache, pe::QuotientEdge, message)
bpc = copy(bpc)
return set_message!(bpc, pe, message)
end
function delete_messages!(
bpc::AbstractBeliefPropagationCache, pes::Vector{<:PartitionEdge} = keys(messages(bpc))
bpc::AbstractBeliefPropagationCache,
pes::Vector{<:QuotientEdge} = keys(messages(bpc)),
)
ms = messages(bpc)
for pe in pes
delete!(ms, pe)
end
return bpc
end
function delete_message!(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
function delete_message!(bpc::AbstractBeliefPropagationCache, pe::QuotientEdge)
return delete_messages!(bpc, [pe])
end
function delete_messages(
bpc::AbstractBeliefPropagationCache, pes::Vector{<:PartitionEdge} = keys(messages(bpc))
bpc::AbstractBeliefPropagationCache,
pes::Vector{<:QuotientEdge} = keys(messages(bpc)),
)
bpc = copy(bpc)
return delete_messages!(bpc, pes)
end
function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
function delete_message(bpc::AbstractBeliefPropagationCache, pe::QuotientEdge)
return delete_messages(bpc, [pe])
end

function updated_message(
alg::Algorithm"contract", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
alg::Algorithm"contract", bpc::AbstractBeliefPropagationCache, edge::QuotientEdge
)
vertex = src(edge)
incoming_ms = incoming_messages(bpc, vertex; ignore_edges = PartitionEdge[reverse(edge)])
incoming_ms = incoming_messages(bpc, vertex; ignore_edges = QuotientEdge[reverse(edge)])
state = factors(bpc, vertex)
contract_list = ITensor[incoming_ms; state]
sequence = contraction_sequence(contract_list; alg = alg.kwargs.sequence_alg)
Expand All @@ -263,10 +290,10 @@ function updated_message(
end

function updated_message(
alg::Algorithm"adapt_update", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
alg::Algorithm"adapt_update", bpc::AbstractBeliefPropagationCache, edge::QuotientEdge
)
incoming_pes = setdiff(
boundary_partitionedges(bpc, [src(edge)]; dir = :in), [reverse(edge)]
boundary_quotientedges(bpc, [src(edge)]; dir = :in), [reverse(edge)]
)
adapted_bpc = adapt_messages(alg.kwargs.adapt, bpc, incoming_pes)
adapted_bpc = adapt_factors(alg.kwargs.adapt, bpc, vertices(bpc, src(edge)))
Expand All @@ -277,15 +304,15 @@ end

function updated_message(
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
edge::QuotientEdge;
alg = default_message_update_alg(bpc),
kwargs...,
)
return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bpc, edge)
end

function update_message(
message_update_alg::Algorithm, bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
message_update_alg::Algorithm, bpc::AbstractBeliefPropagationCache, edge::QuotientEdge
)
return set_message(bpc, edge, updated_message(message_update_alg, bpc, edge))
end
Expand Down Expand Up @@ -318,7 +345,7 @@ mts relevant to that group.
function update_iteration(
alg::Algorithm"bp",
bpc::AbstractBeliefPropagationCache,
edge_groups::Vector{<:Vector{<:PartitionEdge}};
edge_groups::Vector{<:Vector{<:QuotientEdge}};
(update_diff!) = nothing,
)
new_mts = empty(messages(bpc))
Expand Down Expand Up @@ -357,13 +384,13 @@ function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bp
end

function rescale_messages(
bp_cache::AbstractBeliefPropagationCache, partitionedge::PartitionEdge
bp_cache::AbstractBeliefPropagationCache, quotientedge::QuotientEdge
)
return rescale_messages(bp_cache, [partitionedge])
return rescale_messages(bp_cache, [quotientedge])
end

function rescale_messages(bp_cache::AbstractBeliefPropagationCache)
return rescale_messages(bp_cache, partitionedges(bp_cache))
return rescale_messages(bp_cache, quotientedges(bp_cache))
end

function rescale_partitions(
Expand Down
Loading
Loading