From 4724646949ed152d1475961e4f77ad782e01bef9 Mon Sep 17 00:00:00 2001 From: derbuihan Date: Wed, 4 May 2022 06:00:39 +0900 Subject: [PATCH] Fix implementation of tensor product for undirected graphs (#600) * fix implementation of tensor product for undirected graphs * support for a multigraph Co-authored-by: georgios-ts <45130028+georgios-ts@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- src/tensor_product.rs | 23 ++++++++++++-- tests/digraph/test_tensor_product.py | 40 ++++++++++++++++++++++++ tests/graph/test_tensor_product.py | 46 ++++++++++++++++++++++++++-- 3 files changed, 103 insertions(+), 6 deletions(-) diff --git a/src/tensor_product.rs b/src/tensor_product.rs index 6f24c6b3d..527a1f665 100644 --- a/src/tensor_product.rs +++ b/src/tensor_product.rs @@ -65,14 +65,31 @@ fn tensor_product( ) .into_py(py), ); + } + } + if undirected { + for edge_first in first.edge_references() { + for edge_second in second.edge_references() { + if edge_first.source() == edge_first.target() + || edge_second.source() == edge_second.target() + { + continue; + } + + let source = hash_nodes + .get(&(edge_first.source(), edge_second.target())) + .unwrap(); + + let target = hash_nodes + .get(&(edge_first.target(), edge_second.source())) + .unwrap(); - if undirected { final_graph.add_edge( - *target, *source, + *target, ( - edge_second.weight().clone_ref(py), edge_first.weight().clone_ref(py), + edge_second.weight().clone_ref(py), ) .into_py(py), ); diff --git a/tests/digraph/test_tensor_product.py b/tests/digraph/test_tensor_product.py index 89012123f..8e38674c6 100644 --- a/tests/digraph/test_tensor_product.py +++ b/tests/digraph/test_tensor_product.py @@ -68,3 +68,43 @@ def test_directed_edge_weights_tensor(self): graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2) self.assertEqual([("w_1", "w_2")], graph_product.edges()) + + def test_multi_graph_1(self): + graph_1 = retworkx.generators.directed_path_graph(2) + graph_1.add_edge(0, 1, None) + graph_2 = retworkx.generators.directed_path_graph(2) + + graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2) + expected_edges = [(0, 3), (0, 3)] + self.assertEqual(graph_product.num_edges(), 2) + self.assertEqual(graph_product.edge_list(), expected_edges) + + def test_multi_graph_2(self): + graph_1 = retworkx.generators.directed_path_graph(2) + graph_1.add_edge(0, 0, None) + graph_2 = retworkx.generators.directed_path_graph(2) + + graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2) + expected_edges = [(0, 3), (0, 1)] + self.assertEqual(graph_product.num_edges(), 2) + self.assertEqual(graph_product.edge_list(), expected_edges) + + def test_multi_graph_3(self): + graph_1 = retworkx.generators.directed_path_graph(2) + graph_2 = retworkx.generators.directed_path_graph(2) + graph_2.add_edge(0, 1, None) + + graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2) + expected_edges = [(0, 3), (0, 3)] + self.assertEqual(graph_product.num_edges(), 2) + self.assertEqual(graph_product.edge_list(), expected_edges) + + def test_multi_graph_4(self): + graph_1 = retworkx.generators.directed_path_graph(2) + graph_2 = retworkx.generators.directed_path_graph(2) + graph_2.add_edge(0, 0, None) + + graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2) + expected_edges = [(0, 3), (0, 2)] + self.assertEqual(graph_product.num_edges(), 2) + self.assertEqual(graph_product.edge_list(), expected_edges) diff --git a/tests/graph/test_tensor_product.py b/tests/graph/test_tensor_product.py index b8775c1ad..9d282b0bb 100644 --- a/tests/graph/test_tensor_product.py +++ b/tests/graph/test_tensor_product.py @@ -32,7 +32,7 @@ def test_path_2_tensor_path_2(self): expected_node_map = {(0, 0): 0, (0, 1): 1, (1, 0): 2, (1, 1): 3} self.assertEqual(node_map, expected_node_map) - expected_edges = [(0, 3), (3, 0)] + expected_edges = [(0, 3), (1, 2)] self.assertEqual(graph_product.num_nodes(), 4) self.assertEqual(graph_product.num_edges(), 2) self.assertEqual(graph_product.edge_list(), expected_edges) @@ -45,7 +45,7 @@ def test_path_2_tensor_path_3(self): expected_node_map = {(0, 1): 1, (1, 0): 3, (0, 0): 0, (1, 2): 5, (0, 2): 2, (1, 1): 4} self.assertEqual(dict(node_map), expected_node_map) - expected_edges = [(0, 4), (4, 0), (1, 5), (5, 1)] + expected_edges = [(0, 4), (1, 5), (1, 3), (2, 4)] self.assertEqual(graph_product.num_nodes(), 6) self.assertEqual(graph_product.num_edges(), 4) self.assertEqual(graph_product.edge_list(), expected_edges) @@ -69,4 +69,44 @@ def test_edge_weights_tensor(self): graph_2.add_edge(0, 1, "w_2") graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2) - self.assertEqual([("w_1", "w_2"), ("w_2", "w_1")], graph_product.edges()) + self.assertEqual([("w_1", "w_2"), ("w_1", "w_2")], graph_product.edges()) + + def test_multi_graph_1(self): + graph_1 = retworkx.generators.path_graph(2) + graph_1.add_edge(0, 1, None) + graph_2 = retworkx.generators.path_graph(2) + + graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2) + expected_edges = [(0, 3), (0, 3), (1, 2), (1, 2)] + self.assertEqual(graph_product.num_edges(), 4) + self.assertEqual(graph_product.edge_list(), expected_edges) + + def test_multi_graph_2(self): + graph_1 = retworkx.generators.path_graph(2) + graph_1.add_edge(0, 0, None) + graph_2 = retworkx.generators.path_graph(2) + + graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2) + expected_edges = [(0, 3), (0, 1), (1, 2)] + self.assertEqual(graph_product.num_edges(), 3) + self.assertEqual(graph_product.edge_list(), expected_edges) + + def test_multi_graph_3(self): + graph_1 = retworkx.generators.path_graph(2) + graph_2 = retworkx.generators.path_graph(2) + graph_2.add_edge(0, 1, None) + + graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2) + expected_edges = [(0, 3), (0, 3), (1, 2), (1, 2)] + self.assertEqual(graph_product.num_edges(), 4) + self.assertEqual(graph_product.edge_list(), expected_edges) + + def test_multi_graph_4(self): + graph_1 = retworkx.generators.path_graph(2) + graph_2 = retworkx.generators.path_graph(2) + graph_2.add_edge(0, 0, None) + + graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2) + expected_edges = [(0, 3), (0, 2), (1, 2)] + self.assertEqual(graph_product.num_edges(), 3) + self.assertEqual(graph_product.edge_list(), expected_edges)