Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
added the docstrings and some extra checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Coos Baakman committed Mar 24, 2021
1 parent be8d1d8 commit bb6f5d5
Showing 1 changed file with 168 additions and 12 deletions.
180 changes: 168 additions & 12 deletions deeprank/features/ResidueContacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,29 @@


class ResidueSynonymCriteria:
"""The ResidueSynonymCriteria is an object that holds the criteria
for a residue to have a certain synonym. It does not hold the synonym string itself however.
"""

def __init__(self, residue_name, atoms_present, atoms_absent):
"""Build new criteria
residue_name (string): the name of the residue
atoms_present (list of strings): the names of the atoms that should be present in the residue
atoms_absent (list of strings) the names of the atoms that should be absent in the residue
"""

self.residue_name = residue_name
self.atoms_present = atoms_present
self.atoms_absent = atoms_absent

def matches(self, residue_name, atom_names):
"""Check whether the given residue matches this set of criteria
residue_name (string): the name of the residue to match
atom_names (list of strings): the names of the atoms in the residue
"""

if self.residue_name != 'all' and residue_name != self.residue_name:
return False

Expand All @@ -36,15 +53,40 @@ def matches(self, residue_name, atom_names):


def get_squared_distance(pos1, pos2):
"""Get the squared distance between two positions.
This is less computationally expensive than the normal distance
and should be used if one does not necessarily need
the normal distance for a large set of coordinates.
pos1 (array of 3): the xyz coords of position 1
pos2 (array of 3): the xyz coords of position 2
"""

return numpy.sum(numpy.square(pos1 - pos2))

def get_distance(pos1, pos2):
"""Get the distance between two positions.
This is more computationally expensive than the squared distance
and should only be used when one really needs this distance.
pos1 (array of 3): the xyz coords of position 1
pos2 (array of 3): the xyz coords of position 2
"""

return numpy.sqrt(get_squared_distance(pos1, pos2))

def wrap_values_in_lists(dict_):
"""Wrap the dictionary's values in lists. This
appears to be necessary for the exported features to work.
dict_(dictionary): the dictionary that should be converted
"""

return {key: [value] for key,value in dict_.items()}

class _PhysicsStorage:
"A helper object that holds the physics values while summing them"

ATOM_KEY = ["chainID", "resSeq", "resName", "name"]

EPSILON0 = 1.0
Expand All @@ -58,13 +100,29 @@ class _PhysicsStorage:

@staticmethod
def _sum_up(dict_, key, value_to_add):
"""A helper function to sum the values of a dictionary.
dict_(dictionary): the dictionary to store the value in
key(hashable object): the key under which the value should be stored in dict_
value_to_add: the value to add onto the dictionary value
"""

if key not in dict_:
dict_[key] = 0.0

dict_[key] += value_to_add

@staticmethod
def get_vanderwaals_energy(epsilon1, sigma1, epsilon2, sigma2, distance):
"""The formula to calculate the vanderwaals energy for two atoms (atom 1 and atom 2)
epsilon1 (float): the vanderwaals epsilon parameter of atom 1
sigma1 (float): the vanderwaals sigma parameter of atom 1
epsilon2 (float): the vanderwaals epsilon parameter of atom 2
sigma2 (float): the vanderwaals sigma parameter of atom 2
distance (float): the vanderwaals distance between atom 1 and atom 2
"""

average_epsilon = numpy.sqrt(epsilon1 * epsilon2)
average_sigma = 0.5 * (sigma1 + sigma2)

Expand All @@ -83,17 +141,39 @@ def get_vanderwaals_energy(epsilon1, sigma1, epsilon2, sigma2, distance):

@staticmethod
def get_coulomb_energy(charge1, charge2, distance, max_distance):
"""The formula to calculate the coulomb energy for two atoms (atom 1 and atom 2)
charge1 (float): the charge of atom 1
charge2 (float): the charge of atom 2
distance (float): the vanderwaals distance between atom 1 and atom 2
max_distance (float): the max distance that was used to find atoms 1 and 2
"""

return (charge1 * charge2 * _PhysicsStorage.COULOMB_CONSTANT /
(_PhysicsStorage.EPSILON0 * distance) * pow(1 - pow(distance / max_distance, 2), 2))

def __init__(self, sqldb, max_distance):
def __init__(self, sqldb):
"""Build a new set of physical parameters
sqldb (pdb2sql): interface to the contents of a PDB file, with charges and vanderwaals parameters included.
"""

self._vanderwaals_parameters = sqldb.get('eps,sig')
self._charges = sqldb.get('CHARGE')
self._atom_info = sqldb.get(",".join(_PhysicsStorage.ATOM_KEY))
self._positions = sqldb.get('x,y,z')
self._positions = numpy.array(sqldb.get('x,y,z'))

if len(self._vanderwaals_parameters) == 0:
raise RuntimeError("vanderwaals parameters are empty, please run '_assign_parameters' first")

if len(self._charges) == 0:
raise RuntimeError("vanderwaals parameters are empty, please run '_assign_parameters' first")

self._max_distance = max_distance
if len(self._atom_info) == 0:
raise RuntimeError("atom info is empty, please create a ResidueContacts object first")

if len(self._positions) == 0:
raise RuntimeError("positions are empty, please create a ResidueContacts object first")

self._vanderwaals_per_atom = {}
self._vanderwaals_per_position = {}
Expand All @@ -102,9 +182,16 @@ def __init__(self, sqldb, max_distance):
self._coulomb_per_atom = {}
self._coulomb_per_position = {}

def include_pair(self, atom1, atom2):
position1 = tuple(self._positions[atom1])
position2 = tuple(self._positions[atom2])
def include_pair(self, atom1, atom2, max_distance):
"""Add a pair of atoms to the sum
atom1 (int): number of atom 1
atom2 (int): number of atom 2
max_distance (float): the max distance that was used to find the atoms
"""

position1 = self._positions[atom1]
position2 = self._positions[atom2]

epsilon1, sigma1 = self._vanderwaals_parameters[atom1]
epsilon2, sigma2 = self._vanderwaals_parameters[atom2]
Expand All @@ -114,14 +201,17 @@ def include_pair(self, atom1, atom2):

distance = get_distance(position1, position2)
if distance == 0.0:
distance = 3.0
raise ValueError("encountered two atoms {} and {} with distance zero".format(atom1, atom2))

vanderwaals_energy = _PhysicsStorage.get_vanderwaals_energy(epsilon1, sigma1, epsilon2, sigma2, distance)
coulomb_energy = _PhysicsStorage.get_coulomb_energy(charge1, charge2, distance, self._max_distance)
coulomb_energy = _PhysicsStorage.get_coulomb_energy(charge1, charge2, distance, max_distance)

atom1_key = tuple(self._atom_info[atom1])
atom2_key = tuple(self._atom_info[atom2])

position1 = tuple(position1)
position2 = tuple(position2)

self._charge_per_atom[atom1_key] = charge1
self._charge_per_atom[atom2_key] = charge2

Expand All @@ -138,6 +228,12 @@ def include_pair(self, atom1, atom2):
_PhysicsStorage._sum_up(self._coulomb_per_position, position2, coulomb_energy)

def add_to_features(self, feature_data, feature_data_xyz):
"""Convert the summed interactions to deeprank features and store them in the corresponding dictionaries
feature_data (dictionary): where the per atom features should be stored
feature_data_xyz (dictionary): where the per position features should be stored
"""

feature_data['vdwaals'] = wrap_values_in_lists(self._vanderwaals_per_atom)
feature_data_xyz['vdwaals'] = wrap_values_in_lists(self._vanderwaals_per_position)
feature_data['coulomb'] = wrap_values_in_lists(self._coulomb_per_atom)
Expand All @@ -147,7 +243,9 @@ def add_to_features(self, feature_data, feature_data_xyz):


class ResidueContacts(FeatureClass):
"A class that collects features that involve contacts between a residue and its surrounding atoms"

# This dictionary holds the data used to find residue alternative names:
RESIDUE_SYNONYMS = {'PROP': ResidueSynonymCriteria('PRO', ['HT1', 'HT2'], []),
'NTER': ResidueSynonymCriteria('all', ['HT1', 'HT2', 'HT3'], []),
'CTER': ResidueSynonymCriteria('all', ['OXT'], []),
Expand All @@ -157,10 +255,14 @@ class ResidueContacts(FeatureClass):
'HISE': ResidueSynonymCriteria('HIS', ['ND1', 'CE1', 'CD2', 'NE2', 'HE2'], ['HD1']),
'HISD': ResidueSynonymCriteria('HIS', ['ND1', 'CE1', 'CD2', 'NE2', 'HD1'], ['HE2'])}

RESIDUE_KEY = ["chainID", "resSeq", "resName"]

@staticmethod
def get_alternative_residue_name(residue_name, atom_names):
"""Get the alternative residue name, according to the static dictionary in this class
residue_name (string): the name of the residue
atom_names (list of strings): the names of the atoms in the residue
"""

for name, crit in ResidueContacts.RESIDUE_SYNONYMS.items():
if crit.matches(residue_name, atom_names):
return name
Expand All @@ -170,6 +272,16 @@ def get_alternative_residue_name(residue_name, atom_names):
def __init__(self, pdb_path, chain_id, residue_number,
top_path, param_path, patch_path,
max_contact_distance=8.5):
"""Build a new residue contacts feature object
pdb_path (string): where the pdb file is located on disk
chain_id (string): identifier of the residue's protein chain within the pdb file
residue_number (int): identifier of the residue within the protein chain
top_path (string): location of the top file on disk
param_path (string): location of the param file on disk
patch_path (string): location of the patch file on disk
max_contact_distance (float): the maximum distance allowed for two atoms to be considered a contact pair
"""

super().__init__("Atomic")

Expand All @@ -183,13 +295,19 @@ def __init__(self, pdb_path, chain_id, residue_number,
self.patch_path = patch_path

def __enter__(self):
"open the with-clause"

self.sqldb = pdb2sql.interface(self.pdb_path)
return self

def __exit__(self, exc_type, exc, tb):
"close the with-clause"

self.sqldb._close()

def _read_top(self):
"read the top file and store its data in memory"

self._residue_charges = {}
self._residue_atom_types = {}
self._valid_residue_names = set([])
Expand All @@ -207,10 +325,14 @@ def _read_top(self):
self._residue_atom_types[(obj.residue_name, obj.atom_name)] = obj.kwargs['type']

def _read_param(self):
"read the param file and store its data in memory"

with open(self.param_path, 'rt') as f:
self._vanderwaals_parameters = ParamParser.parse(f)

def _read_patch(self):
"read the patch file and store its data in memory"

self._patch_charge = {}
self._patch_type = {}

Expand All @@ -225,6 +347,8 @@ def _read_patch(self):
self._patch_type[(action.selection.residue_type, action.selection.atom_name)] = action.kwargs['TYPE']

def _find_contact_atoms(self):
"find out which atoms of the pdb file lie within the max distance of the residue"

self._residue_atoms = set(self.sqldb.get("rowID", chainID=self.chain_id, resSeq=self.residue_number))

atomic_postions = {r[0]: numpy.array([r[1], r[2], r[3]]) for r in self.sqldb.get('rowID,x,y,z')}
Expand All @@ -243,6 +367,8 @@ def _find_contact_atoms(self):
break

def _extend_contact_to_residues(self):
"find out of which residues the contact atoms are a part"

per_chain = {}
for chain_id, residue_number in self.sqldb.get("chainID,resSeq", rowID=list(self._contact_atoms)):
if chain_id not in per_chain:
Expand All @@ -256,6 +382,13 @@ def _extend_contact_to_residues(self):
self._contact_atoms.add(atom)

def _get_atom_type(self, residue_name, alternative_residue_name, atom_name):
"""Find the type name of the given atom, according to top and patch data
residue_name (string): the name of the residue that the atom is in
alternative_residue_name (string): the name of the residue, outputted from 'get_alternative_residue_name'
atom_name (string): the name of the atom itself
"""

if (alternative_residue_name, atom_name) in self._patch_type:
return self._patch_type[(alternative_residue_name, atom_name)]

Expand All @@ -266,6 +399,13 @@ def _get_atom_type(self, residue_name, alternative_residue_name, atom_name):
return None

def _get_charge(self, residue_name, alternative_residue_name, atom_name):
"""Find the charge of the atom, according to top and patch data
residue_name (string): the name of the residue that the atom is in
alternative_residue_name (string): the name of the residue, outputted from 'get_alternative_residue_name'
atom_name (string): the name of the atom itself
"""

if residue_name not in self._valid_residue_names:
return 0.0

Expand All @@ -282,6 +422,14 @@ def _get_charge(self, residue_name, alternative_residue_name, atom_name):
return 0.0

def _get_vanderwaals_parameters(self, residue_name, alternative_residue_name, atom_name, atom_type):
"""Find the vanderwaals parameters of the atom, according to param data
residue_name (string): the name of the residue that the atom is in
alternative_residue_name (string): the name of the residue, outputted from 'get_alternative_residue_name'
atom_name (string): the name of the atom itself
atom_type (string): output from '_get_atom_type'
"""

if residue_name not in self._valid_residue_names:
return (0.0, 0.0)

Expand All @@ -292,6 +440,8 @@ def _get_vanderwaals_parameters(self, residue_name, alternative_residue_name, at
return (0.0, 0.0)

def _assign_parameters(self):
"Get parameters from top, param and patch data and put them in the pdb2sql database"

atomic_data = self.sqldb.get("rowID,name,chainID,resSeq,resName")
count_atoms = len(atomic_data)

Expand Down Expand Up @@ -343,17 +493,23 @@ def _assign_parameters(self):
self.sqldb.update_column('altRes', atomic_alternative_residue_names)

def _evaluate_physics(self):
physics_storage = _PhysicsStorage(self.sqldb, self.max_contact_distance)
"""From the top, param and patch data,
calculate energies and charges per residue atom and surrounding atoms
and add them to the feature dictionaries"""

physics_storage = _PhysicsStorage(self.sqldb)

for contact_atom in self._contact_atoms: # loop over atoms that contact the residue of interest

for residue_atom in self._residue_atoms: # loop over atoms in the residue of iterest

physics_storage.include_pair(contact_atom, residue_atom)
physics_storage.include_pair(contact_atom, residue_atom, self.max_contact_distance)

physics_storage.add_to_features(self.feature_data, self.feature_data_xyz)

def evaluate(self):
"collect the features before calling 'export_dataxyz_hdf5' and 'export_data_hdf5' on this object"

self._read_top()
self._read_param()
self._read_patch()
Expand Down

0 comments on commit bb6f5d5

Please sign in to comment.