Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move collect_bicolor_runs() to rustworkx-core #1186

Merged
merged 18 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions rustworkx-core/src/collect_bicolor_runs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use std::cmp::Eq;
use std::error::Error;
use std::fmt::{Debug, Display, Formatter};
use std::hash::Hash;

use petgraph::algo;
use petgraph::data::DataMap;
use petgraph::visit::Data;
use petgraph::visit::{
EdgeRef, GraphBase, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers,
NodeIndexable, Visitable,
};

/// Define custom error classes for collect_bicolor_runs
// TODO: clean up once the code compiles
#[derive(Debug)]
pub enum CollectBicolorError<E: Error> {
DAGWouldCycle,
CallableError(E), //placeholder, may remove if not used
}

impl<E: Error> Display for CollectBicolorError<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
CollectBicolorError::DAGWouldCycle => fmt_dag_would_cycle(f),
CollectBicolorError::CallableError(ref e) => fmt_callable_error(f, e),
}
}
}

impl<E: Error> Error for CollectBicolorError<E> {}

fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "The operation would introduce a cycle.")
}

fn fmt_callable_error<E: Error>(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result {
write!(f, "The function failed with: {:?}", inner)
}
/// Collect runs that match a filter function given edge colors
///
/// A bicolor run is a list of group of nodes connected by edges of exactly
/// two colors. In addition, all nodes in the group must match the given
/// condition. Each node in the graph can appear in only a single group
/// in the bicolor run.
///
/// :param PyDiGraph graph: The graph to find runs in
/// :param filter_fn: The filter function to use for matching nodes. It takes
/// in one argument, the node data payload/weight object, and will return a
/// boolean whether the node matches the conditions or not.
/// If it returns ``True``, it will continue the bicolor chain.
/// If it returns ``False``, it will stop the bicolor chain.
/// If it returns ``None`` it will skip that node.
/// :param color_fn: The function that gives the color of the edge. It takes
/// in one argument, the edge data payload/weight object, and will
/// return a non-negative integer, the edge color. If the color is None,
/// the edge is ignored.
///
/// :returns: a list of groups with exactly two edge colors, where each group
/// is a list of node data payload/weight for the nodes in the bicolor run
/// :rtype: list
pub fn collect_bicolor_runs<G, F, C, B, E>(
graph: G,
filter_fn: F,
color_fn: C,
) -> Result<Vec<Vec<G::NodeId>>, CollectBicolorError<E>>
//OG type: PyResult<Vec<Vec<PyObject>>>
where
E: Error,
// add Option to input type because of line 135
F: Fn(&Option<&<G as Data>::NodeWeight>) -> Result<Option<bool>, CollectBicolorError<E>>, //OG input: &PyObject, OG return: PyResult<Option<bool>>
C: Fn(&<G as Data>::EdgeWeight) -> Result<Option<usize>, CollectBicolorError<E>>, //OG input: &PyObject, OG return: PyResult<Option<usize>>
G: NodeIndexable
// can take node index type and convert to usize. It restricts node index type.
+ IntoNodeIdentifiers // used in toposort. Turns graph into list of nodes
+ IntoNeighborsDirected // used in toposort
+ IntoEdgesDirected // used in line 138
+ Visitable // used in toposort
+ DataMap, // used to access node weights
<G as GraphBase>::NodeId: Eq + Hash,
{
let mut pending_list: Vec<Vec<G::NodeId>> = Vec::new(); //OG type: Vec<Vec<PyObject>>
let mut block_id: Vec<Option<usize>> = Vec::new(); //OG type: Vec<Option<usize>>
let mut block_list: Vec<Vec<G::NodeId>> = Vec::new(); //OG type: Vec<Vec<PyObject>> -> return

let filter_node =
|node: &Option<&<G as Data>::NodeWeight>| -> Result<Option<bool>, CollectBicolorError<E>> {
filter_fn(node)
};

let color_edge =
|edge: &<G as Data>::EdgeWeight| -> Result<Option<usize>, CollectBicolorError<E>> {
color_fn(edge)
};

let nodes = match algo::toposort(&graph, None) {
Ok(nodes) => nodes,
Err(_err) => return Err(CollectBicolorError::DAGWouldCycle),
};

// Utility for ensuring pending_list has the color index
macro_rules! ensure_vector_has_index {
($pending_list: expr, $block_id: expr, $color: expr) => {
if $color >= $pending_list.len() {
$pending_list.resize($color + 1, Vec::new());
$block_id.resize($color + 1, None);
}
};
}

for node in nodes {
if let Some(is_match) = filter_node(&graph.node_weight(node))? {
let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing);

// Remove all edges that do not yield errors from color_fn
let colors = raw_edges
.map(|edge| {
let edge_weight = edge.weight();
color_edge(edge_weight)
})
.collect::<Result<Vec<Option<usize>>, _>>()?;

// Remove null edges from color_fn
let colors = colors.into_iter().flatten().collect::<Vec<usize>>();

// &NodeIndexable::from_index(&graph, node)
if colors.len() <= 2 && is_match {
if colors.len() == 1 {
let c0 = colors[0];
ensure_vector_has_index!(pending_list, block_id, c0);
if let Some(c0_block_id) = block_id[c0] {
block_list[c0_block_id].push(node);
} else {
pending_list[c0].push(node);
}
} else if colors.len() == 2 {
let c0 = colors[0];
let c1 = colors[1];
ensure_vector_has_index!(pending_list, block_id, c0);
ensure_vector_has_index!(pending_list, block_id, c1);

if block_id[c0].is_some()
&& block_id[c1].is_some()
&& block_id[c0] == block_id[c1]
{
block_list[block_id[c0].unwrap_or_default()].push(node);
} else {
let mut new_block: Vec<G::NodeId> =
Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1);

// Clears pending lits and add to new block
new_block.append(&mut pending_list[c0]);
new_block.append(&mut pending_list[c1]);

new_block.push(node);

// Create new block, assign its id to color pair
block_id[c0] = Some(block_list.len());
block_id[c1] = Some(block_list.len());
block_list.push(new_block);
}
}
} else {
for color in colors {
ensure_vector_has_index!(pending_list, block_id, color);
if let Some(color_block_id) = block_id[color] {
block_list[color_block_id].append(&mut pending_list[color]);
}
block_id[color] = None;
pending_list[color].clear();
}
}
}
}

Ok(block_list)
}
1 change: 1 addition & 0 deletions rustworkx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub type Result<T, E = Infallible> = core::result::Result<T, E>;
pub mod bipartite_coloring;
/// Module for centrality algorithms.
pub mod centrality;
pub mod collect_bicolor_runs;
ElePT marked this conversation as resolved.
Show resolved Hide resolved
/// Module for coloring algorithms.
pub mod coloring;
pub mod connectivity;
Expand Down
Loading