Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions to parse node link JSON #1091

Merged
merged 8 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api/serialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ Serialization

rustworkx.node_link_json
rustworkx.read_graphml
rustworkx.from_node_link_json_file
rustworkx.parse_node_link_json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added two new functions, :func:`.parse_node_link_json_file` and
:func:`.parse_node_link_json_str`, which are used to parse a node link json
and generate a graph object from it.
3 changes: 3 additions & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ from .rustworkx import NoSuitableNeighbors as NoSuitableNeighbors
from .rustworkx import NullGraph as NullGraph
from .rustworkx import NegativeCycle as NegativeCycle
from .rustworkx import JSONSerializationError as JSONSerializationError
from .rustworkx import JSONDeserializationError as JSONDeserializationError
from .rustworkx import FailedToConverge as FailedToConverge
from .rustworkx import InvalidMapping as InvalidMapping
from .rustworkx import GraphNotBipartite as GraphNotBipartite
Expand Down Expand Up @@ -140,6 +141,8 @@ from .rustworkx import directed_random_bipartite_graph as directed_random_bipart
from .rustworkx import read_graphml as read_graphml
from .rustworkx import digraph_node_link_json as digraph_node_link_json
from .rustworkx import graph_node_link_json as graph_node_link_json
from .rustworkx import from_node_link_json_file as from_node_link_json_file
from .rustworkx import parse_node_link_json as parse_node_link_json
from .rustworkx import digraph_bellman_ford_shortest_paths as digraph_bellman_ford_shortest_paths
from .rustworkx import graph_bellman_ford_shortest_paths as graph_bellman_ford_shortest_paths
from .rustworkx import (
Expand Down
13 changes: 13 additions & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class NoSuitableNeighbors(Exception): ...
class NullGraph(Exception): ...
class NegativeCycle(Exception): ...
class JSONSerializationError(Exception): ...
class JSONDeserializationError(Exception): ...
class FailedToConverge(Exception): ...
class InvalidMapping(Exception): ...
class GraphNotBipartite(Exception): ...
Expand Down Expand Up @@ -640,6 +641,18 @@ def graph_node_link_json(
node_attrs: Callable[[_S], str] | None = ...,
edge_attrs: Callable[[_T], str] | None = ...,
) -> str | None: ...
def parse_node_link_json(
data: str,
graph_attrs: Callable[[dict[str, str]], Any] | None = ...,
node_attrs: Callable[[dict[str, str]], _S] | None = ...,
edge_attrs: Callable[[dict[str, str]], _T] | None = ...,
) -> PyDiGraph[_S, _T] | PyGraph[_S, _T]: ...
def from_node_link_json_file(
path: str,
graph_attrs: Callable[[dict[str, str]], Any] | None = ...,
node_attrs: Callable[[dict[str, str]], _S] | None = ...,
edge_attrs: Callable[[dict[str, str]], _T] | None = ...,
) -> PyDiGraph[_S, _T] | PyGraph[_S, _T]: ...

# Shortest Path

Expand Down
156 changes: 155 additions & 1 deletion src/json/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,165 @@

mod node_link_data;

use crate::{digraph, graph};
use std::fs::File;
use std::io::BufReader;

use crate::{digraph, graph, JSONDeserializationError, StablePyGraph};
use petgraph::{algo, Directed, Undirected};

use pyo3::prelude::*;
use pyo3::Python;

/// Parse a node-link format JSON file to generate a graph
///
/// :param path str: The path to the JSON file to load
/// :param graph_attrs: An optional callable that will be passed a dictionary
/// with string keys and string values and is expected to return a Python
/// object to use for :attr:`~.PyGraph.attrs` attribute of the output graph.
/// If not specified the dictionary with string keys and string values will
/// be used as the value for ``attrs``.
/// :param node_attrs: An optional callable that will be passed a dictionary with
/// string keys and string values representing the data payload
/// for each node in the graph and is expected to return a Python object to
/// use for the data payload of the node. If not specified the dictionary with
/// string keys and string values will be used for the nodes' data payload.
/// :param edge_attrs: An optional callable that will be passed a dictionary with
/// string keys and string values representing the data payload
/// for each edge in the graph and is expected to return a Python object to
/// use for the data payload of the node. If not specified the dictionary with
/// string keys and string values will be used for the edge' data payload.
///
/// :returns: The graph represented by the node link JSON
/// :rtype: PyGraph | PyDiGraph
#[pyfunction]
pub fn from_node_link_json_file(
py: Python,
path: &str,
graph_attrs: Option<PyObject>,
node_attrs: Option<PyObject>,
edge_attrs: Option<PyObject>,
) -> PyResult<PyObject> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let graph: node_link_data::GraphInput = match serde_json::from_reader(reader) {
Ok(v) => v,
Err(e) => {
return Err(JSONDeserializationError::new_err(format!(
"JSON Deserialization Error {}",
e
)));
}
};
let attrs: PyObject = match graph.attrs {
Some(ref attrs) => match graph_attrs {
Some(ref callback) => callback.call1(py, (attrs.clone(),))?,
None => attrs.to_object(py),
},
None => py.None(),
};
let multigraph = graph.multigraph;

Ok(if graph.directed {
let mut inner_graph: StablePyGraph<Directed> =
StablePyGraph::with_capacity(graph.nodes.len(), graph.links.len());
node_link_data::parse_node_link_data(&py, graph, &mut inner_graph, node_attrs, edge_attrs)?;
digraph::PyDiGraph {
graph: inner_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph,
attrs,
}
.into_py(py)
} else {
let mut inner_graph: StablePyGraph<Undirected> =
StablePyGraph::with_capacity(graph.nodes.len(), graph.links.len());
node_link_data::parse_node_link_data(&py, graph, &mut inner_graph, node_attrs, edge_attrs)?;

graph::PyGraph {
graph: inner_graph,
node_removed: false,
multigraph,
attrs,
}
.into_py(py)
})
}

/// Parse a node-link format JSON str to generate a graph
///
/// :param data str: The JSON str to parse
/// :param graph_attrs: An optional callable that will be passed a dictionary
/// with string keys and string values and is expected to return a Python
/// object to use for :attr:`~.PyGraph.attrs` attribute of the output graph.
/// If not specified the dictionary with string keys and string values will
/// be used as the value for ``attrs``.
/// :param node_attrs: An optional callable that will be passed a dictionary with
/// string keys and string values representing the data payload
/// for each node in the graph and is expected to return a Python object to
/// use for the data payload of the node. If not specified the dictionary with
/// string keys and string values will be used for the nodes' data payload.
/// :param edge_attrs: An optional callable that will be passed a dictionary with
/// string keys and string values representing the data payload
/// for each edge in the graph and is expected to return a Python object to
/// use for the data payload of the node. If not specified the dictionary with
/// string keys and string values will be used for the edge' data payload.
///
/// :returns: The graph represented by the node link JSON
/// :rtype: PyGraph | PyDiGraph
#[pyfunction]
pub fn parse_node_link_json(
py: Python,
data: &str,
graph_attrs: Option<PyObject>,
node_attrs: Option<PyObject>,
edge_attrs: Option<PyObject>,
) -> PyResult<PyObject> {
let graph: node_link_data::GraphInput = match serde_json::from_str(data) {
Ok(v) => v,
Err(e) => {
return Err(JSONDeserializationError::new_err(format!(
"JSON Deserialization Error {}",
e
)));
}
};
let attrs: PyObject = match graph.attrs {
Some(ref attrs) => match graph_attrs {
Some(ref callback) => callback.call1(py, (attrs.clone(),))?,
None => attrs.to_object(py),
},
None => py.None(),
};
let multigraph = graph.multigraph;
Ok(if graph.directed {
let mut inner_graph: StablePyGraph<Directed> =
StablePyGraph::with_capacity(graph.nodes.len(), graph.links.len());
node_link_data::parse_node_link_data(&py, graph, &mut inner_graph, node_attrs, edge_attrs)?;
digraph::PyDiGraph {
graph: inner_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph,
attrs,
}
.into_py(py)
} else {
let mut inner_graph: StablePyGraph<Undirected> =
StablePyGraph::with_capacity(graph.nodes.len(), graph.links.len());
node_link_data::parse_node_link_data(&py, graph, &mut inner_graph, node_attrs, edge_attrs)?;
graph::PyGraph {
graph: inner_graph,
node_removed: false,
multigraph,
attrs,
}
.into_py(py)
})
}

/// Generate a JSON object representing a :class:`~.PyDiGraph` in a node-link format
///
/// :param PyDiGraph graph: The graph to generate the JSON for
Expand Down
85 changes: 74 additions & 11 deletions src/json/node_link_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
use std::collections::BTreeMap;
use std::fs::File;

use hashbrown::HashMap;

use serde::{Deserialize, Serialize};

use pyo3::prelude::*;
Expand All @@ -23,31 +25,92 @@ use petgraph::visit::IntoEdgeReferences;
use petgraph::EdgeType;

use crate::JSONSerializationError;
use crate::NodeIndex;
use crate::StablePyGraph;

#[derive(Serialize, Deserialize)]
struct Graph {
directed: bool,
multigraph: bool,
attrs: Option<BTreeMap<String, String>>,
nodes: Vec<Node>,
links: Vec<Link>,
#[derive(Serialize)]
pub struct Graph {
pub directed: bool,
pub multigraph: bool,
pub attrs: Option<BTreeMap<String, String>>,
pub nodes: Vec<Node>,
pub links: Vec<Link>,
}

#[derive(Deserialize)]
pub struct GraphInput {
pub directed: bool,
pub multigraph: bool,
pub attrs: Option<BTreeMap<String, String>>,
pub nodes: Vec<NodeInput>,
pub links: Vec<LinkInput>,
}

#[derive(Serialize, Deserialize)]
struct Node {
#[derive(Serialize)]
pub struct Node {
id: usize,
data: Option<BTreeMap<String, String>>,
}

#[derive(Serialize, Deserialize)]
struct Link {
#[derive(Deserialize)]
pub struct NodeInput {
id: Option<usize>,
data: Option<BTreeMap<String, String>>,
}

#[derive(Deserialize)]
pub struct LinkInput {
source: usize,
target: usize,
#[allow(dead_code)]
id: Option<usize>,
data: Option<BTreeMap<String, String>>,
}

#[derive(Serialize)]
pub struct Link {
source: usize,
target: usize,
id: usize,
data: Option<BTreeMap<String, String>>,
}

#[allow(clippy::too_many_arguments)]
pub fn parse_node_link_data<Ty: EdgeType>(
py: &Python,
graph: GraphInput,
out_graph: &mut StablePyGraph<Ty>,
node_attrs: Option<PyObject>,
edge_attrs: Option<PyObject>,
) -> PyResult<()> {
let mut id_mapping: HashMap<usize, NodeIndex> = HashMap::with_capacity(graph.nodes.len());
for node in graph.nodes {
let payload = match node.data {
Some(data) => match node_attrs {
Some(ref callback) => callback.call1(*py, (data,))?,
None => data.to_object(*py),
},
None => py.None(),
};
let id = out_graph.add_node(payload);
match node.id {
Some(input_id) => id_mapping.insert(input_id, id),
None => id_mapping.insert(id.index(), id),
};
}
for edge in graph.links {
let data = match edge.data {
Some(data) => match edge_attrs {
Some(ref callback) => callback.call1(*py, (data,))?,
None => data.to_object(*py),
},
None => py.None(),
};
out_graph.add_edge(id_mapping[&edge.source], id_mapping[&edge.target], data);
}
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn node_link_data<Ty: EdgeType>(
py: Python,
Expand Down
8 changes: 8 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ import_exception!(rustworkx.visit, PruneSearch);
import_exception!(rustworkx.visit, StopSearch);
// JSON Error
create_exception!(rustworkx, JSONSerializationError, PyException);
// JSON Error
create_exception!(rustworkx, JSONDeserializationError, PyException);
// Negative Cycle found on shortest-path algorithm
create_exception!(rustworkx, NegativeCycle, PyException);
// Failed to Converge on a solution
Expand Down Expand Up @@ -408,6 +410,10 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
"GraphNotBipartite",
py.get_type_bound::<GraphNotBipartite>(),
)?;
m.add(
"JSONDeserializationError",
py.get_type_bound::<JSONDeserializationError>(),
)?;
m.add_wrapped(wrap_pyfunction!(bfs_successors))?;
m.add_wrapped(wrap_pyfunction!(bfs_predecessors))?;
m.add_wrapped(wrap_pyfunction!(graph_bfs_search))?;
Expand Down Expand Up @@ -585,6 +591,8 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(read_graphml))?;
m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(from_node_link_json_file))?;
m.add_wrapped(wrap_pyfunction!(parse_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(pagerank))?;
m.add_wrapped(wrap_pyfunction!(hits))?;
m.add_class::<digraph::PyDiGraph>()?;
Expand Down
Loading