Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
updated _r bug-fix in DataSet.py and DataGenerator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
LilySnow committed Mar 20, 2019
1 parent 7c14696 commit b987040
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
36 changes: 25 additions & 11 deletions deeprank/generate/DataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deeprank.tools import pdb2sql
from deeprank.generate import GridTools as gt
from deeprank.generate import settings
import re

try:
from tqdm import tqdm
Expand Down Expand Up @@ -275,12 +276,14 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_
molgrp.require_group('features')
molgrp.require_group('features_raw')

error_flag = False #error_flag => when False: success; when True: failed
if self.compute_features is not None:
self._compute_features(self.compute_features,
error_flag = self._compute_features(self.compute_features,
molgrp['complex'][:],
molgrp['features'],
molgrp['features_raw'] )


################################################
# add the targets
################################################
Expand Down Expand Up @@ -350,6 +353,12 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_
molgrp.attrs['angle'] = angle
molgrp.attrs['center'] = center

if error_flag:
#error_flag => when False: success; when True: failed
self.feature_error += [mol_name] + mol_aug_name_list
self.logger.warning('Error during the feature calculation of %s' %cplx,exc_info=True)
sys.stdout.flush()

except Exception as inst:

self.feature_error += [mol_name] + mol_aug_name_list
Expand All @@ -361,9 +370,10 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_
# remove the data where we had issues
if remove_error:
for mol in self.feature_error:
self.logger.warning('Error during the feature calculation of %s' %cplx,exc_info=True)
#self.logger.warning('Error during the feature calculation of %s' %cplx,exc_info=True)
_printif('removing %s from %s' %(mol,self.hdf5),self.debug)
del self.f5[mol]
sys.stdout.flush()

# close the file
self.f5.close()
Expand Down Expand Up @@ -405,8 +415,8 @@ def add_feature(self,prog_bar=True):
fnames = f5.keys()

# get the non rotated ones
fnames_original = list( filter(lambda x: '_r' not in x, fnames) )
fnames_augmented = list( filter(lambda x: '_r' in x, fnames) )
fnames_original = list( filter(lambda x: not re.search('_r\d+$',x), fnames) )
fnames_augmented = list( filter(lambda x: re.search('_r\d+$',x), fnames) )

# computes the features of the original
desc = '{:25s}'.format('Add features')
Expand All @@ -430,7 +440,7 @@ def add_feature(self,prog_bar=True):
aug_molgrp = f5[cplx_name]

# get the source group
mol_name = molgrp.name.split('_r')[0]
mol_name = re.split('_r\d+', molgrp.name)[0]
src_molgrp = f5[mol_name]

# get the rotation parameters
Expand Down Expand Up @@ -516,8 +526,8 @@ def add_target(self,prog_bar=False):
fnames = f5.keys()

# get the non rotated ones
fnames_original = list( filter(lambda x: '_r' not in x, fnames) )
fnames_augmented = list( filter(lambda x: '_r' in x, fnames) )
fnames_original = list( filter(lambda x: not re.search('_r\d+$',x), fnames) )
fnames_augmented = list( filter(lambda x: re.search('_r\d+$',x), fnames) )

# compute the targets of the original
desc = '{:25s}'.format('Add targets')
Expand All @@ -538,7 +548,7 @@ def add_target(self,prog_bar=False):
aug_molgrp = f5[cplx_name]

# get the source group
mol_name = molgrp.name.split('_r')[0]
mol_name = re.split('_r\d+', molgrp.name)[0]
src_molgrp = f5[mol_name]

# copy the targets to the augmented
Expand Down Expand Up @@ -1093,14 +1103,18 @@ def _compute_features(feat_list,pdb_data,featgrp,featgrp_raw):
"""Compute the features
Args:
feat_list (list(str)): list of function name
pdb_data (bytes): PDB translated in btes
feat_list (list(str)): list of function name, e.g., ['deeprank.features.ResidueDensity', 'deeprank.features.PSSM_IC']
pdb_data (bytes): PDB translated in bytes
featgrp (str): name of the group where to store the xyz feature
featgrp_raw (str): name of the group where to store the raw feature
"""
error_flag = False # when False: success; when True: failed
for feat in feat_list:
feat_module = importlib.import_module(feat,package=None)
feat_module.__compute_feature__(pdb_data,featgrp,featgrp_raw)
error_flag = feat_module.__compute_feature__(pdb_data,featgrp,featgrp_raw)

if re.search('ResidueDensity', feat) and error_flag == True:
return error_flag


#====================================================================================
Expand Down
9 changes: 6 additions & 3 deletions deeprank/learn/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import h5py
import pickle
import re

from functools import partial

Expand Down Expand Up @@ -457,11 +458,12 @@ def _select_pdb(self, mol_names):
"""

if self.use_rotation is not None:
fnames_original = list(filter(lambda x: '_r' not in x, mol_names))
fnames_original = list(filter(lambda x: not re.search('_r\d+$',x), mol_names))
fnames_augmented = []
if self.use_rotation > 0:
for i in range(self.use_rotation):
fnames_augmented += list(filter(lambda x: '_r%03d' %(i+1) in x, mol_names))
# fnames_augmented += list(filter(lambda x: '_r%03d' %(i+1) in x, mol_names))
fnames_augmented += list(filter(lambda x: re.search('_r%03d$' %(i+1), x), mol_names))
mol_names = fnames_original + fnames_augmented
else:
mol_names = fnames_original
Expand Down Expand Up @@ -1325,7 +1327,8 @@ def _featgrid(center,value,grid,npts):
# shortcut for th center
x0,y0,z0 = center

beta = 1.0
sigma = np.sqrt(1./2)
beta = 0.5/(sigma**2)
cutoff = 5.*beta

dd = np.sqrt( (grid[0]-x0)**2 + (grid[1]-y0)**2 + (grid[2]-z0)**2 )
Expand Down

0 comments on commit b987040

Please sign in to comment.