From 29802409a8df10353a2b49335dfaf37a4383bfe1 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Thu, 8 Feb 2024 17:30:32 -0500 Subject: [PATCH 1/4] add `edge_index.dtype` assertions (tests should fail) --- tests/test_adapter.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 3042a8d..c1e4cfc 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Set, Union import pytest +import torch from pandas import DataFrame from torch import Tensor, cat, long, tensor from torch_geometric.data import Data, HeteroData @@ -554,6 +555,7 @@ def test_adb_partial_to_pyg() -> None: assert type(pyg_g_new) is Data assert pyg_g["v0"].x.tolist() == pyg_g_new.x.tolist() assert pyg_g["v0"].y.tolist() == pyg_g_new.y.tolist() + assert pyg_g[e_t].edge_index.dtype == torch.int64 assert pyg_g[e_t].edge_index.tolist() == pyg_g_new.edge_index.tolist() assert pyg_g[e_t].edge_attr.tolist() == pyg_g_new.edge_attr.tolist() @@ -714,13 +716,17 @@ def test_adb_graph_to_pyg_to_arangodb_with_missing_document_and_permissive( graph = db.graph(name) v_cols: Set[str] = graph.vertex_collections() + assert len(v_cols) == 1 edge_definitions: List[Json] = graph.edge_definitions() e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions} + assert len(e_cols) == 1 for v_col in v_cols: vertex_collection = db.collection(v_col) vertex_collection.delete("0") + number_of_missing_edges = 32 # (i.e node 0 has 32 edges) + metagraph: ADBMetagraph = { "vertexCollections": {col: {} for col in v_cols}, "edgeCollections": {col: {} for col in e_cols}, @@ -729,7 +735,8 @@ def test_adb_graph_to_pyg_to_arangodb_with_missing_document_and_permissive( data = adapter.arangodb_to_pyg(name, metagraph=metagraph, strict=False) collection_count: int = db.collection(list(e_cols)[0]).count() - assert len(data.edge_index[0]) < collection_count + assert data.edge_index.dtype == torch.int64 + assert data.num_edges + number_of_missing_edges == collection_count db.delete_graph(name, drop_collections=True) @@ -1076,6 +1083,7 @@ def assert_adb_to_pyg( from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist() to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist() + assert edge_data.edge_index.dtype == torch.int64 assert from_nodes == edge_data.edge_index[0].tolist() assert to_nodes == edge_data.edge_index[1].tolist() From 480b259a5f1cac3762e1ceb8a988d4668357b337 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Thu, 8 Feb 2024 17:42:16 -0500 Subject: [PATCH 2/4] fix: `edge_index.dtype` to `int` --- adbpyg_adapter/adapter.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/adbpyg_adapter/adapter.py b/adbpyg_adapter/adapter.py index 1c86515..3f6eebc 100644 --- a/adbpyg_adapter/adapter.py +++ b/adbpyg_adapter/adapter.py @@ -904,24 +904,28 @@ def __process_adb_edge_df( ] # 5. Map each ArangoDB from/to _key to the corresponding PyG node id - from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist() - to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist() + # NOTE: map() is somehow converting int values to float... + # So we rely on astype(int) to convert the float back to int, + # but we also fill NaN values with -1 so that we can convert + # the entire column to int without any issues. Need to revisit... + from_n = et_df["from_key"].map(adb_map[from_col]).fillna(-1).astype(int) + to_n = et_df["to_key"].map(adb_map[to_col]).fillna(-1).astype(int) # 6. Set/Update the PyG Edge Index - edge_index = tensor([from_nodes, to_nodes]) - edge_data.edge_index = torch.cat( - (edge_data.get("edge_index", tensor([])), edge_index), dim=1 - ) + edge_index = tensor([from_n.tolist(), to_n.tolist()], dtype=torch.int64) + empty_tensor = torch.tensor([], dtype=torch.int64) + existing_edge_index = edge_data.get("edge_index", empty_tensor) + edge_data.edge_index = torch.cat((existing_edge_index, edge_index), dim=1) # 7. Deal with invalid edges - if torch.any(torch.isnan(edge_data.edge_index)): + if torch.any(edge_data.edge_index == -1): if strict: m = f"Invalid edges found in Edge Collection {e_col}, {from_col} -> {to_col}." # noqa: E501 raise InvalidADBEdgesError(m) else: # Remove the invalid edges edge_data.edge_index = edge_data.edge_index[ - :, ~torch.any(edge_data.edge_index.isnan(), dim=0) + :, ~torch.any(edge_data.edge_index == -1, dim=0) ] # 8. Set the PyG Edge Data From 037fa4d332dca5fc686fcf2f0b691e9fb3aa693a Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Fri, 9 Feb 2024 09:09:40 -0500 Subject: [PATCH 3/4] fix: `edge_index` creation --- adbpyg_adapter/adapter.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/adbpyg_adapter/adapter.py b/adbpyg_adapter/adapter.py index 3f6eebc..1ea90d5 100644 --- a/adbpyg_adapter/adapter.py +++ b/adbpyg_adapter/adapter.py @@ -5,6 +5,7 @@ from math import ceil from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union +import numpy as np import torch from arango.cursor import Cursor from arango.database import StandardDatabase @@ -912,7 +913,10 @@ def __process_adb_edge_df( to_n = et_df["to_key"].map(adb_map[to_col]).fillna(-1).astype(int) # 6. Set/Update the PyG Edge Index - edge_index = tensor([from_n.tolist(), to_n.tolist()], dtype=torch.int64) + edge_index = tensor( + np.array([from_n.to_numpy(), to_n.to_numpy()]), dtype=torch.int64 + ) + empty_tensor = torch.tensor([], dtype=torch.int64) existing_edge_index = edge_data.get("edge_index", empty_tensor) edge_data.edge_index = torch.cat((existing_edge_index, edge_index), dim=1) From c1a846b63f40139ae1380563fa4ea6e4f61ea5e8 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Fri, 9 Feb 2024 09:12:20 -0500 Subject: [PATCH 4/4] cleanup --- adbpyg_adapter/adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/adbpyg_adapter/adapter.py b/adbpyg_adapter/adapter.py index 1ea90d5..b8b9dcb 100644 --- a/adbpyg_adapter/adapter.py +++ b/adbpyg_adapter/adapter.py @@ -917,9 +917,9 @@ def __process_adb_edge_df( np.array([from_n.to_numpy(), to_n.to_numpy()]), dtype=torch.int64 ) - empty_tensor = torch.tensor([], dtype=torch.int64) + empty_tensor = tensor([], dtype=torch.int64) existing_edge_index = edge_data.get("edge_index", empty_tensor) - edge_data.edge_index = torch.cat((existing_edge_index, edge_index), dim=1) + edge_data.edge_index = cat((existing_edge_index, edge_index), dim=1) # 7. Deal with invalid edges if torch.any(edge_data.edge_index == -1):