This repository has been archived by the owner on Nov 28, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
created a test for the ResidueContacts feature class
- Loading branch information
Coos Baakman
committed
Mar 19, 2021
1 parent
1ac15a9
commit 3198b3c
Showing
6 changed files
with
366 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,364 @@ | ||
import logging | ||
import pdb2sql | ||
from deeprank.feature.FeatureClass import FeatureClass | ||
import re | ||
import os | ||
|
||
import numpy | ||
|
||
from deeprank.features.FeatureClass import FeatureClass | ||
from deeprank.parse.param import ParamParser | ||
from deeprank.parse.top import TopParser | ||
from deeprank.parse.patch import PatchParser | ||
from deeprank.models.patch import PatchActionType | ||
|
||
_log = logging.getLogger(__name__) | ||
|
||
|
||
class ResidueSynonymCriteria: | ||
def __init__(self, residue_name, atoms_present, atoms_absent): | ||
self.residue_name = residue_name | ||
self.atoms_present = atoms_present | ||
self.atoms_absent = atoms_absent | ||
|
||
def matches(self, residue_name, atom_names): | ||
if self.residue_name != 'all' and residue_name != self.residue_name: | ||
return False | ||
|
||
for atom_name in self.atoms_present: | ||
if atom_name not in atom_names: | ||
return False | ||
|
||
for atom_name in self.atoms_absent: | ||
if atom_name in atom_names: | ||
return False | ||
|
||
return True | ||
|
||
|
||
def get_squared_distance(pos1, pos2): | ||
return numpy.sum([numpy.square(pos1[i] - pos2[i]) for i in range(3)]) | ||
|
||
def get_distance(pos1, pos2): | ||
return numpy.sqrt(get_squared_distance(pos1, pos2)) | ||
|
||
def wrap_values_in_lists(dict_): | ||
return {key: [value] for key,value in dict_.items()} | ||
|
||
class ResidueContacts(FeatureClass): | ||
RESIDUE_KEY = ["chainID", "resSeq", "resName"] | ||
|
||
RESIDUE_SYNONYMS = {'PROP': ResidueSynonymCriteria('PRO', ['HT1', 'HT2'], []), | ||
'NTER': ResidueSynonymCriteria('all', ['HT1', 'HT2', 'HT3'], []), | ||
'CTER': ResidueSynonymCriteria('all', ['OXT'], []), | ||
'CTN': ResidueSynonymCriteria('all', ['NT', 'HT1', 'HT2'], []), | ||
'CYNH': ResidueSynonymCriteria('CYS', ['1SG'], ['2SG']), | ||
'DISU': ResidueSynonymCriteria('CYS', ['1SG', '2SG'], []), | ||
'HISE': ResidueSynonymCriteria('HIS', ['ND1', 'CE1', 'CD2', 'NE2', 'HE2'], ['HD1']), | ||
'HISD': ResidueSynonymCriteria('HIS', ['ND1', 'CE1', 'CD2', 'NE2', 'HD1'], ['HE2'])} | ||
|
||
ATOM_KEY = ["chainID", "resSeq", "resName", "name"] | ||
EPS0 = 1.0 | ||
C = 332.0636 | ||
RESIDUE_KEY = ["chainID", "resSeq", "resName"] | ||
|
||
EPSILON0 = 1.0 | ||
COULOMB_CONSTANT = 332.0636 | ||
|
||
VANDERWAALS_DISTANCE_OFF = 8.5 | ||
VANDERWAALS_DISTANCE_ON = 6.5 | ||
|
||
SQUARED_VANDERWAALS_DISTANCE_OFF = numpy.square(VANDERWAALS_DISTANCE_OFF) | ||
SQUARED_VANDERWAALS_DISTANCE_ON = numpy.square(VANDERWAALS_DISTANCE_ON) | ||
|
||
@staticmethod | ||
def get_alternative_residue_name(residue_name, atom_names): | ||
for name, crit in ResidueContacts.RESIDUE_SYNONYMS.items(): | ||
if crit.matches(residue_name, atom_names): | ||
return name | ||
|
||
return None | ||
|
||
@staticmethod | ||
def get_vanderwaals_energy(epsilon1, sigma1, epsilon2, sigma2, distance): | ||
average_epsilon = numpy.sqrt(epsilon1 * epsilon2) | ||
average_sigma = 0.5 * (sigma1 + sigma2) | ||
|
||
squared_distance = numpy.square(distance) | ||
prefactor = (pow(ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_OFF - squared_distance, 2) * | ||
(ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_OFF - squared_distance - 3 * (ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_ON - squared_distance)) / | ||
pow(ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_OFF - ResidueContacts.SQUARED_VANDERWAALS_DISTANCE_ON, 3)) | ||
|
||
if distance > ResidueContacts.VANDERWAALS_DISTANCE_OFF: | ||
prefactor = 0.0 | ||
|
||
def __init__(self, pdb_path, chain_id, residue_number): | ||
super.__init__("Atomic") | ||
elif distance < ResidueContacts.VANDERWAALS_DISTANCE_ON: | ||
prefactor = 1.0 | ||
|
||
return 4.0 * average_epsilon * (pow(average_sigma / distance, 12) - pow(average_sigma / distance6)) * prefactor | ||
|
||
|
||
@staticmethod | ||
def get_coulomb_energy(charge1, charge2, distance, max_distance): | ||
return (charge1 * charge2 * ResidueContacts.COULOMB_CONSTANT / | ||
(ResidueContacts.EPSILON0 * distance) * pow(1 - pow(distance / max_distance, 2), 2)) | ||
|
||
def __init__(self, pdb_path, chain_id, residue_number, | ||
top_path, param_path, patch_path, | ||
max_contact_distance=8.5): | ||
|
||
super().__init__("Atomic") | ||
|
||
self.pdb_path = pdb_path | ||
self.chain_id = chain_id | ||
self.residue_number = residue_number | ||
self.max_contact_distance = max_contact_distance | ||
|
||
self.top_path = top_path | ||
self.param_path = param_path | ||
self.patch_path = patch_path | ||
|
||
def __enter__(self): | ||
self.pdbsql = pdb2sql.interface(self.pdb_path) | ||
self.sqldb = pdb2sql.interface(self.pdb_path) | ||
return self | ||
|
||
def __exit__(self, exc_type, exc, tb): | ||
self.pdbsql._close() | ||
self.sqldb._close() | ||
|
||
def _read_top(self): | ||
self._residue_charges = {} | ||
self._residue_atom_types = {} | ||
self._valid_residue_names = set([]) | ||
|
||
with open(self.top_path, 'rt') as f: | ||
for obj in TopParser.parse(f): | ||
|
||
# store the charge | ||
self._residue_charges[(obj.residue_name, obj.atom_name)] = obj.kwargs['charge'] | ||
|
||
# put the resname in a list so far | ||
self._valid_residue_names.add(obj.residue_name) | ||
|
||
# dictionary for conversion name/type | ||
self._residue_atom_types[(obj.residue_name, obj.atom_name)] = obj.kwargs['type'] | ||
|
||
def _read_param(self): | ||
with open(self.param_path, 'rt') as f: | ||
self._vanderwaals_parameters = ParamParser.parse(f) | ||
|
||
def _read_patch(self): | ||
self._patch_charge = {} | ||
self._patch_type = {} | ||
|
||
with open(self.patch_path, 'rt') as f: | ||
for action in PatchParser.parse(f): | ||
|
||
# get the new charge | ||
self._patch_charge[(action.selection.residue_type, action.selection.atom_name)] = action.kwargs['CHARGE'] | ||
|
||
# get the new type if any | ||
if 'TYPE' in action.kwargs: | ||
self._patch_type[(action.selection.residue_type, action.selection.atom_name)] = action.kwargs['TYPE'] | ||
|
||
def _find_contact_atoms(self): | ||
self._residue_atoms = set(self.sqldb.get("rowID", chainID=self.chain_id, resSeq=self.residue_number)) | ||
|
||
atomic_postions = {r[0]: numpy.array([r[1], r[2], r[3]]) for r in self.sqldb.get('rowID,x,y,z')} | ||
|
||
squared_max_distance = numpy.square(self.max_contact_distance) | ||
|
||
self._contact_atoms = set() | ||
for atom, position in atomic_postions.items(): | ||
if atom in self._residue_atoms: | ||
continue # we don't pair a residue with itself | ||
|
||
for residue_atom in self._residue_atoms: | ||
residue_atom_position = atomic_postions[residue_atom] | ||
if get_squared_distance(position, residue_atom_position) < squared_max_distance: | ||
|
||
self._contact_atoms.add(atom) | ||
break | ||
|
||
def _extend_contact_to_residues(self): | ||
per_chain = {} | ||
for chain_id, residue_number in self.sqldb.get("chainID,resSeq", rowID=list(self._contact_atoms)): | ||
per_chain.get(chain_id, set([])).add(residue_number) | ||
|
||
self._contact_atoms = {self.sqldb.get("rowID", chainID=chain_id, resSeq=residue_numbers) | ||
for chain_id, residue_numbers in per_chain.items()} | ||
|
||
def _get_atom_type(self, residue_name, alternative_residue_name, atom_name): | ||
if (alternative_residue_name, atom_name) in self._patch_type: | ||
return self._patch_type[(alternative_residue_name, atom_name)] | ||
|
||
elif (residue_name, atom_name) in self._residue_atom_types: | ||
return self._residue_atom_types[(residue_name, atom_name)] | ||
|
||
else: | ||
return None | ||
|
||
def _get_charge(self, residue_name, alternative_residue_name, atom_name): | ||
if residue_name not in self._valid_residue_names: | ||
return 0.0 | ||
|
||
if (alternative_residue_name, atom_name) in self._patch_charge: | ||
return self._patch_charge[(alternative_residue_name, atom_name)] | ||
|
||
elif (residue_name, atom_name) in self._residue_charges: | ||
return self._residue_charges[(residue_name, atom_name)] | ||
|
||
else: | ||
_log.warn("Atom type {} not found for {}/{}, set charge to 0.0" | ||
.format(atom_name, residue_name, alternative_residue_name)) | ||
|
||
return 0.0 | ||
|
||
def _get_vanderwaals_parameters(self, residue_name, alternative_residue_name, atom_name, atom_type): | ||
if residue_name not in self._valid_residue_names: | ||
return (0.0, 0.0) | ||
|
||
if atom_type in self._vanderwaals_parameters: | ||
o = self._vanderwaals_parameters[atom_type] | ||
return (o.epsilon, o.sigma) | ||
else: | ||
return (0.0, 0.0) | ||
|
||
def _assign_parameters(self): | ||
atomic_data = self.sqldb.get("rowID,name,chainID,resSeq,resName") | ||
count_atoms = len(atomic_data) | ||
|
||
atomic_charges = numpy.zeros(count_atoms) | ||
atomic_epsilon = numpy.zeros(count_atoms) | ||
atomic_sigma = numpy.zeros(count_atoms) | ||
|
||
atomic_types = numpy.zeros(count_atoms, dtype='<U5') | ||
atomic_alternative_residue_names = numpy.zeros(count_atoms, dtype='<U5') | ||
|
||
# here, we map the atom names per residue | ||
residue_atom_names = {} | ||
for atom_nr, atom_name, chain_id, residue_number, residue_name in atomic_data: | ||
key = (chain_id, residue_number) | ||
if key not in residue_atom_names: | ||
residue_atom_names[key] = set([]) | ||
residue_atom_names[key].add(atom_name) | ||
|
||
# loop over all atoms | ||
for atom_nr, atom_name, chain_id, residue_number, residue_name in atomic_data: | ||
atoms_in_residue = residue_atom_names[(chain_id, residue_number)] | ||
|
||
alternative_residue_name = ResidueContacts.get_alternative_residue_name(residue_name, atoms_in_residue) | ||
atomic_alternative_residue_names[atom_nr] = alternative_residue_name | ||
|
||
atom_type = self._get_atom_type(residue_name, alternative_residue_name, atom_name) | ||
atomic_types[atom_nr] = atom_type | ||
|
||
atomic_charges[atom_nr] = self._get_charge(residue_name, alternative_residue_name, atom_name) | ||
|
||
epsilon, sigma = self._get_vanderwaals_parameters(residue_name, alternative_residue_name, atom_name, atom_type) | ||
atomic_epsilon[atom_nr] = epsilon | ||
atomic_sigma[atom_nr] = sigma | ||
|
||
# put in sql | ||
self.sqldb.add_column('CHARGE') | ||
self.sqldb.update_column('CHARGE', atomic_charges) | ||
|
||
self.sqldb.add_column('eps') | ||
self.sqldb.update_column('eps', atomic_epsilon) | ||
|
||
self.sqldb.add_column('sig') | ||
self.sqldb.update_column('sig', atomic_sigma) | ||
|
||
self.sqldb.add_column('type', 'TEXT') | ||
self.sqldb.update_column('type', atomic_types) | ||
|
||
self.sqldb.add_column('altRes', 'TEXT') | ||
self.sqldb.update_column('altRes', atomic_alternative_residue_names) | ||
|
||
def _evaluate_physics(self): | ||
vanderwaals_data = self.sqldb.get('eps,sig') | ||
charge_data = self.sqldb.get('CHARGE') | ||
atom_info = self.sqldb.get(",".join(ResidueContacts.ATOM_KEY)) | ||
atom_positions = self.sqldb.get('x,y,z') | ||
|
||
vanderwaals_per_atom = {} | ||
vanderwaals_per_position = {} | ||
coulomb_per_atom = {} | ||
coulomb_per_position = {} | ||
charge_per_atom = {} | ||
charge_per_position = {} | ||
|
||
for contact_atom in self._contact_atoms: # loop over atoms that contact the residue of interest | ||
|
||
contact_epsilon, contact_sigma = vanderwaals_data[contact_atom] | ||
contact_charge = charge_data[contact_atom] | ||
contact_atom_key = atom_info[contact_atom] | ||
contact_position_key = [0] + atom_positions[contact_atom] | ||
|
||
# set charge | ||
charge_per_atom[contact_atom_key] = contact_charge | ||
charge_per_position[contact_position_key] = contact_charge | ||
|
||
for residue_atom in self._residue_atoms: # loop over atoms in the residue of iterest | ||
|
||
residue_epsilon, residue_sigma = vanderwaals_data[residue_atom] | ||
residue_charge = charge_data[residue_atom] | ||
residue_atom_key = atom_info[residue_atom] | ||
residue_position_key = [1] + atom_info[residue_atom] | ||
|
||
# set charge | ||
charge_per_atom[residue_atom_key] = residue_charge | ||
charge_per_position[residue_position_key] = residue_charge | ||
|
||
distance = get_distance(contact_atom, residue_atom) | ||
if distance == 0.0: | ||
distance = 3.0 | ||
|
||
# add on vanderwaals energy | ||
vanderwaals_energy = ResidueContacts.get_vandderwaals_energy(contact_epsilon, contact_sigma, residue_epsilon, residue_sigma, distance) | ||
|
||
vanderwaals_per_atom[contact_atom_key] = vanderwaals_per_atom.get(contact_atom_key, 0.0) + vanderwaals_energy | ||
vanderwaals_per_atom[residue_atom_key] = vanderwaals_per_atom.get(residue_atom_key, 0.0) + vanderwaals_energy | ||
vanderwaals_per_position[contact_position_key] = vanderwaals_per_postion.get(contact_position_key, 0.0) + vanderwaals_energy | ||
vanderwaals_per_position[residue_position_key] = vanderwaals_per_postion.get(residue_position_key, 0.0) + vanderwaals_energy | ||
|
||
# add on coulomb energy | ||
coulomb_energy = ResidueContacts.get_coulomb_energy(contact_charge, residue_charge, distance, self.max_contact_distance) | ||
|
||
coulomb_per_atom[contact_atom_key] = coulomb_per_atom.get(contact_atom_key, 0.0) + coulomb_energy | ||
coulomb_per_atom[residue_atom_key] = coulomb_per_atom.get(residue_atom_key, 0.0) + coulomb_energy | ||
coulomb_per_position[contact_position_key] = coulomb_per_postion.get(contact_position_key, 0.0) + coulomb_energy | ||
coulomb_per_position[residue_position_key] = coulomb_per_postion.get(residue_position_key, 0.0) + coulomb_energy | ||
|
||
# add to feature data (in list form) | ||
self.feature_data['vdwaals'] = wrap_values_in_lists(vanderwaals_per_atom) | ||
self.feature_data_xyz['vdwaals'] = wrap_values_in_lists(vanderwaals_per_position) | ||
self.feature_data['coulomb'] = wrap_values_in_lists(coulomb_per_atom) | ||
self.feature_data_xyz['coulomb'] = wrap_values_in_lists(coulomb_per_position) | ||
self.feature_data['charge'] = wrap_values_in_lists(charge_per_atom) | ||
self.feature_data_xyz['charge'] = wrap_values_in_lists(charge_per_position) | ||
|
||
def evaluate(self): | ||
self._read_charges() | ||
self._read_vdw() | ||
self._read_top() | ||
self._read_param() | ||
self._read_patch() | ||
self._assign_parameters() | ||
|
||
self._evaluate_vdw() | ||
self._evaluate_coulomb() | ||
self._evaluate_charges() | ||
self._find_contact_atoms() | ||
self._extend_contact_to_residues() | ||
|
||
self._evaluate_physics() | ||
|
||
|
||
def __compute_feature__(pdb_path, feature_group, raw_feature_group, chain_id, residue_number): | ||
|
||
forcefield_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'forcefield') | ||
top_path = os.path.join(forcefield_path, 'protein-allhdg5-4_new.top') | ||
param_path = os.path.join(forcefield_path, 'protein-allhdg5-4_new.param') | ||
patch_path = os.path.join(forcefield_path, 'patch.top') | ||
|
||
def __compute_feature__(pdb_path, featgrp, featgrp_raw, chain_id, residue_number): | ||
with ResidueContacts(pdb_path, chain_id, residue_number) as feature_object: | ||
with ResidueContacts(pdb_path, chain_id, residue_number, | ||
top_path, param_path, patch_path) as feature_object: | ||
|
||
feature_object.evaluate() | ||
|
||
# export in the hdf5 file | ||
atfeat.export_dataxyz_hdf5(featgrp) | ||
atfeat.export_data_hdf5(featgrp_raw) | ||
feature_object.export_dataxyz_hdf5(feature_group) | ||
feature_object.export_data_hdf5(raw_feature_group) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.