From b987040ee1fe219526908b39127d1e7a3662b464 Mon Sep 17 00:00:00 2001 From: LilySnow Date: Wed, 20 Mar 2019 19:06:46 +0100 Subject: [PATCH] updated _r bug-fix in DataSet.py and DataGenerator.py --- deeprank/generate/DataGenerator.py | 36 +++++++++++++++++++++--------- deeprank/learn/DataSet.py | 9 +++++--- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/deeprank/generate/DataGenerator.py b/deeprank/generate/DataGenerator.py index b61b0092..8cb9f604 100644 --- a/deeprank/generate/DataGenerator.py +++ b/deeprank/generate/DataGenerator.py @@ -9,6 +9,7 @@ from deeprank.tools import pdb2sql from deeprank.generate import GridTools as gt from deeprank.generate import settings +import re try: from tqdm import tqdm @@ -275,12 +276,14 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ molgrp.require_group('features') molgrp.require_group('features_raw') + error_flag = False #error_flag => when False: success; when True: failed if self.compute_features is not None: - self._compute_features(self.compute_features, + error_flag = self._compute_features(self.compute_features, molgrp['complex'][:], molgrp['features'], molgrp['features_raw'] ) + ################################################ # add the targets ################################################ @@ -350,6 +353,12 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ molgrp.attrs['angle'] = angle molgrp.attrs['center'] = center + if error_flag: + #error_flag => when False: success; when True: failed + self.feature_error += [mol_name] + mol_aug_name_list + self.logger.warning('Error during the feature calculation of %s' %cplx,exc_info=True) + sys.stdout.flush() + except Exception as inst: self.feature_error += [mol_name] + mol_aug_name_list @@ -361,9 +370,10 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ # remove the data where we had issues if remove_error: for mol in self.feature_error: - self.logger.warning('Error during the feature calculation of %s' %cplx,exc_info=True) + #self.logger.warning('Error during the feature calculation of %s' %cplx,exc_info=True) _printif('removing %s from %s' %(mol,self.hdf5),self.debug) del self.f5[mol] + sys.stdout.flush() # close the file self.f5.close() @@ -405,8 +415,8 @@ def add_feature(self,prog_bar=True): fnames = f5.keys() # get the non rotated ones - fnames_original = list( filter(lambda x: '_r' not in x, fnames) ) - fnames_augmented = list( filter(lambda x: '_r' in x, fnames) ) + fnames_original = list( filter(lambda x: not re.search('_r\d+$',x), fnames) ) + fnames_augmented = list( filter(lambda x: re.search('_r\d+$',x), fnames) ) # computes the features of the original desc = '{:25s}'.format('Add features') @@ -430,7 +440,7 @@ def add_feature(self,prog_bar=True): aug_molgrp = f5[cplx_name] # get the source group - mol_name = molgrp.name.split('_r')[0] + mol_name = re.split('_r\d+', molgrp.name)[0] src_molgrp = f5[mol_name] # get the rotation parameters @@ -516,8 +526,8 @@ def add_target(self,prog_bar=False): fnames = f5.keys() # get the non rotated ones - fnames_original = list( filter(lambda x: '_r' not in x, fnames) ) - fnames_augmented = list( filter(lambda x: '_r' in x, fnames) ) + fnames_original = list( filter(lambda x: not re.search('_r\d+$',x), fnames) ) + fnames_augmented = list( filter(lambda x: re.search('_r\d+$',x), fnames) ) # compute the targets of the original desc = '{:25s}'.format('Add targets') @@ -538,7 +548,7 @@ def add_target(self,prog_bar=False): aug_molgrp = f5[cplx_name] # get the source group - mol_name = molgrp.name.split('_r')[0] + mol_name = re.split('_r\d+', molgrp.name)[0] src_molgrp = f5[mol_name] # copy the targets to the augmented @@ -1093,14 +1103,18 @@ def _compute_features(feat_list,pdb_data,featgrp,featgrp_raw): """Compute the features Args: - feat_list (list(str)): list of function name - pdb_data (bytes): PDB translated in btes + feat_list (list(str)): list of function name, e.g., ['deeprank.features.ResidueDensity', 'deeprank.features.PSSM_IC'] + pdb_data (bytes): PDB translated in bytes featgrp (str): name of the group where to store the xyz feature featgrp_raw (str): name of the group where to store the raw feature """ + error_flag = False # when False: success; when True: failed for feat in feat_list: feat_module = importlib.import_module(feat,package=None) - feat_module.__compute_feature__(pdb_data,featgrp,featgrp_raw) + error_flag = feat_module.__compute_feature__(pdb_data,featgrp,featgrp_raw) + + if re.search('ResidueDensity', feat) and error_flag == True: + return error_flag #==================================================================================== diff --git a/deeprank/learn/DataSet.py b/deeprank/learn/DataSet.py index e9a95711..29d00a0a 100644 --- a/deeprank/learn/DataSet.py +++ b/deeprank/learn/DataSet.py @@ -4,6 +4,7 @@ import time import h5py import pickle +import re from functools import partial @@ -457,11 +458,12 @@ def _select_pdb(self, mol_names): """ if self.use_rotation is not None: - fnames_original = list(filter(lambda x: '_r' not in x, mol_names)) + fnames_original = list(filter(lambda x: not re.search('_r\d+$',x), mol_names)) fnames_augmented = [] if self.use_rotation > 0: for i in range(self.use_rotation): - fnames_augmented += list(filter(lambda x: '_r%03d' %(i+1) in x, mol_names)) +# fnames_augmented += list(filter(lambda x: '_r%03d' %(i+1) in x, mol_names)) + fnames_augmented += list(filter(lambda x: re.search('_r%03d$' %(i+1), x), mol_names)) mol_names = fnames_original + fnames_augmented else: mol_names = fnames_original @@ -1325,7 +1327,8 @@ def _featgrid(center,value,grid,npts): # shortcut for th center x0,y0,z0 = center - beta = 1.0 + sigma = np.sqrt(1./2) + beta = 0.5/(sigma**2) cutoff = 5.*beta dd = np.sqrt( (grid[0]-x0)**2 + (grid[1]-y0)**2 + (grid[2]-z0)**2 )