From 466b884f60b7413d6dcebd6ab7f56131dc4ee9cb Mon Sep 17 00:00:00 2001 From: georgios-ts <45130028+georgios-ts@users.noreply.github.com> Date: Wed, 8 Sep 2021 19:02:01 +0300 Subject: [PATCH] Fixes `digraph_union` if `merge_edges` is set to true. (#439) * Fixes `digraph_union` if `merge_edges` is set to true. Previously, `digraph_union` would falsely keep or delete edges if `merge_edges` is set to true. This commit fixes the logic of `digraph_union` to skip an edge from the second graph if both its endpoints were merged to nodes from the first graph and these nodes already share an edge with equal weight data. At the same time, a new function `graph_union` was added that returns the union of two `PyGraph`s. Closes #432. * increase test cov * improve release note and message in panic exception Co-authored-by: Matthew Treinish * add dispatch function `union` and implement `find_node_by_weight` for `PyGraph` * lint * Release note fixes Co-authored-by: Matthew Treinish --- docs/source/api.rst | 4 +- .../notes/bugfix-union-7da79789134a3028.yaml | 33 +++ retworkx/__init__.py | 61 +++++ src/digraph.rs | 22 +- src/graph.rs | 22 +- src/lib.rs | 24 ++ src/union.rs | 259 +++++++++++------- tests/digraph/test_union.py | 61 +++-- tests/graph/test_union.py | 88 ++++++ 9 files changed, 434 insertions(+), 140 deletions(-) create mode 100644 releasenotes/notes/bugfix-union-7da79789134a3028.yaml create mode 100644 tests/graph/test_union.py diff --git a/docs/source/api.rst b/docs/source/api.rst index c535383f3..d217e075b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -145,7 +145,7 @@ Other Algorithm Functions retworkx.transitivity retworkx.core_number retworkx.graph_greedy_color - retworkx.digraph_union + retworkx.union retworkx.metric_closure Generators @@ -242,6 +242,7 @@ the functions from the explicitly typed based on the data type. retworkx.digraph_transitivity retworkx.digraph_core_number retworkx.digraph_complement + retworkx.digraph_union retworkx.digraph_random_layout retworkx.digraph_bipartite_layout retworkx.digraph_circular_layout @@ -283,6 +284,7 @@ typed API based on the data type. retworkx.graph_transitivity retworkx.graph_core_number retworkx.graph_complement + retworkx.graph_union retworkx.graph_random_layout retworkx.graph_bipartite_layout retworkx.graph_circular_layout diff --git a/releasenotes/notes/bugfix-union-7da79789134a3028.yaml b/releasenotes/notes/bugfix-union-7da79789134a3028.yaml new file mode 100644 index 000000000..ae1da91e0 --- /dev/null +++ b/releasenotes/notes/bugfix-union-7da79789134a3028.yaml @@ -0,0 +1,33 @@ +--- +fixes: + - | + Previously, :func:`~retworkx.digraph_union` would incorrectly keep or delete edges + if argument ``merge_edges`` is set to true. This has been fixed and an edge from + the second graph will be skipped if both its endpoints were merged to nodes from + the first graph and these nodes already share an edge with equal weight data. + Fixed `#432 `__ +features: + - | + Add a new function :func:`~retworkx.graph_union` that returns the union + of two :class:`~retworkx.PyGraph` objects. This is the equivalent to + :func:`~retworkx.digraph_union` but for a :class:`~retworkx.PyGraph` + instead of for a :class:`~retworkx.PyDiGraph`. A new unified function + :func:`~retworkx.union` was also added that supports both + :class:`~retworkx.PyDiGraph` and :class:`~retworkx.PyGraph`. + For example: + + .. jupyter-execute:: + + import retworkx + from retworkx.visualization import mpl_draw + + first = retworkx.generators.path_graph(3, weights=["a_0", "node", "a_1"]) + second = retworkx.generators.cycle_graph(3, weights=["node", "b_0", "b_1"]) + graph = retworkx.graph_union(first, second, merge_nodes=True) + mpl_draw(graph) + - | + The kwargs ``merge_nodes`` and ``merge_edges`` of :func:`~retworkx.digraph_union` are + now optional and by default are set `False`. + - | + Add a new :meth:`~retworkx.PyGraph.find_node_by_weight` that finds the index + of a node given a specific weight. diff --git a/retworkx/__init__.py b/retworkx/__init__.py index 7b2691ce6..a1b9f2f49 100644 --- a/retworkx/__init__.py +++ b/retworkx/__init__.py @@ -1633,3 +1633,64 @@ def _graph_vf2_mapping( induced=induced, call_limit=call_limit, ) + + +@functools.singledispatch +def union( + first, + second, + merge_nodes=False, + merge_edges=False, +): + """Return a new graph by forming a union from two input graph objects + + The algorithm in this function operates in three phases: + + 1. Add all the nodes from ``second`` into ``first``. operates in + :math:`\\mathcal{O}(n_2)`, with :math:`n_2` being number of nodes in + ``second``. + 2. Merge nodes from ``second`` over ``first`` given that: + + - The ``merge_nodes`` is ``True``. operates in :math:`\\mathcal{O}(n_1 n_2)`, + with :math:`n_1` being the number of nodes in ``first`` and :math:`n_2` + the number of nodes in ``second`` + - The respective node in ``second`` and ``first`` share the same + weight/data payload. + + 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges`` + parameter is ``True`` and the respective edge in ``second`` and + ``first`` share the same weight/data payload they will be merged together. + + :param first: The first graph object + :param second: The second graph object + :param bool merge_nodes: If set to ``True`` nodes will be merged between + ``second`` and ``first`` if the weights are equal. Default: ``False``. + :param bool merge_edges: If set to ``True`` edges will be merged between + ``second`` and ``first`` if the weights are equal. Default: ``False``. + + :returns: A new graph object that is the union of ``second`` and + ``first``. It's worth noting the weight/data payload objects are + passed by reference from ``first`` and ``second`` to this new object. + :rtype: :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` + """ + raise TypeError("Invalid Input Type %s for graph" % type(first)) + + +@union.register(PyDiGraph) +def _digraph_union( + first, + second, + merge_nodes=False, + merge_edges=False, +): + return digraph_union(first, second, merge_nodes=False, merge_edges=False) + + +@union.register(PyGraph) +def _graph_union( + first, + second, + merge_nodes=False, + merge_edges=False, +): + return graph_union(first, second, merge_nodes=False, merge_edges=False) diff --git a/src/digraph.rs b/src/digraph.rs index a6fa898f0..fa44f8625 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -50,8 +50,8 @@ use super::iterators::{ EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList, }; use super::{ - DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes, NoSuitableNeighbors, - NodesRemoved, + find_node_by_weight, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes, + NoSuitableNeighbors, NodesRemoved, }; use super::dag_algo::is_directed_acyclic_graph; @@ -1401,21 +1401,9 @@ impl PyDiGraph { &self, py: Python, obj: PyObject, - ) -> Option { - let mut index = None; - for node in self.graph.node_indices() { - let weight = self.graph.node_weight(node).unwrap(); - let weight_compare = |a: &PyAny, b: &PyAny| -> PyResult { - let res = a.compare(b)?; - Ok(res == Ordering::Equal) - }; - - if weight_compare(obj.as_ref(py), weight.as_ref(py)).unwrap() { - index = Some(node.index()); - break; - } - } - index + ) -> PyResult> { + find_node_by_weight(py, &self.graph, &obj) + .map(|node| node.map(|x| x.index())) } /// Merge two nodes in the graph. diff --git a/src/graph.rs b/src/graph.rs index 3d0ba9db7..a6582b630 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -35,7 +35,7 @@ use super::dot_utils::build_dot; use super::iterators::{ EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList, }; -use super::{NoEdgeBetweenNodes, NodesRemoved}; +use super::{find_node_by_weight, NoEdgeBetweenNodes, NodesRemoved}; use petgraph::algo; use petgraph::graph::{EdgeIndex, NodeIndex}; @@ -1020,6 +1020,26 @@ impl PyGraph { Ok(()) } + /// Find node within this graph given a specific weight + /// + /// This algorithm has a worst case of O(n) since it searches the node + /// indices in order. If there is more than one node in the graph with the + /// same weight only the first match (by node index) will be returned. + /// + /// :param obj: The weight to look for in the graph. + /// + /// :returns: the index of the first node in the graph that is equal to the + /// weight. If no match is found ``None`` will be returned. + /// :rtype: int + pub fn find_node_by_weight( + &self, + py: Python, + obj: PyObject, + ) -> PyResult> { + find_node_by_weight(py, &self.graph, &obj) + .map(|node| node.map(|x| x.index())) + } + /// Get the index and data for the neighbors of a node. /// /// This will return a dictionary where the keys are the node indexes of diff --git a/src/lib.rs b/src/lib.rs index f028c9cf3..120dc34b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,9 +62,12 @@ use petgraph::visit::{ Data, GraphBase, GraphProp, IntoEdgeReferences, IntoNodeIdentifiers, NodeCount, NodeIndexable, }; +use petgraph::EdgeType; use crate::generators::PyInit_generators; +type StablePyGraph = StableGraph; + pub trait NodesRemoved { fn nodes_removed(&self) -> bool; } @@ -130,6 +133,26 @@ fn weight_callable( } } +fn find_node_by_weight( + py: Python, + graph: &StablePyGraph, + obj: &PyObject, +) -> PyResult> { + let mut index = None; + for node in graph.node_indices() { + let weight = graph.node_weight(node).unwrap(); + if obj + .as_ref(py) + .rich_compare(weight, pyo3::basic::CompareOp::Eq)? + .is_true()? + { + index = Some(node); + break; + } + } + Ok(index) +} + // The provided node is invalid. create_exception!(retworkx, InvalidNode, PyException); // Performing this operation would result in trying to add a cycle to a DAG. @@ -171,6 +194,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(digraph_vf2_mapping))?; m.add_wrapped(wrap_pyfunction!(graph_vf2_mapping))?; m.add_wrapped(wrap_pyfunction!(digraph_union))?; + m.add_wrapped(wrap_pyfunction!(graph_union))?; m.add_wrapped(wrap_pyfunction!(topological_sort))?; m.add_wrapped(wrap_pyfunction!(descendants))?; m.add_wrapped(wrap_pyfunction!(ancestors))?; diff --git a/src/union.rs b/src/union.rs index 307bcfccb..61b4c68f6 100644 --- a/src/union.rs +++ b/src/union.rs @@ -10,139 +10,182 @@ // License for the specific language governing permissions and limitations // under the License. -use crate::{digraph, digraph::PyDiGraph}; -use hashbrown::{HashMap, HashSet}; -use petgraph::algo; -use petgraph::graph::EdgeIndex; +use crate::{digraph, find_node_by_weight, graph, StablePyGraph}; + +use petgraph::stable_graph::NodeIndex; +use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeIndexable}; +use petgraph::{algo, EdgeType}; + use pyo3::prelude::*; use pyo3::Python; -use std::cmp::Ordering; -/// [Graph] Return a new PyDiGraph by forming a union from`a` and `b` graphs. -/// -/// The algorithm has three phases: -/// - adds all nodes from `b` to `a`. operates in O(n), n being number of nodes in `b`. -/// - merges nodes from `b` over `a` given that: -/// - `merge_nodes` is `true`. operates in O(n^2), n being number of nodes in `b`. -/// - respective node in`b` and `a` share the same weight -/// - adds all edges from `b` to `a`. -/// - `merge_edges` is `true` -/// - respective edge in`b` and `a` share the same weight -/// -/// with the same weight in graphs `a` and `b` and merged those nodes. -/// -/// The nodes from graph `b` will replace nodes from `a`. -/// -/// At this point, only `PyDiGraph` is supported. -fn _digraph_union( +#[derive(Copy, Clone)] +enum Entry { + Merged(T), + Added(T), + None, +} + +fn extract(x: Entry) -> T { + match x { + Entry::Merged(val) => val, + Entry::Added(val) => val, + Entry::None => panic!("Unexpected internal error: called `Entry::extract()` on a `None` value. Please file an issue at https://github.com/Qiskit/retworkx/issues/new/choose with the details on how you encountered this."), + } +} + +fn union( py: Python, - a: &PyDiGraph, - b: &PyDiGraph, + first: &StablePyGraph, + second: &StablePyGraph, merge_nodes: bool, merge_edges: bool, -) -> PyResult { - let first = &a.graph; - let second = &b.graph; - let mut combined = PyDiGraph { - graph: first.clone(), - cycle_state: algo::DfsSpace::default(), - check_cycle: false, - node_removed: false, - multigraph: true, - }; - let mut node_map = HashMap::with_capacity(second.node_count()); - let mut edge_map = HashSet::with_capacity(second.edge_count()); - - let compare_weights = |a: &PyAny, b: &PyAny| -> PyResult { - let res = a.compare(b)?; - Ok(res == Ordering::Equal) - }; +) -> PyResult> { + let mut out_graph = first.clone(); + let mut node_map: Vec> = + vec![Entry::None; second.node_bound()]; for node in second.node_indices() { - let node_index = combined.add_node(second[node].clone_ref(py))?; - node_map.insert(node.index(), node_index); - } - - for edge in b.weighted_edge_list(py).edges { - let source = edge.0; - let target = edge.1; - let edge_weight = edge.2; + let weight = &second[node]; + if merge_nodes { + if let Some(index) = find_node_by_weight(py, first, weight)? { + node_map[node.index()] = Entry::Merged(index); + continue; + } + } - let new_source = *node_map.get(&source).unwrap(); - let new_target = *node_map.get(&target).unwrap(); + let index = out_graph.add_node(weight.clone_ref(py)); + node_map[node.index()] = Entry::Added(index); + } - let edge_index = combined.add_edge( - new_source, - new_target, - edge_weight.clone_ref(py), - )?; + let weights_equal = |a: &PyObject, b: &PyObject| -> PyResult { + a.as_ref(py) + .rich_compare(b, pyo3::basic::CompareOp::Eq)? + .is_true() + }; - let edge_node = EdgeIndex::new(edge_index); + for edge in second.edge_references() { + let source = edge.source().index(); + let target = edge.target().index(); + let new_weight = edge.weight(); - if combined.has_edge(source, target) { - let w = combined.graph.edge_weight(edge_node).unwrap(); - if compare_weights(edge_weight.as_ref(py), w.as_ref(py)).unwrap() { - edge_map.insert(edge_node); + let mut found = false; + if merge_edges { + // if both endpoints were merged, + // check if need to skip the edge as well. + if let (Entry::Merged(new_source), Entry::Merged(new_target)) = + (node_map[source], node_map[target]) + { + for edge in first.edges(new_source) { + if edge.target() == new_target + && weights_equal(new_weight, edge.weight())? + { + found = true; + break; + } + } } } - } - if merge_nodes { - for node in second.node_indices() { - let weight = &second[node].clone_ref(py); - let index = a.find_node_by_weight(py, weight.clone_ref(py)); - - if index.is_some() { - let other_node = node_map.get(&node.index()); - combined.merge_nodes( - py, - *other_node.unwrap(), - index.unwrap(), - )?; - } + if !found { + let new_source = extract(node_map[source]); + let new_target = extract(node_map[target]); + out_graph.add_edge( + new_source, + new_target, + new_weight.clone_ref(py), + ); } } - if merge_edges { - for edge in edge_map { - combined.graph.remove_edge(edge); - } - } + Ok(out_graph) +} - Ok(combined) +/// Return a new PyGraph by forming a union from two input PyGraph objects +/// +/// The algorithm in this function operates in three phases: +/// +/// 1. Add all the nodes from ``second`` into ``first``. operates in +/// :math:`\mathcal{O}(n_2)`, with :math:`n_2` being number of nodes in +/// ``second``. +/// 2. Merge nodes from ``second`` over ``first`` given that: +/// +/// - The ``merge_nodes`` is ``True``. operates in :math:`\mathcal{O}(n_1 n_2)`, +/// with :math:`n_1` being the number of nodes in ``first`` and :math:`n_2` +/// the number of nodes in ``second`` +/// - The respective node in ``second`` and ``first`` share the same +/// weight/data payload. +/// +/// 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges`` +/// parameter is ``True`` and the respective edge in ``second`` and +/// ``first`` share the same weight/data payload they will be merged together. +/// +/// :param PyGraph first: The first undirected graph object +/// :param PyGraph second: The second undirected graph object +/// :param bool merge_nodes: If set to ``True`` nodes will be merged between +/// ``second`` and ``first`` if the weights are equal. Default: ``False``. +/// :param bool merge_edges: If set to ``True`` edges will be merged between +/// ``second`` and ``first`` if the weights are equal. Default: ``False``. +/// +/// :returns: A new PyGraph object that is the union of ``second`` and +/// ``first``. It's worth noting the weight/data payload objects are +/// passed by reference from ``first`` and ``second`` to this new object. +/// :rtype: PyGraph +#[pyfunction(merge_nodes = false, merge_edges = false)] +#[pyo3( + text_signature = "(first, second, /, merge_nodes=False, merge_edges=False)" +)] +fn graph_union( + py: Python, + first: &graph::PyGraph, + second: &graph::PyGraph, + merge_nodes: bool, + merge_edges: bool, +) -> PyResult { + let out_graph = + union(py, &first.graph, &second.graph, merge_nodes, merge_edges)?; + + Ok(graph::PyGraph { + graph: out_graph, + node_removed: first.node_removed, + multigraph: true, + }) } /// Return a new PyDiGraph by forming a union from two input PyDiGraph objects /// /// The algorithm in this function operates in three phases: /// -/// 1. Add all the nodes from ``second`` into ``first``. operates in O(n), -/// with n being number of nodes in `b`. -/// 2. Merge nodes from ``second`` over ``first`` given that: +/// 1. Add all the nodes from ``second`` into ``first``. operates in +/// :math:`\mathcal{O}(n_2)`, with :math:`n_2` being number of nodes in +/// ``second``. +/// 2. Merge nodes from ``second`` over ``first`` given that: /// -/// - The ``merge_nodes`` is ``True``. operates in O(n^2), with n being the -/// number of nodes in ``second``. -/// - The respective node in ``second`` and ``first`` share the same -/// weight/data payload. +/// - The ``merge_nodes`` is ``True``. operates in :math:`\mathcal{O}(n_1 n_2)`, +/// with :math:`n_1` being the number of nodes in ``first`` and :math:`n_2` +/// the number of nodes in ``second`` +/// - The respective node in ``second`` and ``first`` share the same +/// weight/data payload. /// -/// 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges`` -/// parameter is ``True`` and the respective edge in ``second`` and -/// first`` share the same weight/data payload they will be merged -/// together. +/// 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges`` +/// parameter is ``True`` and the respective edge in ``second`` and +/// ``first`` share the same weight/data payload they will be merged together. /// -/// :param PyDiGraph first: The first directed graph object -/// :param PyDiGraph second: The second directed graph object -/// :param bool merge_nodes: If set to ``True`` nodes will be merged between -/// ``second`` and ``first`` if the weights are equal. -/// :param bool merge_edges: If set to ``True`` edges will be merged between -/// ``second`` and ``first`` if the weights are equal. +/// :param PyDiGraph first: The first directed graph object +/// :param PyDiGraph second: The second directed graph object +/// :param bool merge_nodes: If set to ``True`` nodes will be merged between +/// ``second`` and ``first`` if the weights are equal. Default: ``False``. +/// :param bool merge_edges: If set to ``True`` edges will be merged between +/// ``second`` and ``first`` if the weights are equal. Default: ``False``. /// -/// :returns: A new PyDiGraph object that is the union of ``second`` and +/// :returns: A new PyDiGraph object that is the union of ``second`` and /// ``first``. It's worth noting the weight/data payload objects are /// passed by reference from ``first`` and ``second`` to this new object. -/// :rtype: PyDiGraph -#[pyfunction] -#[pyo3(text_signature = "(first, second, merge_nodes, merge_edges, /)")] +/// :rtype: PyDiGraph +#[pyfunction(merge_nodes = false, merge_edges = false)] +#[pyo3( + text_signature = "(first, second, /, merge_nodes=False, merge_edges=False)" +)] fn digraph_union( py: Python, first: &digraph::PyDiGraph, @@ -150,6 +193,14 @@ fn digraph_union( merge_nodes: bool, merge_edges: bool, ) -> PyResult { - let res = _digraph_union(py, first, second, merge_nodes, merge_edges)?; - Ok(res) + let out_graph = + union(py, &first.graph, &second.graph, merge_nodes, merge_edges)?; + + Ok(digraph::PyDiGraph { + graph: out_graph, + cycle_state: algo::DfsSpace::default(), + check_cycle: false, + node_removed: first.node_removed, + multigraph: true, + }) } diff --git a/tests/digraph/test_union.py b/tests/digraph/test_union.py index d32b7a528..5ca44cd9c 100644 --- a/tests/digraph/test_union.py +++ b/tests/digraph/test_union.py @@ -31,23 +31,6 @@ def test_union_merge_all(self): self.assertTrue(retworkx.is_isomorphic(dag_a, dag_c)) - def test_union_basic_merge_edges_only(self): - dag_a = retworkx.PyDiGraph() - dag_b = retworkx.PyDiGraph() - - node_a = dag_a.add_node("a_1") - dag_a.add_child(node_a, "a_2", "e_1") - dag_a.add_child(node_a, "a_3", "e_2") - - node_b = dag_b.add_node("a_1") - dag_b.add_child(node_b, "a_2", "e_1") - dag_b.add_child(node_b, "a_3", "e_2") - - dag_c = retworkx.digraph_union(dag_a, dag_b, False, True) - - self.assertTrue(len(dag_c.edge_list()) == 2) - self.assertTrue(len(dag_c.nodes()) == 6) - def test_union_basic_merge_nodes_only(self): dag_a = retworkx.PyDiGraph() dag_b = retworkx.PyDiGraph() @@ -82,3 +65,47 @@ def test_union_basic_merge_none(self): self.assertTrue(len(dag_c.nodes()) == 6) self.assertTrue(len(dag_c.edge_list()) == 4) + + def test_union_mismatch_edge_weight(self): + first = retworkx.PyDiGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyDiGraph() + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.digraph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 1, "b")]) + + def test_union_node_hole(self): + first = retworkx.PyDiGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyDiGraph() + dummy = second.add_node("dummy") + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "a")]) + second.remove_node(dummy) + + final = retworkx.digraph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a")]) + + def test_union_edge_between_merged_and_unmerged_nodes(self): + first = retworkx.PyDiGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyDiGraph() + nodes = second.add_nodes_from([0, 2]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.digraph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 2, "b")]) diff --git a/tests/graph/test_union.py b/tests/graph/test_union.py new file mode 100644 index 000000000..b7e21fd46 --- /dev/null +++ b/tests/graph/test_union.py @@ -0,0 +1,88 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import retworkx + + +class TestUnion(unittest.TestCase): + def setUp(self): + self.graph = retworkx.PyGraph() + self.graph.add_nodes_from(["a_1", "a_2", "a_3"]) + self.graph.extend_from_weighted_edge_list( + [(0, 1, "e_1"), (1, 2, "e_2")] + ) + + def test_union_basic_merge_none(self): + final = retworkx.graph_union( + self.graph, self.graph, merge_nodes=False, merge_edges=False + ) + self.assertTrue(len(final.nodes()) == 6) + self.assertTrue(len(final.edge_list()) == 4) + + def test_union_merge_all(self): + final = retworkx.graph_union( + self.graph, self.graph, merge_nodes=True, merge_edges=True + ) + self.assertTrue(retworkx.is_isomorphic(final, self.graph)) + + def test_union_basic_merge_nodes_only(self): + final = retworkx.graph_union( + self.graph, self.graph, merge_nodes=True, merge_edges=False + ) + self.assertTrue(len(final.edge_list()) == 4) + self.assertTrue(len(final.get_all_edge_data(0, 1)) == 2) + self.assertTrue(len(final.nodes()) == 3) + + def test_union_mismatch_edge_weight(self): + first = retworkx.PyGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyGraph() + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.graph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 1, "b")]) + + def test_union_node_hole(self): + first = retworkx.PyGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyGraph() + dummy = second.add_node("dummy") + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "a")]) + second.remove_node(dummy) + + final = retworkx.graph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a")]) + + def test_union_edge_between_merged_and_unmerged_nodes(self): + first = retworkx.PyGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyGraph() + nodes = second.add_nodes_from([0, 2]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.graph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 2, "b")])