Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
created a test for the ResidueContacts feature class
Browse files Browse the repository at this point in the history
  • Loading branch information
Coos Baakman committed Mar 19, 2021
1 parent 1ac15a9 commit 3198b3c
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 39 deletions.
348 changes: 331 additions & 17 deletions deeprank/features/ResidueContacts.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,364 @@
import logging
import pdb2sql
from deeprank.feature.FeatureClass import FeatureClass
import re
import os

import numpy

from deeprank.features.FeatureClass import FeatureClass
from deeprank.parse.param import ParamParser
from deeprank.parse.top import TopParser
from deeprank.parse.patch import PatchParser
from deeprank.models.patch import PatchActionType

_log = logging.getLogger(__name__)


class ResidueSynonymCriteria:
def __init__(self, residue_name, atoms_present, atoms_absent):
self.residue_name = residue_name
self.atoms_present = atoms_present
self.atoms_absent = atoms_absent

def matches(self, residue_name, atom_names):
if self.residue_name != 'all' and residue_name != self.residue_name:
return False

for atom_name in self.atoms_present:
if atom_name not in atom_names:
return False

for atom_name in self.atoms_absent:
if atom_name in atom_names:
return False

return True


def get_squared_distance(pos1, pos2):
return numpy.sum([numpy.square(pos1[i] - pos2[i]) for i in range(3)])

def get_distance(pos1, pos2):
return numpy.sqrt(get_squared_distance(pos1, pos2))

def wrap_values_in_lists(dict_):
return {key: [value] for key,value in dict_.items()}

class ResidueContacts(FeatureClass):
RESIDUE_KEY = ["chainID", "resSeq", "resName"]

RESIDUE_SYNONYMS = {'PROP': ResidueSynonymCriteria('PRO', ['HT1', 'HT2'], []),
'NTER': ResidueSynonymCriteria('all', ['HT1', 'HT2', 'HT3'], []),
'CTER': ResidueSynonymCriteria('all', ['OXT'], []),
'CTN': ResidueSynonymCriteria('all', ['NT', 'HT1', 'HT2'], []),
'CYNH': ResidueSynonymCriteria('CYS', ['1SG'], ['2SG']),
'DISU': ResidueSynonymCriteria('CYS', ['1SG', '2SG'], []),
'HISE': ResidueSynonymCriteria('HIS', ['ND1', 'CE1', 'CD2', 'NE2', 'HE2'], ['HD1']),
'HISD': ResidueSynonymCriteria('HIS', ['ND1', 'CE1', 'CD2', 'NE2', 'HD1'], ['HE2'])}

ATOM_KEY = ["chainID", "resSeq", "resName", "name"]
EPS0 = 1.0
C = 332.0636
RESIDUE_KEY = ["chainID", "resSeq", "resName"]

EPSILON0 = 1.0
COULOMB_CONSTANT = 332.0636

VANDERWAALS_DISTANCE_OFF = 8.5
VANDERWAALS_DISTANCE_ON = 6.5

SQUARED_VANDERWAALS_DISTANCE_OFF = numpy.square(VANDERWAALS_DISTANCE_OFF)
SQUARED_VANDERWAALS_DISTANCE_ON = numpy.square(VANDERWAALS_DISTANCE_ON)

@staticmethod
def get_alternative_residue_name(residue_name, atom_names):
for name, crit in ResidueContacts.RESIDUE_SYNONYMS.items():
if crit.matches(residue_name, atom_names):
return name

return None

@staticmethod
def get_vanderwaals_energy(epsilon1, sigma1, epsilon2, sigma2, distance):
average_epsilon = numpy.sqrt(epsilon1 * epsilon2)
average_sigma = 0.5 * (sigma1 + sigma2)

squared_distance = numpy.square(distance)
prefactor = (pow(ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_OFF - squared_distance, 2) *
(ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_OFF - squared_distance - 3 * (ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_ON - squared_distance)) /
pow(ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_OFF - ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_ON, 3))

if distance > ResidueContacts.VANDERWAALS_DISTANCE_OFF:
prefactor = 0.0

def __init__(self, pdb_path, chain_id, residue_number):
super.__init__("Atomic")
elif distance < ResidueContacts.VANDERWAALS_DISTANCE_ON:
prefactor = 1.0

return 4.0 * average_epsilon * (pow(average_sigma / distance, 12) - pow(average_sigma / distance6)) * prefactor


@staticmethod
def get_coulomb_energy(charge1, charge2, distance, max_distance):
return (charge1 * charge2 * ResidueContacts.COULOMB_CONSTANT /
(ResidueContacts.EPSILON0 * distance) * pow(1 - pow(distance / max_distance, 2), 2))

def __init__(self, pdb_path, chain_id, residue_number,
top_path, param_path, patch_path,
max_contact_distance=8.5):

super().__init__("Atomic")

self.pdb_path = pdb_path
self.chain_id = chain_id
self.residue_number = residue_number
self.max_contact_distance = max_contact_distance

self.top_path = top_path
self.param_path = param_path
self.patch_path = patch_path

def __enter__(self):
self.pdbsql = pdb2sql.interface(self.pdb_path)
self.sqldb = pdb2sql.interface(self.pdb_path)
return self

def __exit__(self, exc_type, exc, tb):
self.pdbsql._close()
self.sqldb._close()

def _read_top(self):
self._residue_charges = {}
self._residue_atom_types = {}
self._valid_residue_names = set([])

with open(self.top_path, 'rt') as f:
for obj in TopParser.parse(f):

# store the charge
self._residue_charges[(obj.residue_name, obj.atom_name)] = obj.kwargs['charge']

# put the resname in a list so far
self._valid_residue_names.add(obj.residue_name)

# dictionary for conversion name/type
self._residue_atom_types[(obj.residue_name, obj.atom_name)] = obj.kwargs['type']

def _read_param(self):
with open(self.param_path, 'rt') as f:
self._vanderwaals_parameters = ParamParser.parse(f)

def _read_patch(self):
self._patch_charge = {}
self._patch_type = {}

with open(self.patch_path, 'rt') as f:
for action in PatchParser.parse(f):

# get the new charge
self._patch_charge[(action.selection.residue_type, action.selection.atom_name)] = action.kwargs['CHARGE']

# get the new type if any
if 'TYPE' in action.kwargs:
self._patch_type[(action.selection.residue_type, action.selection.atom_name)] = action.kwargs['TYPE']

def _find_contact_atoms(self):
self._residue_atoms = set(self.sqldb.get("rowID", chainID=self.chain_id, resSeq=self.residue_number))

atomic_postions = {r[0]: numpy.array([r[1], r[2], r[3]]) for r in self.sqldb.get('rowID,x,y,z')}

squared_max_distance = numpy.square(self.max_contact_distance)

self._contact_atoms = set()
for atom, position in atomic_postions.items():
if atom in self._residue_atoms:
continue # we don't pair a residue with itself

for residue_atom in self._residue_atoms:
residue_atom_position = atomic_postions[residue_atom]
if get_squared_distance(position, residue_atom_position) < squared_max_distance:

self._contact_atoms.add(atom)
break

def _extend_contact_to_residues(self):
per_chain = {}
for chain_id, residue_number in self.sqldb.get("chainID,resSeq", rowID=list(self._contact_atoms)):
per_chain.get(chain_id, set([])).add(residue_number)

self._contact_atoms = {self.sqldb.get("rowID", chainID=chain_id, resSeq=residue_numbers)
for chain_id, residue_numbers in per_chain.items()}

def _get_atom_type(self, residue_name, alternative_residue_name, atom_name):
if (alternative_residue_name, atom_name) in self._patch_type:
return self._patch_type[(alternative_residue_name, atom_name)]

elif (residue_name, atom_name) in self._residue_atom_types:
return self._residue_atom_types[(residue_name, atom_name)]

else:
return None

def _get_charge(self, residue_name, alternative_residue_name, atom_name):
if residue_name not in self._valid_residue_names:
return 0.0

if (alternative_residue_name, atom_name) in self._patch_charge:
return self._patch_charge[(alternative_residue_name, atom_name)]

elif (residue_name, atom_name) in self._residue_charges:
return self._residue_charges[(residue_name, atom_name)]

else:
_log.warn("Atom type {} not found for {}/{}, set charge to 0.0"
.format(atom_name, residue_name, alternative_residue_name))

return 0.0

def _get_vanderwaals_parameters(self, residue_name, alternative_residue_name, atom_name, atom_type):
if residue_name not in self._valid_residue_names:
return (0.0, 0.0)

if atom_type in self._vanderwaals_parameters:
o = self._vanderwaals_parameters[atom_type]
return (o.epsilon, o.sigma)
else:
return (0.0, 0.0)

def _assign_parameters(self):
atomic_data = self.sqldb.get("rowID,name,chainID,resSeq,resName")
count_atoms = len(atomic_data)

atomic_charges = numpy.zeros(count_atoms)
atomic_epsilon = numpy.zeros(count_atoms)
atomic_sigma = numpy.zeros(count_atoms)

atomic_types = numpy.zeros(count_atoms, dtype='<U5')
atomic_alternative_residue_names = numpy.zeros(count_atoms, dtype='<U5')

# here, we map the atom names per residue
residue_atom_names = {}
for atom_nr, atom_name, chain_id, residue_number, residue_name in atomic_data:
key = (chain_id, residue_number)
if key not in residue_atom_names:
residue_atom_names[key] = set([])
residue_atom_names[key].add(atom_name)

# loop over all atoms
for atom_nr, atom_name, chain_id, residue_number, residue_name in atomic_data:
atoms_in_residue = residue_atom_names[(chain_id, residue_number)]

alternative_residue_name = ResidueContacts.get_alternative_residue_name(residue_name, atoms_in_residue)
atomic_alternative_residue_names[atom_nr] = alternative_residue_name

atom_type = self._get_atom_type(residue_name, alternative_residue_name, atom_name)
atomic_types[atom_nr] = atom_type

atomic_charges[atom_nr] = self._get_charge(residue_name, alternative_residue_name, atom_name)

epsilon, sigma = self._get_vanderwaals_parameters(residue_name, alternative_residue_name, atom_name, atom_type)
atomic_epsilon[atom_nr] = epsilon
atomic_sigma[atom_nr] = sigma

# put in sql
self.sqldb.add_column('CHARGE')
self.sqldb.update_column('CHARGE', atomic_charges)

self.sqldb.add_column('eps')
self.sqldb.update_column('eps', atomic_epsilon)

self.sqldb.add_column('sig')
self.sqldb.update_column('sig', atomic_sigma)

self.sqldb.add_column('type', 'TEXT')
self.sqldb.update_column('type', atomic_types)

self.sqldb.add_column('altRes', 'TEXT')
self.sqldb.update_column('altRes', atomic_alternative_residue_names)

def _evaluate_physics(self):
vanderwaals_data = self.sqldb.get('eps,sig')
charge_data = self.sqldb.get('CHARGE')
atom_info = self.sqldb.get(",".join(ResidueContacts.ATOM_KEY))
atom_positions = self.sqldb.get('x,y,z')

vanderwaals_per_atom = {}
vanderwaals_per_position = {}
coulomb_per_atom = {}
coulomb_per_position = {}
charge_per_atom = {}
charge_per_position = {}

for contact_atom in self._contact_atoms: # loop over atoms that contact the residue of interest

contact_epsilon, contact_sigma = vanderwaals_data[contact_atom]
contact_charge = charge_data[contact_atom]
contact_atom_key = atom_info[contact_atom]
contact_position_key = [0] + atom_positions[contact_atom]

# set charge
charge_per_atom[contact_atom_key] = contact_charge
charge_per_position[contact_position_key] = contact_charge

for residue_atom in self._residue_atoms: # loop over atoms in the residue of iterest

residue_epsilon, residue_sigma = vanderwaals_data[residue_atom]
residue_charge = charge_data[residue_atom]
residue_atom_key = atom_info[residue_atom]
residue_position_key = [1] + atom_info[residue_atom]

# set charge
charge_per_atom[residue_atom_key] = residue_charge
charge_per_position[residue_position_key] = residue_charge

distance = get_distance(contact_atom, residue_atom)
if distance == 0.0:
distance = 3.0

# add on vanderwaals energy
vanderwaals_energy = ResidueContacts.get_vandderwaals_energy(contact_epsilon, contact_sigma, residue_epsilon, residue_sigma, distance)

vanderwaals_per_atom[contact_atom_key] = vanderwaals_per_atom.get(contact_atom_key, 0.0) + vanderwaals_energy
vanderwaals_per_atom[residue_atom_key] = vanderwaals_per_atom.get(residue_atom_key, 0.0) + vanderwaals_energy
vanderwaals_per_position[contact_position_key] = vanderwaals_per_postion.get(contact_position_key, 0.0) + vanderwaals_energy
vanderwaals_per_position[residue_position_key] = vanderwaals_per_postion.get(residue_position_key, 0.0) + vanderwaals_energy

# add on coulomb energy
coulomb_energy = ResidueContacts.get_coulomb_energy(contact_charge, residue_charge, distance, self.max_contact_distance)

coulomb_per_atom[contact_atom_key] = coulomb_per_atom.get(contact_atom_key, 0.0) + coulomb_energy
coulomb_per_atom[residue_atom_key] = coulomb_per_atom.get(residue_atom_key, 0.0) + coulomb_energy
coulomb_per_position[contact_position_key] = coulomb_per_postion.get(contact_position_key, 0.0) + coulomb_energy
coulomb_per_position[residue_position_key] = coulomb_per_postion.get(residue_position_key, 0.0) + coulomb_energy

# add to feature data (in list form)
self.feature_data['vdwaals'] = wrap_values_in_lists(vanderwaals_per_atom)
self.feature_data_xyz['vdwaals'] = wrap_values_in_lists(vanderwaals_per_position)
self.feature_data['coulomb'] = wrap_values_in_lists(coulomb_per_atom)
self.feature_data_xyz['coulomb'] = wrap_values_in_lists(coulomb_per_position)
self.feature_data['charge'] = wrap_values_in_lists(charge_per_atom)
self.feature_data_xyz['charge'] = wrap_values_in_lists(charge_per_position)

def evaluate(self):
self._read_charges()
self._read_vdw()
self._read_top()
self._read_param()
self._read_patch()
self._assign_parameters()

self._evaluate_vdw()
self._evaluate_coulomb()
self._evaluate_charges()
self._find_contact_atoms()
self._extend_contact_to_residues()

self._evaluate_physics()


def __compute_feature__(pdb_path, feature_group, raw_feature_group, chain_id, residue_number):

forcefield_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'forcefield')
top_path = os.path.join(forcefield_path, 'protein-allhdg5-4_new.top')
param_path = os.path.join(forcefield_path, 'protein-allhdg5-4_new.param')
patch_path = os.path.join(forcefield_path, 'patch.top')

def __compute_feature__(pdb_path, featgrp, featgrp_raw, chain_id, residue_number):
with ResidueContacts(pdb_path, chain_id, residue_number) as feature_object:
with ResidueContacts(pdb_path, chain_id, residue_number,
top_path, param_path, patch_path) as feature_object:

feature_object.evaluate()

# export in the hdf5 file
atfeat.export_dataxyz_hdf5(featgrp)
atfeat.export_data_hdf5(featgrp_raw)
feature_object.export_dataxyz_hdf5(feature_group)
feature_object.export_data_hdf5(raw_feature_group)


10 changes: 0 additions & 10 deletions deeprank/models/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,3 @@ class PatchActionType(Enum):
MODIFY = 1
ADD = 2


class PatchResidueSelectionType(Enum):
NTER = 1
PROP = 2
CTER = 3
CTN = 4
DISU = 5
CYNH = 6
HISE = 7
HISD = 8
Loading

0 comments on commit 3198b3c

Please sign in to comment.