Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pickle/deepcopy not preserve original edge indices #589

Merged
merged 8 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions releasenotes/notes/fix-edge-indices-pickle-83fddf149441fa9f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
fixes:
- |
Fixed an issue when using ``copy.deepcopy()`` on :class:`~.PyDiGraph` and
:class:`~.PyGraph` objects when there were removed edges from the graph
object. Previously, if there were any holes in the edge indices caused by
the removal the output copy of the graph object would incorrectly have
flatten the indices. This has been corrected so that the edge indices are
recreated exactly after a ``deepcopy()``.
Fixed `#585 <https://github.com/Qiskit/rustworkx/issues/585>`__
231 changes: 165 additions & 66 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::*;

use petgraph::visit::{
GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, NodeIndexable,
EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered,
Visitable,
};

Expand Down Expand Up @@ -298,97 +298,196 @@ impl PyDiGraph {
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let mut nodes: Vec<PyObject> = Vec::with_capacity(self.graph.node_count());
let mut edges: Vec<PyObject> = Vec::with_capacity(self.graph.edge_bound());

// save nodes to a list along with its index
for node_idx in self.graph.node_indices() {
let node_data = self.graph.node_weight(node_idx).unwrap();
nodes.push((node_idx.index(), node_data).to_object(py));
}

// edges are saved with none (deleted edges) instead of their index to save space
for i in 0..self.graph.edge_bound() {
let idx = EdgeIndex::new(i);
let edge = match self.graph.edge_weight(idx) {
Some(edge_w) => {
let endpoints = self.graph.edge_endpoints(idx).unwrap();
(endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py)
}
None => py.None(),
};
edges.push(edge);
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
}

let out_dict = PyDict::new(py);
let node_dict = PyDict::new(py);
let mut out_list: Vec<PyObject> = Vec::with_capacity(self.graph.edge_count());
out_dict.set_item("nodes", node_dict)?;
let nodes_lst: PyObject = PyList::new(py, nodes).into();
let edges_lst: PyObject = PyList::new(py, edges).into();
out_dict.set_item("nodes", nodes_lst)?;
out_dict.set_item("edges", edges_lst)?;
out_dict.set_item("nodes_removed", self.node_removed)?;
out_dict.set_item("multigraph", self.multigraph)?;
out_dict.set_item("attrs", self.attrs.clone_ref(py))?;
out_dict.set_item("check_cycle", self.check_cycle)?;
let dir = petgraph::Direction::Incoming;
for node_index in self.graph.node_indices() {
let node_data = self.graph.node_weight(node_index).unwrap();
node_dict.set_item(node_index.index(), node_data)?;
for edge in self.graph.edges_directed(node_index, dir) {
let edge_w = edge.weight();
let triplet = (edge.source().index(), edge.target().index(), edge_w).to_object(py);
out_list.push(triplet);
}
}
let py_out_list: PyObject = PyList::new(py, out_list).into();
out_dict.set_item("edges", py_out_list)?;
Ok(out_dict.into())
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
let dict_state = state.downcast::<PyDict>(py)?;
let nodes_lst = dict_state.get_item("nodes").unwrap().downcast::<PyList>()?;
let edges_lst = dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
self.graph = StablePyGraph::<Directed>::new();
let dict_state = state.downcast::<PyDict>(py)?;

let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::<PyDict>()?;
let edges_list = dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
let nodes_removed_raw = dict_state
.get_item("nodes_removed")
.unwrap()
.downcast::<PyBool>()?;
self.node_removed = nodes_removed_raw.extract()?;
let multigraph_raw = dict_state
self.multigraph = dict_state
.get_item("multigraph")
.unwrap()
.downcast::<PyBool>()?;
self.multigraph = multigraph_raw.extract()?;
.downcast::<PyBool>()?
.extract()?;
self.node_removed = dict_state
.get_item("nodes_removed")
.unwrap()
.downcast::<PyBool>()?
.extract()?;
let attrs = match dict_state.get_item("attrs") {
Some(attr) => attr.into(),
None => py.None(),
};
self.attrs = attrs;
let check_cycle_raw = dict_state
self.check_cycle = dict_state
.get_item("check_cycle")
.unwrap()
.downcast::<PyBool>()?;
self.check_cycle = check_cycle_raw.extract()?;
let mut node_indices: Vec<usize> = Vec::new();
for raw_index in nodes_dict.keys() {
let tmp_index = raw_index.downcast::<PyLong>()?;
node_indices.push(tmp_index.extract()?);
}
if node_indices.is_empty() {
.downcast::<PyBool>()?
.extract()?;

// graph is empty, stop early
if nodes_lst.is_empty() {
return Ok(());
}
let max_index: usize = *node_indices.iter().max().unwrap();
if max_index + 1 != node_indices.len() {
self.node_removed = true;
}
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
let mut node_count: usize = 0;
while max_index >= self.graph.node_bound() {
match nodes_dict.get_item(node_count) {
Some(raw_data) => {
self.graph.add_node(raw_data.into());
}
None => {

if !self.node_removed {
for item in nodes_lst.iter() {
let node_w = item
.downcast::<PyTuple>()
.unwrap()
.get_item(1)
.unwrap()
.extract()
.unwrap();
self.graph.add_node(node_w);
}
} else if nodes_lst.len() == 1 {
// graph has only one node, handle logic here to save one if in the loop later
let item = nodes_lst
.get_item(0)
.unwrap()
.downcast::<PyTuple>()
.unwrap();
let node_idx: usize = item.get_item(0).unwrap().extract().unwrap();
let node_w = item.get_item(1).unwrap().extract().unwrap();

for _i in 0..node_idx {
self.graph.add_node(py.None());
}
self.graph.add_node(node_w);
for i in 0..node_idx {
self.graph.remove_node(NodeIndex::new(i));
}
} else {
let last_item = nodes_lst
.get_item(nodes_lst.len() - 1)
.unwrap()
.downcast::<PyTuple>()
.unwrap();

// use a pointer to iter the node list
let mut pointer = 0;
let mut next_node_idx: usize = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();

// list of temporary nodes that will be removed later to re-create holes
let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap();
let mut tmp_nodes: Vec<NodeIndex> =
Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len());

for i in 0..nodes_lst.len() + 1 {
if i < next_node_idx {
// node does not exist
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
} else {
// add node to the graph, and update the next available node index
let item = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap();

let node_w = item.get_item(1).unwrap().extract().unwrap();
self.graph.add_node(node_w);
pointer += 1;
if pointer < nodes_lst.len() {
next_node_idx = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
}
}
};
node_count += 1;
}
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
for raw_edge in edges_list.iter() {
let edge = raw_edge.downcast::<PyTuple>()?;
let raw_p_index = edge.get_item(0)?.downcast::<PyLong>()?;
let p_index: usize = raw_p_index.extract()?;
let raw_c_index = edge.get_item(1)?.downcast::<PyLong>()?;
let c_index: usize = raw_c_index.extract()?;
let edge_data = edge.get_item(2)?;
self.graph.add_edge(
NodeIndex::new(p_index),
NodeIndex::new(c_index),
edge_data.into(),
);
}
// Remove any temporary nodes we added
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
}

// to ensure O(1) on edge deletion, use a temporary node to store missing edges
let tmp_node = self.graph.add_node(py.None());

for item in edges_lst {
if item.is_none() {
// add a temporary edge that will be deleted later to re-create the hole
self.graph.add_edge(tmp_node, tmp_node, py.None());
} else {
let triple = item.downcast::<PyTuple>().unwrap();
let edge_p: usize = triple
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
let edge_c: usize = triple
.get_item(1)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
let edge_w = triple.get_item(2).unwrap().extract().unwrap();
self.graph
.add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w);
}
}

// remove the temporary node will remove all deleted edges in bulk,
// the cost is equal to the number of edges
self.graph.remove_node(tmp_node);

Ok(())
}

Expand Down