Skip to content

Commit

Permalink
Fix jaxtyping syntax error (#312)
Browse files Browse the repository at this point in the history
* add PSW to nonstandard residues

* improve insertion and non-standard residue handling

* refactor chain selection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused verbosity arg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix chain selection in tests

* fix chain selection in tutorial notebook

* fix notebook chain selection

* fix chain selection typehint

* Update changelog

* Add NLW to non-standard residues

* add .ent support

* add entry for construction from dataframe

* add missing stage arg

* improve obsolete mapping retrieving to include entries with no replacement

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update changelog

* add transforms to foldcomp datasets

* fix jaxtyping syntax

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update changelog

* fix double application of transforms

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
a-r-j and pre-commit-ci[bot] committed Apr 27, 2023
1 parent a7e5d23 commit af2b2e0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 31 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#### Bugfixes
* Adds missing `stage` parameter to `graphein.ml.datasets.foldcomp_data.FoldCompDataModule.setup()`. [#310](https://github.com/a-r-j/graphein/pull/310)
* Fixes incorrect jaxtyping syntax for variable size dimensions [#312](https://github.com/a-r-j/graphein/pull/312)

#### Other Changes
* Adds entry point for biopandas dataframes in `graphein.protein.tensor.io.protein_to_pyg`. [#310](https://github.com/a-r-j/graphein/pull/310)
Expand All @@ -14,7 +15,7 @@
* Adds the ability to store a dictionary of HETATM positions in `Data`/`Protein` objects created in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307)
* Improved handling of non-standard residues in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307)
* Insertions retained by default in the `graphein.protein.tensor` module. I.e. `insertions=True` is now the default behaviour.[#307](https://github.com/a-r-j/graphein/pull/307)

* Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312)


### 1.7.0 - UNRELEASED
Expand Down
34 changes: 19 additions & 15 deletions graphein/ml/datasets/foldcomp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
# Code Repository: https://github.com/a-r-j/graphein

import asyncio
import contextlib
import os
import random
import shutil
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, Iterable, List, Optional, Union

import pandas as pd
from biopandas.pdb import PandasPdb
from loguru import logger as log
from sklearn.model_selection import train_test_split
from torch_geometric import transforms as T
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from tqdm import tqdm
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
exclude_ids: Optional[List[str]] = None,
fraction: float = 1.0,
use_graphein: bool = True,
transform: Optional[List[GraphTransform]] = None,
transform: Optional[T.BaseTransform] = None,
):
"""Dataset class for FoldComp databases.
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
]
self._get_indices()
super().__init__(
root=self.root, transform=None, pre_transform=None # type: ignore
root=self.root, transform=self.transform, pre_transform=None # type: ignore
)

@property
Expand Down Expand Up @@ -232,14 +232,7 @@ def get(self, idx) -> Union[Data, Protein]:
idx = self.protein_to_idx[idx]
name, pdb = self.db[idx]

out = self.process_pdb(pdb, name)

# Apply transforms, if any
if self.transform is not None:
for transform in self.transform:
out = transform(out)

return out
return self.process_pdb(pdb, name)


class FoldCompLightningDataModule(L.LightningDataModule):
Expand All @@ -252,7 +245,7 @@ def __init__(
train_split: Optional[Union[List[str], float]] = None,
val_split: Optional[Union[List[str], float]] = None,
test_split: Optional[Union[List[str], float]] = None,
transform: Optional[List[GraphTransform]] = None,
transform: Optional[Iterable[Callable]] = None,
num_workers: int = 4,
pin_memory: bool = True,
) -> None:
Expand Down Expand Up @@ -281,7 +274,7 @@ def __init__(
``Data``/``Protein`` object and return a transformed version.
The data object will be transformed before every access.
(default: ``None``).
:type transform: Optional[List[GraphTransform]]
:type transform: Optional[Iterable[Callable]]
:param num_workers: Number of workers to use for data loading, defaults
to ``4``.
:type num_workers: int, optional
Expand All @@ -297,7 +290,12 @@ def __init__(
self.train_split = train_split
self.val_split = val_split
self.test_split = test_split
self.transform = transform
self.transform = (
self._compose_transforms(transform)
if transform is not None
else None
)

if (
isinstance(train_split, float)
and isinstance(val_split, float)
Expand All @@ -311,6 +309,12 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory

def _compose_transforms(self, transforms: Iterable[Callable]) -> T.Compose:
try:
return T.Compose(list(transforms.values()))
except Exception:
return T.Compose(transforms)

def setup(self, stage: Optional[str] = None):
self.train_dataset()
self.val_dataset()
Expand Down
38 changes: 23 additions & 15 deletions graphein/protein/tensor/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
# Code Repository: https://github.com/a-r-j/graphein
from typing import NewType, Optional, Union

from jaxtyping import Float
from jaxtyping import Float, Int
from torch import Tensor

# Positions
AtomTensor = NewType("AtomTensor", Float[Tensor, "-1 37 3"])
AtomTensor = NewType("AtomTensor", Float[Tensor, "residues 37 3"])
"""
``torch.float[-1, 37, 3]``
Expand All @@ -24,7 +24,7 @@
.. seealso:: :class:`ResidueTensor` :class:`CoordTensor`
"""

BackboneTensor = NewType("BackboneTensor", Float[Tensor, "-1 4 3"])
BackboneTensor = NewType("BackboneTensor", Float[Tensor, "residues 4 3"])
"""
``torch.float[-1, 4, 3]``
Expand All @@ -49,7 +49,7 @@
"""


CoordTensor = NewType("CoordTensor", Float[Tensor, "-1 3"])
CoordTensor = NewType("CoordTensor", Float[Tensor, "nodes 3"])
"""
``torch.float[-1, 3]``
Expand All @@ -68,7 +68,9 @@
"""

# Represenations
BackboneFrameTensor = NewType("BackboneFrameTensor", Float[Tensor, "-1 3 3"])
BackboneFrameTensor = NewType(
"BackboneFrameTensor", Float[Tensor, "residues 3 3"]
)
"""
``torch.float[-1, 3, 3]``
Expand All @@ -89,9 +91,9 @@


# Rotations
EulerAngleTensor = NewType("EulerAngleTensor", Float[Tensor, "-1 3"])
EulerAngleTensor = NewType("EulerAngleTensor", Float[Tensor, "nodes 3"])

QuaternionTensor = NewType("QuaternionTensor", Float[Tensor, "-1 4"])
QuaternionTensor = NewType("QuaternionTensor", Float[Tensor, "nodes 4"])
"""
``torch.float[-1, 4]``
Expand All @@ -102,7 +104,7 @@
"""


TransformTensor = NewType("TransformTensor", Float[Tensor, "-1 4 4"])
TransformTensor = NewType("TransformTensor", Float[Tensor, "nodes 4 4"])


RotationMatrix2D = NewType("RotationMatrix2D", Float[Tensor, "2 2"])
Expand Down Expand Up @@ -135,7 +137,9 @@
"""


RotationMatrixTensor = NewType("RotationMatrixTensor", Float[Tensor, "-1 3 3"])
RotationMatrixTensor = NewType(
"RotationMatrixTensor", Float[Tensor, "nodes 3 3"]
)

RotationTensor = NewType(
"RotationTensor", Union[QuaternionTensor, RotationMatrixTensor]
Expand All @@ -144,7 +148,8 @@

# Angles
DihedralTensor = NewType(
"DihedralTensor", Union[Float[Tensor, "-1 3"], Float[Tensor, "-1 6"]]
"DihedralTensor",
Union[Float[Tensor, "residues 3"], Float[Tensor, "residues 6"]],
)
"""
``Union[torch.float[-1, 3], torch.float[-1, 6]]``
Expand All @@ -161,7 +166,8 @@
"""

TorsionTensor = NewType(
"TorsionTensor", Union[Float[Tensor, "-1 4"], Float[Tensor, "-1 8"]]
"TorsionTensor",
Union[Float[Tensor, "residues 4"], Float[Tensor, "residues 8"]],
)
"""
``Union[torch.float[-1, 4], torch.float[-1, 8]]``
Expand All @@ -177,7 +183,9 @@
"""

BackboneFrameTensor = NewType("BackboneFrameTensor", Float[Tensor, "-1 3 3"])
BackboneFrameTensor = NewType(
"BackboneFrameTensor", Float[Tensor, "residues 3 3"]
)
"""
``torch.float[-1, 3, 3]``
Expand All @@ -198,12 +206,12 @@
.. seealso:: :class:`BackboneFrameTensor`
"""

EdgeTensor = NewType("EdgeTensor", Float[Tensor, "2 -1"])
EdgeTensor = NewType("EdgeTensor", Int[Tensor, "2 edges"])


OrientationTensor = NewType("OrientationTensor", Float[Tensor, "-1 2 3"])
OrientationTensor = NewType("OrientationTensor", Float[Tensor, "nodes 2 3"])


ScalarTensor = NewType("ScalarTensor", Float[Tensor, "-1"])
ScalarTensor = NewType("ScalarTensor", Float[Tensor, "nodes"])

OptTensor = NewType("OptTensor", Optional[Tensor])

0 comments on commit af2b2e0

Please sign in to comment.