Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cpdb support to speed up parsing #323

Merged
merged 37 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
146b2af
add cpdb to speed up parsing
Jun 26, 2023
1ae45e3
reorder requirements
Jun 26, 2023
34b34c6
reorder requirements
a-r-j Jun 26, 2023
28d8b54
Add blank columns to write PDBs
a-r-j Jul 4, 2023
28ea60e
Update changelog
a-r-j Jul 4, 2023
f13eb4a
update range indexing
a-r-j Jul 4, 2023
3c5db51
add blank segment_id column if necessary
a-r-j Jul 6, 2023
eae189a
Merge branch 'master' into cpdb
a-r-j Aug 20, 2023
984cbbe
pin cpdb version
a-r-j Aug 31, 2023
ea4f82b
update test
a-r-j Aug 31, 2023
0ecf1a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2023
6053e20
updates to save_pdb function
a-r-j Aug 31, 2023
8702990
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2023
0f06c9d
Merge branch 'master' into cpdb
Oct 26, 2023
45b020e
fix broken pdb writer
a-r-j Oct 28, 2023
d6647be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2023
b7a0384
unpin numpy dependency
a-r-j Oct 28, 2023
6069a09
add missing numpy import
a-r-j Oct 28, 2023
f7fde76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2023
284cc4c
resolve test dtype
a-r-j Oct 28, 2023
a5f8959
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2023
62252b3
fix test dtype
a-r-j Oct 28, 2023
4174aec
fix type error in charge writing
a-r-j Oct 28, 2023
d455902
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2023
4bc1bad
fix test syntax error
a-r-j Oct 28, 2023
7d869c5
format charge correctly
a-r-j Oct 28, 2023
b874cf9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2023
cc4fff8
Merge branch 'master' into cpdb
a-r-j Feb 6, 2024
d45de4f
modify tests to use CPDB
a-r-j Feb 6, 2024
ee06903
fix syntax error
a-r-j Feb 6, 2024
a8da5df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
a73e13c
Merge branch 'master' into cpdb
Feb 7, 2024
9f4eda1
fix column drops
a-r-j Feb 7, 2024
0807d5c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2024
5d429aa
Merge branch 'master' into cpdb
a-r-j Aug 3, 2024
5f9bf2c
fix remaining tests after adding CPDB parser backend
a-r-j Aug 3, 2024
787006d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ pandas<2.0.0
biopandas>=0.5.1
biopython
bioservices>=1.10.0
cpdb-protein==0.2.0
cython
deepdiff
loguru
looseversion
matplotlib>=3.4.3
multipledispatch
networkx
numpy<1.24.0
numpy
pandas
plotly
pydantic
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ https://github.com/a-r-j/graphein/pull/334

#### Other Changes

- Uses [`cpdb`](https://github.com/a-r-j/CPDB) as default PDB file parser for improved performance. [#323](https://github.com/a-r-j/graphein/pull/323).
- Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312)
- Adds entry point for biopandas dataframes in `graphein.protein.tensor.io.protein_to_pyg`. [#310](https://github.com/a-r-j/graphein/pull/310)
- Adds support for `.ent` files to `graphein.protein.graphs.read_pdb_to_dataframe`. [#310](https://github.com/a-r-j/graphein/pull/310)
Expand Down
30 changes: 20 additions & 10 deletions graphein/protein/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import cpdb
import networkx as nx
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -109,32 +110,41 @@ def read_pdb_to_dataframe(
or path.endswith(".pdb.gz")
or path.endswith(".ent")
):
atomic_df = PandasPdb().read_pdb(path)
atomic_df = cpdb.parse(path)
elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"):
atomic_df = PandasMmtf().read_mmtf(path)
atomic_df = atomic_df.get_model(model_index)
atomic_df = pd.concat(
[atomic_df.df["ATOM"], atomic_df.df["HETATM"]]
)
elif (
path.endswith(".cif")
or path.endswith(".cif.gz")
or path.endswith(".mmcif")
or path.endswith(".mmcif.gz")
):
atomic_df = PandasMmcif().read_mmcif(path)
atomic_df = atomic_df.get_model(model_index)
atomic_df = atomic_df.convert_to_pandas_pdb()
atomic_df = pd.concat(
[atomic_df.df["ATOM"], atomic_df.df["HETATM"]]
)
else:
raise ValueError(
f"File {path} must be either .pdb(.gz), .mmtf(.gz), .(mm)cif(.gz) or .ent, not {path.split('.')[-1]}"
)
elif uniprot_id is not None:
atomic_df = PandasPdb().fetch_pdb(
uniprot_id=uniprot_id, source="alphafold2-v3"
)
atomic_df = cpdb.parse(uniprot_id=uniprot_id)
else:
atomic_df = PandasPdb().fetch_pdb(pdb_code)
atomic_df = atomic_df.get_model(model_index)
if len(atomic_df.df["ATOM"]) == 0:
atomic_df = cpdb.parse(pdb_code=pdb_code)

if "model_idx" in atomic_df.columns:
atomic_df = atomic_df.loc[atomic_df["model_idx"] == model_index]

if len(atomic_df) == 0:
raise ValueError(f"No model found for index: {model_index}")
if isinstance(atomic_df, PandasMmcif):
atomic_df = atomic_df.convert_to_pandas_pdb()
return pd.concat([atomic_df.df["ATOM"], atomic_df.df["HETATM"]])

return atomic_df


def label_node_id(
Expand Down
86 changes: 74 additions & 12 deletions graphein/protein/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from urllib.request import urlopen

import networkx as nx
import numpy as np
import pandas as pd
import requests
import wget
Expand All @@ -25,6 +26,30 @@

from .resi_atoms import BACKBONE_ATOMS, RESI_THREE_TO_1

pdb_df_columns = [
"record_name",
"atom_number",
"blank_1",
"atom_name",
"alt_loc",
"residue_name",
"blank_2",
"chain_id",
"residue_number",
"insertion",
"blank_3",
"x_coord",
"y_coord",
"z_coord",
"occupancy",
"b_factor",
"blank_4",
"segment_id",
"element_symbol",
"charge",
"line_idx",
]


class ProteinGraphConfigurationError(Exception):
"""
Expand Down Expand Up @@ -418,12 +443,27 @@ def save_graph_to_pdb(
:type gz: bool
"""
ppd = PandasPdb()
atom_df = filter_dataframe(
g.graph["pdb_df"], "record_name", ["ATOM"], boolean=True
)
hetatm_df = filter_dataframe(
g.graph["pdb_df"], "record_name", ["HETATM"], boolean=True
)

df = g.graph["pdb_df"].copy()
# format charge correctly
df.charge = pd.to_numeric(df.charge, errors="coerce")

# Add blank columns
blank_cols = [
"blank_1",
"blank_2",
"blank_3",
"blank_4",
"segment_id",
]
for col in blank_cols:
if col not in df.columns:
df[col] = ""
df["line_idx"] = list(range(1, len(df) + 1))
df = df[pdb_df_columns]
atom_df = filter_dataframe(df, "record_name", ["ATOM"], boolean=True)
hetatm_df = filter_dataframe(df, "record_name", ["HETATM"], boolean=True)

if atoms:
ppd.df["ATOM"] = atom_df
if hetatms:
Expand All @@ -448,9 +488,22 @@ def save_pdb_df_to_pdb(
:param gz: Whether to gzip the file. Defaults to ``False``.
:type gz: bool
"""
df = df.copy()
# format charge correctly
df.charge = pd.to_numeric(df.charge, errors="coerce")
df.alt_loc = df.alt_loc.fillna(" ")
blank_cols = ["blank_1", "blank_2", "blank_3", "blank_4", "segment_id"]
for col in blank_cols:
if col not in df.columns:
df[col] = ""
df["line_idx"] = list(range(1, len(df) + 1))
df = df[pdb_df_columns]

atom_df = filter_dataframe(df, "record_name", ["ATOM"], boolean=True)
hetatm_df = filter_dataframe(df, "record_name", ["HETATM"], boolean=True)

ppd = PandasPdb()

if atoms:
ppd.df["ATOM"] = atom_df
if hetatms:
Expand Down Expand Up @@ -481,12 +534,21 @@ def save_rgroup_df_to_pdb(
:type gz: bool
"""
ppd = PandasPdb()
atom_df = filter_dataframe(
g.graph["rgroup_df"], "record_name", ["ATOM"], boolean=True
)
hetatm_df = filter_dataframe(
g.graph["rgroup_df"], "record_name", ["HETATM"], boolean=True
)
df = g.graph["rgroup_df"].copy()

# format charge correctly
df.charge = pd.to_numeric(df.charge, errors="coerce")

blank_cols = ["blank_1", "blank_2", "blank_3", "blank_4", "segment_id"]
for col in blank_cols:
if col not in df.columns:
df[col] = [""] * len(df)
df["line_idx"] = list(range(1, len(df) + 1))
df = df[pdb_df_columns]

atom_df = filter_dataframe(df, "record_name", ["ATOM"], boolean=True)
hetatm_df = filter_dataframe(df, "record_name", ["HETATM"], boolean=True)

if atoms:
ppd.df["ATOM"] = atom_df
if hetatms:
Expand Down
2 changes: 1 addition & 1 deletion tests/protein/tensor/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
def test_save_and_load_protein():
a = Protein().from_pdb_code("4hhb")
torch.save(a, "4hhb.pt")
b = torch.load("4hhb.pt")
b = torch.load("4hhb.pt", weights_only=False)
assert a == b
1 change: 0 additions & 1 deletion tests/protein/tensor/test_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,3 @@ def test_dist_mat_to_coords():
assert torch.allclose(d, torch.cdist(X, X), atol=1e-4)
X_aligned = kabsch(X, coords)
assert torch.allclose(coords, X_aligned, atol=1e-4)
return coords, X, X_aligned
5 changes: 4 additions & 1 deletion tests/protein/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,10 @@ def test_alt_loc_exclusion():
):
config.alt_locs = opt
g = construct_graph(config=config, pdb_code="2VVI")
assert np.array_equal(g.nodes[node_id]["coords"], expected_coords)
assert np.array_equal(
g.nodes[node_id]["coords"],
np.array(expected_coords, dtype=np.float32),
)


def test_alt_loc_inclusion():
Expand Down
38 changes: 29 additions & 9 deletions tests/protein/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,24 @@ def test_save_graph_to_pdb():
# Check file exists
assert os.path.isfile("/tmp/test_graph.pdb")

# Check for equivalence between saved and existing DFs.
# We drop the line_idx columns as these will be renumbered
graph_df = (
g.graph["pdb_df"]
.drop(
[
"node_id",
"residue_id",
],
axis=1,
)
.reset_index(drop=True)
)

a.reset_index(drop=True, inplace=True)
a = a[graph_df.columns] # Reorder columns

assert_frame_equal(
a.drop(["line_idx"], axis=1),
g.graph["pdb_df"].drop(["line_idx", "node_id", "residue_id"], axis=1),
a,
graph_df,
)
h = construct_graph(path="/tmp/test_graph.pdb")

Expand All @@ -48,10 +61,17 @@ def test_save_pdb_df_to_pdb():
# Check file exists
assert os.path.isfile("/tmp/test_graph.pdb")

# We drop the line_idx columns as these will be renumbered
assert_frame_equal(
a.drop(["line_idx"], axis=1),
g.graph["pdb_df"].drop(["line_idx", "node_id", "residue_id"], axis=1),
a,
g.graph["pdb_df"]
.drop(
[
"node_id",
"residue_id",
],
axis=1,
)
.reset_index(drop=True),
)

# Now check for raw, unprocessed DF
Expand All @@ -73,10 +93,10 @@ def test_save_rgroup_df_to_pdb():

# We drop the line_idx columns as these will be renumbered
assert_frame_equal(
a.drop(["line_idx"], axis=1),
a,
filter_dataframe(
g.graph["rgroup_df"], "record_name", ["HETATM"], False
).drop(["line_idx", "node_id", "residue_id"], axis=1),
).drop(["node_id", "residue_id"], axis=1),
)


Expand Down
Loading