Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added new methods, :meth:`.PyDiGraph.subsitute_subgraph` and
:meth:`.PyGraph.substitute_subgraph`, which is used to replace
a subgraph in a graph object with an external graph.
172 changes: 158 additions & 14 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::io::{BufReader, BufWriter};
use std::str;

use hashbrown::{HashMap, HashSet};
use indexmap::IndexSet;
use indexmap::{IndexMap, IndexSet};

use rustworkx_core::dictmap::*;

Expand Down Expand Up @@ -226,9 +226,10 @@ impl PyDiGraph {
p_index: NodeIndex,
c_index: NodeIndex,
edge: PyObject,
force: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name force for this parameter seems a bit misleading to me. In the context of this method's name, I would expect that force would mean to add the edge, even if it violates properties of the graph (e.g. introduces cycles). I'm also not a fan of us needing to pass a new parameter at all of the callsites for _add_edge.

Instead, what would you think about adding a new method called something like _error_if_would_cycle that takes the indices in question, and contains code moved from _add_edge which does the cycle check and error construction:

fn _error_if_would_cycle(&mut self, p_index: NodeIndex, c_index: NodeIndex) -> PyResult<()> {
    // Only check for a cycle (by running has_path_connecting) if
    // the new edge could potentially add a cycle
    let cycle_check_required = is_cycle_check_required(self, p_index, c_index);
    let state = Some(&mut self.cycle_state);
    if cycle_check_required
        && algo::has_path_connecting(&self.graph, c_index, p_index, state)
    {
        return Err(DAGWouldCycle::new_err("Adding an edge would cycle"));
    }
    Ok()
}

We could then call this from _add_edge if self.check_cycle is true, which preserves the current expectations of callers, and call it manually followed by a call to add_edge_no_cycle_check from methods like substitute_subgraph which offer cycle detection irrespective of whether or not self.check_cycle is enabled.

) -> PyResult<usize> {
// Only check for cycles if instance attribute is set to true
if self.check_cycle {
if self.check_cycle || force {
// Only check for a cycle (by running has_path_connecting) if
// the new edge could potentially add a cycle
let cycle_check_required = is_cycle_check_required(self, p_index, c_index);
Expand Down Expand Up @@ -269,11 +270,11 @@ impl PyDiGraph {
.collect::<Vec<(NodeIndex, EdgeIndex, PyObject)>>();
for (other_index, edge_index, weight) in edges {
if direction {
self._add_edge(node_between_index, index, weight.clone_ref(py))?;
self._add_edge(index, other_index, weight.clone_ref(py))?;
self._add_edge(node_between_index, index, weight.clone_ref(py), false)?;
self._add_edge(index, other_index, weight.clone_ref(py), false)?;
} else {
self._add_edge(other_index, index, weight.clone_ref(py))?;
self._add_edge(index, node_between_index, weight.clone_ref(py))?;
self._add_edge(other_index, index, weight.clone_ref(py), false)?;
self._add_edge(index, node_between_index, weight.clone_ref(py), false)?;
}
self.graph.remove_edge(edge_index);
}
Expand Down Expand Up @@ -1029,7 +1030,7 @@ impl PyDiGraph {
}
}
for (source, target, weight) in edge_list {
self._add_edge(source, target, weight)?;
self._add_edge(source, target, weight, false)?;
}
self.graph.remove_node(index);
self.node_removed = true;
Expand Down Expand Up @@ -1061,7 +1062,7 @@ impl PyDiGraph {
"One of the endpoints of the edge does not exist in graph",
));
}
let out_index = self._add_edge(p_index, c_index, edge)?;
let out_index = self._add_edge(p_index, c_index, edge, false)?;
Ok(out_index)
}

Expand Down Expand Up @@ -1131,7 +1132,12 @@ impl PyDiGraph {
while max_index >= self.node_count() {
self.graph.add_node(py.None());
}
self._add_edge(NodeIndex::new(source), NodeIndex::new(target), py.None())?;
self._add_edge(
NodeIndex::new(source),
NodeIndex::new(target),
py.None(),
false,
)?;
}
Ok(())
}
Expand All @@ -1156,7 +1162,12 @@ impl PyDiGraph {
while max_index >= self.node_count() {
self.graph.add_node(py.None());
}
self._add_edge(NodeIndex::new(source), NodeIndex::new(target), weight)?;
self._add_edge(
NodeIndex::new(source),
NodeIndex::new(target),
weight,
false,
)?;
}
Ok(())
}
Expand Down Expand Up @@ -2264,7 +2275,7 @@ impl PyDiGraph {
let new_p_index = new_node_map.get(&edge.source()).unwrap();
let new_c_index = new_node_map.get(&edge.target()).unwrap();
let weight = weight_transform_callable(py, &edge_map_func, edge.weight())?;
self._add_edge(*new_p_index, *new_c_index, weight)?;
self._add_edge(*new_p_index, *new_c_index, weight, false)?;
}
// Add edges from map
for (this_index, (index, weight)) in node_map.iter() {
Expand All @@ -2273,6 +2284,7 @@ impl PyDiGraph {
NodeIndex::new(*this_index),
*new_index,
weight.clone_ref(py),
false,
)?;
}
let out_dict = PyDict::new(py);
Expand Down Expand Up @@ -2378,6 +2390,7 @@ impl PyDiGraph {
NodeIndex::new(out_map[&edge.source().index()]),
NodeIndex::new(out_map[&edge.target().index()]),
weight_map_fn(edge.weight(), &edge_weight_map)?,
false,
)?;
}
// Add edges to/from node to nodes in other
Expand Down Expand Up @@ -2405,7 +2418,7 @@ impl PyDiGraph {
},
None => continue,
};
self._add_edge(source, target_out, weight)?;
self._add_edge(source, target_out, weight, false)?;
}
for (source, target, weight) in out_edges {
let old_index = map_fn(source.index(), target.index(), &weight)?;
Expand All @@ -2421,7 +2434,7 @@ impl PyDiGraph {
},
None => continue,
};
self._add_edge(source_out, target, weight)?;
self._add_edge(source_out, target, weight, false)?;
}
// Remove node
self.remove_node(node_index.index())?;
Expand Down Expand Up @@ -2605,6 +2618,137 @@ impl PyDiGraph {
}
}

/// Substitute a subgraph in the graph with a different subgraph
///
/// This is used to replace a subgraph in this graph with another graph. A similar result
/// can be achieved by combining :meth:`~.PyDiGraph.contract_nodes` and
/// :meth:`~.PyDiGraph.substitute_node_with_subgraph`.
///
/// :param list nodes: A list of nodes in this graph representing the subgraph
/// to be removed.
/// :param PyDiGraph other: The subgraph to replace ``nodes`` with
/// :param dict input_node_map: The mapping of node indices from ``nodes`` to a node
/// in ``subgraph``. This is used for incoming and outgoing edges into the removed
/// subgraph. This will replace any edges connected to a node in ``nodes`` with the
/// other endpoint outside ``nodes`` where the node in ``nodes`` replaced via this
/// mapping.
/// :param callable edge_weight_map: An optional callable object that when
/// used will receive an edge's weight/data payload from ``subgraph`` and
/// will return an object to use as the weight for a newly created edge
/// after the edge is mapped from ``other``. If not specified the weight
/// from the edge in ``other`` will be copied by reference and used.
///
/// :param bool cycle_check: To check and raise if the substitution would introduce a cycle.
/// If set to ``True`` or :attr:`.check_cycle` is set to ``True`` when a cycle would be
/// added a :class:`~.DAGWouldCycle` exception will be raised. However, in this case the
/// state of the graph will be partially modified through the internal steps required for the
/// substitution. If your intent is to detect and use the graph if a
/// cycle were to be detected, you should make a copy of the graph
/// (see :meth:`.copy`) prior to calling this method so you have a
/// copy of the input graph to use.
///
/// :returns: A mapping of node indices in ``other`` to the new node index in this graph
/// :rtype: NodeMap
///
/// :raises DAGWouldCycle: If ``cycle_check`` or the :attr:`.check_cycle` attribute are set to
/// ``True`` and a cycle would be introduced by the substitution.
#[pyo3(signature=(nodes, other, input_node_map, edge_weight_map=None, cycle_check=false))]
pub fn substitute_subgraph(
&mut self,
py: Python,
nodes: Vec<usize>,
other: &PyDiGraph,
input_node_map: HashMap<usize, usize>,
edge_weight_map: Option<PyObject>,
cycle_check: bool,
) -> PyResult<NodeMap> {
let mut in_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new();
let mut out_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new();
let mut node_map: IndexMap<usize, usize, ahash::RandomState> =
IndexMap::with_capacity_and_hasher(
other.graph.node_count(),
ahash::RandomState::default(),
);
let removed_nodes: HashSet<NodeIndex> = nodes.iter().map(|n| NodeIndex::new(*n)).collect();

let weight_map_fn = |obj: &PyObject, weight_fn: &Option<PyObject>| -> PyResult<PyObject> {
match weight_fn {
Some(weight_fn) => weight_fn.call1(py, (obj,)),
None => Ok(obj.clone_ref(py)),
}
};
for node in nodes {
let index = NodeIndex::new(node);
in_nodes.extend(
self.graph
.edges_directed(index, petgraph::Direction::Incoming)
.filter_map(|edge| {
if !removed_nodes.contains(&edge.source()) {
Some((edge.source(), edge.target(), edge.weight().clone_ref(py)))
} else {
None
}
}),
);
out_nodes.extend(
self.graph
.edges_directed(index, petgraph::Direction::Outgoing)
.filter_map(|edge| {
if !removed_nodes.contains(&edge.target()) {
Some((edge.source(), edge.target(), edge.weight().clone_ref(py)))
} else {
None
}
}),
);
self.graph.remove_node(index);
}
for node in other.graph.node_indices() {
let weight = other.graph.node_weight(node).unwrap();
let new_index = self.graph.add_node(weight.clone_ref(py));
node_map.insert(node.index(), new_index.index());
}
for edge in other.graph.edge_references() {
let new_source = node_map[edge.source().index()];
let new_target = node_map[edge.target().index()];
self._add_edge(
NodeIndex::new(new_source),
NodeIndex::new(new_target),
weight_map_fn(edge.weight(), &edge_weight_map)?,
cycle_check,
)?;
}
for edge in out_nodes {
let old_source = edge.0;
let new_source = match input_node_map.get(&old_source.index()) {
Some(new_source) => NodeIndex::new(node_map[new_source]),
None => {
let missing_index = old_source.index();
return Err(PyIndexError::new_err(format!(
"Input node {} not found in io_node_map",
missing_index
)));
}
};
self._add_edge(new_source, edge.1, edge.2, cycle_check)?;
}
for edge in in_nodes {
let old_target = edge.1;
let new_target = match input_node_map.get(&old_target.index()) {
Some(new_target) => NodeIndex::new(node_map[new_target]),
None => {
let missing_index = old_target.index();
return Err(PyIndexError::new_err(format!(
"Output node {} not found in io_node_map",
missing_index
)));
}
};
self._add_edge(edge.0, new_target, edge.2, cycle_check)?;
}
Ok(NodeMap { node_map })
}

/// Return a new PyDiGraph object for an edge induced subgraph of this graph
///
/// The induced subgraph contains each edge in `edge_list` and each node
Expand Down Expand Up @@ -2716,7 +2860,7 @@ impl PyDiGraph {
Some(callback) => callback.call1(py, (forward_weight,))?,
None => forward_weight.clone_ref(py),
};
self._add_edge(*edge_target, *edge_source, weight)?;
self._add_edge(*edge_target, *edge_source, weight, false)?;
}
}
Ok(())
Expand Down
Loading