Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PyG 2.4+ #350

Merged
merged 17 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9]
torch: [1.12.0, 1.13.0, 2.0.0]
python-version: [3.8, 3.9, "3.10"]
torch: [1.13.0, 2.0.0, 2.1.0]
#include:
# - torch: 1.6.0
# torchvision: 0.7.0
Expand Down Expand Up @@ -60,8 +60,10 @@ jobs:
# run: conda env create -n graphein-dev python=${{ matrix.python-version }}
#- name: Activate Conda Environment
# run: source activate graphein-dev
- name: Install Boost 1.7.3 (for DSSP)
run: conda install -c anaconda libboost=1.73.0
- name: Install DSSP
run: conda install -c salilab dssp
run: conda install dssp -c salilab
- name: Install mmseqs
run: mamba install -c conda-forge -c bioconda mmseqs2
- name: Install PyTorch
Expand Down
1 change: 1 addition & 0 deletions .requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ biopython
bioservices>=1.10.0
deepdiff
loguru
looseversion
matplotlib>=3.4.3
multipledispatch
networkx
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
### 1.7.4 - 24/10/2023

* Adds support for PyG 2.4+ ([#350](https://www.github.com/a-r-j/graphein/pull/339))

### 1.7.3 - 30/08/2023

* Fixes edge case in FoldComp database download if target directory has same name as database ([#339](https://github.com/a-r-j/graphein/pull/339))
Expand Down
62 changes: 53 additions & 9 deletions graphein/protein/tensor/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# Code Repository: https://github.com/a-r-j/graphein
from typing import Any, Callable, List, Optional, Tuple, Union

import looseversion
import pandas as pd
import plotly.graph_objects as go
import torch
import torch.nn.functional as F
import torch_geometric
from biopandas.pdb import PandasPdb
from loguru import logger as log
from torch_geometric.data import Batch, Data
Expand Down Expand Up @@ -64,6 +66,8 @@
TorsionTensor,
)

PYG_VERSION = looseversion.LooseVersion(torch_geometric.__version__)


class Protein(Data):
""" "A data object describing a homogeneous graph. ``Protein`` inherits from
Expand Down Expand Up @@ -249,7 +253,12 @@ def from_data(self, data: Data) -> "Protein":
:return: ``Protein`` object containing the same keys and values
:rtype: Protein
"""
keys = data.keys
keys = (
data.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else data.keys
)

for key in keys:
setattr(self, key, getattr(data, key))
return self
Expand All @@ -271,7 +280,12 @@ def to_data(self) -> Data:
:rtype: Data
"""
data = Data()
for i in self.keys:
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)
for i in keys:
setattr(data, i, getattr(self, i))
return data

Expand Down Expand Up @@ -732,7 +746,13 @@ def has_complete_backbone(self) -> bool:

def __eq__(self, __o: object) -> bool:
# sourcery skip: merge-duplicate-blocks, merge-else-if-into-elif
for i in self.keys:
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)

for i in keys:
attr_self = getattr(self, i)
attr_other = getattr(__o, i)

Expand Down Expand Up @@ -760,9 +780,15 @@ def plot_distance_matrix(
return plot_distance_matrix(x)

def plot_dihedrals(self) -> go.Figure:
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)

dh = (
dihedrals(self.coords)
if "dihedrals" not in self.keys
if "dihedrals" not in keys
else self.dihedrals
)
return plot_dihedrals(dh)
Expand Down Expand Up @@ -833,7 +859,12 @@ def __init__(
def from_batch(
self, batch: Batch, fill_value: float = 1e-5
) -> "ProteinBatch":
for key in batch.keys:
keys = (
batch.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else batch.keys
)
for key in keys:
setattr(self, key, getattr(batch, key))

if hasattr(batch, "_slice_dict"):
Expand Down Expand Up @@ -930,7 +961,11 @@ def from_pdb_files(
def to_batch(self) -> Batch:
"""Returns the ProteinBatch as a torch_geometric.data.Batch object."""
batch = Batch()
keys = self.keys
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)
for key in keys:
setattr(batch, key, getattr(self, key))
return batch
Expand Down Expand Up @@ -1190,8 +1225,12 @@ def to_protein_list(self) -> List["Protein"]:
proteins = [Protein() for _ in range(self.num_graphs)]

# Iterate over attributes
for k in self.keys:
print(k)
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)
for k in keys:
# Get attribute
attr = getattr(self, k)
# Skip ptr
Expand All @@ -1218,7 +1257,12 @@ def to_protein_list(self) -> List["Protein"]:

def __eq__(self, __o: object) -> bool:
# sourcery skip: merge-duplicate-blocks, merge-else-if-into-elif
for i in self.keys:
keys = (
self.keys()
if PYG_VERSION >= looseversion.LooseVersion("2.4.0")
else self.keys
)
for i in keys:
attr_self = getattr(self, i)
attr_other = getattr(__o, i)

Expand Down
8 changes: 4 additions & 4 deletions tests/protein/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_construct_graph():
def test_construct_graph_with_dssp():
"""Makes sure protein graphs can be constructed with dssp

Uses uses both a pdb code (6REW) and a local pdb file to do so.
Uses uses both a pdb code (6YC3) and a local pdb file to do so.
"""
dssp_config_functions = {
"edge_construction_functions": [
Expand All @@ -129,10 +129,10 @@ def test_construct_graph_with_dssp():
dssp_prot_config = ProteinGraphConfig(**dssp_config_functions)

g_pdb = construct_graph(
config=dssp_prot_config, pdb_code="6rew"
) # should download 6rew.pdb to pdb_dir
config=dssp_prot_config, pdb_code="6yc3"
) # should download 6yc3.pdb to pdb_dir

assert g_pdb.graph["pdb_code"] == "6rew"
assert g_pdb.graph["pdb_code"] == "6yc3"
assert g_pdb.graph["path"] is None
assert g_pdb.graph["name"] == g_pdb.graph["pdb_code"]
assert len(g_pdb.graph["dssp_df"]) == 1365
Expand Down