Skip to content

Commit

Permalink
Improve data loading performance for FoldComp Datasets by 10x (#313)
Browse files Browse the repository at this point in the history
* add PSW to nonstandard residues

* improve insertion and non-standard residue handling

* refactor chain selection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused verbosity arg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix chain selection in tests

* fix chain selection in tutorial notebook

* fix notebook chain selection

* fix chain selection typehint

* Update changelog

* Add NLW to non-standard residues

* add .ent support

* add entry for construction from dataframe

* add missing stage arg

* improve obsolete mapping retrieving to include entries with no replacement

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update changelog

* add transforms to foldcomp datasets

* fix jaxtyping syntax

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update changelog

* fix double application of transforms

* improve foldcomp data loading performance

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused imports

* linting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update changelog

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
a-r-j and pre-commit-ci[bot] committed May 1, 2023
1 parent e982aa1 commit bd29a29
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 31 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* Improved handling of non-standard residues in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307)
* Insertions retained by default in the `graphein.protein.tensor` module. I.e. `insertions=True` is now the default behaviour.[#307](https://github.com/a-r-j/graphein/pull/307)
* Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312)

* Improve FoldComp dataloading performance [#313](https://github.com/a-r-j/graphein/pull/313)

### 1.7.0 - UNRELEASED

Expand Down
142 changes: 112 additions & 30 deletions graphein/ml/datasets/foldcomp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@
import random
import shutil
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import pandas as pd
from biopandas.pdb import PandasPdb
import numpy as np
import torch
from loguru import logger as log
from sklearn.model_selection import train_test_split
from torch_geometric import transforms as T
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from graphein.protein.resi_atoms import (
ATOM_NUMBERING,
STANDARD_AMINO_ACID_MAPPING_1_TO_3,
STANDARD_AMINO_ACIDS,
)
from graphein.protein.tensor import Protein
from graphein.utils.dependencies import import_message

Expand Down Expand Up @@ -67,6 +72,58 @@
GraphTransform = Callable[[Union[Data, Protein]], Union[Data, Protein]]


ATOM_MAP = {
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
"ALA": ["N", "CA", "C", "O", "CB"],
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
"ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
"HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
"TYR": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"OH",
],
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
"PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
"GLY": ["N", "CA", "C", "O"],
"SER": ["N", "CA", "C", "O", "CB", "OG"],
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
"CYS": ["N", "CA", "C", "O", "CB", "SG"],
"TRP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"NE1",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
],
}


class FoldCompDataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -112,17 +169,15 @@ def __init__(
self.use_graphein = use_graphein
self.transform = transform

_database_files = [
"$db",
"$db.dbtype",
"$db.index",
"$db.lookup",
"$db.source",
]
self.database_files = [
f.replace("$db", self.database) for f in _database_files
self._database_files = [
f"{self.database}",
f"{self.database}.dbtype",
f"{self.database}.index",
f"{self.database}.lookup",
f"{self.database}.source",
]
self._get_indices()

super().__init__(
root=self.root, transform=self.transform, pre_transform=None # type: ignore
)
Expand All @@ -146,7 +201,9 @@ def processed_file_names(self):
def download(self):
"""Downloads foldcomp database if not already downloaded."""

if not all(os.path.exists(self.root / f) for f in self.database_files):
if not all(
os.path.exists(self.root / f) for f in self._database_files
):
log.info(f"Downloading FoldComp dataset {self.database}...")
try:
foldcomp.setup(self.database)
Expand All @@ -156,7 +213,7 @@ def download(self):
log.info("Download complete.")
log.info("Moving files to raw directory...")

for f in self.database_files:
for f in self._database_files:
shutil.move(f, self.root)
else:
log.info(f"FoldComp database already downloaded: {self.root}.")
Expand Down Expand Up @@ -203,23 +260,47 @@ def process(self):
# Open the database
log.info("Opening database...")
if self.ids is not None:
self.db = foldcomp.open(self.root / self.database, ids=self.ids) # type: ignore
self.db = foldcomp.open(self.root / self.database, ids=self.ids, decompress=False) # type: ignore
else:
self.db = foldcomp.open(self.root / self.database) # type: ignore
self.db = foldcomp.open(self.root / self.database, decompress=False) # type: ignore

@staticmethod
def _parse_dataframe(pdb_string: str) -> pd.DataFrame:
"""Reads a PDB string into a Pandas dataframe."""
pdb: List[str] = pdb_string.split("\n")
return PandasPdb().read_pdb_from_list(pdb).df["ATOM"]

def process_pdb(self, pdb_string: str, name: str) -> Union[Protein, Data]:
"""Process a PDB string into a Graphein Protein object."""
df = self._parse_dataframe(pdb_string)
data = Protein().from_dataframe(df, id=name)
if not self.use_graphein:
data = data.to_data()
return data
def fc_to_pyg(data: Dict[str, Any], name: Optional[str] = None) -> Protein:
# Map sequence to 3-letter codes
res = [STANDARD_AMINO_ACID_MAPPING_1_TO_3[r] for r in data["residues"]]
residue_type = torch.tensor(
[STANDARD_AMINO_ACIDS.index(res) for res in data["residues"]],
)

# Get residue numbers
res_num = [i for i, _ in enumerate(res)]

# Get list of atom types
atom_types = []
atom_counts = []
for r in res:
atom_types += ATOM_MAP[r]
atom_counts.append(len(ATOM_MAP[r]))
atom_types += ["OXT"]
atom_counts[-1] += 1

# Get atom indices
atom_idx = np.array([ATOM_NUMBERING[atm] for atm in atom_types])

# Initialize coordinates
coords = np.ones((len(res), 37, 3)) * 1e-5

res_idx = np.repeat(res_num, atom_counts)
coords[res_idx, atom_idx, :] = np.array(data["coordinates"])

return Protein(
coords=torch.from_numpy(coords).float(),
residues=res,
residue_id=[f"A:{m}:{str(n)}" for m, n in zip(res, res_num)],
chains=torch.zeros(len(res)),
residue_type=residue_type.long(),
id=name,
)

def len(self) -> int:
"""Returns length of the dataset"""
Expand All @@ -230,9 +311,10 @@ def get(self, idx) -> Union[Data, Protein]:
ID or its index."""
if isinstance(idx, str):
idx = self.protein_to_idx[idx]
name, pdb = self.db[idx]

return self.process_pdb(pdb, name)
name = self.idx_to_protein[idx]
data = foldcomp.get_data(self.db[idx]) # type: ignore
return self.fc_to_pyg(data, name)


class FoldCompLightningDataModule(L.LightningDataModule):
Expand Down

0 comments on commit bd29a29

Please sign in to comment.