Skip to content

Commit

Permalink
fix: edge_index
Browse files Browse the repository at this point in the history
ref: #31
  • Loading branch information
aMahanna committed Feb 9, 2024
1 parent 77b7a81 commit 67378e8
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions adbpyg_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from math import ceil
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import torch
from arango.cursor import Cursor
from arango.database import StandardDatabase
Expand Down Expand Up @@ -1032,24 +1033,32 @@ def __process_adb_edge_df(
]

# 5. Map each ArangoDB from/to _key to the corresponding PyG node id
from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist()
to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist()
# NOTE: map() is somehow converting int values to float...
# So we rely on astype(int) to convert the float back to int,
# but we also fill NaN values with -1 so that we can convert
# the entire column to int without any issues. Need to revisit...
from_n = et_df["from_key"].map(adb_map[from_col]).fillna(-1).astype(int)
to_n = et_df["to_key"].map(adb_map[to_col]).fillna(-1).astype(int)

# 6. Set/Update the PyG Edge Index
edge_data: EdgeStorage = data if is_homogeneous else data[edge_type]
existing_ei = edge_data.get("edge_index", tensor([]))
new_ei = tensor([from_nodes, to_nodes])
edge_data.edge_index = torch.cat((existing_ei, new_ei), dim=1)
empty_tensor = tensor([], dtype=torch.int64)
existing_edge_index = edge_data.get("edge_index", empty_tensor)
new_edge_index = tensor(
np.array([from_n.to_numpy(), to_n.to_numpy()]), dtype=torch.int64
)

edge_data.edge_index = cat((existing_edge_index, new_edge_index), dim=1)

# 7. Deal with invalid edges
if torch.any(torch.isnan(edge_data.edge_index)):
if torch.any(edge_data.edge_index == -1):
if strict:
m = f"Invalid edges found in Edge Collection {e_col}, {from_col} -> {to_col}." # noqa: E501
raise InvalidADBEdgesError(m)
else:
# Remove the invalid edges
edge_data.edge_index = edge_data.edge_index[
:, ~torch.any(edge_data.edge_index.isnan(), dim=0)
:, ~torch.any(edge_data.edge_index == -1, dim=0)
]

# 8. Set the PyG Edge Data
Expand Down

0 comments on commit 67378e8

Please sign in to comment.