Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new helper functions to PDBManager #322

Merged
merged 14 commits into from
May 25, 2023
114 changes: 100 additions & 14 deletions graphein/ml/datasets/pdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,64 @@ def experiment_type(
if update:
self.df = df
return df

def experiment_types(
self,
types: List[str] = ["diffraction"],
splits: Optional[List[str]] = None,
update: bool = False,
) -> pd.DataFrame:
"""
Select molecules by experiment types:
[``diffraction``, ``NMR``, ``EM``, ``other``]

:param types: Experiment types of molecules, defaults to "diffraction".
:type types: List[str], optional
:param splits: Names of splits for which to perform the operation,
defaults to ``None``.
:type splits: Optional[List[str]], optional
:param update: Whether to modify the DataFrame in place, defaults to
``False``.
:type update: bool, optional

:return: DataFrame of selected molecules.
:rtype: pd.DataFrame
"""
splits_df = self.get_splits(splits)
df = splits_df.loc[splits_df.experiment_type.isin(types)]

if update:
self.df = df
return df

def name(
self,
substrings: List[str],
splits: Optional[List[str]] = None,
update: bool = False,
) -> pd.DataFrame:
"""
Select molecules by substrings present in their names:
e.g., [``DNA``, ``RNA``]

:param substrings: Substrings to be found within the name field of each molecule.
:type type: str, optional
:param splits: Names of splits for which to perform the operation,
defaults to ``None``.
:type splits: Optional[List[str]], optional
:param update: Whether to modify the DataFrame in place, defaults to
``False``.
:type update: bool, optional

:return: DataFrame of selected molecules.
:rtype: pd.DataFrame
"""
splits_df = self.get_splits(splits)
df = splits_df.loc[splits_df.name.str.contains("|".join(substrings))]

if update:
self.df = df
return df

def compare_length(
self,
Expand Down Expand Up @@ -1055,6 +1113,33 @@ def remove_non_standard_alphabet_sequences(
if update:
self.df = df
return df

def select_complexes_with_grouped_molecule_types(
self, molecule_types_to_group: List[str], splits: Optional[List[str]] = None, update: bool = False
):
"""
Select complexes containing at least one instance of each
provided molecule type.

:param molecule_types_to_group: Names of molecule types by which to assemble complexes.
:type molecule_types_to_group: List[str]
:param splits: Names of splits for which to perform the operation,
defaults to ``None``.
:type splits: Optional[List[str]], optional
:param update: Whether to update the DataFrame in place, defaults to
``False``.
:type update: bool, optional

:return: DataFrame containing only complexes with at least one instance
of each provided molecule type.
:rtype: pd.DataFrame
"""
splits_df = self.get_splits(splits)
df = splits_df.groupby("pdb").filter(
lambda group: all([molecule_type_to_group in group["molecule_type"].values for molecule_type_to_group in molecule_types_to_group])
)
if update:
self.df = df

def split_df_proportionally(
self,
Expand Down Expand Up @@ -1693,7 +1778,7 @@ def merge_pdb_chain_groups(self, group: DataFrameGroupBy) -> pd.DataFrame:

def select_pdb_by_criterion(
self, pdb: PandasPdb, field: str, field_values: List[Any]
) -> PandasPdb:
) -> Optional[PandasPdb]:
"""Filter a PDB using a field selection.

:param pdb: The PDB object to filter by a field.
Expand All @@ -1704,18 +1789,18 @@ def select_pdb_by_criterion(
the PDB.
:type field_values: List[Any]

:return: The filtered PDB object.
:rtype: PandasPdb
:return: The filtered PDB object or instead `None` to signify
that no atoms within the PDB object were found after filtering.
:rtype: Optional[PandasPdb], optional
"""
for key in pdb.df:
if field in pdb.df[key]:
filtered_pdb = pdb.df[key][
pdb.df[key][field].isin(field_values)
]
if "ATOM" in key:
assert (
len(filtered_pdb) > 0
), "Filtered DataFrame must contain atoms."
if "ATOM" in key and len(filtered_pdb) == 0:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be being a little pedantic here but I see two possible edge cases with this:

  1. Sometimes protein atoms are stored as HETATMs (typically modified residues but this kind of bad practice does happen as an abuse of the PDB format to suit some niche way to store structure data)
  2. Similarly, what if the desired selection is actually the HETATM data? Protein-nucleic acid complexes or protein-peptide complexes may store the ligand as a HETATM

Copy link
Contributor Author

@amorehead amorehead May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. Points 1 and 2 didn't come to my mind when I originally implemented this, and I kept expanding on it until now. In light of these points, I think it makes more sense to avoid skipping such DataFrames altogether. It will then be the user's responsibility to "vet" the exported PDB files for these kinds of edge cases in their selected PDBs. Better yet, we can simply issue a warning to users that no "standard" atoms were found post-filtering. However, we would then still export the PDB complex as requested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified this logic to only issue a warning in such an edge case now.

log.warning("Filtered DataFrame does not contain any atoms. Skipping DataFrame...")
return None
pdb.df[key] = filtered_pdb
return pdb

Expand All @@ -1727,7 +1812,7 @@ def write_out_pdb_chain_groups(
split: str,
merge_fn: Callable,
atom_df_name: str = "ATOM",
max_num_chains_per_pdb_code: int = 1,
max_num_chains_per_pdb_code: int = -1,
models: List[int] = [1],
):
"""Record groups of PDB codes and associated chains
Expand All @@ -1748,7 +1833,7 @@ def write_out_pdb_chain_groups(
ATOM entries within a PandasPdb object.
:type atom_df_name: str, defaults to ``ATOM``
:param max_num_chains_per_pdb_code: Maximum number of chains
to collate into a matching PDB file.
to collate into a matching PDB file, defaults to ``-1``.
:type max_num_chains_per_pdb_code: int, optional
:param models: List of indices of models from which to extract chains,
defaults to ``[1]``.
Expand Down Expand Up @@ -1806,15 +1891,16 @@ def write_out_pdb_chain_groups(
pdb, "chain_id", chains
)
# export selected chains within the same PDB file
pdb_chains.to_pdb(str(output_pdb_filepath))
if pdb_chains:
pdb_chains.to_pdb(str(output_pdb_filepath))

def write_df_pdbs(
self,
pdb_dir: str,
df: pd.DataFrame,
out_dir: str = "collated_pdb",
splits: Optional[List[str]] = None,
max_num_chains_per_pdb_code: int = 1,
max_num_chains_per_pdb_code: int = -1,
models: List[int] = [1],
):
"""Write the given selection as a collection of PDB files.
Expand All @@ -1831,7 +1917,7 @@ def write_df_pdbs(
defaults to ``None``.
:type splits: Optional[List[str]], optional
:param max_num_chains_per_pdb_code: Maximum number of chains
to collate into a matching PDB file.
to collate into a matching PDB file, defaults to ``-1``.
:type max_num_chains_per_pdb_code: int, optional
:param models: List of indices of models from which to extract chains,
defaults to ``[1]``.
Expand Down Expand Up @@ -1867,7 +1953,7 @@ def export_pdbs(
self,
pdb_dir: str,
splits: Optional[List[str]] = None,
max_num_chains_per_pdb_code: int = 1,
max_num_chains_per_pdb_code: int = -1,
models: List[int] = [1],
force: bool = False,
):
Expand All @@ -1879,7 +1965,7 @@ def export_pdbs(
defaults to ``None``.
:type splits: Optional[List[str]], optional
:param max_num_chains_per_pdb_code: Maximum number of chains
to collate into a matching PDB file.
to collate into a matching PDB file, defaults to ``-1``.
a-r-j marked this conversation as resolved.
Show resolved Hide resolved
:type max_num_chains_per_pdb_code: int, optional
:param models: List of indices of models from which to extract chains,
defaults to ``[1]``.
Expand Down