diff --git a/docs/source/api.rst b/docs/source/api.rst index c535383f3..d217e075b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -145,7 +145,7 @@ Other Algorithm Functions retworkx.transitivity retworkx.core_number retworkx.graph_greedy_color - retworkx.digraph_union + retworkx.union retworkx.metric_closure Generators @@ -242,6 +242,7 @@ the functions from the explicitly typed based on the data type. retworkx.digraph_transitivity retworkx.digraph_core_number retworkx.digraph_complement + retworkx.digraph_union retworkx.digraph_random_layout retworkx.digraph_bipartite_layout retworkx.digraph_circular_layout @@ -283,6 +284,7 @@ typed API based on the data type. retworkx.graph_transitivity retworkx.graph_core_number retworkx.graph_complement + retworkx.graph_union retworkx.graph_random_layout retworkx.graph_bipartite_layout retworkx.graph_circular_layout diff --git a/releasenotes/notes/bugfix-union-7da79789134a3028.yaml b/releasenotes/notes/bugfix-union-7da79789134a3028.yaml new file mode 100644 index 000000000..ae1da91e0 --- /dev/null +++ b/releasenotes/notes/bugfix-union-7da79789134a3028.yaml @@ -0,0 +1,33 @@ +--- +fixes: + - | + Previously, :func:`~retworkx.digraph_union` would incorrectly keep or delete edges + if argument ``merge_edges`` is set to true. This has been fixed and an edge from + the second graph will be skipped if both its endpoints were merged to nodes from + the first graph and these nodes already share an edge with equal weight data. + Fixed `#432 `__ +features: + - | + Add a new function :func:`~retworkx.graph_union` that returns the union + of two :class:`~retworkx.PyGraph` objects. This is the equivalent to + :func:`~retworkx.digraph_union` but for a :class:`~retworkx.PyGraph` + instead of for a :class:`~retworkx.PyDiGraph`. A new unified function + :func:`~retworkx.union` was also added that supports both + :class:`~retworkx.PyDiGraph` and :class:`~retworkx.PyGraph`. + For example: + + .. jupyter-execute:: + + import retworkx + from retworkx.visualization import mpl_draw + + first = retworkx.generators.path_graph(3, weights=["a_0", "node", "a_1"]) + second = retworkx.generators.cycle_graph(3, weights=["node", "b_0", "b_1"]) + graph = retworkx.graph_union(first, second, merge_nodes=True) + mpl_draw(graph) + - | + The kwargs ``merge_nodes`` and ``merge_edges`` of :func:`~retworkx.digraph_union` are + now optional and by default are set `False`. + - | + Add a new :meth:`~retworkx.PyGraph.find_node_by_weight` that finds the index + of a node given a specific weight. diff --git a/retworkx/__init__.py b/retworkx/__init__.py index 7b2691ce6..a1b9f2f49 100644 --- a/retworkx/__init__.py +++ b/retworkx/__init__.py @@ -1633,3 +1633,64 @@ def _graph_vf2_mapping( induced=induced, call_limit=call_limit, ) + + +@functools.singledispatch +def union( + first, + second, + merge_nodes=False, + merge_edges=False, +): + """Return a new graph by forming a union from two input graph objects + + The algorithm in this function operates in three phases: + + 1. Add all the nodes from ``second`` into ``first``. operates in + :math:`\\mathcal{O}(n_2)`, with :math:`n_2` being number of nodes in + ``second``. + 2. Merge nodes from ``second`` over ``first`` given that: + + - The ``merge_nodes`` is ``True``. operates in :math:`\\mathcal{O}(n_1 n_2)`, + with :math:`n_1` being the number of nodes in ``first`` and :math:`n_2` + the number of nodes in ``second`` + - The respective node in ``second`` and ``first`` share the same + weight/data payload. + + 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges`` + parameter is ``True`` and the respective edge in ``second`` and + ``first`` share the same weight/data payload they will be merged together. + + :param first: The first graph object + :param second: The second graph object + :param bool merge_nodes: If set to ``True`` nodes will be merged between + ``second`` and ``first`` if the weights are equal. Default: ``False``. + :param bool merge_edges: If set to ``True`` edges will be merged between + ``second`` and ``first`` if the weights are equal. Default: ``False``. + + :returns: A new graph object that is the union of ``second`` and + ``first``. It's worth noting the weight/data payload objects are + passed by reference from ``first`` and ``second`` to this new object. + :rtype: :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` + """ + raise TypeError("Invalid Input Type %s for graph" % type(first)) + + +@union.register(PyDiGraph) +def _digraph_union( + first, + second, + merge_nodes=False, + merge_edges=False, +): + return digraph_union(first, second, merge_nodes=False, merge_edges=False) + + +@union.register(PyGraph) +def _graph_union( + first, + second, + merge_nodes=False, + merge_edges=False, +): + return graph_union(first, second, merge_nodes=False, merge_edges=False) diff --git a/src/digraph.rs b/src/digraph.rs index a6fa898f0..fa44f8625 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -50,8 +50,8 @@ use super::iterators::{ EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList, }; use super::{ - DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes, NoSuitableNeighbors, - NodesRemoved, + find_node_by_weight, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes, + NoSuitableNeighbors, NodesRemoved, }; use super::dag_algo::is_directed_acyclic_graph; @@ -1401,21 +1401,9 @@ impl PyDiGraph { &self, py: Python, obj: PyObject, - ) -> Option { - let mut index = None; - for node in self.graph.node_indices() { - let weight = self.graph.node_weight(node).unwrap(); - let weight_compare = |a: &PyAny, b: &PyAny| -> PyResult { - let res = a.compare(b)?; - Ok(res == Ordering::Equal) - }; - - if weight_compare(obj.as_ref(py), weight.as_ref(py)).unwrap() { - index = Some(node.index()); - break; - } - } - index + ) -> PyResult> { + find_node_by_weight(py, &self.graph, &obj) + .map(|node| node.map(|x| x.index())) } /// Merge two nodes in the graph. diff --git a/src/graph.rs b/src/graph.rs index 3d0ba9db7..a6582b630 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -35,7 +35,7 @@ use super::dot_utils::build_dot; use super::iterators::{ EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList, }; -use super::{NoEdgeBetweenNodes, NodesRemoved}; +use super::{find_node_by_weight, NoEdgeBetweenNodes, NodesRemoved}; use petgraph::algo; use petgraph::graph::{EdgeIndex, NodeIndex}; @@ -1020,6 +1020,26 @@ impl PyGraph { Ok(()) } + /// Find node within this graph given a specific weight + /// + /// This algorithm has a worst case of O(n) since it searches the node + /// indices in order. If there is more than one node in the graph with the + /// same weight only the first match (by node index) will be returned. + /// + /// :param obj: The weight to look for in the graph. + /// + /// :returns: the index of the first node in the graph that is equal to the + /// weight. If no match is found ``None`` will be returned. + /// :rtype: int + pub fn find_node_by_weight( + &self, + py: Python, + obj: PyObject, + ) -> PyResult> { + find_node_by_weight(py, &self.graph, &obj) + .map(|node| node.map(|x| x.index())) + } + /// Get the index and data for the neighbors of a node. /// /// This will return a dictionary where the keys are the node indexes of diff --git a/src/lib.rs b/src/lib.rs index f028c9cf3..120dc34b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,9 +62,12 @@ use petgraph::visit::{ Data, GraphBase, GraphProp, IntoEdgeReferences, IntoNodeIdentifiers, NodeCount, NodeIndexable, }; +use petgraph::EdgeType; use crate::generators::PyInit_generators; +type StablePyGraph = StableGraph; + pub trait NodesRemoved { fn nodes_removed(&self) -> bool; } @@ -130,6 +133,26 @@ fn weight_callable( } } +fn find_node_by_weight( + py: Python, + graph: &StablePyGraph, + obj: &PyObject, +) -> PyResult> { + let mut index = None; + for node in graph.node_indices() { + let weight = graph.node_weight(node).unwrap(); + if obj + .as_ref(py) + .rich_compare(weight, pyo3::basic::CompareOp::Eq)? + .is_true()? + { + index = Some(node); + break; + } + } + Ok(index) +} + // The provided node is invalid. create_exception!(retworkx, InvalidNode, PyException); // Performing this operation would result in trying to add a cycle to a DAG. @@ -171,6 +194,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(digraph_vf2_mapping))?; m.add_wrapped(wrap_pyfunction!(graph_vf2_mapping))?; m.add_wrapped(wrap_pyfunction!(digraph_union))?; + m.add_wrapped(wrap_pyfunction!(graph_union))?; m.add_wrapped(wrap_pyfunction!(topological_sort))?; m.add_wrapped(wrap_pyfunction!(descendants))?; m.add_wrapped(wrap_pyfunction!(ancestors))?; diff --git a/src/union.rs b/src/union.rs index 307bcfccb..61b4c68f6 100644 --- a/src/union.rs +++ b/src/union.rs @@ -10,139 +10,182 @@ // License for the specific language governing permissions and limitations // under the License. -use crate::{digraph, digraph::PyDiGraph}; -use hashbrown::{HashMap, HashSet}; -use petgraph::algo; -use petgraph::graph::EdgeIndex; +use crate::{digraph, find_node_by_weight, graph, StablePyGraph}; + +use petgraph::stable_graph::NodeIndex; +use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeIndexable}; +use petgraph::{algo, EdgeType}; + use pyo3::prelude::*; use pyo3::Python; -use std::cmp::Ordering; -/// [Graph] Return a new PyDiGraph by forming a union from`a` and `b` graphs. -/// -/// The algorithm has three phases: -/// - adds all nodes from `b` to `a`. operates in O(n), n being number of nodes in `b`. -/// - merges nodes from `b` over `a` given that: -/// - `merge_nodes` is `true`. operates in O(n^2), n being number of nodes in `b`. -/// - respective node in`b` and `a` share the same weight -/// - adds all edges from `b` to `a`. -/// - `merge_edges` is `true` -/// - respective edge in`b` and `a` share the same weight -/// -/// with the same weight in graphs `a` and `b` and merged those nodes. -/// -/// The nodes from graph `b` will replace nodes from `a`. -/// -/// At this point, only `PyDiGraph` is supported. -fn _digraph_union( +#[derive(Copy, Clone)] +enum Entry { + Merged(T), + Added(T), + None, +} + +fn extract(x: Entry) -> T { + match x { + Entry::Merged(val) => val, + Entry::Added(val) => val, + Entry::None => panic!("Unexpected internal error: called `Entry::extract()` on a `None` value. Please file an issue at https://github.com/Qiskit/retworkx/issues/new/choose with the details on how you encountered this."), + } +} + +fn union( py: Python, - a: &PyDiGraph, - b: &PyDiGraph, + first: &StablePyGraph, + second: &StablePyGraph, merge_nodes: bool, merge_edges: bool, -) -> PyResult { - let first = &a.graph; - let second = &b.graph; - let mut combined = PyDiGraph { - graph: first.clone(), - cycle_state: algo::DfsSpace::default(), - check_cycle: false, - node_removed: false, - multigraph: true, - }; - let mut node_map = HashMap::with_capacity(second.node_count()); - let mut edge_map = HashSet::with_capacity(second.edge_count()); - - let compare_weights = |a: &PyAny, b: &PyAny| -> PyResult { - let res = a.compare(b)?; - Ok(res == Ordering::Equal) - }; +) -> PyResult> { + let mut out_graph = first.clone(); + let mut node_map: Vec> = + vec![Entry::None; second.node_bound()]; for node in second.node_indices() { - let node_index = combined.add_node(second[node].clone_ref(py))?; - node_map.insert(node.index(), node_index); - } - - for edge in b.weighted_edge_list(py).edges { - let source = edge.0; - let target = edge.1; - let edge_weight = edge.2; + let weight = &second[node]; + if merge_nodes { + if let Some(index) = find_node_by_weight(py, first, weight)? { + node_map[node.index()] = Entry::Merged(index); + continue; + } + } - let new_source = *node_map.get(&source).unwrap(); - let new_target = *node_map.get(&target).unwrap(); + let index = out_graph.add_node(weight.clone_ref(py)); + node_map[node.index()] = Entry::Added(index); + } - let edge_index = combined.add_edge( - new_source, - new_target, - edge_weight.clone_ref(py), - )?; + let weights_equal = |a: &PyObject, b: &PyObject| -> PyResult { + a.as_ref(py) + .rich_compare(b, pyo3::basic::CompareOp::Eq)? + .is_true() + }; - let edge_node = EdgeIndex::new(edge_index); + for edge in second.edge_references() { + let source = edge.source().index(); + let target = edge.target().index(); + let new_weight = edge.weight(); - if combined.has_edge(source, target) { - let w = combined.graph.edge_weight(edge_node).unwrap(); - if compare_weights(edge_weight.as_ref(py), w.as_ref(py)).unwrap() { - edge_map.insert(edge_node); + let mut found = false; + if merge_edges { + // if both endpoints were merged, + // check if need to skip the edge as well. + if let (Entry::Merged(new_source), Entry::Merged(new_target)) = + (node_map[source], node_map[target]) + { + for edge in first.edges(new_source) { + if edge.target() == new_target + && weights_equal(new_weight, edge.weight())? + { + found = true; + break; + } + } } } - } - if merge_nodes { - for node in second.node_indices() { - let weight = &second[node].clone_ref(py); - let index = a.find_node_by_weight(py, weight.clone_ref(py)); - - if index.is_some() { - let other_node = node_map.get(&node.index()); - combined.merge_nodes( - py, - *other_node.unwrap(), - index.unwrap(), - )?; - } + if !found { + let new_source = extract(node_map[source]); + let new_target = extract(node_map[target]); + out_graph.add_edge( + new_source, + new_target, + new_weight.clone_ref(py), + ); } } - if merge_edges { - for edge in edge_map { - combined.graph.remove_edge(edge); - } - } + Ok(out_graph) +} - Ok(combined) +/// Return a new PyGraph by forming a union from two input PyGraph objects +/// +/// The algorithm in this function operates in three phases: +/// +/// 1. Add all the nodes from ``second`` into ``first``. operates in +/// :math:`\mathcal{O}(n_2)`, with :math:`n_2` being number of nodes in +/// ``second``. +/// 2. Merge nodes from ``second`` over ``first`` given that: +/// +/// - The ``merge_nodes`` is ``True``. operates in :math:`\mathcal{O}(n_1 n_2)`, +/// with :math:`n_1` being the number of nodes in ``first`` and :math:`n_2` +/// the number of nodes in ``second`` +/// - The respective node in ``second`` and ``first`` share the same +/// weight/data payload. +/// +/// 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges`` +/// parameter is ``True`` and the respective edge in ``second`` and +/// ``first`` share the same weight/data payload they will be merged together. +/// +/// :param PyGraph first: The first undirected graph object +/// :param PyGraph second: The second undirected graph object +/// :param bool merge_nodes: If set to ``True`` nodes will be merged between +/// ``second`` and ``first`` if the weights are equal. Default: ``False``. +/// :param bool merge_edges: If set to ``True`` edges will be merged between +/// ``second`` and ``first`` if the weights are equal. Default: ``False``. +/// +/// :returns: A new PyGraph object that is the union of ``second`` and +/// ``first``. It's worth noting the weight/data payload objects are +/// passed by reference from ``first`` and ``second`` to this new object. +/// :rtype: PyGraph +#[pyfunction(merge_nodes = false, merge_edges = false)] +#[pyo3( + text_signature = "(first, second, /, merge_nodes=False, merge_edges=False)" +)] +fn graph_union( + py: Python, + first: &graph::PyGraph, + second: &graph::PyGraph, + merge_nodes: bool, + merge_edges: bool, +) -> PyResult { + let out_graph = + union(py, &first.graph, &second.graph, merge_nodes, merge_edges)?; + + Ok(graph::PyGraph { + graph: out_graph, + node_removed: first.node_removed, + multigraph: true, + }) } /// Return a new PyDiGraph by forming a union from two input PyDiGraph objects /// /// The algorithm in this function operates in three phases: /// -/// 1. Add all the nodes from ``second`` into ``first``. operates in O(n), -/// with n being number of nodes in `b`. -/// 2. Merge nodes from ``second`` over ``first`` given that: +/// 1. Add all the nodes from ``second`` into ``first``. operates in +/// :math:`\mathcal{O}(n_2)`, with :math:`n_2` being number of nodes in +/// ``second``. +/// 2. Merge nodes from ``second`` over ``first`` given that: /// -/// - The ``merge_nodes`` is ``True``. operates in O(n^2), with n being the -/// number of nodes in ``second``. -/// - The respective node in ``second`` and ``first`` share the same -/// weight/data payload. +/// - The ``merge_nodes`` is ``True``. operates in :math:`\mathcal{O}(n_1 n_2)`, +/// with :math:`n_1` being the number of nodes in ``first`` and :math:`n_2` +/// the number of nodes in ``second`` +/// - The respective node in ``second`` and ``first`` share the same +/// weight/data payload. /// -/// 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges`` -/// parameter is ``True`` and the respective edge in ``second`` and -/// first`` share the same weight/data payload they will be merged -/// together. +/// 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges`` +/// parameter is ``True`` and the respective edge in ``second`` and +/// ``first`` share the same weight/data payload they will be merged together. /// -/// :param PyDiGraph first: The first directed graph object -/// :param PyDiGraph second: The second directed graph object -/// :param bool merge_nodes: If set to ``True`` nodes will be merged between -/// ``second`` and ``first`` if the weights are equal. -/// :param bool merge_edges: If set to ``True`` edges will be merged between -/// ``second`` and ``first`` if the weights are equal. +/// :param PyDiGraph first: The first directed graph object +/// :param PyDiGraph second: The second directed graph object +/// :param bool merge_nodes: If set to ``True`` nodes will be merged between +/// ``second`` and ``first`` if the weights are equal. Default: ``False``. +/// :param bool merge_edges: If set to ``True`` edges will be merged between +/// ``second`` and ``first`` if the weights are equal. Default: ``False``. /// -/// :returns: A new PyDiGraph object that is the union of ``second`` and +/// :returns: A new PyDiGraph object that is the union of ``second`` and /// ``first``. It's worth noting the weight/data payload objects are /// passed by reference from ``first`` and ``second`` to this new object. -/// :rtype: PyDiGraph -#[pyfunction] -#[pyo3(text_signature = "(first, second, merge_nodes, merge_edges, /)")] +/// :rtype: PyDiGraph +#[pyfunction(merge_nodes = false, merge_edges = false)] +#[pyo3( + text_signature = "(first, second, /, merge_nodes=False, merge_edges=False)" +)] fn digraph_union( py: Python, first: &digraph::PyDiGraph, @@ -150,6 +193,14 @@ fn digraph_union( merge_nodes: bool, merge_edges: bool, ) -> PyResult { - let res = _digraph_union(py, first, second, merge_nodes, merge_edges)?; - Ok(res) + let out_graph = + union(py, &first.graph, &second.graph, merge_nodes, merge_edges)?; + + Ok(digraph::PyDiGraph { + graph: out_graph, + cycle_state: algo::DfsSpace::default(), + check_cycle: false, + node_removed: first.node_removed, + multigraph: true, + }) } diff --git a/tests/digraph/test_union.py b/tests/digraph/test_union.py index d32b7a528..5ca44cd9c 100644 --- a/tests/digraph/test_union.py +++ b/tests/digraph/test_union.py @@ -31,23 +31,6 @@ def test_union_merge_all(self): self.assertTrue(retworkx.is_isomorphic(dag_a, dag_c)) - def test_union_basic_merge_edges_only(self): - dag_a = retworkx.PyDiGraph() - dag_b = retworkx.PyDiGraph() - - node_a = dag_a.add_node("a_1") - dag_a.add_child(node_a, "a_2", "e_1") - dag_a.add_child(node_a, "a_3", "e_2") - - node_b = dag_b.add_node("a_1") - dag_b.add_child(node_b, "a_2", "e_1") - dag_b.add_child(node_b, "a_3", "e_2") - - dag_c = retworkx.digraph_union(dag_a, dag_b, False, True) - - self.assertTrue(len(dag_c.edge_list()) == 2) - self.assertTrue(len(dag_c.nodes()) == 6) - def test_union_basic_merge_nodes_only(self): dag_a = retworkx.PyDiGraph() dag_b = retworkx.PyDiGraph() @@ -82,3 +65,47 @@ def test_union_basic_merge_none(self): self.assertTrue(len(dag_c.nodes()) == 6) self.assertTrue(len(dag_c.edge_list()) == 4) + + def test_union_mismatch_edge_weight(self): + first = retworkx.PyDiGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyDiGraph() + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.digraph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 1, "b")]) + + def test_union_node_hole(self): + first = retworkx.PyDiGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyDiGraph() + dummy = second.add_node("dummy") + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "a")]) + second.remove_node(dummy) + + final = retworkx.digraph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a")]) + + def test_union_edge_between_merged_and_unmerged_nodes(self): + first = retworkx.PyDiGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyDiGraph() + nodes = second.add_nodes_from([0, 2]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.digraph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 2, "b")]) diff --git a/tests/graph/test_union.py b/tests/graph/test_union.py new file mode 100644 index 000000000..b7e21fd46 --- /dev/null +++ b/tests/graph/test_union.py @@ -0,0 +1,88 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import retworkx + + +class TestUnion(unittest.TestCase): + def setUp(self): + self.graph = retworkx.PyGraph() + self.graph.add_nodes_from(["a_1", "a_2", "a_3"]) + self.graph.extend_from_weighted_edge_list( + [(0, 1, "e_1"), (1, 2, "e_2")] + ) + + def test_union_basic_merge_none(self): + final = retworkx.graph_union( + self.graph, self.graph, merge_nodes=False, merge_edges=False + ) + self.assertTrue(len(final.nodes()) == 6) + self.assertTrue(len(final.edge_list()) == 4) + + def test_union_merge_all(self): + final = retworkx.graph_union( + self.graph, self.graph, merge_nodes=True, merge_edges=True + ) + self.assertTrue(retworkx.is_isomorphic(final, self.graph)) + + def test_union_basic_merge_nodes_only(self): + final = retworkx.graph_union( + self.graph, self.graph, merge_nodes=True, merge_edges=False + ) + self.assertTrue(len(final.edge_list()) == 4) + self.assertTrue(len(final.get_all_edge_data(0, 1)) == 2) + self.assertTrue(len(final.nodes()) == 3) + + def test_union_mismatch_edge_weight(self): + first = retworkx.PyGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyGraph() + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.graph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 1, "b")]) + + def test_union_node_hole(self): + first = retworkx.PyGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyGraph() + dummy = second.add_node("dummy") + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "a")]) + second.remove_node(dummy) + + final = retworkx.graph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a")]) + + def test_union_edge_between_merged_and_unmerged_nodes(self): + first = retworkx.PyGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyGraph() + nodes = second.add_nodes_from([0, 2]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.graph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 2, "b")])