Skip to content

Commit

Permalink
Fix implementation of tensor product for undirected graphs (#600)
Browse files Browse the repository at this point in the history
* fix implementation of tensor product for undirected graphs

* support for a multigraph

Co-authored-by: georgios-ts <45130028+georgios-ts@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed May 3, 2022
1 parent 9101dbc commit 4724646
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 6 deletions.
23 changes: 20 additions & 3 deletions src/tensor_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,31 @@ fn tensor_product<Ty: EdgeType>(
)
.into_py(py),
);
}
}
if undirected {
for edge_first in first.edge_references() {
for edge_second in second.edge_references() {
if edge_first.source() == edge_first.target()
|| edge_second.source() == edge_second.target()
{
continue;
}

let source = hash_nodes
.get(&(edge_first.source(), edge_second.target()))
.unwrap();

let target = hash_nodes
.get(&(edge_first.target(), edge_second.source()))
.unwrap();

if undirected {
final_graph.add_edge(
*target,
*source,
*target,
(
edge_second.weight().clone_ref(py),
edge_first.weight().clone_ref(py),
edge_second.weight().clone_ref(py),
)
.into_py(py),
);
Expand Down
40 changes: 40 additions & 0 deletions tests/digraph/test_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,43 @@ def test_directed_edge_weights_tensor(self):

graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2)
self.assertEqual([("w_1", "w_2")], graph_product.edges())

def test_multi_graph_1(self):
graph_1 = retworkx.generators.directed_path_graph(2)
graph_1.add_edge(0, 1, None)
graph_2 = retworkx.generators.directed_path_graph(2)

graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2)
expected_edges = [(0, 3), (0, 3)]
self.assertEqual(graph_product.num_edges(), 2)
self.assertEqual(graph_product.edge_list(), expected_edges)

def test_multi_graph_2(self):
graph_1 = retworkx.generators.directed_path_graph(2)
graph_1.add_edge(0, 0, None)
graph_2 = retworkx.generators.directed_path_graph(2)

graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2)
expected_edges = [(0, 3), (0, 1)]
self.assertEqual(graph_product.num_edges(), 2)
self.assertEqual(graph_product.edge_list(), expected_edges)

def test_multi_graph_3(self):
graph_1 = retworkx.generators.directed_path_graph(2)
graph_2 = retworkx.generators.directed_path_graph(2)
graph_2.add_edge(0, 1, None)

graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2)
expected_edges = [(0, 3), (0, 3)]
self.assertEqual(graph_product.num_edges(), 2)
self.assertEqual(graph_product.edge_list(), expected_edges)

def test_multi_graph_4(self):
graph_1 = retworkx.generators.directed_path_graph(2)
graph_2 = retworkx.generators.directed_path_graph(2)
graph_2.add_edge(0, 0, None)

graph_product, _ = retworkx.digraph_tensor_product(graph_1, graph_2)
expected_edges = [(0, 3), (0, 2)]
self.assertEqual(graph_product.num_edges(), 2)
self.assertEqual(graph_product.edge_list(), expected_edges)
46 changes: 43 additions & 3 deletions tests/graph/test_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_path_2_tensor_path_2(self):
expected_node_map = {(0, 0): 0, (0, 1): 1, (1, 0): 2, (1, 1): 3}
self.assertEqual(node_map, expected_node_map)

expected_edges = [(0, 3), (3, 0)]
expected_edges = [(0, 3), (1, 2)]
self.assertEqual(graph_product.num_nodes(), 4)
self.assertEqual(graph_product.num_edges(), 2)
self.assertEqual(graph_product.edge_list(), expected_edges)
Expand All @@ -45,7 +45,7 @@ def test_path_2_tensor_path_3(self):
expected_node_map = {(0, 1): 1, (1, 0): 3, (0, 0): 0, (1, 2): 5, (0, 2): 2, (1, 1): 4}
self.assertEqual(dict(node_map), expected_node_map)

expected_edges = [(0, 4), (4, 0), (1, 5), (5, 1)]
expected_edges = [(0, 4), (1, 5), (1, 3), (2, 4)]
self.assertEqual(graph_product.num_nodes(), 6)
self.assertEqual(graph_product.num_edges(), 4)
self.assertEqual(graph_product.edge_list(), expected_edges)
Expand All @@ -69,4 +69,44 @@ def test_edge_weights_tensor(self):
graph_2.add_edge(0, 1, "w_2")

graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2)
self.assertEqual([("w_1", "w_2"), ("w_2", "w_1")], graph_product.edges())
self.assertEqual([("w_1", "w_2"), ("w_1", "w_2")], graph_product.edges())

def test_multi_graph_1(self):
graph_1 = retworkx.generators.path_graph(2)
graph_1.add_edge(0, 1, None)
graph_2 = retworkx.generators.path_graph(2)

graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2)
expected_edges = [(0, 3), (0, 3), (1, 2), (1, 2)]
self.assertEqual(graph_product.num_edges(), 4)
self.assertEqual(graph_product.edge_list(), expected_edges)

def test_multi_graph_2(self):
graph_1 = retworkx.generators.path_graph(2)
graph_1.add_edge(0, 0, None)
graph_2 = retworkx.generators.path_graph(2)

graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2)
expected_edges = [(0, 3), (0, 1), (1, 2)]
self.assertEqual(graph_product.num_edges(), 3)
self.assertEqual(graph_product.edge_list(), expected_edges)

def test_multi_graph_3(self):
graph_1 = retworkx.generators.path_graph(2)
graph_2 = retworkx.generators.path_graph(2)
graph_2.add_edge(0, 1, None)

graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2)
expected_edges = [(0, 3), (0, 3), (1, 2), (1, 2)]
self.assertEqual(graph_product.num_edges(), 4)
self.assertEqual(graph_product.edge_list(), expected_edges)

def test_multi_graph_4(self):
graph_1 = retworkx.generators.path_graph(2)
graph_2 = retworkx.generators.path_graph(2)
graph_2.add_edge(0, 0, None)

graph_product, _ = retworkx.graph_tensor_product(graph_1, graph_2)
expected_edges = [(0, 3), (0, 2), (1, 2)]
self.assertEqual(graph_product.num_edges(), 3)
self.assertEqual(graph_product.edge_list(), expected_edges)

0 comments on commit 4724646

Please sign in to comment.