From 45187eb70da9391b04a613647ab0f242a95f7073 Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 5 Nov 2025 17:28:03 +0100 Subject: [PATCH 1/8] Add OG comparison --- graphix/measurements.py | 2 +- graphix/opengraph.py | 93 +++++++++++++++++++++++++++++------------ tests/test_opengraph.py | 36 +++++++++++++++- 3 files changed, 102 insertions(+), 29 deletions(-) diff --git a/graphix/measurements.py b/graphix/measurements.py index 104796fa..f1f41c2d 100644 --- a/graphix/measurements.py +++ b/graphix/measurements.py @@ -44,7 +44,7 @@ class Measurement(AbstractPlanarMeasurement): Attributes ---------- - angle : Expressionor Float + angle : ExpressionOrFloat The angle of the measurement in units of :math:`\pi`. Should be between [0, 2). plane : graphix.fundamentals.Plane The measurement plane. diff --git a/graphix/opengraph.py b/graphix/opengraph.py index 540aab3d..2b993033 100644 --- a/graphix/opengraph.py +++ b/graphix/opengraph.py @@ -83,32 +83,6 @@ def __post_init__(self) -> None: if len(outputs) != len(self.output_nodes): raise ValueError("Output nodes contain duplicates.") - # TODO: Up docstrings and generalise to any type - def isclose( - self: OpenGraph[Measurement], other: OpenGraph[Measurement], rel_tol: float = 1e-09, abs_tol: float = 0.0 - ) -> bool: - """Return `True` if two open graphs implement approximately the same unitary operator. - - Ensures the structure of the graphs are the same and all - measurement angles are sufficiently close. - - This doesn't check they are equal up to an isomorphism. - - """ - if not nx.utils.graphs_equal(self.graph, other.graph): - return False - - if self.input_nodes != other.input_nodes or self.output_nodes != other.output_nodes: - return False - - if set(self.measurements.keys()) != set(other.measurements.keys()): - return False - - return all( - m.isclose(other.measurements[node], rel_tol=rel_tol, abs_tol=abs_tol) - for node, m in self.measurements.items() - ) - def to_pattern(self: OpenGraph[Measurement]) -> Pattern: """Extract a deterministic pattern from an `OpenGraph[Measurement]` if it exists. @@ -140,6 +114,63 @@ def to_pattern(self: OpenGraph[Measurement]) -> Pattern: raise OpenGraphError("The open graph does not have flow. It does not support a deterministic pattern.") + def __eq__(self, other: object) -> bool: + """Check if two open graphs are equal. + + Parameters + ---------- + other : object + + Returns + ------- + bool + ``True`` if the two open graphs are equal. + + Notes + ----- + This method verifies the open graphs have: + - Truly equal underlying graphs (not up to an isomorphism). + - Equal input and output nodes. + - Same measurement planes or axis. It does not compare measurement angles (for that, see :func:`OpenGraph.isclose`). + """ + if isinstance(other, OpenGraph): + return _compare_opengraph_structure(self, other) and all( + m1.to_plane_or_axis() == m2.to_plane_or_axis() + for m1, m2 in zip(self.measurements.values(), other.measurements.values(), strict=False) + ) + + return False + + def isclose( + self: OpenGraph[Measurement], other: OpenGraph[Measurement], rel_tol: float = 1e-09, abs_tol: float = 0.0 + ) -> bool: + """Check if two open graphs of `Measurement` type are similar. + + Parameters + ---------- + other : OpenGraph[Measurement] + rel_tol : float + Relative tolerance. Optional, defaults to ``1e-09``. + abs_tol : float + Absolute tolerance. Optional, defaults to ``0.0``. + + Returns + ------- + bool + ``True`` if the two open graphs are approximately equal. + + Notes + ----- + This method verifies the open graphs have: + - Truly equal underlying graphs (not up to an isomorphism). + - Equal input and output nodes. + - Same measurement planes and approximately equal measurement angles. + """ + return _compare_opengraph_structure(self, other) and all( + m.isclose(other.measurements[node], rel_tol=rel_tol, abs_tol=abs_tol) + for node, m in self.measurements.items() + ) + def neighbors(self, nodes: Collection[int]) -> set[int]: """Return the set containing the neighborhood of a set of nodes in the open graph. @@ -405,5 +436,15 @@ def merge_ports(p1: Iterable[int], p2: Iterable[int]) -> list[int]: return OpenGraph(g, inputs, outputs, measurements), mapping_complete +def _compare_opengraph_structure(og_1: OpenGraph[_M_co], og_2: OpenGraph[_M_co]) -> bool: + if not nx.utils.graphs_equal(og_1.graph, og_2.graph): + return False + + if og_1.input_nodes != og_2.input_nodes or og_1.output_nodes != og_2.output_nodes: + return False + + return set(og_1.measurements.keys()) == set(og_2.measurements.keys()) + + class OpenGraphError(Exception): """Exception subclass to handle open graphs errors.""" diff --git a/tests/test_opengraph.py b/tests/test_opengraph.py index bb0955a5..cfea8575 100644 --- a/tests/test_opengraph.py +++ b/tests/test_opengraph.py @@ -13,7 +13,7 @@ import pytest from graphix.command import E -from graphix.fundamentals import Plane +from graphix.fundamentals import Axis, Plane from graphix.measurements import Measurement from graphix.opengraph import OpenGraph, OpenGraphError from graphix.pattern import Pattern @@ -633,8 +633,40 @@ def test_from_to_pattern(self, fx_rng: Generator) -> None: state = pattern.simulate_pattern(input_state=PlanarState(plane, alpha)) assert np.abs(np.dot(state.flatten().conjugate(), state_ref.flatten())) == pytest.approx(1) + def test_eq(self) -> None: + og_1 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Plane.XY), + ) + og_2 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Axis.X), + ) + assert og_1 == og_1 # noqa: PLR0124 + assert og_1 != og_2 + assert og_2 == og_2 # noqa: PLR0124 + + def test_isclose(self) -> None: + og_1 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Measurement(0.1, Plane.XY)), + ) + og_2 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Measurement(0.15, Plane.XY)), + ) + assert og_1.isclose(og_2, abs_tol=0.1) + assert not og_1.isclose(og_2) + -# TODO: Add test `OpenGraph.is_close` # TODO: rewrite as parametric tests # Tests composition of two graphs From 877be37fd7aa1e1edf8d9f9b93268841a225847f Mon Sep 17 00:00:00 2001 From: matulni Date: Tue, 25 Nov 2025 10:12:09 +0100 Subject: [PATCH 2/8] Add tests og compare --- graphix/opengraph.py | 14 ++++++++- tests/test_opengraph.py | 66 +++++++++++++++++++++++++++++++---------- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/graphix/opengraph.py b/graphix/opengraph.py index 2b993033..b4480456 100644 --- a/graphix/opengraph.py +++ b/graphix/opengraph.py @@ -136,7 +136,7 @@ def __eq__(self, other: object) -> bool: if isinstance(other, OpenGraph): return _compare_opengraph_structure(self, other) and all( m1.to_plane_or_axis() == m2.to_plane_or_axis() - for m1, m2 in zip(self.measurements.values(), other.measurements.values(), strict=False) + for m1, m2 in zip(self.measurements.values(), other.measurements.values(), strict=True) ) return False @@ -437,6 +437,18 @@ def merge_ports(p1: Iterable[int], p2: Iterable[int]) -> list[int]: def _compare_opengraph_structure(og_1: OpenGraph[_M_co], og_2: OpenGraph[_M_co]) -> bool: + """Compare the underlying structure of two open graphs. + + Parameters + ---------- + og_1 : OpenGraph[_M_co] + og_2 : OpenGraph[_M_co] + + Returns + ------- + bool + ``True`` if both open graphs have the same underlying structure. + """ if not nx.utils.graphs_equal(og_1.graph, og_2.graph): return False diff --git a/tests/test_opengraph.py b/tests/test_opengraph.py index cfea8575..6ddce5bb 100644 --- a/tests/test_opengraph.py +++ b/tests/test_opengraph.py @@ -15,13 +15,13 @@ from graphix.command import E from graphix.fundamentals import Axis, Plane from graphix.measurements import Measurement -from graphix.opengraph import OpenGraph, OpenGraphError +from graphix.opengraph import OpenGraph, OpenGraphError, _M_co from graphix.pattern import Pattern from graphix.random_objects import rand_circuit from graphix.states import PlanarState if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Sequence from numpy.random import Generator @@ -633,19 +633,55 @@ def test_from_to_pattern(self, fx_rng: Generator) -> None: state = pattern.simulate_pattern(input_state=PlanarState(plane, alpha)) assert np.abs(np.dot(state.flatten().conjugate(), state_ref.flatten())) == pytest.approx(1) - def test_eq(self) -> None: - og_1 = OpenGraph( - graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), - input_nodes=[0], - output_nodes=[3], - measurements=dict.fromkeys(range(3), Plane.XY), - ) - og_2 = OpenGraph( - graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), - input_nodes=[0], - output_nodes=[3], - measurements=dict.fromkeys(range(3), Axis.X), - ) + @pytest.mark.parametrize( + "test_case", + [ + ( + OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Plane.XY), + ), + OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Axis.X), + ), + ), + ( + OpenGraph( + graph=nx.Graph([(0, 1)]), + input_nodes=[0], + output_nodes=[], + measurements=dict.fromkeys(range(2), Plane.XY), + ), + OpenGraph( + graph=nx.Graph([(0, 1)]), + input_nodes=[0, 1], + output_nodes=[], + measurements=dict.fromkeys(range(2), Plane.XY), + ), + ), + ( + OpenGraph( + graph=nx.Graph([(0, 1)]), + input_nodes=[0], + output_nodes=[1], + measurements={0: Measurement(0.6, Plane.XY)}, + ), + OpenGraph( + graph=nx.Graph([(0, 2)]), + input_nodes=[0], + output_nodes=[2], + measurements={0: Measurement(0.6, Plane.XY)}, + ), + ), + ], + ) + def test_eq(self, test_case: Sequence[tuple[OpenGraph[_M_co], OpenGraph[_M_co]]]) -> None: + og_1, og_2 = test_case assert og_1 == og_1 # noqa: PLR0124 assert og_1 != og_2 assert og_2 == og_2 # noqa: PLR0124 From 00d07a71b121655014ac9c9f2c96d6b47f00c1e4 Mon Sep 17 00:00:00 2001 From: matulni Date: Tue, 25 Nov 2025 10:27:53 +0100 Subject: [PATCH 3/8] Add og compose --- graphix/opengraph.py | 39 ++-- tests/test_opengraph.py | 493 ++++++++++++++++++++++------------------ 2 files changed, 298 insertions(+), 234 deletions(-) diff --git a/graphix/opengraph.py b/graphix/opengraph.py index 540aab3d..bb360aae 100644 --- a/graphix/opengraph.py +++ b/graphix/opengraph.py @@ -11,12 +11,12 @@ from graphix.flow._find_gpflow import AlgebraicOpenGraph, PlanarAlgebraicOpenGraph, compute_correction_matrix from graphix.flow.core import GFlow, PauliFlow from graphix.fundamentals import AbstractMeasurement, AbstractPlanarMeasurement +from graphix.measurements import Measurement if TYPE_CHECKING: from collections.abc import Collection, Iterable, Mapping, Sequence from graphix.flow.core import CausalFlow - from graphix.measurements import Measurement from graphix.pattern import Pattern # TODO: Maybe move these definitions to graphix.fundamentals and graphix.measurements ? Now they are redefined in graphix.flow._find_gpflow, not very elegant. @@ -329,25 +329,22 @@ def find_pauli_flow(self: OpenGraph[_M_co]) -> PauliFlow[_M_co] | None: correction_matrix ) # The constructor returns `None` if the correction matrix is not compatible with any partial order on the open graph. - # TODO: Generalise `compose` to any type of OpenGraph - def compose( - self: OpenGraph[Measurement], other: OpenGraph[Measurement], mapping: Mapping[int, int] - ) -> tuple[OpenGraph[Measurement], dict[int, int]]: - r"""Compose two open graphs by merging subsets of nodes from `self` and `other`, and relabeling the nodes of `other` that were not merged. + def compose(self, other: OpenGraph[_M_co], mapping: Mapping[int, int]) -> tuple[OpenGraph[_M_co], dict[int, int]]: + r"""Compose two open graphs by merging subsets of nodes from ``self`` and ``other``, and relabeling the nodes of ``other`` that were not merged. Parameters ---------- - other : OpenGraph - Open graph to be composed with `self`. + other : OpenGraph[_M_co] + Open graph to be composed with ``self``. mapping: dict[int, int] - Partial relabelling of the nodes in `other`, with `keys` and `values` denoting the old and new node labels, respectively. + Partial relabelling of the nodes in ``other``, with ``keys`` and ``values`` denoting the old and new node labels, respectively. Returns ------- - og: OpenGraph - composed open graph + og: OpenGraph[_M_co] + Composed open graph. mapping_complete: dict[int, int] - Complete relabelling of the nodes in `other`, with `keys` and `values` denoting the old and new node label, respectively. + Complete relabelling of the nodes in ``other``, with ``keys`` and ``values`` denoting the old and new node label, respectively. Notes ----- @@ -368,13 +365,19 @@ def compose( raise ValueError("Keys of mapping must be correspond to nodes of other.") if len(mapping) != len(set(mapping.values())): raise ValueError("Values in mapping contain duplicates.") + + def equal_measurements(vm: AbstractMeasurement, um: AbstractMeasurement) -> bool: + return vm.isclose(um) if isinstance(vm, Measurement) and isinstance(um, Measurement) else vm == um + for v, u in mapping.items(): - if ( - (vm := other.measurements.get(v)) is not None - and (um := self.measurements.get(u)) is not None - and not vm.isclose(um) - ): - raise ValueError(f"Attempted to merge nodes {v}:{u} but have different measurements") + vm = other.measurements.get(v) + um = self.measurements.get(u) + + if vm is None or um is None: + continue + + if not equal_measurements(vm, um): + raise OpenGraphError(f"Attempted to merge nodes with different measurements: {v, vm} -> {u, um}.") shift = max(*self.graph.nodes, *mapping.values()) + 1 diff --git a/tests/test_opengraph.py b/tests/test_opengraph.py index bb0955a5..2acab621 100644 --- a/tests/test_opengraph.py +++ b/tests/test_opengraph.py @@ -6,6 +6,7 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, NamedTuple import networkx as nx @@ -25,6 +26,8 @@ from numpy.random import Generator + from graphix.fundamentals import AbstractMeasurement + class OpenGraphFlowTestCase(NamedTuple): og: OpenGraph[Measurement] @@ -544,6 +547,245 @@ def _og_19() -> OpenGraphFlowTestCase: return OpenGraphFlowTestCase(og, has_cflow=False, has_gflow=False, has_pflow=True) +class OpenGraphComposeTestCase(NamedTuple): + og1: OpenGraph[AbstractMeasurement] + og2: OpenGraph[AbstractMeasurement] + og_ref: OpenGraph[AbstractMeasurement] + mapping: dict[int, int] + comparison_method: Callable[..., bool] = ( + OpenGraph.__eq__ + ) # Replace by `OpenGraph.isclose` if `OpenGraph` is of type `Measurement`. + + +# Parallel composition +def _compose_0() -> OpenGraphComposeTestCase: + """Generate composition test. + + Graph 1 + [1] -- (2) + + Graph 2 = Graph 1 + + Mapping: 1 -> 100, 2 -> 200 + + Expected graph + [1] -- (2) + + [100] -- (200) + """ + g: nx.Graph[int] = nx.Graph([(1, 2)]) + inputs = [1] + outputs = [2] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og1 = OpenGraph(g, inputs, outputs, meas) + og2 = OpenGraph(g, inputs, outputs, meas) + og_ref = OpenGraph( + nx.Graph([(1, 2), (100, 200)]), + input_nodes=[1, 100], + output_nodes=[2, 200], + measurements={1: Measurement(0, Plane.XY), 100: Measurement(0, Plane.XY)}, + ) + + mapping = {1: 100, 2: 200} + + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + + +# Series composition +def _compose_1() -> OpenGraphComposeTestCase: + """Generate composition test. + + Graph 1 + [0] -- 17 -- (23) + | + [3] -- 4 -- (13) + + Graph 2 + [6] -- 17 -- (1) + | | + [7] -- 4 -- (2) + + Mapping: 6 -> 23, 7 -> 13, 1 -> 100, 2 -> 200, 17 -> 90 + + Expected graph + [0] -- 17 -- 23 -- 90 -- (100) + | | | + [3] -- 4 -- 13 -- 201 -- (200) + """ + g: nx.Graph[int] = nx.Graph([(0, 17), (17, 23), (17, 4), (3, 4), (4, 13)]) + inputs = [0, 3] + outputs = [13, 23] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og1 = OpenGraph(g, inputs, outputs, meas) + + g = nx.Graph([(6, 7), (6, 17), (17, 1), (7, 4), (17, 4), (4, 2)]) + inputs = [6, 7] + outputs = [1, 2] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og2 = OpenGraph(g, inputs, outputs, meas) + + mapping = {6: 23, 7: 13, 1: 100, 2: 200, 17: 90} + + g = nx.Graph( + [(0, 17), (17, 23), (17, 4), (3, 4), (4, 13), (23, 13), (23, 90), (13, 201), (90, 201), (90, 100), (201, 200)] + ) + inputs = [0, 3] + outputs = [100, 200] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og_ref = OpenGraph(g, inputs, outputs, meas) + + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + + +# Full overlap +def _compose_2() -> OpenGraphComposeTestCase: + """Generate composition test. + + Graph 1 + [0] -- 17 -- (23) + | + [3] -- 4 -- (13) + + Graph 2 = Graph 1 + + Mapping: 0 -> 0, 3 -> 3, 17 -> 17, 4 -> 4, 23 -> 23, 13 -> 13 + + Expected graph = Graph 1 + """ + g: nx.Graph[int] + g = nx.Graph([(0, 17), (17, 23), (17, 4), (3, 4), (4, 13)]) + inputs = [0, 3] + outputs = [13, 23] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og1 = OpenGraph(g, inputs, outputs, meas) + og2 = OpenGraph(g, inputs, outputs, meas) + og_ref = OpenGraph(g, inputs, outputs, meas) + + mapping = {i: i for i in g.nodes} + + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + + +# Overlap inputs/outputs +def _compose_3() -> OpenGraphComposeTestCase: + """Generate composition test. + + Graph 1 + ([17]) -- (3) + | + [18] + + Graph 2 + [1] -- 2 -- (3) + + Mapping: 1 -> 17, 3 -> 300 + + Expected graph + (300) -- 301 -- [17] -- (3) + | + [18] + """ + g: nx.Graph[int] = nx.Graph([(18, 17), (17, 3)]) + inputs = [17, 18] + outputs = [3, 17] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og1 = OpenGraph(g, inputs, outputs, meas) + + g = nx.Graph([(1, 2), (2, 3)]) + inputs = [1] + outputs = [3] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og2 = OpenGraph(g, inputs, outputs, meas) + + mapping = {1: 17, 3: 300} + + g = nx.Graph([(18, 17), (17, 3), (17, 301), (301, 300)]) + inputs = [17, 18] # the input character of node 17 is kept because node 1 (in G2) is an input. + outputs = [3, 300] # the output character of node 17 is lost because node 1 (in G2) is not an output + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og_ref = OpenGraph(g, inputs, outputs, meas) + + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + + +# Inverse series composition +def _compose_4() -> OpenGraphComposeTestCase: + """Generate composition test. + + Graph 1 + [1] -- (2) + | + [3] + + Graph 2 + [3] -- (4) + + Mapping: 4 -> 1, 3 -> 300 + + Expected graph + [300] -- 1 -- (2) + | + [3] + """ + g: nx.Graph[int] = nx.Graph([(1, 2), (1, 3)]) + inputs = [1, 3] + outputs = [2] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og1 = OpenGraph(g, inputs, outputs, meas) + + g = nx.Graph([(3, 4)]) + inputs = [3] + outputs = [4] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og2 = OpenGraph(g, inputs, outputs, meas) + + mapping = {4: 1, 3: 300} + + g = nx.Graph([(1, 2), (1, 3), (1, 300)]) + inputs = [3, 300] + outputs = [2] + meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} + og_ref = OpenGraph(g, inputs, outputs, meas) + + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + + +def _compose_5() -> OpenGraphComposeTestCase: + """Generate composition test. + + Graph 1 + [1] -- (2) + + Graph 2 = Graph 1 + + Mapping: 1 -> 2 + + Expected graph + [1] -- 2 -- (3) + + """ + g: nx.Graph[int] = nx.Graph([(1, 2)]) + inputs = [1] + outputs = [2] + meas = dict.fromkeys(g.nodes - set(outputs), Plane.XY) + og1 = OpenGraph(g, inputs, outputs, meas) + og2 = OpenGraph(g, inputs, outputs, meas) + og_ref = OpenGraph( + nx.Graph([(1, 2), (2, 3)]), input_nodes=[1], output_nodes=[3], measurements={1: Plane.XY, 2: Plane.XY} + ) + + mapping = {1: 2} + + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping) + + +def prepare_test_og_compose() -> list[OpenGraphComposeTestCase]: + n_og_samples = 6 + test_cases: list[OpenGraphComposeTestCase] = [globals()[f"_compose_{i}"]() for i in range(n_og_samples)] + + return test_cases + + def check_determinism(pattern: Pattern, fx_rng: Generator, n_shots: int = 3) -> bool: """Verify if the input pattern is deterministic.""" for plane in {Plane.XY, Plane.XZ, Plane.YZ}: @@ -633,219 +875,38 @@ def test_from_to_pattern(self, fx_rng: Generator) -> None: state = pattern.simulate_pattern(input_state=PlanarState(plane, alpha)) assert np.abs(np.dot(state.flatten().conjugate(), state_ref.flatten())) == pytest.approx(1) - -# TODO: Add test `OpenGraph.is_close` -# TODO: rewrite as parametric tests - -# Tests composition of two graphs - - -# Parallel composition -def test_compose_1() -> None: - # Graph 1 - # [1] -- (2) - # - # Graph 2 = Graph 1 - # - # Mapping: 1 -> 100, 2 -> 200 - # - # Expected graph - # [1] -- (2) - # - # [100] -- (200) - - g: nx.Graph[int] - g = nx.Graph([(1, 2)]) - inputs = [1] - outputs = [2] - meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} - og_1 = OpenGraph(g, inputs, outputs, meas) - - mapping = {1: 100, 2: 200} - - og, mapping_complete = og_1.compose(og_1, mapping) - - expected_graph: nx.Graph[int] - expected_graph = nx.Graph([(1, 2), (100, 200)]) - assert nx.is_isomorphic(og.graph, expected_graph) - assert og.input_nodes == [1, 100] - assert og.output_nodes == [2, 200] - - outputs_c = {i for i in og.graph.nodes if i not in og.output_nodes} - assert og.measurements.keys() == outputs_c - assert mapping.keys() <= mapping_complete.keys() - assert set(mapping.values()) <= set(mapping_complete.values()) - - -# Series composition -def test_compose_2() -> None: - # Graph 1 - # [0] -- 17 -- (23) - # | - # [3] -- 4 -- (13) - # - # Graph 2 - # [6] -- 17 -- (1) - # | | - # [7] -- 4 -- (2) - # - # Mapping: 6 -> 23, 7 -> 13, 1 -> 100, 2 -> 200 - # - # Expected graph - # [0] -- 17 -- 23 -- o -- (100) - # | | | - # [3] -- 4 -- 13 -- o -- (200) - - g: nx.Graph[int] - g = nx.Graph([(0, 17), (17, 23), (17, 4), (3, 4), (4, 13)]) - inputs = [0, 3] - outputs = [13, 23] - meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} - og_1 = OpenGraph(g, inputs, outputs, meas) - - g = nx.Graph([(6, 7), (6, 17), (17, 1), (7, 4), (17, 4), (4, 2)]) - inputs = [6, 7] - outputs = [1, 2] - meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} - og_2 = OpenGraph(g, inputs, outputs, meas) - - mapping = {6: 23, 7: 13, 1: 100, 2: 200} - - og, mapping_complete = og_1.compose(og_2, mapping) - - expected_graph: nx.Graph[int] - expected_graph = nx.Graph( - [(0, 17), (17, 23), (17, 4), (3, 4), (4, 13), (23, 13), (23, 1), (13, 2), (1, 2), (1, 100), (2, 200)] - ) - assert nx.is_isomorphic(og.graph, expected_graph) - assert og.input_nodes == [0, 3] - assert og.output_nodes == [100, 200] - - outputs_c = {i for i in og.graph.nodes if i not in og.output_nodes} - assert og.measurements.keys() == outputs_c - assert mapping.keys() <= mapping_complete.keys() - assert set(mapping.values()) <= set(mapping_complete.values()) - - -# Full overlap -def test_compose_3() -> None: - # Graph 1 - # [0] -- 17 -- (23) - # | - # [3] -- 4 -- (13) - # - # Graph 2 = Graph 1 - # - # Mapping: 0 -> 0, 3 -> 3, 17 -> 17, 4 -> 4, 23 -> 23, 13 -> 13 - # - # Expected graph = Graph 1 - - g: nx.Graph[int] - g = nx.Graph([(0, 17), (17, 23), (17, 4), (3, 4), (4, 13)]) - inputs = [0, 3] - outputs = [13, 23] - meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} - og_1 = OpenGraph(g, inputs, outputs, meas) - - mapping = {i: i for i in g.nodes} - - og, mapping_complete = og_1.compose(og_1, mapping) - - assert og.isclose(og_1) - assert mapping.keys() <= mapping_complete.keys() - assert set(mapping.values()) <= set(mapping_complete.values()) - - -# Overlap inputs/outputs -def test_compose_4() -> None: - # Graph 1 - # ([17]) -- (3) - # | - # [18] - # - # Graph 2 - # [1] -- 2 -- (3) - # - # Mapping: 1 -> 17, 3 -> 300 - # - # Expected graph - # (300) -- 2 -- [17] -- (3) - # | - # [18] - - g: nx.Graph[int] - g = nx.Graph([(18, 17), (17, 3)]) - inputs = [17, 18] - outputs = [3, 17] - meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} - og_1 = OpenGraph(g, inputs, outputs, meas) - - g = nx.Graph([(1, 2), (2, 3)]) - inputs = [1] - outputs = [3] - meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} - og_2 = OpenGraph(g, inputs, outputs, meas) - - mapping = {1: 17, 3: 300} - - og, mapping_complete = og_1.compose(og_2, mapping) - - expected_graph: nx.Graph[int] - expected_graph = nx.Graph([(18, 17), (17, 3), (17, 2), (2, 300)]) - assert nx.is_isomorphic(og.graph, expected_graph) - assert og.input_nodes == [17, 18] # the input character of node 17 is kept because node 1 (in G2) is an input - assert og.output_nodes == [ - 3, - 300, - ] # the output character of node 17 is lost because node 1 (in G2) is not an output - - outputs_c = {i for i in og.graph.nodes if i not in og.output_nodes} - assert og.measurements.keys() == outputs_c - assert mapping.keys() <= mapping_complete.keys() - assert set(mapping.values()) <= set(mapping_complete.values()) - - -# Inverse series composition -def test_compose_5() -> None: - # Graph 1 - # [1] -- (2) - # | - # [3] - # - # Graph 2 - # [3] -- (4) - # - # Mapping: 4 -> 1, 3 -> 300 - # - # Expected graph - # [300] -- 1 -- (2) - # | - # [3] - - g: nx.Graph[int] - g = nx.Graph([(1, 2), (1, 3)]) - inputs = [1, 3] - outputs = [2] - meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} - og_1 = OpenGraph(g, inputs, outputs, meas) - - g = nx.Graph([(3, 4)]) - inputs = [3] - outputs = [4] - meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} - og_2 = OpenGraph(g, inputs, outputs, meas) - - mapping = {4: 1, 3: 300} - - og, mapping_complete = og_1.compose(og_2, mapping) - - expected_graph: nx.Graph[int] - expected_graph = nx.Graph([(1, 2), (1, 3), (1, 300)]) - assert nx.is_isomorphic(og.graph, expected_graph) - assert og.input_nodes == [3, 300] - assert og.output_nodes == [2] - - outputs_c = {i for i in og.graph.nodes if i not in og.output_nodes} - assert og.measurements.keys() == outputs_c - assert mapping.keys() <= mapping_complete.keys() - assert set(mapping.values()) <= set(mapping_complete.values()) + @pytest.mark.parametrize("test_case", prepare_test_og_compose()) + def test_compose(self, test_case: OpenGraphComposeTestCase) -> None: + og1, og2, og_ref, mapping, compare = test_case + + og, mapping_complete = og1.compose(og2, mapping) + + assert compare(og, og_ref) + assert mapping.keys() <= mapping_complete.keys() + assert set(mapping.values()) <= set(mapping_complete.values()) + + def test_compose_exception(self) -> None: + g: nx.Graph[int] = nx.Graph([(0, 1)]) + inputs = [0] + outputs = [1] + mapping = {0: 0, 1: 1} + + og1 = OpenGraph(g, inputs, outputs, measurements={0: Measurement(0, Plane.XY)}) + og2 = OpenGraph(g, inputs, outputs, measurements={0: Measurement(0.5, Plane.XY)}) + + with pytest.raises( + OpenGraphError, + match=re.escape( + "Attempted to merge nodes with different measurements: (0, Measurement(angle=0.5, plane=Plane.XY)) -> (0, Measurement(angle=0, plane=Plane.XY))." + ), + ): + og1.compose(og2, mapping) + + og3 = OpenGraph(g, inputs, outputs, measurements={0: Plane.XY}) + og4 = OpenGraph(g, inputs, outputs, measurements={0: Plane.XZ}) + + with pytest.raises( + OpenGraphError, + match=re.escape("Attempted to merge nodes with different measurements: (0, Plane.XZ) -> (0, Plane.XY)."), + ): + og3.compose(og4, mapping) From bd55a4e9ff26fc85f4980854c43b077a9d679cba Mon Sep 17 00:00:00 2001 From: matulni Date: Tue, 25 Nov 2025 10:43:43 +0100 Subject: [PATCH 4/8] Remove globals call from test_opengraph --- tests/test_opengraph.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_opengraph.py b/tests/test_opengraph.py index f151705e..8f905368 100644 --- a/tests/test_opengraph.py +++ b/tests/test_opengraph.py @@ -37,6 +37,7 @@ class OpenGraphFlowTestCase(NamedTuple): OPEN_GRAPH_FLOW_TEST_CASES: list[OpenGraphFlowTestCase] = [] +OPEN_GRAPH_COMPOSE_TEST_CASES: list[OpenGraphComposeTestCase] = [] def register_open_graph_flow_test_case( @@ -46,6 +47,13 @@ def register_open_graph_flow_test_case( return test_case +def register_open_graph_compose_test_case( + test_case: Callable[[], OpenGraphComposeTestCase], +) -> Callable[[], OpenGraphComposeTestCase]: + OPEN_GRAPH_COMPOSE_TEST_CASES.append(test_case()) + return test_case + + @register_open_graph_flow_test_case def _og_0() -> OpenGraphFlowTestCase: """Generate open graph. @@ -558,6 +566,7 @@ class OpenGraphComposeTestCase(NamedTuple): # Parallel composition +@register_open_graph_compose_test_case def _compose_0() -> OpenGraphComposeTestCase: """Generate composition test. @@ -592,6 +601,7 @@ def _compose_0() -> OpenGraphComposeTestCase: # Series composition +@register_open_graph_compose_test_case def _compose_1() -> OpenGraphComposeTestCase: """Generate composition test. @@ -638,6 +648,7 @@ def _compose_1() -> OpenGraphComposeTestCase: # Full overlap +@register_open_graph_compose_test_case def _compose_2() -> OpenGraphComposeTestCase: """Generate composition test. @@ -667,6 +678,7 @@ def _compose_2() -> OpenGraphComposeTestCase: # Overlap inputs/outputs +@register_open_graph_compose_test_case def _compose_3() -> OpenGraphComposeTestCase: """Generate composition test. @@ -709,6 +721,7 @@ def _compose_3() -> OpenGraphComposeTestCase: # Inverse series composition +@register_open_graph_compose_test_case def _compose_4() -> OpenGraphComposeTestCase: """Generate composition test. @@ -750,6 +763,7 @@ def _compose_4() -> OpenGraphComposeTestCase: return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) +@register_open_graph_compose_test_case def _compose_5() -> OpenGraphComposeTestCase: """Generate composition test. @@ -779,13 +793,6 @@ def _compose_5() -> OpenGraphComposeTestCase: return OpenGraphComposeTestCase(og1, og2, og_ref, mapping) -def prepare_test_og_compose() -> list[OpenGraphComposeTestCase]: - n_og_samples = 6 - test_cases: list[OpenGraphComposeTestCase] = [globals()[f"_compose_{i}"]() for i in range(n_og_samples)] - - return test_cases - - def check_determinism(pattern: Pattern, fx_rng: Generator, n_shots: int = 3) -> bool: """Verify if the input pattern is deterministic.""" for plane in {Plane.XY, Plane.XZ, Plane.YZ}: @@ -875,13 +882,6 @@ def test_from_to_pattern(self, fx_rng: Generator) -> None: state = pattern.simulate_pattern(input_state=PlanarState(plane, alpha)) assert np.abs(np.dot(state.flatten().conjugate(), state_ref.flatten())) == pytest.approx(1) -<<<<<<< HEAD - @pytest.mark.parametrize("test_case", prepare_test_og_compose()) - def test_compose(self, test_case: OpenGraphComposeTestCase) -> None: - og1, og2, og_ref, mapping, compare = test_case - - og, mapping_complete = og1.compose(og2, mapping) -======= @pytest.mark.parametrize( "test_case", [ @@ -951,10 +951,10 @@ def test_isclose(self) -> None: assert og_1.isclose(og_2, abs_tol=0.1) assert not og_1.isclose(og_2) - -# TODO: rewrite as parametric tests ->>>>>>> rf_og_compare - + @pytest.mark.parametrize("test_case", OPEN_GRAPH_COMPOSE_TEST_CASES) + def test_compose(self, test_case: OpenGraphComposeTestCase) -> None: + og1, og2, og_ref, mapping, compare = test_case + og, mapping_complete = og1.compose(og2, mapping) assert compare(og, og_ref) assert mapping.keys() <= mapping_complete.keys() assert set(mapping.values()) <= set(mapping_complete.values()) From b3c30042771bd6b28396f0e03337dac6a8ef68da Mon Sep 17 00:00:00 2001 From: matulni Date: Wed, 26 Nov 2025 17:39:43 +0100 Subject: [PATCH 5/8] Add comments Thierry's review --- graphix/opengraph.py | 55 ++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/graphix/opengraph.py b/graphix/opengraph.py index b4480456..32a444ce 100644 --- a/graphix/opengraph.py +++ b/graphix/opengraph.py @@ -134,7 +134,7 @@ def __eq__(self, other: object) -> bool: - Same measurement planes or axis. It does not compare measurement angles (for that, see :func:`OpenGraph.isclose`). """ if isinstance(other, OpenGraph): - return _compare_opengraph_structure(self, other) and all( + return self.is_equal_structurally(other) and all( m1.to_plane_or_axis() == m2.to_plane_or_axis() for m1, m2 in zip(self.measurements.values(), other.measurements.values(), strict=True) ) @@ -144,7 +144,7 @@ def __eq__(self, other: object) -> bool: def isclose( self: OpenGraph[Measurement], other: OpenGraph[Measurement], rel_tol: float = 1e-09, abs_tol: float = 0.0 ) -> bool: - """Check if two open graphs of `Measurement` type are similar. + """Check if two open graphs of `Measurement` type are equal within a given tolerance. Parameters ---------- @@ -166,11 +166,38 @@ def isclose( - Equal input and output nodes. - Same measurement planes and approximately equal measurement angles. """ - return _compare_opengraph_structure(self, other) and all( + return self.is_equal_structurally(other) and all( m.isclose(other.measurements[node], rel_tol=rel_tol, abs_tol=abs_tol) for node, m in self.measurements.items() ) + def is_equal_structurally(self, other: OpenGraph[_M_co]) -> bool: + """Compare the underlying structure of two open graphs. + + Parameters + ---------- + other : OpenGraph[_M_co] + + Returns + ------- + bool + ``True`` if ``self`` and ``og`` have the same structure. + + Notes + ----- + This method verifies the open graphs have: + - Truly equal underlying graphs (not up to an isomorphism). + - Equal input and output nodes. + """ + if ( + not nx.utils.graphs_equal(self.graph, other.graph) + or self.input_nodes != other.input_nodes + or other.output_nodes != other.output_nodes + ): + return False + + return set(self.measurements.keys()) == set(other.measurements.keys()) + def neighbors(self, nodes: Collection[int]) -> set[int]: """Return the set containing the neighborhood of a set of nodes in the open graph. @@ -436,27 +463,5 @@ def merge_ports(p1: Iterable[int], p2: Iterable[int]) -> list[int]: return OpenGraph(g, inputs, outputs, measurements), mapping_complete -def _compare_opengraph_structure(og_1: OpenGraph[_M_co], og_2: OpenGraph[_M_co]) -> bool: - """Compare the underlying structure of two open graphs. - - Parameters - ---------- - og_1 : OpenGraph[_M_co] - og_2 : OpenGraph[_M_co] - - Returns - ------- - bool - ``True`` if both open graphs have the same underlying structure. - """ - if not nx.utils.graphs_equal(og_1.graph, og_2.graph): - return False - - if og_1.input_nodes != og_2.input_nodes or og_1.output_nodes != og_2.output_nodes: - return False - - return set(og_1.measurements.keys()) == set(og_2.measurements.keys()) - - class OpenGraphError(Exception): """Exception subclass to handle open graphs errors.""" From ecbbacbfa746929bd855e0f52ffb6aead52972a3 Mon Sep 17 00:00:00 2001 From: matulni Date: Thu, 27 Nov 2025 10:21:48 +0100 Subject: [PATCH 6/8] Up meas is close and og comparison --- graphix/fundamentals.py | 23 +++++++ graphix/measurements.py | 43 +++++++++--- graphix/opengraph.py | 45 +++--------- tests/test_fundamentals.py | 15 ++++ tests/test_measurements.py | 14 ++++ tests/test_opengraph.py | 137 ++++++++++++++++++++++--------------- 6 files changed, 177 insertions(+), 100 deletions(-) create mode 100644 tests/test_measurements.py diff --git a/graphix/fundamentals.py b/graphix/fundamentals.py index 08e92fdd..4d4f2ca7 100644 --- a/graphix/fundamentals.py +++ b/graphix/fundamentals.py @@ -235,6 +235,29 @@ def to_plane_or_axis(self) -> Plane | Axis: Plane | Axis """ + @abstractmethod + def isclose(self, other: AbstractMeasurement, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Determine whether this measurement is close to another. + + Subclasses should implement a notion of “closeness” between two measurements, comparing measurement-specific attributes. The default comparison for ``float`` values involves checking equality within given relative or absolute tolerances. + + Parameters + ---------- + other : AbstractMeasurement + The measurement to compare against. + rel_tol : float, optional + Relative tolerance for determining closeness. Relevant for comparing angles in the `Measurement` subclass. Default is ``1e-9``. + abs_tol : float, optional + Absolute tolerance for determining closeness. Relevant for comparing angles in the `Measurement` subclass. Default is ``0.0``. + + Returns + ------- + bool + ``True`` if this measurement is considered close to ``other`` according + to the subclass's comparison rules; ``False`` otherwise. + """ + return self == other + class AbstractPlanarMeasurement(AbstractMeasurement): """Abstract base class for planar measurement objects. diff --git a/graphix/measurements.py b/graphix/measurements.py index f1f41c2d..f6a7bd56 100644 --- a/graphix/measurements.py +++ b/graphix/measurements.py @@ -11,8 +11,11 @@ TypeAlias, ) +# override introduced in Python 3.12 +from typing_extensions import override + from graphix import utils -from graphix.fundamentals import AbstractPlanarMeasurement, Axis, Plane, Sign +from graphix.fundamentals import AbstractMeasurement, AbstractPlanarMeasurement, Axis, Plane, Sign # Ruff suggests to move this import to a type-checking block, but dataclass requires it here from graphix.parameter import ExpressionOrFloat # noqa: TC001 @@ -53,11 +56,31 @@ class Measurement(AbstractPlanarMeasurement): angle: ExpressionOrFloat plane: Plane - def isclose(self, other: Measurement, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: - """Compare if two measurements have the same plane and their angles are close. + @override + def isclose(self, other: AbstractMeasurement, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Determine whether two measurements are close in angle and share the same plane. + + This method compares the angle of the current measurement with that of + another measurement, using :func:`math.isclose` when both angles are floats. + The planes must match exactly for the measurements to be considered close. + + Parameters + ---------- + other : AbstractMeasurement + The measurement to compare against. + rel_tol : float, optional + Relative tolerance for comparing angles, passed to :func:`math.isclose`. Default is ``1e-9``. + abs_tol : float, optional + Absolute tolerance for comparing angles, passed to :func:`math.isclose`. Default is ``0.0``. - Example + Returns ------- + bool + ``True`` if both measurements lie in the same plane and their angles + are equal or close within the given tolerances; ``False`` otherwise. + + Examples + -------- >>> from graphix.measurements import Measurement >>> from graphix.fundamentals import Plane >>> Measurement(0.0, Plane.XY).isclose(Measurement(0.0, Plane.XY)) @@ -67,11 +90,13 @@ def isclose(self, other: Measurement, rel_tol: float = 1e-09, abs_tol: float = 0 >>> Measurement(0.1, Plane.XY).isclose(Measurement(0.0, Plane.XY)) False """ - return ( - math.isclose(self.angle, other.angle, rel_tol=rel_tol, abs_tol=abs_tol) - if isinstance(self.angle, float) and isinstance(other.angle, float) - else self.angle == other.angle - ) and self.plane == other.plane + if isinstance(other, Measurement): + return ( + math.isclose(self.angle, other.angle, rel_tol=rel_tol, abs_tol=abs_tol) + if isinstance(self.angle, float) and isinstance(other.angle, float) + else self.angle == other.angle + ) and self.plane == other.plane + return False def to_plane_or_axis(self) -> Plane | Axis: """Return the measurements's plane or axis. diff --git a/graphix/opengraph.py b/graphix/opengraph.py index 32a444ce..24346e85 100644 --- a/graphix/opengraph.py +++ b/graphix/opengraph.py @@ -114,41 +114,12 @@ def to_pattern(self: OpenGraph[Measurement]) -> Pattern: raise OpenGraphError("The open graph does not have flow. It does not support a deterministic pattern.") - def __eq__(self, other: object) -> bool: - """Check if two open graphs are equal. + def isclose(self, other: OpenGraph[_M_co], rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Check if two open graphs are equal within a given tolerance. Parameters ---------- - other : object - - Returns - ------- - bool - ``True`` if the two open graphs are equal. - - Notes - ----- - This method verifies the open graphs have: - - Truly equal underlying graphs (not up to an isomorphism). - - Equal input and output nodes. - - Same measurement planes or axis. It does not compare measurement angles (for that, see :func:`OpenGraph.isclose`). - """ - if isinstance(other, OpenGraph): - return self.is_equal_structurally(other) and all( - m1.to_plane_or_axis() == m2.to_plane_or_axis() - for m1, m2 in zip(self.measurements.values(), other.measurements.values(), strict=True) - ) - - return False - - def isclose( - self: OpenGraph[Measurement], other: OpenGraph[Measurement], rel_tol: float = 1e-09, abs_tol: float = 0.0 - ) -> bool: - """Check if two open graphs of `Measurement` type are equal within a given tolerance. - - Parameters - ---------- - other : OpenGraph[Measurement] + other : OpenGraph[_M_co] rel_tol : float Relative tolerance. Optional, defaults to ``1e-09``. abs_tol : float @@ -164,19 +135,21 @@ def isclose( This method verifies the open graphs have: - Truly equal underlying graphs (not up to an isomorphism). - Equal input and output nodes. - - Same measurement planes and approximately equal measurement angles. + - Same measurement planes or axes and approximately equal measurement angles if the open graph is of parametric type `Measurement`. + + The static typer does not allow comparing the structure of two open graphs with different parametric type. """ return self.is_equal_structurally(other) and all( m.isclose(other.measurements[node], rel_tol=rel_tol, abs_tol=abs_tol) for node, m in self.measurements.items() ) - def is_equal_structurally(self, other: OpenGraph[_M_co]) -> bool: + def is_equal_structurally(self, other: OpenGraph[AbstractMeasurement]) -> bool: """Compare the underlying structure of two open graphs. Parameters ---------- - other : OpenGraph[_M_co] + other : OpenGraph[AbstractMeasurement] Returns ------- @@ -188,6 +161,8 @@ def is_equal_structurally(self, other: OpenGraph[_M_co]) -> bool: This method verifies the open graphs have: - Truly equal underlying graphs (not up to an isomorphism). - Equal input and output nodes. + + The static typer allows comparing the structure of two open graphs with different parametric type. """ if ( not nx.utils.graphs_equal(self.graph, other.graph) diff --git a/tests/test_fundamentals.py b/tests/test_fundamentals.py index a127a83e..74a81d4d 100644 --- a/tests/test_fundamentals.py +++ b/tests/test_fundamentals.py @@ -161,3 +161,18 @@ def test_from_axes_ng(self) -> None: Plane.from_axes(Axis.Y, Axis.Y) with pytest.raises(ValueError): Plane.from_axes(Axis.Z, Axis.Z) + + def test_isclose(self) -> None: + for p1, p2 in itertools.combinations(Plane, 2): + assert not p1.isclose(p2) + + for a1, a2 in itertools.combinations(Axis, 2): + assert not a1.isclose(a2) + + for p in Plane: + assert p.isclose(p) + for a in Axis: + assert not p.isclose(a) + + for a in Axis: + assert a.isclose(a) diff --git a/tests/test_measurements.py b/tests/test_measurements.py new file mode 100644 index 00000000..a4386b88 --- /dev/null +++ b/tests/test_measurements.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from graphix.fundamentals import Plane +from graphix.measurements import Measurement + + +class TestMeasurement: + def test_isclose(self) -> None: + m1 = Measurement(0.1, Plane.XY) + m2 = Measurement(0.15, Plane.XY) + + assert not m1.isclose(m2) + assert not m1.isclose(Plane.XY) + assert m1.isclose(m2, abs_tol=0.1) diff --git a/tests/test_opengraph.py b/tests/test_opengraph.py index 6ddce5bb..bf2eb58c 100644 --- a/tests/test_opengraph.py +++ b/tests/test_opengraph.py @@ -15,13 +15,13 @@ from graphix.command import E from graphix.fundamentals import Axis, Plane from graphix.measurements import Measurement -from graphix.opengraph import OpenGraph, OpenGraphError, _M_co +from graphix.opengraph import OpenGraph, OpenGraphError from graphix.pattern import Pattern from graphix.random_objects import rand_circuit from graphix.states import PlanarState if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable from numpy.random import Generator @@ -633,60 +633,7 @@ def test_from_to_pattern(self, fx_rng: Generator) -> None: state = pattern.simulate_pattern(input_state=PlanarState(plane, alpha)) assert np.abs(np.dot(state.flatten().conjugate(), state_ref.flatten())) == pytest.approx(1) - @pytest.mark.parametrize( - "test_case", - [ - ( - OpenGraph( - graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), - input_nodes=[0], - output_nodes=[3], - measurements=dict.fromkeys(range(3), Plane.XY), - ), - OpenGraph( - graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), - input_nodes=[0], - output_nodes=[3], - measurements=dict.fromkeys(range(3), Axis.X), - ), - ), - ( - OpenGraph( - graph=nx.Graph([(0, 1)]), - input_nodes=[0], - output_nodes=[], - measurements=dict.fromkeys(range(2), Plane.XY), - ), - OpenGraph( - graph=nx.Graph([(0, 1)]), - input_nodes=[0, 1], - output_nodes=[], - measurements=dict.fromkeys(range(2), Plane.XY), - ), - ), - ( - OpenGraph( - graph=nx.Graph([(0, 1)]), - input_nodes=[0], - output_nodes=[1], - measurements={0: Measurement(0.6, Plane.XY)}, - ), - OpenGraph( - graph=nx.Graph([(0, 2)]), - input_nodes=[0], - output_nodes=[2], - measurements={0: Measurement(0.6, Plane.XY)}, - ), - ), - ], - ) - def test_eq(self, test_case: Sequence[tuple[OpenGraph[_M_co], OpenGraph[_M_co]]]) -> None: - og_1, og_2 = test_case - assert og_1 == og_1 # noqa: PLR0124 - assert og_1 != og_2 - assert og_2 == og_2 # noqa: PLR0124 - - def test_isclose(self) -> None: + def test_isclose_measurement(self) -> None: og_1 = OpenGraph( graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), input_nodes=[0], @@ -699,8 +646,86 @@ def test_isclose(self) -> None: output_nodes=[3], measurements=dict.fromkeys(range(3), Measurement(0.15, Plane.XY)), ) + og_3 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3), (0, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Measurement(0.15, Plane.XY)), + ) assert og_1.isclose(og_2, abs_tol=0.1) assert not og_1.isclose(og_2) + assert not og_2.isclose(og_3) + + def test_isclose_plane(self) -> None: + og_1 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Plane.XY), + ) + og_2 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Plane.XZ), + ) + + assert not og_1.isclose(og_2) + assert og_1.isclose(og_1) + + def test_isclose_axis(self) -> None: + og_1 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Axis.X), + ) + og_2 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Axis.Y), + ) + + assert not og_1.isclose(og_2) + assert og_1.isclose(og_1) + assert og_2.isclose(og_2) + + def test_is_equal_structurally(self) -> None: + og_1 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Measurement(0.15, Plane.XY)), + ) + og_2 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Measurement(0.1, Plane.XY)), + ) + og_3 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Plane.XY), + ) + og_4 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Axis.X), + ) + og_5 = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3), (0, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Axis.X), + ) + assert og_1.is_equal_structurally(og_2) + assert og_1.is_equal_structurally(og_3) + assert og_1.is_equal_structurally(og_4) + assert not og_1.is_equal_structurally(og_5) # TODO: rewrite as parametric tests From 5ca8698a7569162ef1a0c17cf08a0ec504bf5f2e Mon Sep 17 00:00:00 2001 From: matulni Date: Thu, 27 Nov 2025 10:24:50 +0100 Subject: [PATCH 7/8] Fix docstring --- graphix/opengraph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphix/opengraph.py b/graphix/opengraph.py index 24346e85..39f5cdf5 100644 --- a/graphix/opengraph.py +++ b/graphix/opengraph.py @@ -137,7 +137,7 @@ def isclose(self, other: OpenGraph[_M_co], rel_tol: float = 1e-09, abs_tol: floa - Equal input and output nodes. - Same measurement planes or axes and approximately equal measurement angles if the open graph is of parametric type `Measurement`. - The static typer does not allow comparing the structure of two open graphs with different parametric type. + The static typer does not allow an ``isclose`` comparison of two open graphs with different parametric type. For a structural comparison, see :func:`OpenGraph.is_equal_structurally`. """ return self.is_equal_structurally(other) and all( m.isclose(other.measurements[node], rel_tol=rel_tol, abs_tol=abs_tol) From 4e4260e24bf2714745bfcfc81fecf6ea6fc929d7 Mon Sep 17 00:00:00 2001 From: matulni Date: Thu, 27 Nov 2025 10:36:41 +0100 Subject: [PATCH 8/8] Up compose with measurement isclose --- graphix/opengraph.py | 17 ++++++----------- tests/test_opengraph.py | 17 +++++++---------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/graphix/opengraph.py b/graphix/opengraph.py index 2dc8b3b2..bc1bf873 100644 --- a/graphix/opengraph.py +++ b/graphix/opengraph.py @@ -11,12 +11,12 @@ from graphix.flow._find_gpflow import AlgebraicOpenGraph, PlanarAlgebraicOpenGraph, compute_correction_matrix from graphix.flow.core import GFlow, PauliFlow from graphix.fundamentals import AbstractMeasurement, AbstractPlanarMeasurement -from graphix.measurements import Measurement if TYPE_CHECKING: from collections.abc import Collection, Iterable, Mapping, Sequence from graphix.flow.core import CausalFlow + from graphix.measurements import Measurement from graphix.pattern import Pattern # TODO: Maybe move these definitions to graphix.fundamentals and graphix.measurements ? Now they are redefined in graphix.flow._find_gpflow, not very elegant. @@ -399,17 +399,12 @@ def compose(self, other: OpenGraph[_M_co], mapping: Mapping[int, int]) -> tuple[ if len(mapping) != len(set(mapping.values())): raise ValueError("Values in mapping contain duplicates.") - def equal_measurements(vm: AbstractMeasurement, um: AbstractMeasurement) -> bool: - return vm.isclose(um) if isinstance(vm, Measurement) and isinstance(um, Measurement) else vm == um - for v, u in mapping.items(): - vm = other.measurements.get(v) - um = self.measurements.get(u) - - if vm is None or um is None: - continue - - if not equal_measurements(vm, um): + if ( + (vm := other.measurements.get(v)) is not None + and (um := self.measurements.get(u)) is not None + and not vm.isclose(um) + ): raise OpenGraphError(f"Attempted to merge nodes with different measurements: {v, vm} -> {u, um}.") shift = max(*self.graph.nodes, *mapping.values()) + 1 diff --git a/tests/test_opengraph.py b/tests/test_opengraph.py index 0bd4b0a6..f86eced4 100644 --- a/tests/test_opengraph.py +++ b/tests/test_opengraph.py @@ -560,9 +560,6 @@ class OpenGraphComposeTestCase(NamedTuple): og2: OpenGraph[AbstractMeasurement] og_ref: OpenGraph[AbstractMeasurement] mapping: dict[int, int] - comparison_method: Callable[..., bool] = ( - OpenGraph.__eq__ - ) # Replace by `OpenGraph.isclose` if `OpenGraph` is of type `Measurement`. # Parallel composition @@ -597,7 +594,7 @@ def _compose_0() -> OpenGraphComposeTestCase: mapping = {1: 100, 2: 200} - return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping) # Series composition @@ -644,7 +641,7 @@ def _compose_1() -> OpenGraphComposeTestCase: meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} og_ref = OpenGraph(g, inputs, outputs, meas) - return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping) # Full overlap @@ -674,7 +671,7 @@ def _compose_2() -> OpenGraphComposeTestCase: mapping = {i: i for i in g.nodes} - return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping) # Overlap inputs/outputs @@ -717,7 +714,7 @@ def _compose_3() -> OpenGraphComposeTestCase: meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} og_ref = OpenGraph(g, inputs, outputs, meas) - return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping) # Inverse series composition @@ -760,7 +757,7 @@ def _compose_4() -> OpenGraphComposeTestCase: meas = {i: Measurement(0, Plane.XY) for i in g.nodes - set(outputs)} og_ref = OpenGraph(g, inputs, outputs, meas) - return OpenGraphComposeTestCase(og1, og2, og_ref, mapping, OpenGraph.isclose) + return OpenGraphComposeTestCase(og1, og2, og_ref, mapping) @register_open_graph_compose_test_case @@ -978,9 +975,9 @@ def test_is_equal_structurally(self) -> None: @pytest.mark.parametrize("test_case", OPEN_GRAPH_COMPOSE_TEST_CASES) def test_compose(self, test_case: OpenGraphComposeTestCase) -> None: - og1, og2, og_ref, mapping, compare = test_case + og1, og2, og_ref, mapping = test_case og, mapping_complete = og1.compose(og2, mapping) - assert compare(og, og_ref) + assert og.isclose(og_ref) assert mapping.keys() <= mapping_complete.keys() assert set(mapping.values()) <= set(mapping_complete.values())