Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent conversion to undirected graphs #301

Merged
merged 13 commits into from
Apr 16, 2023
Merged
30 changes: 30 additions & 0 deletions graphein/ml/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch
from loguru import logger as log
from torch_geometric.utils.undirected import to_undirected

from graphein.utils.dependencies import import_message

Expand Down Expand Up @@ -288,6 +289,12 @@ def convert_nx_to_pyg(self, G: nx.Graph) -> Data:
data[key].append(value)

# Add edge features
edge_feature_names = list(G.edges(data=True))[0][2].keys()
edge_feature_names = list(
filter(
lambda x: x in self.columns and x != "kind", edge_feature_names
)
)
for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
for key, value in feat_dict.items():
key = str(key)
Expand Down Expand Up @@ -324,8 +331,31 @@ def convert_nx_to_pyg(self, G: nx.Graph) -> Data:
log.warning(e)
pass

# Construct PyG data
data = Data.from_dict(data)
data.num_nodes = G.number_of_nodes()

# Symmetrize if undirected
if not G.is_directed():
# Edge index and edge features
edge_index, edge_features = to_undirected(
data.edge_index,
[getattr(data, attr) for attr in edge_feature_names],
data.num_nodes,
)
data.edge_index = edge_index
for attr, val in zip(edge_feature_names, edge_features):
setattr(data, attr, val)

# Edge indices of different kinds
for kind in set(kind_strs):
key = f"edge_index_{kind}"
if key in self.columns:
edge_index_kind = to_undirected(
getattr(data, key), num_nodes=data.num_nodes
)
setattr(data, key, edge_index_kind)

return data

@staticmethod
Expand Down