Skip to content

Commit

Permalink
Consistent conversion to undirected graphs (#301)
Browse files Browse the repository at this point in the history
* Fix `convert_nx_to_pyg` to return undirected graph

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix symmetrization of edges of different kinds

* Clean

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix case when `edge_index` is not desired

* Test directed/undirected conversion consistency

* Update contributors

* Update CHANGELOG.md

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anton-bushuiev and pre-commit-ci[bot] committed Apr 16, 2023
1 parent 649a490 commit 86313b4
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
* [ESM] - [#284](https://github.com/a-r-j/graphein/pull/284) - Wrapper for ESMFold batch folding & embedding.
* [Downloads] MMTF downloading now supported in download utilities. [#272](https://github.com/a-r-j/graphein/pull/272)

#### Improvements
* [Bugfix] - [#301](https://github.com/a-r-j/graphein/pull/301) Fixes the conversion of undirected NetworkX graph to directed PyG data.

#### API Changes
* The `pdb_path` argument to many functions (e.g. `graphein.protein.graphs.construct_graph`) has been renamed to `path` as this can now accept MMTF files in addition to PDB files.
* `Protein` tensors have coordinates renamed from `Protein.x` to `Protein.coords`. [#272](https://github.com/a-r-j/graphein/pull/272)
Expand Down
2 changes: 2 additions & 0 deletions docs/source/contributing/CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ Rico Meinl: [@ricomnl](https://github.com/ricomnl)
Alex Morehead [@amorehead](https://github.com/amorehead)

Aviv Korman [@avivko](https://github.com/avivko)

Anton Bushuiev [@anton-bushuiev](https://github.com/anton-bushuiev)
31 changes: 31 additions & 0 deletions graphein/ml/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch
from loguru import logger as log
from torch_geometric.utils.undirected import to_undirected

from graphein.utils.dependencies import import_message

Expand Down Expand Up @@ -288,6 +289,12 @@ def convert_nx_to_pyg(self, G: nx.Graph) -> Data:
data[key].append(value)

# Add edge features
edge_feature_names = list(G.edges(data=True))[0][2].keys()
edge_feature_names = list(
filter(
lambda x: x in self.columns and x != "kind", edge_feature_names
)
)
for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
for key, value in feat_dict.items():
key = str(key)
Expand Down Expand Up @@ -324,8 +331,32 @@ def convert_nx_to_pyg(self, G: nx.Graph) -> Data:
log.warning(e)
pass

# Construct PyG data
data = Data.from_dict(data)
data.num_nodes = G.number_of_nodes()

# Symmetrize if undirected
if not G.is_directed():
# Edge index and edge features
edge_index, edge_features = to_undirected(
edge_index,
[getattr(data, attr) for attr in edge_feature_names],
data.num_nodes,
)
if "edge_index" in self.columns:
data.edge_index = edge_index
for attr, val in zip(edge_feature_names, edge_features):
setattr(data, attr, val)

# Edge indices of different kinds
for kind in set(kind_strs):
key = f"edge_index_{kind}"
if key in self.columns:
edge_index_kind = to_undirected(
getattr(data, key), num_nodes=data.num_nodes
)
setattr(data, key, edge_index_kind)

return data

@staticmethod
Expand Down
4 changes: 4 additions & 0 deletions tests/ml/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

try:
import torch_geometric
from torch_geometric.utils import is_undirected

PYG_AVAIL = True
except ImportError:
Expand Down Expand Up @@ -82,3 +83,6 @@ def test_nx_to_pyg(pdb_code):
data.edge_index.shape[1]
== data.edge_index_inter.shape[1] + data.edge_index_intra.shape[1]
)

# Directed/undirected consistency
assert g.is_directed() is not is_undirected(data.edge_index)

0 comments on commit 86313b4

Please sign in to comment.