diff --git a/.gitignore b/.gitignore index e498d9e8..74118c67 100644 --- a/.gitignore +++ b/.gitignore @@ -21,9 +21,8 @@ example/*.hdf5 example/*.pdb # some test file -test/out_2d -test/out_3d -test/out_3d_class +test/out_2d* +test/out_3d* test/out_test test/*.pckl test/*.hdf5 diff --git a/.travis.yml b/.travis.yml index fac2b986..19a19273 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,7 +19,7 @@ before_install: # pytest - conda install -c anaconda pytest - conda install -c conda-forge pytest-cov - - conda install python=3.6 + - conda install python=3.7 # codacy-coverage - pip install -q --upgrade pip diff --git a/README.md b/README.md index e44ebc4d..ae6cd8b1 100644 --- a/README.md +++ b/README.md @@ -3,12 +3,12 @@ **Deep Learning for ranking protein-protein conformations** [![Build Status](https://secure.travis-ci.org/DeepRank/deeprank.svg?branch=master)](https://travis-ci.org/DeepRank/deeprank) -[![Codacy Badge](https://api.codacy.com/project/badge/Grade/9252e59633cf46a7ada0c3c614c175ea)](https://www.codacy.com/app/NicoRenaud/deeprank?utm_source=github.com&utm_medium=referral&utm_content=DeepRank/deeprank&utm_campaign=Badge_Grade) +[![Codacy Badge](https://api.codacy.com/project/badge/Grade/9252e59633cf46a7ada0c3c614c175ea)](https://www.codacy.com/app/NicoRenaud/deeprank?utm_source=github.com&utm_medium=referral&utm_content=DeepRank/deeprank&utm_campaign=Badge_Grade) [![Documentation Status](https://readthedocs.org/projects/deeprank/badge/?version=latest)](http://deeprank.readthedocs.io/?badge=latest) [![Coverage Status](https://coveralls.io/repos/github/DeepRank/deeprank/badge.svg?branch=master)](https://coveralls.io/github/DeepRank/deeprank?branch=master) The documentation of the module can be found on readthedocs : -http://deeprank.readthedocs.io/en/latest/ + ![alt-text](./pics/deeprank.png) @@ -16,22 +16,22 @@ http://deeprank.readthedocs.io/en/latest/ Minimal information to install the module - * clone the repository `git clone https://github.com/DeepRank/deeprank.git` - * go there `cd deeprank` - * install the module `pip install -e ./` - * go int the test dir `cd test` - * run the test suite `pytest` - +- clone the repository `git clone https://github.com/DeepRank/deeprank.git` +- go there `cd deeprank` +- install the module `pip install -e ./` +- go int the test dir `cd test` +- run the test suite `pytest` ## 2 . Tutorial -We give here the tutorial like introduction to the DeepRank machinery. More informatoin can be found in the documentation http://deeprank.readthedocs.io/en/latest/. We quickly illsutrate here the two main steps of Deeprank : - * the generation of the data - * running deep leaning experiments. +We give here the tutorial like introduction to the DeepRank machinery. More informatoin can be found in the documentation . We quickly illsutrate here the two main steps of Deeprank : + +- the generation of the data +- running deep leaning experiments. ### A . Generate the data set (using MPI) -The generation of the data require only require PDBs files of decoys and their native and the PSSM if needed. All the features/targets and mapped features onto grid points will be auomatically calculated and store in a HDF5 file. +The generation of the data require only require PDBs files of decoys and their native and the PSSM if needed. All the features/targets and mapped features onto grid points will be auomatically calculated and store in a HDF5 file. ```python from deeprank.generate import * @@ -79,39 +79,36 @@ grid_info = { This script can be exectuted using for example 4 MPI processes with the command: ``` -NP=4 -mpiexec -n $NP python generate.py + NP=4 + mpiexec -n $NP python generate.py ``` - -In the first part of the script we define the path where to find the PDBs of the decoys and natives that we want to have in the dataset. All the .pdb files present in *pdb_source* will be used in the dataset. We need to specify where to find the native conformations to be able to compute RMSD and the dockQ score. For each pdb file detected in *pdb_source*, the code will try to find a native conformation in *pdb_native*. +In the first part of the script we define the path where to find the PDBs of the decoys and natives that we want to have in the dataset. All the .pdb files present in _pdb_source_ will be used in the dataset. We need to specify where to find the native conformations to be able to compute RMSD and the dockQ score. For each pdb file detected in _pdb_source_, the code will try to find a native conformation in _pdb_native_. We then initialize the `DataGenerator` object. This object (defined in `deeprank/generate/DataGenerator.py`) needs a few input parameters: - * pdb_source : where to find the pdb to include in the dataset - * pdb_native : where to find the corresponding native conformations - * compute_targets : list of modules used to compute the targets - * compute_features : list of modules used to compute the features - * hdf5 : Name of the HDF5 file to store the data set +- pdb_source : where to find the pdb to include in the dataset +- pdb_native : where to find the corresponding native conformations +- compute_targets : list of modules used to compute the targets +- compute_features : list of modules used to compute the features +- hdf5 : Name of the HDF5 file to store the data set We then create the data base with the command `database.create_database()`. This function autmatically create an HDF5 files where each pdb has its own group. In each group we can find the pdb of the complex and its native form, the calculated features and the calculated targets. We can now mapped the features to a grid. This is done via the command `database.map_features()`. As you can see this method requires a dictionary as input. The dictionary contains the instruction to map the data. - * number_of_points: the number of points in each direction - * resolution : the resolution in Angs - * atomic_densities : {'atom_name' : vvdw_radius} the atomic densities required +- number_of_points: the number of points in each direction +- resolution : the resolution in Angs +- atomic_densities : {'atom_name' : vvdw_radius} the atomic densities required The atomic densities are mapped following the [protein-ligand paper](https://arxiv.org/abs/1612.02751). The other features are mapped to the grid points using a Gaussian function (other modes are possible but somehow hard coded) #### Visualization of the mapped features -To explore the HDf5 file and vizualize the features you can use the dedicated browser https://github.com/DeepRank/DeepXplorer. This tool saloows to dig through the hdf5 file and to directly generate the files required to vizualie the features in VMD or PyMol. An iPython comsole is also embedded to analyze the feature values, plot them etc .... - +To explore the HDf5 file and vizualize the features you can use the dedicated browser . This tool saloows to dig through the hdf5 file and to directly generate the files required to vizualie the features in VMD or PyMol. An iPython comsole is also embedded to analyze the feature values, plot them etc .... ### B . Deep Learning The HDF5 files generated above can be used as input for deep learning experiments. You can take a look at the file `test/test_learn.py` for some examples. We give here a quick overview of the process. - ```python from deeprank.learn import * from deeprank.learn.model3d import cnn as cnn3d @@ -145,9 +142,6 @@ model.optimizer = optim.SGD(model.net.parameters(), model.train(nepoch = 50,divide_trainset=0.8, train_batch_size = 5,num_workers=0) ``` - - In the first part of the script we create a Torch database from the HDF5 file. We can specify one or several HDF5 files and even select some conformations using the `dict_filter` argument. Other options of `DataSet` can be used to specify the features/targets the normalization, etc ... We then create a `NeuralNet` instance that takes the dataset as input argument. Several options are available to specify the task to do, the GPU use, etc ... We then have simply to train the model. Simple ! - diff --git a/deeprank/features/AtomicFeature.py b/deeprank/features/AtomicFeature.py index 55ac33c2..eb20f322 100644 --- a/deeprank/features/AtomicFeature.py +++ b/deeprank/features/AtomicFeature.py @@ -2,9 +2,9 @@ import warnings import numpy as np +import pdb2sql from deeprank.features import FeatureClass -from deeprank.tools import pdb2sql class AtomicFeature(FeatureClass): @@ -81,7 +81,7 @@ def __init__(self, pdbfile, param_charge=None, param_vdw=None, self.atom_key = 'chainID, resSeq, resName, name' # read the pdb as an sql - self.sqldb = pdb2sql(self.pdbfile) + self.sqldb = pdb2sql.pdb2sql(self.pdbfile) # read the force field self.read_charge_file() diff --git a/deeprank/features/BSA.py b/deeprank/features/BSA.py index 8a52f587..cdf5ac03 100644 --- a/deeprank/features/BSA.py +++ b/deeprank/features/BSA.py @@ -1,7 +1,8 @@ import warnings +import pdb2sql + from deeprank.features import FeatureClass -from deeprank.tools import pdb2sql try: import freesasa @@ -33,7 +34,7 @@ def __init__(self, pdb_data, chainA='A', chainB='B'): >>> bsa.sql.close() """ self.pdb_data = pdb_data - self.sql = pdb2sql(pdb_data) + self.sql = pdb2sql.interface(pdb_data) self.chains_label = [chainA, chainB] self.feature_data = {} @@ -83,9 +84,8 @@ def get_contact_residue_sasa(self, cutoff=5.5): self.bsa_data = {} self.bsa_data_xyz = {} - # res = ([chain1 residues], [chain2 residues]) - ctc_res = self.sql.get_contact_residue(cutoff=cutoff) - ctc_res = ctc_res[0] + ctc_res[1] + ctc_res = self.sql.get_contact_residues(cutoff=cutoff) + ctc_res = ctc_res["A"] + ctc_res["B"] # handle with small interface or no interface total_res = len(ctc_res) diff --git a/deeprank/features/FullPSSM.py b/deeprank/features/FullPSSM.py index 86bca24f..a8715662 100644 --- a/deeprank/features/FullPSSM.py +++ b/deeprank/features/FullPSSM.py @@ -2,10 +2,10 @@ import warnings import numpy as np +import pdb2sql from deeprank import config from deeprank.features import FeatureClass -from deeprank.tools import pdb2sql ######################################################################## # @@ -163,7 +163,7 @@ def read_PSSM_data(self): def get_feature_value(self, cutoff=5.5): """get the feature value.""" - sql = pdb2sql(self.pdb_file) + sql = pdb2sql.interface(self.pdb_file) # set achors for all residues and get their xyz xyz_info = sql.get('chainID,resSeq,resName', name='CB') @@ -178,10 +178,10 @@ def get_feature_value(self, cutoff=5.5): xyz_dict[tuple(info)] = pos # get interface contact residues - # ctc_res = ([chain 1 residues], [chain2 residues]) - ctc_res = sql.get_contact_residue(cutoff=cutoff) + # ctc_res = {"A":[chain 1 residues], "B": [chain2 residues]} + ctc_res = sql.get_contact_residues(cutoff=cutoff) sql.close() - ctc_res = ctc_res[0] + ctc_res[1] + ctc_res = ctc_res["A"] + ctc_res["B"] # handle with small interface or no interface total_res = len(ctc_res) diff --git a/deeprank/features/NaivePSSM.py b/deeprank/features/NaivePSSM.py index 0792d22f..603131df 100644 --- a/deeprank/features/NaivePSSM.py +++ b/deeprank/features/NaivePSSM.py @@ -2,9 +2,10 @@ from time import time import numpy as np +import pdb2sql from deeprank.features import FeatureClass -from deeprank.tools import SASA, pdb2sql +from deeprank.tools import SASA def printif(string, cond): return print(string) if cond else None @@ -148,7 +149,7 @@ def _smooth_pssm(pssm_data, msmooth=3): def get_feature_value(self, contact_only=True): """get the feature value.""" - sql = pdb2sql(self.pdbfile) + sql = pdb2sql.interface(self.pdbfile) xyz_info = sql.get('chainID,resSeq,resName', name='CB') xyz = sql.get('x,y,z', name='CB') @@ -157,7 +158,7 @@ def get_feature_value(self, contact_only=True): xyz_dict[tuple(info)] = pos contact_residue = sql.get_contact_residue(cutoff=5.5) - contact_residue = contact_residue[0] + contact_residue[1] + contact_residue = contact_residue["A"] + contact_residue["B"] sql.close() pssm_data_xyz = {} diff --git a/deeprank/features/ResidueDensity.py b/deeprank/features/ResidueDensity.py index ddcc22ad..bf784a04 100644 --- a/deeprank/features/ResidueDensity.py +++ b/deeprank/features/ResidueDensity.py @@ -1,8 +1,8 @@ import itertools import warnings +import pdb2sql from deeprank.features import FeatureClass -from deeprank.tools import pdb2sql from deeprank import config @@ -23,7 +23,7 @@ def __init__(self, pdb_data, chainA='A', chainB='B'): """ self.pdb_data = pdb_data - self.sql = pdb2sql(pdb_data) + self.sql = pdb2sql.interface(pdb_data) self.chains_label = [chainA, chainB] self.feature_data = {} @@ -40,7 +40,7 @@ def get(self, cutoff=5.5): # res = {('chainA,resSeq,resName'): set( # ('chainB,res1Seq,res1Name), # ('chainB,res2Seq,res2Name'))} - res = self.sql.get_contact_residue(chain1=self.chains_label[0], + res = self.sql.get_contact_residues(chain1=self.chains_label[0], chain2=self.chains_label[1], cutoff=cutoff, return_contact_pairs=True) diff --git a/deeprank/generate/DataGenerator.py b/deeprank/generate/DataGenerator.py index 921b8249..47361ae2 100644 --- a/deeprank/generate/DataGenerator.py +++ b/deeprank/generate/DataGenerator.py @@ -11,7 +11,7 @@ from deeprank import config from deeprank.config import logger from deeprank.generate import GridTools as gt -from deeprank.tools import pdb2sql +import pdb2sql try: from tqdm import tqdm @@ -155,7 +155,8 @@ def create_database( verbose=False, remove_error=True, prog_bar=False, - contact_distance=8.5): + contact_distance=8.5, + random_seed=None): """Create the hdf5 file architecture and compute the features/targets. Args: @@ -163,6 +164,7 @@ def create_database( remove_error (bool, optional): remove the groups that errored prog_bar (bool, optional): use tqdm contact_distance (float): contact distance cutoff, defaults to 8.5Å + random_seed (int): random seed for getting rotation axis and angle Raises: ValueError: If creation of the group errored. @@ -400,7 +402,7 @@ def create_database( self._add_pdb(molgrp, ref, 'native') # get the rotation axis and angle - axis, angle = self._get_aug_rot() + axis, angle = pdb2sql.transform.get_rot_axis_angle(random_seed) # create the new pdb and get molecule center mol_center = self._add_aug_pdb( @@ -415,7 +417,7 @@ def create_database( # grid center molgrp.require_group('grid_points') - center = DataGenerator._rotate_xyz( + center = pdb2sql.transform.rot_xyz_around_axis( self.f5[mol_name + '/grid_points/center'], axis, angle, mol_center) @@ -713,10 +715,15 @@ def add_target(self, prog_bar=False): @staticmethod def _get_grid_center(pdb, contact_distance): - sqldb = pdb2sql(pdb) + sqldb = pdb2sql.interface(pdb) contact_atoms = sqldb.get_contact_atoms(cutoff=contact_distance) - contact_atoms = list(set(contact_atoms[0] + contact_atoms[1])) + + tmp = [] + for i in contact_atoms.values(): + tmp.extend(i) + contact_atoms = list(set(tmp)) + center_contact = np.mean( np.array(sqldb.get('x,y,z', rowID=contact_atoms)), 0) @@ -1330,48 +1337,23 @@ def _add_aug_pdb(molgrp, pdbfile, name, axis, angle): list(float): center of the molecule """ # create tthe sqldb and extract positions - sqldb = pdb2sql(pdbfile) - - # rotate the positions and get molecule center - center = sqldb.rotation_around_axis(axis, angle) - - # get the data - sqldata = sqldb.get('*') + sqldb = pdb2sql.pdb2sql(pdbfile) - # close the db - sqldb.close() + # rotate the positions + pdb2sql.transform.rot_axis(sqldb, axis, angle) - # TODO the output does not obey PDB format - # TODO should not strip them! - # export the data to h5 - data = [] - for d in sqldata: - line = 'ATOM ' - line += '{:>5}'.format(d[0]) # serial - line += ' ' - line += '{:^4}'.format(d[1]) # name - line += '{:>1}'.format(d[2]) # altLoc - line += '{:>3}'.format(d[3]) # resname - line += ' ' - line += '{:>1}'.format(d[4]) # chainID - line += '{:>4}'.format(d[5]) # resSeq - line += '{:>1}'.format(d[6]) # iCODE - line += ' ' - line += '{: 8.3f}'.format(d[7]) # x - line += '{: 8.3f}'.format(d[8]) # y - line += '{: 8.3f}'.format(d[9]) # z - # TODO add the element - try: - line += '{: 6.2f}'.format(d[10]) # occ - line += '{: 6.2f}'.format(d[11]) # temp - except BaseException: - line += '{: 6.2f}'.format(0) # occ - line += '{: 6.2f}'.format(0) # temp - data.append(line) + # get molecule center + xyz = sqldb.get('x,y,z') + center = np.mean(xyz, 0) + # get the pdb-format data + data = sqldb.sql2pdb() data = np.array(data).astype('|S78') molgrp.create_dataset(name, data=data) + # close the db + sqldb.close() + return center # rotate th xyz-formatted feature in the database @@ -1403,68 +1385,7 @@ def _rotate_feature(molgrp, axis, angle, center, feat_name='all'): xyz = data[:, 1:4] # get rotated xyz - xyz_rot = DataGenerator._rotate_xyz(xyz, axis, angle, center) + xyz_rot = pdb2sql.transform.rot_xyz_around_axis(xyz, axis, angle, center) # put back the data - data[:, 1:4] = xyz_rot - - # rotate xyz - - @staticmethod - def _rotate_xyz(xyz, axis, angle, center): - """Get the rotated xyz. - - Args: - xyz(np.array): original xyz coordinates - axis (list(float)): axis of rotation - angle (float): angle of rotation - center (list(float)): center of rotation - - Returns: - np.array: rotated xyz coordinates - """ - - # get the data - ct, st = np.cos(angle), np.sin(angle) - ux, uy, uz = axis - - # definition of the rotation matrix - # see https://en.wikipedia.org/wiki/Rotation_matrix - rot_mat = np.array([[ct + ux ** 2 * (1 - ct), - ux * uy * (1 - ct) - uz * st, - ux * uz * (1 - ct) + uy * st], - [uy * ux * (1 - ct) + uz * st, - ct + uy ** 2 * (1 - ct), - uy * uz * (1 - ct) - ux * st], - [uz * ux * (1 - ct) - uy * st, - uz * uy * (1 - ct) + ux * st, - ct + uz ** 2 * (1 - ct)]]) - - # apply the rotation - xyz_rot = np.dot(rot_mat, (xyz - center).T).T + center - - return xyz_rot - - # get rotation axis and angle - - @staticmethod - def _get_aug_rot(): - """Get the rotation angle/axis. - - Returns: - list(float): axis of rotation - float: angle of rotation - """ - # define the axis - # uniform distribution on a sphere - # http://mathworld.wolfram.com/SpherePointPicking.html - u1, u2 = np.random.rand(), np.random.rand() - teta, phi = np.arccos(2 * u1 - 1), 2 * np.pi * u2 - axis = [np.sin(teta) * np.cos(phi), - np.sin(teta) * np.sin(phi), - np.cos(teta)] - - # and the rotation angle - angle = -np.pi + np.pi * np.random.rand() - - return axis, angle + molgrp['features/' + fn][:, 1:4] = xyz_rot diff --git a/deeprank/generate/GridTools.py b/deeprank/generate/GridTools.py index a3d0aeed..421851b1 100644 --- a/deeprank/generate/GridTools.py +++ b/deeprank/generate/GridTools.py @@ -5,9 +5,10 @@ import numpy as np from scipy.signal import bspline +import pdb2sql from deeprank.config import logger -from deeprank.tools import pdb2sql, sparse +from deeprank.tools import sparse try: from tqdm import tqdm @@ -205,7 +206,7 @@ def update_feature(self): def read_pdb(self): """Create a sql databse for the pdb.""" - self.sqldb = pdb2sql(self.molgrp['complex'][()]) + self.sqldb = pdb2sql.interface(self.molgrp['complex'][()]) # get the contact atoms and interface center def get_contact_center(self): @@ -214,12 +215,14 @@ def get_contact_center(self): contact_atoms = self.sqldb.get_contact_atoms( cutoff=self.contact_distance) - # create a set of unique indexes - self.contact_atoms = list(set(contact_atoms[0] + contact_atoms[1])) + tmp = [] + for i in contact_atoms.values(): + tmp.extend(i) + contact_atoms = list(set(tmp)) # get interface center self.center_contact = np.mean( - np.array(self.sqldb.get('x,y,z', rowID=self.contact_atoms)), 0) + np.array(self.sqldb.get('x,y,z', rowID=contact_atoms)), 0) ################################################################ # shortcut to add all the feature a @@ -332,8 +335,8 @@ def map_atomic_densities(self, only_contact=True): if only_contact: index = self.sqldb.get_contact_atoms(cutoff=self.contact_distance) else: - index = (self.sqldb.get('rowID', chainID='A'), - self.sqldb.get('rowID', chainID='B')) + index = {"A": self.sqldb.get('rowID', chainID='A'), + "B": self.sqldb.get('rowID', chainID='B')} # loop over all the data we want for elementtype, vdw_rad in self.local_tqdm( @@ -342,9 +345,9 @@ def map_atomic_densities(self, only_contact=True): t0 = time() xyzA = np.array(self.sqldb.get( - 'x,y,z', rowID=index[0], element=elementtype)) + 'x,y,z', rowID=index["A"], element=elementtype)) xyzB = np.array(self.sqldb.get( - 'x,y,z', rowID=index[1], element=elementtype)) + 'x,y,z', rowID=index["B"], element=elementtype)) tprocess = time() - t0 diff --git a/deeprank/learn/DataSet.py b/deeprank/learn/DataSet.py index a97d0673..91049f72 100644 --- a/deeprank/learn/DataSet.py +++ b/deeprank/learn/DataSet.py @@ -9,11 +9,12 @@ import h5py import numpy as np from tqdm import tqdm +import pdb2sql from deeprank import config from deeprank.config import logger from deeprank.generate import MinMaxParam, NormalizeData, NormParam -from deeprank.tools import pdb2sql, sparse +from deeprank.tools import sparse # import torch.utils.data as data_utils # The class used to subclass data_utils.Dataset @@ -33,7 +34,9 @@ def __init__(self, train_database, valid_database=None, test_database=None, transform_to_2D=False, projection=0, grid_shape=None, clip_features=True, clip_factor=1.5, - tqdm=False, process=True): + rotation_seed=None, + tqdm=False, + process=True): '''Generates the dataset needed for pytorch. This class hanldes the data generated by deeprank.generate to be @@ -111,6 +114,8 @@ def __init__(self, train_database, valid_database=None, test_database=None, tqdm (bool, optional): Print the progress bar process (bool, optional): Actually process the data set. Must be set to False when reusing a model for testing + rotation_seed(int, optional): random seed for getting rotation + axis and angle. Examples: >>> from deeprank.learn import * @@ -189,6 +194,9 @@ def __init__(self, train_database, valid_database=None, test_database=None, # print the progress bar or not self.tqdm = tqdm + # set random seed + self.rotation_seed = rotation_seed + # process the data if process: self.process_dataset() @@ -394,7 +402,8 @@ def create_index_molecules(self): if self.filter(fh5[k]): self.index_complexes += [(fdata, k, None, None)] for irot in range(self.data_augmentation): - axis, angle = self._get_aug_rot() + axis, angle = pdb2sql.transform.get_rot_axis_angle( + self.rotation_seed) self.index_complexes += [(fdata, k, angle, axis)] fh5.close() except Exception: @@ -1279,7 +1288,7 @@ def map_atomic_densities( list : atomic densities of each atom type on each chain """ - sql = pdb2sql(mol_data['complex'][()]) + sql = pdb2sql.interface(mol_data['complex'][()]) index = sql.get_contact_atoms() if angle is not None: @@ -1290,17 +1299,17 @@ def map_atomic_densities( # get pos of the contact atoms of correct type xyzA = np.array(sql.get( - 'x,y,z', rowID=index[0], element=elementtype)) + 'x,y,z', rowID=index['A'], element=elementtype)) xyzB = np.array(sql.get( - 'x,y,z', rowID=index[1], element=elementtype)) + 'x,y,z', rowID=index['B'], element=elementtype)) # rotate if necessary if angle is not None: if xyzA != np.array([]): - xyzA = self._rotate_coord(xyzA, center, angle, axis) + xyzA = pdb2sql.transform.rot_xyz_around_axis(xyzA, axis, angle, center) if xyzB != np.array([]): - xyzB = self._rotate_coord(xyzB, center, angle, axis) + xyzB = pdb2sql.transform.rot_xyz_around_axis(xyzB, axis, angle, center) # init the grid atdensA = np.zeros(npts) @@ -1373,7 +1382,7 @@ def map_feature(self, feat_names, mol_data, grid, npts, angle, axis): feat_value = data[:, 4] if angle is not None: - pos = self._rotate_coord(pos, center, angle, axis) + pos = pdb2sql.transform.rot_xyz_around_axis(pos, axis, angle, center) if __vectorize__ or __vectorize__ == 'both': @@ -1431,56 +1440,4 @@ def _featgrid(center, value, grid, npts): #dgrid[dd>> from deeprank.tools import StructureSimilarity - >>> decoy = '1AK4_1w.pdb' - >>> ref = '1AK4.pdb' - >>> sim = StructureSimilarity(decoy,ref) - >>> irmsd_fast = sim.compute_irmsd_fast(method='svd', - ... izone='1AK4.izone') - >>> irmsd = sim.compute_irmsd_pdb2sql(method='svd', - ... izone='1AK4.izone') - >>> lrmsd_fast = sim.compute_lrmsd_fast(method='svd', - ... lzone='1AK4.lzone',check=True) - >>> lrmsd = sim.compute_lrmsd_pdb2sql(exportpath=None, - ... method='svd') - >>> Fnat = sim.compute_Fnat_pdb2sql() - >>> Fnat_fast = sim.compute_fnat_fast( - ... ref_pairs='1AK4.ref_pairs') - >>> dockQ = sim.compute_DockQScore(Fnat_fast, - ... lrmsd_fast,irmsd_fast) - """ - - self.decoy = decoy - self.ref = ref - self.verbose = verbose - - def compute_lrmsd_fast(self, lzone=None, method='svd', check=True): - """Fast routine to compute the L-RMSD. - - This routine parse the PDB directly without using pdb2sql. - - L-RMSD is computed by aligning the longest chain of the decoy to - the one of the reference and computing the RMSD of the shortest - chain between decoy and reference. See reference: - DockQ: A Quality Measure for Protein-Protein Docking Models - https://doi.org/10.1371/journal.pone.0161879 - - Args: - lzone (None, optional): name of the file containing the zone - definition. If None the file will be calculated first. - method (str, optional): Method to align the fragments, - 'svd' or 'quaternion'. - check (bool, optional): Check if the sequences are aligned - and fix it if not. Defaults to True. - - Returns: - float: L-RMSD value of the conformation - """ - - # create/read the lzone file - if lzone is None: - resData = self.compute_lzone(save_file=False) - elif not os.path.isfile(lzone): - resData = self.compute_lzone(save_file=True, filename=lzone) - else: - resData = self.read_zone(lzone) - - ################################################## - # the check make sure that all the - # atoms are in the correct order - # I STRONGLY discourage turning the check off - # it actually reorder the xyz data of the native/decoy - # to match. - ################################################## - if check: - - # Note: - # 1. read_data_zone returns in_zone and not_in_zone - # which means the in_zone only defines the zone for fitting - # but not for rmsd calculation. - # 2. the decoy and ref pdb must have consitent residue - # numbering, otherwise e.g. shifted numbering can also give - # results which is totally wrong, because the code here does - # not do sequence alignment. - data_decoy_long, data_decoy_short = self.read_data_zone( - self.decoy, resData, return_not_in_zone=True) - data_ref_long, data_ref_short = self.read_data_zone( - self.ref, resData, return_not_in_zone=True) - - atom_decoy_long = [data[:3] for data in data_decoy_long] - atom_ref_long = [data[:3] for data in data_ref_long] - - xyz_decoy_long, xyz_ref_long = [], [] - for ind_decoy, at in enumerate(atom_decoy_long): - - try: - ind_ref = atom_ref_long.index(at) - xyz_decoy_long.append(data_decoy_long[ind_decoy][3:]) - xyz_ref_long.append(data_ref_long[ind_ref][3:]) - except ValueError: - warnings.warn( - f'Decoy atom {at} not found in reference pdb') - - atom_decoy_short = [data[:3] for data in data_decoy_short] - atom_ref_short = [data[:3] for data in data_ref_short] - - xyz_decoy_short, xyz_ref_short = [], [] - for ind_decoy, at in enumerate(atom_decoy_short): - try: - ind_ref = atom_ref_short.index(at) - xyz_decoy_short.append(data_decoy_short[ind_decoy][3:]) - xyz_ref_short.append(data_ref_short[ind_ref][3:]) - except ValueError: - warnings.warn( - f'Decoy atom {at} not found in reference pdb') - - # extract the xyz - else: - warnings.warn( - f'WARNING: The atom order have not been checked.' - f'Switch to check=True or continue at your own risk' - ) - xyz_decoy_long, xyz_decoy_short = self.read_xyz_zone( - self.decoy, resData, return_not_in_zone=True) - xyz_ref_long, xyz_ref_short = self.read_xyz_zone( - self.ref, resData, return_not_in_zone=True) - - # get the translation so that both A chains are centered - tr_decoy = self.get_trans_vect(xyz_decoy_long) - tr_ref = self.get_trans_vect(xyz_ref_long) - - # translate everything for 1 - xyz_decoy_short = self.translation(xyz_decoy_short, tr_decoy) - xyz_decoy_long = self.translation(xyz_decoy_long, tr_decoy) - - # translate everuthing for 2 - xyz_ref_short = self.translation(xyz_ref_short, tr_ref) - xyz_ref_long = self.translation(xyz_ref_long, tr_ref) - - # get the ideql rotation matrix - # to superimpose the A chains - U = self.get_rotation_matrix( - xyz_decoy_long, xyz_ref_long, method=method) - - # rotate the entire fragment - xyz_decoy_short = self.rotation_matrix( - xyz_decoy_short, U, center=False) - - # compute the RMSD - return self.get_rmsd(xyz_decoy_short, xyz_ref_short) - - def compute_lzone(self, save_file=True, filename=None): - """Compute the zone for L-RMSD calculation. - - Note: - It only provides the zone of long chain(s) which is used for - fitting. The zone used for calculating RMSD is defined in - the function `compute_lrmsd_fast`. - - Args: - save_file (bool, optional): save the zone file - filename (str, optional): name of the file - - Returns: - dict: definition of the zone. - """ - sql_ref = pdb2sql(self.ref) - nA = len(sql_ref.get('x,y,z', chainID='A')) - nB = len(sql_ref.get('x,y,z', chainID='B')) - - # detect which chain is the longest - long_chain = 'A' - if nA < nB: - long_chain = 'B' - - # extract data about the residue - data_test = [tuple(data) for data in sql_ref.get( - 'chainID,resSeq', - chainID=long_chain)] - data_test = sorted(set(data_test)) - - # close the sql - sql_ref.close() - - if save_file: - if filename is None: - f = open(self.ref.split('.')[0] + '.lzone', 'w') - else: - f = open(filename, 'w') - for res in data_test: - chain = res[0] - num = res[1] - f.write('zone %s%d-%s%d\n' % (chain, num, chain, num)) - f.close() - - resData = {} - for res in data_test: - chain = res[0] - num = res[1] - - if chain not in resData.keys(): - resData[chain] = [] - resData[chain].append(num) - - return resData - - def compute_irmsd_fast( - self, - izone=None, - method='svd', - cutoff=10.0, - check=True): - """Fast method to compute the i-rmsd. - - i-RMSD is computed by selecting the backbone atoms of reference - interface that is defined as any pair of heavy atoms from two - chains within 10Å of each other. - Align these backbone atoms as best as possible with their - coutner part in the decoy and compute the RMSD. See reference: - DockQ: A Quality Measure for Protein-Protein Docking Models - https://doi.org/10.1371/journal.pone.0161879 - - Args: - izone (None, optional): file name of the zone. - if None the zones will be calculated automatically. - method (str, optional): Method to align the fragments, - 'svd' or 'quaternion'. - cutoff (float, optional): cutoff for the contact atoms - check (bool, optional): Check if the sequences are aligned - and fix it if not. Should be True. - - Returns: - float: i-RMSD value of the conformation - """ - - # read the izone file - if izone is None: - resData = self.compute_izone(cutoff, save_file=False) - elif not os.path.isfile(izone): - resData = self.compute_izone(cutoff, save_file=True, - filename=izone) - else: - resData = self.read_zone(izone) - - ################################################## - # the check make sure that all the - # atoms are in the correct order - # I STRONGLY discourage turning the check off - # it actually fixes the order - ################################################## - if check: - - data_decoy = self.read_data_zone( - self.decoy, resData, return_not_in_zone=False) - data_ref = self.read_data_zone( - self.ref, resData, return_not_in_zone=False) - - atom_decoy = [data[:3] for data in data_decoy] - atom_ref = [data[:3] for data in data_ref] - - xyz_contact_decoy, xyz_contact_ref = [], [] - for ind_decoy, at in enumerate(atom_decoy): - try: - ind_ref = atom_ref.index(at) - xyz_contact_decoy.append(data_decoy[ind_decoy][3:]) - xyz_contact_ref.append(data_ref[ind_ref][3:]) - except ValueError: - warnings.warn( - f'Decoy atom {at} not found in reference pdb') - - # extract the xyz - else: - warnings.warn( - f'WARNING: The atom order have not been checked.' - f'Switch to check=True or continue at your own risk' - ) - xyz_contact_decoy = self.read_xyz_zone(self.decoy, resData) - xyz_contact_ref = self.read_xyz_zone(self.ref, resData) - - # get the translation so that both A chains are centered - tr_decoy = self.get_trans_vect(xyz_contact_decoy) - tr_ref = self.get_trans_vect(xyz_contact_ref) - - # translate everything - xyz_contact_decoy = self.translation(xyz_contact_decoy, tr_decoy) - xyz_contact_ref = self.translation(xyz_contact_ref, tr_ref) - - # get the ideql rotation matrix - # to superimpose the A chains - U = self.get_rotation_matrix( - xyz_contact_decoy, - xyz_contact_ref, - method=method) - - # rotate the entire fragment - xyz_contact_decoy = self.rotation_matrix( - xyz_contact_decoy, U, center=False) - - # return the RMSD - return self.get_rmsd(xyz_contact_decoy, xyz_contact_ref) - - def compute_izone(self, cutoff=5.0, save_file=True, filename=None): - """Compute the zones for i-rmsd calculationss. - - Args: - cutoff (float, optional): cutoff for the contact atoms - save_file (bool, optional): svae file containing the zone - filename (str, optional): filename - - Returns: - dict: i-zone definition - """ - - sql_ref = pdb2sql(self.ref) - contact_ref = sql_ref.get_contact_atoms( - cutoff=cutoff, - extend_to_residue=True, - return_only_backbone_atoms=True) - index_contact_ref = contact_ref[0] + contact_ref[1] - - # get the xyz and atom identifier of the decoy contact atoms - #xyz_contact_ref = sql_ref.get('x,y,z',rowID=index_contact_ref) - data_test = [ - tuple(data) for data in sql_ref.get( - 'chainID,resSeq', - rowID=index_contact_ref)] - data_test = sorted(set(data_test)) - - # close the sql - sql_ref.close() - - if save_file: - if filename is None: - f = open(self.ref.split('.')[0] + '.izone', 'w') - else: - f = open(filename, 'w') - - for res in data_test: - chain = res[0] - num = res[1] - f.write('zone %s%d-%s%d\n' % (chain, num, chain, num)) - f.close() - - resData = {} - for res in data_test: - chain = res[0] - num = res[1] - - if chain not in resData.keys(): - resData[chain] = [] - resData[chain].append(num) - - return resData - - def compute_fnat_fast(self, ref_pairs=None, cutoff=5): - """Compute the FNAT of the conformation. - - Fnat is the fraction of reference interface contacts preserved - in the interface of decoy. The interface is defined as any pair - of heavy atoms from two chains within 5Å of each other. - - Args: - ref_pairs (str, optional): file name describing the pairs - cutoff (int, optional): cutoff for the contact atoms - - Returns: - float: FNAT value - - Raises: - ValueError: if the decoy file is not found - """ - # read the ref_pairs file - if ref_pairs is None: - residue_pairs_ref = self.compute_residue_pairs_ref( - cutoff, save_file=False) - elif not os.path.isfile(ref_pairs): - residue_pairs_ref = self.compute_residue_pairs_ref( - cutoff, save_file=True, filename=ref_pairs) - else: - f = open(ref_pairs, 'rb') - residue_pairs_ref = pickle.load(f) - f.close() - - # create a dict of the decoy data - if isinstance(self.decoy, str) and os.path.isfile(self.decoy): - with open(self.decoy, 'r') as f: - data_decoy = f.readlines() - decoy_name = os.path.basename(self.decoy) - elif isinstance(self.decoy, np.ndarray): - data_decoy = [l.decode('utf-8') for l in self.decoy] - decoy_name = 'decoy' - else: - raise ValueError('Decoy not found in FNAT calculation.') - - # read the decoy data - residue_xyz = {} - residue_name = {} - - # go through all the lines - # that starts with ATOM - for line in data_decoy: - - if line.startswith('ATOM'): - - # chain ID - chainID = line[21] - if chainID == ' ': - chainID = line[72] - - # atom info - resSeq = int(line[22:26]) - resName = line[17:20].strip() - name = line[12:16].strip() - - # position - x, y, z = float(line[30:38]), float( - line[38:46]), float(line[46:54]) - - # dict entry - key = (chainID, resSeq, resName) - - # create the dict entry if necessary - if key not in residue_xyz.keys(): - residue_xyz[key] = [] - residue_name[key] = [] - - # we exclude the Hydrogens from the search - if name[0] != 'H': - residue_xyz[key].append([x, y, z]) - residue_name[key].append(name) - - # loop over the residue pairs of the - # and increment common if an atom pair is close enough - nCommon, nTotal = 0, 0 - for resA, resB_list in residue_pairs_ref.items(): - if resA in residue_xyz: - xyzA = residue_xyz[resA] - for resB in resB_list: - if resB in residue_xyz.keys(): - xyzB = residue_xyz[resB] - dist_min = np.min(np.array( - [np.sqrt(np.sum((np.array(p1) - np.array(p2)) ** 2)) - for p1 in xyzA for p2 in xyzB])) - if dist_min <= cutoff: - nCommon += 1 - nTotal += 1 - else: - msg = f'\t FNAT: not find residue: {resA} in {decoy_name}' - warnings.warn(msg) - - # normalize - return nCommon / nTotal - - def compute_residue_pairs_ref( - self, - cutoff=5.0, - save_file=True, - filename=None): - """Compute the residue pair on the reference conformation. - - Args: - cutoff (float, optional): cutoff for the contact atoms - save_file (bool, optional): save the file containing the - residue pairs - filename (None, optional): filename - - Returns: - dict: defintition of the residue pairs - """ - sql_ref = pdb2sql(self.ref) - residue_pairs_ref = sql_ref.get_contact_residue( - cutoff=cutoff, return_contact_pairs=True, excludeH=True) - sql_ref.close() - - if save_file: - if filename is None: - f = open(self.ref.split('.')[0] + 'residue_contact_pairs.pckl', - 'wb') - else: - f = open(filename, 'wb') - - # save as pickle - pickle.dump(residue_pairs_ref, f) - f.close() - - return residue_pairs_ref - - def compute_lrmsd_pdb2sql(self, exportpath=None, method='svd'): - """Slow routine to compute the L-RMSD. - - This routine parse the PDB directly using pdb2sql. - - L-RMSD is computed by aligning the longest chain of the decoy to - the one of the reference and computing the RMSD of the shortest - chain between decoy and reference. See reference: - DockQ: A Quality Measure for Protein-Protein Docking Models - https://doi.org/10.1371/journal.pone.0161879 - - Args: - exportpath (str, optional): file name where the aligned pdbs - are exported. - method (str, optional): Method to align the fragments, - 'svd' or 'quaternion'. - - Returns: - float: L-RMSD value of the conformation - """ - - # create the sql - sql_decoy = pdb2sql(self.decoy, sqlfile='decoy.db') - sql_ref = pdb2sql(self.ref, sqlfile='ref.db') - - # extract the pos of chains A - xyz_decoy_A = np.array(sql_decoy.get('x,y,z', chainID='A')) - xyz_ref_A = np.array(sql_ref.get('x,y,z', chainID='A')) - - # extract the pos of chains B - xyz_decoy_B = np.array(sql_decoy.get('x,y,z', chainID='B')) - xyz_ref_B = np.array(sql_ref.get('x,y,z', chainID='B')) - - # check the lengthes - if len(xyz_decoy_A) != len(xyz_ref_A): - xyz_decoy_A, xyz_ref_A = self.get_identical_atoms( - sql_decoy, sql_ref, 'A') - - if len(xyz_decoy_B) != len(xyz_ref_B): - xyz_decoy_B, xyz_ref_B = self.get_identical_atoms( - sql_decoy, sql_ref, 'B') - - # detect which chain is the longest - nA, nB = len(xyz_decoy_A), len(xyz_decoy_B) - if nA > nB: - xyz_decoy_long = xyz_decoy_A - xyz_ref_long = xyz_ref_A - - xyz_decoy_short = xyz_decoy_B - xyz_ref_short = xyz_ref_B - - else: - xyz_decoy_long = xyz_decoy_B - xyz_ref_long = xyz_ref_B - - xyz_decoy_short = xyz_decoy_A - xyz_ref_short = xyz_ref_A - - # get the translation so that both A chains are centered - tr_decoy = self.get_trans_vect(xyz_decoy_long) - tr_ref = self.get_trans_vect(xyz_ref_long) - - # translate everything for 1 - xyz_decoy_short = self.translation(xyz_decoy_short, tr_decoy) - xyz_decoy_long = self.translation(xyz_decoy_long, tr_decoy) - - # translate everuthing for 2 - xyz_ref_short = self.translation(xyz_ref_short, tr_ref) - xyz_ref_long = self.translation(xyz_ref_long, tr_ref) - - # get the ideal rotation matrix - # to superimpose the A chains - U = self.get_rotation_matrix( - xyz_decoy_long, xyz_ref_long, method=method) - - # rotate the entire fragment - xyz_decoy_short = self.rotation_matrix( - xyz_decoy_short, U, center=False) - - # compute the RMSD - lrmsd = self.get_rmsd(xyz_decoy_short, xyz_ref_short) - - # export the pdb for verifiactions - if exportpath is not None: - - # extract the pos of the dimer - xyz_decoy = np.array(sql_decoy.get('x,y,z')) - xyz_ref = np.array(sql_ref.get('x,y,z')) - - # translate - xyz_ref = self.translation(xyz_ref, tr_ref) - xyz_decoy = self.translation(xyz_decoy, tr_decoy) - - # rotate decoy - xyz_decoy = self.rotation_matrix(xyz_decoy, U, center=False) - - # update the sql database - sql_decoy.update_xyz(xyz_decoy) - sql_ref.update_xyz(xyz_ref) - - # export - sql_decoy.exportpdb(exportpath + '/lrmsd_decoy.pdb') - sql_ref.exportpdb(exportpath + '/lrmsd_aligned.pdb') - - # close the db - sql_decoy.close() - sql_ref.close() - - return lrmsd - - @staticmethod - def get_identical_atoms(db1, db2, chain): - """Return that atoms shared by both databse for a specific chain. - - Args: - db1 (TYPE): pdb2sql database of the first conformation - db2 (TYPE): pdb2sql database of the 2nd conformation - chain (str): chain name - - Returns: - list, list: list of xyz for both database - """ - # get data - data1 = db1.get('chainID,resSeq,name', chainID=chain) - data2 = db2.get('chainID,resSeq,name', chainID=chain) - - # tuplify - data1 = [tuple(d1) for d1 in data1] - data2 = [tuple(d2) for d2 in data2] - - # get the intersection - shared_data = list(set(data1).intersection(data2)) - - # get the xyz - xyz1, xyz2 = [], [] - for data in shared_data: - query = 'SELECT x,y,z from ATOM WHERE chainID=? AND resSeq=? and name=?' - xyz1.append(list(list(db1.c.execute(query, data))[0])) - xyz2.append(list(list(db2.c.execute(query, data))[0])) - - return xyz1, xyz2 - - def compute_irmsd_pdb2sql( - self, - cutoff=10, - method='svd', - izone=None, - exportpath=None): - """Slow method to compute the i-rmsd. - - i-RMSD is computed by selecting the backbone atoms of reference - interface that is defined as any pair of heavy atoms from two - chains within 10Å of each other. - Align these backbone atoms as best as possible with their - coutner part in the decoy and compute the RMSD. See reference: - DockQ: A Quality Measure for Protein-Protein Docking Models - https://doi.org/10.1371/journal.pone.0161879 - - Args: - izone (None, optional): file name of the zone. - if None the zones will be calculated first. - method (str, optional): Method to align the fragments, - 'svd' or 'quaternion'. - cutoff (float, optional): cutoff for the contact atoms - exportpath (str, optional): file name where the aligned pdbs - are exported. - - Returns: - float: i-RMSD value of the conformation - """ - - # create thes sql - sql_decoy = pdb2sql(self.decoy) - sql_ref = pdb2sql(self.ref) - - # get the contact atoms - if izone is None: - contact_ref = sql_ref.get_contact_atoms( - cutoff=cutoff, - extend_to_residue=True, - return_only_backbone_atoms=False) - index_contact_ref = contact_ref[0] + contact_ref[1] - else: - index_contact_ref = self.get_izone_rowID( - sql_ref, izone, return_only_backbone_atoms=False) - - # get the xyz and atom identifier of the decoy contact atoms - xyz_contact_ref = sql_ref.get('x,y,z', rowID=index_contact_ref) - data_contact_ref = sql_ref.get( - 'chainID,resSeq,resName,name', - rowID=index_contact_ref) - - # get the xyz and atom indeitifier of the reference - xyz_decoy = sql_decoy.get('x,y,z') - data_decoy = sql_decoy.get('chainID,resSeq,resName,name') - - # loop through the ref label - # check if the atom is in the decoy - # if yes -> add xyz to xyz_contact_decoy - # if no -> remove the corresponding to xyz_contact_ref - xyz_contact_decoy = [] - index_contact_decoy = [] - clean_ref = False - for iat, atom in enumerate(data_contact_ref): - - try: - index = data_decoy.index(atom) - index_contact_decoy.append(index) - xyz_contact_decoy.append(xyz_decoy[index]) - except Exception: - xyz_contact_ref[iat] = None - index_contact_ref[iat] = None - clean_ref = True - - # clean the xyz - if clean_ref: - xyz_contact_ref = [ - xyz for xyz in xyz_contact_ref if xyz is not None] - index_contact_ref = [ - ind for ind in index_contact_ref if ind is not None] - - # check that we still have atoms in both chains - chain_decoy = list( - set(sql_decoy.get('chainID', rowID=index_contact_decoy))) - chain_ref = list(set(sql_ref.get('chainID', rowID=index_contact_ref))) - - if len(chain_decoy) < 1 or len(chain_ref) < 1: - raise ValueError( - 'Error in i-rmsd: only one chain represented in one chain') - - # get the translation so that both A chains are centered - tr_decoy = self.get_trans_vect(xyz_contact_decoy) - tr_ref = self.get_trans_vect(xyz_contact_ref) - - # translate everything - xyz_contact_decoy = self.translation(xyz_contact_decoy, tr_decoy) - xyz_contact_ref = self.translation(xyz_contact_ref, tr_ref) - - # get the ideql rotation matrix - # to superimpose the A chains - U = self.get_rotation_matrix( - xyz_contact_decoy, - xyz_contact_ref, - method=method) - - # rotate the entire fragment - xyz_contact_decoy = self.rotation_matrix( - xyz_contact_decoy, U, center=False) - - # compute the RMSD - irmsd = self.get_rmsd(xyz_contact_decoy, xyz_contact_ref) - - # export the pdb for verifiactions - if exportpath is not None: - - # update the sql database - sql_decoy.update_xyz(xyz_contact_decoy, index=index_contact_decoy) - sql_ref.update_xyz(xyz_contact_ref, index=index_contact_ref) - - sql_decoy.exportpdb( - exportpath + '/irmsd_decoy.pdb', - index=index_contact_decoy) - sql_ref.exportpdb( - exportpath + '/irmsd_ref.pdb', - index=index_contact_ref) - - # close the db - sql_decoy.close() - sql_ref.close() - - return irmsd - - @staticmethod - def get_izone_rowID(sql, izone, return_only_backbone_atoms=True): - """Compute the index of the izone atoms. - - Args: - sql (pdb2sql): database of the conformation - izone (str): filename to store the zone - return_only_backbone_atoms (bool, optional): Returns only - the backbone atoms - - Returns: - lis(int): index of the atoms in the zone - - Raises: - FileNotFoundError: if the izone file is not found - """ - # read the file - if not os.path.isfile(izone): - raise FileNotFoundError('i-zone file not found', izone) - - with open(izone, 'r') as f: - data = f.readlines() - - # get the data out of it - resData = {} - for line in data: - - res = line.split()[1].split('-')[0] - chainID, resSeq = res[0], int(res[1:]) - - if chainID not in resData.keys(): - resData[chainID] = [] - - resData[chainID].append(resSeq) - - # get the rowID - index_contact = [] - - for chainID, resSeq in resData.items(): - if return_only_backbone_atoms: - index_contact += sql.get('rowID', - chainID=chainID, - resSeq=resSeq, - name=['C', 'CA', 'N', 'O']) - else: - index_contact += sql.get('rowID', - chainID=chainID, resSeq=resSeq) - - return index_contact - - def compute_Fnat_pdb2sql(self, cutoff=5.0): - """Slow method to compute the FNAT usign pdb2sql. - - Args: - cutoff (float, optional): cutoff for the contact atoms - - Returns: - float: Fnat value for the conformation - """ - # create the sql - sql_decoy = pdb2sql(self.decoy) - sql_ref = pdb2sql(self.ref) - - # get the contact atoms of the decoy - residue_pairs_decoy = sql_decoy.get_contact_residue( - cutoff=cutoff, return_contact_pairs=True, excludeH=True) - - # get the contact atoms of the ref - residue_pairs_ref = sql_ref.get_contact_residue( - cutoff=cutoff, return_contact_pairs=True, excludeH=True) - - # form the pair data - data_pair_decoy = [] - for resA, resB_list in residue_pairs_decoy.items(): - data_pair_decoy += [(resA, resB) for resB in resB_list] - - # form the pair data - data_pair_ref = [] - for resA, resB_list in residue_pairs_ref.items(): - data_pair_ref += [(resA, resB) for resB in resB_list] - - # find the umber of residue that ref and decoys hace in common - nCommon = len(set(data_pair_ref).intersection(data_pair_decoy)) - - # normalize - Fnat = nCommon / len(data_pair_ref) - - sql_decoy.close() - sql_ref.close() - - return Fnat - - #################################################################### - # - # HELPER ROUTINES TO HANDLE THE ZONE FILES - # - #################################################################### - - @staticmethod - def read_xyz_zone(pdb_file, resData, return_not_in_zone=False): - """Read the xyz of the zone atoms. - - Args: - pdb_file (str): filename containing the pdb of the molecule - resData (dict): information about the residues - return_not_in_zone (bool, optional): Do we return the atoms - not in the zone - - Returns: - list(float): XYZ of the atoms in the zone - """ - # read the ref file - with open(pdb_file, 'r') as f: - data = f.readlines() - - # get the xyz of the - xyz_in_zone = [] - xyz_not_in_zone = [] - - for line in data: - - if line.startswith('ATOM'): - - chainID = line[21] - if chainID == ' ': - chainID = line[72] - - resSeq = int(line[22:26]) - name = line[12:16].strip() - - x = float(line[30:38]) - y = float(line[38:46]) - z = float(line[46:54]) - - if chainID in resData.keys(): - - if resSeq in resData[chainID] and name in [ - 'C', 'CA', 'N', 'O']: - xyz_in_zone.append([x, y, z]) - - elif resSeq not in resData[chainID] and name in [ - 'C', 'CA', 'N', 'O']: - xyz_not_in_zone.append([x, y, z]) - - else: - if name in ['C', 'CA', 'N', 'O']: - xyz_not_in_zone.append([x, y, z]) - - if return_not_in_zone: - return xyz_in_zone, xyz_not_in_zone - - else: - return xyz_in_zone - - @staticmethod - def read_data_zone(pdb_file, resData, return_not_in_zone=False): - """Read the data of the atoms in the zone. - - Args: - pdb_file (str): filename containing the pdb of the molecule - resData (dict): information about the residues - return_not_in_zone (bool, optional): Do we return the atoms - not in the zone - - Returns: - list(float): data of the atoms in the zone - """ - # read the ref file - if isinstance(pdb_file, str) and os.path.isfile(pdb_file): - with open(pdb_file, 'r') as f: - data = f.readlines() - elif isinstance(pdb_file, np.ndarray): - data = [l.decode('utf-8') for l in pdb_file] - - # get the xyz of the - data_in_zone = [] - data_not_in_zone = [] - - for line in data: - - if line.startswith('ATOM'): - - chainID = line[21] - if chainID == ' ': - chainID = line[72] - - resSeq = int(line[22:26]) - name = line[12:16].strip() - - x = float(line[30:38]) - y = float(line[38:46]) - z = float(line[46:54]) - - if chainID in resData.keys(): - - if resSeq in resData[chainID] and name in [ - 'C', 'CA', 'N', 'O']: - data_in_zone.append([chainID, resSeq, name, x, y, z]) - - elif resSeq not in resData[chainID] and name in [ - 'C', 'CA', 'N', 'O']: - data_not_in_zone.append( - [chainID, resSeq, name, x, y, z]) - - else: - if name in ['C', 'CA', 'N', 'O']: - data_not_in_zone.append( - [chainID, resSeq, name, x, y, z]) - - if return_not_in_zone: - return data_in_zone, data_not_in_zone - else: - return data_in_zone - - @staticmethod - def read_zone(zone_file): - """Read the zone file. - - Args: - zone_file (str): name of the file - - Returns: - dict: Info about the residues in the zone - - Raises: - FileNotFoundError: if the zone file is not found - """ - - # read the zone file - if not os.path.isfile(zone_file): - raise FileNotFoundError('zone file not found', zone_file) - - with open(zone_file, 'r') as f: - data = f.readlines() - - # get the data out of it - resData = {} - for line in data: - - # line = zone A4-A4 for positive resNum - # or line = zone A-4-A-4 for negative resNum - # that happens for example in 2OUL - - # split the line - res = line.split()[1].split('-') - - # if the resnum was positive - # we have e.g res = [A4,A4] - if len(res) == 2: - res = res[0] - chainID, resSeq = res[0], int(res[1:]) - - # if the resnum was negative was negtive - # we have e.g res = [A,4,A,4] - elif len(res) == 4: - chainID, resSeq = res[0], -int(res[1]) - - if chainID not in resData.keys(): - resData[chainID] = [] - - resData[chainID].append(resSeq) - - return resData - - ################################################################### - # - # ROUTINES TO ACTUALY ALIGN THE MOLECULES - # - ################################################################### - - @staticmethod - def compute_DockQScore(Fnat, lrmsd, irmsd, d1=8.5, d2=1.5): - """Compute the DockQ Score. - - Args: - Fnat (float): Fnat value - lrmsd (float): lrmsd value - irmsd (float): irmsd value - d1 (float, optional): first coefficient for the DockQ - calculations - d2 (float, optional): second coefficient for the DockQ - calculations - - Returns: - float: dockQ value - """ - def scale_rms(rms, d): - return(1. / (1 + (rms / d)**2)) - - return 1. / 3 * (Fnat + scale_rms(lrmsd, d1) + scale_rms(irmsd, d2)) - - @staticmethod - def get_rmsd(P, Q): - """compute the RMSD. - - Args: - P (np.array(nx3)): position of the points in the first - molecule - Q (np.array(nx3)): position of the points in the second - molecule - - Returns: - float: RMSD value - """ - n = len(P) - return np.sqrt(1. / n * np.sum((P - Q)**2)) - - @staticmethod - def get_trans_vect(P): - """Get the translationv vector to the origin. - - Args: - P (np.array(nx3)): position of the points in the molecule - - Returns: - float: minus mean value of the xyz columns - """ - return -np.mean(P, 0) - - # main switch for the rotation matrix - # add new methods here if necessary - def get_rotation_matrix(self, P, Q, method='svd'): - - # get the matrix with Kabsh method - if method.lower() == 'svd': - mat = self.get_rotation_matrix_Kabsh(P, Q) - - # or with the quaternion method - elif method.lower() == 'quaternion': - mat = self.get_rotation_matrix_quaternion(P, Q) - - else: - raise ValueError( - f'{method} is not a valid method for rmsd alignement. ' - f'Options are svd or quaternions') - - return mat - - @staticmethod - def get_rotation_matrix_Kabsh(P, Q): - """Get the rotation matrix to aligh two point clouds. - - The method is based on th Kabsh approach - https://cnx.org/contents/HV-RsdwL@23/Molecular-Distance-Measures - - Args: - P (np.array): xyz of the first point cloud - Q (np.array): xyz of the second point cloud - - Returns: - np.array: rotation matrix - - Raises: - ValueError: matrix have different sizes - """ - - pshape = P.shape - qshape = Q.shape - - if pshape[0] == qshape[0]: - npts = pshape[0] - else: - raise ValueError( - "Matrix don't have the same number of points", - P.shape, - Q.shape) - - p0, q0 = np.abs(np.mean(P, 0)), np.abs(np.mean(Q, 0)) - eps = 1E-6 - if any(p0 > eps) or any(q0 > eps): - raise ValueError('You must center the fragment first', p0, q0) - - # form the covariance matrix - A = np.dot(P.T, Q) / npts - - # SVD the matrix - V, S, W = np.linalg.svd(A) - - # the W matrix returned here is - # already its transpose - # https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.linalg.svd.html - W = W.T - - # determinant - d = np.linalg.det(np.dot(W, V.T)) - - # form the U matrix - Id = np.eye(3) - if d < 0: - Id[2, 2] = -1 - - U = np.dot(W, np.dot(Id, V.T)) - - return U - - @staticmethod - def get_rotation_matrix_quaternion(P, Q): - """Get the rotation matrix to aligh two point clouds. - - The method is based on the quaternion approach - http://www.ams.stonybrook.edu/~coutsias/papers/rmsd17.pdf - - Args: - P (np.array): xyz of the first point cloud - Q (np.array): xyz of the second point cloud - - Returns: - np.array: rotation matrix - - Raises: - ValueError: matrix have different sizes - """ - - pshape = P.shape - qshape = Q.shape - - if pshape[0] != qshape[0]: - raise ValueError( - "Matrix don't have the same number of points", - P.shape, - Q.shape) - - p0, q0 = np.abs(np.mean(P, 0)), np.abs(np.mean(Q, 0)) - eps = 1E-6 - if any(p0 > eps) or any(q0 > eps): - raise ValueError('You must center the fragment first', p0, q0) - - # form the correlation matrix - R = np.dot(P.T, Q) - - # form the F matrix (eq. 10 of ref[1]) - F = np.zeros((4, 4)) - - F[0, 0] = np.trace(R) - F[0, 1] = R[1, 2] - R[2, 1] - F[0, 2] = R[2, 0] - R[0, 2] - F[0, 3] = R[0, 1] - R[1, 0] - - F[1, 0] = R[1, 2] - R[2, 1] - F[1, 1] = R[0, 0] - R[1, 1] - R[2, 2] - F[1, 2] = R[0, 1] + R[1, 0] - F[1, 3] = R[0, 2] + R[2, 0] - - F[2, 0] = R[2, 0] - R[0, 2] - F[2, 1] = R[0, 1] + R[1, 0] - F[2, 2] = -R[0, 0] + R[1, 1] - R[2, 2] - F[2, 3] = R[1, 2] + R[2, 1] - - F[3, 0] = R[0, 1] - R[1, 0] - F[3, 1] = R[0, 2] + R[2, 0] - F[3, 2] = R[1, 2] + R[2, 1] - F[3, 3] = -R[0, 0] - R[1, 1] + R[2, 2] - - # diagonalize it - l, U = np.linalg.eig(F) - - # extract the eigenvect of the highest eigenvalues - indmax = np.argmax(l) - q0, q1, q2, q3 = U[:, indmax] - - # form the rotation matrix (eq. 33 ref[1]) - U = np.zeros((3, 3)) - - U[0, 0] = q0**2 + q1**2 - q2**2 - q3**2 - U[0, 1] = 2 * (q1 * q2 - q0 * q3) - U[0, 2] = 2 * (q1 * q3 + q0 * q2) - U[1, 1] = 2 * (q1 * q2 + q0 * q3) - U[1, 2] = q0**2 - q1**2 + q2 * 2 - q3**2 - U[1, 2] = 2 * (q2 * q3 - q0 * q1) - U[2, 0] = 2 * (q1 * q3 - q0 * q2) - U[2, 1] = 2 * (q2 * q3 + q0 * q1) - U[2, 2] = q0**2 - q1**2 - q2**2 + q3**2 - - return U - - @staticmethod - def translation(xyz, vect): - """Translate a fragment. - - Args: - xyz (np.array): position of the fragment - vect (np.array): translation vector - - Returns: - np.array: translated positions - """ - return xyz + vect - - @staticmethod - def rotation_around_axis(xyz, axis, angle): - """Rotate a fragment around an axis. - - Args: - xyz (np.array): original positions - axis (np.array): axis of rotation - angle (float): angle of rotation (radians) - - Returns: - np.array: Rotated positions - """ - - # get the data - ct, st = np.cos(angle), np.sin(angle) - ux, uy, uz = axis - - # get the center of the molecule - xyz0 = np.mean(xyz, 0) - - # definition of the rotation matrix - # see https://en.wikipedia.org/wiki/Rotation_matrix - rot_mat = np.array([[ct + ux**2 * (1 - ct), - ux * uy * (1 - ct) - uz * st, - ux * uz * (1 - ct) + uy * st], - [uy * ux * (1 - ct) + uz * st, - ct + uy**2 * (1 - ct), - uy * uz * (1 - ct) - ux * st], - [uz * ux * (1 - ct) - uy * st, - uz * uy * (1 - ct) + ux * st, - ct + uz**2 * (1 - ct)]]) - - # apply the rotation - return np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 - - @staticmethod - def rotation_euler(xyz, alpha, beta, gamma): - """Rotate a fragment from Euler rotation angle. - - Args: - xyz (np.array): original positions - alpha (float): rotation angle around the x axis - beta (float): rotation angle around the x axis - gamma (float): rotation angle around the x axis - - Returns: - np.array: Rotated positions - """ - # precompute the trig - ca, sa = np.cos(alpha), np.sin(alpha) - cb, sb = np.cos(beta), np.sin(beta) - cg, sg = np.cos(gamma), np.sin(gamma) - - # get the center of the molecule - xyz0 = np.mean(xyz, 0) - - # rotation matrices - rx = np.array([[1, 0, 0], [0, ca, -sa], [0, sa, ca]]) - ry = np.array([[cb, 0, sb], [0, 1, 0], [-sb, 0, cb]]) - rz = np.array([[cg, -sg, 0], [sg, cg, 0], [0, 0, 1]]) - - rot_mat = np.dot(rx, np.dot(ry, rz)) - - # apply the rotation - return np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 - - @staticmethod - def rotation_matrix(xyz, rot_mat, center=True): - """Rotate a fragment from a roation matrix. - - Args: - xyz (np.array): original positions - rot_mat (np.array): rotation matrix - center (bool, optional): Center the fragment before rotation - - Returns: - np.array: rotated positions - """ - if center: - xyz0 = np.mean(xyz) - mat = np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 - else: - mat = np.dot(rot_mat, (xyz).T).T - - return mat - - -# if __name__ == '__main__': -# import time -# BM4 = '/home/nico/Documents/projects/deeprank/data/HADDOCK/BM4_dimers/' -# decoy = BM4 + 'decoys_pdbFLs/1AK4/water/1AK4_1w.pdb' -# ref = BM4 + 'BM4_dimers_bound/pdbFLs_ori/1AK4.pdb' - -# sim = StructureSimilarity(decoy,ref) - -# #---------------------------------------------------------------------- - -# t0 = time.time() -# irmsd_fast = sim.compute_irmsd_fast(method='svd',izone='1AK4.izone') -# t1 = time.time()-t0 -# print('\nIRMSD TIME FAST %f in %f sec' %(irmsd_fast,t1)) - -# t0 = time.time() -# irmsd = sim.compute_irmsd_pdb2sql(method='svd',izone='1AK4.izone') -# t1 = time.time()-t0 -# print('IRMSD TIME SQL %f in %f sec' %(irmsd,t1)) - -# #---------------------------------------------------------------------- - -# t0 = time.time() -# lrmsd_fast = sim.compute_lrmsd_fast(method='svd',lzone='1AK4.lzone',check=True) -# t1 = time.time()-t0 -# print('\nLRMSD TIME FAST %f in %f sec' %(lrmsd_fast,t1)) - -# t0 = time.time() -# lrmsd = sim.compute_lrmsd_pdb2sql(exportpath=None,method='svd') -# t1 = time.time()-t0 -# print('LRMSD TIME SQL %f in %f sec' %(lrmsd,t1)) - -# #---------------------------------------------------------------------- - -# t0 = time.time() -# Fnat = sim.compute_Fnat_pdb2sql() -# t1 = time.time()-t0 -# print('\nFNAT TIME SQL %f in %f sec' %(Fnat,t1)) - - -# t0 = time.time() -# Fnat_fast = sim.compute_fnat_fast(ref_pairs='1AK4.ref_pairs') -# t1 = time.time()-t0 -# print('LRMSD TIME FAST %f in %f sec' %(Fnat_fast,t1)) - -# #---------------------------------------------------------------------- - -# dockQ = sim.compute_DockQScore(Fnat_fast,lrmsd_fast,irmsd_fast) -# print('\nDockQ %f' %dockQ ) diff --git a/deeprank/tools/__init__.py b/deeprank/tools/__init__.py index 8c750024..b8cef67e 100644 --- a/deeprank/tools/__init__.py +++ b/deeprank/tools/__init__.py @@ -1,5 +1,3 @@ -from .pdb2sql import pdb2sql from .sasa import SASA from .sparse import * -from .StructureSimilarity import StructureSimilarity diff --git a/deeprank/tools/pdb2sql.py b/deeprank/tools/pdb2sql.py deleted file mode 100644 index ca1b2522..00000000 --- a/deeprank/tools/pdb2sql.py +++ /dev/null @@ -1,1290 +0,0 @@ -import os -import sqlite3 -import subprocess as sp -import sys -import warnings - -import numpy as np - -from deeprank.config import logger - - -class pdb2sql(object): - - def __init__(self, - pdbfile, - sqlfile=None, - fix_chainID=True, - verbose=False, - no_extra=True): - """Create a SQL data base for a PDB file. - - This allows to easily parse and extract information of the PDB - using SQL queries. This is a local version of the pdb2sql tool - (https://github.com/DeepRank/pdb2sql). pdb2sql is further - developped as a standalone we should use the library directly. - - Args: - pdbfile (str or list(bytes)) : name of pdbfile or - list of bytes containing the pdb data - sqlfile (str, optional): name of the sqlfile. - By default it is created in memory only. - fix_chainID (bool, optinal): check if the name of the chains - are A,B,C, .... and fix it if not. - verbose (bool): probably print stuff. - no_extra (bool): remove occupancy and tempFactor clumns or not - - Examples: - >>> # create the sql - >>> db = pdb2sql('1AK4_100w.pdb') - >>> - >>> # print the database - >>> db.prettyprint() - >>> - >>> # get the names of the columns - >>> db.get_colnames() - >>> - >>> # extract the xyz position of the atoms with name CB - >>> xyz = db.get('*',index=[0,1,2,3]) - >>> print(xyz) - >>> - >>> xyz = db.get('rowID',where="resName='VAL'") - >>> print(xyz) - >>> - >>> db.add_column('CHARGE','FLOAT') - >>> db.put('CHARGE',0.1) - >>> db.prettyprint() - >>> - >>> db.exportpdb('chainA.pdb',where="chainID='A'") - >>> - >>> # close the database - >>> db.close() - """ - self.pdbfile = pdbfile - self.sqlfile = sqlfile - self.is_valid = True - self.verbose = verbose - self.no_extra = no_extra - - # create the database - self._create_sql() - - # backbone type - self.backbone_type = ['C', 'CA', 'N', 'O'] - - # hard limit for the number of SQL varaibles - self.SQLITE_LIMIT_VARIABLE_NUMBER = 999 - self.max_sql_values = 950 - - # fix the chain ID - if fix_chainID: - self._fix_chainID() - - # a few constant - self.residue_key = 'chainID,resSeq,resName' - # self.atom_key = 'chainID,resSeq,resName,name' - #################################################################### - # - # CREATION AND PRINTING - # - #################################################################### - - def _create_sql(self): - """Create the sql database.""" - - pdbfile = self.pdbfile - sqlfile = self.sqlfile - - if self.verbose: - logger.info('-- Create SQLite3 database') - - # name of the table - # table = 'ATOM' - - # column names and types - self.col = {'serial': 'INT', - 'name': 'TEXT', - 'altLoc': 'TEXT', - 'resName': 'TEXT', - 'chainID': 'TEXT', - 'resSeq': 'INT', - 'iCode': 'TEXT', - 'x': 'REAL', - 'y': 'REAL', - 'z': 'REAL', - 'occ': 'REAL', - 'temp': 'REAL', - 'element': 'TEXT' - } - - # delimtier of the column format - # taken from - # http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM - self.delimiter = { - 'serial': [6, 11], - 'name': [12, 16], - 'altLoc': [16, 17], - 'resName': [17, 20], - 'chainID': [21, 22], - 'resSeq': [22, 26], - 'iCode': [26, 27], - 'x': [30, 38], - 'y': [38, 46], - 'z': [46, 54], - 'occ': [54, 60], - 'temp': [60, 66], - 'element': [76,78] - } - - if self.no_extra: - del self.col['occ'] - del self.col['temp'] - - # size of the things - ncol = len(self.col) - - # open the data base - # if we do not specify a db name - # the db is only in RAM - # there might be little advantage to use memory - # https://stackoverflow.com/questions/764710/ - if self.sqlfile is None: - self.conn = sqlite3.connect(':memory:') - - # or we create a new db file - else: - if os.path.isfile(sqlfile): - sp.call('rm %s' % sqlfile, shell=True) - self.conn = sqlite3.connect(sqlfile) - self.c = self.conn.cursor() - - # intialize the header/placeholder - header, qm = '', '' - for ic, (colname, coltype) in enumerate(self.col.items()): - header += f'{colname} {coltype}' - qm += '?' - if ic < ncol - 1: - header += ', ' - qm += ',' - - # create the table - query = f'CREATE TABLE ATOM ({header})' - self.c.execute(query) - - # read the pdb file - # this is dangerous if there are ATOM written in the comment part - # which happends often - # data = sp.check_output("awk '/ATOM/' %s" %pdbfile,shell=True).decode('utf8').split('\n') - - # a safer version consist at matching against the first field - # won't work on windows - # data = sp.check_output("awk '$1 ~ /^ATOM/' %s" %pdbfile,shell=True).decode('utf8').split('\n') - - # a pure python way - # RMK we go through the data twice here. Once to read the ATOM line and once to parse the data ... - # we could do better than that. But the most time consuming step seems to be the CREATE TABLE query - # if we path a file we read it - if isinstance(pdbfile, str): - if os.path.isfile(pdbfile): - with open(pdbfile, 'r') as fi: - data = [line.split('\n')[0] - for line in fi if line.startswith('ATOM')] - else: - raise FileNotFoundError(f'PDB file {pdbfile} not found') - - # if we pass a list as for h5py read/write - # we directly use that - elif isinstance(pdbfile, np.ndarray): - data = [l.decode('utf-8') for l in pdbfile.tolist()] - - # if we cant read it - else: - raise ValueError(f'PDB data not recognized: {pdbfile}') - - # if there is no ATOM in the file - if len(data) == 1 and data[0] == '': - self.is_valid = False - raise ValueError(f'No ATOM found in the pdb data {pdbfile}') - - # haddock chain ID fix - del_copy = self.delimiter.copy() - if data[0][del_copy['chainID'][0]] == ' ': - del_copy['chainID'] = [72, 73] - - # get all the data - data_atom = [] - for line in data: - - # sometimes we still have an empty line somewhere - if len(line) == 0: - continue - - # browse all attribute of each atom - at = () - for (colname, coltype) in self.col.items(): - - # get the piece of data - data_col = line[del_copy[colname][0]: - del_copy[colname][1]].strip() - - # convert it if necessary - if coltype == 'INT': - data_col = int(data_col) - elif coltype == 'REAL': - data_col = float(data_col) - - # get element if it does not exist - if colname == "element" and not data_col: - data_col = pdb2sql._get_element(line) - - # append keep the comma !! - # we need proper tuple - at += (data_col,) - - # append - data_atom.append(at) - - # push in the database - self.c.executemany(f'INSERT INTO ATOM VALUES ({qm})', data_atom) - - @staticmethod - def _get_element(pdb_line): - """Get element type from the atom type of a pdb line - - Notes: - Atom type occupies 13-16th columns of a PDB line. - http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM - Four situations exist: - 13 14 15 16 - C A The element is C - C A The element is Ca - 1 H G The element is H - H E 2 1 The element is H - - Args: - pdb_line(str): one PDB ATOM line - - Returns: - [str]: element name - """ - - first_char = pdb_line[12].strip() - last_char = pdb_line[15].strip() - if first_char: - if first_char in "0123456789": - elem = pdb_line[13] - elif first_char == "H" and last_char: - elem = "H" - else: - elem = pdb_line[12:14] - - else: - elem = pdb_line[13] - return elem - - - def _fix_chainID(self): - """Fix the chain ID if necessary. - - Replace the chain ID by A,B,C,D, ..... in that order - """ - - from string import ascii_uppercase - - # get the current names - data = self.get('chainID') - natom = len(data) - - # get uniques - chainID = [] - for c in data: - if c not in chainID: - chainID.append(c) - - if chainID == ['A', 'B']: - return - - if len(chainID) > 26: - warnings.warn( - f"More than 26 chains have been detected. " - f"This is so far not supported") - sys.exit() - - # declare the new names - newID = [''] * natom - - # fill in the new names - for ic, chain in enumerate(chainID): - index = self.get('rowID', chainID=chain) - for ind in index: - newID[ind] = ascii_uppercase[ic] - - # update the new name - self.update_column('chainID', newID) - - # get the names of the columns - def get_colnames(self): - """Print the colom names of the database.""" - - cd = self.conn.execute('select * from atom') - print('Possible column names are:') - names = list(map(lambda x: x[0], cd.description)) - print('\trowID') - for n in names: - print('\t' + n) - - # print the database - def prettyprint(self): - """Print the database with pandas.""" - - import pandas.io.sql as psql - df = psql.read_sql("SELECT * FROM ATOM", self.conn) - print(df) - - def uglyprint(self): - """Raw print of the database.""" - - ctmp = self.conn.cursor() - ctmp.execute("SELECT * FROM ATOM") - print(ctmp.fetchall()) - - #################################################################### - # - # GET FUNCTIONS - # - # get(attribute,selection) -> return the atribute(s) value(s) - # for the given selection - # get_contact_atoms() -> return a list of rowID - # for the contact atoms - # get_contact_residue() -> return a list of resSeq - # for the contact residue - # - #################################################################### - - def get(self, atnames, **kwargs): - """Get data from the sql database. - - Get the values of specified attributes for a specific selection. - - Args: - - atnames (str): attribute name. They can be printed via - the get_colnames(). - - serial - - name - - atLoc - - resName - - chainID - - resSeq, - - iCode, - - x/y/z - Several attributes can be specified at once e.g 'x,y,z' - - kwargs : Several options are possible to select atoms. - Each column can be used as a keyword argument. - Several keywords can be combined assuming a AND - logical combination. - None : return the entire table - chainID = 'A' select chain from name - resIndex = [1,2,3] select residue from index - resName = ['VAL', 'LEU'] select residue from name - name = ['CA', 'N'] select atoms from names - rowID = [1,2,3] select atoms from index - - Returns: - np.array: Numpy array containing the requested data. - - Examples: - >>> db = pdb2sql(filename) - >>> xyz = db.get('x,y,z',name = ['CA', 'CB']) - >>> xyz = db.get('x,y,z',chainID='A',resName=['VAL', 'LEU']) - """ - - # the asked keys - keys = kwargs.keys() - - # check if the column exists - try: - self.c.execute(f"SELECT EXISTS(SELECT {atnames} FROM ATOM)") - except BaseException: - logger.error( - f"Column {atnames} not found in the database") - self.get_colnames() - return - - # if we have 0 key we take the entire db - if len(kwargs) == 0: - query = 'SELECT {an} FROM ATOM'.format(an=atnames) - data = [list(row) for row in self.c.execute(query)] - - ################################################################ - # GENERIC QUERY - # - # the only one we need - # each keys must be a valid columns - # each value may be a single value or an array - # AND is assumed between different keys - # OR is assumed for the different values of a given key - # - ################################################################ - else: - - # check that all the keys exists - for k in keys: - - if k.startswith('no_'): - k = k[3:] - - try: - self.c.execute(f"SELECT EXISTS(SELECT {k} FROM ATOM)") - except BaseException: - logger.error(f'Column {k} not found in the database') - self.get_colnames() - return - - # form the query and the tuple value - query = 'SELECT {an} FROM ATOM WHERE '.format(an=atnames) - conditions = [] - vals = () - - # iterate through the kwargs - for (k, v) in kwargs.items(): - - # deals with negative conditions - if k.startswith('no_'): - k = k[3:] - neg = ' NOT' - else: - neg = '' - - # get if we have an array or a scalar - # and build the value tuple for the sql query - # deal with the indexing issue if rowID is required - if isinstance(v, list): - - nv = len(v) - - # if we have a large number of values - # we must cut that in pieces because SQL has a hard limit - # that is 999. The limit is here set to 950 - # so that we can have multiple conditions with a total number - # of values inferior to 999 - if nv > self.max_sql_values: - - # cut in chunck - chunck_size = self.max_sql_values - vchunck = [v[i:i + chunck_size] - for i in range(0, nv, chunck_size)] - - data = [] - for v in vchunck: - new_kwargs = kwargs.copy() - new_kwargs[k] = v - data += self.get(atnames, **new_kwargs) - return data - - # otherwithe we just go on - else: - - if k == 'rowID': - vals = vals + tuple([iv + 1 for iv in v]) - else: - vals = vals + tuple(v) - - else: - - nv = 1 - if k == 'rowID': - vals = vals + (v + 1,) - else: - vals = vals + (v,) - - # create the condition for that key - conditions.append(k + neg + ' in (' + ','.join('?' * nv) + ')') - - # stitch the conditions and append to the query - query += ' AND '.join(conditions) - - # error if vals is too long - if len(vals) > self.SQLITE_LIMIT_VARIABLE_NUMBER: - logger.error( - f'\n SQL Queries can only handle a total of 999 values.' - f'\n The current query has {len(vals)} values' - f'\n Hence it fails.' - f'\n You are in a rare situation where MULTIPLE ' - f'conditions have a combined number of values that is ' - f'too large' - f'\n These conditions are:') - ntot = 0 - for k, v in kwargs.items(): - logger.error(f'\n : --> {k:10s} : {len(v)} values.') - ntot += len(v) - logger.error(f'\n : --> %10s : %d values' % ('Total', ntot)) - logger.error( - f'\n : Try to decrease self.max_sql_values ' - f'in pdb2sql.py\n') - raise ValueError('Too many SQL variables') - - # query the sql database and return the answer in a list - data = [list(row) for row in self.c.execute(query, vals)] - - # empty data - if len(data) == 0: - warnings.warn('sqldb.get returned an empty') - return data - - # fix the python <--> sql indexes - # if atnames == 'rowID': - if 'rowID' in atnames: - index = atnames.split(',').index('rowID') - for i, _ in enumerate(data): - data[i][index] -= 1 - - # postporcess the output of the SQl query - # flatten it if each els is of size 1 - if len(data[0]) == 1: - data = [d[0] for d in data] - - return data - - #################################################################### - # - # get the contact atoms - # - # we should have a entire module called pdb2sql - # with a submodule pdb2sql.interface that finds - # contact atoms/residues, - # and possbily other submodules to do other things - # that will leave only the get / put methods in the main class - # - #################################################################### - def get_contact_atoms(self, - cutoff=8.5, - chain1='A', - chain2='B', - extend_to_residue=False, - only_backbone_atoms=False, - excludeH=False, - return_only_backbone_atoms=False, - return_contact_pairs=False): - """Get contact atoms of the interface. - - The cutoff distance is by default 8.5 Angs but can be changed - at will. A few more options allows to precisely define - how the contact atoms are identified and returned. - - Args: - cutoff (float): cutoff for contact atoms (default 8.5) - chain1 (str): name of the first chain - chain2 (str): name of the first chain - extend_to_residue (bool): extend the contact atoms to - entire residues. - only_bacbone_atoms (bool): consider only backbone atoms - excludeH (bool): exclude hydrogen atoms - return_only_backbone_atoms (bool): only returns backbone atoms - return_contact_pairs (bool): return the contact pairs - instead of contact atoms. - - Raises: - ValueError: contact atoms not found. - - Returns: - np.array: index of the contact atoms - - Examples: - - >>> db = pdb2sql(filename) - >>> db.get_contact_atoms(cutoff=5.0,return_contact_pairs=True) - """ - # xyz of the chains - xyz1 = np.array(self.get('x,y,z', chainID=chain1)) - xyz2 = np.array(self.get('x,y,z', chainID=chain2)) - - # index of b - index2 = self.get('rowID', chainID=chain2) - - # resName of the chains - # resName1 = np.array(self.get('resName', chainID=chain1)) - # resName2 = np.array(self.get('resName',chainID=chain2)) - - # atomnames of the chains - atName1 = np.array(self.get('name', chainID=chain1)) - atName2 = np.array(self.get('name', chainID=chain2)) - - # loop through the first chain - # TO DO : loop through the smallest chain instead ... - index_contact_1, index_contact_2 = [], [] - index_contact_pairs = {} - - for i, x0 in enumerate(xyz1): - - # compute the contact atoms - contacts = np.where( - np.sqrt( - np.sum( - (xyz2 - x0)**2, - 1)) <= cutoff)[0] - - # exclude the H if required - if excludeH and atName1[i][0] == 'H': - continue - - if len(contacts) > 0 and any([not only_backbone_atoms, - atName1[i] in self.backbone_type]): - - # the contact atoms - index_contact_1 += [i] - index_contact_2 += [ - index2[k] for k in contacts if - (any([atName2[k] in self.backbone_type, - not only_backbone_atoms]) - and not (excludeH and atName2[k][0] == 'H'))] - - # the pairs - pairs = [ - index2[k] for k in contacts if - any([atName2[k] in self.backbone_type, - not only_backbone_atoms]) - and not (excludeH and atName2[k][0] == 'H')] - if len(pairs) > 0: - index_contact_pairs[i] = pairs - - # get uniques - index_contact_1 = sorted(set(index_contact_1)) - index_contact_2 = sorted(set(index_contact_2)) - - # if no atoms were found - if len(index_contact_1) == 0: - warnings.warn(f"No contact atoms found with cutoff {cutoff}Å") - - # extend the list to entire residue - if extend_to_residue: - index_contact_1, index_contact_2 = self._extend_contact_to_residue( - index_contact_1, index_contact_2, only_backbone_atoms) - - # filter only the backbone atoms - if return_only_backbone_atoms and not only_backbone_atoms: - - # get all the names - # there are better ways to do that ! - atNames = np.array(self.get('name')) - - # change the index_contacts - index_contact_1 = [ind for ind in index_contact_1 - if atNames[ind] in self.backbone_type] - index_contact_2 = [ind for ind in index_contact_2 - if atNames[ind] in self.backbone_type] - - # change the contact pairs - tmp_dict = {} - for ind1, ind2_list in index_contact_pairs.items(): - - if atNames[ind1] in self.backbone_type: - tmp_dict[ind1] = [ind2 for ind2 in ind2_list - if atNames[ind2] in self.backbone_type] - - index_contact_pairs = tmp_dict - - # not sure that's the best way of dealing with that - if return_contact_pairs: - return index_contact_pairs - else: - return index_contact_1, index_contact_2 - - # extend the contact atoms to the residue - def _extend_contact_to_residue(self, index1, index2, only_backbone_atoms): - - # extract the data - dataA = self.get(self.residue_key, rowID=index1) - dataB = self.get(self.residue_key, rowID=index2) - - # create tuple cause we want to hash through it - # dataA = list(map(lambda x: tuple(x),dataA)) - # dataB = list(map(lambda x: tuple(x),dataB)) - dataA = [tuple(x) for x in dataA] - dataB = [tuple(x) for x in dataB] - - # extract uniques - resA = list(set(dataA)) - resB = list(set(dataB)) - - # init the list - index_contact_A, index_contact_B = [], [] - - # contact of chain A - for resdata in resA: - chainID, resSeq, resName = resdata - - if only_backbone_atoms: - index_contact_A += self.get('rowID', - chainID=chainID, - resName=resName, - resSeq=resSeq, - name=self.backbone_type) - else: - index_contact_A += self.get('rowID', - chainID=chainID, - resName=resName, - resSeq=resSeq) - - # contact of chain B - for resdata in resB: - chainID, resSeq, resName = resdata - - if only_backbone_atoms: - index_contact_B += self.get('rowID', - chainID=chainID, - resName=resName, - resSeq=resSeq, - name=self.backbone_type) - else: - index_contact_B += self.get('rowID', - chainID=chainID, - resName=resName, - resSeq=resSeq) - - # make sure that we don't have double (maybe optional) - index_contact_A = sorted(set(index_contact_A)) - index_contact_B = sorted(set(index_contact_B)) - - return index_contact_A, index_contact_B - - # get the contact residue - def get_contact_residue(self, - cutoff=8.5, - chain1='A', - chain2='B', - excludeH=False, - only_backbone_atoms=False, - return_contact_pairs=False): - """Get contact residues of the interface. - - The cutoff distance is by default 8.5 Angs but can be changed - at will. A few more options allows to precisely define how - the contact residues are identified and returned. - - Args: - cutoff (float): cutoff for contact atoms (default 8.5) - chain1 (str): name of the first chain - chain2 (str): name of the first chain - only_bacbone_atoms (bool): consider only backbone atoms - excludeH (bool): exclude hydrogen atoms - return_contact_pairs (bool): return the contact pairs - instead of contact atoms - - Returns: - np.array: index of the contact atoms - - Examples: - >>> db = pdb2sql(filename) - >>> db.get_contact_residue(cutoff=5.0, - ... return_contact_pairs=True) - """ - # get the contact atoms - if return_contact_pairs: - - # declare the dict - residue_contact_pairs = {} - - # get the contact atom pairs - atom_pairs = self.get_contact_atoms( - cutoff=cutoff, chain1=chain1, chain2=chain2, - only_backbone_atoms=only_backbone_atoms, - excludeH=excludeH, - return_contact_pairs=True) - - # loop over the atom pair dict - for iat1, atoms2 in atom_pairs.items(): - - # get the res info of the current atom - data1 = tuple(self.get(self.residue_key, rowID=[iat1])[0]) - - # create a new entry in the dict if necessary - if data1 not in residue_contact_pairs: - residue_contact_pairs[data1] = set() - - # get the res info of the atom in the other chain - data2 = self.get(self.residue_key, rowID=atoms2) - - # store that in the dict without double - for resData in data2: - residue_contact_pairs[data1].add(tuple(resData)) - - for resData in residue_contact_pairs.keys(): - residue_contact_pairs[resData] = sorted( - residue_contact_pairs[resData]) - - return residue_contact_pairs - - else: - - # get the contact atoms - contact_atoms = self.get_contact_atoms( - cutoff=cutoff, chain1=chain1, chain2=chain2, - return_contact_pairs=False) - - # get the residue info - data1 = self.get(self.residue_key, rowID=contact_atoms[0]) - data2 = self.get(self.residue_key, rowID=contact_atoms[1]) - - # take only unique - residue_contact_A = sorted( - set([tuple(resData) for resData in data1])) - residue_contact_B = sorted( - set([tuple(resData) for resData in data2])) - - return residue_contact_A, residue_contact_B - - #################################################################### - # - # PUT FUNCTONS AND ASSOCIATED - # - # add_column() -> add a column - # update_column() -> update the values of one column - # update_xyz() -> update_xyz of the pdb - # put() -> put a value in a column - # - #################################################################### - - def add_column(self, colname, coltype='FLOAT', default=0): - """Add an extra column to the ATOM table. - - Args: - colname (str): name of the column - coltype (str, optional): type of the column data - (default FLOAT) - default (int, optional): default value to fill in the column - (default 0.0) - """ - - query = "ALTER TABLE ATOM ADD COLUMN '%s' %s DEFAULT %s" % ( - colname, coltype, str(default)) - self.c.execute(query) - - def update(self, attribute, values, **kwargs): - """Update multiple columns in the data. - - Args: - attribute (str): comma separated attribute names: 'x,y,z' - values (np.array): new values for the attributes - **kwargs: selection of the rows to update. - - Raises: - ValueError: if size mismatch between values, conditions - and attribute names - - Examples: - >>> n = 200 - >>> index = list(range(n)) - >>> vals = np.random.rand(n,3) - >>> db.update('x,y,z',vals,rowID=index) - """ - - # the asked keys - # keys = kwargs.keys() - - # check if the column exists - try: - self.c.execute(f"SELECT EXISTS(SELECT {attribute} FROM ATOM)") - except BaseException: - logger.error(f'Column {attribute} not found in the database') - self.get_colnames() - raise ValueError(f'Attribute name {attribute} not recognized') - - # if len(kwargs) == 0: - # raise ValueError(f'Update without kwargs seem to be buggy.' - # f' Use rowID=list(range(natom)) instead') - - # handle the multi model cases - # this is still in devs and not necessary - # for deeprank. - # We will have to deal with that if we - # release pdb2sql as a standalone - # if 'model' not in keys and self.nModel > 0: - # for iModel in range(self.nModel): - # kwargs['model'] = iModel - # self.update(attribute,values,**kwargs) - # return - - # parse the attribute - if ',' in attribute: - attribute = attribute.split(',') - - if not isinstance(attribute, list): - attribute = [attribute] - - # check the size - natt = len(attribute) - nrow = len(values) - ncol = len(values[0]) - - if natt != ncol: - raise ValueError( - f'Number of attribute incompatible with ' - f' the number of columns in the data') - - # get the row ID of the selection - rowID = self.get('rowID', **kwargs) - nselect = len(rowID) - - if nselect != nrow: - raise ValueError( - 'Number of data values incompatible with the given conditions') - - # prepare the query - query = 'UPDATE ATOM SET ' - query = query + ', '.join(map(lambda x: x + '=?', attribute)) - # if len(kwargs)>0: # why did I do that ... - query = query + ' WHERE rowID=?' - - # prepare the data - data = [] - for i, val in enumerate(values): - - tmp_data = [v for v in val] - - # if len(kwargs)>0: Same here why did I do that ? - # here the conversion of the indexes is a bit annoying - tmp_data += [rowID[i] + 1] - - data.append(tmp_data) - - self.c.executemany(query, data) - - def update_column(self, colname, values, index=None): - """Update a single column. - - Args: - colname (str): name of the column to update - values (list): new values of the column - index (None, optional): index of the column to update - (default all) - - Examples: - >>> db.update_column('x', np.random.rand(10), - ... index=list(range(10))) - """ - - if index is None: - data = [[v, i + 1] for i, v in enumerate(values)] - else: - # shouldn't that be ind+1 ? - data = [[v, ind] for v, ind in zip(values, index)] - - query = 'UPDATE ATOM SET {cn}=? WHERE rowID=?'.format(cn=colname) - self.c.executemany(query, data) - # self.conn.commit() - - def update_xyz(self, xyz, index=None): - """Update the xyz information. - - Update the positions of the atoms selected - if index=None the position of all the atoms are changed - - Args: - xyz (np.array): new xyz position - index (None, list(int)): index of the atom to move - - Examples: - >>> n = 200 - >>> index = list(range(n)) - >>> vals = np.random.rand(n,3) - >>> db.update_xyz(vals,index=index) - """ - - if index is None: - data = [[pos[0], pos[1], pos[2], i + 1] - for i, pos in enumerate(xyz)] - else: - data = [[pos[0], pos[1], pos[2], ind + 1] - for pos, ind in zip(xyz, index)] - - query = 'UPDATE ATOM SET x=?, y=?, z=? WHERE rowID=?' - self.c.executemany(query, data) - - def put(self, colname, value, **kwargs): - """Update the value of the attribute with value specified with possible - selection. - - Args: - colname (str): must be a valid attribute name. - you can get these names via the get_colnames(): - serial, name, atLoc,resName, chainID, resSeq, - iCode,x,y,z,occ,temp - you can specify more than one attribute name at once, - e.g 'x,y,z' - - keyword args: Several options are possible - None : put the value in the entire column - index = [0,1,2,3] in only these indexes (not serial) - where = "chainID='B'" only for this chain - query = general SQL Query - - Examples: - >>> db = pdb2sql(filename) - >>> db.add_column('CHARGE') - >>> db.put('CHARGE',1.25,index=[1]) - >>> db.close() - """ - arguments = { - 'where': "String e.g 'chainID = 'A''", - 'index': "Array e.g. [27,28,30]", - 'name': "'CA' atome name", - 'query': "SQL query e.g. 'WHERE chainID='B' AND resName='ASP' "} - - # the asked keys - keys = kwargs.keys() - - # if we have more than one key we kill it - if len(keys) > 1: - logger.error(f'You can only specify 1 conditional statement ' - f'for the pdb2sql.put function') - return - - # check if the column exists - try: - self.c.execute(f"SELECT EXISTS(SELECT {colname} FROM ATOM)") - except BaseException: - logger.error(f'Column {colname} not found in the database') - self.get_colnames() - return - - # if we have 0 key we take the entire db - if len(kwargs) == 0: - query = f'UPDATE ATOM SET {colname}=?' - value = tuple([value]) - self.c.execute(query, value) - return - - # otherwise we have only one key - key = list(keys)[0] - cond = kwargs[key] - - # select which key we have - if key == 'where': - query = f'UPDATE ATOM SET {colname}=? WHERE {cond}' - value = tuple([value]) - self.c.execute(query, value) - - elif key == 'name': - values = tuple([value, cond]) - query = f'UPDATE ATOM SET {colname}=? WHERE name=?' - self.c.execute(query, values) - - elif key == 'index': - values = tuple([value] + [v + 1 for v in cond]) - qm = ','.join(['?' for i in range(len(cond))]) - query = f'UPDATE ATOM SET {colname}=? WHERE rowID in ({qm})' - self.c.execute(query, values) - - elif key == 'query': - query = f'UPDATE ATOM SET {colname}=? {cond}' - value = tuple([value]) - self.c.execute(query, value) - - else: - logger.error( - f'Error arguments {key} not supported in pdb2sql.get().' - f'\nOptions are:\n') - for posskey, possvalue in arguments.items(): - logger.error(f'\t{posskey}\t\t{possvalue}') - return - - #################################################################### - # - # COMMIT, EXPORT, CLOSE FUNCTIONS - # - #################################################################### - - # comit changes - - def commit(self): - """Commit the database.""" - self.conn.commit() - - # export to pdb file - def exportpdb(self, fname, **kwargs): - """Export a PDB file with kwargs selection. - - Args: - fname (str): Name of the file - **kwargs: Selection (see pdb2sql.get()) - - Examples: - >>> db = pdb2sql('1AK4.pdb') - >>> db.exportpdb('CA.pdb',name='CA') - """ - # get the data - data = self.get('*', **kwargs) - - # write each line - # the PDB format is pretty strict - # http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM - # TODO make sure the output of atom type on correct position. - # TODO use exportpdb in DataGenerator - f = open(fname, 'w') - for d in data: - line = 'ATOM ' - line += '{:>5}'.format(d[0]) # serial - line += ' ' - line += '{:^4}'.format(d[1]) # name - line += '{:>1}'.format(d[2]) # altLoc - line += '{:>3}'.format(d[3]) # resname - line += ' ' - line += '{:>1}'.format(d[4]) # chainID - line += '{:>4}'.format(d[5]) # resSeq - line += '{:>1}'.format(d[6]) # iCODE - line += ' ' - line += '{: 8.3f}'.format(d[7]) # x - line += '{: 8.3f}'.format(d[8]) # y - line += '{: 8.3f}'.format(d[9]) # z - - if not self.no_extra: - line += '{: 6.2f}'.format(d[10]) # occ - line += '{: 6.2f}'.format(d[11]) # temp - - line += '\n' - - f.write(line) - - # close - f.close() - - # close the database - def close(self, rmdb=True): - """Close the database. - - Args: - rmdb (bool, optional): Remove the database file - """ - - if self.sqlfile is None: - self.conn.close() - - else: - - if rmdb: - self.conn.close() - os.system('rm %s' % (self.sqlfile)) - else: - self.commit() - self.conn.close() - - #################################################################### - # - # Transform the position of the molecule - # - #################################################################### - - def translation(self, vect, **kwargs): - """Translate a part or all of the molecule. - - Args: - vect (np.array): translation vector - **kwargs: keyword argument to select the atoms. - See pdb2sql.get() - - Examples: - >>> vect = np.random.rand(3) - >>> db.translation(vect, chainID = 'A') - """ - xyz = self.get('x,y,z', **kwargs) - xyz += vect - self.update('x,y,z', xyz, **kwargs) - - def rotation_around_axis(self, axis, angle, **kwargs): - """Rotate a molecule around a specified axis. - - Args: - axis (np.array): axis of rotation - angle (float): angle of rotation in radian - **kwargs: keyword argument to select the atoms. - See pdb2sql.get() - - Returns: - np.array: center of the molecule - - Examples: - >>> axis = np.random.rand(3) - >>> angle = np.random.rand() - >>> db.rotation_around_axis(axis, angle, chainID = 'B') - """ - xyz = self.get('x,y,z', **kwargs) - - # get the data - ct, st = np.cos(angle), np.sin(angle) - ux, uy, uz = axis - - # get the center of the molecule - xyz0 = np.mean(xyz, 0) - - # definition of the rotation matrix - # see https://en.wikipedia.org/wiki/Rotation_matrix - rot_mat = np.array([ - [ct + ux**2 * (1 - ct), ux * uy * (1 - ct) - uz * st, ux * uz * (1 - ct) + uy * st], - [uy * ux * (1 - ct) + uz * st, ct + uy**2 * (1 - ct), uy * uz * (1 - ct) - ux * st], - [uz * ux * (1 - ct) - uy * st, uz * uy * (1 - ct) + ux * st, ct + uz**2 * (1 - ct)]]) - - # apply the rotation - xyz = np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 - self.update('x,y,z', xyz, **kwargs) - - return xyz0 - - def rotation_euler(self, alpha, beta, gamma, **kwargs): - """Rotate a part or all of the molecule from Euler rotation axis. - - Args: - alpha (float): angle of rotation around the x axis - beta (float): angle of rotation around the y axis - gamma (float): angle of rotation around the z axis - **kwargs: keyword argument to select the atoms. - See pdb2sql.get() - - Examples: - >>> a,b,c = np.random.rand(3) - >>> db.rotation_euler(a,b,c,resName='VAL') - """ - xyz = self.get('x,y,z', **kwargs) - - # precomte the trig - ca, sa = np.cos(alpha), np.sin(alpha) - cb, sb = np.cos(beta), np.sin(beta) - cg, sg = np.cos(gamma), np.sin(gamma) - - # get the center of the molecule - xyz0 = np.mean(xyz, 0) - - # rotation matrices - rx = np.array([[1, 0, 0], [0, ca, -sa], [0, sa, ca]]) - ry = np.array([[cb, 0, sb], [0, 1, 0], [-sb, 0, cb]]) - rz = np.array([[cg, -sg, 0], [sg, cg, 0], [0, 0, 1]]) - rot_mat = np.dot(rx, np.dot(ry, rz)) - - # apply the rotation - xyz = np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 - - self.update('x,y,z', xyz, **kwargs) - - def rotation_matrix(self, rot_mat, center=True, **kwargs): - """Rotate a part or all of the molecule from a rotation matrix. - - Args: - rot_mat (np.array): 3x3 rotation matrix - center (bool, optional): center the molecule before - applying the rotation. - **kwargs: keyword argument to select the atoms. - See pdb2sql.get() - - Examples: - >>> mat = np.random.rand(3,3) - >>> db.rotation_matrix(mat,chainID='A') - """ - xyz = self.get('x,y,z', **kwargs) - - if center: - xyz0 = np.mean(xyz) - xyz = np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 - else: - xyz = np.dot(rot_mat, (xyz).T).T - self.update('x,y,z', xyz, **kwargs) diff --git a/deeprank/tools/sasa.py b/deeprank/tools/sasa.py index c246a20c..2616513f 100644 --- a/deeprank/tools/sasa.py +++ b/deeprank/tools/sasa.py @@ -1,6 +1,5 @@ import numpy as np - -from deeprank.tools import pdb2sql +import pdb2sql class SASA(object): @@ -55,7 +54,7 @@ def get_residue_center(self, chainA='A', chainB='B'): chainB (str, optional): Name of the second chain """ - sql = pdb2sql(self.pdbfile) + sql = pdb2sql.pdb2sql(self.pdbfile) resA = np.array(sql.get('resSeq,resName', chainID=chainA)) resB = np.array(sql.get('resSeq,resName', chainID=chainB)) @@ -94,7 +93,7 @@ def get_residue_carbon_beta(self, chainA='A', chainB='B'): chainB (str, optional): Name of the second chain """ - sql = pdb2sql(self.pdbfile) + sql = pdb2sql.pdb2sql(self.pdbfile) resA = np.array( sql.get( 'resSeq,resName,x,y,z', diff --git a/deeprank/utils/visualize3Ddata.py b/deeprank/utils/visualize3Ddata.py index 5adfbed8..096d4c11 100755 --- a/deeprank/utils/visualize3Ddata.py +++ b/deeprank/utils/visualize3Ddata.py @@ -5,8 +5,9 @@ import h5py import numpy as np +import pdb2sql -from deeprank.tools import pdb2sql, sparse +from deeprank.tools import sparse def visualize3Ddata(hdf5=None, mol_name=None, out=None): @@ -54,7 +55,7 @@ def visualize3Ddata(hdf5=None, mol_name=None, out=None): raise LookupError('Molecule %s not found in %s' % (mol_name, hdf5)) # create the pdb file - sqldb = pdb2sql(molgrp['complex'][:]) + sqldb = pdb2sql.pdb2sql(molgrp['complex'][:]) sqldb.exportpdb(outdir + '/complex.pdb') sqldb.close() diff --git a/setup.py b/setup.py index f608aebf..4ca1bd35 100644 --- a/setup.py +++ b/setup.py @@ -17,8 +17,10 @@ 'pandas', 'matplotlib', 'torchsummary', - 'freesasa', - 'torch' + 'torch', + 'pdb2sql >= 0.2.1', + 'freesasa==2.0.3.post7;platform_system=="Linux"', + 'freesasa==2.0.3.post6;platform_system=="Darwin"' ], extras_require={ diff --git a/test/test_atomic_features.py b/test/test_atomic_features.py index 13ec692d..6f13734e 100644 --- a/test/test_atomic_features.py +++ b/test/test_atomic_features.py @@ -105,44 +105,44 @@ def test_atomic_haddock(): # close the db atfeat.sqldb.close() - @staticmethod - def test_atomic_zdock(): - - # in case you change the ref don't forget to: - # - comment the first line (E0=1) - # - uncomment the last two lines (Total = ...) - # - use the corresponding PDB file to test - #REF = './1AK4/atomic_features/ref_1AK4_100w.dat' - pdb = './2OUL/decoys/2OUL_1.pdb' - test_name = './2OUL/atomic_features/test_2OUL_1.dat' - - # get the force field included in deeprank - # if another FF has been used to compute the ref - # change also this path to the correct one - FF = pkg_resources.resource_filename( - 'deeprank.features', '') + '/forcefield/' - - # declare the feature calculator instance - atfeat = AtomicFeature(pdb, - param_charge=FF + 'protein-allhdg5-4_new.top', - param_vdw=FF + 'protein-allhdg5-4_new.param', - patch_file=FF + 'patch.top') - # assign parameters - atfeat.assign_parameters() - - # only compute the pair interactions here - atfeat.evaluate_pair_interaction(save_interactions=test_name) - - # make sure that the other properties are not crashing - atfeat.compute_coulomb_interchain_only(contact_only=True) - atfeat.compute_coulomb_interchain_only(contact_only=False) - - # make sure that the other properties are not crashing - atfeat.compute_vdw_interchain_only(contact_only=True) - atfeat.compute_vdw_interchain_only(contact_only=False) - - # close the db - atfeat.sqldb.close() + # @staticmethod + # def test_atomic_zdock(): + + # # in case you change the ref don't forget to: + # # - comment the first line (E0=1) + # # - uncomment the last two lines (Total = ...) + # # - use the corresponding PDB file to test + # #REF = './1AK4/atomic_features/ref_1AK4_100w.dat' + # pdb = './2OUL/decoys/2OUL_1.pdb' + # test_name = './2OUL/atomic_features/test_2OUL_1.dat' + + # # get the force field included in deeprank + # # if another FF has been used to compute the ref + # # change also this path to the correct one + # FF = pkg_resources.resource_filename( + # 'deeprank.features', '') + '/forcefield/' + + # # declare the feature calculator instance + # atfeat = AtomicFeature(pdb, + # param_charge=FF + 'protein-allhdg5-4_new.top', + # param_vdw=FF + 'protein-allhdg5-4_new.param', + # patch_file=FF + 'patch.top') + # # assign parameters + # atfeat.assign_parameters() + + # # only compute the pair interactions here + # atfeat.evaluate_pair_interaction(save_interactions=test_name) + + # # make sure that the other properties are not crashing + # atfeat.compute_coulomb_interchain_only(contact_only=True) + # atfeat.compute_coulomb_interchain_only(contact_only=False) + + # # make sure that the other properties are not crashing + # atfeat.compute_vdw_interchain_only(contact_only=True) + # atfeat.compute_vdw_interchain_only(contact_only=False) + + # # close the db + # atfeat.sqldb.close() if __name__ == '__main__': diff --git a/test/test_generate.py b/test/test_generate.py index bfc25e9a..f8ca2a98 100644 --- a/test/test_generate.py +++ b/test/test_generate.py @@ -1,6 +1,5 @@ import os import unittest -import numpy as np from time import time from deeprank.generate import * @@ -17,9 +16,6 @@ class TestGenerateData(unittest.TestCase): """Test the data generation process.""" - # set random seed to make results repeatable - np.random.seed(2019) - h5file = ['./1ak4.hdf5', 'native.hdf5'] pdb_source = ['./1AK4/decoys/', './1AK4/native/'] # pdb_native is only used to calculate i-RMSD, dockQ and so on. The native @@ -62,7 +58,7 @@ def test_1_generate(self): if not os.path.isfile(database.hdf5): t0 = time() print('{:25s}'.format('Create new database') + database.hdf5) - database.create_database(prog_bar=True) + database.create_database(prog_bar=True, random_seed=2019) print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) else: print('{:25s}'.format('Use existing database') + database.hdf5) diff --git a/test/test_pdb2sql.py b/test/test_pdb2sql.py deleted file mode 100644 index 6937c0d1..00000000 --- a/test/test_pdb2sql.py +++ /dev/null @@ -1,82 +0,0 @@ -import unittest - -import numpy as np - -from deeprank.tools import pdb2sql - - -class TestPDB2SQL(unittest.TestCase): - """Test PDB2SQL.""" - - def test_read(self): - """Read a pdb and create a sql db.""" - - # db.prettyprint() - self.db.get_colnames() - self.db.exportpdb('chainA.pdb', chainID='A') - - def test_get(self): - """Test get with large number of index.""" - - index = list(range(1200)) - self.db.get('x,y,z', rowID=index) - - @unittest.expectedFailure - def test_get_fails(self): - """Test get with a too large number of conds.""" - - index_res = list(range(100)) - index_atoms = list(range(1200)) - self.db.get('x,y,z', resSeq=index_res, rowID=index_atoms) - - def test_add_column(self): - """Add a new column to the db and change its values.""" - - self.db.add_column('CHARGE', 'FLOAT') - self.db.put('CHARGE', 0.1) - n = 100 - q = np.random.rand(n) - ind = list(range(n)) - self.db.update_column('CHARGE', q, index=ind) - - def test_update(self): - """Update the database.""" - - n = 200 - index = list(range(n)) - vals = np.random.rand(n, 3) - self.db.update('x,y,z', vals, rowID=index) - self.db.prettyprint() - self.db.update_xyz(vals, index=index) - - def test_update_all(self): - xyz = self.db.get('x,y,z') - self.db.update('x,y,z', xyz) - self.db.prettyprint() - - def test_manip(self): - """Manipualte part of the protein.""" - - vect = np.random.rand(3) - self.db.translation(vect, chainID='A') - - axis = np.random.rand(3) - angle = np.random.rand() - self.db.rotation_around_axis(axis, angle, chainID='B') - - a, b, c = np.random.rand(3) - self.db.rotation_euler(a, b, c, resName='VAL') - - mat = np.random.rand(3, 3) - self.db.rotation_matrix(mat, chainID='A') - - def setUp(self): - mol = './1AK4/decoys/1AK4_cm-it0_745.pdb' - self.db = pdb2sql(mol) - - def tearDown(self): - self.db.close() - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_rmsd.py b/test/test_rmsd.py index 640a934e..cfb8b381 100644 --- a/test/test_rmsd.py +++ b/test/test_rmsd.py @@ -2,8 +2,7 @@ import unittest import numpy as np - -from deeprank.tools import StructureSimilarity +from pdb2sql import StructureSimilarity class TestStructureSimilarity(unittest.TestCase): @@ -90,7 +89,7 @@ def test_slow(): sim = StructureSimilarity(decoy, ref) trash = sim.compute_lrmsd_pdb2sql(method='svd') trash = sim.compute_irmsd_pdb2sql(method='svd') - trash = sim.compute_Fnat_pdb2sql() + trash = sim.compute_fnat_pdb2sql() print(trash) def setUp(self): @@ -103,4 +102,4 @@ def setUp(self): if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/test/test_tools.py b/test/test_tools.py index 17a945ea..8610142d 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -1,27 +1,11 @@ import unittest -from deeprank.tools import SASA, pdb2sql +from deeprank.tools import SASA class TestTools(unittest.TestCase): """Test StructureSimialrity.""" - @staticmethod - def test_pdb2sql(): - """Test the db2sql module.""" - - # create the sql db - pdb = './1AK4/decoys/1AK4_cm-it0_745.pdb' - db = pdb2sql(pdb) - db._fix_chainID() - - # get column name - db.get_colnames() - - # print - db.prettyprint() - db.uglyprint() - @staticmethod def test_sasa(): """Test the SASA module."""