Skip to content

Commit

Permalink
Fixes digraph_union if merge_edges is set to true. (#439)
Browse files Browse the repository at this point in the history
* Fixes `digraph_union` if `merge_edges` is set to true.
Previously, `digraph_union` would falsely keep or delete
edges if `merge_edges` is set to true. This commit fixes
the logic of `digraph_union` to skip an edge from the second
graph if both its endpoints were merged to nodes from the
first graph and these nodes already share an edge with equal
weight data. At the same time, a new function `graph_union`
was added that returns the union of two `PyGraph`s.
Closes #432.

* increase test cov

* improve release note and message in panic exception

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* add dispatch function `union` and implement `find_node_by_weight` for `PyGraph`

* lint

* Release note fixes

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>
  • Loading branch information
georgios-ts and mtreinish committed Sep 8, 2021
1 parent f2f3a09 commit 466b884
Show file tree
Hide file tree
Showing 9 changed files with 434 additions and 140 deletions.
4 changes: 3 additions & 1 deletion docs/source/api.rst
Expand Up @@ -145,7 +145,7 @@ Other Algorithm Functions
retworkx.transitivity
retworkx.core_number
retworkx.graph_greedy_color
retworkx.digraph_union
retworkx.union
retworkx.metric_closure

Generators
Expand Down Expand Up @@ -242,6 +242,7 @@ the functions from the explicitly typed based on the data type.
retworkx.digraph_transitivity
retworkx.digraph_core_number
retworkx.digraph_complement
retworkx.digraph_union
retworkx.digraph_random_layout
retworkx.digraph_bipartite_layout
retworkx.digraph_circular_layout
Expand Down Expand Up @@ -283,6 +284,7 @@ typed API based on the data type.
retworkx.graph_transitivity
retworkx.graph_core_number
retworkx.graph_complement
retworkx.graph_union
retworkx.graph_random_layout
retworkx.graph_bipartite_layout
retworkx.graph_circular_layout
Expand Down
33 changes: 33 additions & 0 deletions releasenotes/notes/bugfix-union-7da79789134a3028.yaml
@@ -0,0 +1,33 @@
---
fixes:
- |
Previously, :func:`~retworkx.digraph_union` would incorrectly keep or delete edges
if argument ``merge_edges`` is set to true. This has been fixed and an edge from
the second graph will be skipped if both its endpoints were merged to nodes from
the first graph and these nodes already share an edge with equal weight data.
Fixed `#432 <https://github.com/Qiskit/retworkx/issues/432>`__
features:
- |
Add a new function :func:`~retworkx.graph_union` that returns the union
of two :class:`~retworkx.PyGraph` objects. This is the equivalent to
:func:`~retworkx.digraph_union` but for a :class:`~retworkx.PyGraph`
instead of for a :class:`~retworkx.PyDiGraph`. A new unified function
:func:`~retworkx.union` was also added that supports both
:class:`~retworkx.PyDiGraph` and :class:`~retworkx.PyGraph`.
For example:
.. jupyter-execute::
import retworkx
from retworkx.visualization import mpl_draw
first = retworkx.generators.path_graph(3, weights=["a_0", "node", "a_1"])
second = retworkx.generators.cycle_graph(3, weights=["node", "b_0", "b_1"])
graph = retworkx.graph_union(first, second, merge_nodes=True)
mpl_draw(graph)
- |
The kwargs ``merge_nodes`` and ``merge_edges`` of :func:`~retworkx.digraph_union` are
now optional and by default are set `False`.
- |
Add a new :meth:`~retworkx.PyGraph.find_node_by_weight` that finds the index
of a node given a specific weight.
61 changes: 61 additions & 0 deletions retworkx/__init__.py
Expand Up @@ -1633,3 +1633,64 @@ def _graph_vf2_mapping(
induced=induced,
call_limit=call_limit,
)


@functools.singledispatch
def union(
first,
second,
merge_nodes=False,
merge_edges=False,
):
"""Return a new graph by forming a union from two input graph objects
The algorithm in this function operates in three phases:
1. Add all the nodes from ``second`` into ``first``. operates in
:math:`\\mathcal{O}(n_2)`, with :math:`n_2` being number of nodes in
``second``.
2. Merge nodes from ``second`` over ``first`` given that:
- The ``merge_nodes`` is ``True``. operates in :math:`\\mathcal{O}(n_1 n_2)`,
with :math:`n_1` being the number of nodes in ``first`` and :math:`n_2`
the number of nodes in ``second``
- The respective node in ``second`` and ``first`` share the same
weight/data payload.
3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges``
parameter is ``True`` and the respective edge in ``second`` and
``first`` share the same weight/data payload they will be merged together.
:param first: The first graph object
:param second: The second graph object
:param bool merge_nodes: If set to ``True`` nodes will be merged between
``second`` and ``first`` if the weights are equal. Default: ``False``.
:param bool merge_edges: If set to ``True`` edges will be merged between
``second`` and ``first`` if the weights are equal. Default: ``False``.
:returns: A new graph object that is the union of ``second`` and
``first``. It's worth noting the weight/data payload objects are
passed by reference from ``first`` and ``second`` to this new object.
:rtype: :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph`
"""
raise TypeError("Invalid Input Type %s for graph" % type(first))


@union.register(PyDiGraph)
def _digraph_union(
first,
second,
merge_nodes=False,
merge_edges=False,
):
return digraph_union(first, second, merge_nodes=False, merge_edges=False)


@union.register(PyGraph)
def _graph_union(
first,
second,
merge_nodes=False,
merge_edges=False,
):
return graph_union(first, second, merge_nodes=False, merge_edges=False)
22 changes: 5 additions & 17 deletions src/digraph.rs
Expand Up @@ -50,8 +50,8 @@ use super::iterators::{
EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList,
};
use super::{
DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes, NoSuitableNeighbors,
NodesRemoved,
find_node_by_weight, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes,
NoSuitableNeighbors, NodesRemoved,
};

use super::dag_algo::is_directed_acyclic_graph;
Expand Down Expand Up @@ -1401,21 +1401,9 @@ impl PyDiGraph {
&self,
py: Python,
obj: PyObject,
) -> Option<usize> {
let mut index = None;
for node in self.graph.node_indices() {
let weight = self.graph.node_weight(node).unwrap();
let weight_compare = |a: &PyAny, b: &PyAny| -> PyResult<bool> {
let res = a.compare(b)?;
Ok(res == Ordering::Equal)
};

if weight_compare(obj.as_ref(py), weight.as_ref(py)).unwrap() {
index = Some(node.index());
break;
}
}
index
) -> PyResult<Option<usize>> {
find_node_by_weight(py, &self.graph, &obj)
.map(|node| node.map(|x| x.index()))
}

/// Merge two nodes in the graph.
Expand Down
22 changes: 21 additions & 1 deletion src/graph.rs
Expand Up @@ -35,7 +35,7 @@ use super::dot_utils::build_dot;
use super::iterators::{
EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList,
};
use super::{NoEdgeBetweenNodes, NodesRemoved};
use super::{find_node_by_weight, NoEdgeBetweenNodes, NodesRemoved};

use petgraph::algo;
use petgraph::graph::{EdgeIndex, NodeIndex};
Expand Down Expand Up @@ -1020,6 +1020,26 @@ impl PyGraph {
Ok(())
}

/// Find node within this graph given a specific weight
///
/// This algorithm has a worst case of O(n) since it searches the node
/// indices in order. If there is more than one node in the graph with the
/// same weight only the first match (by node index) will be returned.
///
/// :param obj: The weight to look for in the graph.
///
/// :returns: the index of the first node in the graph that is equal to the
/// weight. If no match is found ``None`` will be returned.
/// :rtype: int
pub fn find_node_by_weight(
&self,
py: Python,
obj: PyObject,
) -> PyResult<Option<usize>> {
find_node_by_weight(py, &self.graph, &obj)
.map(|node| node.map(|x| x.index()))
}

/// Get the index and data for the neighbors of a node.
///
/// This will return a dictionary where the keys are the node indexes of
Expand Down
24 changes: 24 additions & 0 deletions src/lib.rs
Expand Up @@ -62,9 +62,12 @@ use petgraph::visit::{
Data, GraphBase, GraphProp, IntoEdgeReferences, IntoNodeIdentifiers,
NodeCount, NodeIndexable,
};
use petgraph::EdgeType;

use crate::generators::PyInit_generators;

type StablePyGraph<Ty> = StableGraph<PyObject, PyObject, Ty>;

pub trait NodesRemoved {
fn nodes_removed(&self) -> bool;
}
Expand Down Expand Up @@ -130,6 +133,26 @@ fn weight_callable(
}
}

fn find_node_by_weight<Ty: EdgeType>(
py: Python,
graph: &StablePyGraph<Ty>,
obj: &PyObject,
) -> PyResult<Option<NodeIndex>> {
let mut index = None;
for node in graph.node_indices() {
let weight = graph.node_weight(node).unwrap();
if obj
.as_ref(py)
.rich_compare(weight, pyo3::basic::CompareOp::Eq)?
.is_true()?
{
index = Some(node);
break;
}
}
Ok(index)
}

// The provided node is invalid.
create_exception!(retworkx, InvalidNode, PyException);
// Performing this operation would result in trying to add a cycle to a DAG.
Expand Down Expand Up @@ -171,6 +194,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(digraph_vf2_mapping))?;
m.add_wrapped(wrap_pyfunction!(graph_vf2_mapping))?;
m.add_wrapped(wrap_pyfunction!(digraph_union))?;
m.add_wrapped(wrap_pyfunction!(graph_union))?;
m.add_wrapped(wrap_pyfunction!(topological_sort))?;
m.add_wrapped(wrap_pyfunction!(descendants))?;
m.add_wrapped(wrap_pyfunction!(ancestors))?;
Expand Down

0 comments on commit 466b884

Please sign in to comment.