diff --git a/deeprank/__init__.py b/deeprank/__init__.py index fc733e08..53a788f0 100644 --- a/deeprank/__init__.py +++ b/deeprank/__init__.py @@ -1,18 +1,15 @@ # h5py complains since the last numpy update ... # the warning is -#/home/nico/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. +# /home/nico/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. # from ._conv import register_converters as _register_converters import warnings -warnings.simplefilter(action='ignore', category=FutureWarning) -# generate the data from .generate import * - -# tools from .tools import * +warnings.simplefilter(action='ignore', category=FutureWarning) + + # deep learning # import torch fals on Travis #from .learn import * - - diff --git a/deeprank/features/AtomicFeature.py b/deeprank/features/AtomicFeature.py index 86f032af..0ea88cb5 100644 --- a/deeprank/features/AtomicFeature.py +++ b/deeprank/features/AtomicFeature.py @@ -1,16 +1,24 @@ import os + import numpy as np -from deeprank.tools import pdb2sql from deeprank.features import FeatureClass +from deeprank.tools import pdb2sql -class AtomicFeature(FeatureClass): - def __init__(self,pdbfile,param_charge=None,param_vdw=None,patch_file=None, - contact_distance=8.5, root_export = './',individual_directory=False,verbose=False): +class AtomicFeature(FeatureClass): - ''' - Compute the Coulomb, van der Waals interaction and charges + def __init__( + self, + pdbfile, + param_charge=None, + param_vdw=None, + patch_file=None, + contact_distance=8.5, + root_export='./', + individual_directory=False, + verbose=False): + """Compute the Coulomb, van der Waals interaction and charges. Args: @@ -61,7 +69,7 @@ def __init__(self,pdbfile,param_charge=None,param_vdw=None,patch_file=None, >>> >>> # close the db >>> atfeat.sqldb.close() - ''' + """ super().__init__("Atomic") @@ -87,7 +95,7 @@ def __init__(self,pdbfile,param_charge=None,param_vdw=None,patch_file=None, # read the force field self.read_charge_file() - if patch_file != None: + if patch_file is not None: self.patch = self.read_patch() else: self.patch = None @@ -98,23 +106,21 @@ def __init__(self,pdbfile,param_charge=None,param_vdw=None,patch_file=None, # get the contact atoms self.get_contact_atoms() - ##################################################################################### + ########################################################################## # # READ INPUT FILES # - ##################################################################################### + ########################################################################## def read_charge_file(self): - - '''Read the .top file given in entry. + """Read the .top file given in entry. This function creates : - self.charge : dictionary {(resname,atname):charge} - self.valid_resnames : list ['VAL','ALP', .....] - self.at_name_type_convertor : dictionary {(resname,atname):attype} - - ''' + """ f = open(self.param_charge) data = f.readlines() @@ -131,66 +137,62 @@ def read_charge_file(self): # split the line words = l.split() - #get the resname/atname - res,atname = words[0],words[2] + # get the resname/atname + res, atname = words[0], words[2] # get the charge ind = l.find('charge=') - q = float(l[ind+7:ind+13]) + q = float(l[ind + 7:ind + 13]) # get the type attype = words[3].split('=')[-1] # store the charge - self.charge[(res,atname)] = q + self.charge[(res, atname)] = q # put the resname in a list so far resnames.append(res) # dictionary for conversion name/type - self.at_name_type_convertor[(res,atname)] = attype + self.at_name_type_convertor[(res, atname)] = attype self.valid_resnames = list(set(resnames)) - def read_patch(self): - - '''Read the patchfile. + """Read the patchfile. This function creates - self.patch_charge : Dicitionary {(resName,atName) : charge} - self.patch_type : Dicitionary {(resName,atName) : type} - - ''' + """ f = open(self.patch_file) data = f.readlines() f.close() - self.patch_charge,self.patch_type = {},{} + self.patch_charge, self.patch_type = {}, {} for l in data: # ignore comments - if l[0] != '#' and l[0] != '!' and len(l.split())>0: + if l[0] != '#' and l[0] != '!' and len(l.split()) > 0: words = l.split() # get the new charge ind = l.find('CHARGE=') - q = float(l[ind+7:ind+13]) - self.patch_charge [(words[0],words[3])] = q + q = float(l[ind + 7:ind + 13]) + self.patch_charge[(words[0], words[3])] = q # get the new type if any ind = l.find('TYPE=') if ind != -1: - type_ = l[ind+5:ind+9] - self.patch_type[(words[0],words[3])] = type_.strip() + type_ = l[ind + 5:ind + 9] + self.patch_type[(words[0], words[3])] = type_.strip() def read_vdw_file(self): - - ''' Read the .param file + """Read the .param file. The patch file must be of the form: @@ -203,7 +205,7 @@ def read_vdw_file(self): This function creates - self.vdw : dictionary {attype:[E1,S1]} - ''' + """ f = open(self.param_vdw) data = f.readlines() @@ -224,39 +226,42 @@ def read_vdw_file(self): if line[0][0] == '#': continue - self.vdw_param[line[1]] = list(map(float,line[2:4])) + self.vdw_param[line[1]] = list(map(float, line[2:4])) def get_contact_atoms(self): - """Get the contact atoms only select amino acids. The ligands are not considered. """ # position of the chains - xyz1 = np.array(self.sqldb.get('x,y,z',chainID='A')) - xyz2 = np.array(self.sqldb.get('x,y,z',chainID='B')) + xyz1 = np.array(self.sqldb.get('x,y,z', chainID='A')) + xyz2 = np.array(self.sqldb.get('x,y,z', chainID='B')) # rowID of the second chain - index_a = self.sqldb.get('rowID',chainID='A') - index_b = self.sqldb.get('rowID',chainID='B') + index_a = self.sqldb.get('rowID', chainID='A') + index_b = self.sqldb.get('rowID', chainID='B') # resName of the chains - resName1 = np.array(self.sqldb.get('resName',chainID='A')) - resName2 = np.array(self.sqldb.get('resName',chainID='B')) + resName1 = np.array(self.sqldb.get('resName', chainID='A')) + resName2 = np.array(self.sqldb.get('resName', chainID='B')) # declare the contact atoms self.contact_atoms_A = [] self.contact_atoms_B = [] - #The contact atom pairs only co ntains pairs of atoms that are + # The contact atom pairs only co ntains pairs of atoms that are # in contact self.contact_pairs = {} - for i,x0 in enumerate(xyz1): + for i, x0 in enumerate(xyz1): # compute the contact atoms - contacts = np.where(np.sqrt(np.sum((xyz2-x0)**2,1)) < self.contact_distance)[0] + contacts = np.where( + np.sqrt( + np.sum( + (xyz2 - x0)**2, + 1)) < self.contact_distance)[0] # if we have contact atoms and resA is not a ligand if (len(contacts) > 0) and (resName1[i] in self.valid_resnames): @@ -264,26 +269,32 @@ def get_contact_atoms(self): # add i to the list # add the index of b if its resname is not a ligand self.contact_atoms_A += [index_a[i]] - self.contact_atoms_B += [index_b[k] for k in contacts if resName2[k] in self.valid_resnames] + self.contact_atoms_B += [index_b[k] + for k in contacts if resName2[k] in self.valid_resnames] # add the contact pairs to the list - self.contact_pairs[index_a[i]] = [index_b[k] for k in contacts if resName2[k] in self.valid_resnames] + self.contact_pairs[index_a[i]] = [index_b[k] + for k in contacts if resName2[k] in self.valid_resnames] # create a set of unique indexes self.contact_atoms_A = sorted(set(self.contact_atoms_A)) self.contact_atoms_B = sorted(set(self.contact_atoms_B)) # if no atoms were found - if len(self.contact_atoms_A)==0: + if len(self.contact_atoms_A) == 0: print('Warning : No contact atoms detected in atomicFeature') - def _extend_contact_to_residue(self): - """Extend the contact atoms to entire residue where one atom is contacting.""" + """Extend the contact atoms to entire residue where one atom is + contacting.""" # extract the data - dataA = self.sqldb.get('chainId,resName,resSeq',rowID=self.contact_atoms_A) - dataB = self.sqldb.get('chainId,resName,resSeq',rowID=self.contact_atoms_B) + dataA = self.sqldb.get( + 'chainId,resName,resSeq', + rowID=self.contact_atoms_A) + dataB = self.sqldb.get( + 'chainId,resName,resSeq', + rowID=self.contact_atoms_B) # create tuple cause we want to hash through it dataA = [tuple(x) for x in dataA] @@ -294,39 +305,43 @@ def _extend_contact_to_residue(self): resB = list(set(dataB)) # init the list - index_contact_A,index_contact_B = [],[] + index_contact_A, index_contact_B = [], [] # contact of chain A for resdata in resA: - chainID,resName,resSeq = resdata - index_contact_A += self.sqldb.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) + chainID, resName, resSeq = resdata + index_contact_A += self.sqldb.get('rowID', + chainID=chainID, + resName=resName, + resSeq=resSeq) # contact of chain B for resdata in resB: - chainID,resName,resSeq = resdata - index_contact_B += self.sqldb.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) + chainID, resName, resSeq = resdata + index_contact_B += self.sqldb.get('rowID', + chainID=chainID, + resName=resName, + resSeq=resSeq) # make sure that we don't have double (maybe optional) index_contact_A = sorted(set(index_contact_A)) index_contact_B = sorted(set(index_contact_B)) - return index_contact_A,index_contact_B + return index_contact_A, index_contact_B - - - ##################################################################################### + ########################################################################## # # Assign parameters # - ##################################################################################### + ########################################################################## def assign_parameters(self): + """Assign to each atom in the pdb its charge and vdw interchain + parameters. - '''Assign to each atom in the pdb its charge and vdw interchain parameters. - - Directly deals with the patch so that we don't loop over the residues - multiple times - ''' + Directly deals with the patch so that we don't loop over the + residues multiple times + """ # get all the resnumbers if self.verbose: @@ -334,8 +349,7 @@ def assign_parameters(self): data = self.sqldb.get('chainID,resSeq,resName') natom = len(data) - data = np.unique(np.array(data),axis=0) - + data = np.unique(np.array(data), axis=0) # declare the parameters for future insertion in SQL atcharge = np.zeros(natom) @@ -343,56 +357,61 @@ def assign_parameters(self): atsig = np.zeros(natom) # check - attype = np.zeros(natom,dtype='r_off] = 0. - pref[r r_off] = 0. + pref[r < r_on] = 1.0 return pref -##################################################################################### +########################################################################## # # THE MAIN FUNCTION CALLED IN THE INTERNAL FEATURE CALCULATOR # -##################################################################################### - -def __compute_feature__(pdb_data,featgrp,featgrp_raw): +########################################################################## - """Main function called in deeprank for the feature calculations +def __compute_feature__(pdb_data, featgrp, featgrp_raw): + """Main function called in deeprank for the feature calculations. Args: pdb_data (list(bytes)): pdb information @@ -981,9 +1036,9 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): FF = path + '/forcefield/' atfeat = AtomicFeature(pdb_data, - param_charge = FF + 'protein-allhdg5-4_new.top', - param_vdw = FF + 'protein-allhdg5-4_new.param', - patch_file = FF + 'patch.top') + param_charge=FF + 'protein-allhdg5-4_new.top', + param_vdw=FF + 'protein-allhdg5-4_new.param', + patch_file=FF + 'patch.top') atfeat.assign_parameters() diff --git a/deeprank/features/BSA.py b/deeprank/features/BSA.py index 57500664..09e2ba15 100644 --- a/deeprank/features/BSA.py +++ b/deeprank/features/BSA.py @@ -1,7 +1,9 @@ import os + import numpy as np -from deeprank.tools import pdb2sql + from deeprank.features import FeatureClass +from deeprank.tools import pdb2sql try: import freesasa @@ -9,11 +11,11 @@ except ImportError: print('Freesasa not found') -class BSA(FeatureClass): - def __init__(self,pdb_data,chainA='A',chainB='B'): +class BSA(FeatureClass): - '''Compute the burried surface area feature + def __init__(self, pdb_data, chainA='A', chainB='B'): + """Compute the burried surface area feature. Freesasa is required for this feature. @@ -48,12 +50,11 @@ def __init__(self,pdb_data,chainA='A',chainB='B'): >>> bsa.get_structure() >>> bsa.get_contact_residue_sasa() >>> bsa.sql.close() - - ''' + """ self.pdb_data = pdb_data self.sql = pdb2sql(pdb_data) - self.chains_label = [chainA,chainB] + self.chains_label = [chainA, chainB] self.feature_data = {} self.feature_data_xyz = {} @@ -64,73 +65,83 @@ def get_structure(self): """Get the pdb structure of the molecule.""" # we can have a str or a list of bytes as input - if isinstance(self.pdb_data,str): + if isinstance(self.pdb_data, str): self.complex = freesasa.Structure(self.pdb_data) else: self.complex = freesasa.Structure() atomdata = self.sql.get('name,resName,resSeq,chainID,x,y,z') - for atomName,residueName,residueNumber,chainLabel,x,y,z in atomdata: + for atomName, residueName, residueNumber, chainLabel, x, y, z in atomdata: atomName = '{:>2}'.format(atomName[0]) - self.complex.addAtom(atomName,residueName,residueNumber,chainLabel,x,y,z) + self.complex.addAtom( + atomName, residueName, residueNumber, chainLabel, x, y, z) self.result_complex = freesasa.calc(self.complex) self.chains = {} self.result_chains = {} for label in self.chains_label: self.chains[label] = freesasa.Structure() - atomdata = self.sql.get('name,resName,resSeq,chainID,x,y,z',chainID=label) - for atomName,residueName,residueNumber,chainLabel,x,y,z in atomdata: + atomdata = self.sql.get( + 'name,resName,resSeq,chainID,x,y,z', chainID=label) + for atomName, residueName, residueNumber, chainLabel, x, y, z in atomdata: atomName = '{:>2}'.format(atomName[0]) - self.chains[label].addAtom(atomName,residueName,residueNumber,chainLabel,x,y,z) + self.chains[label].addAtom( + atomName, residueName, residueNumber, chainLabel, x, y, z) self.result_chains[label] = freesasa.calc(self.chains[label]) - def get_contact_residue_sasa(self,cutoff=5.5): + def get_contact_residue_sasa(self, cutoff=5.5): """Compute the feature value.""" self.bsa_data = {} self.bsa_data_xyz = {} res = self.sql.get_contact_residue(cutoff=cutoff) - res = res[0]+res[1] + res = res[0] + res[1] for r in res: # define the selection string and the bsa for the complex - select_str = ('res, (resi %d) and (chain %s)' %(r[1],r[0]),) - asa_complex = freesasa.selectArea(select_str,self.complex,self.result_complex)['res'] + select_str = ('res, (resi %d) and (chain %s)' % (r[1], r[0]),) + asa_complex = freesasa.selectArea( + select_str, self.complex, self.result_complex)['res'] # define the selection string and the bsa for the isolated - select_str = ('res, resi %d' %r[1],) - asa_unbound = freesasa.selectArea(select_str,self.chains[r[0]],self.result_chains[r[0]])['res'] + select_str = ('res, resi %d' % r[1],) + asa_unbound = freesasa.selectArea( + select_str, self.chains[r[0]], self.result_chains[r[0]])['res'] # define the bsa - bsa = asa_unbound-asa_complex + bsa = asa_unbound - asa_complex # define the xyz key : (chain,x,y,z) - chain = {'A':0,'B':1}[r[0]] + chain = {'A': 0, 'B': 1}[r[0]] atcenter = 'CB' if r[2] == 'GLY': atcenter = 'CA' - xyz = self.sql.get('x,y,z',resSeq=r[1],chainID=r[0],name=atcenter)[0] + xyz = self.sql.get( + 'x,y,z', + resSeq=r[1], + chainID=r[0], + name=atcenter)[0] #xyz = np.mean(self.sql.get('x,y,z',resSeq=r[1],chainID=r[0]),0) - xyzkey = tuple([chain]+xyz) + xyzkey = tuple([chain] + xyz) # put the data in dict - self.bsa_data[r] = [bsa] - self.bsa_data_xyz[xyzkey] = [bsa] + self.bsa_data[r] = [bsa] + self.bsa_data_xyz[xyzkey] = [bsa] # pyt the data in dict self.feature_data['bsa'] = self.bsa_data self.feature_data_xyz['bsa'] = self.bsa_data_xyz -##################################################################################### +########################################################################## # # THE MAIN FUNCTION CALLED IN THE INTERNAL FEATURE CALCULATOR # -##################################################################################### +########################################################################## -def __compute_feature__(pdb_data,featgrp,featgrp_raw): + +def __compute_feature__(pdb_data, featgrp, featgrp_raw): # create the BSA instance bsa = BSA(pdb_data) @@ -149,16 +160,15 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): bsa.sql.close() -##################################################################################### +########################################################################## # # TEST THE CLASS # -##################################################################################### +########################################################################## if __name__ == '__main__': bsa = BSA('1AK4.pdb') bsa.get_structure() - #bsa.get_contact_residue_sasa() + # bsa.get_contact_residue_sasa() bsa.sql.close() - diff --git a/deeprank/features/FeatureClass.py b/deeprank/features/FeatureClass.py index 94070a3e..8144a360 100644 --- a/deeprank/features/FeatureClass.py +++ b/deeprank/features/FeatureClass.py @@ -1,10 +1,11 @@ import os + import numpy as np -class FeatureClass(object): - def __init__(self,feature_type): +class FeatureClass(object): + def __init__(self, feature_type): ''' Master class fron which all the other Feature classes should be derived.""" Each subclass must compute : @@ -28,31 +29,31 @@ def __init__(self,feature_type): self.export_directories = {} self.error = False - def export_data_hdf5(self,featgrp): + def export_data_hdf5(self, featgrp): """Export the data in human readable format in an HDF5 file group. - For **atomic features**, the format of the data must be : chainID resSeq resNum name [values] - For **residue features**, the format must be : chainID resSeq resNum [values] - """ # loop through the datadict and name - for name,data in self.feature_data.items(): + for name, data in self.feature_data.items(): ds = [] - for key,value in data.items(): + for key, value in data.items(): # residue based feature if len(key) == 3: # tags - feat = '{:>4}{:>10}{:>10}'.format(key[0],key[1],key[2]) + feat = '{:>4}{:>10}{:>10}'.format(key[0], key[1], key[2]) # atomic based features elif len(key) == 4: # tags - feat = '{:>4}{:>10}{:>10}{:>10}'.format(key[0],key[1],key[2],key[3]) + feat = '{:>4}{:>10}{:>10}{:>10}'.format( + key[0], key[1], key[2], key[3]) # values for v in value: @@ -62,20 +63,18 @@ def export_data_hdf5(self,featgrp): ds.append(feat) # put in the hdf5 file - if len(ds) ==0 : + if len(ds) == 0: self.error = True return - ds = np.array(ds).astype('|S'+str(len(ds[0]))) + ds = np.array(ds).astype('|S' + str(len(ds[0]))) # create the dataset - if name+'_raw' in featgrp: - old_data = featgrp[name+'_raw'] + if name + '_raw' in featgrp: + old_data = featgrp[name + '_raw'] old_data[...] = ds else: - featgrp.create_dataset(name+'_raw',data=ds) - - + featgrp.create_dataset(name + '_raw', data=ds) ######################################## # @@ -89,22 +88,22 @@ def export_data_hdf5(self,featgrp): # CON : only usefull for deeprank # ######################################## - def export_dataxyz_hdf5(self,featgrp): + + def export_dataxyz_hdf5(self, featgrp): """Export the data in xyz-val format in an HDF5 file group. For **atomic** and **residue** the format of the data must be : x y z [values] """ - # loop through the datadict and name - for name,data in self.feature_data_xyz.items(): + for name, data in self.feature_data_xyz.items(): # create the data set - ds = np.array([list(key)+value for key,value in data.items()]) + ds = np.array([list(key) + value for key, value in data.items()]) # create the dataset if name in featgrp: old = featgrp[name] old[...] = ds else: - featgrp.create_dataset(name,data=ds) + featgrp.create_dataset(name, data=ds) diff --git a/deeprank/features/FullPSSM.py b/deeprank/features/FullPSSM.py index 03775076..6caee431 100644 --- a/deeprank/features/FullPSSM.py +++ b/deeprank/features/FullPSSM.py @@ -1,29 +1,28 @@ import os -import numpy as np - +import sys from time import time -from deeprank.tools import pdb2sql -from deeprank.tools import SASA +import numpy as np + from deeprank.features import FeatureClass from deeprank.generate import settings -import sys +from deeprank.tools import SASA, pdb2sql + -printif = lambda string,cond: print(string) if cond else None +def printif(string, cond): return print(string) if cond else None -##################################################################################### +########################################################################## # # Definition of the class # -##################################################################################### +########################################################################## class FullPSSM(FeatureClass): - def __init__(self,mol_name=None,pdbfile=None,pssm_path=None, - debug=True,pssm_format='new'): - - '''Compute all the PSSM data. + def __init__(self, mol_name=None, pdbfile=None, pssm_path=None, + debug=True, pssm_format='new'): + """Compute all the PSSM data. Simply extracts all the PSSM information and store that into features @@ -41,7 +40,7 @@ def __init__(self,mol_name=None,pdbfile=None,pssm_path=None, >>> pssm.read_PSSM_data() >>> pssm.get_feature_value() >>> print(pssm.feature_data_xyz) - ''' + """ super().__init__("Residue") @@ -52,11 +51,30 @@ def __init__(self,mol_name=None,pdbfile=None,pssm_path=None, self.debug = debug self.pssm_format = pssm_format - if isinstance(pdbfile,str) and mol_name is None: + if isinstance(pdbfile, str) and mol_name is None: self.mol_name = os.path.splitext(pdbfile)[0] - res_names = ['ALA','ARG','ASN','ASP','CYS','GLN','GLU','GLY','HIS','LLE', - 'LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL'] + res_names = [ + 'ALA', + 'ARG', + 'ASN', + 'ASP', + 'CYS', + 'GLN', + 'GLU', + 'GLY', + 'HIS', + 'LLE', + 'LEU', + 'LYS', + 'MET', + 'PHE', + 'PRO', + 'SER', + 'THR', + 'TRP', + 'TYR', + 'VAL'] self.pssm_val_name = ['PSSM_' + n for n in res_names] for name in self.pssm_val_name: @@ -72,60 +90,86 @@ def read_PSSM_data(self): """Read the PSSM data into a dictionary.""" names = os.listdir(self.pssm_path) - fname = [n for n in names if n.find(self.molname)==0] - - + fname = [n for n in names if n.find(self.molname) == 0] # old format with one file for all data # and only pssm data if self.pssm_format == 'old': - if len(fname)>1: - raise ValueError('Multiple PSSM files found for %s in %s',self.mol_name,self.pssm_path) - if len(fname)==0: - raise FileNotFoundError('No PSSM file found for %s in %s',self.mol_name,self.pssm_path) + if len(fname) > 1: + raise ValueError( + 'Multiple PSSM files found for %s in %s', + self.mol_name, + self.pssm_path) + if len(fname) == 0: + raise FileNotFoundError( + 'No PSSM file found for %s in %s', + self.mol_name, + self.pssm_path) else: fname = fname[0] - f = open(self.pssm_path + '/' + fname,'rb') + f = open(self.pssm_path + '/' + fname, 'rb') data = f.readlines() f.close() - raw_data = list( map(lambda x: x.decode('utf-8').split(),data)) + raw_data = list(map(lambda x: x.decode('utf-8').split(), data)) - self.pssm_res_data = np.array(raw_data)[:,:3] - self.pssm_res_data = [ (r[0],int(r[1]),r[2]) for r in self.pssm_res_data ] - self.pssm_data = np.array(raw_data)[:,3:].astype(np.float) + self.pssm_res_data = np.array(raw_data)[:, :3] + self.pssm_res_data = [(r[0], int(r[1]), r[2]) + for r in self.pssm_res_data] + self.pssm_data = np.array(raw_data)[:, 3:].astype(np.float) # new format with 2 files (each chain has one file) # and aligned mapping and IC (i.e. the iScore format) elif self.pssm_format == 'new': - if len(fname)<2: - raise FileNotFoundError('Only one PSSM file found for %s in %s',self.mol_name,self.pssm_path) + if len(fname) < 2: + raise FileNotFoundError( + 'Only one PSSM file found for %s in %s', + self.mol_name, + self.pssm_path) # get chain name fname.sort() chain_names = [n.split('.')[1] for n in fname] resmap = { - 'A' : 'ALA', 'R' : 'ARG', 'N' : 'ASN', 'D' : 'ASP', 'C' : 'CYS', 'E' : 'GLU', 'Q' : 'GLN', - 'G' : 'GLY', 'H' : 'HIS', 'I' : 'ILE', 'L' : 'LEU', 'K' : 'LYS', 'M' : 'MET', 'F' : 'PHE', - 'P' : 'PRO', 'S' : 'SER', 'T' : 'THR', 'W' : 'TRP', 'Y' : 'TYR', 'V' : 'VAL', - 'B' : 'ASX', 'U' : 'SEC', 'Z' : 'GLX' - } + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'E': 'GLU', + 'Q': 'GLN', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', + 'B': 'ASX', + 'U': 'SEC', + 'Z': 'GLX'} iiter = 0 - for chainID, fn in zip(chain_names,fname): + for chainID, fn in zip(chain_names, fname): - f = open(self.pssm_path + '/' + fn,'rb') + f = open(self.pssm_path + '/' + fn, 'rb') data = f.readlines() f.close() - raw_data = list( map(lambda x: x.decode('utf-8').split(),data)) + raw_data = list(map(lambda x: x.decode('utf-8').split(), data)) - rd = np.array(raw_data)[1:,:2] - rd = [ (chainID,int(r[0]),resmap[r[1]]) for r in rd ] - pd = np.array(raw_data)[1:,4:-1].astype(np.float) + rd = np.array(raw_data)[1:, :2] + rd = [(chainID, int(r[0]), resmap[r[1]]) for r in rd] + pd = np.array(raw_data)[1:, 4:-1].astype(np.float) ''' rd: residue data @@ -136,8 +180,6 @@ def read_PSSM_data(self): [-2. -1. 3. ... -6. -4. -2.]] ''' - - if iiter == 0: self.pssm_res_data = rd self.pssm_data = pd @@ -145,24 +187,23 @@ def read_PSSM_data(self): else: self.pssm_res_data += rd - self.pssm_data = np.vstack((self.pssm_data,pd)) + self.pssm_data = np.vstack((self.pssm_data, pd)) self.pssm = dict(zip(self.pssm_res_data, self.pssm_data)) - def get_feature_value(self): """get the feature value.""" sql = pdb2sql(self.pdbfile) - xyz_info = sql.get('chainID,resSeq,resName',name='CB') - xyz_info += sql.get('chainID,resSeq,resName',name='CA',resName='GLY') + xyz_info = sql.get('chainID,resSeq,resName', name='CB') + xyz_info += sql.get('chainID,resSeq,resName', name='CA', resName='GLY') - xyz = sql.get('x,y,z',name='CB') - xyz += sql.get('x,y,z',name='CA',resName='GLY') + xyz = sql.get('x,y,z', name='CB') + xyz += sql.get('x,y,z', name='CA', resName='GLY') xyz_dict = {} - for pos,info in zip(xyz,xyz_info): + for pos, info in zip(xyz, xyz_info): xyz_dict[tuple(info)] = pos contact_residue = sql.get_contact_residue(cutoff=5.5) @@ -172,11 +213,11 @@ def get_feature_value(self): pssm_data_xyz = {} pssm_data = {} - if len(contact_residue) ==0: + if len(contact_residue) == 0: # if we have no contact atoms print("WARNING: contact residues NOT found.") - pssm_data_xyz[tuple([0,0.,0.,0.])] = [0.0] - pssm_data_xyz[tuple([1,0.,0.,0.])] = [0.0] + pssm_data_xyz[tuple([0, 0., 0., 0.])] = [0.0] + pssm_data_xyz[tuple([1, 0., 0., 0.])] = [0.0] else: @@ -189,11 +230,11 @@ def get_feature_value(self): if tuple(res) in xyz_dict: #res: ('A', 13, 'ASP') - chain = {'A':0,'B':1}[res[0]] + chain = {'A': 0, 'B': 1}[res[0]] key = tuple([chain] + xyz_dict[tuple(res)]) #key: (0, -19.346, 6.156, -3.44) - for name,value in zip(self.pssm_val_name,self.pssm[res]): + for name, value in zip(self.pssm_val_name, self.pssm[res]): # res: ('B', 573, 'HIS') # name: PSSM_ALA # value:-3.0 @@ -202,15 +243,16 @@ def get_feature_value(self): self.feature_data_xyz[name][key] = [value] else: - printif([tuple(res), ' not found in the pdbfile'],self.debug) + printif([tuple(res), ' not found in the pdbfile'], self.debug) -##################################################################################### +########################################################################## # # THE MAIN FUNCTION CALLED IN THE INTERNAL FEATURE CALCULATOR # -##################################################################################### +########################################################################## + -def __compute_feature__(pdb_data,featgrp,featgrp_raw): +def __compute_feature__(pdb_data, featgrp, featgrp_raw): if settings.__PATH_PSSM_SOURCE__ is None: path = os.path.dirname(os.path.realpath(__file__)) @@ -221,7 +263,7 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): mol_name = os.path.split(featgrp.name)[0] mol_name = mol_name.lstrip('/') - pssm = FullPSSM(mol_name,pdb_data,PSSM) + pssm = FullPSSM(mol_name, pdb_data, PSSM) # read the raw data pssm.read_PSSM_data() @@ -234,22 +276,18 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): pssm.export_data_hdf5(featgrp_raw) - - -##################################################################################### +########################################################################## # # IF WE JUST TEST THE CLASS # -##################################################################################### - +########################################################################## if __name__ == '__main__': t0 = time() path = '/home/nico/Documents/projects/deeprank/data/HADDOCK/BM4_dimers/PSSM_newformat/' - pssm = FullPSSM(mol_name = '1AK4', pdbfile='1AK4_100w.pdb',pssm_path=path) - + pssm = FullPSSM(mol_name='1AK4', pdbfile='1AK4_100w.pdb', pssm_path=path) # get the pssm smoothed sum score pssm.read_PSSM_data() pssm.get_feature_value() - print(' Time %f ms' %((time()-t0)*1000)) + print(' Time %f ms' % ((time() - t0) * 1000)) diff --git a/deeprank/features/NaivePSSM.py b/deeprank/features/NaivePSSM.py index 04e8c2fe..0792d22f 100644 --- a/deeprank/features/NaivePSSM.py +++ b/deeprank/features/NaivePSSM.py @@ -1,26 +1,32 @@ import os -import numpy as np - from time import time -from deeprank.tools import pdb2sql -from deeprank.tools import SASA +import numpy as np + from deeprank.features import FeatureClass +from deeprank.tools import SASA, pdb2sql + -printif = lambda string,cond: print(string) if cond else None +def printif(string, cond): return print(string) if cond else None -##################################################################################### +########################################################################## # # Definition of the class # -##################################################################################### +########################################################################## class NaivePSSM(FeatureClass): - def __init__(self,mol_name=None,pdbfile=None,pssm_path=None,nmask=17,nsmooth=3,debug=False): - - '''Compute compressed PSSM data. + def __init__( + self, + mol_name=None, + pdbfile=None, + pssm_path=None, + nmask=17, + nsmooth=3, + debug=False): + """Compute compressed PSSM data. The method is adapted from: Simplified Sequence-based method for ATP-binding prediction using contextual local evolutionary conservation @@ -46,7 +52,7 @@ def __init__(self,mol_name=None,pdbfile=None,pssm_path=None,nmask=17,nsmooth=3,d >>> pssm.process_pssm_data() >>> pssm.get_feature_value() >>> print(pssm.feature_data_xyz) - ''' + """ super().__init__("Residue") print("== Warning : Please don't use NaivePSSM as a feature it's very experimental") @@ -59,7 +65,7 @@ def __init__(self,mol_name=None,pdbfile=None,pssm_path=None,nmask=17,nsmooth=3,d self.nsmooth = nsmooth self.debug = debug - if isinstance(pdbfile,str) and mol_name is None: + if isinstance(pdbfile, str) and mol_name is None: self.mol_name = os.path.splitext(pdbfile)[0] def get_sasa(self): @@ -76,73 +82,78 @@ def read_PSSM_data(self): """Read the PSSM data.""" names = os.listdir(self.pssm_path) - fname = [n for n in names if n.find(self.molname)==0] - - if len(fname)>1: - raise ValueError('Multiple PSSM files found for %s in %s',self.mol_name,self.pssm_path) - if len(fname)==0: - raise FileNotFoundError('No PSSM file found for %s in %s',self.mol_name,self.pssm_path) + fname = [n for n in names if n.find(self.molname) == 0] + + if len(fname) > 1: + raise ValueError( + 'Multiple PSSM files found for %s in %s', + self.mol_name, + self.pssm_path) + if len(fname) == 0: + raise FileNotFoundError( + 'No PSSM file found for %s in %s', + self.mol_name, + self.pssm_path) else: fname = fname[0] - f = open(self.pssm_path + '/' + fname,'rb') + f = open(self.pssm_path + '/' + fname, 'rb') data = f.readlines() f.close() - raw_data = list( map(lambda x: x.decode('utf-8').split(),data)) + raw_data = list(map(lambda x: x.decode('utf-8').split(), data)) - self.res_data = np.array(raw_data)[:,:3] - self.res_data = [ (r[0],int(r[1]),r[2]) for r in self.res_data ] - self.pssm_data = np.array(raw_data)[:,3:].astype(np.float) + self.res_data = np.array(raw_data)[:, :3] + self.res_data = [(r[0], int(r[1]), r[2]) for r in self.res_data] + self.pssm_data = np.array(raw_data)[:, 3:].astype(np.float) def process_pssm_data(self): """Process the PSSM data.""" - self.pssm_data = self._mask_pssm(self.pssm_data,nmask=self.nmask) + self.pssm_data = self._mask_pssm(self.pssm_data, nmask=self.nmask) self.pssm_data = self._filter_pssm(self.pssm_data) - self.pssm_data = self._smooth_pssm(self.pssm_data,msmooth=self.nsmooth) - self.pssm_data = np.mean(self.pssm_data,1) - + self.pssm_data = self._smooth_pssm( + self.pssm_data, msmooth=self.nsmooth) + self.pssm_data = np.mean(self.pssm_data, 1) @staticmethod - def _mask_pssm(pssm_data,nmask=17): + def _mask_pssm(pssm_data, nmask=17): nres = len(pssm_data) masked_pssm = np.copy(pssm_data) for idata in range(nres): - istart = np.max([idata-nmask,0]) - iend = np.min([idata+nmask+1,nres]) - N = 1./(2*(iend-1-istart)) - masked_pssm[idata,:] -= N*np.sum( pssm_data[istart:iend,:],0 ) + istart = np.max([idata - nmask, 0]) + iend = np.min([idata + nmask + 1, nres]) + N = 1. / (2 * (iend - 1 - istart)) + masked_pssm[idata, :] -= N * np.sum(pssm_data[istart:iend, :], 0) return masked_pssm @staticmethod def _filter_pssm(pssm_data): - pssm_data[pssm_data<=0] = 0 + pssm_data[pssm_data <= 0] = 0 return pssm_data @staticmethod - def _smooth_pssm(pssm_data,msmooth=3): + def _smooth_pssm(pssm_data, msmooth=3): nres = len(pssm_data) smoothed_pssm = np.copy(pssm_data) for idata in range(nres): - istart = np.max([idata-msmooth,0]) - iend = np.min([idata+msmooth+1,nres]) - N = 1./(2*(iend-1-istart)) - smoothed_pssm[idata,:] = N*np.sum( pssm_data[istart:iend,:],0 ) + istart = np.max([idata - msmooth, 0]) + iend = np.min([idata + msmooth + 1, nres]) + N = 1. / (2 * (iend - 1 - istart)) + smoothed_pssm[idata, :] = N * np.sum(pssm_data[istart:iend, :], 0) return smoothed_pssm - - def get_feature_value(self,contact_only=True): + def get_feature_value(self, contact_only=True): """get the feature value.""" sql = pdb2sql(self.pdbfile) - xyz_info = sql.get('chainID,resSeq,resName',name='CB') - xyz = sql.get('x,y,z',name='CB') + xyz_info = sql.get('chainID,resSeq,resName', name='CB') + xyz = sql.get('x,y,z', name='CB') xyz_dict = {} - for pos,info in zip(xyz,xyz_info): + for pos, info in zip(xyz, xyz_info): xyz_dict[tuple(info)] = pos contact_residue = sql.get_contact_residue(cutoff=5.5) @@ -152,37 +163,37 @@ def get_feature_value(self,contact_only=True): pssm_data_xyz = {} pssm_data = {} - for res,data in zip(self.res_data,self.pssm_data): + for res, data in zip(self.res_data, self.pssm_data): if contact_only and res not in contact_residue: continue if tuple(res) in xyz_dict: - chain = {'A':0,'B':1}[res[0]] + chain = {'A': 0, 'B': 1}[res[0]] key = tuple([chain] + xyz_dict[tuple(res)]) sasa = self.sasa[tuple(res)] - pssm_data[res] = [data*sasa] - pssm_data_xyz[key] = [data*sasa] + pssm_data[res] = [data * sasa] + pssm_data_xyz[key] = [data * sasa] else: - printif([tuple(res), ' not found in the pdbfile'],self.debug) + printif([tuple(res), ' not found in the pdbfile'], self.debug) # if we have no contact atoms if len(pssm_data_xyz) == 0: - pssm_data_xyz[tuple([0,0.,0.,0.])] = [0.0] - pssm_data_xyz[tuple([1,0.,0.,0.])] = [0.0] + pssm_data_xyz[tuple([0, 0., 0., 0.])] = [0.0] + pssm_data_xyz[tuple([1, 0., 0., 0.])] = [0.0] - self.feature_data['pssm'] = pssm_data + self.feature_data['pssm'] = pssm_data self.feature_data_xyz['pssm'] = pssm_data_xyz -##################################################################################### +########################################################################## # # THE MAIN FUNCTION CALLED IN THE INTERNAL FEATURE CALCULATOR # -##################################################################################### +########################################################################## -def __compute_feature__(pdb_data,featgrp,featgrp_raw): +def __compute_feature__(pdb_data, featgrp, featgrp_raw): if '__PATH_PSSM_SOURCE__' not in globals(): path = os.path.dirname(os.path.realpath(__file__)) @@ -193,7 +204,7 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): mol_name = os.path.split(featgrp.name)[0] mol_name = mol_name.lstrip('/') - pssm = NaivePSSM(mol_name,pdb_data,PSSM) + pssm = NaivePSSM(mol_name, pdb_data, PSSM) # get the sasa info pssm.get_sasa() @@ -212,19 +223,16 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): pssm.export_data_hdf5(featgrp_raw) - - -##################################################################################### +########################################################################## # # IF WE JUST TEST THE CLASS # -##################################################################################### - +########################################################################## if __name__ == '__main__': t0 = time() path = '/home/nico/Documents/projects/deeprank/data/HADDOCK/BM4_dimers/PSSM_newformat/' - pssm = NaivePSSM(mol_name = '2ABZ', pdbfile='2ABZ_1w.pdb',pssm_path=path) + pssm = NaivePSSM(mol_name='2ABZ', pdbfile='2ABZ_1w.pdb', pssm_path=path) # get the surface accessible solvent area pssm.get_sasa() @@ -234,4 +242,4 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): pssm.process_pssm_data() pssm.get_feature_value() print(pssm.feature_data_xyz) - print(' Time %f ms' %((time()-t0)*1000)) + print(' Time %f ms' % ((time() - t0) * 1000)) diff --git a/deeprank/features/PSSM/reformat_pssm.py b/deeprank/features/PSSM/reformat_pssm.py index 4ca88f36..03bbe7d3 100644 --- a/deeprank/features/PSSM/reformat_pssm.py +++ b/deeprank/features/PSSM/reformat_pssm.py @@ -1,98 +1,89 @@ -import numpy as np import sys +import numpy as np +def write_newfile(names_oldfile, name_newfile): -def write_newfile(names_oldfile,name_newfile): - - chainID = {0:'A',1:'B'} - resconv = { - 'A' : 'ALA', - 'R' : 'ARG', - 'N' : 'ASN', - 'D' : 'ASP', - 'C' : 'CYS', - 'E' : 'GLU', - 'Q' : 'GLN', - 'G' : 'GLY', - 'H' : 'HIS', - 'I' : 'ILE', - 'L' : 'LEU', - 'K' : 'LYS', - 'M' : 'MET', - 'F' : 'PHE', - 'P' : 'PRO', - 'S' : 'SER', - 'T' : 'THR', - 'W' : 'TRP', - 'Y' : 'TYR', - 'V' : 'VAL' - } - - # write the new file - new_file = open(name_newfile,'w') - + chainID = {0: 'A', 1: 'B'} + resconv = { + 'A': 'ALA', + 'R': 'ARG', + 'N' : 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'E': 'GLU', + 'Q': 'GLN', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL' + } - for ifile,f in enumerate(names_oldfile): + # write the new file + new_file = open(name_newfile, 'w') - # read the file - f = open(f,'r') - data = f.readlines()[4:-6] - f.close() + for ifile, f in enumerate(names_oldfile): - # write the new file - for l in data: - l = l.split() - if len(l)>0: + # read the file + f = open(f, 'r') + data = f.readlines()[4:-6] + f.close() - chain = chainID[ifile] - feat = '{:>4}'.format(chain) + # write the new file + for l in data: + l = l.split() + if len(l) > 0: - resNum = l[0] - feat += '{:>10}'.format(resNum) + chain = chainID[ifile] + feat = '{:>4}'.format(chain) - resName1 = l[2] - resName3 = resconv[resName1] - feat += '{:>10}'.format(resName3) + resNum = l[0] + feat += '{:>10}'.format(resNum) - feat += '\t' - values = map(int,l[3:23]) - feat += '\t'.join(map("{: 3d}".format,values)) + resName1 = l[2] + resName3 = resconv[resName1] + feat += '{:>10}'.format(resName3) - feat+= '\n' - new_file.write(feat) + feat += '\t' + values = map(int, l[3:23]) + feat += '\t'.join(map("{: 3d}".format, values)) - new_file.close() + feat += '\n' + new_file.write(feat) + new_file.close() oldfile_dir = '../PSSM/' #oldfiles = sp.check_output('ls %s/*PSSM' %(oldfile_dir),shell=True).decode('utf-8').split() -oldfiles = list(filter(lambda x: '.PSSM' in x,os.listdir(oldfile_dir))) +oldfiles = list(filter(lambda x: '.PSSM' in x, os.listdir(oldfile_dir))) oldfiles = [oldfile_dir + f for f in oldfiles] nfile = len(oldfiles) -oldfiles = np.array(oldfiles).reshape(int(nfile/2),2).tolist() - +oldfiles = np.array(oldfiles).reshape(int(nfile / 2), 2).tolist() for filenames in oldfiles: - print('process files\n\t%s\n\t%s' %(filenames[0],filenames[1])) - cplx_name = [] - cplx_name.append(filenames[0].split('/')[-1]) - cplx_name.append(filenames[1].split('/')[-1]) - cplx_name = list(set([cplx_name[0][:4],cplx_name[1][:4]])) - print(cplx_name) - if len(cplx_name)>1: - print('error' + cplx_name) - sys.exit() - - name_newfile = './'+cplx_name[0]+'.PSSM' - print('\nexport to \t%s\n' %(name_newfile)) - write_newfile(filenames,name_newfile) - - - - - + print('process files\n\t%s\n\t%s' % (filenames[0], filenames[1])) + cplx_name = [] + cplx_name.append(filenames[0].split('/')[-1]) + cplx_name.append(filenames[1].split('/')[-1]) + cplx_name = list(set([cplx_name[0][:4], cplx_name[1][:4]])) + print(cplx_name) + if len(cplx_name) > 1: + print('error' + cplx_name) + sys.exit() + + name_newfile = './' + cplx_name[0] + '.PSSM' + print('\nexport to \t%s\n' % (name_newfile)) + write_newfile(filenames, name_newfile) diff --git a/deeprank/features/PSSM_IC.py b/deeprank/features/PSSM_IC.py index d0cd3a95..32aa5b28 100644 --- a/deeprank/features/PSSM_IC.py +++ b/deeprank/features/PSSM_IC.py @@ -1,24 +1,31 @@ import os -import numpy as np - from time import time -from deeprank.tools import pdb2sql +import numpy as np + from deeprank.features import FeatureClass from deeprank.generate import settings +from deeprank.tools import pdb2sql + -printif = lambda string,cond: print(string) if cond else None +def printif(string, cond): return print(string) if cond else None -##################################################################################### +########################################################################## # # Definition of the class # -##################################################################################### +########################################################################## class PSSM_IC(FeatureClass): - def __init__(self,mol_name=None,pdbfile=None,pssmic_path=None,debug=False,pssm_format='new'): + def __init__( + self, + mol_name=None, + pdbfile=None, + pssmic_path=None, + debug=False, + pssm_format='new'): """Compute the information content of the PSSM. Args: @@ -51,56 +58,83 @@ def get_mol_name(mol_name): return mol_name.split('_')[0] def read_PSSMIC_data(self): - """ Read the PSSM data.""" + """Read the PSSM data.""" names = os.listdir(self.pssmic_path) - fname = [n for n in names if n.find(self.molname)==0] + fname = [n for n in names if n.find(self.molname) == 0] if self.pssm_format == 'old': - if len(fname)>1: - raise ValueError('Multiple PSSM files found for %s in %s',self.pdbname,self.pssmic_path) - if len(fname)==0: - raise FileNotFoundError('No PSSM file found for %s in %s',self.pdbname,self.pssmic_path) + if len(fname) > 1: + raise ValueError( + 'Multiple PSSM files found for %s in %s', + self.pdbname, + self.pssmic_path) + if len(fname) == 0: + raise FileNotFoundError( + 'No PSSM file found for %s in %s', + self.pdbname, + self.pssmic_path) else: fname = fname[0] - f = open(self.pssmic_path + '/' + fname,'rb') + f = open(self.pssmic_path + '/' + fname, 'rb') data = f.readlines() f.close() - raw_data = list( map(lambda x: x.decode('utf-8').split(),data)) + raw_data = list(map(lambda x: x.decode('utf-8').split(), data)) - self.res_data = np.array(raw_data)[:,:3] - self.res_data = [ (r[0],int(r[1]),r[2]) for r in self.res_data ] - self.pssmic_data = np.array(raw_data)[:,-1].astype(np.float) + self.res_data = np.array(raw_data)[:, :3] + self.res_data = [(r[0], int(r[1]), r[2]) for r in self.res_data] + self.pssmic_data = np.array(raw_data)[:, -1].astype(np.float) elif self.pssm_format == 'new': - if len(fname)<2: - raise FileNotFoundError('Only one PSSM file found for %s in %s',self.mol_name,self.pssmic_path) + if len(fname) < 2: + raise FileNotFoundError( + 'Only one PSSM file found for %s in %s', + self.mol_name, + self.pssmic_path) # get chain name fname.sort() chain_names = [n.split('.')[1] for n in fname] resmap = { - 'A' : 'ALA', 'R' : 'ARG', 'N' : 'ASN', 'D' : 'ASP', 'C' : 'CYS', 'E' : 'GLU', 'Q' : 'GLN', - 'G' : 'GLY', 'H' : 'HIS', 'I' : 'ILE', 'L' : 'LEU', 'K' : 'LYS', 'M' : 'MET', 'F' : 'PHE', - 'P' : 'PRO', 'S' : 'SER', 'T' : 'THR', 'W' : 'TRP', 'Y' : 'TYR', 'V' : 'VAL', - 'B' : 'ASX', 'U' : 'SEC', 'Z' : 'GLX' - } + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'E': 'GLU', + 'Q': 'GLN', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', + 'B': 'ASX', + 'U': 'SEC', + 'Z': 'GLX'} iiter = 0 - for chainID, fn in zip(chain_names,fname): + for chainID, fn in zip(chain_names, fname): - f = open(self.pssmic_path + '/' + fn,'rb') + f = open(self.pssmic_path + '/' + fn, 'rb') data = f.readlines() f.close() - raw_data = list( map(lambda x: x.decode('utf-8').split(),data)) + raw_data = list(map(lambda x: x.decode('utf-8').split(), data)) - rd = np.array(raw_data)[1:,:2] - rd = [ (chainID,int(r[0]),resmap[r[1]]) for r in rd ] - pd = np.array(raw_data)[1:,-1].astype(np.float) + rd = np.array(raw_data)[1:, :2] + rd = [(chainID, int(r[0]), resmap[r[1]]) for r in rd] + pd = np.array(raw_data)[1:, -1].astype(np.float) if iiter == 0: self.res_data = rd @@ -109,21 +143,20 @@ def read_PSSMIC_data(self): else: self.res_data += rd - self.pssmic_data = np.hstack((self.pssmic_data,pd)) - + self.pssmic_data = np.hstack((self.pssmic_data, pd)) - def get_feature_value(self,contact_only=True): + def get_feature_value(self, contact_only=True): """Compute the feature value.""" sql = pdb2sql(self.pdbfile) - xyz_info = sql.get('chainID,resSeq,resName',name='CB') - xyz_info += sql.get('chainID,resSeq,resName',name='CA',resName='GLY') + xyz_info = sql.get('chainID,resSeq,resName', name='CB') + xyz_info += sql.get('chainID,resSeq,resName', name='CA', resName='GLY') - xyz = sql.get('x,y,z',name='CB') - xyz += sql.get('x,y,z',name='CA',resName='GLY') + xyz = sql.get('x,y,z', name='CB') + xyz += sql.get('x,y,z', name='CA', resName='GLY') xyz_dict = {} - for pos,info in zip(xyz,xyz_info): + for pos, info in zip(xyz, xyz_info): xyz_dict[tuple(info)] = pos contact_residue = sql.get_contact_residue(cutoff=5.5) @@ -132,37 +165,35 @@ def get_feature_value(self,contact_only=True): pssm_data_xyz = {} pssm_data = {} - for res,data in zip(self.res_data,self.pssmic_data): + for res, data in zip(self.res_data, self.pssmic_data): if contact_only and res not in contact_residue: continue if tuple(res) in xyz_dict: - chain = {'A':0,'B':1}[res[0]] + chain = {'A': 0, 'B': 1}[res[0]] key = tuple([chain] + xyz_dict[tuple(res)]) pssm_data[res] = [data] pssm_data_xyz[key] = [data] else: - printif([tuple(res), ' not found in the pdbfile'],self.debug) + printif([tuple(res), ' not found in the pdbfile'], self.debug) # if we have no contact atoms if len(pssm_data_xyz) == 0: - pssm_data_xyz[tuple([0,0.,0.,0.])] = [0.0] - pssm_data_xyz[tuple([1,0.,0.,0.])] = [0.0] + pssm_data_xyz[tuple([0, 0., 0., 0.])] = [0.0] + pssm_data_xyz[tuple([1, 0., 0., 0.])] = [0.0] self.feature_data['pssm_ic'] = pssm_data self.feature_data_xyz['pssm_ic'] = pssm_data_xyz - -##################################################################################### +########################################################################## # # THE MAIN FUNCTION CALLED IN THE INTERNAL FEATURE CALCULATOR # -##################################################################################### - -def __compute_feature__(pdb_data,featgrp,featgrp_raw): +########################################################################## +def __compute_feature__(pdb_data, featgrp, featgrp_raw): if settings.__PATH_PSSM_SOURCE__ is None: path = os.path.dirname(os.path.realpath(__file__)) @@ -173,7 +204,7 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): mol_name = os.path.split(featgrp.name)[0] mol_name = mol_name.lstrip('/') - pssmic = PSSM_IC(mol_name,pdb_data,path) + pssmic = PSSM_IC(mol_name, pdb_data, path) # read the raw data pssmic.read_PSSMIC_data() @@ -186,21 +217,21 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): pssmic.export_data_hdf5(featgrp_raw) -##################################################################################### +########################################################################## # # IF WE JUST TEST THE CLASS # -##################################################################################### +########################################################################## if __name__ == '__main__': t0 = time() path = '/home/nico/Documents/projects/deeprank/data/HADDOCK/BM4_dimers/PSSM_IC/' - pssmic = PSSM_IC(mol_name = '1AK4', pdbfile='1AK4.pdb',pssmic_path=path) + pssmic = PSSM_IC(mol_name='1AK4', pdbfile='1AK4.pdb', pssmic_path=path) # get the pssm smoothed sum score pssmic.read_PSSMIC_data() pssmic.get_feature_value() print(pssmic.feature_data_xyz) - print(' Time %f ms' %((time()-t0)*1000)) \ No newline at end of file + print(' Time %f ms' % ((time() - t0) * 1000)) diff --git a/deeprank/features/PSSM_IC/extract_ic.py b/deeprank/features/PSSM_IC/extract_ic.py index 46558b87..f8514bf4 100644 --- a/deeprank/features/PSSM_IC/extract_ic.py +++ b/deeprank/features/PSSM_IC/extract_ic.py @@ -1,100 +1,92 @@ -import numpy as np import sys +import numpy as np +def write_newfile(names_oldfile, name_newfile): -def write_newfile(names_oldfile,name_newfile): - - chainID = {0:'A',1:'B'} - resconv = { - 'A' : 'ALA', - 'R' : 'ARG', - 'N' : 'ASN', - 'D' : 'ASP', - 'C' : 'CYS', - 'E' : 'GLU', - 'Q' : 'GLN', - 'G' : 'GLY', - 'H' : 'HIS', - 'I' : 'ILE', - 'L' : 'LEU', - 'K' : 'LYS', - 'M' : 'MET', - 'F' : 'PHE', - 'P' : 'PRO', - 'S' : 'SER', - 'T' : 'THR', - 'W' : 'TRP', - 'Y' : 'TYR', - 'V' : 'VAL' - } - - # write the new file - new_file = open(name_newfile,'w') - + chainID = {0: 'A', 1: 'B'} + resconv = { + 'A': 'ALA', + 'R': 'ARG', + 'N' : 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'E': 'GLU', + 'Q': 'GLN', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL' + } - for ifile,f in enumerate(names_oldfile): + # write the new file + new_file = open(name_newfile, 'w') - # read the file - f = open(f,'r') - data = f.readlines()[4:-6] - f.close() + for ifile, f in enumerate(names_oldfile): - # write the new file - for l in data: - l = l.split() - if len(l)>0: + # read the file + f = open(f, 'r') + data = f.readlines()[4:-6] + f.close() - chain = chainID[ifile] - feat = '{:>4}'.format(chain) + # write the new file + for l in data: + l = l.split() + if len(l) > 0: - resNum = l[0] - feat += '{:>10}'.format(resNum) + chain = chainID[ifile] + feat = '{:>4}'.format(chain) - resName1 = l[2] - resName3 = resconv[resName1] - feat += '{:>10}'.format(resName3) + resNum = l[0] + feat += '{:>10}'.format(resNum) - feat += '\t' - values = float(l[-2]) - feat += '\t{:>10}'.format(values) + resName1 = l[2] + resName3 = resconv[resName1] + feat += '{:>10}'.format(resName3) - feat+= '\n' - new_file.write(feat) + feat += '\t' + values = float(l[-2]) + feat += '\t{:>10}'.format(values) - new_file.close() + feat += '\n' + new_file.write(feat) + new_file.close() oldfile_dir = '../PSSM/' -oldfiles = list(filter(lambda x: '.PSSM' in x,os.listdir(oldfile_dir))) +oldfiles = list(filter(lambda x: '.PSSM' in x, os.listdir(oldfile_dir))) oldfiles = [oldfile_dir + f for f in oldfiles] #oldfiles = sp.check_output('ls %s/*PSSM' %(oldfile_dir),shell=True).decode('utf-8').split() nfile = len(oldfiles) -oldfiles = np.array(oldfiles).reshape(int(nfile/2),2).tolist() +oldfiles = np.array(oldfiles).reshape(int(nfile / 2), 2).tolist() for filenames in oldfiles: - print('process files\n\t%s\n\t%s' %(filenames[0],filenames[1])) - cplx_name = [] - cplx_name.append(filenames[0].split('/')[-1]) - cplx_name.append(filenames[1].split('/')[-1]) - cplx_name = list(set([cplx_name[0][:4],cplx_name[1][:4]])) - print(cplx_name) - - if len(cplx_name)>1: - print('error' + cplx_name) - sys.exit() - - name_newfile = './'+cplx_name[0]+'.PSSM_IC' - print('\nexport to \t%s\n' %(name_newfile)) - write_newfile(filenames,name_newfile) - - - + print('process files\n\t%s\n\t%s' % (filenames[0], filenames[1])) + cplx_name = [] + cplx_name.append(filenames[0].split('/')[-1]) + cplx_name.append(filenames[1].split('/')[-1]) + cplx_name = list(set([cplx_name[0][:4], cplx_name[1][:4]])) + print(cplx_name) + if len(cplx_name) > 1: + print('error' + cplx_name) + sys.exit() + name_newfile = './' + cplx_name[0] + '.PSSM_IC' + print('\nexport to \t%s\n' % (name_newfile)) + write_newfile(filenames, name_newfile) diff --git a/deeprank/features/ResidueDensity.py b/deeprank/features/ResidueDensity.py index a37240de..e0b22e86 100644 --- a/deeprank/features/ResidueDensity.py +++ b/deeprank/features/ResidueDensity.py @@ -1,14 +1,15 @@ -import numpy as np import itertools -from deeprank.tools import pdb2sql -from deeprank.features import FeatureClass import sys +import numpy as np + +from deeprank.features import FeatureClass +from deeprank.tools import pdb2sql class ResidueDensity(FeatureClass): - def __init__(self,pdb_data,chainA='A',chainB='B'): + def __init__(self, pdb_data, chainA='A', chainB='B'): """Compute the residue densities between polar/apolar/charged residues. Args : @@ -27,42 +28,58 @@ def __init__(self,pdb_data,chainA='A',chainB='B'): """ self.pdb_data = pdb_data - self.sql=pdb2sql(pdb_data) - self.chains_label = [chainA,chainB] + self.sql = pdb2sql(pdb_data) + self.chains_label = [chainA, chainB] self.feature_data = {} self.feature_data_xyz = {} - self.residue_types = {'CYS':'polar','HIS':'polar','ASN':'polar','GLN':'polar','SER':'polar','THR':'polar','TYR':'polar','TRP':'polar', - 'ALA':'apolar','PHE':'apolar','GLY':'apolar','ILE':'apolar','VAL':'apolar','MET':'apolar','PRO':'apolar','LEU':'apolar', - 'GLU':'charged','ASP':'charged','LYS':'charged','ARG':'charged'} - self.error = False # When True, feature calculation failed - - - - def get(self,cutoff=5.5): + self.residue_types = { + 'CYS': 'polar', + 'HIS': 'polar', + 'ASN': 'polar', + 'GLN': 'polar', + 'SER': 'polar', + 'THR': 'polar', + 'TYR': 'polar', + 'TRP': 'polar', + 'ALA': 'apolar', + 'PHE': 'apolar', + 'GLY': 'apolar', + 'ILE': 'apolar', + 'VAL': 'apolar', + 'MET': 'apolar', + 'PRO': 'apolar', + 'LEU': 'apolar', + 'GLU': 'charged', + 'ASP': 'charged', + 'LYS': 'charged', + 'ARG': 'charged'} + self.error = False # When True, feature calculation failed + + def get(self, cutoff=5.5): """Get the densities.""" res = self.sql.get_contact_residue(chain1=self.chains_label[0], chain2=self.chains_label[1], - cutoff = cutoff, + cutoff=cutoff, return_contact_pairs=True) - #if len(res) < 5: - # the interface is too small + # if len(res) < 5: + # the interface is too small # self.error = True # return - self.residue_densities = {} - for key,other_res in res.items(): + for key, other_res in res.items(): # some residues are not amino acids if key[2] not in self.residue_types: continue if key not in self.residue_densities: - self.residue_densities[key] = residue_pair(key,self.residue_types[key[2]]) + self.residue_densities[key] = residue_pair( + key, self.residue_types[key[2]]) self.residue_densities[key].density['total'] += len(other_res) for key2 in other_res: @@ -72,14 +89,17 @@ def get(self,cutoff=5.5): continue self.residue_densities[key].density[self.residue_types[key2[2]]] += 1 - self.residue_densities[key].connections[self.residue_types[key2[2]]].append(key2) + self.residue_densities[key].connections[self.residue_types[key2[2]]].append( + key2) if key2 not in self.residue_densities: - self.residue_densities[key2] = residue_pair(key2,self.residue_types[key2[2]]) + self.residue_densities[key2] = residue_pair( + key2, self.residue_types[key2[2]]) self.residue_densities[key2].density['total'] += 1 self.residue_densities[key2].density[self.residue_types[key[2]]] += 1 - self.residue_densities[key2].connections[self.residue_types[key[2]]].append(key) + self.residue_densities[key2].connections[self.residue_types[key[2]]].append( + key) # uncomment for debug # def _print(self): @@ -92,13 +112,14 @@ def extract_features(self): self.feature_data['RCD_total'] = {} self.feature_data_xyz['RCD_total'] = {} - restype = ['polar','apolar','charged'] - pairtype = [ '-'.join(p) for p in list(itertools.combinations_with_replacement(restype,2))] + restype = ['polar', 'apolar', 'charged'] + pairtype = [ + '-'.join(p) for p in list(itertools.combinations_with_replacement(restype, 2))] for p in pairtype: - self.feature_data['RCD_'+p] = {} - self.feature_data_xyz['RCD_'+p] = {} + self.feature_data['RCD_' + p] = {} + self.feature_data_xyz['RCD_' + p] = {} - for key,res in self.residue_densities.items(): + for key, res in self.residue_densities.items(): # total density in raw format self.feature_data['RCD_total'][key] = [res.density['total']] @@ -109,11 +130,16 @@ def extract_features(self): atcenter = 'CA' # get the xyz of the center atom - xyz = self.sql.get('x,y,z',resSeq=key[1],chainID=key[0],name=atcenter)[0] + xyz = self.sql.get( + 'x,y,z', + resSeq=key[1], + chainID=key[0], + name=atcenter)[0] #xyz = np.mean(self.sql.get('x,y,z',resSeq=key[1],chainID=key[0]),0).tolist() - xyz_key = tuple([{'A':0,'B':1}[key[0]]] + xyz) - self.feature_data_xyz['RCD_total'][xyz_key] = [res.density['total']] + xyz_key = tuple([{'A': 0, 'B': 1}[key[0]]] + xyz) + self.feature_data_xyz['RCD_total'][xyz_key] = [ + res.density['total']] # iterate through all the connection for r in restype: @@ -123,16 +149,16 @@ def extract_features(self): self.feature_data[pairtype][key] = [res.density[r]] self.feature_data_xyz[pairtype][xyz_key] = [res.density[r]] + class residue_pair(object): - def __init__(self,res,rtype): + def __init__(self, res, rtype): """Ancillary class that holds information for a given residue.""" self.res = res self.type = rtype - self.density = {'total':0,'polar':0,'apolar':0,'charged':0} - self.connections = {'polar':[],'apolar':[],'charged':[]} - + self.density = {'total': 0, 'polar': 0, 'apolar': 0, 'charged': 0} + self.connections = {'polar': [], 'apolar': [], 'charged': []} # Uncomment for debug # def print(self): @@ -153,15 +179,13 @@ def __init__(self,res,rtype): # print('') - - -##################################################################################### +########################################################################## # # THE MAIN FUNCTION CALLED IN THE INTERNAL FEATURE CALCULATOR # -##################################################################################### +########################################################################## -def __compute_feature__(pdb_data,featgrp,featgrp_raw): +def __compute_feature__(pdb_data, featgrp, featgrp_raw): error_flag = False @@ -169,7 +193,7 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): resdens = ResidueDensity(pdb_data) # get the densities - resdens.get(cutoff=5.5) # may set resdens.error to True + resdens.get(cutoff=5.5) # may set resdens.error to True if not resdens.error: # extract the features @@ -177,10 +201,9 @@ def __compute_feature__(pdb_data,featgrp,featgrp_raw): # export in the hdf5 file resdens.export_dataxyz_hdf5(featgrp) - resdens.export_data_hdf5(featgrp_raw) # may set resdens.error to True + resdens.export_data_hdf5(featgrp_raw) # may set resdens.error to True - if resdens.error == True: + if resdens.error: error_flag = True print("WARNING: Failed to calculate ResidueDensity. This might be caused by a very small interface.") return error_flag - diff --git a/deeprank/features/__init__.py b/deeprank/features/__init__.py index 8c88afff..4e1300e0 100644 --- a/deeprank/features/__init__.py +++ b/deeprank/features/__init__.py @@ -1,7 +1,7 @@ from .FeatureClass import FeatureClass +from .BSA import BSA from .AtomicFeature import AtomicFeature +from .FullPSSM import FullPSSM from .NaivePSSM import NaivePSSM from .PSSM_IC import PSSM_IC -from .BSA import BSA from .ResidueDensity import ResidueDensity -from .FullPSSM import FullPSSM \ No newline at end of file diff --git a/deeprank/generate/DataGenerator.py b/deeprank/generate/DataGenerator.py index 8cb9f604..7a70832f 100644 --- a/deeprank/generate/DataGenerator.py +++ b/deeprank/generate/DataGenerator.py @@ -1,15 +1,17 @@ +import importlib +import logging import os +import re import sys -import importlib -import numpy as np -import h5py from collections import OrderedDict -import logging -from deeprank.tools import pdb2sql + +import h5py +import numpy as np + from deeprank.generate import GridTools as gt from deeprank.generate import settings -import re +from deeprank.tools import pdb2sql try: from tqdm import tqdm @@ -23,13 +25,25 @@ def tqdm(x): except ImportError: pass -_printif = lambda string,cond: print(string) if cond else None + +def _printif(string, cond): return print(string) if cond else None + class DataGenerator(object): - def __init__(self,pdb_select=None,pdb_source=None,pdb_native=None,pssm_source=None, - compute_targets = None, compute_features = None, - data_augmentation=None, hdf5='database.h5',logger=None,debug=True,mpi_comm=None): + def __init__( + self, + pdb_select=None, + pdb_source=None, + pdb_native=None, + pssm_source=None, + compute_targets=None, + compute_features=None, + data_augmentation=None, + hdf5='database.h5', + logger=None, + debug=True, + mpi_comm=None): """Generate the data (features/targets/maps) required for deeprank. Args: @@ -63,23 +77,21 @@ def __init__(self,pdb_select=None,pdb_source=None,pdb_native=None,pssm_source=No >>> 'deeprank.features.PSSM_IC', >>> 'deeprank.features.BSA'], >>> hdf5=h5file) - """ settings.init() - self.pdb_select = pdb_select or [] - self.pdb_source = pdb_source or [] - self.pdb_native = pdb_native or [] + self.pdb_select = pdb_select or [] + self.pdb_source = pdb_source or [] + self.pdb_native = pdb_native or [] settings.__PATH_PSSM_SOURCE__ = pssm_source - self.data_augmentation = data_augmentation self.hdf5 = hdf5 - self.compute_targets = compute_targets + self.compute_targets = compute_targets self.compute_features = compute_features self.all_pdb = [] @@ -93,31 +105,34 @@ def __init__(self,pdb_select=None,pdb_source=None,pdb_native=None,pssm_source=No self.debug = debug # handle the pdb_select - if not isinstance(self.pdb_select,list): + if not isinstance(self.pdb_select, list): self.pdb_select = [self.pdb_select] # check that a source was given if self.pdb_source is None: - raise NotADirectoryError('You must provide one or several source directory where the pdbs are stored') + raise NotADirectoryError( + 'You must provide one or several source directory where the pdbs are stored') # handle the sources - if not isinstance(self.pdb_source,list): + if not isinstance(self.pdb_source, list): self.pdb_source = [self.pdb_source] # get all the conformation path for src in self.pdb_source: if os.path.isdir(src): - self.all_pdb += [os.path.join(src,fname) for fname in os.listdir(src) if fname.endswith('.pdb')] + self.all_pdb += [os.path.join(src, fname) + for fname in os.listdir(src) if fname.endswith('.pdb')] elif os.path.isfile(src): self.all_pdb.append(src) # handle the native - if not isinstance(self.pdb_native,list): + if not isinstance(self.pdb_native, list): self.pdb_native = [self.pdb_native] for src in self.pdb_native: if os.path.isdir(src): - self.all_native += [os.path.join(src,fname) for fname in os.listdir(src)] + self.all_native += [os.path.join(src, fname) + for fname in os.listdir(src)] if os.path.isfile(src): self.all_native.append(src) @@ -131,15 +146,19 @@ def __init__(self,pdb_select=None,pdb_source=None,pdb_native=None,pssm_source=No # MPI COMM self.mpi_comm = mpi_comm -#==================================================================================== +# ==================================================================================== # # CREATE THE DATABASE ALL AT ONCE IF ALL OPTIONS ARE GIVEN # -#==================================================================================== +# ==================================================================================== - def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_distance=8.5): - - '''Create the hdf5 file architecture and compute the features/targets. + def create_database( + self, + verbose=False, + remove_error=True, + prog_bar=False, + contact_distance=8.5): + """Create the hdf5 file architecture and compute the features/targets. Args: verbose (bool, optional): Print creation details @@ -167,7 +186,8 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ >>> >>> #create new files >>> database.create_database(prog_bar=True) - >>> ''' + >>> + """ # deals with the parallelization self.local_pdbs = self.pdb_path @@ -187,24 +207,24 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ self.local_pdbs = pdbs[0] # send to other procs - for iP in range(1,size): - self.mpi_comm.send(pdbs[iP],dest=iP,tag=11) + for iP in range(1, size): + self.mpi_comm.send(pdbs[iP], dest=iP, tag=11) else: # receive procs - self.local_pdbs = self.mpi_comm.recv(source=0,tag=11) + self.local_pdbs = self.mpi_comm.recv(source=0, tag=11) # change hdf5 name h5path, h5name = os.path.split(self.hdf5) - self.hdf5 = os.path.join(h5path, '%03d_' %rank + h5name) + self.hdf5 = os.path.join(h5path, '%03d_' % rank + h5name) # open the file - self.f5 = h5py.File(self.hdf5,'w') + self.f5 = h5py.File(self.hdf5, 'w') self.logger.info('Start Feature calculation') # get the local progress bar desc = '{:25s}'.format('Create database') - cplx_tqdm = tqdm(self.local_pdbs,desc=desc,disable = not prog_bar) + cplx_tqdm = tqdm(self.local_pdbs, desc=desc, disable=not prog_bar) if not prog_bar: print(desc, ':', self.hdf5) @@ -213,7 +233,7 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ for cplx in cplx_tqdm: cplx_tqdm.set_postfix(mol=os.path.basename(cplx)) - self.logger.debug('MOLECULE %s' %(cplx)) + self.logger.debug('MOLECULE %s' % (cplx)) try: @@ -238,11 +258,14 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ ref = cplx else: - if len(self.all_native)>0: + if len(self.all_native) > 0: - ref = list(filter(lambda x: ref_name in x,self.all_native)) + ref = list( + filter( + lambda x: ref_name in x, + self.all_native)) - if len(ref)>1: + if len(ref) > 1: raise ValueError('Multiple native nout found') if len(ref) == 0: raise ValueError('Native not found') @@ -257,16 +280,16 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ # talk a bit if verbose: - print('\n: Process complex %s' %(mol_name)) + print('\n: Process complex %s' % (mol_name)) # crete a subgroup for the molecule molgrp = self.f5.require_group(mol_name) molgrp.attrs['type'] = 'molecule' # add the ref and the complex - self._add_pdb(molgrp,cplx,'complex') + self._add_pdb(molgrp, cplx, 'complex') if ref is not None: - self._add_pdb(molgrp,ref,'native') + self._add_pdb(molgrp, ref, 'native') ################################################ # add the features @@ -276,13 +299,12 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ molgrp.require_group('features') molgrp.require_group('features_raw') - error_flag = False #error_flag => when False: success; when True: failed + error_flag = False # error_flag => when False: success; when True: failed if self.compute_features is not None: error_flag = self._compute_features(self.compute_features, - molgrp['complex'][:], - molgrp['features'], - molgrp['features_raw'] ) - + molgrp['complex'][:], + molgrp['features'], + molgrp['features_raw']) ################################################ # add the targets @@ -299,9 +321,9 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ # add the box center ################################################ molgrp.require_group('grid_points') - center = self._get_grid_center(molgrp['complex'][:],contact_distance) - molgrp['grid_points'].create_dataset('center',data=center) - + center = self._get_grid_center( + molgrp['complex'][:], contact_distance) + molgrp['grid_points'].create_dataset('center', data=center) ################################################ # DATA AUGMENTATION @@ -309,7 +331,12 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ # GET ALL THE NAMES if self.data_augmentation is not None: - mol_aug_name_list = [mol_name + '_r%03d' %(idir+1) for idir in range(self.data_augmentation)] + mol_aug_name_list = [ + mol_name + + '_r%03d' % + (idir + + 1) for idir in range( + self.data_augmentation)] else: mol_aug_name_list = [] @@ -326,26 +353,28 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ # copy the ref into it if ref is not None: - self._add_pdb(molgrp,ref,'native') + self._add_pdb(molgrp, ref, 'native') # get the rotation axis and angle - axis,angle = self._get_aug_rot() + axis, angle = self._get_aug_rot() # create the new pdb - center = self._add_aug_pdb(molgrp,cplx,'complex',axis,angle) + center = self._add_aug_pdb( + molgrp, cplx, 'complex', axis, angle) # copy the targets/features - self.f5.copy(mol_name+'/targets/', molgrp) - self.f5.copy(mol_name+'/features/', molgrp) + self.f5.copy(mol_name + '/targets/', molgrp) + self.f5.copy(mol_name + '/features/', molgrp) # rotate the feature - self._rotate_feature(molgrp,axis,angle,center) + self._rotate_feature(molgrp, axis, angle, center) # grid center molgrp.require_group('grid_points') - center = self._get_grid_center(molgrp['complex'][:],contact_distance) + center = self._get_grid_center( + molgrp['complex'][:], contact_distance) print(center) - molgrp['grid_points'].create_dataset('center',data=center) + molgrp['grid_points'].create_dataset('center', data=center) # store the axis/angl/center as attriutes # in case we need them later @@ -354,24 +383,30 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ molgrp.attrs['center'] = center if error_flag: - #error_flag => when False: success; when True: failed + # error_flag => when False: success; when True: failed self.feature_error += [mol_name] + mol_aug_name_list - self.logger.warning('Error during the feature calculation of %s' %cplx,exc_info=True) + self.logger.warning( + 'Error during the feature calculation of %s' % + cplx, exc_info=True) sys.stdout.flush() except Exception as inst: self.feature_error += [mol_name] + mol_aug_name_list - self.logger.warning('Error during the feature calculation of %s' %cplx,exc_info=True) - _printif('Error during the feature calculation of %s' %cplx,self.debug) - _printif(type(inst),self.debug) - _printif(inst.args,self.debug) + self.logger.warning( + 'Error during the feature calculation of %s' % + cplx, exc_info=True) + _printif( + 'Error during the feature calculation of %s' % + cplx, self.debug) + _printif(type(inst), self.debug) + _printif(inst.args, self.debug) # remove the data where we had issues if remove_error: for mol in self.feature_error: #self.logger.warning('Error during the feature calculation of %s' %cplx,exc_info=True) - _printif('removing %s from %s' %(mol,self.hdf5),self.debug) + _printif('removing %s from %s' % (mol, self.hdf5), self.debug) del self.f5[mol] sys.stdout.flush() @@ -379,18 +414,14 @@ def create_database(self,verbose=False,remove_error=True,prog_bar=False,contact_ self.f5.close() - - -#==================================================================================== +# ==================================================================================== # # ADD FEATURES TO AN EXISTING DATASET # -#==================================================================================== - +# ==================================================================================== - def add_feature(self,prog_bar=True): - - ''' Add a feature to an existing hdf5 file + def add_feature(self, prog_bar=True): + """Add a feature to an existing hdf5 file. Args: prog_bar (bool, optional): use tqdm @@ -404,23 +435,37 @@ def add_feature(self,prog_bar=True): >>> hdf5=h5file) >>> >>> database.add_feature(prog_bar=True) - ''' + """ # check if file exists if not os.path.isfile(self.hdf5): - raise FileNotFoundError('File %s does not exists' %self.hdf5) + raise FileNotFoundError('File %s does not exists' % self.hdf5) # get the folder names - f5 = h5py.File(self.hdf5,'a') + f5 = h5py.File(self.hdf5, 'a') fnames = f5.keys() # get the non rotated ones - fnames_original = list( filter(lambda x: not re.search('_r\d+$',x), fnames) ) - fnames_augmented = list( filter(lambda x: re.search('_r\d+$',x), fnames) ) + fnames_original = list( + filter( + lambda x: not re.search( + r'_r\d+$', + x), + fnames)) + fnames_augmented = list( + filter( + lambda x: re.search( + r'_r\d+$', + x), + fnames)) # computes the features of the original desc = '{:25s}'.format('Add features') - for cplx_name in tqdm(fnames_original,desc=desc,ncols=100,disable = not prog_bar): + for cplx_name in tqdm( + fnames_original, + desc=desc, + ncols=100, + disable=not prog_bar): # molgrp molgrp = f5[cplx_name] @@ -430,8 +475,11 @@ def add_feature(self,prog_bar=True): molgrp.require_group('features_raw') if self.compute_features is not None: - self._compute_features(self.compute_features, molgrp['complex'][:],molgrp['features'],molgrp['features_raw'] ) - + self._compute_features( + self.compute_features, + molgrp['complex'][:], + molgrp['features'], + molgrp['features_raw']) # copy the data from the original to the augmented for cplx_name in fnames_augmented: @@ -440,7 +488,7 @@ def add_feature(self,prog_bar=True): aug_molgrp = f5[cplx_name] # get the source group - mol_name = re.split('_r\d+', molgrp.name)[0] + mol_name = re.split(r'_r\d+', molgrp.name)[0] src_molgrp = f5[mol_name] # get the rotation parameters @@ -452,28 +500,26 @@ def add_feature(self,prog_bar=True): for k in molgrp['features']: if k not in aug_molgrp['features']: - #copy - data = src_molgrp['features/'+k][:] + # copy + data = src_molgrp['features/' + k][:] aug_molgrp.require_group('features') - aug_molgrp.create_dataset("features/"+k,data=data) - - #rotate - self._rotate_feature(aug_molgrp,axis,angle,center,feat_name=[k]) + aug_molgrp.create_dataset("features/" + k, data=data) + # rotate + self._rotate_feature( + aug_molgrp, axis, angle, center, feat_name=[k]) # close the file f5.close() -#==================================================================================== +# ==================================================================================== # # ADD TARGETS TO AN EXISTING DATASET # -#==================================================================================== - +# ==================================================================================== - def add_unique_target(self,targdict): - - '''Add identical targets for all the complexes in the datafile. + def add_unique_target(self, targdict): + """Add identical targets for all the complexes in the datafile. This is usefull if you want to add the binary class of all the complexes created from decoys or natives @@ -483,23 +529,21 @@ def add_unique_target(self,targdict): >>> database = DataGenerator(hdf5='1ak4.hdf5') >>> database.add_unique_target({'DOCKQ':1.0}) - ''' + """ # check if file exists if not os.path.isfile(self.hdf5): - raise FileNotFoundError('File %s does not exists' %self.hdf5) + raise FileNotFoundError('File %s does not exists' % self.hdf5) - f5 = h5py.File(self.hdf5,'a') + f5 = h5py.File(self.hdf5, 'a') for mol in list(f5.keys()): targrp = f5[mol].require_group('targets') - for name,value in targdict.items(): - targrp.create_dataset(name,data=np.array([value])) + for name, value in targdict.items(): + targrp.create_dataset(name, data=np.array([value])) f5.close() - - def add_target(self,prog_bar=False): - - ''' Add a target to an existing hdf5 file + def add_target(self, prog_bar=False): + """Add a target to an existing hdf5 file. Args: prog_bar (bool, optional): Use tqdm @@ -513,33 +557,50 @@ def add_target(self,prog_bar=False): >>> hdf5=h5file) >>> >>> database.add_target(prog_bar=True) - ''' + """ # check if file exists if not os.path.isfile(self.hdf5): - raise FileNotFoundError('File %s does not exists' %self.hdf5) + raise FileNotFoundError('File %s does not exists' % self.hdf5) # name of the hdf5 file - f5 = h5py.File(self.hdf5,'a') + f5 = h5py.File(self.hdf5, 'a') # get the folder names fnames = f5.keys() # get the non rotated ones - fnames_original = list( filter(lambda x: not re.search('_r\d+$',x), fnames) ) - fnames_augmented = list( filter(lambda x: re.search('_r\d+$',x), fnames) ) + fnames_original = list( + filter( + lambda x: not re.search( + r'_r\d+$', + x), + fnames)) + fnames_augmented = list( + filter( + lambda x: re.search( + r'_r\d+$', + x), + fnames)) # compute the targets of the original desc = '{:25s}'.format('Add targets') - for cplx_name in tqdm(fnames_original,desc=desc,ncols=100,disable = not prog_bar): + for cplx_name in tqdm( + fnames_original, + desc=desc, + ncols=100, + disable=not prog_bar): # group of the molecule molgrp = f5[cplx_name] # add the targets if self.compute_targets is not None: - self._compute_targets(self.compute_targets, molgrp['complex'][:],molgrp['targets']) + self._compute_targets( + self.compute_targets, + molgrp['complex'][:], + molgrp['targets']) # copy the targets of the original to the rotated for cplx_name in fnames_augmented: @@ -548,41 +609,42 @@ def add_target(self,prog_bar=False): aug_molgrp = f5[cplx_name] # get the source group - mol_name = re.split('_r\d+', molgrp.name)[0] + mol_name = re.split(r'_r\d+', molgrp.name)[0] src_molgrp = f5[mol_name] # copy the targets to the augmented for k in molgrp['targets']: if k not in aug_molgrp['targets']: - data = src_molgrp['targets/'+k][()] + data = src_molgrp['targets/' + k][()] aug_molgrp.require_group('targets') - aug_molgrp.create_dataset("targets/"+k,data=data) - - + aug_molgrp.create_dataset("targets/" + k, data=data) # close the file f5.close() - -#==================================================================================== +# ==================================================================================== # # PRECOMPUTE TEH GRID POINTS # -#==================================================================================== +# ==================================================================================== @staticmethod - def _get_grid_center(pdb,contact_distance): + def _get_grid_center(pdb, contact_distance): sqldb = pdb2sql(pdb) - xyz1 = np.array(sqldb.get('x,y,z',chainID='A')) - xyz2 = np.array(sqldb.get('x,y,z',chainID='B')) + xyz1 = np.array(sqldb.get('x,y,z', chainID='A')) + xyz2 = np.array(sqldb.get('x,y,z', chainID='B')) - index_b = sqldb.get('rowID',chainID='B') + index_b = sqldb.get('rowID', chainID='B') contact_atoms = [] - for i,x0 in enumerate(xyz1): - contacts = np.where(np.sqrt(np.sum((xyz2-x0)**2,1)) < contact_distance)[0] + for i, x0 in enumerate(xyz1): + contacts = np.where( + np.sqrt( + np.sum( + (xyz2 - x0)**2, + 1)) < contact_distance)[0] if len(contacts) > 0: contact_atoms += [i] @@ -591,23 +653,33 @@ def _get_grid_center(pdb,contact_distance): # create a set of unique indexes contact_atoms = list(set(contact_atoms)) - - center_contact = np.mean(np.array(sqldb.get('x,y,z',rowID=contact_atoms)),0) + center_contact = np.mean( + np.array( + sqldb.get( + 'x,y,z', + rowID=contact_atoms)), + 0) sqldb.close() return center_contact - def precompute_grid(self,grid_info, contact_distance = 8.5, prog_bar = False,time=False,try_sparse=True): + def precompute_grid( + self, + grid_info, + contact_distance=8.5, + prog_bar=False, + time=False, + try_sparse=True): # name of the hdf5 file - f5 = h5py.File(self.hdf5,'a') + f5 = h5py.File(self.hdf5, 'a') # check all the input PDB files mol_names = f5.keys() # get the local progress bar desc = '{:25s}'.format('Precompute grid points') - mol_tqdm = tqdm(mol_names,desc=desc, disable = not prog_bar) + mol_tqdm = tqdm(mol_names, desc=desc, disable=not prog_bar) if not prog_bar: print(desc, ':', self.hdf5) @@ -620,34 +692,34 @@ def precompute_grid(self,grid_info, contact_distance = 8.5, prog_bar = False,tim # compute the data we want on the grid grid = gt.GridTools(molgrp=f5[mol], - number_of_points = grid_info['number_of_points'], - resolution = grid_info['resolution'], - hdf5_file = f5, - contact_distance = contact_distance, - time = time, - prog_bar = prog_bar, - try_sparse = try_sparse) + number_of_points=grid_info['number_of_points'], + resolution=grid_info['resolution'], + hdf5_file=f5, + contact_distance=contact_distance, + time=time, + prog_bar=prog_bar, + try_sparse=try_sparse) f5.close() -#==================================================================================== +# ==================================================================================== # # MAP THE FEATURES TO THE GRID # -#==================================================================================== +# ==================================================================================== + - def map_features(self,grid_info={}, - cuda=False,gpu_block=None, + def map_features(self, grid_info={}, + cuda=False, gpu_block=None, cuda_kernel='/kernel_map.c', - cuda_func_name = 'gaussian', + cuda_func_name='gaussian', try_sparse=True, - reset=False,use_tmpdir=False, + reset=False, use_tmpdir=False, time=False, - prog_bar=True,grid_prog_bar=False, + prog_bar=True, grid_prog_bar=False, remove_error=True): - - ''' Map the feature on a grid of points centered at the interface + """Map the feature on a grid of points centered at the interface. Args: grid_info (dict): Informaton for the grid see deeprank.generate.GridTool.py for details @@ -676,8 +748,7 @@ def map_features(self,grid_info={}, >>> } >>> >>> database.map_features(grid_info,try_sparse=True,time=False,prog_bar=True) - - ''' + """ # default CUDA cuda_func = None @@ -686,23 +757,23 @@ def map_features(self,grid_info={}, # disable CUDA when using MPI if self.mpi_comm is not None: if self.mpi_comm.Get_size() > 1: - if cuda == True: + if cuda: print('Warning : CUDA mapping disabled when using MPI') cuda = False # name of the hdf5 file - f5 = h5py.File(self.hdf5,'a') + f5 = h5py.File(self.hdf5, 'a') # check all the input PDB files mol_names = f5.keys() if len(mol_names) == 0: - _printif('No molecules found in %s' %self.hdf5,self.debug) + _printif('No molecules found in %s' % self.hdf5, self.debug) f5.close() return # fills in the grid data if not provided : default = NONE - grinfo = ['number_of_points','resolution'] + grinfo = ['number_of_points', 'resolution'] for gr in grinfo: if gr not in grid_info: grid_info[gr] = None @@ -715,21 +786,23 @@ def map_features(self,grid_info={}, # if we havent mapped anything yet or if we reset if 'mapped_features' not in list(f5[mol].keys()) or reset: - grid_info['feature'] = list(f5[mol+'/features'].keys()) + grid_info['feature'] = list(f5[mol + '/features'].keys()) # if we have already mapped stuff elif 'mapped_features' in list(f5[mol].keys()): # feature name - all_feat = list(f5[mol+'/features'].keys()) + all_feat = list(f5[mol + '/features'].keys()) # feature already mapped - mapped_feat = list(f5[mol+'/mapped_features/Feature_ind'].keys()) + mapped_feat = list( + f5[mol + '/mapped_features/Feature_ind'].keys()) # we select only the feture that were not mapped yet grid_info['feature'] = [] for feat_name in all_feat: - if not any(map(lambda x: x.startswith(feat_name+'_'), mapped_feat)): + if not any(map(lambda x: x.startswith( + feat_name + '_'), mapped_feat)): grid_info['feature'].append(feat_name) # by default we do not map atomic densities @@ -737,35 +810,35 @@ def map_features(self,grid_info={}, grid_info['atomic_densities'] = None # fills in the features mode if somes are missing : default = IND - modes = ['atomic_densities_mode','feature_mode'] + modes = ['atomic_densities_mode', 'feature_mode'] for m in modes: if m not in grid_info: grid_info[m] = 'ind' # sanity check for cuda - if cuda and gpu_block is None: # pragma: no cover + if cuda and gpu_block is None: # pragma: no cover print('Warning GPU block automatically set to 8 x 8 x 8') print('You can sepcify the block size with gpu_block=[n,m,k]') - gpu_block = [8,8,8] + gpu_block = [8, 8, 8] # initialize cuda - if cuda: # pragma: no cover + if cuda: # pragma: no cover # compile cuda module npts = grid_info['number_of_points'] res = grid_info['resolution'] - module = self.compile_cuda_kernel(cuda_kernel,npts,res) + module = self.compile_cuda_kernel(cuda_kernel, npts, res) # get the cuda function for the atomic/residue feature - cuda_func = self.get_cuda_function(module,cuda_func_name) + cuda_func = self.get_cuda_function(module, cuda_func_name) # get the cuda function for the atomic densties cuda_atomic_name = 'atomic_densities' - cuda_atomic = self.get_cuda_function(module,cuda_atomic_name) + cuda_atomic = self.get_cuda_function(module, cuda_atomic_name) # get the local progress bar desc = '{:25s}'.format('Map Features') - mol_tqdm = tqdm(mol_names,desc=desc,disable = not prog_bar) + mol_tqdm = tqdm(mol_names, desc=desc, disable=not prog_bar) if not prog_bar: print(desc, ':', self.hdf5) @@ -779,47 +852,49 @@ def map_features(self,grid_info={}, try: # compute the data we want on the grid - grid = gt.GridTools(molgrp=f5[mol], - number_of_points = grid_info['number_of_points'], - resolution = grid_info['resolution'], - atomic_densities = grid_info['atomic_densities'], - atomic_densities_mode = grid_info['atomic_densities_mode'], - feature = grid_info['feature'], - feature_mode = grid_info['feature_mode'], - cuda = cuda, - gpu_block = gpu_block, - cuda_func = cuda_func, - cuda_atomic = cuda_atomic, - hdf5_file = f5, - time=time, - prog_bar=grid_prog_bar, - try_sparse=try_sparse) - - except: + grid = gt.GridTools( + molgrp=f5[mol], + number_of_points=grid_info['number_of_points'], + resolution=grid_info['resolution'], + atomic_densities=grid_info['atomic_densities'], + atomic_densities_mode=grid_info['atomic_densities_mode'], + feature=grid_info['feature'], + feature_mode=grid_info['feature_mode'], + cuda=cuda, + gpu_block=gpu_block, + cuda_func=cuda_func, + cuda_atomic=cuda_atomic, + hdf5_file=f5, + time=time, + prog_bar=grid_prog_bar, + try_sparse=try_sparse) + + except BaseException: self.map_error.append(mol) - self.logger.warning('Error during the mapping of %s' %mol,exc_info=True) - _printif('Error during the mapping of %s' %mol,self.debug) + self.logger.warning( + 'Error during the mapping of %s' % + mol, exc_info=True) + _printif('Error during the mapping of %s' % mol, self.debug) # remove the molecule with issues if remove_error: for mol in self.map_error: - print('removing %s from %s' %(mol,self.hdf5)) + print('removing %s from %s' % (mol, self.hdf5)) del f5[mol] # close he hdf5 file f5.close() -#==================================================================================== +# ==================================================================================== # # REMOVE DATA FROM THE DATA SET # -#==================================================================================== - - def remove(self,feature=True,pdb=True,points=True,grid=False): +# ==================================================================================== - '''Remove data from the data set. + def remove(self, feature=True, pdb=True, points=True, grid=False): + """Remove data from the data set. Equivalent to the cleandata command line tool. Once the data has been removed from the file it is impossible to add new features/targets @@ -829,13 +904,12 @@ def remove(self,feature=True,pdb=True,points=True,grid=False): pdb (bool, optional): Remove the pdbs points (bool, optional): remove teh grid points grid (bool, optional): remove the maps + """ - ''' - - _printif('Remove features',self.debug) + _printif('Remove features', self.debug) # name of the hdf5 file - f5 = h5py.File(self.hdf5,'a') + f5 = h5py.File(self.hdf5, 'a') # get the folder names mol_names = f5.keys() @@ -857,23 +931,19 @@ def remove(self,feature=True,pdb=True,points=True,grid=False): f5.close() # reclaim the space - os.system('h5repack %s _tmp.h5py' %self.hdf5) - os.system('mv _tmp.h5py %s' %self.hdf5) - - - + os.system('h5repack %s _tmp.h5py' % self.hdf5) + os.system('mv _tmp.h5py %s' % self.hdf5) -#==================================================================================== +# ==================================================================================== # # Simply tune or test the kernel # -#==================================================================================== +# ==================================================================================== - def _tune_cuda_kernel(self,grid_info,cuda_kernel='kernel_map.c',func='gaussian'): # pragma: no cover - ''' - Tune the CUDA kernel using the kernel tuner + def _tune_cuda_kernel(self, grid_info, cuda_kernel='kernel_map.c', func='gaussian'): # pragma: no cover + """Tune the CUDA kernel using the kernel tuner http://benvanwerkhoven.github.io/kernel_tuner/ Args: @@ -883,44 +953,43 @@ def _tune_cuda_kernel(self,grid_info,cuda_kernel='kernel_map.c',func='gaussian') Raises: ValueError: If the tuner has not been used - ''' - + """ try: from kernel_tuner import tune_kernel - except: + except BaseException: print('Install the Kernel Tuner : \n \t\t pip install kernel_tuner') print('http://benvanwerkhoven.github.io/kernel_tuner/') # fills in the grid data if not provided : default = NONE - grinfo = ['number_of_points','resolution'] - for gr in grinfo: + grinfo = ['number_of_points', 'resolution'] + for gr in grinfo: if gr not in grid_info: raise ValueError('%s must be specified to tune the kernel') # define the grid center_contact = np.zeros(3) - nx,ny,nz = grid_info['number_of_points'] - dx,dy,dz = grid_info['resolution'] - lx,ly,lz = nx*dx,ny*dy,nz*dz + nx, ny, nz = grid_info['number_of_points'] + dx, dy, dz = grid_info['resolution'] + lx, ly, lz = nx * dx, ny * dy, nz * dz - x = np.linspace(0,lx,nx) - y = np.linspace(0,ly,ny) - z = np.linspace(0,lz,nz) + x = np.linspace(0, lx, nx) + y = np.linspace(0, ly, ny) + z = np.linspace(0, lz, nz) # create the dictionary containing the tune parameters tune_params = OrderedDict() - tune_params['block_size_x'] = [2,4,8,16,32] - tune_params['block_size_y'] = [2,4,8,16,32] - tune_params['block_size_z'] = [2,4,8,16,32] + tune_params['block_size_x'] = [2, 4, 8, 16, 32] + tune_params['block_size_y'] = [2, 4, 8, 16, 32] + tune_params['block_size_z'] = [2, 4, 8, 16, 32] # define the final grid grid = np.zeros(grid_info['number_of_points']) # arguments of the CUDA function - x0,y0,z0 = np.float32(0),np.float32(0),np.float32(0) + x0, y0, z0 = np.float32(0), np.float32(0), np.float32(0) alpha = np.float32(0) - args = [alpha,x0,y0,z0,x,y,z,grid] + args = [alpha, x0, y0, z0, x, y, z, grid] # dimensionality problem_size = grid_info['number_of_points'] @@ -931,23 +1000,30 @@ def _tune_cuda_kernel(self,grid_info,cuda_kernel='kernel_map.c',func='gaussian') npts = grid_info['number_of_points'] res = grid_info['resolution'] - kernel_code = kernel_code_template % {'nx' : npts[0], 'ny': npts[1], 'nz' : npts[2], 'RES' : np.max(res)} + kernel_code = kernel_code_template % { + 'nx': npts[0], + 'ny': npts[1], + 'nz': npts[2], + 'RES': np.max(res)} tunable_kernel = self._tunable_kernel(kernel_code) # tune - result = tune_kernel(func, tunable_kernel,problem_size,args,tune_params) + result = tune_kernel( + func, + tunable_kernel, + problem_size, + args, + tune_params) -#==================================================================================== +# ==================================================================================== # # Simply test the kernel # -#==================================================================================== - - def _test_cuda(self,grid_info,gpu_block=8,cuda_kernel='kernel_map.c',func='gaussian'): # pragma: no cover +# ==================================================================================== - ''' - Test the CUDA kernel + def _test_cuda(self, grid_info, gpu_block=8, cuda_kernel='kernel_map.c', func='gaussian'): # pragma: no cover + """Test the CUDA kernel. Args: grid_info (dict): Information for the grid definition @@ -957,67 +1033,68 @@ def _test_cuda(self,grid_info,gpu_block=8,cuda_kernel='kernel_map.c',func='gauss Raises: ValueError: If the kernel has not been installed - ''' + """ from time import time # fills in the grid data if not provided : default = NONE - grinfo = ['number_of_points','resolution'] - for gr in grinfo: + grinfo = ['number_of_points', 'resolution'] + for gr in grinfo: if gr not in grid_info: raise ValueError('%s must be specified to tune the kernel') # get the cuda function npts = grid_info['number_of_points'] res = grid_info['resolution'] - module = self._compile_cuda_kernel(cuda_kernel,npts,res) - cuda_func = self._get_cuda_function(module,func) + module = self._compile_cuda_kernel(cuda_kernel, npts, res) + cuda_func = self._get_cuda_function(module, func) # define the grid center_contact = np.zeros(3) - nx,ny,nz = grid_info['number_of_points'] - dx,dy,dz = grid_info['resolution'] - lx,ly,lz = nx*dx,ny*dy,nz*dz + nx, ny, nz = grid_info['number_of_points'] + dx, dy, dz = grid_info['resolution'] + lx, ly, lz = nx * dx, ny * dy, nz * dz # create the coordinate - x = np.linspace(0,lx,nx) - y = np.linspace(0,ly,ny) - z = np.linspace(0,lz,nz) + x = np.linspace(0, lx, nx) + y = np.linspace(0, ly, ny) + z = np.linspace(0, lz, nz) # book memp on the gpu x_gpu = gpuarray.to_gpu(x.astype(np.float32)) y_gpu = gpuarray.to_gpu(y.astype(np.float32)) z_gpu = gpuarray.to_gpu(z.astype(np.float32)) - grid_gpu = gpuarray.zeros(grid_info['number_of_points'],np.float32) + grid_gpu = gpuarray.zeros(grid_info['number_of_points'], np.float32) # make sure we have three block value - if not isinstance(gpu_block,list): - gpu_block = [gpu_block]*3 + if not isinstance(gpu_block, list): + gpu_block = [gpu_block] * 3 # get the grid - gpu_grid = [ int(np.ceil(n/b)) for b,n in zip(gpu_block,grid_info['number_of_points'])] + gpu_grid = [int(np.ceil(n / b)) + for b, n in zip(gpu_block, grid_info['number_of_points'])] print('GPU BLOCK :', gpu_block) print('GPU GRID :', gpu_grid) - xyz_center = np.random.rand(500,3).astype(np.float32) + xyz_center = np.random.rand(500, 3).astype(np.float32) alpha = np.float32(1) t0 = time() for xyz in xyz_center: - x0,y0,z0 = xyz - cuda_func(alpha,x0,y0,z0,x_gpu,y_gpu,z_gpu,grid_gpu, - block=tuple(gpu_block),grid=tuple(gpu_grid)) + x0, y0, z0 = xyz + cuda_func(alpha, x0, y0, z0, x_gpu, y_gpu, z_gpu, grid_gpu, + block=tuple(gpu_block), grid=tuple(gpu_grid)) - print('Done in : %f ms' %((time()-t0)*1000)) + print('Done in : %f ms' % ((time() - t0) * 1000)) -#==================================================================================== +# ==================================================================================== # # Routines needed to handle CUDA # -#==================================================================================== +# ==================================================================================== @staticmethod - def _compile_cuda_kernel(cuda_kernel,npts,res): # pragma: no cover - """Compile the cuda kernel + def _compile_cuda_kernel(cuda_kernel, npts, res): # pragma: no cover + """Compile the cuda kernel. Args: cuda_kernel (str): filename @@ -1030,15 +1107,19 @@ def _compile_cuda_kernel(cuda_kernel,npts,res): # pragma: no cover # get the cuda kernel path kernel = os.path.dirname(os.path.abspath(__file__)) + '/' + cuda_kernel kernel_code_template = open(kernel, 'r').read() - kernel_code = kernel_code_template % {'nx' : npts[0], 'ny': npts[1], 'nz' : npts[2], 'RES' : np.max(res)} + kernel_code = kernel_code_template % { + 'nx': npts[0], + 'ny': npts[1], + 'nz': npts[2], + 'RES': np.max(res)} # compile the kernel mod = compiler.SourceModule(kernel_code) return mod @staticmethod - def _get_cuda_function(module,func_name): # pragma: no cover - """Get a single function from the compiled kernel + def _get_cuda_function(module, func_name): # pragma: no cover + """Get a single function from the compiled kernel. Args: module (compiler.SourceModule): compiled kernel module @@ -1052,8 +1133,8 @@ def _get_cuda_function(module,func_name): # pragma: no cover # tranform the kernel to a tunable one @staticmethod - def _tunable_kernel(kernel): # pragma: no cover - """Make a tunale kernel + def _tunable_kernel(kernel): # pragma: no cover + """Make a tunale kernel. Args: kernel (str): String of the kernel @@ -1061,17 +1142,20 @@ def _tunable_kernel(kernel): # pragma: no cover Returns: TYPE: tunable kernel """ - switch_name = { 'blockDim.x' : 'block_size_x', 'blockDim.y' : 'block_size_y','blockDim.z' : 'block_size_z' } - for old,new in switch_name.items(): - kernel = kernel.replace(old,new) + switch_name = { + 'blockDim.x': 'block_size_x', + 'blockDim.y': 'block_size_y', + 'blockDim.z': 'block_size_z'} + for old, new in switch_name.items(): + kernel = kernel.replace(old, new) return kernel -#==================================================================================== +# ==================================================================================== # # FILTER DATASET # -#=================================================================================== +# =================================================================================== def _filter_cplx(self): """Filter the name of the complexes.""" @@ -1080,27 +1164,26 @@ def _filter_cplx(self): f = open(self.pdb_select) pdb_name = f.readlines() f.close() - pdb_name = [name.split()[0]+'.pdb' for name in pdb_name] + pdb_name = [name.split()[0] + '.pdb' for name in pdb_name] # create the filters tmp_path = [] for name in pdb_name: - tmp_path += list(filter(lambda x: name in x,self.pdb_path)) + tmp_path += list(filter(lambda x: name in x, self.pdb_path)) # update the pdb_path self.pdb_path = tmp_path - -#==================================================================================== +# ==================================================================================== # # FEATURES ROUTINES # -#==================================================================================== +# ==================================================================================== @staticmethod - def _compute_features(feat_list,pdb_data,featgrp,featgrp_raw): - """Compute the features + def _compute_features(feat_list, pdb_data, featgrp, featgrp_raw): + """Compute the features. Args: feat_list (list(str)): list of function name, e.g., ['deeprank.features.ResidueDensity', 'deeprank.features.PSSM_IC'] @@ -1108,25 +1191,26 @@ def _compute_features(feat_list,pdb_data,featgrp,featgrp_raw): featgrp (str): name of the group where to store the xyz feature featgrp_raw (str): name of the group where to store the raw feature """ - error_flag = False # when False: success; when True: failed + error_flag = False # when False: success; when True: failed for feat in feat_list: - feat_module = importlib.import_module(feat,package=None) - error_flag = feat_module.__compute_feature__(pdb_data,featgrp,featgrp_raw) + feat_module = importlib.import_module(feat, package=None) + error_flag = feat_module.__compute_feature__( + pdb_data, featgrp, featgrp_raw) - if re.search('ResidueDensity', feat) and error_flag == True: + if re.search('ResidueDensity', feat) and error_flag: return error_flag -#==================================================================================== +# ==================================================================================== # # TARGETS ROUTINES # -#==================================================================================== +# ==================================================================================== @staticmethod - def _compute_targets(targ_list,pdb_data,targrp): - """Compute the targets + def _compute_targets(targ_list, pdb_data, targrp): + """Compute the targets. Args: targ_list (list(str)): list of function name @@ -1134,18 +1218,20 @@ def _compute_targets(targ_list,pdb_data,targrp): targrp (str): name of the group where to store the targets """ for targ in targ_list: - targ_module = importlib.import_module(targ,package=None) - targ_module.__compute_target__(pdb_data,targrp) + targ_module = importlib.import_module(targ, package=None) + targ_module.__compute_target__(pdb_data, targrp) -#==================================================================================== +# ==================================================================================== # # ADD PDB FILE # -#==================================================================================== +# ==================================================================================== + + @staticmethod - def _add_pdb(molgrp,pdbfile,name): - """ Add a pdb to a molgrp. + def _add_pdb(molgrp, pdbfile, name): + """Add a pdb to a molgrp. Args: molgrp (str): mopl group where tp add the pdb @@ -1153,22 +1239,25 @@ def _add_pdb(molgrp,pdbfile,name): name (str): dataset name in the hdf5 molgroup """ # read the pdb and extract the ATOM lines - with open(pdbfile,'r') as fi: - data = [line.split('\n')[0] for line in fi if line.startswith('ATOM')] + with open(pdbfile, 'r') as fi: + data = [line.split('\n')[0] + for line in fi if line.startswith('ATOM')] data = np.array(data).astype('|S73') - dataset = molgrp.create_dataset(name,data=data) + dataset = molgrp.create_dataset(name, data=data) -#==================================================================================== +# ==================================================================================== # # AUGMENTED DATA # -#==================================================================================== +# ==================================================================================== # add a rotated pdb structure to the database + + @staticmethod - def _add_aug_pdb(molgrp,pdbfile,name,axis,angle): - """Add augmented pdbs to the dataset + def _add_aug_pdb(molgrp, pdbfile, name, axis, angle): + """Add augmented pdbs to the dataset. Args: molgrp (str): name of the molgroup @@ -1184,7 +1273,7 @@ def _add_aug_pdb(molgrp,pdbfile,name,axis,angle): sqldb = pdb2sql(pdbfile) # rotate the positions - center = sqldb.rotation_around_axis(axis,angle) + center = sqldb.rotation_around_axis(axis, angle) # get the data sqldata = sqldb.get('*') @@ -1200,32 +1289,32 @@ def _add_aug_pdb(molgrp,pdbfile,name,axis,angle): line += ' ' line += '{:^4}'.format(d[1]) # name line += '{:>1}'.format(d[2]) # altLoc - line += '{:>3}'.format(d[3]) #resname + line += '{:>3}'.format(d[3]) # resname line += ' ' line += '{:>1}'.format(d[4]) # chainID line += '{:>4}'.format(d[5]) # resSeq line += '{:>1}'.format(d[6]) # iCODE line += ' ' - line += '{: 8.3f}'.format(d[7]) #x - line += '{: 8.3f}'.format(d[8]) #y - line += '{: 8.3f}'.format(d[9]) #z + line += '{: 8.3f}'.format(d[7]) # x + line += '{: 8.3f}'.format(d[8]) # y + line += '{: 8.3f}'.format(d[9]) # z try: line += '{: 6.2f}'.format(d[10]) # occ line += '{: 6.2f}'.format(d[11]) # temp - except: + except BaseException: line += '{: 6.2f}'.format(0) # occ line += '{: 6.2f}'.format(0) # temp data.append(line) data = np.array(data).astype('|S73') - dataset = molgrp.create_dataset(name,data=data) + dataset = molgrp.create_dataset(name, data=data) return center # rotate th xyz-formatted feature in the database @staticmethod - def _rotate_feature(molgrp,axis,angle,center,feat_name='all'): - """Rotate the raw feature values + def _rotate_feature(molgrp, axis, angle, center, feat_name='all'): + """Rotate the raw feature values. Args: molgrp (str): name pf the molgrp @@ -1238,38 +1327,43 @@ def _rotate_feature(molgrp,axis,angle,center,feat_name='all'): feat = list(molgrp['features'].keys()) else: feat = feat_name - if not isinstance(feat,list): + if not isinstance(feat, list): feat = list(feat) for fn in feat: # extract the data - data = molgrp['features/'+fn][:] + data = molgrp['features/' + fn][:] # xyz - xyz = data[:,1:4] + xyz = data[:, 1:4] # get the data - ct,st = np.cos(angle),np.sin(angle) - ux,uy,uz = axis + ct, st = np.cos(angle), np.sin(angle) + ux, uy, uz = axis # definition of the rotation matrix # see https://en.wikipedia.org/wiki/Rotation_matrix - rot_mat = np.array([ - [ct + ux**2*(1-ct), ux*uy*(1-ct) - uz*st, ux*uz*(1-ct) + uy*st], - [uy*ux*(1-ct) + uz*st, ct + uy**2*(1-ct), uy*uz*(1-ct) - ux*st], - [uz*ux*(1-ct) - uy*st, uz*uy*(1-ct) + ux*st, ct + uz**2*(1-ct) ]]) + rot_mat = np.array([[ct + ux**2 * (1 - ct), + ux * uy * (1 - ct) - uz * st, + ux * uz * (1 - ct) + uy * st], + [uy * ux * (1 - ct) + uz * st, + ct + uy**2 * (1 - ct), + uy * uz * (1 - ct) - ux * st], + [uz * ux * (1 - ct) - uy * st, + uz * uy * (1 - ct) + ux * st, + ct + uz**2 * (1 - ct)]]) # apply the rotation - xyz = np.dot(rot_mat,(xyz-center).T).T + center + xyz = np.dot(rot_mat, (xyz - center).T).T + center # put back the data - data[:,1:4] = xyz + data[:, 1:4] = xyz # get rotation axis and angle @staticmethod def _get_aug_rot(): - """Get the rotation angle/axis + """Get the rotation angle/axis. Returns: list(float): axis of rotation @@ -1278,11 +1372,16 @@ def _get_aug_rot(): # define the axis # uniform distribution on a sphere # http://mathworld.wolfram.com/SpherePointPicking.html - u1,u2 = np.random.rand(),np.random.rand() - teta,phi = np.arccos(2*u1-1),2*np.pi*u2 - axis = [np.sin(teta)*np.cos(phi),np.sin(teta)*np.sin(phi),np.cos(teta)] + u1, u2 = np.random.rand(), np.random.rand() + teta, phi = np.arccos(2 * u1 - 1), 2 * np.pi * u2 + axis = [ + np.sin(teta) * + np.cos(phi), + np.sin(teta) * + np.sin(phi), + np.cos(teta)] # and the rotation angle - angle = -np.pi + np.pi*np.random.rand() + angle = -np.pi + np.pi * np.random.rand() - return axis,angle + return axis, angle diff --git a/deeprank/generate/GridTools.py b/deeprank/generate/GridTools.py index 76bc073d..a598bd02 100644 --- a/deeprank/generate/GridTools.py +++ b/deeprank/generate/GridTools.py @@ -1,37 +1,38 @@ -import numpy as np -import subprocess as sp -import os, sys import itertools -from scipy.signal import bspline +import logging +import os +import subprocess as sp +import sys from collections import OrderedDict from time import time -import logging -from deeprank.tools import pdb2sql -from deeprank.tools import sparse +import numpy as np +from scipy.signal import bspline + +from deeprank.tools import pdb2sql, sparse try: from tqdm import tqdm -except ImportError : +except ImportError: def tqdm(x): return x -printif = lambda string,cond: print(string) if cond else None -# the main gridtool class -class GridTools(object): +def printif(string, cond): return print(string) if cond else None - def __init__(self, molgrp, - number_of_points = 30,resolution = 1., - atomic_densities=None, atomic_densities_mode='ind', - feature = None, feature_mode ='ind', - contact_distance = 8.5, hdf5_file=None, - cuda=False, gpu_block=None, cuda_func=None, cuda_atomic=None, - prog_bar = False,time=False,try_sparse=True,logger=None): +# the main gridtool class +class GridTools(object): + def __init__(self, molgrp, + number_of_points=30, resolution=1., + atomic_densities=None, atomic_densities_mode='ind', + feature=None, feature_mode='ind', + contact_distance=8.5, hdf5_file=None, + cuda=False, gpu_block=None, cuda_func=None, cuda_atomic=None, + prog_bar=False, time=False, try_sparse=True, logger=None): """Map the feature of a complex on the grid. Args: @@ -73,25 +74,25 @@ def __init__(self, molgrp, self.try_sparse = try_sparse # export to HDF5 file - self.hdf5.require_group(self.mol_basename+'/features/') + self.hdf5.require_group(self.mol_basename + '/features/') # parameter of the grid if number_of_points is not None: - if not isinstance(number_of_points,list): - number_of_points = [number_of_points]*3 + if not isinstance(number_of_points, list): + number_of_points = [number_of_points] * 3 self.npts = np.array(number_of_points).astype('int') if resolution is not None: - if not isinstance(resolution,list): - resolution = [resolution]*3 - self.res = np.array(resolution) + if not isinstance(resolution, list): + resolution = [resolution] * 3 + self.res = np.array(resolution) # cuda support self.cuda = cuda - if self.cuda: # pragma: no cover + if self.cuda: # pragma: no cover self.gpu_block = gpu_block - self.gpu_grid = [ int(np.ceil(n/b)) for b,n in zip(self.gpu_block,self.npts)] - + self.gpu_grid = [int(np.ceil(n / b)) + for b, n in zip(self.gpu_block, self.npts)] # parameter of the atomic system self.atom_xyz = None @@ -129,19 +130,20 @@ def __init__(self, molgrp, # if we already have an output containing the grid # we update the existing features _update_ = False - if self.mol_basename+'/grid_points/x' in self.hdf5: + if self.mol_basename + '/grid_points/x' in self.hdf5: _update_ = True if _update_: - printif('\n= Updating grid data for %s' %(self.mol_basename),self.time) + printif( + '\n= Updating grid data for %s' % + (self.mol_basename), self.time) self.update_feature() else: - printif('\n= Creating grid and grid data for %s' %(self.mol_basename),self.time) + printif( + '\n= Creating grid and grid data for %s' % + (self.mol_basename), self.time) self.create_new_data() - - - ################################################################ def create_new_data(self): @@ -150,7 +152,7 @@ def create_new_data(self): # get the position/atom type .. of the complex self.read_pdb() - #get the contact atoms + # get the contact atoms self.get_contact_atoms() # define the grid @@ -159,7 +161,7 @@ def create_new_data(self): # save the grid points self.export_grid_points() - #map the features + # map the features self.add_all_features() # if we wnat the atomic densisties @@ -178,20 +180,23 @@ def update_feature(self): # read the grid from the hdf5 if self.hdf5 is not None: - grid = self.hdf5.get(self.mol_basename+'/grid_points/') - self.x,self.y,self.z = grid['x'][()],grid['y'][()],grid['z'][()] + grid = self.hdf5.get(self.mol_basename + '/grid_points/') + self.x, self.y, self.z = grid['x'][( + )], grid['y'][()], grid['z'][()] # or read the grid points from file else: - grid = np.load(self.export_path+'/grid_points.npz') - self.x,self.y,self.z = grid['x'], grid['y'], grid['z'] + grid = np.load(self.export_path + '/grid_points.npz') + self.x, self.y, self.z = grid['x'], grid['y'], grid['z'] # create the grid - self.ygrid,self.xgrid,self.zgrid = np.meshgrid(self.y,self.x,self.z) + self.ygrid, self.xgrid, self.zgrid = np.meshgrid( + self.y, self.x, self.z) # set the resolution/dimension - self.npts = np.array([len(self.x),len(self.y),len(self.z)]) - self.res = np.array([self.x[1]-self.x[0],self.y[1]-self.y[0],self.z[1]-self.z[0]]) + self.npts = np.array([len(self.x), len(self.y), len(self.z)]) + self.res = np.array( + [self.x[1] - self.x[0], self.y[1] - self.y[0], self.z[1] - self.z[0]]) # map the features self.add_all_features() @@ -204,25 +209,28 @@ def update_feature(self): ################################################################ - def read_pdb(self): - """ Create a sql databse for the pdb.""" + """Create a sql databse for the pdb.""" self.sqldb = pdb2sql(self.molgrp['complex'][:]) - # get the contact atoms + def get_contact_atoms(self): """Get the contact atoms.""" - xyz1 = np.array(self.sqldb.get('x,y,z',chainID='A')) - xyz2 = np.array(self.sqldb.get('x,y,z',chainID='B')) + xyz1 = np.array(self.sqldb.get('x,y,z', chainID='A')) + xyz2 = np.array(self.sqldb.get('x,y,z', chainID='B')) - index_b = self.sqldb.get('rowID',chainID='B') + index_b = self.sqldb.get('rowID', chainID='B') self.contact_atoms = [] - for i,x0 in enumerate(xyz1): - contacts = np.where(np.sqrt(np.sum((xyz2-x0)**2,1)) < self.contact_distance)[0] + for i, x0 in enumerate(xyz1): + contacts = np.where( + np.sqrt( + np.sum( + (xyz2 - x0)**2, + 1)) < self.contact_distance)[0] if len(contacts) > 0: self.contact_atoms += [i] @@ -232,9 +240,12 @@ def get_contact_atoms(self): self.contact_atoms = list(set(self.contact_atoms)) # get the mean xyz position - self.center_contact = np.mean(np.array(self.sqldb.get('x,y,z',rowID=self.contact_atoms)),0) - - + self.center_contact = np.mean( + np.array( + self.sqldb.get( + 'x,y,z', + rowID=self.contact_atoms)), + 0) ################################################################ # shortcut to add all the feature a @@ -242,10 +253,11 @@ def get_contact_atoms(self): ################################################################ # add all the residue features to the data + def add_all_features(self): """Add all the features toa given molecule.""" - #map the features + # map the features if self.feature is not None: # map the residue features @@ -253,12 +265,12 @@ def add_all_features(self): # save to hdf5 if specfied t0 = time() - printif('-- Save Features to HDF5',self.time) - self.hdf5_grid_data(dict_data,'Feature_%s' %( self.feature_mode)) - printif(' Total %f ms' %((time()-t0)*1000),self.time) - + printif('-- Save Features to HDF5', self.time) + self.hdf5_grid_data(dict_data, 'Feature_%s' % (self.feature_mode)) + printif(' Total %f ms' % ((time() - t0) * 1000), self.time) # add all the atomic densities to the data + def add_all_atomic_densities(self): """Add all atomic densities.""" @@ -270,11 +282,12 @@ def add_all_atomic_densities(self): # save to hdf5 t0 = time() - printif('-- Save Atomic Densities to HDF5',self.time) - self.hdf5_grid_data(self.atdens,'AtomicDensities_%s' %(self.atomic_densities_mode)) - printif(' Total %f ms' %((time()-t0)*1000),self.time) - - + printif('-- Save Atomic Densities to HDF5', self.time) + self.hdf5_grid_data( + self.atdens, + 'AtomicDensities_%s' % + (self.atomic_densities_mode)) + printif(' Total %f ms' % ((time() - t0) * 1000), self.time) ################################################################ # define the grid points @@ -287,25 +300,26 @@ def add_all_atomic_densities(self): def define_grid_points(self): """Define the grid points.""" - printif('-- Define %dx%dx%d grid ' %(self.npts[0],self.npts[1],self.npts[2]),self.time) - printif('-- Resolution of %1.2fx%1.2fx%1.2f Angs' %(self.res[0],self.res[1],self.res[2]),self.time) - + printif('-- Define %dx%dx%d grid ' % + (self.npts[0], self.npts[1], self.npts[2]), self.time) + printif('-- Resolution of %1.2fx%1.2fx%1.2f Angs' % + (self.res[0], self.res[1], self.res[2]), self.time) - halfdim = 0.5*(self.npts*self.res) + halfdim = 0.5 * (self.npts * self.res) center = self.center_contact - low_lim = center-halfdim - hgh_lim = low_lim + self.res*(np.array(self.npts)-1) - - self.x = np.linspace(low_lim[0],hgh_lim[0],self.npts[0]) - self.y = np.linspace(low_lim[1],hgh_lim[1],self.npts[1]) - self.z = np.linspace(low_lim[2],hgh_lim[2],self.npts[2]) + low_lim = center - halfdim + hgh_lim = low_lim + self.res * (np.array(self.npts) - 1) + self.x = np.linspace(low_lim[0], hgh_lim[0], self.npts[0]) + self.y = np.linspace(low_lim[1], hgh_lim[1], self.npts[1]) + self.z = np.linspace(low_lim[2], hgh_lim[2], self.npts[2]) # there is something fishy about the meshgrid 3d # the axis are a bit screwy .... # i dont quite get why the ordering is like that - self.ygrid,self.xgrid,self.zgrid = np.meshgrid(self.y,self.x,self.z) + self.ygrid, self.xgrid, self.zgrid = np.meshgrid( + self.y, self.x, self.z) ################################################################ # Atomic densities @@ -313,8 +327,8 @@ def define_grid_points(self): ################################################################ # compute all the atomic densities data - def map_atomic_densities(self,only_contact=True): - """Map the atomic densities to the grid + def map_atomic_densities(self, only_contact=True): + """Map the atomic densities to the grid. Args: only_contact (bool, optional): Map only the contact atoms @@ -323,58 +337,79 @@ def map_atomic_densities(self,only_contact=True): ImportError: Description """ mode = self.atomic_densities_mode - printif('-- Map atomic densities on %dx%dx%d grid (mode=%s)'%(self.npts[0],self.npts[1],self.npts[2],mode),self.time) + printif('-- Map atomic densities on %dx%dx%d grid (mode=%s)' % + (self.npts[0], self.npts[1], self.npts[2], mode), self.time) # prepare the cuda memory - if self.cuda: # pragma: no cover + if self.cuda: # pragma: no cover # try to import pycuda try: from pycuda import driver, compiler, gpuarray, tools import pycuda.autoinit - except: + except BaseException: raise ImportError("Error when importing pyCuda in GridTools") # book mem on the gpu x_gpu = gpuarray.to_gpu(self.x.astype(np.float32)) y_gpu = gpuarray.to_gpu(self.y.astype(np.float32)) z_gpu = gpuarray.to_gpu(self.z.astype(np.float32)) - grid_gpu = gpuarray.zeros(self.npts,np.float32) + grid_gpu = gpuarray.zeros(self.npts, np.float32) # get the contact atoms if only_contact: index = self.sqldb.get_contact_atoms() # loop over all the data we want - for atomtype,vdw_rad in self.local_tqdm(self.atomic_densities.items()): + for atomtype, vdw_rad in self.local_tqdm( + self.atomic_densities.items()): t0 = time() # get the contact atom that of the correct type on both chains if only_contact: #index = self.sqldb.get_contact_atoms() - xyzA = np.array(self.sqldb.get('x,y,z',rowID=index[0],name=atomtype)) - xyzB = np.array(self.sqldb.get('x,y,z',rowID=index[1],name=atomtype)) + xyzA = np.array( + self.sqldb.get( + 'x,y,z', + rowID=index[0], + name=atomtype)) + xyzB = np.array( + self.sqldb.get( + 'x,y,z', + rowID=index[1], + name=atomtype)) else: # get the atom that are of the correct type on both chains - xyzA = np.array(self.sqldb.get('x,y,z',chainID='A',name=atomtype)) - xyzB = np.array(self.sqldb.get('x,y,z',chainID='B',name=atomtype)) - - tprocess = time()-t0 + xyzA = np.array( + self.sqldb.get( + 'x,y,z', + chainID='A', + name=atomtype)) + xyzB = np.array( + self.sqldb.get( + 'x,y,z', + chainID='B', + name=atomtype)) + + tprocess = time() - t0 t0 = time() # if we use CUDA - if self.cuda: # pragma: no cover + if self.cuda: # pragma: no cover # reset the grid grid_gpu *= 0 # get the atomic densities of chain A for pos in xyzA: - x0,y0,z0 = pos.astype(np.float32) + x0, y0, z0 = pos.astype(np.float32) vdw = np.float32(vdw_rad) - self.cuda_atomic(vdw,x0,y0,z0,x_gpu,y_gpu,z_gpu,grid_gpu,block=tuple(self.gpu_block),grid=tuple(self.gpu_grid)) + self.cuda_atomic( + vdw, x0, y0, z0, x_gpu, y_gpu, z_gpu, grid_gpu, block=tuple( + self.gpu_block), grid=tuple( + self.gpu_grid)) atdensA = grid_gpu.get() # reset the grid @@ -382,9 +417,12 @@ def map_atomic_densities(self,only_contact=True): # get the atomic densities of chain B for pos in xyzB: - x0,y0,z0 = pos.astype(np.float32) + x0, y0, z0 = pos.astype(np.float32) vdw = np.float32(vdw_rad) - self.cuda_atomic(vdw,x0,y0,z0,x_gpu,y_gpu,z_gpu,grid_gpu,block=tuple(self.gpu_block),grid=tuple(self.gpu_grid)) + self.cuda_atomic( + vdw, x0, y0, z0, x_gpu, y_gpu, z_gpu, grid_gpu, block=tuple( + self.gpu_block), grid=tuple( + self.gpu_grid)) atdensB = grid_gpu.get() # if we don't use CUDA @@ -396,37 +434,35 @@ def map_atomic_densities(self,only_contact=True): # run on the atoms for pos in xyzA: - atdensA += self.densgrid(pos,vdw_rad) + atdensA += self.densgrid(pos, vdw_rad) # run on the atoms for pos in xyzB: - atdensB += self.densgrid(pos,vdw_rad) + atdensB += self.densgrid(pos, vdw_rad) # create the final grid : A - B - if mode=='diff': - self.atdens[atomtype] = atdensA-atdensB + if mode == 'diff': + self.atdens[atomtype] = atdensA - atdensB # create the final grid : A + B - elif mode=='sum': - self.atdens[atomtype] = atdensA+atdensB + elif mode == 'sum': + self.atdens[atomtype] = atdensA + atdensB # create the final grid : A and B - elif mode=='ind': - self.atdens[atomtype+'_chainA'] = atdensA - self.atdens[atomtype+'_chainB'] = atdensB + elif mode == 'ind': + self.atdens[atomtype + '_chainA'] = atdensA + self.atdens[atomtype + '_chainB'] = atdensB else: - print('Error: Atomic density mode %s not recognized' %mode) + print('Error: Atomic density mode %s not recognized' % mode) sys.exit() - tgrid = time()-t0 - printif(' Process time %f ms' %(tprocess*1000),self.time) - printif(' Grid time %f ms' %(tgrid*1000),self.time) + tgrid = time() - t0 + printif(' Process time %f ms' % (tprocess * 1000), self.time) + printif(' Grid time %f ms' % (tgrid * 1000), self.time) # compute the atomic denisties on the grid - def densgrid(self,center,vdw_radius): - - ''' Function to map individual atomic density on the grid. - + def densgrid(self, center, vdw_radius): + """Function to map individual atomic density on the grid. The formula is equation (1) of the Koes paper Protein-Ligand Scoring with Convolutional NN Arxiv:1612.02751v1 @@ -437,13 +473,16 @@ def densgrid(self,center,vdw_radius): Returns: TYPE: np.array (mapped density) - ''' + """ - x0,y0,z0 = center - dd = np.sqrt( (self.xgrid-x0)**2 + (self.ygrid-y0)**2 + (self.zgrid-z0)**2 ) + x0, y0, z0 = center + dd = np.sqrt((self.xgrid - x0)**2 + (self.ygrid - y0) + ** 2 + (self.zgrid - z0)**2) dgrid = np.zeros(self.npts) - dgrid[dd=vdw_radius) & (dd<1.5*vdw_radius)] = 4./np.e**2/vdw_radius**2*dd[ (dd >=vdw_radius) & (dd<1.5*vdw_radius)]**2 - 12./np.e**2/vdw_radius*dd[ (dd >=vdw_radius) & (dd<1.5*vdw_radius)] + 9./np.e**2 + dgrid[dd < vdw_radius] = np.exp(-2 * + dd[dd < vdw_radius]**2 / vdw_radius**2) + dgrid[(dd >= vdw_radius) & (dd < 1.5 * vdw_radius)] = 4. / np.e**2 / vdw_radius**2 * dd[(dd >= vdw_radius) & (dd < + 1.5 * vdw_radius)]**2 - 12. / np.e**2 / vdw_radius * dd[(dd >= vdw_radius) & (dd < 1.5 * vdw_radius)] + 9. / np.e**2 return dgrid ################################################################ @@ -454,8 +493,7 @@ def densgrid(self,center,vdw_radius): # map residue a feature on the grid def map_features(self, featlist, transform=None): - - '''Map individual feature to the grid. + """Map individual feature to the grid. For residue based feature the feature file must be of the format chainID residue_name(3-letter) residue_number [values] @@ -473,41 +511,50 @@ def map_features(self, featlist, transform=None): Raises: ImportError: Description ValueError: Description - ''' + """ # declare the total dictionary dict_data = {} # prepare the cuda memory - if self.cuda: # pragma: no cover + if self.cuda: # pragma: no cover # try to import pycuda try: from pycuda import driver, compiler, gpuarray, tools import pycuda.autoinit - except: + except BaseException: raise ImportError("Error when importing pyCuda in GridTools") # book mem on the gpu x_gpu = gpuarray.to_gpu(self.x.astype(np.float32)) y_gpu = gpuarray.to_gpu(self.y.astype(np.float32)) z_gpu = gpuarray.to_gpu(self.z.astype(np.float32)) - grid_gpu = gpuarray.zeros(self.npts,np.float32) + grid_gpu = gpuarray.zeros(self.npts, np.float32) # loop over all the features required for feature_name in featlist: - - printif('-- Map %s on %dx%dx%d grid ' %(feature_name,self.npts[0],self.npts[1],self.npts[2]),self.time) + printif( + '-- Map %s on %dx%dx%d grid ' % + (feature_name, + self.npts[0], + self.npts[1], + self.npts[2]), + self.time) # read the data featgrp = self.molgrp['features'] if feature_name in featgrp.keys(): data = featgrp[feature_name][:] else: - print('Error Feature not found \n\tPossible features : ' + ' | '.join(featgrp.keys()) ) - raise ValueError('feature %s not found in the file' %(feature_name)) - + print( + 'Error Feature not found \n\tPossible features : ' + + ' | '.join( + featgrp.keys())) + raise ValueError( + 'feature %s not found in the file' % + (feature_name)) # detect if we have a xyz format # or a byte format @@ -516,15 +563,15 @@ def map_features(self, featlist, transform=None): # xyz : 4 (chain x y z) # byte - residue : 3 (chain resSeq resName) # byte - atomic : 4 (chain resSeq resName name) - if not isinstance(data[0],bytes): + if not isinstance(data[0], bytes): feature_type = 'xyz' ntext = 4 else: - try : + try: float(data[0].split()[3]) feature_type = 'residue' ntext = 3 - except: + except BaseException: feature_type = 'atomic' ntext = 4 @@ -533,12 +580,12 @@ def map_features(self, featlist, transform=None): # get the data on the first line if feature_type != 'xyz': data_test = data[0].split() - data_test = list(map(float,data_test[ntext:])) + data_test = list(map(float, data_test[ntext:])) else: - data_test = data[0,ntext:] + data_test = data[0, ntext:] # define the length of the output - if transform == None: + if transform is None: nFeat = len(data_test) elif callable(transform): nFeat = len(transform(data_test)) @@ -550,20 +597,23 @@ def map_features(self, featlist, transform=None): # that will in fine holds all the data if nFeat == 1: if self.feature_mode == 'ind': - dict_data[feature_name+'_chainA'] = np.zeros(self.npts) - dict_data[feature_name+'_chainB'] = np.zeros(self.npts) + dict_data[feature_name + '_chainA'] = np.zeros(self.npts) + dict_data[feature_name + '_chainB'] = np.zeros(self.npts) else: dict_data[feature_name] = np.zeros(self.npts) else: for iF in range(nFeat): if self.feature_mode == 'ind': - dict_data[feature_name+'_chainA_%03d' %iF] = np.zeros(self.npts) - dict_data[feature_name+'_chainB_%03d' %iF] = np.zeros(self.npts) + dict_data[feature_name + '_chainA_%03d' % + iF] = np.zeros(self.npts) + dict_data[feature_name + '_chainB_%03d' % + iF] = np.zeros(self.npts) else: - dict_data[feature_name+'_%03d' %iF] = np.zeros(self.npts) + dict_data[feature_name + '_%03d' % + iF] = np.zeros(self.npts) # rest the grid and get the x y z values - if self.cuda: # pragma: no cover + if self.cuda: # pragma: no cover grid_gpu *= 0 # timing @@ -577,7 +627,7 @@ def map_features(self, featlist, transform=None): # i.e chain x y z values if feature_type == 'xyz': - chain = ['A','B'][int(line[0])] + chain = ['A', 'B'][int(line[0])] pos = line[1:ntext] feat_values = np.array(line[ntext:]) @@ -589,7 +639,7 @@ def map_features(self, featlist, transform=None): line = line.decode('utf-8').split() # get the position of the resnumber - chain,resName,resNum = line[0],line[1],line[2] + chain, resName, resNum = line[0], line[1], line[2] # get the atom name for atomic data if feature_type == 'atomic': @@ -597,29 +647,52 @@ def map_features(self, featlist, transform=None): # get the position if feature_type == 'residue': - pos = np.mean(np.array(self.sqldb.get('x,y,z',chainID=chain,resSeq=resNum)),0) - sql_resName = list(set(self.sqldb.get('resName',chainID=chain,resSeq=resNum))) + pos = np.mean( + np.array( + self.sqldb.get( + 'x,y,z', + chainID=chain, + resSeq=resNum)), + 0) + sql_resName = list( + set(self.sqldb.get('resName', chainID=chain, resSeq=resNum))) else: - pos = np.array(self.sqldb.get('x,y,z',chainID=chain,resSeq=resNum,name=atName))[0] - sql_resName = list(set(self.sqldb.get('resName',chainID=chain,resSeq=resNum,name=atName))) + pos = np.array( + self.sqldb.get( + 'x,y,z', + chainID=chain, + resSeq=resNum, + name=atName))[0] + sql_resName = list( + set(self.sqldb.get('resName', chainID=chain, resSeq=resNum, name=atName))) # check if the resname correspond if len(sql_resName) == 0: print('Error : SQL query returned empty list') - print('Tip : Make sure the parameter file %s' %(feature_file)) - print('Tip : corresponds to the pdb file %s' %(self.sqldb.pdbfile)) + print( + 'Tip : Make sure the parameter file %s' % + (feature_file)) + print( + 'Tip : corresponds to the pdb file %s' % + (self.sqldb.pdbfile)) sys.exit() else: sql_resName = sql_resName[0] if resName != sql_resName: - print('Residue Name Error in the Feature file %s' %(feature_file)) - print('Feature File : chain %s resNum %s resName %s' %(chain,resNum, resName)) - print('SQL data : chain %s resNum %s resName %s' %(chain,resNum, sql_resName)) + print( + 'Residue Name Error in the Feature file %s' % + (feature_file)) + print( + 'Feature File : chain %s resNum %s resName %s' % + (chain, resNum, resName)) + print( + 'SQL data : chain %s resNum %s resName %s' % + (chain, resNum, sql_resName)) sys.exit() # get the values of the feature(s) for thsi residue - feat_values = np.array(list(map(float,line[ntext:]))) + feat_values = np.array(list(map(float, line[ntext:]))) # postporcess the data if callable(transform): @@ -628,46 +701,51 @@ def map_features(self, featlist, transform=None): # handle the mode fname = feature_name if self.feature_mode == "diff": - coeff = {'A':1,'B':-1}[chain] + coeff = {'A': 1, 'B': -1}[chain] else: coeff = 1 if self.feature_mode == "ind": fname = feature_name + "_chain" + chain - tprocess += time()-t0 + tprocess += time() - t0 t0 = time() # map this feature(s) on the grid(s) if not self.cuda: if nFeat == 1: - dict_data[fname] += coeff*self.featgrid(pos,feat_values) + dict_data[fname] += coeff * \ + self.featgrid(pos, feat_values) else: for iF in range(nFeat): - dict_data[fname+'_%03d' %iF] += coeff*self.featgrid(pos,feat_values[iF]) + dict_data[fname + '_%03d' % iF] += coeff * \ + self.featgrid(pos, feat_values[iF]) # try to use cuda to speed it up - else: # pragma: no cover + else: # pragma: no cover if nFeat == 1: - x0,y0,z0 = pos.astype(np.float32) - alpha = np.float32(coeff*feat_values) - self.cuda_func(alpha,x0,y0,z0,x_gpu,y_gpu,z_gpu,grid_gpu,block=tuple(self.gpu_block),grid=tuple(self.gpu_grid)) + x0, y0, z0 = pos.astype(np.float32) + alpha = np.float32(coeff * feat_values) + self.cuda_func( + alpha, x0, y0, z0, x_gpu, y_gpu, z_gpu, grid_gpu, block=tuple( + self.gpu_block), grid=tuple( + self.gpu_grid)) else: - raise ValueError('CUDA only possible for single-valued features so far') + raise ValueError( + 'CUDA only possible for single-valued features so far') - tgrid += time()-t0 + tgrid += time() - t0 - if self.cuda: # pragma: no cover + if self.cuda: # pragma: no cover dict_data[fname] = grid_gpu.get() driver.Context.synchronize() - - printif(' Process time %f ms' %(tprocess*1000),self.time) - printif(' Grid time %f ms' %(tgrid*1000),self.time) + printif(' Process time %f ms' % (tprocess * 1000), self.time) + printif(' Grid time %f ms' % (tgrid * 1000), self.time) return dict_data # compute the a given feature on the grid - def featgrid(self,center,value,type_='fast_gaussian'): - '''Map an individual feature (atomic or residue) on the grid + def featgrid(self, center, value, type_='fast_gaussian'): + """Map an individual feature (atomic or residue) on the grid. Args: center (list(float)): position of the feature center @@ -679,43 +757,52 @@ def featgrid(self,center,value,type_='fast_gaussian'): Raises: ValueError: Description - ''' + """ # shortcut for th center - x0,y0,z0 = center - sigma = np.sqrt(1./2) - beta = 0.5/(sigma**2) + x0, y0, z0 = center + sigma = np.sqrt(1. / 2) + beta = 0.5 / (sigma**2) # simple Gaussian if type_ == 'gaussian': - dd = np.sqrt( (self.xgrid-x0)**2 + (self.ygrid-y0)**2 + (self.zgrid-z0)**2 ) - dd = value*np.exp(-beta*dd) + dd = np.sqrt((self.xgrid - x0)**2 + (self.ygrid - y0) + ** 2 + (self.zgrid - z0)**2) + dd = value * np.exp(-beta * dd) return dd # fast gaussian elif type_ == 'fast_gaussian': - cutoff = 5.*beta + cutoff = 5. * beta - dd = np.sqrt( (self.xgrid-x0)**2 + (self.ygrid-y0)**2 + (self.zgrid-z0)**2 ) + dd = np.sqrt((self.xgrid - x0)**2 + (self.ygrid - y0) + ** 2 + (self.zgrid - z0)**2) dgrid = np.zeros(self.npts) - dgrid[dd>> norm = NormalizeData('1ak4.hdf5') >>> norm.get() - """ self.fname = fname - self.parameters = {'features':{},'targets':{}} + self.parameters = {'features': {}, 'targets': {}} self.shape = shape self.fexport = os.path.splitext(self.fname)[0] + '_norm.pckl' self.skip_feature = [] self.skip_target = [] - - def get(self): """Get the normalization and write them to file.""" @@ -44,31 +44,29 @@ def get(self): self._process_data() self._export_data() - def _load(self): """Load data from already existing normalization file.""" if os.path.isfile(self.fexport): - f = open(self.fexport,'rb') + f = open(self.fexport, 'rb') self.parameters = pickle.load(f) f.close() - for _,feat_name in self.parameters['features'].items(): - for name,_ in feat_name.items(): + for _, feat_name in self.parameters['features'].items(): + for name, _ in feat_name.items(): self.skip_feature.append(name) for target in self.parameters['targets'].keys(): self.skip_target.append(target) - def _extract_shape(self): """Get the shape of the data in the hdf5 file.""" if self.shape is not None: return - f5 = h5py.File(self.fname,'r') + f5 = h5py.File(self.fname, 'r') mol = list(f5.keys())[0] mol_data = f5.get(mol) @@ -77,26 +75,27 @@ def _extract_shape(self): nx = mol_data['grid_points']['x'].shape[0] ny = mol_data['grid_points']['y'].shape[0] nz = mol_data['grid_points']['z'].shape[0] - self.shape=(nx,ny,nz) + self.shape = (nx, ny, nz) else: - raise ValueError('Impossible to determine sparse grid shape.\\n Specify argument grid_shape=(x,y,z)') + raise ValueError( + 'Impossible to determine sparse grid shape.\\n Specify argument grid_shape=(x,y,z)') def _extract_data(self): """Extract the data from the different maps.""" - f5 = h5py.File(self.fname,'r') + f5 = h5py.File(self.fname, 'r') mol_names = list(f5.keys()) self.nmol = len(mol_names) # loop over the molecules for mol in mol_names: - #get the mapped features group - data_group = f5.get(mol+'/mapped_features/') + # get the mapped features group + data_group = f5.get(mol + '/mapped_features/') # loop over all the feature types - for feat_types,feat_names in data_group.items(): + for feat_types, feat_names in data_group.items(): # if feature type not in param add if feat_types not in self.parameters['features']: @@ -111,10 +110,11 @@ def _extract_data(self): # create the param if it doesn't already exists if name not in self.parameters['features'][feat_types]: - self.parameters['features'][feat_types][name] = NormParam() + self.parameters['features'][feat_types][name] = NormParam( + ) # load the matrix - feat_data = data_group[feat_types+'/'+name] + feat_data = data_group[feat_types + '/' + name] if feat_data.attrs['sparse']: mat = sparse.FLANgrid(sparse=True, index=feat_data['index'][:], @@ -124,13 +124,14 @@ def _extract_data(self): mat = feat_data['value'][:] # add the parameter (mean and var) - self.parameters['features'][feat_types][name].add(np.mean(mat),np.var(mat)) + self.parameters['features'][feat_types][name].add( + np.mean(mat), np.var(mat)) # get the target groups - target_group = f5.get(mol+'/targets') + target_group = f5.get(mol + '/targets') # loop over all the targets - for tname,tval in target_group.items(): + for tname, tval in target_group.items(): # we skip the already computed target if tname in self.skip_target: @@ -147,15 +148,16 @@ def _extract_data(self): def _process_data(self): """Compute the standard deviation of the data.""" - for feat_types,feat_dict in self.parameters['features'].items(): + for feat_types, feat_dict in self.parameters['features'].items(): for feat in feat_dict: - self.parameters['features'][feat_types][feat].process(self.nmol) + self.parameters['features'][feat_types][feat].process( + self.nmol) def _export_data(self): """Pickle the data to file.""" - f = open(self.fexport,'wb') - pickle.dump(self.parameters,f) + f = open(self.fexport, 'wb') + pickle.dump(self.parameters, f) f.close() @@ -184,18 +186,19 @@ def __init__(self, std=0, mean=0, var=0, sqmean=0): self.var = var self.sqmean = sqmean - def add(self,mean,var): - """ Add the mean value, sqmean and variance of a new molecule to the corresponding attributes.""" + def add(self, mean, var): + """Add the mean value, sqmean and variance of a new molecule to the + corresponding attributes.""" self.mean += mean self.sqmean += mean**2 self.var += var - def process(self,n): + def process(self, n): """Compute the standard deviation of the ensemble.""" # normalize the mean and var - self.mean /= n - self.var /= n + self.mean /= n + self.var /= n self.sqmean /= n # get the std @@ -204,6 +207,7 @@ def process(self,n): self.std -= self.mean**2 self.std = np.sqrt(self.std) + class MinMaxParam(object): """Compute the min/max of an ensenble of data. @@ -213,18 +217,17 @@ class MinMaxParam(object): Args: minv (float, optional): minimal value maxv (float, optional): maximal value - """ - def __init__(self,minv=None,maxv=None): + def __init__(self, minv=None, maxv=None): self.min = minv self.max = maxv - def update(self,val): + def update(self, val): if self.min is None: self.min = val self.max = val else: - self.min = min(self.min,val) - self.max = max(self.max,val) + self.min = min(self.min, val) + self.max = max(self.max, val) diff --git a/deeprank/generate/__init__.py b/deeprank/generate/__init__.py index bad83c15..e927ce96 100644 --- a/deeprank/generate/__init__.py +++ b/deeprank/generate/__init__.py @@ -1,3 +1,3 @@ from .DataGenerator import DataGenerator from .GridTools import GridTools -from .NormalizeData import NormalizeData, NormParam,MinMaxParam \ No newline at end of file +from .NormalizeData import MinMaxParam, NormalizeData, NormParam diff --git a/deeprank/generate/settings.py b/deeprank/generate/settings.py index 76fd4180..63fd55d4 100644 --- a/deeprank/generate/settings.py +++ b/deeprank/generate/settings.py @@ -1,2 +1,2 @@ def init(): - global __PATH_PSSM_FILES__ \ No newline at end of file + global __PATH_PSSM_FILES__ diff --git a/deeprank/learn/DataSet.py b/deeprank/learn/DataSet.py index ec1b13f0..592a4c6c 100644 --- a/deeprank/learn/DataSet.py +++ b/deeprank/learn/DataSet.py @@ -1,41 +1,38 @@ -import os import glob -import sys -import time -import h5py +import os import pickle import re - +import sys +import time from functools import partial -from torch import FloatTensor - +import h5py import numpy as np - -from deeprank.generate import NormParam, MinMaxParam, NormalizeData -from deeprank.tools import sparse, pdb2sql from tqdm import tqdm +from deeprank.generate import MinMaxParam, NormalizeData, NormParam +from deeprank.tools import pdb2sql, sparse +from torch import FloatTensor + # import torch.utils.data as data_utils # The class used to subclass data_utils.Dataset # but that conflict with Sphinx that couldn't build the API # It's apparently not necessary though and works without subclassing -class DataSet(): - - def __init__(self,train_database, valid_database=None, test_database = None, - mapfly = True, grid_info = None, - use_rotation = None, - select_feature = 'all', select_target = 'DOCKQ', - normalize_features = True, normalize_targets = True, - target_ordering = None, - dict_filter = None, pair_chain_feature = None, - transform_to_2D = False, projection = 0, - grid_shape = None, - clip_features = True, clip_factor = 1.5, - tqdm = False, process = True): +class DataSet(): + def __init__(self, train_database, valid_database=None, test_database=None, + mapfly=True, grid_info=None, + use_rotation=None, + select_feature='all', select_target='DOCKQ', + normalize_features=True, normalize_targets=True, + target_ordering=None, + dict_filter=None, pair_chain_feature=None, + transform_to_2D=False, projection=0, + grid_shape=None, + clip_features=True, clip_factor=1.5, + tqdm=False, process=True): '''Generates the dataset needed for pytorch. This class hanldes the data generated by deeprank.generate to be used in the deep learning part of DeepRank. To create an instance you must provide quite a few arguments. @@ -128,7 +125,7 @@ def __init__(self,train_database, valid_database=None, test_database = None, # features/targets selection self.select_feature = select_feature - self.select_target = select_target + self.select_target = select_target # map generation self.mapfly = mapfly @@ -175,7 +172,7 @@ def __init__(self,train_database, valid_database=None, test_database = None, self._get_target_ordering(target_ordering) # print the progress bar or not - self.tqdm=tqdm + self.tqdm = tqdm # process the data if process: @@ -185,48 +182,48 @@ def __init__(self,train_database, valid_database=None, test_database = None, def _get_database_name(database): if database is not None: - if not isinstance(database,list): + if not isinstance(database, list): database = [database] filenames = [] for db in database: filenames += glob.glob(db) - else : + else: filenames = None return filenames - def process_dataset(self): """Process the data set. - Done by default. However must be turned off when one want to test a pretrained model. This can be done - by setting ``process=False`` in the creation of the ``DataSet`` instance. + + Done by default. However must be turned off when one want to + test a pretrained model. This can be done by setting + ``process=False`` in the creation of the ``DataSet`` instance. """ print('\n') - print('='*40) + print('=' * 40) print('=\t DeepRank Data Set') print('=') - print('=\t Training data' ) + print('=\t Training data') for f in self.train_database: - print('=\t ->',f) + print('=\t ->', f) print('=') if self.valid_database is not None: - print('=\t Validation data' ) + print('=\t Validation data') for f in self.valid_database: - print('=\t ->',f) + print('=\t ->', f) print('=') if self.test_database is not None: - print('=\t Test data' ) + print('=\t Test data') for f in self.test_database: - print('=\t ->',f) + print('=\t ->', f) print('=') - print('='*40,'\n') + print('=' * 40, '\n') sys.stdout.flush() - # check if the files are ok self.check_hdf5_files(self.train_database) @@ -236,7 +233,6 @@ def process_dataset(self): if self.test_database is not None: self.test_database = self.check_hdf5_files(self.test_database) - # create the indexing system # alows to associate each mol to an index # and get fname and mol name from the index @@ -258,7 +254,7 @@ def process_dataset(self): self.get_input_shape() # get the target ordering - #self._get_target_ordering() + # self._get_target_ordering() # get renormalization factor if self.normalize_features or self.normalize_targets: @@ -267,21 +263,21 @@ def process_dataset(self): else: self.get_norm() - print('\n') print(" Data Set Info") - print(' Training set : %d conformations' %self.ntrain) + print(' Training set : %d conformations' % self.ntrain) if self.data_augmentation is not None: - print(' Augmentation : %d rotations' %self.data_augmentation) - - print(' Validation set : %d conformations' %self.nvalid) - print(' Test set : %d conformations' %(self.ntest)) - print(' Number of channels : %d' %self.input_shape[0]) - print(' Grid Size : %d x %d x %d' %(self.data_shape[1],self.data_shape[2],self.data_shape[3])) + print( + ' Augmentation : %d rotations' % + self.data_augmentation) + + print(' Validation set : %d conformations' % self.nvalid) + print(' Test set : %d conformations' % (self.ntest)) + print(' Number of channels : %d' % self.input_shape[0]) + print(' Grid Size : %d x %d x %d' % + (self.data_shape[1], self.data_shape[2], self.data_shape[3])) sys.stdout.flush() - - def __len__(self): """Get the length of the dataset Returns: @@ -289,9 +285,9 @@ def __len__(self): """ return len(self.index_complexes) - - def __getitem__(self,index): + def __getitem__(self, index): """Get one item from its unique index. + Args: index (int): index of the complex Returns: @@ -300,13 +296,13 @@ def __getitem__(self,index): debug_time = False t0 = time.time() - fname,mol,angle,axis = self.index_complexes[index] + fname, mol, angle, axis = self.index_complexes[index] t0 = time.time() if self.mapfly: - feature, target = self.map_one_molecule(fname,mol,angle,axis) + feature, target = self.map_one_molecule(fname, mol, angle, axis) else: - feature, target = self.load_one_molecule(fname,mol) + feature, target = self.load_one_molecule(fname, mol) if self.clip_features: feature = self._clip_feature(feature) @@ -318,13 +314,12 @@ def __getitem__(self,index): target = self._normalize_target(target) if self.pair_chain_feature: - feature = self.make_feature_pair(feature,self.pair_chain_feature) + feature = self.make_feature_pair(feature, self.pair_chain_feature) if self.transform: - feature = self.convert2d(feature,self.proj2D) - - return {'mol':[fname,mol],'feature':feature,'target':target} + feature = self.convert2d(feature, self.proj2D) + return {'mol': [fname, mol], 'feature': feature, 'target': target} @staticmethod def check_hdf5_files(database): @@ -334,33 +329,34 @@ def check_hdf5_files(database): remove_file = [] for fname in database: try: - f = h5py.File(fname,'r') + f = h5py.File(fname, 'r') mol_names = list(f.keys()) if len(mol_names) == 0: - print(' -> %s is empty ' %fname) + print(' -> %s is empty ' % fname) remove_file.append(fname) f.close() - except: - print(' -> %s is corrputed ' %fname) + except BaseException: + print(' -> %s is corrputed ' % fname) remove_file.append(fname) for name in remove_file: database.remove(name) return database - def create_index_molecules(self): - '''Create the indexing of each molecule in the dataset. - Create the indexing: [ ('1ak4.hdf5,1AK4_100w),...,('1fqj.hdf5,1FGJ_400w)] - This allows to refer to one complex with its index in the list - ''' + """Create the indexing of each molecule in the dataset. + + Create the indexing: [ + ('1ak4.hdf5,1AK4_100w),...,('1fqj.hdf5,1FGJ_400w)] This allows + to refer to one complex with its index in the list + """ print(" Processing data set") self.index_complexes = [] desc = '{:25s}'.format(' Train dataset') if self.tqdm: - data_tqdm = tqdm(self.train_database,desc=desc,file=sys.stdout) + data_tqdm = tqdm(self.train_database, desc=desc, file=sys.stdout) else: print(' Train dataset') data_tqdm = self.train_database @@ -370,17 +366,17 @@ def create_index_molecules(self): if self.tqdm: data_tqdm.set_postfix(mol=os.path.basename(fdata)) try: - fh5 = h5py.File(fdata,'r') + fh5 = h5py.File(fdata, 'r') mol_names = list(fh5.keys()) mol_names = self._select_pdb(mol_names) for k in mol_names: if self.filter(fh5[k]): #print (f"\tmol {k} passed {self.dict_filter}") - self.index_complexes += [(fdata,k,None,None)] + self.index_complexes += [(fdata, k, None, None)] for irot in range(self.data_augmentation): axis, angle = self._get_aug_rot() - self.index_complexes += [(fdata,k,angle,axis)] + self.index_complexes += [(fdata, k, angle, axis)] fh5.close() except Exception as inst: @@ -394,7 +390,10 @@ def create_index_molecules(self): desc = '{:25s}'.format(' Validation dataset') if self.tqdm: - data_tqdm = tqdm(self.valid_database,desc=desc,file=sys.stdout) + data_tqdm = tqdm( + self.valid_database, + desc=desc, + file=sys.stdout) else: data_tqdm = self.valid_database print(' Validation dataset') @@ -404,24 +403,27 @@ def create_index_molecules(self): if self.tqdm: data_tqdm.set_postfix(mol=os.path.basename(fdata)) try: - fh5 = h5py.File(fdata,'r') + fh5 = h5py.File(fdata, 'r') mol_names = list(fh5.keys()) mol_names = self._select_pdb(mol_names) - self.index_complexes += [(fdata,k,None,None) for k in mol_names] + self.index_complexes += [(fdata, k, None, None) + for k in mol_names] fh5.close() - except: - print('\t\t-->Ignore File : '+fdata) + except BaseException: + print('\t\t-->Ignore File : ' + fdata) self.ntot = len(self.index_complexes) - self.index_valid = list(range(self.ntrain,self.ntot)) + self.index_valid = list(range(self.ntrain, self.ntot)) self.nvalid = self.ntot - self.ntrain - if self.test_database is not None: desc = '{:25s}'.format(' Test dataset') if self.tqdm: - data_tqdm = tqdm(self.test_database,desc=desc,file=sys.stdout) + data_tqdm = tqdm( + self.test_database, + desc=desc, + file=sys.stdout) else: data_tqdm = self.test_database print(' Test dataset') @@ -431,24 +433,24 @@ def create_index_molecules(self): if self.tqdm: data_tqdm.set_postfix(mol=os.path.basename(fdata)) try: - fh5 = h5py.File(fdata,'r') + fh5 = h5py.File(fdata, 'r') mol_names = list(fh5.keys()) mol_names = self._select_pdb(mol_names) # that's the master #self.index_complexes += [(fdata,k) for k in mol_names] # thats what i had in issue25 - self.index_complexes += [(fdata,k,None,None) for k in mol_names] + self.index_complexes += [(fdata, k, None, None) + for k in mol_names] fh5.close() - except: - print('\t\t-->Ignore File : '+fdata) + except BaseException: + print('\t\t-->Ignore File : ' + fdata) self.ntot = len(self.index_complexes) - self.index_test = list(range(self.ntrain + self.nvalid ,self.ntot)) + self.index_test = list(range(self.ntrain + self.nvalid, self.ntot)) self.ntest = self.ntot - self.ntrain - self.nvalid - def _select_pdb(self, mol_names): - """Select complexes + """Select complexes. Args: mol_names (list): list of complex names @@ -458,21 +460,31 @@ def _select_pdb(self, mol_names): """ if self.use_rotation is not None: - fnames_original = list(filter(lambda x: not re.search('_r\d+$',x), mol_names)) + fnames_original = list( + filter( + lambda x: not re.search( + r'_r\d+$', + x), + mol_names)) fnames_augmented = [] if self.use_rotation > 0: for i in range(self.use_rotation): -# fnames_augmented += list(filter(lambda x: '_r%03d' %(i+1) in x, mol_names)) - fnames_augmented += list(filter(lambda x: re.search('_r%03d$' %(i+1), x), mol_names)) + # fnames_augmented += list(filter(lambda x: '_r%03d' %(i+1) in x, mol_names)) + fnames_augmented += list( + filter( + lambda x: re.search( + '_r%03d$' % + (i + 1), x), mol_names)) mol_names = fnames_original + fnames_augmented else: mol_names = fnames_original return mol_names + def filter(self, molgrp): + """Filter the molecule according to a dictionary, e.g., + dict_filter={'DOCKQ':'>0.1', 'IRMSD':'<=4 or >10'}). - def filter(self,molgrp): - '''Filter the molecule according to a dictionary, e.g., dict_filter={'DOCKQ':'>0.1', 'IRMSD':'<=4 or >10'}). The filter is based on the attribute self.dict_filter that must be either of the form: { 'name' : cond } or None Args: @@ -481,24 +493,24 @@ def filter(self,molgrp): bool: True if we keep the complex False otherwise Raises: ValueError: If an unsuported condition is provided - ''' + """ if self.dict_filter is None: return True - for cond_name,cond_vals in self.dict_filter.items(): + for cond_name, cond_vals in self.dict_filter.items(): try: - val = molgrp['targets/'+cond_name][()] + val = molgrp['targets/' + cond_name][()] except KeyError: - print(' :Filter %s not found for mol %s' %(cond_name,mol)) + print(' :Filter %s not found for mol %s' % (cond_name, mol)) # if we have a string it's more complicated - if isinstance(cond_vals,str): + if isinstance(cond_vals, str): - ops = ['>','<','=='] + ops = ['>', '<', '=='] new_cond_vals = cond_vals for o in ops: - new_cond_vals = new_cond_vals.replace(o,'val'+o) + new_cond_vals = new_cond_vals.replace(o, 'val' + o) if not eval(new_cond_vals): return False else: @@ -507,7 +519,6 @@ def filter(self,molgrp): return True def get_mapped_feature_name(self): - ''' Create the dictionary with actual feature_type : [feature names] Add _chainA, _chainB to each feature names if we have individual storage @@ -516,10 +527,10 @@ def get_mapped_feature_name(self): ''' # open a h5 file in case we need it - f5 = h5py.File(self.train_database[0],'r') + f5 = h5py.File(self.train_database[0], 'r') mol_name = list(f5.keys())[0] mapped_data = f5.get(mol_name + '/mapped_features/') - chain_tags = ['_chainA','_chainB'] + chain_tags = ['_chainA', '_chainB'] # if we select all the features if self.select_feature == "all": @@ -528,19 +539,20 @@ def get_mapped_feature_name(self): self.select_feature = {} # loop over the feat types and add all the feat_names - for feat_type,feat_names in mapped_data.items(): + for feat_type, feat_names in mapped_data.items(): self.select_feature[feat_type] = [name for name in feat_names] # if a selection was made else: # we loop over the input dict - for feat_type,feat_names in self.select_feature.items(): + for feat_type, feat_names in self.select_feature.items(): # if for a given type we need all the feature if feat_names == 'all': if feat_type in mapped_data: - self.select_feature[feat_type] = list(mapped_data[feat_type].keys()) + self.select_feature[feat_type] = list( + mapped_data[feat_type].keys()) else: self.print_possible_features() raise KeyError('Feature type %s not found') @@ -566,15 +578,18 @@ def get_mapped_feature_name(self): # we check the matches and add them if '*' in name: match = name.split('*')[0] - possible_names = list(mapped_data[feat_type].keys()) - match_names = [n for n in possible_names if n.startswith(match)] + possible_names = list( + mapped_data[feat_type].keys()) + match_names = [ + n for n in possible_names if n.startswith(match)] self.select_feature[feat_type] += match_names # if we don't have a wild card we append # _chainA and _chainB # to the list else: - self.select_feature[feat_type] += [name+tag for tag in chain_tags] + self.select_feature[feat_type] += [name + + tag for tag in chain_tags] # if there is a chain tag in the name # (we probably relaod a pretrained model) @@ -588,7 +603,6 @@ def get_mapped_feature_name(self): f5.close() def get_raw_feature_name(self): - ''' Create the dictionary with actual feature_type : [feature names] @@ -598,7 +612,7 @@ def get_raw_feature_name(self): ''' # open a h5 file in case we need it - f5 = h5py.File(self.train_database[0],'r') + f5 = h5py.File(self.train_database[0], 'r') mol_name = list(f5.keys())[0] mol_data = f5.get(mol_name) raw_data = f5.get(mol_name + '/features/') @@ -609,21 +623,25 @@ def get_raw_feature_name(self): # redefine dict self.select_feature = {} print("Select atomic densities for CA, C, N, O") - self.select_feature['AtomicDensities'] = {'CA':1.7, 'C':1.7, 'N':1.55, 'O':1.52} - self.select_feature['Features'] = [name for name in raw_data.keys()] + self.select_feature['AtomicDensities'] = { + 'CA': 1.7, 'C': 1.7, 'N': 1.55, 'O': 1.52} + self.select_feature['Features'] = [ + name for name in raw_data.keys()] # if a selection was made else: # we loop over the input dict - for feat_type,feat_names in self.select_feature.items(): + for feat_type, feat_names in self.select_feature.items(): # if for a given type we need all the feature if feat_names == 'all': if feat_type == 'AtomicDensities': - self.select_feature['AtomicDensities'] = {'CA':1.7, 'C':1.7, 'N':1.55, 'O':1.52} + self.select_feature['AtomicDensities'] = { + 'CA': 1.7, 'C': 1.7, 'N': 1.55, 'O': 1.52} elif feat_type in mol_data: - self.select_feature[feat_type] = list(mapped_data[feat_type].keys()) + self.select_feature[feat_type] = list( + mapped_data[feat_type].keys()) else: raise KeyError('Feature type %s not found') @@ -634,14 +652,16 @@ def get_raw_feature_name(self): # So then we shouldn't add the tags else: if feat_type == 'AtomicDensities': - assert isinstance(self.select_feature['AtomicDensities'],dict) + assert isinstance( + self.select_feature['AtomicDensities'], dict) else: self.select_feature[feat_type] = [] for name in feat_names: if '*' in name: match = name.split('*')[0] possible_names = list(raw_data.keys()) - match_names = [n for n in possible_names if n.startswith(match)] + match_names = [ + n for n in possible_names if n.startswith(match)] self.select_feature[feat_type] += match_names else: self.select_feature[feat_type] += [name] @@ -651,85 +671,84 @@ def get_raw_feature_name(self): def print_possible_features(self): """Print the possible features in the group.""" - f5 = h5py.File(self.train_database[0],'r') + f5 = h5py.File(self.train_database[0], 'r') mol_name = list(f5.keys())[0] mapgrp = f5.get(mol_name + '/mapped_features/') print('\nPossible Features:') - print('-'*20) + print('-' * 20) for feat_type in list(mapgrp.keys()): - print('== %s' %feat_type) + print('== %s' % feat_type) for fname in list(mapgrp[feat_type].keys()): - print(' -- %s' %fname) + print(' -- %s' % fname) if self.select_feature is not None: print('\nYour selection was:') - for feat_type,feat in self.select_feature.items(): + for feat_type, feat in self.select_feature.items(): if feat_type not in list(mapgrp.keys()): print('== \x1b[0;37;41m' + feat_type + '\x1b[0m') else: - print('== %s' %feat_type) - if isinstance(feat,str): - print(' -- %s' %feat) - if isinstance(feat,list): + print('== %s' % feat_type) + if isinstance(feat, str): + print(' -- %s' % feat) + if isinstance(feat, list): for f in feat: - print(' -- %s' %f) + print(' -- %s' % f) print("You don't need to specify _chainA _chainB for each feature. The code will append it automatically") def get_pairing_feature(self): - """Creates the index of paired features. - """ + """Creates the index of paired features.""" if self.pair_chain_feature: self.pair_indexes = [] start = 0 - for feat_type,feat_names in self.select_feature.items(): + for feat_type, feat_names in self.select_feature.items(): nfeat = len(feat_names) if '_ind' in feat_type: - self.pair_indexes += [ [i,i+1] for i in range(start,start+nfeat,2)] + self.pair_indexes += [[i, i + 1] + for i in range(start, start + nfeat, 2)] else: - self.pair_indexes += [ [i] for i in range(start,start+nfeat)] + self.pair_indexes += [[i] + for i in range(start, start + nfeat)] start += nfeat def get_input_shape(self): - """Get the size of the data and input. + Reminder : self.data_shape : shape of the raw 3d data set self.input_shape : input size of the CNN (potentially after 2d transformation) """ - fname = self.train_database[0] if self.mapfly: - feature,_ = self.map_one_molecule(fname) + feature, _ = self.map_one_molecule(fname) else: - feature,_ = self.load_one_molecule(fname) + feature, _ = self.load_one_molecule(fname) self.data_shape = feature.shape if self.pair_chain_feature: - feature = self.make_feature_pair(feature,self.pair_chain_feature) + feature = self.make_feature_pair(feature, self.pair_chain_feature) if self.transform: - feature = self.convert2d(feature,self.proj2D) + feature = self.convert2d(feature, self.proj2D) self.input_shape = feature.shape - def get_grid_shape(self): + """Get the shape of the matrices. - '''Get the shape of the matrices. Raises: ValueError: If no grid shape is provided or is present in the HDF5 file - ''' + """ if self.mapfly is False: fname = self.train_database[0] - fh5 = h5py.File(fname,'r') + fh5 = h5py.File(fname, 'r') mol = list(fh5.keys())[0] # get the mol @@ -742,10 +761,11 @@ def get_grid_shape(self): nx = mol_data['grid_points']['x'].shape[0] ny = mol_data['grid_points']['y'].shape[0] nz = mol_data['grid_points']['z'].shape[0] - self.grid_shape = (nx,ny,nz) + self.grid_shape = (nx, ny, nz) else: - raise ValueError('Impossible to determine sparse grid shape.\n Specify argument grid_shape=(x,y,z)') + raise ValueError( + 'Impossible to determine sparse grid shape.\n Specify argument grid_shape=(x,y,z)') fh5.close() @@ -753,14 +773,11 @@ def get_grid_shape(self): self.grid_shape = self.grid_info['number_of_points'] else: - raise Warning('Impossible to determine sparse grid shape.\nIf you are not loading a pretrained model, specify grid_shape or grid_info') - - - - + raise Warning( + 'Impossible to determine sparse grid shape.\nIf you are not loading a pretrained model, specify grid_shape or grid_info') def compute_norm(self): - """ compute the normalization factors.""" + """compute the normalization factors.""" print(" Normalization factor :") @@ -771,13 +788,13 @@ def compute_norm(self): # get the feature/target if self.mapfly: - feature,target = self.map_one_molecule(fname,mol=molname) + feature, target = self.map_one_molecule(fname, mol=molname) else: - feature,target = self.load_one_molecule(fname,mol=molname) + feature, target = self.load_one_molecule(fname, mol=molname) # create the norm isntances at the first passage if first: - self.param_norm = {'features':[],'targets':None} + self.param_norm = {'features': [], 'targets': None} for ifeat in range(feature.shape[0]): self.param_norm['features'].append(NormParam()) self.param_norm['targets'] = MinMaxParam() @@ -785,11 +802,14 @@ def compute_norm(self): # update the norm instances for ifeat, mat in enumerate(feature): - self.param_norm['features'][ifeat].add(np.mean(mat),np.var(mat)) + self.param_norm['features'][ifeat].add( + np.mean(mat), np.var(mat)) self.param_norm['targets'].update(target) # process the std of the features and make array for fast access - nfeat, ncomplex = len(self.param_norm['features']), len(self.index_complexes) + nfeat, ncomplex = len( + self.param_norm['features']), len( + self.index_complexes) self.feature_mean, self.feature_std = [], [] for ifeat in range(nfeat): @@ -806,18 +826,17 @@ def compute_norm(self): self.target_min = self.param_norm['targets'].min[0] self.target_max = self.param_norm['targets'].max[0] - print(self.target_min,self.target_max) + print(self.target_min, self.target_max) def get_norm(self): - """Get the normalization values for the features. - """ + """Get the normalization values for the features.""" print(" Normalization factor :") # declare the dict of class instance # where we'll store the normalization parameter - self.param_norm = {'features':{},'targets':{}} - for feat_type,feat_names in self.select_feature.items(): + self.param_norm = {'features': {}, 'targets': {}} + for feat_type, feat_names in self.select_feature.items(): self.param_norm['features'][feat_type] = {} for name in feat_names: self.param_norm['features'][feat_type][name] = NormParam() @@ -827,41 +846,42 @@ def get_norm(self): self._read_norm() # make array for fast access - self.feature_mean,self.feature_std = [],[] - for feat_type,feat_names in self.select_feature.items(): + self.feature_mean, self.feature_std = [], [] + for feat_type, feat_names in self.select_feature.items(): for name in feat_names: - self.feature_mean.append(self.param_norm['features'][feat_type][name].mean) - self.feature_std.append(self.param_norm['features'][feat_type][name].std) + self.feature_mean.append( + self.param_norm['features'][feat_type][name].mean) + self.feature_std.append( + self.param_norm['features'][feat_type][name].std) self.target_min = self.param_norm['targets'][self.select_target].min self.target_max = self.param_norm['targets'][self.select_target].max def _read_norm(self): - """Read or create the normalization file for the complex. - """ + """Read or create the normalization file for the complex.""" # loop through all the filename for f5 in self.train_database: # get the precalculated data - fdata = os.path.splitext(f5)[0]+'_norm.pckl' + fdata = os.path.splitext(f5)[0] + '_norm.pckl' # if the file doesn't exist we create it if not os.path.isfile(fdata): print(" Computing norm for ", f5) - norm = NormalizeData(f5,shape=self.grid_shape) + norm = NormalizeData(f5, shape=self.grid_shape) norm.get() # read the data - data = pickle.load(open(fdata,'rb')) + data = pickle.load(open(fdata, 'rb')) # handle the features - for feat_type,feat_names in self.select_feature.items(): + for feat_type, feat_names in self.select_feature.items(): for name in feat_names: mean = data['features'][feat_type][name].mean var = data['features'][feat_type][name].var if var == 0: - print(' : STD is null for %s in %s' %(name,f5)) - self.param_norm['features'][feat_type][name].add(mean,var) + print(' : STD is null for %s in %s' % (name, f5)) + self.param_norm['features'][feat_type][name].add(mean, var) # handle the target minv = data['targets'][self.select_target].min @@ -871,22 +891,25 @@ def _read_norm(self): # process the std nfile = len(self.train_database) - for feat_types,feat_dict in self.param_norm['features'].items(): + for feat_types, feat_dict in self.param_norm['features'].items(): for feat in feat_dict: self.param_norm['features'][feat_types][feat].process(nfile) if self.param_norm['features'][feat_types][feat].std == 0: - print(' Final STD Null for %s/%s. Changed it to 1' %(feat_types,feat)) + print( + ' Final STD Null for %s/%s. Changed it to 1' % + (feat_types, feat)) self.param_norm['features'][feat_types][feat].std = 1 def _get_target_ordering(self, order): """Determine if ordering of the target. - This can be lower the better or higher the better - If it can't determine the ordering 'lower' is assumed + + This can be lower the better or higher the better If it can't + determine the ordering 'lower' is assumed """ - lower_list = ['IRMSD','LRMSD','HADDOCK'] - higher_list = ['DOCKQ','Fnat'] - NA_list = ['binary_class','BIN_CLASS', 'class'] + lower_list = ['IRMSD', 'LRMSD', 'HADDOCK'] + higher_list = ['DOCKQ', 'Fnat'] + NA_list = ['binary_class', 'BIN_CLASS', 'class'] if order is not None: self.target_ordering = order @@ -901,22 +924,24 @@ def _get_target_ordering(self, order): print(' Target ordering unidentified. lower assumed') self.target_ordering = 'lower' - def backtransform_target(self,data): + def backtransform_target(self, data): """Returns the values of the target after de-normalization. + Args: data (list(float)): normalized data Returns: list(float): un-normalized data """ - #print(data) - #print(self.target_max) + # print(data) + # print(self.target_max) #data = FloatTensor(data) data *= self.target_max data += self.target_min - return data #.numpy() + return data # .numpy() - def _normalize_target(self,target): + def _normalize_target(self, target): """Normalize the values of the targets. + Args: target (list(float)): raw data Returns: @@ -927,8 +952,9 @@ def _normalize_target(self,target): target /= self.target_max return target - def _normalize_feature(self,feature): + def _normalize_feature(self, feature): """Normalize the values of the features. + Args: feature (np.array): raw feature values Returns: @@ -936,10 +962,11 @@ def _normalize_feature(self,feature): """ for ic in range(self.data_shape[0]): - feature[ic] = (feature[ic]-self.feature_mean[ic])/self.feature_std[ic] + feature[ic] = (feature[ic] - self.feature_mean[ic] + ) / self.feature_std[ic] return feature - def _clip_feature(self,feature): + def _clip_feature(self, feature): """Clip the value of the features at +/- mean + clip_factor * std. Args: feature (np.array): raw feature values @@ -949,15 +976,17 @@ def _clip_feature(self,feature): w = self.clip_factor for ic in range(self.data_shape[0]): - minv = self.feature_mean[ic] - w*self.feature_std[ic] - maxv = self.feature_mean[ic] + w*self.feature_std[ic] - feature[ic] = np.clip(feature[ic],minv,maxv) + minv = self.feature_mean[ic] - w * self.feature_std[ic] + maxv = self.feature_mean[ic] + w * self.feature_std[ic] + feature[ic] = np.clip(feature[ic], minv, maxv) #feature[ic] = self._mad_based_outliers(feature[ic],minv,maxv) return feature @staticmethod def _mad_based_outliers(points, minv, maxv, thresh=3.5): - """Mean absolute deviation based outlier detection. (Experimental). + """Mean absolute deviation based outlier detection. + + (Experimental). Args: points (np.array): raw input data minv (float): Minimum (negative) value requested @@ -977,26 +1006,27 @@ def _mad_based_outliers(points, minv, maxv, thresh=3.5): modified_z_score = 0.6745 * diff / med_abs_deviation mask_outliers = modified_z_score > thresh - mask_max = np.abs(points-maxv) < np.abs(points-minv) - mask_min = np.abs(points-maxv) > np.abs(points-minv) + mask_max = np.abs(points - maxv) < np.abs(points - minv) + mask_min = np.abs(points - maxv) > np.abs(points - minv) points[mask_max * mask_outliers] = maxv points[mask_min * mask_outliers] = minv return points - def load_one_molecule(self,fname,mol=None): - '''Load the feature/target of a single molecule. + def load_one_molecule(self, fname, mol=None): + """Load the feature/target of a single molecule. + Args: fname (str): hdf5 file name mol (None or str, optional): name of the complex in the hdf5 Returns: np.array,float: features, targets - ''' + """ t0 = time.time() outtype = 'float32' - fh5 = h5py.File(fname,'r') + fh5 = h5py.File(fname, 'r') if mol is None: mol = list(fh5.keys())[0] @@ -1006,15 +1036,18 @@ def load_one_molecule(self,fname,mol=None): # get the features feature = [] - for feat_type,feat_names in self.select_feature.items(): + for feat_type, feat_names in self.select_feature.items(): # see if the feature exists - if 'mapped_features/'+feat_type in mol_data.keys(): - feat_dict = mol_data.get('mapped_features/'+feat_type) + if 'mapped_features/' + feat_type in mol_data.keys(): + feat_dict = mol_data.get('mapped_features/' + feat_type) else: - print('Feature type %s not found in file %s for molecule %s' %(feat_type,fname,mol)) - print('Possible feature types are : ' + '\n\t'.join(list(mol_data['mapped_features'].keys()))) - raise ValueError(feat_type,' not supported') + print( + 'Feature type %s not found in file %s for molecule %s' % + (feat_type, fname, mol)) + print('Possible feature types are : ' + + '\n\t'.join(list(mol_data['mapped_features'].keys()))) + raise ValueError(feat_type, ' not supported') # loop through all the desired feat names for name in feat_names: @@ -1023,9 +1056,12 @@ def load_one_molecule(self,fname,mol=None): try: data = feat_dict[name] except KeyError: - print('Feature %s not found in file %s for mol %s and feature type %s' %(name,fname,mol,feat_type)) - print('Possible feature are : ' + '\n\t'.join(list(mol_data['mapped_features/'+feat_type].keys()))) - + print( + 'Feature %s not found in file %s for mol %s and feature type %s' % + (name, fname, mol, feat_type)) + print('Possible feature are : ' + + '\n\t'.join(list(mol_data['mapped_features/' + + feat_type].keys()))) # check its sparse attribute # if true get a FLAN @@ -1042,21 +1078,21 @@ def load_one_molecule(self,fname,mol=None): feature.append(mat) # get the target value - target = mol_data.get('targets/'+self.select_target)[()] + target = mol_data.get('targets/' + self.select_target)[()] # close fh5.close() - print(' --> Load one molecule %f sec.' %(time.time()-t0)) - #sys.exit() + print(' --> Load one molecule %f sec.' % (time.time() - t0)) + # sys.exit() # make sure all the feature have exact same type # if they don't collate_fn in the creation of the minibatch will fail. # Note returning torch.FloatTensor makes each epoch twice longer ... - return np.array(feature).astype(outtype),np.array([target]).astype(outtype) - + return np.array(feature).astype( + outtype), np.array([target]).astype(outtype) - def map_one_molecule(self,fname,mol=None,angle=None,axis=None): - '''Map the feature and load feature/target of a single molecule. + def map_one_molecule(self, fname, mol=None, angle=None, axis=None): + """Map the feature and load feature/target of a single molecule. Args: fname (str): hdf5 file name @@ -1064,11 +1100,11 @@ def map_one_molecule(self,fname,mol=None,angle=None,axis=None): Returns: np.array,float: features, targets - ''' + """ t0 = time.time() outtype = 'float32' - fh5 = h5py.File(fname,'r') + fh5 = h5py.File(fname, 'r') if mol is None: mol = list(fh5.keys())[0] @@ -1079,33 +1115,36 @@ def map_one_molecule(self,fname,mol=None,angle=None,axis=None): # get the features feature = [] - for feat_type,feat_names in self.select_feature.items(): + for feat_type, feat_names in self.select_feature.items(): if feat_type == 'AtomicDensities': - densities = self.map_atomic_densities(feat_names, mol_data, grid, npts, angle, axis) + densities = self.map_atomic_densities( + feat_names, mol_data, grid, npts, angle, axis) feature += densities elif feat_type == 'Features': - data = self.map_feature(feat_names, mol_data, grid, npts, angle, axis) + data = self.map_feature( + feat_names, mol_data, grid, npts, angle, axis) feature += data # get the target value - target = mol_data.get('targets/'+self.select_target)[()] + target = mol_data.get('targets/' + self.select_target)[()] # close fh5.close() - print(' --> Map one molecule %f sec.' %(time.time()-t0)) - #sys.exit() + print(' --> Map one molecule %f sec.' % (time.time() - t0)) + # sys.exit() # make sure all the feature have exact same type # if they don't collate_fn in the creation of the minibatch will fail. # Note returning torch.FloatTensor makes each epoch twice longer ... - return np.array(feature).astype(outtype),np.array([target]).astype(outtype) - + return np.array(feature).astype( + outtype), np.array([target]).astype(outtype) @staticmethod - def convert2d(feature,proj2d): - '''Convert the 3D volumetric feature to a 2D planar data set. + def convert2d(feature, proj2d): + """Convert the 3D volumetric feature to a 2D planar data set. + proj2d specifies the dimension that we want to consider as channel for example for proj2d = 0 the 2D images are in the yz plane and the stack along the x dimension is considered as extra channels @@ -1114,20 +1153,21 @@ def convert2d(feature,proj2d): proj2d (int): projection Returns: np.array: projected features - ''' - nc,nx,ny,nz = feature.shape - if proj2d==0: - feature = feature.reshape(-1,1,ny,nz).squeeze() - elif proj2d==1: - feature = feature.reshape(-1,nx,1,nz).squeeze() - elif proj2d==2: - feature = feature.reshape(-1,nx,ny,1).squeeze() + """ + nc, nx, ny, nz = feature.shape + if proj2d == 0: + feature = feature.reshape(-1, 1, ny, nz).squeeze() + elif proj2d == 1: + feature = feature.reshape(-1, nx, 1, nz).squeeze() + elif proj2d == 2: + feature = feature.reshape(-1, nx, ny, 1).squeeze() return feature @staticmethod - def make_feature_pair(feature,op): + def make_feature_pair(feature, op): """Pair the features of both chains. + Args: feature (np.array): raw features op (callable): function to combine the features @@ -1138,19 +1178,18 @@ def make_feature_pair(feature,op): """ if not callable(op): - raise ValueError('Operation not callable',op) + raise ValueError('Operation not callable', op) nFeat = len(feature) - pair_indexes = list(np.arange(nFeat).reshape(int(nFeat/2),2)) + pair_indexes = list(np.arange(nFeat).reshape(int(nFeat / 2), 2)) outtype = feature.dtype new_feat = [] for ind in pair_indexes: - new_feat.append(op(feature[ind[0],...],feature[ind[1],...])) + new_feat.append(op(feature[ind[0], ...], feature[ind[1], ...])) return np.array(new_feat).astype(outtype) - def get_grid(self, mol_data): if self.grid_info is None: @@ -1161,7 +1200,7 @@ def get_grid(self, mol_data): y = mol_data['grid_points/y'][()] z = mol_data['grid_points/z'][()] - except: + except BaseException: raise ValueError("Grid points not found in the data file") @@ -1171,26 +1210,31 @@ def get_grid(self, mol_data): npts = np.array(self.grid_info['number_of_points']) res = np.array(self.grid_info['resolution']) - halfdim = 0.5*(npts*res) + halfdim = 0.5 * (npts * res) - low_lim = center-halfdim - hgh_lim = low_lim + res*(npts-1) - - x = np.linspace(low_lim[0],hgh_lim[0],npts[0]) - y = np.linspace(low_lim[1],hgh_lim[1],npts[1]) - z = np.linspace(low_lim[2],hgh_lim[2],npts[2]) + low_lim = center - halfdim + hgh_lim = low_lim + res * (npts - 1) + x = np.linspace(low_lim[0], hgh_lim[0], npts[0]) + y = np.linspace(low_lim[1], hgh_lim[1], npts[1]) + z = np.linspace(low_lim[2], hgh_lim[2], npts[2]) # there is stil something strange # with the ordering of the grid # also noted in GridTools define_grid_points() - y,x,z = np.meshgrid(y,x,z) - grid = (x,y,z) - npts = (len(x),len(y),len(z)) + y, x, z = np.meshgrid(y, x, z) + grid = (x, y, z) + npts = (len(x), len(y), len(z)) return grid, npts - - def map_atomic_densities(self,feat_names, mol_data, grid, npts, angle, axis): + def map_atomic_densities( + self, + feat_names, + mol_data, + grid, + npts, + angle, + axis): t0 = time.time() sql = pdb2sql(mol_data['complex'][()]) @@ -1200,18 +1244,18 @@ def map_atomic_densities(self,feat_names, mol_data, grid, npts, angle, axis): center = [np.mean(g) for g in grid] densities = [] - for atomtype,vdw_rad in feat_names.items(): + for atomtype, vdw_rad in feat_names.items(): start = time.time() # get pos of the contact atoms of correct type - xyzA = np.array(sql.get('x,y,z',rowID=index[0],name=atomtype)) - xyzB = np.array(sql.get('x,y,z',rowID=index[1],name=atomtype)) + xyzA = np.array(sql.get('x,y,z', rowID=index[0], name=atomtype)) + xyzB = np.array(sql.get('x,y,z', rowID=index[1], name=atomtype)) # rotate if necessary if angle is not None: - xyzA = self._rotate_coord(xyzA,center,angle,axis) - xyzB = self._rotate_coord(xyzB,center,angle,axis) + xyzA = self._rotate_coord(xyzA, center, angle, axis) + xyzB = self._rotate_coord(xyzB, center, angle, axis) # init the grid atdensA = np.zeros(npts) @@ -1219,23 +1263,21 @@ def map_atomic_densities(self,feat_names, mol_data, grid, npts, angle, axis): # run on the atoms for pos in xyzA: - atdensA += self._densgrid(pos,vdw_rad,grid,npts) + atdensA += self._densgrid(pos, vdw_rad, grid, npts) # run on the atoms for pos in xyzB: - atdensB += self._densgrid(pos,vdw_rad,grid,npts) + atdensB += self._densgrid(pos, vdw_rad, grid, npts) - densities += [atdensA,atdensB] + densities += [atdensA, atdensB] #print(' __ Map single atomic density %s %f' %(atomtype, time.time()-start)) sql.close() #print(' __ Total Atomic Densities : %f' %(time.time()-t0)) return densities @staticmethod - def _densgrid(center,vdw_radius,grid,npts): - - ''' Function to map individual atomic density on the grid. - + def _densgrid(center, vdw_radius, grid, npts): + """Function to map individual atomic density on the grid. The formula is equation (1) of the Koes paper Protein-Ligand Scoring with Convolutional NN Arxiv:1612.02751v1 @@ -1246,16 +1288,18 @@ def _densgrid(center,vdw_radius,grid,npts): Returns: TYPE: np.array (mapped density) - ''' + """ - x0,y0,z0 = center - dd = np.sqrt( (grid[0]-x0)**2 + (grid[1]-y0)**2 + (grid[2]-z0)**2 ) + x0, y0, z0 = center + dd = np.sqrt((grid[0] - x0)**2 + (grid[1] - y0)**2 + (grid[2] - z0)**2) dgrid = np.zeros(npts) - dgrid[dd=vdw_radius) & (dd<1.5*vdw_radius)] = 4./np.e**2/vdw_radius**2*dd[ (dd >=vdw_radius) & (dd<1.5*vdw_radius)]**2 - 12./np.e**2/vdw_radius*dd[ (dd >=vdw_radius) & (dd<1.5*vdw_radius)] + 9./np.e**2 + dgrid[dd < vdw_radius] = np.exp(-2 * + dd[dd < vdw_radius]**2 / vdw_radius**2) + dgrid[(dd >= vdw_radius) & (dd < 1.5 * vdw_radius)] = 4. / np.e**2 / vdw_radius**2 * dd[(dd >= vdw_radius) & (dd < + 1.5 * vdw_radius)]**2 - 12. / np.e**2 / vdw_radius * dd[(dd >= vdw_radius) & (dd < 1.5 * vdw_radius)] + 9. / np.e**2 return dgrid - def map_feature(self,feat_names, mol_data, grid, npts, angle, axis): + def map_feature(self, feat_names, mol_data, grid, npts, angle, axis): __vectorize__ = False @@ -1266,38 +1310,39 @@ def map_feature(self,feat_names, mol_data, grid, npts, angle, axis): if __vectorize__: tprep = time.time() - pfunc = partial(self._featgrid,grid=grid,npts=npts) - vmap = np.vectorize(pfunc,signature='(n),()->(p,p,p)') + pfunc = partial(self._featgrid, grid=grid, npts=npts) + vmap = np.vectorize(pfunc, signature='(n),()->(p,p,p)') #print(' __ Prepare function %f' %(time.time()-t0)) feat = [] for name in feat_names: + tmp_feat_ser = [np.zeros(npts), np.zeros(npts)] + tmp_feat_vect = [np.zeros(npts), np.zeros(npts)] + data = np.array(mol_data['features/' + name][()]) - tmp_feat_ser = [np.zeros(npts),np.zeros(npts)] - tmp_feat_vect = [np.zeros(npts),np.zeros(npts)] - data = np.array(mol_data['features/'+name][()]) - - chain = data[:,0] - pos = data[:,1:4] - feat_value = data[:,4] + chain = data[:, 0] + pos = data[:, 1:4] + feat_value = data[:, 4] if angle is not None: - pos = self._rotate_coord(pos,center,angle,axis) + pos = self._rotate_coord(pos, center, angle, axis) start = time.time() - if __vectorize__ == True or __vectorize__ == 'both': + if __vectorize__ or __vectorize__ == 'both': - for chainID in [0,1]: - tmp_feat_vect[chainID] = np.sum(vmap(pos[chain==chainID,:],feat_value[chain==chainID]),0) + for chainID in [0, 1]: + tmp_feat_vect[chainID] = np.sum( + vmap(pos[chain == chainID, :], feat_value[chain == chainID]), 0) if __vectorize__ == False or __vectorize__ == 'both': - for chainID,xyz,val in zip(chain,pos,feat_value): - tmp_feat_ser[int(chainID)] += self._featgrid(xyz,val,grid,npts) + for chainID, xyz, val in zip(chain, pos, feat_value): + tmp_feat_ser[int(chainID) + ] += self._featgrid(xyz, val, grid, npts) if __vectorize__ == 'both': - assert np.allclose(tmp_feat_ser,tmp_feat_vect) + assert np.allclose(tmp_feat_ser, tmp_feat_vect) if __vectorize__: feat += tmp_feat_vect @@ -1309,8 +1354,8 @@ def map_feature(self,feat_names, mol_data, grid, npts, angle, axis): return feat @staticmethod - def _featgrid(center,value,grid,npts): - '''Map an individual feature (atomic or residue) on the grid + def _featgrid(center, value, grid, npts): + """Map an individual feature (atomic or residue) on the grid. Args: center (list(float)): position of the feature center @@ -1322,30 +1367,30 @@ def _featgrid(center,value,grid,npts): Raises: ValueError: Description - ''' + """ # shortcut for th center - x0,y0,z0 = center + x0, y0, z0 = center - sigma = np.sqrt(1./2) - beta = 0.5/(sigma**2) - cutoff = 5.*beta + sigma = np.sqrt(1. / 2) + beta = 0.5 / (sigma**2) + cutoff = 5. * beta - dd = np.sqrt( (grid[0]-x0)**2 + (grid[1]-y0)**2 + (grid[2]-z0)**2 ) + dd = np.sqrt((grid[0] - x0)**2 + (grid[1] - y0)**2 + (grid[2] - z0)**2) - dd[ddcutoff] = 0 + dd[dd < cutoff] = value * np.exp(-beta * dd[dd < cutoff]) + dd[dd > cutoff] = 0 #dgrid = np.zeros(npts) #dgrid[dd 0: self.cuda = True - if self.ngpu == 0 and self.cuda : + if self.ngpu == 0 and self.cuda: self.ngpu = 1 - - #------------------------------------------ + # ------------------------------------------ # Regression or classifiation - #------------------------------------------ + # ------------------------------------------ # task to accomplish self.task = task # Set the loss functiom - if self.task=='reg': + if self.task == 'reg': self.criterion = nn.MSELoss(reduction='sum') self._plot_scatter = self._plot_scatter_reg - elif self.task=='class': + elif self.task == 'class': self.criterion = nn.CrossEntropyLoss(reduction='sum') self._plot_scatter = self._plot_boxplot_class self.data_set.normalize_targets = False else: - raise ValueError("Task " + self.task +"not recognized.\nOptions are \n\t 'reg': regression \n\t 'class': classifiation\n\n") + raise ValueError( + "Task " + + self.task + + "not recognized.\nOptions are \n\t 'reg': regression \n\t 'class': classifiation\n\n") - #------------------------------------------ + # ------------------------------------------ # Output - #------------------------------------------ + # ------------------------------------------ # plot or not plot self.plot = plot @@ -192,9 +193,9 @@ def __init__(self,data_set,model, if not os.path.isdir(self.outdir): os.mkdir(outdir) - #------------------------------------------ + # ------------------------------------------ # Network - #------------------------------------------ + # ------------------------------------------ # load the model self.net = model(self.data_set.input_shape) @@ -205,13 +206,16 @@ def __init__(self,data_set,model, device = torch.device("cuda") # PyTorch v0.4.0 else: device = torch.device("cpu") - summary(self.net.to(device), self.data_set.input_shape, device = device.type) + summary( + self.net.to(device), + self.data_set.input_shape, + device=device.type) sys.stdout.flush() # load parameters of pretrained model if provided if self.pretrained_model: - ## a prefix 'module.' is added to parameter names if torch.nn.DataParallel was used - ## https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel + # a prefix 'module.' is added to parameter names if torch.nn.DataParallel was used + # https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel if self.state['cuda']: for paramname in list(self.state['state_dict'].keys()): paramname_new = paramname.lstrip('module.') @@ -219,53 +223,69 @@ def __init__(self,data_set,model, del self.state['state_dict'][paramname] self.load_model_params() - #multi-gpu - if self.ngpu>1: + # multi-gpu + if self.ngpu > 1: ids = [i for i in range(self.ngpu)] - self.net = nn.DataParallel(self.net,device_ids=ids).cuda() + self.net = nn.DataParallel(self.net, device_ids=ids).cuda() # cuda compatible elif self.cuda: self.net = self.net.cuda() # set the optimizer - self.optimizer = optim.SGD(self.net.parameters(),lr=0.005,momentum=0.9,weight_decay=0.001) + self.optimizer = optim.SGD( + self.net.parameters(), + lr=0.005, + momentum=0.9, + weight_decay=0.001) if self.pretrained_model: self.load_optimizer_params() - - #------------------------------------------ + # ------------------------------------------ # print - #------------------------------------------ + # ------------------------------------------ print('\n') - print('='*40) + print('=' * 40) print('=\t Convolution Neural Network') - print('=\t model : %s' %model_type) - print('=\t CNN : %s' %model.__name__) + print('=\t model : %s' % model_type) + print('=\t CNN : %s' % model.__name__) - for feat_type,feat_names in self.data_set.select_feature.items(): - print('=\t features : %s' %(feat_type)) + for feat_type, feat_names in self.data_set.select_feature.items(): + print('=\t features : %s' % (feat_type)) for name in feat_names: - print('=\t\t %s' %(name)) + print('=\t\t %s' % (name)) if self.data_set.pair_chain_feature is not None: - print('=\t Pair : %s' %self.data_set.pair_chain_feature.__name__) - print('=\t targets : %s' %self.data_set.select_target) - print('=\t CUDA : %s' %str(self.cuda)) + print( + '=\t Pair : %s' % + self.data_set.pair_chain_feature.__name__) + print('=\t targets : %s' % self.data_set.select_target) + print('=\t CUDA : %s' % str(self.cuda)) if self.cuda: - print('=\t nGPU : %d' %self.ngpu) - print('='*40,'\n') + print('=\t nGPU : %d' % self.ngpu) + print('=' * 40, '\n') # check if CUDA works if self.cuda and not torch.cuda.is_available(): - print(' --> CUDA not deteceted : Make sure that CUDA is installed and that you are running on GPUs') + print( + ' --> CUDA not deteceted : Make sure that CUDA is installed and that you are running on GPUs') print(' --> To turn CUDA of set cuda=False in NeuralNet') print(' --> Aborting the experiment \n\n') sys.exit() - def train(self,nepoch=50, divide_trainset= None, hdf5='epoch_data.hdf5',train_batch_size = 10, - preshuffle=True, preshuffle_seed=None, export_intermediate=True,num_workers=1,save_model='best',save_epoch='intermediate'): - - """Perform a simple training of the model. The data set is divided in training/validation sets. + def train( + self, + nepoch=50, + divide_trainset=None, + hdf5='epoch_data.hdf5', + train_batch_size=10, + preshuffle=True, + preshuffle_seed=None, + export_intermediate=True, + num_workers=1, + save_model='best', + save_epoch='intermediate'): + """Perform a simple training of the model. The data set is divided in + training/validation sets. Args: @@ -308,34 +328,33 @@ def train(self,nepoch=50, divide_trainset= None, hdf5='epoch_data.hdf5',train_ba >>> model.train(nepoch = 50,divide_trainset=0.8, train_batch_size = 5,num_workers=0) >>> # save the model >>> model.save_model() - """ - print('\n: Batch Size : %d' %train_batch_size) + print('\n: Batch Size : %d' % train_batch_size) if self.cuda: - print(': NGPU : %d' %self.ngpu) + print(': NGPU : %d' % self.ngpu) # hdf5 support - fname =self.outdir+'/'+hdf5 - self.f5 = h5py.File(fname,'w') + fname = self.outdir + '/' + hdf5 + self.f5 = h5py.File(fname, 'w') # divide the set in train+ valid and test if divide_trainset is not None: # if divide_trainset is not None - index_train,index_valid,index_test = self._divide_dataset(divide_trainset,preshuffle, preshuffle_seed) + index_train, index_valid, index_test = self._divide_dataset( + divide_trainset, preshuffle, preshuffle_seed) else: index_train = self.data_set.index_train index_valid = self.data_set.index_valid index_test = self.data_set.index_test - - print(': %d confs. for training' %len(index_train)) - print(': %d confs. for validation' %len(index_valid)) - print(': %d confs. for testing' %len(index_test)) + print(': %d confs. for training' % len(index_train)) + print(': %d confs. for validation' % len(index_valid)) + print(': %d confs. for testing' % len(index_test)) # train the model t0 = time.time() - self._train(index_train,index_valid,index_test, + self._train(index_train, index_valid, index_test, nepoch=nepoch, train_batch_size=train_batch_size, export_intermediate=export_intermediate, @@ -343,7 +362,10 @@ def train(self,nepoch=50, divide_trainset= None, hdf5='epoch_data.hdf5',train_ba save_epoch=save_epoch, save_model=save_model) self.f5.close() - print(' --> Training done in ', self.convertSeconds2Days(time.time()-t0)) + print( + ' --> Training done in ', + self.convertSeconds2Days( + time.time() - t0)) # save the model self.save_model(filename='last_model.pth.tar') @@ -360,8 +382,7 @@ def convertSeconds2Days(time): minutes = time // 60 time %= 60 seconds = time - return '%02d-%02d:%02d:%02d'%(day,hour,minutes,seconds) - + return '%02d-%02d:%02d:%02d' % (day, hour, minutes, seconds) def test(self, hdf5='test_data.hdf5'): """Test a predefined model on a new dataset. @@ -376,11 +397,10 @@ def test(self, hdf5='test_data.hdf5'): Args: hdf5 (str, optional): hdf5 file to store the test results - """ # output - fname = self.outdir+'/'+hdf5 - self.f5 = h5py.File(fname,'w') + fname = self.outdir + '/' + hdf5 + self.f5 = h5py.File(fname, 'w') # load pretrained model to get task and criterion self.load_nn_params() @@ -388,84 +408,74 @@ def test(self, hdf5='test_data.hdf5'): # load data index = list(range(self.data_set.__len__())) sampler = data_utils.sampler.SubsetRandomSampler(index) - loader = data_utils.DataLoader(self.data_set,sampler=sampler) + loader = data_utils.DataLoader(self.data_set, sampler=sampler) # do test self.data = {} - _, self.data['test'] = self._epoch(loader,train_model=False) + _, self.data['test'] = self._epoch(loader, train_model=False) if self.task == 'reg': - self._plot_scatter_reg(self.outdir+'/prediction.png') - self.plot_hit_rate(self.outdir+'/hitrate.png') + self._plot_scatter_reg(self.outdir + '/prediction.png') + self.plot_hit_rate(self.outdir + '/hitrate.png') self._export_epoch_hdf5(0, self.data) self.f5.close() - - def save_model(self,filename='model.pth.tar'): - - """save the model to disk + def save_model(self, filename='model.pth.tar'): + """save the model to disk. Args: filename (str, optional): name of the file """ filename = self.outdir + '/' + filename - state = {'state_dict' : self.net.state_dict(), - 'optimizer' : self.optimizer.state_dict(), - 'normalize_targets' : self.data_set.normalize_targets, - 'normalize_features' : self.data_set.normalize_features, - 'select_feature' : self.data_set.select_feature, - 'select_target' : self.data_set.select_target, - 'target_ordering' : self.data_set.target_ordering, - 'pair_chain_feature' : self.data_set.pair_chain_feature, - 'dict_filter' : self.data_set.dict_filter, - 'transform' : self.data_set.transform, - 'proj2D' : self.data_set.proj2D, - 'clip_features' : self.data_set.clip_features, - 'clip_factor' : self.data_set.clip_factor, - 'grid_shape' : self.data_set.grid_shape, - 'grid_info' : self.data_set.grid_info, - 'mapfly' : self.data_set.mapfly, - 'task' : self.task, - 'criterion' : self.criterion, - 'cuda' : self.cuda + state = {'state_dict': self.net.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'normalize_targets': self.data_set.normalize_targets, + 'normalize_features': self.data_set.normalize_features, + 'select_feature': self.data_set.select_feature, + 'select_target': self.data_set.select_target, + 'target_ordering': self.data_set.target_ordering, + 'pair_chain_feature': self.data_set.pair_chain_feature, + 'dict_filter': self.data_set.dict_filter, + 'transform': self.data_set.transform, + 'proj2D': self.data_set.proj2D, + 'clip_features': self.data_set.clip_features, + 'clip_factor': self.data_set.clip_factor, + 'grid_shape': self.data_set.grid_shape, + 'grid_info': self.data_set.grid_info, + 'mapfly': self.data_set.mapfly, + 'task': self.task, + 'criterion': self.criterion, + 'cuda': self.cuda } if self.data_set.normalize_features: - state['feature_mean'] = self.data_set.feature_mean - state['feature_std' ] = self.data_set.feature_std + state['feature_mean'] = self.data_set.feature_mean + state['feature_std'] = self.data_set.feature_std if self.data_set.normalize_targets: - state['target_min'] = self.data_set.target_min - state['target_max'] = self.data_set.target_max - - torch.save(state,filename) + state['target_min'] = self.data_set.target_min + state['target_max'] = self.data_set.target_max + torch.save(state, filename) def load_model_params(self): - """Get model parameters from a saved model. - """ + """Get model parameters from a saved model.""" self.net.load_state_dict(self.state['state_dict']) - def load_optimizer_params(self): - """Get optimizer parameters from a saved model. - """ + """Get optimizer parameters from a saved model.""" self.optimizer.load_state_dict(self.state['optimizer']) - def load_nn_params(self): - """Get NeuralNet parameters from a saved model. - """ + """Get NeuralNet parameters from a saved model.""" self.task = self.state['task'] self.criterion = self.state['criterion'] - def load_data_params(self): - '''Get dataset parameters from a saved model. - ''' + """Get dataset parameters from a saved model.""" self.data_set.select_feature = self.state['select_feature'] - self.data_set.select_target = self.state['select_target'] + self.data_set.select_target = self.state['select_target'] self.data_set.pair_chain_feature = self.state['pair_chain_feature'] self.data_set.dict_filter = self.state['dict_filter'] @@ -489,10 +499,9 @@ def load_data_params(self): self.data_set.mapfly = self.state['mapfly'] self.data_set.grid_info = self.state['grid_info'] - - def _divide_dataset(self,divide_set, preshuffle, preshuffle_seed): - - '''Divide the data set in a training validation and test according to the percentage in divide_set. + def _divide_dataset(self, divide_set, preshuffle, preshuffle_seed): + """Divide the data set in a training validation and test according to + the percentage in divide_set. Args: divide_set (list(float)): percentage used for training/validation/test @@ -501,50 +510,49 @@ def _divide_dataset(self,divide_set, preshuffle, preshuffle_seed): Returns: list(int),list(int),list(int): Indices of the training/validation/test set - ''' + """ # if user only provided one number # we assume it's the training percentage - if not isinstance(divide_set,list): - divide_set = [divide_set, 1.-divide_set] + if not isinstance(divide_set, list): + divide_set = [divide_set, 1. - divide_set] # if user provided 3 number and testset if len(divide_set) == 3 and self.data_set.test_database is not None: - divide_set = [divide_set[0],1.-divide_set[0]] + divide_set = [divide_set[0], 1. - divide_set[0]] print(' : test data set AND test in training set detected') - print(' : Divide training set as %f train %f valid' %(divide_set[0],divide_set[1])) + print( + ' : Divide training set as %f train %f valid' % + (divide_set[0], divide_set[1])) print(' : Keep test set for testing') - # preshuffle if preshuffle: - if preshuffle_seed is not None and not isinstance(preshuffle_seed, int): + if preshuffle_seed is not None and not isinstance( + preshuffle_seed, int): preshuffle_seed = int(preshuffle_seed) np.random.seed(preshuffle_seed) np.random.shuffle(self.data_set.index_train) # size of the subset for training - ntrain = int( np.ceil(float(self.data_set.ntrain)*divide_set[0]) ) - nvalid = int( np.floor(float(self.data_set.ntrain)*divide_set[1]) ) + ntrain = int(np.ceil(float(self.data_set.ntrain) * divide_set[0])) + nvalid = int(np.floor(float(self.data_set.ntrain) * divide_set[1])) # indexes train and valid index_train = self.data_set.index_train[:ntrain] - index_valid = self.data_set.index_train[ntrain:ntrain+nvalid] + index_valid = self.data_set.index_train[ntrain:ntrain + nvalid] # index of test depending of the situation - if len(divide_set)==3: - index_test = self.data_set.index_train[ntrain+nvalid:] + if len(divide_set) == 3: + index_test = self.data_set.index_train[ntrain + nvalid:] else: index_test = self.data_set.index_test - return index_train,index_valid,index_test - - - - def _train(self,index_train,index_valid,index_test, - nepoch = 50,train_batch_size = 5, - export_intermediate=False,num_workers=1, - save_epoch='intermediate',save_model='best'): + return index_train, index_valid, index_test + def _train(self, index_train, index_valid, index_test, + nepoch=50, train_batch_size=5, + export_intermediate=False, num_workers=1, + save_epoch='intermediate', save_model='best'): """Train the model. Args: @@ -563,7 +571,7 @@ def _train(self,index_train,index_valid,index_test, """ # printing options - nprint = np.max([1,int(nepoch/10)]) + nprint = np.max([1, int(nepoch / 10)]) # store the length of the training set ntrain = len(index_train) @@ -579,11 +587,11 @@ def _train(self,index_train,index_valid,index_test, test_sampler = data_utils.sampler.SubsetRandomSampler(index_test) # get if we test as well - _valid_ = len(valid_sampler.indices)>0 - _test_ = len(test_sampler.indices)>0 + _valid_ = len(valid_sampler.indices) > 0 + _test_ = len(test_sampler.indices) > 0 # containers for the losses - self.losses={'train': []} + self.losses = {'train': []} if _valid_: self.losses['valid'] = [] if _test_: @@ -593,30 +601,51 @@ def _train(self,index_train,index_valid,index_test, if self.save_classmetrics: self.classmetrics = {} for i in self.metricnames: - self.classmetrics[i] = {'train':[]} + self.classmetrics[i] = {'train': []} if _valid_: self.classmetrics[i]['valid'] = [] if _test_: self.classmetrics[i]['test'] = [] # create the loaders - train_loader = data_utils.DataLoader(self.data_set,batch_size=train_batch_size,sampler=train_sampler,pin_memory=pin,num_workers=num_workers,shuffle=False,drop_last=False) + train_loader = data_utils.DataLoader( + self.data_set, + batch_size=train_batch_size, + sampler=train_sampler, + pin_memory=pin, + num_workers=num_workers, + shuffle=False, + drop_last=False) if _valid_: - valid_loader = data_utils.DataLoader(self.data_set,batch_size=train_batch_size,sampler=valid_sampler,pin_memory=pin,num_workers=num_workers,shuffle=False,drop_last=False) + valid_loader = data_utils.DataLoader( + self.data_set, + batch_size=train_batch_size, + sampler=valid_sampler, + pin_memory=pin, + num_workers=num_workers, + shuffle=False, + drop_last=False) if _test_: - test_loader = data_utils.DataLoader(self.data_set,batch_size=train_batch_size,sampler=test_sampler,pin_memory=pin,num_workers=num_workers,shuffle=False,drop_last=False) + test_loader = data_utils.DataLoader( + self.data_set, + batch_size=train_batch_size, + sampler=test_sampler, + pin_memory=pin, + num_workers=num_workers, + shuffle=False, + drop_last=False) # min error to kee ptrack of the best model. min_error = {'train': float('Inf'), 'valid': float('Inf'), - 'test' : float('Inf')} + 'test': float('Inf')} # training loop av_time = 0.0 self.data = {} for epoch in range(nepoch): - print('\n: epoch %03d / %03d ' %(epoch,nepoch) + '-'*45) + print('\n: epoch %03d / %03d ' % (epoch, nepoch) + '-' * 45) t0 = time.time() # validate the model @@ -625,93 +654,104 @@ def _train(self,index_train,index_valid,index_test, sys.stdout.flush() print(f"\n\t=> validate the model\n") - self.valid_loss,self.data['valid'] = self._epoch(valid_loader,train_model=False) + self.valid_loss, self.data['valid'] = self._epoch( + valid_loader, train_model=False) self.losses['valid'].append(self.valid_loss) if self.save_classmetrics: for i in self.metricnames: - self.classmetrics[i]['valid'].append(self.data['valid'][i]) + self.classmetrics[i]['valid'].append( + self.data['valid'][i]) # test the model if _test_: sys.stdout.flush() print(f"\n\t=> test the model\n") - test_loss,self.data['test'] = self._epoch(test_loader,train_model=False) + test_loss, self.data['test'] = self._epoch( + test_loader, train_model=False) self.losses['test'].append(test_loss) if self.save_classmetrics: for i in self.metricnames: - self.classmetrics[i]['test'].append(self.data['test'][i]) + self.classmetrics[i]['test'].append( + self.data['test'][i]) # train the model sys.stdout.flush() print(f"\n\t=> train the model\n") - self.train_loss,self.data['train'] = self._epoch(train_loader,train_model=True) + self.train_loss, self.data['train'] = self._epoch( + train_loader, train_model=True) self.losses['train'].append(self.train_loss) if self.save_classmetrics: for i in self.metricnames: self.classmetrics[i]['train'].append(self.data['train'][i]) # talk a bit about losse - print(' train loss : %1.3e' %(self.train_loss)) + print(' train loss : %1.3e' % (self.train_loss)) if _valid_: - print(' valid loss : %1.3e' %(self.valid_loss)) + print(' valid loss : %1.3e' % (self.valid_loss)) if _test_: - print(' test loss : %1.3e' %(test_loss)) + print(' test loss : %1.3e' % (test_loss)) # timer - elapsed = time.time()-t0 + elapsed = time.time() - t0 print(' epoch done in :', self.convertSeconds2Days(elapsed)) # remaining time av_time += elapsed - nremain = nepoch-(epoch+1) - remaining_time = av_time/(epoch+1)*nremain - print(' remaining time :', time.strftime('%H:%M:%S', time.gmtime(remaining_time))) + nremain = nepoch - (epoch + 1) + remaining_time = av_time / (epoch + 1) * nremain + print( + ' remaining time :', + time.strftime( + '%H:%M:%S', + time.gmtime(remaining_time))) # save the best model - for mode in ['train','valid','test']: - if not mode in self.losses: + for mode in ['train', 'valid', 'test']: + if mode not in self.losses: continue if self.losses[mode][-1] < min_error[mode]: - self.save_model(filename="best_{}_model.pth.tar".format(mode)) + self.save_model( + filename="best_{}_model.pth.tar".format(mode)) min_error[mode] = self.losses[mode][-1] - #save all the model if required + # save all the model if required if save_model == 'all': - self.save_model(filename="model_epoch_%04d.pth.tar" %epoch) + self.save_model(filename="model_epoch_%04d.pth.tar" % epoch) # plot and save epoch - if (export_intermediate and epoch%nprint == nprint-1) or epoch==0 or epoch==nepoch-1: + if (export_intermediate and epoch % + nprint == nprint - 1) or epoch == 0 or epoch == nepoch - 1: if self.plot: - figname = self.outdir+"/prediction_%04d.png" %epoch + figname = self.outdir + "/prediction_%04d.png" % epoch self._plot_scatter(figname) if self.save_hitrate: - figname = self.outdir+"/hitrate_%04d.png" %epoch + figname = self.outdir + "/hitrate_%04d.png" % epoch self.plot_hit_rate(figname) - self._export_epoch_hdf5(epoch,self.data) + self._export_epoch_hdf5(epoch, self.data) elif save_epoch == 'all': - #self._compute_hitrate() - self._export_epoch_hdf5(epoch,self.data) + # self._compute_hitrate() + self._export_epoch_hdf5(epoch, self.data) sys.stdout.flush() # plot the losses - self._export_losses(self.outdir+'/'+'losses.png') + self._export_losses(self.outdir + '/' + 'losses.png') # plot classification metrics if self.save_classmetrics: for i in self.metricnames: self._export_metrics(i) - return torch.cat([param.data.view(-1) for param in self.net.parameters()],0) - - def _epoch(self,data_loader,train_model): + return torch.cat([param.data.view(-1) + for param in self.net.parameters()], 0) + def _epoch(self, data_loader, train_model): """Perform one single epoch iteration over a data loader. Args: @@ -725,7 +765,7 @@ def _epoch(self,data_loader,train_model): # variables of the epoch running_loss = 0 - data = {'outputs':[],'targets':[],'mol':[]} + data = {'outputs': [], 'targets': [], 'mol': []} if self.save_hitrate: data['hit'] = None @@ -737,7 +777,7 @@ def _epoch(self,data_loader,train_model): debug_time = False time_learn = 0 - #set train/eval mode + # set train/eval mode self.net.train(mode=train_model) mini_batch = 0 @@ -753,7 +793,7 @@ def _epoch(self,data_loader,train_model): mol = d['mol'] # transform the data - inputs,targets = self._get_variables(inputs,targets) + inputs, targets = self._get_variables(inputs, targets) # zero gradient tlearn0 = time.time() @@ -762,12 +802,12 @@ def _epoch(self,data_loader,train_model): outputs = self.net(inputs) # class complains about the shape ... - if self.task=='class': + if self.task == 'class': targets = targets.view(-1) # evaluate loss - loss = self.criterion(outputs,targets) - running_loss += loss.data.item() # pytorch1 compatible + loss = self.criterion(outputs, targets) + running_loss += loss.data.item() # pytorch1 compatible n += len(inputs) # zero + backward + step @@ -775,29 +815,31 @@ def _epoch(self,data_loader,train_model): self.optimizer.zero_grad() loss.backward() self.optimizer.step() - time_learn += time.time()-tlearn0 + time_learn += time.time() - tlearn0 # get the outputs for export if self.cuda: - data['outputs'] += outputs.data.cpu().numpy().tolist() + data['outputs'] += outputs.data.cpu().numpy().tolist() data['targets'] += targets.data.cpu().numpy().tolist() else: - data['outputs'] += outputs.data.numpy().tolist() + data['outputs'] += outputs.data.numpy().tolist() data['targets'] += targets.data.numpy().tolist() - fname,molname = mol[0],mol[1] - data['mol'] += [ (f,m) for f,m in zip(fname,molname)] + fname, molname = mol[0], mol[1] + data['mol'] += [(f, m) for f, m in zip(fname, molname)] # transform the output back if self.data_set.normalize_targets: - data['outputs'] = self.data_set.backtransform_target(np.array(data['outputs']))#.flatten()) - data['targets'] = self.data_set.backtransform_target(np.array(data['targets']))#.flatten()) + data['outputs'] = self.data_set.backtransform_target( + np.array(data['outputs'])) # .flatten()) + data['targets'] = self.data_set.backtransform_target( + np.array(data['targets'])) # .flatten()) else: - data['outputs'] = np.array(data['outputs'])#.flatten() - data['targets'] = np.array(data['targets'])#.flatten() + data['outputs'] = np.array(data['outputs']) # .flatten() + data['targets'] = np.array(data['targets']) # .flatten() # make np for export - data['mol'] = np.array(data['mol'],dtype=object) + data['mol'] = np.array(data['mol'], dtype=object) # get the relevance of the ranking if self.save_hitrate: @@ -816,11 +858,9 @@ def _epoch(self,data_loader,train_model): return running_loss, data - - def _get_variables(self,inputs,targets): + def _get_variables(self, inputs, targets): # xue: why not put this step to DataSet.py? - - '''Convert the feature/target in torch.Variables. + """Convert the feature/target in torch.Variables. The format is different for regression where the targets are float and classification where they are int. @@ -832,40 +872,41 @@ def _get_variables(self,inputs,targets): Returns: torch.Variable: features torch.Variable: target values - ''' + """ # if cuda is available if self.cuda: inputs = inputs.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) - # get the varialbe as float by default - inputs,targets = Variable(inputs).float(),Variable(targets).float() + inputs, targets = Variable(inputs).float(), Variable(targets).float() # change the targets to long for classification if self.task == 'class': - targets = targets.long() - - return inputs,targets + targets = targets.long() + return inputs, targets - def _export_losses(self,figname): - - '''Plot the losses vs the epoch + def _export_losses(self, figname): + """Plot the losses vs the epoch. Args: figname (str): name of the file where to export the figure - ''' + """ print('\n --> Loss Plot') - color_plot = ['red','blue','green'] - labels = ['Train','Valid','Test'] + color_plot = ['red', 'blue', 'green'] + labels = ['Train', 'Valid', 'Test'] - fig,ax = plt.subplots() - for ik,name in enumerate(self.losses): - plt.plot(np.array(self.losses[name]),c=color_plot[ik],label=labels[ik]) + fig, ax = plt.subplots() + for ik, name in enumerate(self.losses): + plt.plot( + np.array( + self.losses[name]), + c=color_plot[ik], + label=labels[ik]) legend = ax.legend(loc='upper left') ax.set_xlabel('Epoch') @@ -876,21 +917,20 @@ def _export_losses(self,figname): grp = self.f5.create_group('/losses/') grp.attrs['type'] = 'losses' - for k,v in self.losses.items(): - grp.create_dataset(k,data=v) - + for k, v in self.losses.items(): + grp.create_dataset(k, data=v) def _export_metrics(self, metricname): - print('\n --> %s Plot' %(metricname.upper())) + print('\n --> %s Plot' % (metricname.upper())) - color_plot = ['red','blue','green'] - labels = ['Train','Valid','Test'] + color_plot = ['red', 'blue', 'green'] + labels = ['Train', 'Valid', 'Test'] data = self.classmetrics[metricname] - fig,ax = plt.subplots() - for ik,name in enumerate(data): - plt.plot(np.array(data[name]),c=color_plot[ik],label=labels[ik]) + fig, ax = plt.subplots() + for ik, name in enumerate(data): + plt.plot(np.array(data[name]), c=color_plot[ik], label=labels[ik]) legend = ax.legend(loc='upper left') ax.set_xlabel('Epoch') @@ -902,32 +942,28 @@ def _export_metrics(self, metricname): grp = self.f5.create_group(metricname) grp.attrs['type'] = metricname - for k,v in data.items(): - grp.create_dataset(k,data=v) - - - def _plot_scatter_reg(self,figname): + for k, v in data.items(): + grp.create_dataset(k, data=v) - '''Plot a scatter plots of predictions VS targets. + def _plot_scatter_reg(self, figname): + """Plot a scatter plots of predictions VS targets. Useful to visualize the performance of the training algorithm Args: figname (str): filename - - ''' + """ # abort if we don't want to plot if self.plot is False: return - print('\n --> Scatter Plot : ', figname, '\n') - color_plot = {'train':'red','valid':'blue','test':'green'} - labels = ['train','valid','test'] + color_plot = {'train': 'red', 'valid': 'blue', 'test': 'green'} + labels = ['train', 'valid', 'test'] - fig,ax = plt.subplots() + fig, ax = plt.subplots() xvalues = np.array([]) yvalues = np.array([]) @@ -939,42 +975,40 @@ def _plot_scatter_reg(self,figname): targ = self.data[l]['targets'].flatten() out = self.data[l]['outputs'].flatten() - xvalues = np.append(xvalues,targ) - yvalues = np.append(yvalues,out) + xvalues = np.append(xvalues, targ) + yvalues = np.append(yvalues, out) - ax.scatter(targ,out,c = color_plot[l],label=l) + ax.scatter(targ, out, c=color_plot[l], label=l) legend = ax.legend(loc='upper left') ax.set_xlabel('Targets') ax.set_ylabel('Predictions') - values = np.append(xvalues,yvalues) - border = 0.1 * (values.max()-values.min()) - ax.plot([values.min()-border,values.max()+border],[values.min()-border,values.max()+border]) + values = np.append(xvalues, yvalues) + border = 0.1 * (values.max() - values.min()) + ax.plot([values.min() - border, values.max() + border], + [values.min() - border, values.max() + border]) fig.savefig(figname) plt.close() - def _plot_boxplot_class(self,figname): - - ''' - Plot a boxplot of predictions VS targets useful ' - to visualize the performance of the training algorithm - This is only usefull in classification tasks + def _plot_boxplot_class(self, figname): + """Plot a boxplot of predictions VS targets useful ' to visualize the + performance of the training algorithm This is only usefull in + classification tasks. Args: figname (str): filename - - ''' + """ # abort if we don't want to plot - if self.plot == False: + if not self.plot: return print('\n --> Box Plot : ', figname, '\n') - color_plot = {'train':'red','valid':'blue','test':'green'} - labels = ['train','valid','test'] + color_plot = {'train': 'red', 'valid': 'blue', 'test': 'green'} + labels = ['train', 'valid', 'test'] nwin = len(self.data) @@ -989,11 +1023,11 @@ def _plot_boxplot_class(self,figname): out = self.data[l]['outputs'] data = [[], []] - confusion=[[0, 0], [0, 0]] - for pts,t in zip(out,tar): + confusion = [[0, 0], [0, 0]] + for pts, t in zip(out, tar): r = F.softmax(torch.FloatTensor(pts), dim=0).data.numpy() data[t].append(r[1]) - confusion[t][bool(r[1]>0.5)] += 1 + confusion[t][bool(r[1] > 0.5)] += 1 #print(" {:5s}: {:s}".format(l,str(confusion))) @@ -1005,10 +1039,8 @@ def _plot_boxplot_class(self,figname): fig.savefig(figname, bbox_inches='tight') plt.close() - - def plot_hit_rate(self,figname): - - '''Plot the hit rate of the different training/valid/test sets + def plot_hit_rate(self, figname): + """Plot the hit rate of the different training/valid/test sets. The hit rate is defined as: the percentage of positive decoys that are included among the top m decoys. @@ -1017,29 +1049,34 @@ def plot_hit_rate(self,figname): Args: figname (str): filename for the plot irmsd_thr (float, optional): threshold for 'good' models - - ''' + """ if self.plot is False: return print('\n --> Hit Rate :', figname, '\n') - color_plot = {'train':'red','valid':'blue','test':'green'} - labels = ['train','valid','test'] + color_plot = {'train': 'red', 'valid': 'blue', 'test': 'green'} + labels = ['train', 'valid', 'test'] # compute the hitrate - #self._compute_hitrate(irmsd_thr=irmsd_thr) + # self._compute_hitrate(irmsd_thr=irmsd_thr) # plot - fig,ax = plt.subplots() + fig, ax = plt.subplots() for l in labels: if l in self.data: if 'hit' in self.data[l]: hitrate = rankingMetrics.hitrate(self.data[l]['hit']) m = len(hitrate) - x = np.linspace(0,100,m) - plt.plot(x,hitrate,c = color_plot[l],label=l+' M=%d' %m) + x = np.linspace(0, 100, m) + plt.plot( + x, + hitrate, + c=color_plot[l], + label=l + + ' M=%d' % + m) legend = ax.legend(loc='upper left') ax.set_xlabel('Top M (%)') ax.set_ylabel('Hit Rate') @@ -1051,10 +1088,9 @@ def plot_hit_rate(self,figname): fig.savefig(figname) plt.close() - def _compute_hitrate(self,irmsd_thr = 4.0): - + def _compute_hitrate(self, irmsd_thr=4.0): - labels = ['train','valid','test'] + labels = ['train', 'valid', 'test'] self.hitrate = {} # get the target ordering @@ -1071,15 +1107,16 @@ def _compute_hitrate(self,irmsd_thr = 4.0): # get the irmsd irmsd = [] - for fname,mol in self.data[l]['mol']: + for fname, mol in self.data[l]['mol']: - f5 = h5py.File(fname,'r') - irmsd.append(f5[mol+'/targets/IRMSD'][()]) + f5 = h5py.File(fname, 'r') + irmsd.append(f5[mol + '/targets/IRMSD'][()]) f5.close() # sort the data if self.task == 'class': - out = F.softmax(torch.FloatTensor(out), dim=1).data.numpy()[:,1] + out = F.softmax(torch.FloatTensor( + out), dim=1).data.numpy()[:, 1] ind_sort = np.argsort(out) if not inverse: @@ -1089,20 +1126,22 @@ def _compute_hitrate(self,irmsd_thr = 4.0): irmsd = np.array(irmsd)[ind_sort] # make a binary list out of that - binary_recomendation = (irmsd<=irmsd_thr).astype('int') + binary_recomendation = (irmsd <= irmsd_thr).astype('int') # number of recommended hit npos = np.sum(binary_recomendation) if npos == 0: npos = len(irmsd) - print('Warning : Non positive decoys found in %s for hitrate plot' % l) + print( + 'Warning : Non positive decoys found in %s for hitrate plot' % + l) # get the hitrate - self.data[l]['hitrate'] = rankingMetrics.hitrate(binary_recomendation,npos) + self.data[l]['hitrate'] = rankingMetrics.hitrate( + binary_recomendation, npos) self.data[l]['relevance'] = binary_recomendation - - def _get_relevance(self,data,irmsd_thr = 4.0): + def _get_relevance(self, data, irmsd_thr=4.0): # get the target ordering inverse = self.data_set.target_ordering == 'lower' @@ -1114,15 +1153,15 @@ def _get_relevance(self,data,irmsd_thr = 4.0): # get the irmsd irmsd = [] - for fname,mol in data['mol']: + for fname, mol in data['mol']: - f5 = h5py.File(fname,'r') - irmsd.append(f5[mol+'/targets/IRMSD'][()]) + f5 = h5py.File(fname, 'r') + irmsd.append(f5[mol + '/targets/IRMSD'][()]) f5.close() # sort the data if self.task == 'class': - out = F.softmax(torch.FloatTensor(out), dim=1).data.numpy()[:,1] + out = F.softmax(torch.FloatTensor(out), dim=1).data.numpy()[:, 1] ind_sort = np.argsort(out) if not inverse: @@ -1132,8 +1171,7 @@ def _get_relevance(self,data,irmsd_thr = 4.0): irmsd = np.array(irmsd)[ind_sort] # make a binary list out of that - return (irmsd<=irmsd_thr).astype('int') - + return (irmsd <= irmsd_thr).astype('int') def _get_classmetrics(self, data, metricname): @@ -1157,17 +1195,15 @@ def _get_classmetrics(self, data, metricname): else: return None - @staticmethod def _get_binclass_prediction(data): out = data['outputs'] probility = F.softmax(torch.FloatTensor(out), dim=1).data.numpy() - pred = probility[:,0] <= probility[:,1] + pred = probility[:, 0] <= probility[:, 1] return pred.astype(int) - - def _export_epoch_hdf5(self,epoch,data): + def _export_epoch_hdf5(self, epoch, data): """Export the epoch data to the hdf5 file. Export the data of a given epoch in train/valid/test group. @@ -1179,7 +1215,7 @@ def _export_epoch_hdf5(self,epoch,data): """ # create a group - grp_name = 'epoch_%04d' %epoch + grp_name = 'epoch_%04d' % epoch grp = self.f5.create_group(grp_name) # create attribute for DeepXplroer @@ -1187,7 +1223,7 @@ def _export_epoch_hdf5(self,epoch,data): grp.attrs['task'] = self.task # loop over the pass_type : train/valid/test - for pass_type,pass_data in data.items(): + for pass_type, pass_data in data.items(): # we don't want to breack the process in case of issue try: @@ -1196,17 +1232,18 @@ def _export_epoch_hdf5(self,epoch,data): sg = grp.create_group(pass_type) # loop over the data : target/output/molname - for data_name,data_value in pass_data.items(): + for data_name, data_value in pass_data.items(): # mol name is a bit different # since there are strings if data_name == 'mol': string_dt = h5py.special_dtype(vlen=str) - sg.create_dataset(data_name,data=data_value,dtype=string_dt) + sg.create_dataset( + data_name, data=data_value, dtype=string_dt) # output/target values else: - sg.create_dataset(data_name,data=data_value) + sg.create_dataset(data_name, data=data_value) except TypeError: print('Epoch Error export') diff --git a/deeprank/learn/__init__.py b/deeprank/learn/__init__.py index 40f33c38..436417f3 100644 --- a/deeprank/learn/__init__.py +++ b/deeprank/learn/__init__.py @@ -1,5 +1,5 @@ # deep learning +from .modelGenerator import NetworkGenerator from .DataSet import DataSet from .NeuralNet import NeuralNet -from .modelGenerator import NetworkGenerator -from .metaqnn import MetaQNN +from .metaqnn import MetaQNN \ No newline at end of file diff --git a/deeprank/learn/classMetrics.py b/deeprank/learn/classMetrics.py index c47982df..5729a5d6 100644 --- a/deeprank/learn/classMetrics.py +++ b/deeprank/learn/classMetrics.py @@ -1,6 +1,7 @@ -import numpy as np import warnings +import numpy as np + # info # https://en.wikipedia.org/wiki/Precision_and_recall @@ -17,11 +18,12 @@ def sensitivity(yp, yt): """ tp = true_positive(yp, yt) p = positive(yt) - if p==0: - tpr=float('inf') - warnings.warn('Number of positive cases is 0, TPR or sensitivity is assigned as inf') + if p == 0: + tpr = float('inf') + warnings.warn( + 'Number of positive cases is 0, TPR or sensitivity is assigned as inf') else: - tpr = tp/p + tpr = tp / p return tpr @@ -38,10 +40,11 @@ def specificity(yp, yt): tn = true_negative(yp, yt) n = negative(yt) if n == 0: - warnings.warn('Number of negative cases is 0, TNR or sepcificity is assigned as inf') + warnings.warn( + 'Number of negative cases is 0, TNR or sepcificity is assigned as inf') tnr = float('inf') else: - tnr = tn/n + tnr = tn / n return tnr @@ -57,16 +60,17 @@ def precision(yp, yt): """ tp = true_positive(yp, yt) fp = false_positive(yp, yt) - if tp+fp == 0: - warnings.warn('Total number of true positive and false positive cases is 0, PPV or precision is assigned as inf') + if tp + fp == 0: + warnings.warn( + 'Total number of true positive and false positive cases is 0, PPV or precision is assigned as inf') ppv = float('inf') else: - ppv = tp/(tp+fp) + ppv = tp / (tp + fp) return ppv def accuracy(yp, yt): - """Accuracy + """Accuracy. Args: yp (array): predictions @@ -79,12 +83,12 @@ def accuracy(yp, yt): tn = true_negative(yp, yt) p = positive(yt) n = negative(yt) - acc = (tp+tn)/(p+n) + acc = (tp + tn) / (p + n) return acc def F1(yp, yt): - """F1 score + """F1 score. Args: yp (array): predictions @@ -96,60 +100,60 @@ def F1(yp, yt): tp = true_positive(yp, yt) fp = false_positive(yp, yt) fn = false_negative(yp, yt) - f1 = 2*tp/(2*tp+fp+fn) + f1 = 2 * tp / (2 * tp + fp + fn) return f1 def true_positive(yp, yt): - """number of true positive cases + """number of true positive cases. Args: yp (array): predictions yt (array): targets """ yp, yt = _to_bool(yp), _to_bool(yt) - tp = np.logical_and(yp==True, yt==True) + tp = np.logical_and(yp, yt) return(np.sum(tp)) def true_negative(yp, yt): - """number of true negative cases + """number of true negative cases. Args: yp (array): predictions yt (array): targets """ yp, yt = _to_bool(yp), _to_bool(yt) - tn = np.logical_and(yp==False, yt==False) + tn = np.logical_and(yp == False, yt == False) return(np.sum(tn)) def false_positive(yp, yt): - """number of false positive cases + """number of false positive cases. Args: yp (array): predictions yt (array): targets """ yp, yt = _to_bool(yp), _to_bool(yt) - fp = np.logical_and(yp==True, yt==False) + fp = np.logical_and(yp, yt == False) return(np.sum(fp)) def false_negative(yp, yt): - """number of false false cases + """number of false false cases. Args: yp (array): predictions yt (array): targets """ yp, yt = _to_bool(yp), _to_bool(yt) - fn = np.logical_and(yp==False, yt==True) + fn = np.logical_and(yp == False, yt == True) return(np.sum(fn)) def positive(yt): - """The number of real positive cases + """The number of real positive cases. Args: yt (array): targets @@ -159,7 +163,7 @@ def positive(yt): def negative(yt): - """The nunber of real negative cases + """The nunber of real negative cases. Args: yt (array): targets @@ -169,7 +173,7 @@ def negative(yt): def _to_bool(x): - """convert array values to boolean values + """convert array values to boolean values. Args: x (array): values should be 0 or 1 diff --git a/deeprank/learn/metaqnn.py b/deeprank/learn/metaqnn.py index 872044ed..1227f8cb 100644 --- a/deeprank/learn/metaqnn.py +++ b/deeprank/learn/metaqnn.py @@ -1,19 +1,26 @@ -import numpy as np import pickle -import torch.optim as optim -from deeprank.learn import NetworkGenerator,NeuralNet +import numpy as np + import deeprank.learn.modelGenerator +import torch.optim as optim +from deeprank.learn import NetworkGenerator, NeuralNet + class saved_model(object): - def __init__(self,conv_layers_params=None,fc_layers_params=None,reward=None): + def __init__( + self, + conv_layers_params=None, + fc_layers_params=None, + reward=None): self.conv_layers_params = conv_layers_params self.fc_layers_params = fc_layers_params self.reward = reward + class MetaQNN(object): - def __init__(self,final_dim=1): + def __init__(self, final_dim=1): # names self.model_name = 'conv3d' @@ -24,32 +31,32 @@ def __init__(self,final_dim=1): self.memory = [] # max number of layers - self.num_conv_layers = range(1,11) - self.num_fc_layers = range(1,5) + self.num_conv_layers = range(1, 11) + self.num_fc_layers = range(1, 5) - #types of layers possible - self.conv_types = ['conv','dropout','pool'] + # types of layers possible + self.conv_types = ['conv', 'dropout', 'pool'] # types of post processing # must be in torch.nn.functional - self.post_types = [None,'relu'] + self.post_types = [None, 'relu'] # params of conv layers self.conv_params = {} - self.conv_params['output_size'] = range(1,10) - self.conv_params['kernel_size'] = range(2,5) + self.conv_params['output_size'] = range(1, 10) + self.conv_params['kernel_size'] = range(2, 5) # params of pool layers self.pool_params = {} - self.pool_params['kernel_size'] = range(2,5) + self.pool_params['kernel_size'] = range(2, 5) # params of the dropout layers self.dropout_params = {} - self.dropout_params['percent'] = np.linspace(0.1,0.9,9) + self.dropout_params['percent'] = np.linspace(0.1, 0.9, 9) # params of the fc layers self.fc_params = {} - self.fc_params['output_size'] = [2**i for i in range(4,11)] + self.fc_params['output_size'] = [2**i for i in range(4, 11)] # store the current layers/reward self.conv_layers = [] @@ -81,16 +88,16 @@ def store_model(self): fc_layers_params.append(layer.__get_params__()) self.memory.append(saved_model(conv_layers_params=conv_layers_params, - fc_layers_params=fc_layers_params), - reward=self.reward) + fc_layers_params=fc_layers_params), + reward=self.reward) ######################################### # # save the the entire memory to disk # ######################################### - def pickle_memory(self,fname='memory.pkl'): - pickle.dump(self.memory,open(fname,"wb")) + def pickle_memory(self, fname='memory.pkl'): + pickle.dump(self.memory, open(fname, "wb")) ######################################### # @@ -99,7 +106,7 @@ def pickle_memory(self,fname='memory.pkl'): ######################################### def write_model(self): model_generator = NetworkGenerator(name=self.model_name, - fname =self.file_name, + fname=self.file_name, conv_layers=self.conv_layers, fc_layers=self.fc_layers) model_generator.print() @@ -115,7 +122,7 @@ def get_new_random_model(self): print('QNN : Generate new model') # number of conv/fc layers nconv = np.random.choice(self.num_conv_layers) - nfc = np.random.choice(self.num_fc_layers) + nfc = np.random.choice(self.num_fc_layers) # generate the conv layers self.conv_layers = [] @@ -133,9 +140,8 @@ def get_new_random_model(self): # write the model to file self.write_model() - # pick a layer type - def _init_conv_layer_random(self,ilayer): + def _init_conv_layer_random(self, ilayer): # determine wih type of layer we want # first layer is a conv @@ -144,7 +150,7 @@ def _init_conv_layer_random(self,ilayer): name = self.conv_types[0] # if rpevious layer is pool, next can't be pool - elif self.conv_layers[ilayer-1].__name__ == 'pool': + elif self.conv_layers[ilayer - 1].__name__ == 'pool': name = np.random.choice(self.conv_types[:-1]) # else it can be anything @@ -159,81 +165,86 @@ def _init_conv_layer_random(self,ilayer): params['name'] = name if ilayer == 0: - params['input_size'] = -1 #fixed by input shape + params['input_size'] = -1 # fixed by input shape else: - for isearch in range(ilayer-1,-1,-1): + for isearch in range(ilayer - 1, -1, -1): if self.conv_layers[isearch].__name__ == 'conv': params['input_size'] = self.conv_layers[isearch].output_size break - params['output_size'] = np.random.choice(self.conv_params['output_size']) - params['kernel_size'] = np.random.choice(self.conv_params['kernel_size']) + params['output_size'] = np.random.choice( + self.conv_params['output_size']) + params['kernel_size'] = np.random.choice( + self.conv_params['kernel_size']) params['post'] = np.random.choice(self.post_types) if name == 'pool': params = {} params['name'] = name - params['kernel_size'] = np.random.choice(self.pool_params['kernel_size']) + params['kernel_size'] = np.random.choice( + self.pool_params['kernel_size']) params['post'] = np.random.choice(self.post_types) - if name == 'dropout': params = {} params['name'] = name - params['percent'] = np.random.choice(self.dropout_params['percent']) + params['percent'] = np.random.choice( + self.dropout_params['percent']) # create the current layer class instance # and initialize if with the __init_from_dict__() method - current_layer = getattr(deeprank.learn.modelGenerator,params['name'])() + current_layer = getattr( + deeprank.learn.modelGenerator, + params['name'])() current_layer.__init_from_dict__(params) self.conv_layers.append(current_layer) - def _init_fc_layer_random(self,ilayer): + def _init_fc_layer_random(self, ilayer): # init the parms of the layer # each layer type has its own params # the output/input size matching is done automatically - name = 'fc' # so far only fc layer here + name = 'fc' # so far only fc layer here params = {} params['name'] = name if ilayer == 0: - params['input_size'] = -1 # fixed by the conv layers + params['input_size'] = -1 # fixed by the conv layers else: - params['input_size'] = self.fc_layers[ilayer-1].output_size + params['input_size'] = self.fc_layers[ilayer - 1].output_size params['output_size'] = np.random.choice(self.fc_params['output_size']) params['post'] = np.random.choice(self.post_types) - - current_layer = getattr(deeprank.learn.modelGenerator,params['name'])() + current_layer = getattr( + deeprank.learn.modelGenerator, + params['name'])() current_layer.__init_from_dict__(params) self.fc_layers.append(current_layer) # load the data set in memory only once - def load_dataset(self,database,feature='all',target='DOCKQ'): + def load_dataset(self, database, feature='all', target='DOCKQ'): print('QNN : Load data set') self.data_set = DataSet(database, - select_feature=feature, - select_target=target, - normalize_features=True, - normalize_targets=True) + select_feature=feature, + select_target=target, + normalize_features=True, + normalize_targets=True) self.data_set.load() - def train_model(self,cuda=False,ngpu=0): + def train_model(self, cuda=False, ngpu=0): print('QNN : Train model') from .model3d import cnn # create the ConvNet - model = NeuralNet(self.data_set,cnn,plot=False,cuda=cuda,ngpu=ngpu) + model = NeuralNet(self.data_set, cnn, plot=False, cuda=cuda, ngpu=ngpu) # fix optimizer model.optimizer = optim.SGD(model.net.parameters(), - lr=0.001,momentum=0.9,weight_decay=0.005) + lr=0.001, momentum=0.9, weight_decay=0.005) # train and save reward - model.train(nepoch = 20) + model.train(nepoch=20) self.reward = model.test_loss - diff --git a/deeprank/learn/model2d.py b/deeprank/learn/model2d.py index a24e655c..34c6ce2d 100644 --- a/deeprank/learn/model2d.py +++ b/deeprank/learn/model2d.py @@ -1,9 +1,8 @@ import torch -from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F - +from torch.autograd import Variable ###################################################################### # @@ -11,47 +10,48 @@ # ###################################################################### -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Network Structure -#---------------------------------------------------------------------- -#conv layer 0: conv | input -1 output 4 kernel 2 post relu -#conv layer 1: pool | kernel 2 post None -#conv layer 2: conv | input 4 output 5 kernel 2 post relu -#conv layer 3: pool | kernel 2 post None -#fc layer 0: fc | input -1 output 84 post relu -#fc layer 1: fc | input 84 output 1 post None -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- +# conv layer 0: conv | input -1 output 4 kernel 2 post relu +# conv layer 1: pool | kernel 2 post None +# conv layer 2: conv | input 4 output 5 kernel 2 post relu +# conv layer 3: pool | kernel 2 post None +# fc layer 0: fc | input -1 output 84 post relu +# fc layer 1: fc | input 84 output 1 post None +# ---------------------------------------------------------------------- + class cnn(nn.Module): - def __init__(self,input_shape): - super(cnn,self).__init__() + def __init__(self, input_shape): + super(cnn, self).__init__() - self.convlayer2D_000 = nn.Conv2d(input_shape[0],4,kernel_size=2) - self.convlayer2D_001 = nn.MaxPool2d((2,2)) - self.convlayer2D_002 = nn.Conv2d(4,2,kernel_size=2) - self.convlayer2D_003 = nn.MaxPool2d((2,2)) + self.convlayer2D_000 = nn.Conv2d(input_shape[0], 4, kernel_size=2) + self.convlayer2D_001 = nn.MaxPool2d((2, 2)) + self.convlayer2D_002 = nn.Conv2d(4, 2, kernel_size=2) + self.convlayer2D_003 = nn.MaxPool2d((2, 2)) size = self._get_conv_output(input_shape) - self.fclayer2D_000 = nn.Linear(size,84) - self.fclayer2D_001 = nn.Linear(84,1) + self.fclayer2D_000 = nn.Linear(size, 84) + self.fclayer2D_001 = nn.Linear(84, 1) - def _get_conv_output(self,shape): - inp = Variable(torch.rand(1,*shape)) + def _get_conv_output(self, shape): + inp = Variable(torch.rand(1, *shape)) out = self._forward_features(inp) - return out.data.view(1,-1).size(1) + return out.data.view(1, -1).size(1) - def _forward_features(self,x): + def _forward_features(self, x): x = F.relu(self.convlayer2D_000(x)) x = self.convlayer2D_001(x) x = F.relu(self.convlayer2D_002(x)) x = self.convlayer2D_003(x) return x - def forward(self,x): + def forward(self, x): x = self._forward_features(x) - x = x.view(x.size(0),-1) + x = x.view(x.size(0), -1) x = F.relu(self.fclayer2D_000(x)) x = self.fclayer2D_001(x) - return x \ No newline at end of file + return x diff --git a/deeprank/learn/model3d.py b/deeprank/learn/model3d.py index 26663072..c3330888 100644 --- a/deeprank/learn/model3d.py +++ b/deeprank/learn/model3d.py @@ -1,8 +1,7 @@ import torch -from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F - +from torch.autograd import Variable ###################################################################### # @@ -10,103 +9,101 @@ # ###################################################################### -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Network Structure -#---------------------------------------------------------------------- -#conv layer 0: conv | input -1 output 4 kernel 2 post relu -#conv layer 1: pool | kernel 2 post None -#conv layer 2: conv | input 4 output 5 kernel 2 post relu -#conv layer 3: pool | kernel 2 post None -#fc layer 0: fc | input -1 output 84 post relu -#fc layer 1: fc | input 84 output 1 post None -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- +# conv layer 0: conv | input -1 output 4 kernel 2 post relu +# conv layer 1: pool | kernel 2 post None +# conv layer 2: conv | input 4 output 5 kernel 2 post relu +# conv layer 3: pool | kernel 2 post None +# fc layer 0: fc | input -1 output 84 post relu +# fc layer 1: fc | input 84 output 1 post None +# ---------------------------------------------------------------------- + class cnn_reg(nn.Module): - def __init__(self,input_shape): - super(cnn_reg,self).__init__() + def __init__(self, input_shape): + super(cnn_reg, self).__init__() - self.convlayer_000 = nn.Conv3d(input_shape[0],4,kernel_size=2) - self.convlayer_001 = nn.MaxPool3d((2,2,2)) - self.convlayer_002 = nn.Conv3d(4,5,kernel_size=2) - self.convlayer_003 = nn.MaxPool3d((2,2,2)) + self.convlayer_000 = nn.Conv3d(input_shape[0], 4, kernel_size=2) + self.convlayer_001 = nn.MaxPool3d((2, 2, 2)) + self.convlayer_002 = nn.Conv3d(4, 5, kernel_size=2) + self.convlayer_003 = nn.MaxPool3d((2, 2, 2)) size = self._get_conv_output(input_shape) - self.fclayer_000 = nn.Linear(size,84) - self.fclayer_001 = nn.Linear(84,1) + self.fclayer_000 = nn.Linear(size, 84) + self.fclayer_001 = nn.Linear(84, 1) - - def _get_conv_output(self,shape): + def _get_conv_output(self, shape): num_data_points = 2 - inp = Variable(torch.rand(num_data_points,*shape)) + inp = Variable(torch.rand(num_data_points, *shape)) out = self._forward_features(inp) - return out.data.view(num_data_points,-1).size(1) + return out.data.view(num_data_points, -1).size(1) - def _forward_features(self,x): + def _forward_features(self, x): x = F.relu(self.convlayer_000(x)) x = self.convlayer_001(x) x = F.relu(self.convlayer_002(x)) x = self.convlayer_003(x) return x - def forward(self,x): + def forward(self, x): x = self._forward_features(x) - x = x.view(x.size(0),-1) + x = x.view(x.size(0), -1) x = F.relu(self.fclayer_000(x)) x = self.fclayer_001(x) return x - ###################################################################### # # Model automatically generated by modelGenerator # ###################################################################### -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Network Structure -#---------------------------------------------------------------------- -#conv layer 0: conv | input -1 output 4 kernel 2 post relu -#conv layer 1: pool | kernel 2 post None -#conv layer 2: conv | input 4 output 5 kernel 2 post relu -#conv layer 3: pool | kernel 2 post None -#fc layer 0: fc | input -1 output 84 post relu -#fc layer 1: fc | input 84 output 1 post None -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- +# conv layer 0: conv | input -1 output 4 kernel 2 post relu +# conv layer 1: pool | kernel 2 post None +# conv layer 2: conv | input 4 output 5 kernel 2 post relu +# conv layer 3: pool | kernel 2 post None +# fc layer 0: fc | input -1 output 84 post relu +# fc layer 1: fc | input 84 output 1 post None +# ---------------------------------------------------------------------- class cnn_class(nn.Module): - def __init__(self,input_shape): - super(cnn_class,self).__init__() + def __init__(self, input_shape): + super(cnn_class, self).__init__() - self.convlayer_000 = nn.Conv3d(input_shape[0],4,kernel_size=2) - self.convlayer_001 = nn.MaxPool3d((2,2,2)) - self.convlayer_002 = nn.Conv3d(4,5,kernel_size=2) - self.convlayer_003 = nn.MaxPool3d((2,2,2)) + self.convlayer_000 = nn.Conv3d(input_shape[0], 4, kernel_size=2) + self.convlayer_001 = nn.MaxPool3d((2, 2, 2)) + self.convlayer_002 = nn.Conv3d(4, 5, kernel_size=2) + self.convlayer_003 = nn.MaxPool3d((2, 2, 2)) size = self._get_conv_output(input_shape) - self.fclayer_000 = nn.Linear(size,84) - self.fclayer_001 = nn.Linear(84,2) - + self.fclayer_000 = nn.Linear(size, 84) + self.fclayer_001 = nn.Linear(84, 2) - def _get_conv_output(self,shape): - inp = Variable(torch.rand(1,*shape)) + def _get_conv_output(self, shape): + inp = Variable(torch.rand(1, *shape)) out = self._forward_features(inp) - return out.data.view(1,-1).size(1) + return out.data.view(1, -1).size(1) - def _forward_features(self,x): + def _forward_features(self, x): x = F.relu(self.convlayer_000(x)) x = self.convlayer_001(x) x = F.relu(self.convlayer_002(x)) x = self.convlayer_003(x) return x - def forward(self,x): + def forward(self, x): x = self._forward_features(x) - x = x.view(x.size(0),-1) + x = x.view(x.size(0), -1) x = F.relu(self.fclayer_000(x)) x = self.fclayer_001(x) return x diff --git a/deeprank/learn/modelGenerator.py b/deeprank/learn/modelGenerator.py index 1fba2eda..7e3ea127 100644 --- a/deeprank/learn/modelGenerator.py +++ b/deeprank/learn/modelGenerator.py @@ -1,5 +1,8 @@ -import numpy as np import ast + +import numpy as np + + ################################# # # MODEL GENERATOR @@ -7,7 +10,12 @@ ################################# class NetworkGenerator(object): - def __init__(self,name='_tmp_model_',fname='_tmp_model_.py',conv_layers=None,fc_layers=None): + def __init__( + self, + name='_tmp_model_', + fname='_tmp_model_.py', + conv_layers=None, + fc_layers=None): """Automatic generation of NN files. This class allows for automatic generation of python file containing the definition of @@ -34,7 +42,6 @@ def __init__(self,name='_tmp_model_',fname='_tmp_model_.py',conv_layers=None,fc_ >>> MG = NetworkGenerator(name='test',fname='model_test.py',conv_layers=conv_layers,fc_layers=fc_layers) >>> MG.print() >>> MG.write() - """ # name of the model self.name = name @@ -50,32 +57,32 @@ def __init__(self,name='_tmp_model_',fname='_tmp_model_.py',conv_layers=None,fc_ self.final_dim = 1 # possible number of randomly generated conv/fc layers - self.num_conv_layers = range(1,11) - self.num_fc_layers = range(1,5) + self.num_conv_layers = range(1, 11) + self.num_fc_layers = range(1, 5) # possible types of conv layers - self.conv_types = ['conv','dropout','pool'] + self.conv_types = ['conv', 'dropout', 'pool'] # conv parameters self.conv_params = {} - self.conv_params['output_size'] = range(1,10) - self.conv_params['kernel_size'] = range(2,5) + self.conv_params['output_size'] = range(1, 10) + self.conv_params['kernel_size'] = range(2, 5) # pool parameters self.pool_params = {} - self.pool_params['kernel_size'] = range(2,5) + self.pool_params['kernel_size'] = range(2, 5) # params of the dropout layers self.dropout_params = {} - self.dropout_params['percent'] = np.linspace(0.1,0.9,9) + self.dropout_params['percent'] = np.linspace(0.1, 0.9, 9) # params for the automatic generation of fc layers self.fc_params = {} - self.fc_params['output_size'] = [2**i for i in range(4,11)] + self.fc_params['output_size'] = [2**i for i in range(4, 11)] # types of post processing # must be in torch.nn.functional - self.post_types = [None,'relu'] + self.post_types = [None, 'relu'] ####################################### # @@ -87,7 +94,7 @@ def __init__(self,name='_tmp_model_',fname='_tmp_model_.py',conv_layers=None,fc_ def write(self): """Write the model to file.""" - f = open(self.fname,'w') + f = open(self.fname, 'w') self._write_import(f) self._write_definition(f) self._write_init(f) @@ -110,33 +117,33 @@ def _write_import(fhandle): fhandle.write(modules) # comment and such - def _write_definition(self,fhandle): + def _write_definition(self, fhandle): ndash = 70 - fhandle.write('#'*ndash+'\n#\n') + fhandle.write('#' * ndash + '\n#\n') fhandle.write('# Model automatically generated by modelGenerator\n') - fhandle.write('#\n'+'#'*ndash+'\n\n') + fhandle.write('#\n' + '#' * ndash + '\n\n') - fhandle.write('#'+'-'*ndash+'\n') + fhandle.write('#' + '-' * ndash + '\n') fhandle.write('# Network Structure\n') - fhandle.write('#'+'-'*ndash+'\n') - for ilayer,layer in enumerate(self.conv_layers): - fhandle.write('%s\n' %layer.__human_readable_str__(ilayer)) - for ilayer,layer in enumerate(self.fc_layers): - fhandle.write('%s\n' %layer.__human_readable_str__(ilayer)) - fhandle.write('#'+'-'*ndash+'\n\n') + fhandle.write('#' + '-' * ndash + '\n') + for ilayer, layer in enumerate(self.conv_layers): + fhandle.write('%s\n' % layer.__human_readable_str__(ilayer)) + for ilayer, layer in enumerate(self.fc_layers): + fhandle.write('%s\n' % layer.__human_readable_str__(ilayer)) + fhandle.write('#' + '-' * ndash + '\n\n') # initialization of the model # here all the layers are defined - def _write_init(self,fhandle): + def _write_init(self, fhandle): fhandle.write('class ' + self.name + '(nn.Module):\n') fhandle.write('\n') fhandle.write('\tdef __init__(self,input_shape):\n') - fhandle.write('\t\tsuper(%s,self).__init__()\n' %(self.name)) + fhandle.write('\t\tsuper(%s,self).__init__()\n' % (self.name)) fhandle.write('\n') # write the conv layer - for ilayer,layer in enumerate(self.conv_layers): - fhandle.write('\t\t%s\n' %layer.__def_str__(ilayer)) + for ilayer, layer in enumerate(self.conv_layers): + fhandle.write('\t\t%s\n' % layer.__def_str__(ilayer)) # the size determination between the conv and fc blocks fhandle.write('\n') @@ -144,8 +151,8 @@ def _write_init(self,fhandle): fhandle.write('\n') # write the fc layers - for ilayer,layer in enumerate(self.fc_layers): - fhandle.write('\t\t%s\n' %layer.__def_str__(ilayer)) + for ilayer, layer in enumerate(self.fc_layers): + fhandle.write('\t\t%s\n' % layer.__def_str__(ilayer)) fhandle.write('\n') @staticmethod @@ -162,22 +169,22 @@ def _write_conv_output(fhandle): # forward feature conatining all the conv layers # here all the conv layers are defined - def _write_forward_feature(self,fhandle): + def _write_forward_feature(self, fhandle): fhandle.write('\tdef _forward_features(self,x):\n') - for ilayer,layer in enumerate(self.conv_layers): - fhandle.write('\t\t%s\n' %layer.__use_str__(ilayer)) + for ilayer, layer in enumerate(self.conv_layers): + fhandle.write('\t\t%s\n' % layer.__use_str__(ilayer)) fhandle.write('\t\treturn x\n') fhandle.write('\n') # total forward pass # here _forward_features is used # and all the fc layers are used - def _write_forward(self,fhandle): + def _write_forward(self, fhandle): fhandle.write('\tdef forward(self,x):\n') fhandle.write('\t\tx = self._forward_features(x)\n') fhandle.write('\t\tx = x.view(x.size(0),-1)\n') - for ilayer,layer in enumerate(self.fc_layers): - fhandle.write('\t\t%s\n' %layer.__use_str__(ilayer)) + for ilayer, layer in enumerate(self.fc_layers): + fhandle.write('\t\t%s\n' % layer.__use_str__(ilayer)) fhandle.write('\t\treturn x\n') fhandle.write('\n') @@ -187,14 +194,14 @@ def print(self): ndash = 70 - print('#'+'-'*ndash) + print('#' + '-' * ndash) print('# Network Structure') - print('#'+'-'*ndash) - for ilayer,layer in enumerate(self.conv_layers): - print('%s' %layer.__human_readable_str__(ilayer)) - for ilayer,layer in enumerate(self.fc_layers): - print('%s' %layer.__human_readable_str__(ilayer)) - print('#'+'-'*ndash+'\n') + print('#' + '-' * ndash) + for ilayer, layer in enumerate(self.conv_layers): + print('%s' % layer.__human_readable_str__(ilayer)) + for ilayer, layer in enumerate(self.fc_layers): + print('%s' % layer.__human_readable_str__(ilayer)) + print('#' + '-' * ndash + '\n') ######################################### # @@ -206,7 +213,7 @@ def get_new_random_model(self): # number of conv/fc layers nconv = np.random.choice(self.num_conv_layers) - nfc = np.random.choice(self.num_fc_layers) + nfc = np.random.choice(self.num_fc_layers) # generate the conv layers self.conv_layers = [] @@ -221,9 +228,8 @@ def get_new_random_model(self): # fix the final dimension self.fc_layers[-1].output_size = self.final_dim - # pick a layer type - def _init_conv_layer_random(self,ilayer): + def _init_conv_layer_random(self, ilayer): # determine wih type of layer we want # first layer is a conv @@ -232,7 +238,7 @@ def _init_conv_layer_random(self,ilayer): name = self.conv_types[0] # if rpevious layer is pool, next can't be pool - elif self.conv_layers[ilayer-1].__name__ == 'pool': + elif self.conv_layers[ilayer - 1].__name__ == 'pool': name = np.random.choice(self.conv_types[:-1]) # else it can be anything @@ -247,28 +253,31 @@ def _init_conv_layer_random(self,ilayer): params['name'] = name if ilayer == 0: - params['input_size'] = -1 #fixed by input shape + params['input_size'] = -1 # fixed by input shape else: - for isearch in range(ilayer-1,-1,-1): + for isearch in range(ilayer - 1, -1, -1): if self.conv_layers[isearch].__name__ == 'conv': params['input_size'] = self.conv_layers[isearch].output_size break - params['output_size'] = np.random.choice(self.conv_params['output_size']) - params['kernel_size'] = np.random.choice(self.conv_params['kernel_size']) + params['output_size'] = np.random.choice( + self.conv_params['output_size']) + params['kernel_size'] = np.random.choice( + self.conv_params['kernel_size']) params['post'] = np.random.choice(self.post_types) if name == 'pool': params = {} params['name'] = name - params['kernel_size'] = np.random.choice(self.pool_params['kernel_size']) + params['kernel_size'] = np.random.choice( + self.pool_params['kernel_size']) params['post'] = np.random.choice(self.post_types) - if name == 'dropout': params = {} params['name'] = name - params['percent'] = np.random.choice(self.dropout_params['percent']) + params['percent'] = np.random.choice( + self.dropout_params['percent']) # create the current layer class instance # and initialize if with the __init_from_dict__() method @@ -276,23 +285,22 @@ def _init_conv_layer_random(self,ilayer): current_layer.__init_from_dict__(params) self.conv_layers.append(current_layer) - def _init_fc_layer_random(self,ilayer): + def _init_fc_layer_random(self, ilayer): # init the parms of the layer # each layer type has its own params # the output/input size matching is done automatically - name = 'fc' # so far only fc layer here + name = 'fc' # so far only fc layer here params = {} params['name'] = name if ilayer == 0: - params['input_size'] = -1 # fixed by the conv layers + params['input_size'] = -1 # fixed by the conv layers else: - params['input_size'] = self.fc_layers[ilayer-1].output_size + params['input_size'] = self.fc_layers[ilayer - 1].output_size params['output_size'] = np.random.choice(self.fc_params['output_size']) params['post'] = np.random.choice(self.post_types) - current_layer = ast.list_eval(params['name'])() current_layer.__init_from_dict__(params) self.fc_layers.append(current_layer) @@ -303,7 +311,12 @@ def _init_fc_layer_random(self,ilayer): ################################# class conv(object): - def __init__(self,input_size=-1,output_size=None,kernel_size=None,post=None): + def __init__( + self, + input_size=-1, + output_size=None, + kernel_size=None, + post=None): """Wrapper around the convolutional layer. Args: @@ -318,25 +331,26 @@ def __init__(self,input_size=-1,output_size=None,kernel_size=None,post=None): """ self.__name__ = 'conv' self.input_size = input_size - self.output_size = output_size + self.output_size = output_size self.kernel_size = kernel_size self.post = post - - def __def_str__(self,ilayer): + def __def_str__(self, ilayer): if ilayer == 0: - return 'self.convlayer_%03d = nn.Conv3d(input_shape[0],%d,kernel_size=%d)' %(ilayer,self.output_size,self.kernel_size) + return 'self.convlayer_%03d = nn.Conv3d(input_shape[0],%d,kernel_size=%d)' % ( + ilayer, self.output_size, self.kernel_size) else: - return 'self.convlayer_%03d = nn.Conv3d(%d,%d,kernel_size=%d)' %(ilayer,self.input_size,self.output_size,self.kernel_size) + return 'self.convlayer_%03d = nn.Conv3d(%d,%d,kernel_size=%d)' % ( + ilayer, self.input_size, self.output_size, self.kernel_size) - def __use_str__(self,ilayer): + def __use_str__(self, ilayer): if self.post is None: - return 'x = self.convlayer_%03d(x)' %ilayer - elif isinstance(self.post,str): - return 'x = F.%s(self.convlayer_%03d(x))' %(self.post,ilayer) + return 'x = self.convlayer_%03d(x)' % ilayer + elif isinstance(self.post, str): + return 'x = F.%s(self.convlayer_%03d(x))' % (self.post, ilayer) else: - print('Error with post processing of conv layer %d' %ilayer) - return 'x = self.convlayer_%03d(x)' %ilayer + print('Error with post processing of conv layer %d' % ilayer) + return 'x = self.convlayer_%03d(x)' % ilayer def __get_params__(self): params = {} @@ -347,21 +361,24 @@ def __get_params__(self): params['post'] = self.post return params - def __init_from_dict__(self,params): + def __init_from_dict__(self, params): self.input_size = params['input_size'] - self.output_size = params['output_size'] + self.output_size = params['output_size'] self.kernel_size = params['kernel_size'] self.post = params['post'] - def __human_readable_str__(self,ilayer): - return '#conv layer % 3d: conv | input % 2d output % 2d kernel % 2d post %s' %(ilayer,self.input_size,self.output_size,self.kernel_size,self.post) + def __human_readable_str__(self, ilayer): + return '#conv layer % 3d: conv | input % 2d output % 2d kernel % 2d post %s' % ( + ilayer, self.input_size, self.output_size, self.kernel_size, self.post) ################################# # POOL layer ################################# + + class pool(object): - def __init__(self,kernel_size=None,post=None): + def __init__(self, kernel_size=None, post=None): """Wrapper around the pool layer. Args: @@ -376,17 +393,18 @@ def __init__(self,kernel_size=None,post=None): self.kernel_size = kernel_size self.post = post - def __def_str__(self,ilayer): - return 'self.convlayer_%03d = nn.MaxPool3d((%d,%d,%d))' %(ilayer,self.kernel_size,self.kernel_size,self.kernel_size) + def __def_str__(self, ilayer): + return 'self.convlayer_%03d = nn.MaxPool3d((%d,%d,%d))' % ( + ilayer, self.kernel_size, self.kernel_size, self.kernel_size) - def __use_str__(self,ilayer): + def __use_str__(self, ilayer): if self.post is None: - return 'x = self.convlayer_%03d(x)' %ilayer - elif isinstance(self.post,str): - return 'x = F.%s(self.convlayer_%03d(x))' %(self.post,ilayer) + return 'x = self.convlayer_%03d(x)' % ilayer + elif isinstance(self.post, str): + return 'x = F.%s(self.convlayer_%03d(x))' % (self.post, ilayer) else: - print('Error with post processing of conv layer %d' %ilayer) - return 'x = self.convlayer_%03d(x)' %ilayer + print('Error with post processing of conv layer %d' % ilayer) + return 'x = self.convlayer_%03d(x)' % ilayer def __get_params__(self): params = {} @@ -395,19 +413,22 @@ def __get_params__(self): params['post'] = self.post return params - def __init_from_dict__(self,params): + def __init_from_dict__(self, params): self.kernel_size = params['kernel_size'] self.post = params['post'] - def __human_readable_str__(self,ilayer): - return '#conv layer % 3d: pool | kernel % 2d post %s' %(ilayer,self.kernel_size,self.post) + def __human_readable_str__(self, ilayer): + return '#conv layer % 3d: pool | kernel % 2d post %s' % ( + ilayer, self.kernel_size, self.post) ################################# # dropout layer ################################# + + class dropout(object): - def __init__(self,percent=0.5): + def __init__(self, percent=0.5): """Wrapper around the dropout layer layer. Args: @@ -418,14 +439,15 @@ def __init__(self,percent=0.5): >>> fc_layers.append(dropout(precent=0.25)) """ self.__name__ = 'dropout' - self.percent=percent + self.percent = percent - def __def_str__(self,ilayer): - return 'self.convlayer_%03d = nn.Dropout3d(%0.1f)' %(ilayer,self.percent) + def __def_str__(self, ilayer): + return 'self.convlayer_%03d = nn.Dropout3d(%0.1f)' % ( + ilayer, self.percent) @staticmethod def __use_str__(ilayer): - return 'x = self.convlayer_%03d(x)' %ilayer + return 'x = self.convlayer_%03d(x)' % ilayer def __get_params__(self): params = {} @@ -433,18 +455,21 @@ def __get_params__(self): params['percent'] = self.percent return params - def __init_from_dict__(self,params): + def __init_from_dict__(self, params): self.percent = params['percent'] - def __human_readable_str__(self,ilayer): - return '#conv layer % 3d: drop | percent %0.1f' %(ilayer,self.percent) + def __human_readable_str__(self, ilayer): + return '#conv layer % 3d: drop | percent %0.1f' % ( + ilayer, self.percent) ################################# # fully connected layer ################################# + + class fc(object): - def __init__(self,input_size=-1,output_size=None,post=None): + def __init__(self, input_size=-1, output_size=None, post=None): """Wrapper around the fully conneceted layer. Args: @@ -461,20 +486,22 @@ def __init__(self,input_size=-1,output_size=None,post=None): self.output_size = output_size self.post = post - def __def_str__(self,ilayer): + def __def_str__(self, ilayer): if ilayer == 0: - return 'self.fclayer_%03d = nn.Linear(size,%d)' %(ilayer,self.output_size) + return 'self.fclayer_%03d = nn.Linear(size,%d)' % ( + ilayer, self.output_size) else: - return 'self.fclayer_%03d = nn.Linear(%d,%d)' %(ilayer,self.input_size,self.output_size) + return 'self.fclayer_%03d = nn.Linear(%d,%d)' % ( + ilayer, self.input_size, self.output_size) - def __use_str__(self,ilayer): + def __use_str__(self, ilayer): if self.post is None: - return 'x = self.fclayer_%03d(x)' %ilayer - elif isinstance(self.post,str): - return 'x = F.%s(self.fclayer_%03d(x))' %(self.post,ilayer) + return 'x = self.fclayer_%03d(x)' % ilayer + elif isinstance(self.post, str): + return 'x = F.%s(self.fclayer_%03d(x))' % (self.post, ilayer) else: - print('Error with post processing of conv layer %d' %ilayer) - return 'x = self.fclayer_%03d(x)' %ilayer + print('Error with post processing of conv layer %d' % ilayer) + return 'x = self.fclayer_%03d(x)' % ilayer def __get_params__(self): params = {} @@ -484,29 +511,38 @@ def __get_params__(self): params['post'] = self.post return params - def __init_from_dict__(self,params): + def __init_from_dict__(self, params): self.input_size = params['input_size'] - self.output_size = params['output_size'] + self.output_size = params['output_size'] self.post = params['post'] - def __human_readable_str__(self,ilayer): - return '#fc layer % 3d: fc | input % 2d output % 2d post %s' %(ilayer,self.input_size,self.output_size,self.post) + def __human_readable_str__(self, ilayer): + return '#fc layer % 3d: fc | input % 2d output % 2d post %s' % ( + ilayer, self.input_size, self.output_size, self.post) -if __name__== '__main__': - +if __name__ == '__main__': conv_layers = [] - conv_layers.append(conv(output_size=4,kernel_size=2,post='relu')) + conv_layers.append(conv(output_size=4, kernel_size=2, post='relu')) conv_layers.append(pool(kernel_size=2)) - conv_layers.append(conv(input_size=4,output_size=5,kernel_size=2,post='relu')) + conv_layers.append( + conv( + input_size=4, + output_size=5, + kernel_size=2, + post='relu')) conv_layers.append(pool(kernel_size=2)) fc_layers = [] - fc_layers.append(fc(output_size=84,post='relu')) - fc_layers.append(fc(input_size=84,output_size=1)) - - MG = NetworkGenerator(name='test',fname='model_test.py',conv_layers=conv_layers,fc_layers=fc_layers) + fc_layers.append(fc(output_size=84, post='relu')) + fc_layers.append(fc(input_size=84, output_size=1)) + + MG = NetworkGenerator( + name='test', + fname='model_test.py', + conv_layers=conv_layers, + fc_layers=fc_layers) MG.print() MG.write() diff --git a/deeprank/learn/rankingMetrics.py b/deeprank/learn/rankingMetrics.py index d783a76a..fdc240af 100644 --- a/deeprank/learn/rankingMetrics.py +++ b/deeprank/learn/rankingMetrics.py @@ -1,4 +1,4 @@ -"""Information Retrieval metrics +"""Information Retrieval metrics. Useful Resources: http://www.cs.utexas.edu/~mooney/ir-course/slides/Evaluation.ppt @@ -11,8 +11,8 @@ def hitrate(rs): - """Hit rate Basically equivalent to the recall@k - First element is rank 1, Relevance is binray + """Hit rate Basically equivalent to the recall@k First element is rank 1, + Relevance is binray. Example: @@ -55,8 +55,7 @@ def avprec(rs): def recall(rs, nr): - """recall rate - First element is rank 1, Relevance is binray + """recall rate First element is rank 1, Relevance is binray. Example: @@ -74,11 +73,11 @@ def recall(rs, nr): recall (int): recall value """ - return np.sum(rs)/nr + return np.sum(rs) / nr def mean_reciprocal_rank(rs): - """Score is reciprocal of the rank of the first relevant item + """Score is reciprocal of the rank of the first relevant item. First element is 'rank 1'. Relevance is binary (nonzero is relevant). @@ -105,7 +104,7 @@ def mean_reciprocal_rank(rs): def r_precision(r): - """Score is precision after all relevant documents have been retrieved + """Score is precision after all relevant documents have been retrieved. Relevance is binary (nonzero is relevant). @@ -195,7 +194,7 @@ def average_precision(r): def mean_average_precision(rs): - """Score is mean average precision + """Score is mean average precision. Relevance is binary (nonzero is relevant). diff --git a/deeprank/targets/binary_class.py b/deeprank/targets/binary_class.py index 9b4efe9e..16739491 100644 --- a/deeprank/targets/binary_class.py +++ b/deeprank/targets/binary_class.py @@ -1,16 +1,16 @@ import numpy as np + def __compute_target__(decoy, targrp): - """ - pdb_data (bytes): PDB translated in bytes - targrp (h5 file handle): name of the group where to store the targets + """pdb_data (bytes): PDB translated in bytes targrp (h5 file handle): name + of the group where to store the targets. - e.g., - f = h5py.File('1LFD.hdf5') - targrp = f['1LFD_9w/targets'] + e.g., + f = h5py.File('1LFD.hdf5') + targrp = f['1LFD_9w/targets'] - list(targrp) - ['DOCKQ', 'FNAT', 'IRMSD', 'LRMSD'] + list(targrp) + ['DOCKQ', 'FNAT', 'IRMSD', 'LRMSD'] """ irmsd_thr = 4 @@ -24,10 +24,10 @@ def __compute_target__(decoy, targrp): del targrp[target_name] if targrp['IRMSD'][()] <= irmsd_thr: - print (f"This is a hit (irmsd <= {irmsd_thr} A). {molname} -> irmsd: {targrp['IRMSD'][()]}") + print( + f"This is a hit (irmsd <= {irmsd_thr} A). {molname} -> irmsd: {targrp['IRMSD'][()]}") classID = 1 else: classID = 0 - targrp.create_dataset('BIN_CLASS',data=np.array(classID)) - + targrp.create_dataset('BIN_CLASS', data=np.array(classID)) diff --git a/deeprank/targets/dockQ.py b/deeprank/targets/dockQ.py index 8be62fb5..0bdb1ec8 100644 --- a/deeprank/targets/dockQ.py +++ b/deeprank/targets/dockQ.py @@ -1,8 +1,11 @@ -from deeprank.tools.StructureSimilarity import StructureSimilarity import os + import numpy as np -def __compute_target__(decoy,targrp): +from deeprank.tools.StructureSimilarity import StructureSimilarity + + +def __compute_target__(decoy, targrp): # fet the mol group molgrp = targrp.parent @@ -14,7 +17,7 @@ def __compute_target__(decoy,targrp): if not os.path.isdir(ZONE): os.mkdir(ZONE) - for target_name in ['LRMSD','IRMSD','FNAT','DOCKQ']: + for target_name in ['LRMSD', 'IRMSD', 'FNAT', 'DOCKQ']: if target_name in targrp.keys(): del targrp[target_name] @@ -23,34 +26,33 @@ def __compute_target__(decoy,targrp): # lrmsd = irmsd = 0 | fnat = dockq = 1 print(f"WARNING: {molname} has no '_' indicating it is a bound complex. Assign 0, 0, 1 and 1 for lrmsd, irmsd, fnat, DockQ, respectively") - targrp.create_dataset('LRMSD',data=np.array(0.0)) - targrp.create_dataset('IRMSD',data=np.array(0.0)) + targrp.create_dataset('LRMSD', data=np.array(0.0)) + targrp.create_dataset('IRMSD', data=np.array(0.0)) targrp.create_dataset('FNAT', data=np.array(1.0)) - targrp.create_dataset('DOCKQ',data=np.array(1.0)) + targrp.create_dataset('DOCKQ', data=np.array(1.0)) # or it's a decoy else: # compute the izone/lzone/ref_pairs molname = molname.split('_')[0] - lzone = ZONE + molname+'.lzone' - izone = ZONE + molname+'.izone' + lzone = ZONE + molname + '.lzone' + izone = ZONE + molname + '.izone' ref_pairs = ZONE + molname + '.ref_pairs' # init the class decoy = molgrp['complex'][:] ref = molgrp['native'][:] - sim = StructureSimilarity(decoy,ref) + sim = StructureSimilarity(decoy, ref) - lrmsd = sim.compute_lrmsd_fast(method='svd',lzone=lzone) - targrp.create_dataset('LRMSD',data=np.array(lrmsd)) + lrmsd = sim.compute_lrmsd_fast(method='svd', lzone=lzone) + targrp.create_dataset('LRMSD', data=np.array(lrmsd)) - irmsd = sim.compute_irmsd_fast(method='svd',izone=izone) - targrp.create_dataset('IRMSD',data=np.array(irmsd)) + irmsd = sim.compute_irmsd_fast(method='svd', izone=izone) + targrp.create_dataset('IRMSD', data=np.array(irmsd)) Fnat = sim.compute_Fnat_fast(ref_pairs=ref_pairs) - targrp.create_dataset('FNAT',data=np.array(Fnat)) - - dockQ = sim.compute_DockQScore(Fnat,lrmsd,irmsd) - targrp.create_dataset('DOCKQ',data=np.array(dockQ)) + targrp.create_dataset('FNAT', data=np.array(Fnat)) + dockQ = sim.compute_DockQScore(Fnat, lrmsd, irmsd) + targrp.create_dataset('DOCKQ', data=np.array(dockQ)) diff --git a/deeprank/tools/StructureSimilarity.py b/deeprank/tools/StructureSimilarity.py index 72fcb412..0ee95413 100644 --- a/deeprank/tools/StructureSimilarity.py +++ b/deeprank/tools/StructureSimilarity.py @@ -1,11 +1,17 @@ +import os +import pickle + import numpy as np + from deeprank.tools import pdb2sql -import os,pickle -_printif = lambda string,cond: print(string) if cond else None + + +def _printif(string, cond): return print(string) if cond else None + class StructureSimilarity(object): - def __init__(self,decoy,ref,verbose=False): + def __init__(self, decoy, ref, verbose=False): """Compute the structure similarity between different molecules. This class allows to compute the i-RMSD, L-RMSD, Fnat and DockQ score of a given conformation. @@ -32,14 +38,13 @@ def __init__(self,decoy,ref,verbose=False): >>> Fnat = sim.compute_Fnat_pdb2sql() >>> Fnat_fast = sim.compute_Fnat_fast(ref_pairs='1AK4.ref_pairs') >>> dockQ = sim.compute_DockQScore(Fnat_fast,lrmsd_fast,irmsd_fast) - """ self.decoy = decoy self.ref = ref self.verbose = verbose - def compute_lrmsd_fast(self,lzone=None,method='svd',check=True): + def compute_lrmsd_fast(self, lzone=None, method='svd', check=True): """Fast routine to compute the L-RMSD. This routine parse the PDB directly without using pdb2sql. @@ -62,7 +67,7 @@ def compute_lrmsd_fast(self,lzone=None,method='svd',check=True): if lzone is None: resData = self.compute_lzone(save_file=False) elif not os.path.isfile(lzone): - self.compute_lzone(save_file=True,filename=lzone) + self.compute_lzone(save_file=True, filename=lzone) resData = self.read_zone(lzone) else: resData = self.read_zone(lzone) @@ -76,13 +81,15 @@ def compute_lrmsd_fast(self,lzone=None,method='svd',check=True): ################################################## if check: - data_decoy_long, data_decoy_short = self.read_data_zone(self.decoy,resData,return_not_in_zone=True) - data_ref_long, data_ref_short = self.read_data_zone(self.ref,resData,return_not_in_zone=True) + data_decoy_long, data_decoy_short = self.read_data_zone( + self.decoy, resData, return_not_in_zone=True) + data_ref_long, data_ref_short = self.read_data_zone( + self.ref, resData, return_not_in_zone=True) - atom_decoy_long = [ data[:3] for data in data_decoy_long ] - atom_ref_long = [ data[:3] for data in data_ref_long ] + atom_decoy_long = [data[:3] for data in data_decoy_long] + atom_ref_long = [data[:3] for data in data_ref_long] - xyz_decoy_long, xyz_ref_long = [],[] + xyz_decoy_long, xyz_ref_long = [], [] for ind_decoy, at in enumerate(atom_decoy_long): try: @@ -93,9 +100,9 @@ def compute_lrmsd_fast(self,lzone=None,method='svd',check=True): pass atom_decoy_short = [data[:3] for data in data_decoy_short] - atom_ref_short = [data[:3] for data in data_ref_short] + atom_ref_short = [data[:3] for data in data_ref_short] - xyz_decoy_short, xyz_ref_short = [],[] + xyz_decoy_short, xyz_ref_short = [], [] for ind_decoy, at in enumerate(atom_decoy_short): try: ind_ref = atom_ref_short.index(at) @@ -104,37 +111,40 @@ def compute_lrmsd_fast(self,lzone=None,method='svd',check=True): except ValueError: pass - # extract the xyz else: print('WARNING : The atom order have not been checked. Switch to check=True or continue at your own risk') - xyz_decoy_long,xyz_decoy_short = self.read_xyz_zone(self.decoy,resData,return_not_in_zone=True) - xyz_ref_long,xyz_ref_short = self.read_xyz_zone(self.ref,resData,return_not_in_zone=True) + xyz_decoy_long, xyz_decoy_short = self.read_xyz_zone( + self.decoy, resData, return_not_in_zone=True) + xyz_ref_long, xyz_ref_short = self.read_xyz_zone( + self.ref, resData, return_not_in_zone=True) # get the translation so that both A chains are centered tr_decoy = self.get_trans_vect(xyz_decoy_long) tr_ref = self.get_trans_vect(xyz_ref_long) # translate everything for 1 - xyz_decoy_short = self.translation(xyz_decoy_short,tr_decoy) - xyz_decoy_long = self.translation(xyz_decoy_long,tr_decoy) + xyz_decoy_short = self.translation(xyz_decoy_short, tr_decoy) + xyz_decoy_long = self.translation(xyz_decoy_long, tr_decoy) # translate everuthing for 2 - xyz_ref_short = self.translation(xyz_ref_short,tr_ref) - xyz_ref_long = self.translation(xyz_ref_long,tr_ref) + xyz_ref_short = self.translation(xyz_ref_short, tr_ref) + xyz_ref_long = self.translation(xyz_ref_long, tr_ref) # get the ideql rotation matrix # to superimpose the A chains - U = self.get_rotation_matrix(xyz_decoy_long,xyz_ref_long,method=method) + U = self.get_rotation_matrix( + xyz_decoy_long, xyz_ref_long, method=method) # rotate the entire fragment - xyz_decoy_short = self.rotation_matrix(xyz_decoy_short,U,center=False) + xyz_decoy_short = self.rotation_matrix( + xyz_decoy_short, U, center=False) # compute the RMSD - return self.get_rmsd(xyz_decoy_short,xyz_ref_short) + return self.get_rmsd(xyz_decoy_short, xyz_ref_short) - def compute_lzone(self,save_file=True,filename=None): - """Compute the zone for L-RMSD calculation + def compute_lzone(self, save_file=True, filename=None): + """Compute the zone for L-RMSD calculation. Args: save_file (bool, optional): save the zone file @@ -144,16 +154,19 @@ def compute_lzone(self,save_file=True,filename=None): dict: definition of the zone """ sql_ref = pdb2sql(self.ref) - nA = len(sql_ref.get('x,y,z',chainID='A')) - nB = len(sql_ref.get('x,y,z',chainID='B')) + nA = len(sql_ref.get('x,y,z', chainID='A')) + nB = len(sql_ref.get('x,y,z', chainID='B')) # detect which chain is the longest long_chain = 'A' - if nAnB: + nA, nB = len(xyz_decoy_A), len(xyz_decoy_B) + if nA > nB: xyz_decoy_long = xyz_decoy_A xyz_ref_long = xyz_ref_A @@ -496,23 +536,24 @@ def compute_lrmsd_pdb2sql(self,exportpath=None,method='svd'): tr_ref = self.get_trans_vect(xyz_ref_long) # translate everything for 1 - xyz_decoy_short = self.translation(xyz_decoy_short,tr_decoy) - xyz_decoy_long = self.translation(xyz_decoy_long,tr_decoy) + xyz_decoy_short = self.translation(xyz_decoy_short, tr_decoy) + xyz_decoy_long = self.translation(xyz_decoy_long, tr_decoy) # translate everuthing for 2 - xyz_ref_short = self.translation(xyz_ref_short,tr_ref) - xyz_ref_long = self.translation(xyz_ref_long,tr_ref) + xyz_ref_short = self.translation(xyz_ref_short, tr_ref) + xyz_ref_long = self.translation(xyz_ref_long, tr_ref) # get the ideal rotation matrix # to superimpose the A chains - U = self.get_rotation_matrix(xyz_decoy_long,xyz_ref_long,method=method) + U = self.get_rotation_matrix( + xyz_decoy_long, xyz_ref_long, method=method) # rotate the entire fragment - xyz_decoy_short = self.rotation_matrix(xyz_decoy_short,U,center=False) - + xyz_decoy_short = self.rotation_matrix( + xyz_decoy_short, U, center=False) # compute the RMSD - lrmsd = self.get_rmsd(xyz_decoy_short,xyz_ref_short) + lrmsd = self.get_rmsd(xyz_decoy_short, xyz_ref_short) # export the pdb for verifiactions if exportpath is not None: @@ -522,19 +563,19 @@ def compute_lrmsd_pdb2sql(self,exportpath=None,method='svd'): xyz_ref = np.array(sql_ref.get('x,y,z')) # translate - xyz_ref = self.translation(xyz_ref,tr_ref) - xyz_decoy = self.translation(xyz_decoy,tr_decoy) + xyz_ref = self.translation(xyz_ref, tr_ref) + xyz_decoy = self.translation(xyz_decoy, tr_decoy) # rotate decoy - xyz_decoy= self.rotation_matrix(xyz_decoy,U,center=False) + xyz_decoy = self.rotation_matrix(xyz_decoy, U, center=False) # update the sql database sql_decoy.update_xyz(xyz_decoy) sql_ref.update_xyz(xyz_ref) # export - sql_decoy.exportpdb(exportpath+'/lrmsd_decoy.pdb') - sql_ref.exportpdb(exportpath+'/lrmsd_aligned.pdb') + sql_decoy.exportpdb(exportpath + '/lrmsd_decoy.pdb') + sql_ref.exportpdb(exportpath + '/lrmsd_aligned.pdb') # close the db sql_decoy.close() @@ -543,8 +584,8 @@ def compute_lrmsd_pdb2sql(self,exportpath=None,method='svd'): return lrmsd @staticmethod - def get_identical_atoms(db1,db2,chain): - """Return that atoms shared by both databse for a specific chain + def get_identical_atoms(db1, db2, chain): + """Return that atoms shared by both databse for a specific chain. Args: db1 (TYPE): pdb2sql database of the first conformation @@ -555,8 +596,8 @@ def get_identical_atoms(db1,db2,chain): list, list: list of xyz for both database """ # get data - data1 = db1.get('chainID,resSeq,name',chainID=chain) - data2 = db2.get('chainID,resSeq,name',chainID=chain) + data1 = db1.get('chainID,resSeq,name', chainID=chain) + data2 = db2.get('chainID,resSeq,name', chainID=chain) # tuplify data1 = [tuple(d1) for d1 in data1] @@ -566,16 +607,21 @@ def get_identical_atoms(db1,db2,chain): shared_data = list(set(data1).intersection(data2)) # get the xyz - xyz1,xyz2 = [],[] + xyz1, xyz2 = [], [] for data in shared_data: query = 'SELECT x,y,z from ATOM WHERE chainID=? AND resSeq=? and name=?' - xyz1.append(list(list(db1.c.execute(query,data))[0])) - xyz2.append(list(list(db2.c.execute(query,data))[0])) + xyz1.append(list(list(db1.c.execute(query, data))[0])) + xyz2.append(list(list(db2.c.execute(query, data))[0])) - return xyz1,xyz2 + return xyz1, xyz2 - def compute_irmsd_pdb2sql(self,cutoff=10,method='svd',izone=None,exportpath=None): - """Slow method to compute the i-rmsd + def compute_irmsd_pdb2sql( + self, + cutoff=10, + method='svd', + izone=None, + exportpath=None): + """Slow method to compute the i-rmsd. Require the precalculation of the izone. A dedicated routine is implemented to comoute the izone if izone is not given in argument the routine will compute them automatically @@ -603,16 +649,18 @@ def compute_irmsd_pdb2sql(self,cutoff=10,method='svd',izone=None,exportpath=None # get the contact atoms if izone is None: - contact_ref = sql_ref.get_contact_atoms(cutoff=cutoff,extend_to_residue=True, - return_only_backbone_atoms=False) - index_contact_ref = contact_ref[0]+contact_ref[1] + contact_ref = sql_ref.get_contact_atoms( + cutoff=cutoff, extend_to_residue=True, return_only_backbone_atoms=False) + index_contact_ref = contact_ref[0] + contact_ref[1] else: - index_contact_ref = self.get_izone_rowID(sql_ref,izone,return_only_backbone_atoms=False) - + index_contact_ref = self.get_izone_rowID( + sql_ref, izone, return_only_backbone_atoms=False) # get the xyz and atom identifier of the decoy contact atoms - xyz_contact_ref = sql_ref.get('x,y,z',rowID=index_contact_ref) - data_contact_ref = sql_ref.get('chainID,resSeq,resName,name',rowID=index_contact_ref) + xyz_contact_ref = sql_ref.get('x,y,z', rowID=index_contact_ref) + data_contact_ref = sql_ref.get( + 'chainID,resSeq,resName,name', + rowID=index_contact_ref) # get the xyz and atom indeitifier of the reference xyz_decoy = sql_decoy.get('x,y,z') @@ -625,7 +673,7 @@ def compute_irmsd_pdb2sql(self,cutoff=10,method='svd',izone=None,exportpath=None xyz_contact_decoy = [] index_contact_decoy = [] clean_ref = False - for iat,atom in enumerate(data_contact_ref): + for iat, atom in enumerate(data_contact_ref): try: index = data_decoy.index(atom) @@ -638,47 +686,55 @@ def compute_irmsd_pdb2sql(self,cutoff=10,method='svd',izone=None,exportpath=None # clean the xyz if clean_ref: - xyz_contact_ref = [xyz for xyz in xyz_contact_ref if xyz is not None] - index_contact_ref = [ind for ind in index_contact_ref if ind is not None] - + xyz_contact_ref = [ + xyz for xyz in xyz_contact_ref if xyz is not None] + index_contact_ref = [ + ind for ind in index_contact_ref if ind is not None] # check that we still have atoms in both chains - chain_decoy = list(set(sql_decoy.get('chainID',rowID=index_contact_decoy))) - chain_ref = list(set(sql_ref.get('chainID',rowID=index_contact_ref))) - - if len(chain_decoy)<1 or len(chain_ref)<1: - raise ValueError('Error in i-rmsd: only one chain represented in one chain') + chain_decoy = list( + set(sql_decoy.get('chainID', rowID=index_contact_decoy))) + chain_ref = list(set(sql_ref.get('chainID', rowID=index_contact_ref))) + if len(chain_decoy) < 1 or len(chain_ref) < 1: + raise ValueError( + 'Error in i-rmsd: only one chain represented in one chain') # get the translation so that both A chains are centered tr_decoy = self.get_trans_vect(xyz_contact_decoy) - tr_ref = self.get_trans_vect(xyz_contact_ref) + tr_ref = self.get_trans_vect(xyz_contact_ref) # translate everything - xyz_contact_decoy = self.translation(xyz_contact_decoy,tr_decoy) - xyz_contact_ref = self.translation(xyz_contact_ref,tr_ref) + xyz_contact_decoy = self.translation(xyz_contact_decoy, tr_decoy) + xyz_contact_ref = self.translation(xyz_contact_ref, tr_ref) # get the ideql rotation matrix # to superimpose the A chains - U = self.get_rotation_matrix(xyz_contact_decoy,xyz_contact_ref,method=method) + U = self.get_rotation_matrix( + xyz_contact_decoy, + xyz_contact_ref, + method=method) # rotate the entire fragment - xyz_contact_decoy = self.rotation_matrix(xyz_contact_decoy,U,center=False) + xyz_contact_decoy = self.rotation_matrix( + xyz_contact_decoy, U, center=False) # compute the RMSD - irmsd = self.get_rmsd(xyz_contact_decoy,xyz_contact_ref) - - + irmsd = self.get_rmsd(xyz_contact_decoy, xyz_contact_ref) # export the pdb for verifiactions if exportpath is not None: # update the sql database - sql_decoy.update_xyz(xyz_contact_decoy,index=index_contact_decoy) - sql_ref.update_xyz(xyz_contact_ref,index=index_contact_ref) + sql_decoy.update_xyz(xyz_contact_decoy, index=index_contact_decoy) + sql_ref.update_xyz(xyz_contact_ref, index=index_contact_ref) - sql_decoy.exportpdb(exportpath+'/irmsd_decoy.pdb',index=index_contact_decoy) - sql_ref.exportpdb(exportpath+'/irmsd_ref.pdb',index=index_contact_ref) + sql_decoy.exportpdb( + exportpath + '/irmsd_decoy.pdb', + index=index_contact_decoy) + sql_ref.exportpdb( + exportpath + '/irmsd_ref.pdb', + index=index_contact_ref) # close the db sql_decoy.close() @@ -687,8 +743,8 @@ def compute_irmsd_pdb2sql(self,cutoff=10,method='svd',izone=None,exportpath=None return irmsd @staticmethod - def get_izone_rowID(sql,izone,return_only_backbone_atoms=True): - """Compute the index of the izone atoms + def get_izone_rowID(sql, izone, return_only_backbone_atoms=True): + """Compute the index of the izone atoms. Args: sql (pdb2sql): database of the conformation @@ -703,18 +759,17 @@ def get_izone_rowID(sql,izone,return_only_backbone_atoms=True): """ # read the file if not os.path.isfile(izone): - raise FileNotFoundError('i-zone file not found',izone) + raise FileNotFoundError('i-zone file not found', izone) - with open(izone,'r') as f: - data=f.readlines() + with open(izone, 'r') as f: + data = f.readlines() # get the data out of it resData = {} for line in data: res = line.split()[1].split('-')[0] - chainID,resSeq = res[0],int(res[1:]) - + chainID, resSeq = res[0], int(res[1:]) if chainID not in resData.keys(): resData[chainID] = [] @@ -724,17 +779,23 @@ def get_izone_rowID(sql,izone,return_only_backbone_atoms=True): # get the rowID index_contact = [] - for chainID,resSeq in resData.items(): + for chainID, resSeq in resData.items(): if return_only_backbone_atoms: - index_contact += sql.get('rowID',chainID=chainID,resSeq=resSeq, - name=['C','CA','N','O']) + index_contact += sql.get('rowID', + chainID=chainID, + resSeq=resSeq, + name=['C', + 'CA', + 'N', + 'O']) else: - index_contact += sql.get('rowID',chainID=chainID,resSeq=resSeq) + index_contact += sql.get('rowID', + chainID=chainID, resSeq=resSeq) return index_contact - def compute_Fnat_pdb2sql(self,cutoff=5.0): - """Slow method to compute the FNAT usign pdb2sql + def compute_Fnat_pdb2sql(self, cutoff=5.0): + """Slow method to compute the FNAT usign pdb2sql. Args: cutoff (float, optional): cutoff for the contact atoms @@ -747,50 +808,43 @@ def compute_Fnat_pdb2sql(self,cutoff=5.0): sql_ref = pdb2sql(self.ref) # get the contact atoms of the decoy - residue_pairs_decoy = sql_decoy.get_contact_residue(cutoff=cutoff, - return_contact_pairs=True, - excludeH=True) + residue_pairs_decoy = sql_decoy.get_contact_residue( + cutoff=cutoff, return_contact_pairs=True, excludeH=True) # get the contact atoms of the ref - residue_pairs_ref = sql_ref.get_contact_residue(cutoff=cutoff, - return_contact_pairs=True, - excludeH=True) - + residue_pairs_ref = sql_ref.get_contact_residue( + cutoff=cutoff, return_contact_pairs=True, excludeH=True) # form the pair data data_pair_decoy = [] - for resA,resB_list in residue_pairs_decoy.items(): - data_pair_decoy += [ (resA,resB) for resB in resB_list ] + for resA, resB_list in residue_pairs_decoy.items(): + data_pair_decoy += [(resA, resB) for resB in resB_list] # form the pair data data_pair_ref = [] - for resA,resB_list in residue_pairs_ref.items(): - data_pair_ref += [ (resA,resB) for resB in resB_list ] - + for resA, resB_list in residue_pairs_ref.items(): + data_pair_ref += [(resA, resB) for resB in resB_list] # find the umber of residue that ref and decoys hace in common nCommon = len(set(data_pair_ref).intersection(data_pair_decoy)) # normalize - Fnat = nCommon/len(data_pair_ref) + Fnat = nCommon / len(data_pair_ref) sql_decoy.close() sql_ref.close() return Fnat - - - ################################################################################################ + ########################################################################## # # HELPER ROUTINES TO HANDLE THE ZONE FILES # - ################################################################################################# - + ########################################################################## @staticmethod - def read_xyz_zone(pdb_file,resData,return_not_in_zone=False): - """Read the xyz of the zone atoms + def read_xyz_zone(pdb_file, resData, return_not_in_zone=False): + """Read the xyz of the zone atoms. Args: pdb_file (str): filename containing the pdb of the molecule @@ -801,7 +855,7 @@ def read_xyz_zone(pdb_file,resData,return_not_in_zone=False): list(float): XYZ of the atoms in the zone """ # read the ref file - with open(pdb_file,'r') as f: + with open(pdb_file, 'r') as f: data = f.readlines() # get the xyz of the @@ -825,25 +879,25 @@ def read_xyz_zone(pdb_file,resData,return_not_in_zone=False): if chainID in resData.keys(): - if resSeq in resData[chainID] and name in ['C','CA','N','O']: - xyz_in_zone.append([x,y,z]) + if resSeq in resData[chainID] and name in [ + 'C', 'CA', 'N', 'O']: + xyz_in_zone.append([x, y, z]) - elif resSeq not in resData[chainID] and name in ['C','CA','N','O']: - xyz_not_in_zone.append([x,y,z]) + elif resSeq not in resData[chainID] and name in ['C', 'CA', 'N', 'O']: + xyz_not_in_zone.append([x, y, z]) else: - if name in ['C','CA','N','O']: - xyz_not_in_zone.append([x,y,z]) + if name in ['C', 'CA', 'N', 'O']: + xyz_not_in_zone.append([x, y, z]) if return_not_in_zone: - return xyz_in_zone,xyz_not_in_zone + return xyz_in_zone, xyz_not_in_zone else: return xyz_in_zone - @staticmethod - def read_data_zone(pdb_file,resData,return_not_in_zone=False): + def read_data_zone(pdb_file, resData, return_not_in_zone=False): """Read the data of the atoms in the zone. Args: @@ -855,11 +909,11 @@ def read_data_zone(pdb_file,resData,return_not_in_zone=False): list(float): data of the atoms in the zone """ # read the ref file - if isinstance(pdb_file,str) and os.path.isfile(pdb_file): - with open(pdb_file,'r') as f: + if isinstance(pdb_file, str) and os.path.isfile(pdb_file): + with open(pdb_file, 'r') as f: data = f.readlines() - elif isinstance(pdb_file,np.ndarray): - data = [l.decode('utf-8') for l in pdb_file] + elif isinstance(pdb_file, np.ndarray): + data = [l.decode('utf-8') for l in pdb_file] # get the xyz of the data_in_zone = [] @@ -882,23 +936,25 @@ def read_data_zone(pdb_file,resData,return_not_in_zone=False): if chainID in resData.keys(): - if resSeq in resData[chainID] and name in ['C','CA','N','O']: - data_in_zone.append([chainID,resSeq,name,x,y,z]) + if resSeq in resData[chainID] and name in [ + 'C', 'CA', 'N', 'O']: + data_in_zone.append([chainID, resSeq, name, x, y, z]) - elif resSeq not in resData[chainID] and name in ['C','CA','N','O']: - data_not_in_zone.append([chainID,resSeq,name,x,y,z]) + elif resSeq not in resData[chainID] and name in ['C', 'CA', 'N', 'O']: + data_not_in_zone.append( + [chainID, resSeq, name, x, y, z]) else: - if name in ['C','CA','N','O']: - data_not_in_zone.append([chainID,resSeq,name,x,y,z]) + if name in ['C', 'CA', 'N', 'O']: + data_not_in_zone.append( + [chainID, resSeq, name, x, y, z]) if return_not_in_zone: - return data_in_zone,data_not_in_zone + return data_in_zone, data_not_in_zone else: return data_in_zone - @staticmethod def read_zone(zone_file): """Read the zone file. @@ -914,10 +970,10 @@ def read_zone(zone_file): """ # read the izone file if not os.path.isfile(zone_file): - raise FileNotFoundError('zone file not found',zone_file) + raise FileNotFoundError('zone file not found', zone_file) - with open(zone_file,'r') as f: - data=f.readlines() + with open(zone_file, 'r') as f: + data = f.readlines() # get the data out of it resData = {} @@ -934,12 +990,12 @@ def read_zone(zone_file): # we have e.g res = [A4,A4] if len(res) == 2: res = res[0] - chainID,resSeq = res[0],int(res[1:]) + chainID, resSeq = res[0], int(res[1:]) # if the resnum was negative was negtive # we have e.g res = [A,4,A,4] elif len(res) == 4: - chainID,resSeq = res[0],-int(res[1]) + chainID, resSeq = res[0], -int(res[1]) if chainID not in resData.keys(): resData[chainID] = [] @@ -948,7 +1004,6 @@ def read_zone(zone_file): return resData - ################################################################### # # ROUTINES TO ACTUALY ALIGN THE MOLECULES @@ -956,8 +1011,8 @@ def read_zone(zone_file): ################################################################### @staticmethod - def compute_DockQScore(Fnat,lrmsd,irmsd,d1=8.5,d2=1.5): - """Compute the DockQ Score + def compute_DockQScore(Fnat, lrmsd, irmsd, d1=8.5, d2=1.5): + """Compute the DockQ Score. Args: Fnat (float): Fnat value @@ -966,14 +1021,14 @@ def compute_DockQScore(Fnat,lrmsd,irmsd,d1=8.5,d2=1.5): d1 (float, optional): first coefficient for the DockQ calculations d2 (float, optional): second coefficient for the DockQ calculations """ - def scale_rms(rms,d): - return(1./(1+(rms/d)**2)) + def scale_rms(rms, d): + return(1. / (1 + (rms / d)**2)) - return 1./3 * ( Fnat + scale_rms(lrmsd,d1) + scale_rms(irmsd,d2)) + return 1. / 3 * (Fnat + scale_rms(lrmsd, d1) + scale_rms(irmsd, d2)) @staticmethod - def get_rmsd(P,Q): - """compute the RMSD + def get_rmsd(P, Q): + """compute the RMSD. Args: P (np.array(nx3)): position of the points in the first molecule @@ -983,11 +1038,11 @@ def get_rmsd(P,Q): float: RMSD value """ n = len(P) - return np.sqrt(1./n*np.sum((P-Q)**2)) + return np.sqrt(1. / n * np.sum((P - Q)**2)) @staticmethod def get_trans_vect(P): - """Get the translationv vector to the origin + """Get the translationv vector to the origin. Args: P (np.array(nx3)): position of the points in the molecule @@ -995,27 +1050,28 @@ def get_trans_vect(P): Returns: float: minus mean value of the xyz columns """ - return -np.mean(P,0) - + return -np.mean(P, 0) # main switch for the rotation matrix # add new methods here if necessary - def get_rotation_matrix(self,P,Q,method='svd'): + def get_rotation_matrix(self, P, Q, method='svd'): # get the matrix with Kabsh method - if method.lower()=='svd': - return self.get_rotation_matrix_Kabsh(P,Q) + if method.lower() == 'svd': + return self.get_rotation_matrix_Kabsh(P, Q) # or with the quaternion method - elif method.lower()=='quaternion': - return self.get_rotation_matrix_quaternion(P,Q) + elif method.lower() == 'quaternion': + return self.get_rotation_matrix_quaternion(P, Q) else: - raise ValueError('%s is not a valid method for rmsd alignement.\n Options are svd or quaternions' %method) + raise ValueError( + '%s is not a valid method for rmsd alignement.\n Options are svd or quaternions' % + method) @staticmethod - def get_rotation_matrix_Kabsh(P,Q): - '''Get the rotation matrix to aligh two point clouds. + def get_rotation_matrix_Kabsh(P, Q): + """Get the rotation matrix to aligh two point clouds. The method is based on th Kabsh approach https://cnx.org/contents/HV-RsdwL@23/Molecular-Distance-Measures @@ -1029,7 +1085,7 @@ def get_rotation_matrix_Kabsh(P,Q): Raises: ValueError: matrix have different sizes - ''' + """ pshape = P.shape qshape = Q.shape @@ -1037,20 +1093,21 @@ def get_rotation_matrix_Kabsh(P,Q): if pshape[0] == qshape[0]: npts = pshape[0] else: - raise ValueError("Matrix don't have the same number of points",P.shape,Q.shape) - + raise ValueError( + "Matrix don't have the same number of points", + P.shape, + Q.shape) - p0,q0 = np.abs(np.mean(P,0)),np.abs(np.mean(Q,0)) + p0, q0 = np.abs(np.mean(P, 0)), np.abs(np.mean(Q, 0)) eps = 1E-6 if any(p0 > eps) or any(q0 > eps): - raise ValueError('You must center the fragment first',p0,q0) - + raise ValueError('You must center the fragment first', p0, q0) # form the covariance matrix - A = np.dot(P.T,Q)/npts + A = np.dot(P.T, Q) / npts # SVD the matrix - V,S,W = np.linalg.svd(A) + V, S, W = np.linalg.svd(A) # the W matrix returned here is # already its transpose @@ -1058,20 +1115,20 @@ def get_rotation_matrix_Kabsh(P,Q): W = W.T # determinant - d = np.linalg.det(np.dot(W,V.T)) + d = np.linalg.det(np.dot(W, V.T)) # form the U matrix Id = np.eye(3) if d < 0: - Id[2,2] = -1 + Id[2, 2] = -1 - U = np.dot(W,np.dot(Id,V.T)) + U = np.dot(W, np.dot(Id, V.T)) return U @staticmethod - def get_rotation_matrix_quaternion(P,Q): - '''Get the rotation matrix to aligh two point clouds + def get_rotation_matrix_quaternion(P, Q): + """Get the rotation matrix to aligh two point clouds. The method is based on the quaternion approach http://www.ams.stonybrook.edu/~coutsias/papers/rmsd17.pdf @@ -1085,71 +1142,73 @@ def get_rotation_matrix_quaternion(P,Q): Raises: ValueError: matrix have different sizes - ''' + """ pshape = P.shape qshape = Q.shape if pshape[0] != qshape[0]: - raise ValueError("Matrix don't have the same number of points",P.shape,Q.shape) + raise ValueError( + "Matrix don't have the same number of points", + P.shape, + Q.shape) - p0,q0 = np.abs(np.mean(P,0)),np.abs(np.mean(Q,0)) + p0, q0 = np.abs(np.mean(P, 0)), np.abs(np.mean(Q, 0)) eps = 1E-6 if any(p0 > eps) or any(q0 > eps): - raise ValueError('You must center the fragment first',p0,q0) + raise ValueError('You must center the fragment first', p0, q0) # form the correlation matrix - R = np.dot(P.T,Q) + R = np.dot(P.T, Q) # form the F matrix (eq. 10 of ref[1]) - F = np.zeros((4,4)) + F = np.zeros((4, 4)) - F[0,0] = np.trace(R) - F[0,1] = R[1,2]-R[2,1] - F[0,2] = R[2,0]-R[0,2] - F[0,3] = R[0,1]-R[1,0] + F[0, 0] = np.trace(R) + F[0, 1] = R[1, 2] - R[2, 1] + F[0, 2] = R[2, 0] - R[0, 2] + F[0, 3] = R[0, 1] - R[1, 0] - F[1,0] = R[1,2]-R[2,1] - F[1,1] = R[0,0]-R[1,1]-R[2,2] - F[1,2] = R[0,1]+R[1,0] - F[1,3] = R[0,2]+R[2,0] + F[1, 0] = R[1, 2] - R[2, 1] + F[1, 1] = R[0, 0] - R[1, 1] - R[2, 2] + F[1, 2] = R[0, 1] + R[1, 0] + F[1, 3] = R[0, 2] + R[2, 0] - F[2,0] = R[2,0]-R[0,2] - F[2,1] = R[0,1]+R[1,0] - F[2,2] =-R[0,0]+R[1,1]-R[2,2] - F[2,3] = R[1,2]+R[2,1] + F[2, 0] = R[2, 0] - R[0, 2] + F[2, 1] = R[0, 1] + R[1, 0] + F[2, 2] = -R[0, 0] + R[1, 1] - R[2, 2] + F[2, 3] = R[1, 2] + R[2, 1] - F[3,0] = R[0,1]-R[1,0] - F[3,1] = R[0,2]+R[2,0] - F[3,2] = R[1,2]+R[2,1] - F[3,3] =-R[0,0]-R[1,1]+R[2,2] + F[3, 0] = R[0, 1] - R[1, 0] + F[3, 1] = R[0, 2] + R[2, 0] + F[3, 2] = R[1, 2] + R[2, 1] + F[3, 3] = -R[0, 0] - R[1, 1] + R[2, 2] # diagonalize it - l,U = np.linalg.eig(F) + l, U = np.linalg.eig(F) # extract the eigenvect of the highest eigenvalues indmax = np.argmax(l) - q0,q1,q2,q3 = U[:,indmax] + q0, q1, q2, q3 = U[:, indmax] # form the rotation matrix (eq. 33 ref[1]) - U = np.zeros((3,3)) - - U[0,0] = q0**2+q1**2-q2**2-q3**2 - U[0,1] = 2*(q1*q2-q0*q3) - U[0,2] = 2*(q1*q3+q0*q2) - U[1,1] = 2*(q1*q2+q0*q3) - U[1,2] = q0**2-q1**2+q2*2-q3**2 - U[1,2] = 2*(q2*q3-q0*q1) - U[2,0] = 2*(q1*q3-q0*q2) - U[2,1] = 2*(q2*q3+q0*q1) - U[2,2] = q0**2-q1**2-q2**2+q3**2 + U = np.zeros((3, 3)) + + U[0, 0] = q0**2 + q1**2 - q2**2 - q3**2 + U[0, 1] = 2 * (q1 * q2 - q0 * q3) + U[0, 2] = 2 * (q1 * q3 + q0 * q2) + U[1, 1] = 2 * (q1 * q2 + q0 * q3) + U[1, 2] = q0**2 - q1**2 + q2 * 2 - q3**2 + U[1, 2] = 2 * (q2 * q3 - q0 * q1) + U[2, 0] = 2 * (q1 * q3 - q0 * q2) + U[2, 1] = 2 * (q2 * q3 + q0 * q1) + U[2, 2] = q0**2 - q1**2 - q2**2 + q3**2 return U - @staticmethod - def translation(xyz,vect): - """Translate a fragment + def translation(xyz, vect): + """Translate a fragment. Args: xyz (np.array): position of the fragment @@ -1161,7 +1220,7 @@ def translation(xyz,vect): return xyz + vect @staticmethod - def rotation_around_axis(xyz,axis,angle): + def rotation_around_axis(xyz, axis, angle): """Rotate a fragment around an axis. Args: @@ -1174,25 +1233,30 @@ def rotation_around_axis(xyz,axis,angle): """ # get the data - ct,st = np.cos(angle),np.sin(angle) - ux,uy,uz = axis + ct, st = np.cos(angle), np.sin(angle) + ux, uy, uz = axis # get the center of the molecule - xyz0 = np.mean(xyz,0) + xyz0 = np.mean(xyz, 0) # definition of the rotation matrix # see https://en.wikipedia.org/wiki/Rotation_matrix - rot_mat = np.array([ - [ct + ux**2*(1-ct), ux*uy*(1-ct) - uz*st, ux*uz*(1-ct) + uy*st], - [uy*ux*(1-ct) + uz*st, ct + uy**2*(1-ct), uy*uz*(1-ct) - ux*st], - [uz*ux*(1-ct) - uy*st, uz*uy*(1-ct) + ux*st, ct + uz**2*(1-ct) ]]) + rot_mat = np.array([[ct + ux**2 * (1 - ct), + ux * uy * (1 - ct) - uz * st, + ux * uz * (1 - ct) + uy * st], + [uy * ux * (1 - ct) + uz * st, + ct + uy**2 * (1 - ct), + uy * uz * (1 - ct) - ux * st], + [uz * ux * (1 - ct) - uy * st, + uz * uy * (1 - ct) + ux * st, + ct + uz**2 * (1 - ct)]]) # apply the rotation - return np.dot(rot_mat,(xyz-xyz0).T).T + xyz0 + return np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 @staticmethod - def rotation_euler(xyz,alpha,beta,gamma): - """Rotate a fragment from Euler rotation angle + def rotation_euler(xyz, alpha, beta, gamma): + """Rotate a fragment from Euler rotation angle. Args: xyz (np.array): original positions @@ -1204,28 +1268,26 @@ def rotation_euler(xyz,alpha,beta,gamma): np.array: Rotated positions """ # precompute the trig - ca,sa = np.cos(alpha),np.sin(alpha) - cb,sb = np.cos(beta),np.sin(beta) - cg,sg = np.cos(gamma),np.sin(gamma) - + ca, sa = np.cos(alpha), np.sin(alpha) + cb, sb = np.cos(beta), np.sin(beta) + cg, sg = np.cos(gamma), np.sin(gamma) # get the center of the molecule - xyz0 = np.mean(xyz,0) + xyz0 = np.mean(xyz, 0) # rotation matrices - rx = np.array([[1,0,0],[0,ca,-sa],[0,sa,ca]]) - ry = np.array([[cb,0,sb],[0,1,0],[-sb,0,cb]]) - rz = np.array([[cg,-sg,0],[sg,cs,0],[0,0,1]]) + rx = np.array([[1, 0, 0], [0, ca, -sa], [0, sa, ca]]) + ry = np.array([[cb, 0, sb], [0, 1, 0], [-sb, 0, cb]]) + rz = np.array([[cg, -sg, 0], [sg, cs, 0], [0, 0, 1]]) - rot_mat = np.dot(rx,np.dot(ry,rz)) + rot_mat = np.dot(rx, np.dot(ry, rz)) # apply the rotation - return np.dot(rot_mat,(xyz-xyz0).T).T + xyz0 - + return np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 @staticmethod - def rotation_matrix(xyz,rot_mat,center=True): - """Rotate a fragment from a roation matrix + def rotation_matrix(xyz, rot_mat, center=True): + """Rotate a fragment from a roation matrix. Args: xyz (np.array): original positions @@ -1237,10 +1299,9 @@ def rotation_matrix(xyz,rot_mat,center=True): """ if center: xyz0 = np.mean(xyz) - return np.dot(rot_mat,(xyz-xyz0).T).T + xyz0 + return np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 else: - return np.dot(rot_mat,(xyz).T).T - + return np.dot(rot_mat, (xyz).T).T # if __name__ == '__main__': diff --git a/deeprank/tools/__init__.py b/deeprank/tools/__init__.py index 55e21b85..8c750024 100644 --- a/deeprank/tools/__init__.py +++ b/deeprank/tools/__init__.py @@ -1,6 +1,5 @@ from .pdb2sql import pdb2sql -from .StructureSimilarity import StructureSimilarity from .sasa import SASA from .sparse import * - +from .StructureSimilarity import StructureSimilarity diff --git a/deeprank/tools/pdb2sql.py b/deeprank/tools/pdb2sql.py index 5d0f01a5..e1e27111 100644 --- a/deeprank/tools/pdb2sql.py +++ b/deeprank/tools/pdb2sql.py @@ -1,17 +1,21 @@ +import os import sqlite3 import subprocess as sp -import os -import numpy as np from time import time +import numpy as np class pdb2sql(object): - - def __init__(self,pdbfile,sqlfile=None,fix_chainID=True,verbose=False,no_extra=True): - - '''Create a SQL data base for a PDB file. + def __init__( + self, + pdbfile, + sqlfile=None, + fix_chainID=True, + verbose=False, + no_extra=True): + """Create a SQL data base for a PDB file. This allows to easily parse and extract information of the PDB using SQL queries. This is a local version of the pdb2sql tool (https://github.com/DeepRank/pdb2sql). Of pdb2sql is further developped as @@ -49,8 +53,7 @@ def __init__(self,pdbfile,sqlfile=None,fix_chainID=True,verbose=False,no_extra=T >>> >>> # close the database >>> db.close() - - ''' + """ self.pdbfile = pdbfile self.sqlfile = sqlfile self.is_valid = True @@ -61,7 +64,7 @@ def __init__(self,pdbfile,sqlfile=None,fix_chainID=True,verbose=False,no_extra=T self._create_sql() # backbone type - self.backbone_type = ['C','CA','N','O'] + self.backbone_type = ['C', 'CA', 'N', 'O'] # hard limit for the number of SQL varaibles self.SQLITE_LIMIT_VARIABLE_NUMBER = 999 @@ -71,11 +74,11 @@ def __init__(self,pdbfile,sqlfile=None,fix_chainID=True,verbose=False,no_extra=T if fix_chainID: self._fix_chainID() - ################################################################################## + ########################################################################## # # CREATION AND PRINTING # - ################################################################################## + ########################################################################## def _create_sql(self): """Create the sql database.""" @@ -86,38 +89,39 @@ def _create_sql(self): if self.verbose: print('-- Create SQLite3 database') - #name of the table + # name of the table #table = 'ATOM' # column names and types - self.col = {'serial' : 'INT', - 'name' : 'TEXT', - 'altLoc' : 'TEXT', - 'resName' : 'TEXT', - 'chainID' : 'TEXT', - 'resSeq' : 'INT', - 'iCode' : 'TEXT', - 'x' : 'REAL', - 'y' : 'REAL', - 'z' : 'REAL', - 'occ' : 'REAL', - 'temp' : 'REAL'} + self.col = {'serial': 'INT', + 'name': 'TEXT', + 'altLoc': 'TEXT', + 'resName': 'TEXT', + 'chainID': 'TEXT', + 'resSeq': 'INT', + 'iCode': 'TEXT', + 'x': 'REAL', + 'y': 'REAL', + 'z': 'REAL', + 'occ': 'REAL', + 'temp': 'REAL'} # delimtier of the column format - # taken from http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM + # taken from + # http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM self.delimiter = { - 'serial' : [6,11], - 'name' : [12,16], - 'altLoc' : [16,17], - 'resName' :[17,20], - 'chainID' :[21,22], - 'resSeq' :[22,26], - 'iCode' :[26,26], - 'x' :[30,38], - 'y' :[38,46], - 'z' :[46,54], - 'occ' :[54,60], - 'temp' :[60,66]} + 'serial': [6, 11], + 'name': [12, 16], + 'altLoc': [16, 17], + 'resName': [17, 20], + 'chainID': [21, 22], + 'resSeq': [22, 26], + 'iCode': [26, 26], + 'x': [30, 38], + 'y': [38, 46], + 'z': [46, 54], + 'occ': [54, 60], + 'temp': [60, 66]} if self.no_extra: del self.col['occ'] @@ -127,7 +131,6 @@ def _create_sql(self): ncol = len(self.col) ndel = len(self.delimiter) - # open the data base # if we do not specify a db name # the db is only in RAM @@ -135,20 +138,20 @@ def _create_sql(self): # https://stackoverflow.com/questions/764710/sqlite-performance-benchmark-why-is-memory-so-slow-only-1-5x-as-fast-as-d if self.sqlfile is None: self.conn = sqlite3.connect(':memory:') - + # or we create a new db file else: if os.path.isfile(sqlfile): - sp.call('rm %s' %sqlfile,shell=True) + sp.call('rm %s' % sqlfile, shell=True) self.conn = sqlite3.connect(sqlfile) self.c = self.conn.cursor() # intialize the header/placeholder - header,qm = '','' - for ic,(colname,coltype) in enumerate(self.col.items()): - header += '{cn} {ct}'.format(cn=colname,ct=coltype) + header, qm = '', '' + for ic, (colname, coltype) in enumerate(self.col.items()): + header += '{cn} {ct}'.format(cn=colname, ct=coltype) qm += '?' - if ic < ncol-1: + if ic < ncol - 1: header += ', ' qm += ',' @@ -156,7 +159,6 @@ def _create_sql(self): query = 'CREATE TABLE ATOM ({hd})'.format(hd=header) self.c.execute(query) - # read the pdb file # this is dangerous if there are ATOM written in the comment part # which happends often @@ -170,17 +172,18 @@ def _create_sql(self): # RMK we go through the data twice here. Once to read the ATOM line and once to parse the data ... # we could do better than that. But the most time consuming step seems to be the CREATE TABLE query # if we path a file we read it - if isinstance(pdbfile,str): + if isinstance(pdbfile, str): if os.path.isfile(pdbfile): - with open(pdbfile,'r') as fi: - data = [line.split('\n')[0] for line in fi if line.startswith('ATOM')] + with open(pdbfile, 'r') as fi: + data = [line.split('\n')[0] + for line in fi if line.startswith('ATOM')] else: - raise FileNotFoundError('File %s was not found',pdbfile) + raise FileNotFoundError('File %s was not found', pdbfile) # if we pass a list as for h5py read/write # we directly use that - elif isinstance(pdbfile,np.ndarray): - data = [l.decode('utf-8') for l in pdbfile.tolist()] + elif isinstance(pdbfile, np.ndarray): + data = [l.decode('utf-8') for l in pdbfile.tolist()] # if we cant read it else: @@ -188,7 +191,7 @@ def _create_sql(self): raise ValueError('PDB data not recognized') # if there is no ATOM in the file - if len(data)==1 and data[0]=='': + if len(data) == 1 and data[0] == '': print("-- Error : No ATOM in the pdb file.") self.is_valid = False return @@ -196,11 +199,11 @@ def _create_sql(self): # haddock chain ID fix del_copy = self.delimiter.copy() if data[0][del_copy['chainID'][0]] == ' ': - del_copy['chainID'] = [72,73] + del_copy['chainID'] = [72, 73] # get all the data data_atom = [] - for iatom,atom in enumerate(data): + for iatom, atom in enumerate(data): # sometimes we still have an empty line somewhere if len(atom) == 0: @@ -208,7 +211,7 @@ def _create_sql(self): # browse all attribute of each atom at = () - for ik,(colname,coltype) in enumerate(self.col.items()): + for ik, (colname, coltype) in enumerate(self.col.items()): # get the piece of data data = atom[del_copy[colname][0]:del_copy[colname][1]].strip() @@ -221,15 +224,15 @@ def _create_sql(self): # append keep the comma !! # we need proper tuple - at +=(data,) + at += (data,) # append data_atom.append(at) - # push in the database - self.c.executemany('INSERT INTO ATOM VALUES ({qm})'.format(qm=qm),data_atom) - + self.c.executemany( + 'INSERT INTO ATOM VALUES ({qm})'.format( + qm=qm), data_atom) def _fix_chainID(self): """Fix the chain ID if necessary. @@ -243,30 +246,31 @@ def _fix_chainID(self): data = self.get('chainID') natom = len(data) - #get uniques + # get uniques chainID = [] for c in data: if c not in chainID: chainID.append(c) - if chainID == ['A','B']: + if chainID == ['A', 'B']: return - if len(chainID)>26: - print("Warning more than 26 chains have been detected. This is so far not supported") + if len(chainID) > 26: + print( + "Warning more than 26 chains have been detected. This is so far not supported") sys.exit() # declare the new names newID = [''] * natom # fill in the new names - for ic,chain in enumerate(chainID): - index = self.get('rowID',chainID=chain) + for ic, chain in enumerate(chainID): + index = self.get('rowID', chainID=chain) for ind in index: newID[ind] = ascii_uppercase[ic] # update the new name - self.update_column('chainID',newID) + self.update_column('chainID', newID) # get the names of the columns def get_colnames(self): @@ -277,14 +281,14 @@ def get_colnames(self): names = list(map(lambda x: x[0], cd.description)) print('\trowID') for n in names: - print('\t'+n) + print('\t' + n) # print the database def prettyprint(self): """Print the database with pandas.""" import pandas.io.sql as psql - df = psql.read_sql("SELECT * FROM ATOM",self.conn) + df = psql.read_sql("SELECT * FROM ATOM", self.conn) print(df) def uglyprint(self): @@ -294,8 +298,7 @@ def uglyprint(self): ctmp.execute("SELECT * FROM ATOM") print(ctmp.fetchall()) - - ############################################################################################ + ########################################################################## # # GET FUNCTIONS # @@ -303,12 +306,10 @@ def uglyprint(self): # get_contact_atoms() -> return a list of rowID for the contact atoms # get_contact_residue() -> return a list of resSeq for the contact residue # - ############################################################################################### - + ########################################################################## - def get(self,atnames,**kwargs): - - '''Get data from the sql database. + def get(self, atnames, **kwargs): + """Get data from the sql database. Get the values of specified attributes for a specific selection. @@ -348,17 +349,18 @@ def get(self,atnames,**kwargs): >>> db = pdb2sql(filename) >>> xyz = db.get('x,y,z',name = ['CA', 'CB']) >>> xyz = db.get('x,y,z',chainID='A',resName=['VAL', 'LEU']) - - ''' + """ # the asked keys keys = kwargs.keys() # check if the column exists try: - self.c.execute("SELECT EXISTS(SELECT {an} FROM ATOM)".format(an=atnames)) - except: - print('Error column %s not found in the database' %atnames) + self.c.execute( + "SELECT EXISTS(SELECT {an} FROM ATOM)".format( + an=atnames)) + except BaseException: + print('Error column %s not found in the database' % atnames) self.get_colnames() return @@ -367,7 +369,7 @@ def get(self,atnames,**kwargs): query = 'SELECT {an} FROM ATOM'.format(an=atnames) data = [list(row) for row in self.c.execute(query)] - ############################################################################ + ####################################################################### # GENERIC QUERY # # the only one we need @@ -376,7 +378,7 @@ def get(self,atnames,**kwargs): # AND is assumed between different keys # OR is assumed for the different values of a given key # - ############################################################################## + ####################################################################### else: # check that all the keys exists @@ -386,9 +388,11 @@ def get(self,atnames,**kwargs): k = k[3:] try: - self.c.execute("SELECT EXISTS(SELECT {an} FROM ATOM)".format(an=k)) - except: - print('Error column %s not found in the database' %k) + self.c.execute( + "SELECT EXISTS(SELECT {an} FROM ATOM)".format( + an=k)) + except BaseException: + print('Error column %s not found in the database' % k) self.get_colnames() return @@ -398,8 +402,7 @@ def get(self,atnames,**kwargs): vals = () # iterate through the kwargs - for ik,(k,v) in enumerate(kwargs.items()): - + for ik, (k, v) in enumerate(kwargs.items()): # deals with negative conditions if k.startswith('no_'): @@ -411,7 +414,7 @@ def get(self,atnames,**kwargs): # get if we have an array or a scalar # and build the value tuple for the sql query # deal with the indexing issue if rowID is required - if isinstance(v,list): + if isinstance(v, list): nv = len(v) @@ -420,24 +423,25 @@ def get(self,atnames,**kwargs): # that is 999. The limit is here set to 950 # so that we can have multiple conditions with a total number # of values inferior to 999 - if nv>self.max_sql_values: + if nv > self.max_sql_values: # cut in chunck chunck_size = self.max_sql_values - vchunck = [v[i:i+chunck_size] for i in range(0,nv,chunck_size)] + vchunck = [v[i:i + chunck_size] + for i in range(0, nv, chunck_size)] data = [] for v in vchunck: new_kwargs = kwargs.copy() new_kwargs[k] = v - data += self.get(atnames,**new_kwargs) + data += self.get(atnames, **new_kwargs) return data - #otherwithe we just go on + # otherwithe we just go on else: if k == 'rowID': - vals = vals + tuple([iv+1 for iv in v]) + vals = vals + tuple([iv + 1 for iv in v]) else: vals = vals + tuple(v) @@ -445,37 +449,38 @@ def get(self,atnames,**kwargs): nv = 1 if k == 'rowID': - vals = vals + (v+1,) + vals = vals + (v + 1,) else: vals = vals + (v,) # create the condition for that key - conditions.append(k + neg + ' in (' + ','.join('?'*nv) + ')') + conditions.append(k + neg + ' in (' + ','.join('?' * nv) + ')') # stitch the conditions and append to the query query += ' AND '.join(conditions) # error if vals is too long - if len(vals)>self.SQLITE_LIMIT_VARIABLE_NUMBER: + if len(vals) > self.SQLITE_LIMIT_VARIABLE_NUMBER: print('\nError : SQL Queries can only handle a total of 999 values') - print(' : The current query has %d values' %len(vals)) + print(' : The current query has %d values' % len(vals)) print(' : Hence it will fails.') - print(' : You are in a rare situation where MULTIPLE conditions have') + print( + ' : You are in a rare situation where MULTIPLE conditions have') print(' : have a combined number of values that are too large') print(' : These conditions are:') ntot = 0 - for k,v in kwargs.items(): - print(' : --> %10s : %d values' %(k,len(v))) + for k, v in kwargs.items(): + print(' : --> %10s : %d values' % (k, len(v))) ntot += len(v) - print(' : --> %10s : %d values' %('Total',ntot)) + print(' : --> %10s : %d values' % ('Total', ntot)) print(' : Try to decrease self.max_sql_values in pdb2sql.py\n') raise ValueError('Too many SQL variables') # query the sql database and return the answer in a list - data = [list(row) for row in self.c.execute(query,vals)] + data = [list(row) for row in self.c.execute(query, vals)] # empty data - if len(data)==0: + if len(data) == 0: print('Warning sqldb.get returned an empty') return data @@ -483,17 +488,17 @@ def get(self,atnames,**kwargs): # if atnames == 'rowID': if 'rowID' in atnames: index = atnames.split(',').index('rowID') - for i,_ in enumerate(data): + for i, _ in enumerate(data): data[i][index] -= 1 # postporcess the output of the SQl query # flatten it if each els is of size 1 - if len(data[0])==1: + if len(data[0]) == 1: data = [d[0] for d in data] return data - ############################################################################ + ########################################################################## # # get the contact atoms # @@ -502,11 +507,17 @@ def get(self,atnames,**kwargs): # and possbily other submodules to do other things # that will leave only the get / put methods in the main class # - ############################################################################# - def get_contact_atoms(self,cutoff=8.5,chain1='A',chain2='B', - extend_to_residue=False,only_backbone_atoms=False, - excludeH=False,return_only_backbone_atoms=False,return_contact_pairs=False): - + ########################################################################## + def get_contact_atoms( + self, + cutoff=8.5, + chain1='A', + chain2='B', + extend_to_residue=False, + only_backbone_atoms=False, + excludeH=False, + return_only_backbone_atoms=False, + return_contact_pairs=False): """Get contact atoms of the interface. The cutoff distance is by default 8.5 Angs but can be changed at will. A few more options allows to @@ -530,47 +541,56 @@ def get_contact_atoms(self,cutoff=8.5,chain1='A',chain2='B', >>> db = pdb2sql(filename) >>> db.get_contact_atoms(cutoff=5.0,return_contact_pairs=True) - """ # xyz of the chains - xyz1 = np.array(self.get('x,y,z',chainID=chain1)) - xyz2 = np.array(self.get('x,y,z',chainID=chain2)) + xyz1 = np.array(self.get('x,y,z', chainID=chain1)) + xyz2 = np.array(self.get('x,y,z', chainID=chain2)) # index of b - index2 = self.get('rowID',chainID=chain2) + index2 = self.get('rowID', chainID=chain2) # resName of the chains - resName1 = np.array(self.get('resName',chainID=chain1)) + resName1 = np.array(self.get('resName', chainID=chain1)) #resName2 = np.array(self.get('resName',chainID=chain2)) # atomnames of the chains - atName1 = np.array(self.get('name',chainID=chain1)) - atName2 = np.array(self.get('name',chainID=chain2)) - + atName1 = np.array(self.get('name', chainID=chain1)) + atName2 = np.array(self.get('name', chainID=chain2)) # loop through the first chain # TO DO : loop through the smallest chain instead ... - index_contact_1,index_contact_2 = [],[] + index_contact_1, index_contact_2 = [], [] index_contact_pairs = {} - for i,x0 in enumerate(xyz1): + for i, x0 in enumerate(xyz1): # compute the contact atoms - contacts = np.where(np.sqrt(np.sum((xyz2-x0)**2,1)) <= cutoff )[0] + contacts = np.where( + np.sqrt( + np.sum( + (xyz2 - x0)**2, + 1)) <= cutoff)[0] # exclude the H if required if excludeH and atName1[i][0] == 'H': continue - if len(contacts)>0 and any([not only_backbone_atoms, atName1[i] in self.backbone_type]): + if len(contacts) > 0 and any( + [not only_backbone_atoms, atName1[i] in self.backbone_type]): # the contact atoms index_contact_1 += [i] - index_contact_2 += [index2[k] for k in contacts if ( any( [atName2[k] in self.backbone_type, not only_backbone_atoms]) and not (excludeH and atName2[k][0]=='H') ) ] + index_contact_2 += [index2[k] for k in contacts if (any( + [atName2[k] in self.backbone_type, not only_backbone_atoms]) and not (excludeH and atName2[k][0] == 'H'))] # the pairs - pairs = [index2[k] for k in contacts if any( [atName2[k] in self.backbone_type, not only_backbone_atoms] ) and not (excludeH and atName2[k][0]=='H') ] + pairs = [ + index2[k] for k in contacts if any( + [ + atName2[k] in self.backbone_type, + not only_backbone_atoms]) and not ( + excludeH and atName2[k][0] == 'H')] if len(pairs) > 0: index_contact_pairs[i] = pairs @@ -579,13 +599,13 @@ def get_contact_atoms(self,cutoff=8.5,chain1='A',chain2='B', index_contact_2 = sorted(set(index_contact_2)) # if no atoms were found - if len(index_contact_1)==0: + if len(index_contact_1) == 0: print('Warning : No contact atoms detected in pdb2sql') # extend the list to entire residue if extend_to_residue: - index_contact_1,index_contact_2 = self._extend_contact_to_residue(index_contact_1,index_contact_2,only_backbone_atoms) - + index_contact_1, index_contact_2 = self._extend_contact_to_residue( + index_contact_1, index_contact_2, only_backbone_atoms) # filter only the backbone atoms if return_only_backbone_atoms and not only_backbone_atoms: @@ -595,15 +615,18 @@ def get_contact_atoms(self,cutoff=8.5,chain1='A',chain2='B', atNames = np.array(self.get('name')) # change the index_contacts - index_contact_1 = [ ind for ind in index_contact_1 if atNames[ind] in self.backbone_type ] - index_contact_2 = [ ind for ind in index_contact_2 if atNames[ind] in self.backbone_type ] + index_contact_1 = [ + ind for ind in index_contact_1 if atNames[ind] in self.backbone_type] + index_contact_2 = [ + ind for ind in index_contact_2 if atNames[ind] in self.backbone_type] # change the contact pairs tmp_dict = {} - for ind1,ind2_list in index_contact_pairs.items(): + for ind1, ind2_list in index_contact_pairs.items(): if atNames[ind1] in self.backbone_type: - tmp_dict[ind1] = [ind2 for ind2 in ind2_list if atNames[ind2] in self.backbone_type] + tmp_dict[ind1] = [ + ind2 for ind2 in ind2_list if atNames[ind2] in self.backbone_type] index_contact_pairs = tmp_dict @@ -611,14 +634,14 @@ def get_contact_atoms(self,cutoff=8.5,chain1='A',chain2='B', if return_contact_pairs: return index_contact_pairs else: - return index_contact_1,index_contact_2 + return index_contact_1, index_contact_2 # extend the contact atoms to the residue - def _extend_contact_to_residue(self,index1,index2,only_backbone_atoms): + def _extend_contact_to_residue(self, index1, index2, only_backbone_atoms): # extract the data - dataA = self.get('chainId,resName,resSeq',rowID=index1) - dataB = self.get('chainId,resName,resSeq',rowID=index2) + dataA = self.get('chainId,resName,resSeq', rowID=index1) + dataB = self.get('chainId,resName,resSeq', rowID=index2) # create tuple cause we want to hash through it #dataA = list(map(lambda x: tuple(x),dataA)) @@ -631,39 +654,59 @@ def _extend_contact_to_residue(self,index1,index2,only_backbone_atoms): resB = list(set(dataB)) # init the list - index_contact_A,index_contact_B = [],[] + index_contact_A, index_contact_B = [], [] # contact of chain A for resdata in resA: - chainID,resName,resSeq = resdata + chainID, resName, resSeq = resdata if only_backbone_atoms: - index_contact_A += self.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq,name=self.backbone_type) + index_contact_A += self.get('rowID', + chainID=chainID, + resName=resName, + resSeq=resSeq, + name=self.backbone_type) else: - index_contact_A += self.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) + index_contact_A += self.get('rowID', + chainID=chainID, + resName=resName, + resSeq=resSeq) # contact of chain B for resdata in resB: - chainID,resName,resSeq = resdata + chainID, resName, resSeq = resdata if only_backbone_atoms: - index_contact_B += self.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq,name=self.backbone_type) + index_contact_B += self.get('rowID', + chainID=chainID, + resName=resName, + resSeq=resSeq, + name=self.backbone_type) else: - index_contact_B += self.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) + index_contact_B += self.get('rowID', + chainID=chainID, + resName=resName, + resSeq=resSeq) # make sure that we don't have double (maybe optional) index_contact_A = sorted(set(index_contact_A)) index_contact_B = sorted(set(index_contact_B)) - return index_contact_A,index_contact_B - + return index_contact_A, index_contact_B # get the contact residue - def get_contact_residue(self,cutoff=8.5,chain1='A',chain2='B',excludeH=False, - only_backbone_atoms=False,return_contact_pairs=False): - - """Get contact residues of the interface. The cutoff distance is by default 8.5 Angs but can be changed at will. A few more options allows to - precisely define how the contact residues are identified and returned. + def get_contact_residue( + self, + cutoff=8.5, + chain1='A', + chain2='B', + excludeH=False, + only_backbone_atoms=False, + return_contact_pairs=False): + """Get contact residues of the interface. The cutoff distance is by + default 8.5 Angs but can be changed at will. A few more options allows + to precisely define how the contact residues are identified and + returned. Args: cutoff (float): cutoff for contact atoms (default 8.5) @@ -681,7 +724,6 @@ def get_contact_residue(self,cutoff=8.5,chain1='A',chain2='B',excludeH=False, >>> db = pdb2sql(filename) >>> db.get_contact_residue(cutoff=5.0,return_contact_pairs=True) - """ # get the contact atoms if return_contact_pairs: @@ -690,50 +732,59 @@ def get_contact_residue(self,cutoff=8.5,chain1='A',chain2='B',excludeH=False, residue_contact_pairs = {} # get the contact atom pairs - atom_pairs = self.get_contact_atoms(cutoff=cutoff,chain1=chain1,chain2=chain2, - only_backbone_atoms=only_backbone_atoms, - excludeH=excludeH, - return_contact_pairs=True) + atom_pairs = self.get_contact_atoms( + cutoff=cutoff, + chain1=chain1, + chain2=chain2, + only_backbone_atoms=only_backbone_atoms, + excludeH=excludeH, + return_contact_pairs=True) # loop over the atom pair dict - for iat1,atoms2 in atom_pairs.items(): + for iat1, atoms2 in atom_pairs.items(): # get the res info of the current atom - data1 = tuple(self.get('chainID,resSeq,resName',rowID=[iat1])[0]) + data1 = tuple( + self.get( + 'chainID,resSeq,resName', + rowID=[iat1])[0]) # create a new entry in the dict if necessary if data1 not in residue_contact_pairs: residue_contact_pairs[data1] = set() # get the res info of the atom in the other chain - data2 = self.get('chainID,resSeq,resName',rowID=atoms2) + data2 = self.get('chainID,resSeq,resName', rowID=atoms2) # store that in the dict without double for resData in data2: residue_contact_pairs[data1].add(tuple(resData)) for resData in residue_contact_pairs.keys(): - residue_contact_pairs[resData] = sorted(residue_contact_pairs[resData]) + residue_contact_pairs[resData] = sorted( + residue_contact_pairs[resData]) return residue_contact_pairs else: # get the contact atoms - contact_atoms = self.get_contact_atoms(cutoff=cutoff,chain1=chain1,chain2=chain2,return_contact_pairs=False) + contact_atoms = self.get_contact_atoms( + cutoff=cutoff, chain1=chain1, chain2=chain2, return_contact_pairs=False) # get the residue info - data1 = self.get('chainID,resSeq,resName',rowID=contact_atoms[0]) - data2 = self.get('chainID,resSeq,resName',rowID=contact_atoms[1]) + data1 = self.get('chainID,resSeq,resName', rowID=contact_atoms[0]) + data2 = self.get('chainID,resSeq,resName', rowID=contact_atoms[1]) # take only unique - residue_contact_A = sorted(set([tuple(resData) for resData in data1])) - residue_contact_B = sorted(set([tuple(resData) for resData in data2])) + residue_contact_A = sorted( + set([tuple(resData) for resData in data1])) + residue_contact_B = sorted( + set([tuple(resData) for resData in data2])) - return residue_contact_A,residue_contact_B + return residue_contact_A, residue_contact_B - - ############################################################################################ + ########################################################################## # # PUT FUNCTONS AND ASSOCIATED # @@ -742,22 +793,22 @@ def get_contact_residue(self,cutoff=8.5,chain1='A',chain2='B',excludeH=False, # update_xyz() -> update_xyz of the pdb # put() -> put a value in a column # - ############################################################################################### - + ########################################################################## - def add_column(self,colname,coltype='FLOAT',default=0): - '''Add an extra column to the ATOM table. + def add_column(self, colname, coltype='FLOAT', default=0): + """Add an extra column to the ATOM table. Args: colname (str): name of the column coltype (str, optional): type of the column data (default FLOAT) default (int, optional): default value to fill in the column (default 0.0) - ''' + """ - query = "ALTER TABLE ATOM ADD COLUMN '%s' %s DEFAULT %s" %(colname,coltype,str(default)) + query = "ALTER TABLE ATOM ADD COLUMN '%s' %s DEFAULT %s" % ( + colname, coltype, str(default)) self.c.execute(query) - def update(self,attribute,values,**kwargs): + def update(self, attribute, values, **kwargs): """Update multiple columns in the data. Args: @@ -780,13 +831,15 @@ def update(self,attribute,values,**kwargs): # check if the column exists try: - self.c.execute("SELECT EXISTS(SELECT {an} FROM ATOM)".format(an=attribute)) - except: - print('Error column %s not found in the database' %attribute) + self.c.execute( + "SELECT EXISTS(SELECT {an} FROM ATOM)".format( + an=attribute)) + except BaseException: + print('Error column %s not found in the database' % attribute) self.get_colnames() raise ValueError('Attribute name not recognized') - #if len(kwargs) == 0: + # if len(kwargs) == 0: # raise ValueError('Update without kwargs seem to be buggy. Use rowID=list(range(natom)) instead') # handle the multi model cases @@ -804,48 +857,48 @@ def update(self,attribute,values,**kwargs): if ',' in attribute: attribute = attribute.split(',') - if not isinstance(attribute,list): + if not isinstance(attribute, list): attribute = [attribute] - # check the size natt = len(attribute) nrow = len(values) ncol = len(values[0]) if natt != ncol: - raise ValueError('Number of attribute incompatible with the number of columns in the data') - + raise ValueError( + 'Number of attribute incompatible with the number of columns in the data') # get the row ID of the selection - rowID = self.get('rowID',**kwargs) + rowID = self.get('rowID', **kwargs) nselect = len(rowID) if nselect != nrow: - raise ValueError('Number of data values incompatible with the given conditions') + raise ValueError( + 'Number of data values incompatible with the given conditions') # prepare the query query = 'UPDATE ATOM SET ' - query = query + ', '.join(map(lambda x: x+'=?',attribute)) - #if len(kwargs)>0: # why did I do that ... + query = query + ', '.join(map(lambda x: x + '=?', attribute)) + # if len(kwargs)>0: # why did I do that ... query = query + ' WHERE rowID=?' # prepare the data data = [] - for i,val in enumerate(values): + for i, val in enumerate(values): - tmp_data = [ v for v in val ] + tmp_data = [v for v in val] - #if len(kwargs)>0: Same here why did I do that ? + # if len(kwargs)>0: Same here why did I do that ? # here the conversion of the indexes is a bit annoying - tmp_data += [rowID[i]+1] + tmp_data += [rowID[i] + 1] data.append(tmp_data) - self.c.executemany(query,data) + self.c.executemany(query, data) - def update_column(self,colname,values,index=None): - '''Update a single column. + def update_column(self, colname, values, index=None): + """Update a single column. Args: colname (str): name of the column to update @@ -854,19 +907,20 @@ def update_column(self,colname,values,index=None): Example: >>> db.update_column('x',np.random.rand(10),index=list(range(10))) - ''' + """ - if index==None: - data = [ [v,i+1] for i,v in enumerate(values) ] + if index is None: + data = [[v, i + 1] for i, v in enumerate(values)] else: - data = [ [v,ind] for v,ind in zip(values,index)] # shouldn't that be ind+1 ? + # shouldn't that be ind+1 ? + data = [[v, ind] for v, ind in zip(values, index)] query = 'UPDATE ATOM SET {cn}=? WHERE rowID=?'.format(cn=colname) - self.c.executemany(query,data) - #self.conn.commit() + self.c.executemany(query, data) + # self.conn.commit() - def update_xyz(self,xyz,index=None): - '''Update the xyz information. + def update_xyz(self, xyz, index=None): + """Update the xyz information. Update the positions of the atoms selected if index=None the position of all the atoms are changed @@ -880,18 +934,21 @@ def update_xyz(self,xyz,index=None): >>> index = list(range(n)) >>> vals = np.random.rand(n,3) >>> db.update_xyz(vals,index=index) - ''' + """ - if index==None: - data = [ [pos[0],pos[1],pos[2],i+1] for i,pos in enumerate(xyz) ] + if index is None: + data = [[pos[0], pos[1], pos[2], i + 1] + for i, pos in enumerate(xyz)] else: - data = [ [pos[0],pos[1],pos[2],ind+1] for pos,ind in zip(xyz,index)] + data = [[pos[0], pos[1], pos[2], ind + 1] + for pos, ind in zip(xyz, index)] query = 'UPDATE ATOM SET x=?, y=?, z=? WHERE rowID=?' - self.c.executemany(query,data) + self.c.executemany(query, data) - def put(self,colname,value,**kwargs): - """ Update the value of the attribute with value specified with possible selection. + def put(self, colname, value, **kwargs): + """Update the value of the attribute with value specified with possible + selection. Args: colname (str) : must be a valid attribute name. @@ -912,33 +969,36 @@ def put(self,colname,value,**kwargs): >>> db.put('CHARGE',1.25,index=[1]) >>> db.close() """ - arguments = {'where' : "String e.g 'chainID = 'A''", - 'index' : "Array e.g. [27,28,30]", - 'name' : "'CA' atome name", - 'query' : "SQL query e.g. 'WHERE chainID='B' AND resName='ASP' "} + arguments = { + 'where': "String e.g 'chainID = 'A''", + 'index': "Array e.g. [27,28,30]", + 'name': "'CA' atome name", + 'query': "SQL query e.g. 'WHERE chainID='B' AND resName='ASP' "} # the asked keys keys = kwargs.keys() # if we have more than one key we kill it - if len(keys)>1 : - print('You can only specify 1 conditional statement for the pdb2sql.put function') + if len(keys) > 1: + print( + 'You can only specify 1 conditional statement for the pdb2sql.put function') return # check if the column exists try: - self.c.execute("SELECT EXISTS(SELECT {an} FROM ATOM)".format(an=colname)) - except: - print('Error column %s not found in the database' %colname) + self.c.execute( + "SELECT EXISTS(SELECT {an} FROM ATOM)".format( + an=colname)) + except BaseException: + print('Error column %s not found in the database' % colname) self.get_colnames() return - # if we have 0 key we take the entire db if len(kwargs) == 0: query = 'UPDATE ATOM SET {cn}=?'.format(cn=colname) value = tuple([value]) - self.c.execute(query,value) + self.c.execute(query, value) return # otherwise we have only one key @@ -947,48 +1007,51 @@ def put(self,colname,value,**kwargs): # select which key we have if key == 'where': - query = 'UPDATE ATOM SET {cn}=? WHERE {cond}'.format(cn=colname,cond=cond) + query = 'UPDATE ATOM SET {cn}=? WHERE {cond}'.format( + cn=colname, cond=cond) value = tuple([value]) - self.c.execute(query,value) + self.c.execute(query, value) - elif key == 'name' : - values = tuple([value,cond]) + elif key == 'name': + values = tuple([value, cond]) query = 'UPDATE ATOM SET {cn}=? WHERE name=?'.format(cn=colname) - self.c.execute(query,values) + self.c.execute(query, values) - elif key == 'index' : - values = tuple([value] + [v+1 for v in cond]) + elif key == 'index': + values = tuple([value] + [v + 1 for v in cond]) qm = ','.join(['?' for i in range(len(cond))]) - query = 'UPDATE ATOM SET {cn}=? WHERE rowID in ({qm})'.format(cn=colname,qm=qm) - self.c.execute(query,values) + query = 'UPDATE ATOM SET {cn}=? WHERE rowID in ({qm})'.format( + cn=colname, qm=qm) + self.c.execute(query, values) - elif key == 'query' : - query = 'UPDATE ATOM SET {cn}=? {c1}'.format(cn=colname,c1=cond) + elif key == 'query': + query = 'UPDATE ATOM SET {cn}=? {c1}'.format(cn=colname, c1=cond) value = tuple([value]) - self.c.execute(query,value) + self.c.execute(query, value) else: - print('Error arguments %s not supported in pdb2sql.get()\nOptions are:\n' %(key)) - for posskey,possvalue in arguments.items(): + print( + 'Error arguments %s not supported in pdb2sql.get()\nOptions are:\n' % + (key)) + for posskey, possvalue in arguments.items(): print('\t' + posskey + '\t\t' + possvalue) return - - - ############################################################################################ + ########################################################################## # # COMMIT, EXPORT, CLOSE FUNCTIONS # - ############################################################################################### + ########################################################################## # comit changes + def commit(self): """Commit the database.""" self.conn.commit() # export to pdb file - def exportpdb(self,fname,**kwargs): - """Export a PDB file with kwargs selection + def exportpdb(self, fname, **kwargs): + """Export a PDB file with kwargs selection. Args: fname (str): Name of the file @@ -998,30 +1061,29 @@ def exportpdb(self,fname,**kwargs): >>> db = pdb2sql('1AK4.pdb') >>> db.exportpdb('CA.pdb',name='CA') - """ # get the data - data = self.get('*',**kwargs) + data = self.get('*', **kwargs) # write each line # the PDB format is pretty strict # http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM - f = open(fname,'w') + f = open(fname, 'w') for d in data: line = 'ATOM ' line += '{:>5}'.format(d[0]) # serial line += ' ' line += '{:^4}'.format(d[1]) # name line += '{:>1}'.format(d[2]) # altLoc - line += '{:>3}'.format(d[3]) #resname + line += '{:>3}'.format(d[3]) # resname line += ' ' line += '{:>1}'.format(d[4]) # chainID line += '{:>4}'.format(d[5]) # resSeq line += '{:>1}'.format(d[6]) # iCODE line += ' ' - line += '{: 8.3f}'.format(d[7]) #x - line += '{: 8.3f}'.format(d[8]) #y - line += '{: 8.3f}'.format(d[9]) #z + line += '{: 8.3f}'.format(d[7]) # x + line += '{: 8.3f}'.format(d[8]) # y + line += '{: 8.3f}'.format(d[9]) # z if not self.no_extra: line += '{: 6.2f}'.format(d[10]) # occ @@ -1034,9 +1096,8 @@ def exportpdb(self,fname,**kwargs): # close f.close() - # close the database - def close(self,rmdb = True): + def close(self, rmdb=True): """Close the database. Args: @@ -1050,20 +1111,19 @@ def close(self,rmdb = True): if rmdb: self.conn.close() - os.system('rm %s' %(self.sqlfile)) + os.system('rm %s' % (self.sqlfile)) else: self.commit() self.conn.close() - ############################################################################ + ########################################################################## # # Transform the position of the molecule # - ############################################################################## + ########################################################################## - - def translation(self,vect,**kwargs): - """Translate a part or all of the molecule + def translation(self, vect, **kwargs): + """Translate a part or all of the molecule. Args: vect (np.array): translation vector @@ -1074,12 +1134,12 @@ def translation(self,vect,**kwargs): >>> vect = np.random.rand(3) >>> db.translation(vect, chainID = 'A') """ - xyz = self.get('x,y,z',**kwargs) + xyz = self.get('x,y,z', **kwargs) xyz += vect - self.update('x,y,z',xyz,**kwargs) + self.update('x,y,z', xyz, **kwargs) - def rotation_around_axis(self,axis,angle,**kwargs): - """Rotate a part or all of the molecule around a specified axis + def rotation_around_axis(self, axis, angle, **kwargs): + """Rotate a part or all of the molecule around a specified axis. Args: axis (np.array): axis of rotation @@ -1092,30 +1152,35 @@ def rotation_around_axis(self,axis,angle,**kwargs): >>> angle = np.random.rand() >>> db.rotation_around_axis(axis, angle, chainID = 'B') """ - xyz = self.get('x,y,z',**kwargs) + xyz = self.get('x,y,z', **kwargs) # get the data - ct,st = np.cos(angle),np.sin(angle) - ux,uy,uz = axis + ct, st = np.cos(angle), np.sin(angle) + ux, uy, uz = axis # get the center of the molecule - xyz0 = np.mean(xyz,0) + xyz0 = np.mean(xyz, 0) # definition of the rotation matrix # see https://en.wikipedia.org/wiki/Rotation_matrix - rot_mat = np.array([ - [ct + ux**2*(1-ct), ux*uy*(1-ct) - uz*st, ux*uz*(1-ct) + uy*st], - [uy*ux*(1-ct) + uz*st, ct + uy**2*(1-ct), uy*uz*(1-ct) - ux*st], - [uz*ux*(1-ct) - uy*st, uz*uy*(1-ct) + ux*st, ct + uz**2*(1-ct) ]]) + rot_mat = np.array([[ct + ux**2 * (1 - ct), + ux * uy * (1 - ct) - uz * st, + ux * uz * (1 - ct) + uy * st], + [uy * ux * (1 - ct) + uz * st, + ct + uy**2 * (1 - ct), + uy * uz * (1 - ct) - ux * st], + [uz * ux * (1 - ct) - uy * st, + uz * uy * (1 - ct) + ux * st, + ct + uz**2 * (1 - ct)]]) # apply the rotation - xyz = np.dot(rot_mat,(xyz-xyz0).T).T + xyz0 - self.update('x,y,z',xyz,**kwargs) + xyz = np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 + self.update('x,y,z', xyz, **kwargs) return xyz0 - def rotation_euler(self,alpha,beta,gamma,**kwargs): - """Rotate a part or all of the molecule from Euler rotation axis + def rotation_euler(self, alpha, beta, gamma, **kwargs): + """Rotate a part or all of the molecule from Euler rotation axis. Args: alpha (float): angle of rotation around the x axis @@ -1128,30 +1193,29 @@ def rotation_euler(self,alpha,beta,gamma,**kwargs): >>> a,b,c = np.random.rand(3) >>> db.rotation_euler(a,b,c,resName='VAL') """ - xyz = self.get('x,y,z',**kwargs) + xyz = self.get('x,y,z', **kwargs) # precomte the trig - ca,sa = np.cos(alpha),np.sin(alpha) - cb,sb = np.cos(beta),np.sin(beta) - cg,sg = np.cos(gamma),np.sin(gamma) - + ca, sa = np.cos(alpha), np.sin(alpha) + cb, sb = np.cos(beta), np.sin(beta) + cg, sg = np.cos(gamma), np.sin(gamma) # get the center of the molecule - xyz0 = np.mean(xyz,0) + xyz0 = np.mean(xyz, 0) # rotation matrices - rx = np.array([[1,0,0],[0,ca,-sa],[0,sa,ca]]) - ry = np.array([[cb,0,sb],[0,1,0],[-sb,0,cb]]) - rz = np.array([[cg,-sg,0],[sg,cg,0],[0,0,1]]) - rot_mat = np.dot(rx,np.dot(ry,rz)) + rx = np.array([[1, 0, 0], [0, ca, -sa], [0, sa, ca]]) + ry = np.array([[cb, 0, sb], [0, 1, 0], [-sb, 0, cb]]) + rz = np.array([[cg, -sg, 0], [sg, cg, 0], [0, 0, 1]]) + rot_mat = np.dot(rx, np.dot(ry, rz)) # apply the rotation - xyz = np.dot(rot_mat,(xyz-xyz0).T).T + xyz0 + xyz = np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 - self.update('x,y,z',xyz,**kwargs) + self.update('x,y,z', xyz, **kwargs) - def rotation_matrix(self,rot_mat,center=True,**kwargs): - """Rotate a part or all of the molecule from a rotation matrix + def rotation_matrix(self, rot_mat, center=True, **kwargs): + """Rotate a part or all of the molecule from a rotation matrix. Args: rot_mat (np.array): 3x3 rotation matrix @@ -1163,11 +1227,11 @@ def rotation_matrix(self,rot_mat,center=True,**kwargs): >>> mat = np.random.rand(3,3) >>> db.rotation_matrix(mat,chainID='A') """ - xyz = self.get('x,y,z',**kwargs) + xyz = self.get('x,y,z', **kwargs) if center: xyz0 = np.mean(xyz) - xyz = np.dot(rot_mat,(xyz-xyz0).T).T + xyz0 + xyz = np.dot(rot_mat, (xyz - xyz0).T).T + xyz0 else: - xyz = np.dot(rot_mat,(xyz).T).T - self.update('x,y,z',xyz,**kwargs) + xyz = np.dot(rot_mat, (xyz).T).T + self.update('x,y,z', xyz, **kwargs) diff --git a/deeprank/tools/sasa.py b/deeprank/tools/sasa.py index f8d38717..75624641 100644 --- a/deeprank/tools/sasa.py +++ b/deeprank/tools/sasa.py @@ -1,10 +1,11 @@ import numpy as np + from deeprank.tools import pdb2sql -class SASA(object): - def __init__(self,pdbfile): +class SASA(object): + def __init__(self, pdbfile): """Simple class that computes Surface Accessible Solvent Area. The method follows some of the approaches presented in : @@ -19,14 +20,12 @@ def __init__(self,pdbfile): Args: pdbfile (str): PDB file of the conformation - """ self.pdbfile = pdbfile - def get_center(self,chainA='A',chainB='B',center='cb'): - - '''Get the center of the resiudes. + def get_center(self, chainA='A', chainB='B', center='cb'): + """Get the center of the resiudes. center = cb --> the center is located on the carbon beta of each residue center = 'center' --> average position of all atoms of the residue @@ -39,32 +38,31 @@ def get_center(self,chainA='A',chainB='B',center='cb'): center = 'center' --> average position of all atoms of the residue Raises: ValueError: If the center is not recpgnized - ''' + """ if center == 'center': - self.get_residue_center(chainA=chainA,chainB=chainB) + self.get_residue_center(chainA=chainA, chainB=chainB) elif center == 'cb': - self.get_residue_carbon_beta(chainA=chainA,chainB=chainB) + self.get_residue_carbon_beta(chainA=chainA, chainB=chainB) else: - raise ValueError('Options %s not recognized in SASA.get_center' %center) - - - def get_residue_center(self,chainA='A',chainB='B'): + raise ValueError( + 'Options %s not recognized in SASA.get_center' % + center) - '''Compute the average position of all the residues. + def get_residue_center(self, chainA='A', chainB='B'): + """Compute the average position of all the residues. Args: chainA (str, optional): Name of the first chain chainB (str, optional): Name of the second chain - ''' + """ sql = pdb2sql(self.pdbfile) - resA = np.array(sql.get('resSeq,resName',chainID=chainA)) - resB = np.array(sql.get('resSeq,resName',chainID=chainB)) - + resA = np.array(sql.get('resSeq,resName', chainID=chainA)) + resB = np.array(sql.get('resSeq,resName', chainID=chainB)) - resSeqA = np.unique(resA[:,0].astype(np.int)) - resSeqB = np.unique(resB[:,0].astype(np.int)) + resSeqA = np.unique(resA[:, 0].astype(np.int)) + resSeqB = np.unique(resB[:, 0].astype(np.int)) self.xyz = {} #self.xyz[chainA] = [ np.mean( resA[np.argwhere(resA[:,0].astype(np.int)==r),2:],0 ).astype(np.float).tolist()[0] for r in resSeqA ] @@ -72,55 +70,68 @@ def get_residue_center(self,chainA='A',chainB='B'): self.xyz[chainA] = [] for r in resSeqA: - xyz = sql.get('x,y,z',chainID=chainA,resSeq=str(r)) + xyz = sql.get('x,y,z', chainID=chainA, resSeq=str(r)) self.xyz[chainA].append(np.mean(xyz)) self.xyz[chainB] = [] for r in resSeqB: - xyz = sql.get('x,y,z',chainID=chainB,resSeq=str(r)) + xyz = sql.get('x,y,z', chainID=chainB, resSeq=str(r)) self.xyz[chainA].append(np.mean(xyz)) self.resinfo = {} self.resinfo[chainA] = [] - for r in resA[:,:2]: + for r in resA[:, :2]: if tuple(r) not in self.resinfo[chainA]: self.resinfo[chainA].append(tuple(r)) self.resinfo[chainB] = [] - for r in resB[:,:2]: + for r in resB[:, :2]: if tuple(r) not in self.resinfo[chainB]: self.resinfo[chainB].append(tuple(r)) sql.close() - def get_residue_carbon_beta(self,chainA='A',chainB='B'): - - '''Extract the position of the carbon beta of each residue. + def get_residue_carbon_beta(self, chainA='A', chainB='B'): + """Extract the position of the carbon beta of each residue. Args: chainA (str, optional): Name of the first chain chainB (str, optional): Name of the second chain - ''' + """ sql = pdb2sql(self.pdbfile) - resA = np.array(sql.get('resSeq,resName,x,y,z',name='CB',chainID=chainA)) - resB = np.array(sql.get('resSeq,resName,x,y,z',name='CB',chainID=chainB)) + resA = np.array( + sql.get( + 'resSeq,resName,x,y,z', + name='CB', + chainID=chainA)) + resB = np.array( + sql.get( + 'resSeq,resName,x,y,z', + name='CB', + chainID=chainB)) sql.close() - assert len(resA[:,0].astype(np.int).tolist()) == len(np.unique(resA[:,0].astype(np.int)).tolist()) - assert len(resB[:,0].astype(np.int).tolist()) == len(np.unique(resB[:,0].astype(np.int)).tolist()) + assert len(resA[:, 0].astype(np.int).tolist()) == len( + np.unique(resA[:, 0].astype(np.int)).tolist()) + assert len(resB[:, 0].astype(np.int).tolist()) == len( + np.unique(resB[:, 0].astype(np.int)).tolist()) self.xyz = {} - self.xyz[chainA] = resA[:,2:].astype(np.float) - self.xyz[chainB] = resB[:,2:].astype(np.float) + self.xyz[chainA] = resA[:, 2:].astype(np.float) + self.xyz[chainB] = resB[:, 2:].astype(np.float) self.resinfo = {} - self.resinfo[chainA] = resA[:,:2] - self.resinfo[chainB] = resB[:,:2] - - def neighbor_vector(self,lbound=3.3,ubound=11.1,chainA='A',chainB='B',center='cb'): - - - '''Compute teh SASA folowing the neighbour vector approach. + self.resinfo[chainA] = resA[:, :2] + self.resinfo[chainB] = resB[:, :2] + + def neighbor_vector( + self, + lbound=3.3, + ubound=11.1, + chainA='A', + chainB='B', + center='cb'): + """Compute teh SASA folowing the neighbour vector approach. The method is based on Eq on page 1097 of https://link.springer.com/article/10.1007%2Fs00894-009-0454-9 @@ -133,42 +144,47 @@ def neighbor_vector(self,lbound=3.3,ubound=11.1,chainA='A',chainB='B',center='cb Returns: dict: neighbouring vectors - ''' + """ # get the center - self.get_center(chainA=chainA,chainB=chainB,center=center) + self.get_center(chainA=chainA, chainB=chainB, center=center) NV = {} - for chain in [chainA,chainB]: + for chain in [chainA, chainB]: - for i,xyz in enumerate(self.xyz[chain]): + for i, xyz in enumerate(self.xyz[chain]): - vect = self.xyz[chain]-xyz - dist = np.sqrt(np.sum((self.xyz[chain]-xyz)**2,1)) + vect = self.xyz[chain] - xyz + dist = np.sqrt(np.sum((self.xyz[chain] - xyz)**2, 1)) - dist = np.delete(dist,i,0) - vect = np.delete(vect,i,0) + dist = np.delete(dist, i, 0) + vect = np.delete(vect, i, 0) - vect /= np.linalg.norm(vect,axis=1).reshape(-1,1) + vect /= np.linalg.norm(vect, axis=1).reshape(-1, 1) - weight = self.neighbor_weight(dist,lbound=lbound,ubound=ubound).reshape(-1,1) + weight = self.neighbor_weight( + dist, lbound=lbound, ubound=ubound).reshape(-1, 1) vect *= weight - vect = np.sum(vect,0) + vect = np.sum(vect, 0) vect /= np.sum(weight) - resSeq,resName = self.resinfo[chain][i].tolist() - key = tuple([chain,int(resSeq),resName]) - value = np.linalg.norm(vect) - NV[key] = value + resSeq, resName = self.resinfo[chain][i].tolist() + key = tuple([chain, int(resSeq), resName]) + value = np.linalg.norm(vect) + NV[key] = value return NV - - def neighbor_count(self,lbound=4.0,ubound=11.4,chainA='A',chainB='B',center='cb'): - - '''Compute the neighbourhood count of each residue. + def neighbor_count( + self, + lbound=4.0, + ubound=11.4, + chainA='A', + chainB='B', + center='cb'): + """Compute the neighbourhood count of each residue. The method is based on Eq on page 1097 of https://link.springer.com/article/10.1007%2Fs00894-009-0454-9 @@ -181,28 +197,28 @@ def neighbor_count(self,lbound=4.0,ubound=11.4,chainA='A',chainB='B',center='cb' Returns: dict: Neighborhood count - ''' + """ # get the center - self.get_center(chainA=chainA,chainB=chainB,center=center) + self.get_center(chainA=chainA, chainB=chainB, center=center) # dict of NC NC = {} - for chain in [chainA,chainB]: + for chain in [chainA, chainB]: - for i,xyz in enumerate(self.xyz[chain]): - dist = np.sqrt(np.sum((self.xyz[chain]-xyz)**2,1)) - resSeq,resName = self.resinfo[chain][i].tolist() - key = tuple([chain,int(resSeq),resName]) - value = np.sum(self.neighbor_weight(dist,lbound,ubound)) - NC[key] = value + for i, xyz in enumerate(self.xyz[chain]): + dist = np.sqrt(np.sum((self.xyz[chain] - xyz)**2, 1)) + resSeq, resName = self.resinfo[chain][i].tolist() + key = tuple([chain, int(resSeq), resName]) + value = np.sum(self.neighbor_weight(dist, lbound, ubound)) + NC[key] = value return NC @staticmethod - def neighbor_weight(dist,lbound,ubound): - """Neighboor weight + def neighbor_weight(dist, lbound, ubound): + """Neighboor weight. Args: dist (np.array): distance from neighboors @@ -212,10 +228,11 @@ def neighbor_weight(dist,lbound,ubound): Returns: float: distance """ - ind = np.argwhere( (dist>lbound) & (dist=ubound] = 0 + ind = np.argwhere((dist > lbound) & (dist < ubound)) + dist[ind] = 0.5 * \ + (np.cos(np.pi * (dist[ind] - lbound) / (ubound - lbound)) + 1) + dist[dist <= lbound] = 1 + dist[dist >= ubound] = 0 return dist @@ -223,4 +240,4 @@ def neighbor_weight(dist,lbound,ubound): sasa = SASA('1AK4_1w.pdb') NV = sasa.neighbor_vector() - print(NV) \ No newline at end of file + print(NV) diff --git a/deeprank/tools/sparse.py b/deeprank/tools/sparse.py index 73726dc5..722e6c55 100644 --- a/deeprank/tools/sparse.py +++ b/deeprank/tools/sparse.py @@ -1,11 +1,13 @@ import numpy as np -_printif = lambda string,cond: print(string) if cond else None + + +def _printif(string, cond): return print(string) if cond else None class FLANgrid(object): - def __init__(self,sparse=None,index=None,value=None,shape=None): - """Flat Array sparse matrix + def __init__(self, sparse=None, index=None, value=None, shape=None): + """Flat Array sparse matrix. Args: sparse (bool, optional): Sparse or Not @@ -13,32 +15,32 @@ def __init__(self,sparse=None,index=None,value=None,shape=None): value (list(float), optional): values of the non-zero elements shape (3x3 array, optional): Shape of the matrix """ - self.sparse=sparse + self.sparse = sparse self.index = index self.value = value self.shape = shape - def from_dense(self,data,beta=None,debug=False): - '''Create a sparse matrix from a dense one. + def from_dense(self, data, beta=None, debug=False): + """Create a sparse matrix from a dense one. Args: data (np.array): Dense matrix beta (float, optional): threshold to determine if a sparse rep is valuable debug (bool, optional): print debug information - ''' + """ if beta is not None: - thr = beta*np.mean(np.abs(data)) - index = np.argwhere(np.abs(data)>thr) - value = data[np.abs(data)>thr].reshape(-1,1) + thr = beta * np.mean(np.abs(data)) + index = np.argwhere(np.abs(data) > thr) + value = data[np.abs(data) > thr].reshape(-1, 1) else: - index = np.argwhere(data!=0) - value = data[data!=0].reshape(-1,1) + index = np.argwhere(data != 0) + value = data[data != 0].reshape(-1, 1) self.shape = data.shape # we can probably have different grid size # hence differnent index range to handle - if np.prod(data.shape) < 2**16-1: + if np.prod(data.shape) < 2**16 - 1: index_type = np.uint16 ind_byte = 16 else: @@ -46,28 +48,31 @@ def from_dense(self,data,beta=None,debug=False): ind_byte = 32 # memory requirements - mem_sparse = int(len(index)*ind_byte + len(index) * 32) - mem_dense = int(np.prod(data.shape)*32) + mem_sparse = int(len(index) * ind_byte + len(index) * 32) + mem_dense = int(np.prod(data.shape) * 32) # decide if we store sparse or not # if enough elements are sparse if mem_sparse < mem_dense: - _printif('--> FLAN sparse %d bits/%d bits' %(mem_sparse,mem_dense),debug) + _printif( + '--> FLAN sparse %d bits/%d bits' % + (mem_sparse, mem_dense), debug) self.sparse = True self.index = self._get_single_index_array(index).astype(index_type) - self.value= value.astype(np.float32) - + self.value = value.astype(np.float32) else: - _printif('--> FLAN dense %d bits/%d bits' %(mem_sparse,mem_dense),debug) + _printif( + '--> FLAN dense %d bits/%d bits' % + (mem_sparse, mem_dense), debug) self.sparse = False - self.index=None - self.value=data.astype(np.float32) + self.index = None + self.value = data.astype(np.float32) - def to_dense(self,shape=None): - """Create a dense matrix + def to_dense(self, shape=None): + """Create a dense matrix. Args: shape (3x3 array, optional): Shape of the matrix @@ -86,10 +91,10 @@ def to_dense(self,shape=None): shape = self.shape data = np.zeros(np.prod(self.shape)) - data[self.index] = self.value[:,0] + data[self.index] = self.value[:, 0] return data.reshape(self.shape) - def _get_single_index(self,index): + def _get_single_index(self, index): """Get the single index for a single element. # get the index can be used with a map @@ -106,12 +111,12 @@ def _get_single_index(self,index): assert ndim == len(self.shape) ind = index[-1] - for i in range(ndim-1): - ind += index[i] * np.prod(self.shape[i+1:]) + for i in range(ndim - 1): + ind += index[i] * np.prod(self.shape[i + 1:]) return ind - def _get_single_index_array(self,index): - """Get the single index for multiple elements + def _get_single_index_array(self, index): + """Get the single index for multiple elements. # get the index can be used with a map # self.index = np.array( list( map(self._get_single_index,index) ) ).astype(index_type) @@ -124,11 +129,11 @@ def _get_single_index_array(self,index): list(int): index """ - single_ind = index[:,-1] + single_ind = index[:, -1] ndim = index.shape[-1] assert ndim == len(self.shape) - for i in range(ndim-1): - single_ind += index[:,i] * np.prod(self.shape[i+1:]) + for i in range(ndim - 1): + single_ind += index[:, i] * np.prod(self.shape[i + 1:]) return single_ind diff --git a/deeprank/utils/add_binaryClass.py b/deeprank/utils/add_binaryClass.py index ada137c8..e4aeafa2 100755 --- a/deeprank/utils/add_binaryClass.py +++ b/deeprank/utils/add_binaryClass.py @@ -2,23 +2,27 @@ # This script can be used to create/correct target values -import deeprank.generate.DataGenerator as DataGenerator -import os -import numpy as np import glob +import os from time import time +import numpy as np + +import deeprank.generate.DataGenerator as DataGenerator + path = './' -database = [ f for f in glob.glob(path + '*.hdf5') ] +database = [f for f in glob.glob(path + '*.hdf5')] -print (database) +print(database) # create binary target for hdf5_FL in database: - print("Add binary class to %s" %hdf5_FL) - data_set = DataGenerator(compute_targets = ['deeprank.targets.binary_class'], hdf5=hdf5_FL) + print("Add binary class to %s" % hdf5_FL) + data_set = DataGenerator( + compute_targets=['deeprank.targets.binary_class'], + hdf5=hdf5_FL) t0 = time() data_set.add_target(prog_bar=True) - print(' '*25 + '--> Done in %f s.' %(time()-t0)) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) diff --git a/deeprank/utils/cal_hitrate_successrate.py b/deeprank/utils/cal_hitrate_successrate.py index 74bd9683..012ebb4c 100644 --- a/deeprank/utils/cal_hitrate_successrate.py +++ b/deeprank/utils/cal_hitrate_successrate.py @@ -1,11 +1,11 @@ import numpy as np import pandas as pd + from deeprank.learn import rankingMetrics def evaluate(data): - ''' - Calculate success rate and hit rate. + """Calculate success rate and hit rate. data: a data frame. @@ -25,8 +25,7 @@ def evaluate(data): train 1ZHI 1 0.2 1 0.3 where success =[0, 0, 1, 1, 1,...]: starting from rank 3 this case is a success - - ''' + """ out_df = pd.DataFrame() labels = data.label.unique() # ['train', 'test', 'valid'] @@ -56,7 +55,8 @@ def evaluate(data): caseIDs.extend([caseID] * len(df_one_case)) # hitrate = df_sorted['target'].apply(rankingMetrics.hitrate) # df_sorted['target']: class IDs for each model - # success = hitrate.apply(rankingMetrics.success) # success =[0, 0, 1, 1, 1,...]: starting from rank 3 this case is a success + # success = hitrate.apply(rankingMetrics.success) # success =[0, 0, + # 1, 1, 1,...]: starting from rank 3 this case is a success out_df_tmp['label'] = [l] * len(df) # train, valid or test out_df_tmp['caseID'] = caseIDs @@ -69,8 +69,7 @@ def evaluate(data): def ave_evaluate(data): - ''' - Calculate the average of each column over all cases. + """Calculate the average of each column over all cases. INPUT: data = @@ -103,8 +102,7 @@ def ave_evaluate(data): test 5ACD 0.0 0.0 0.0 0.0 test 5ACD 1.0 1.0 1.0 1.0 - - ''' + """ new_data = pd.DataFrame() for l, perf_per_case in data.groupby('label'): @@ -128,7 +126,7 @@ def ave_evaluate(data): perf_ave[col] = perf_ave[col][0:top_N] + \ np.array(perf_case[col][0:top_N]) - perf_ave[col] = perf_ave[col]/num_cases + perf_ave[col] = perf_ave[col] / num_cases new_data = pd.concat([new_data, perf_ave]) @@ -136,11 +134,9 @@ def ave_evaluate(data): def add_rank(df): - ''' - INPUT (a data frame): - label success_DR hitRate_DR success_HS hitRate_HS - Test 0.0 0.000000 0.0 0.000000 - Test 0.0 0.000000 1.0 0.012821 + """INPUT (a data frame): label success_DR hitRate_DR success_HS + hitRate_HS Test 0.0 0.000000 0.0 0.000000 Test + 0.0 0.000000 1.0 0.012821. Train 0.0 0.000000 1.0 0.012821 Train 0.0 0.000000 1.0 0.025641 @@ -152,15 +148,14 @@ def add_rank(df): Train 0.0 0.000000 1.0 0.012821 0.002846 Train 0.0 0.000000 1.0 0.025641 0.003795 - - ''' + """ # -- add the 'rank' column to df rank = [] for _, df_per_label in df.groupby('label'): num_mol = len(df_per_label) rank_raw = np.array(range(num_mol)) + 1 - rank.extend(rank_raw/num_mol) + rank.extend(rank_raw / num_mol) df['rank'] = rank df['label'] = pd.Categorical(df['label'], categories=[ diff --git a/deeprank/utils/cleandata.py b/deeprank/utils/cleandata.py index aab14863..265bb012 100755 --- a/deeprank/utils/cleandata.py +++ b/deeprank/utils/cleandata.py @@ -1,11 +1,13 @@ #!/usr/bin/env python -import h5py import os -def clean_dataset(fname,feature=True,pdb=True,points=True,grid=False): +import h5py + + +def clean_dataset(fname, feature=True, pdb=True, points=True, grid=False): # name of the hdf5 file - f5 = h5py.File(fname,'a') + f5 = h5py.File(fname, 'a') # get the folder names mol_names = f5.keys() @@ -26,24 +28,40 @@ def clean_dataset(fname,feature=True,pdb=True,points=True,grid=False): f5.close() - os.system('h5repack %s _tmp.h5py' %fname) - os.system('mv _tmp.h5py %s' %fname) + os.system('h5repack %s _tmp.h5py' % fname) + os.system('mv _tmp.h5py %s' % fname) + if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='remove data from a hdf5 data set') - parser.add_argument('hdf5', help="hdf5 file storing the data set",default=None) - parser.add_argument('--keep_feature', action='store_true',help="keep the features") - parser.add_argument('--keep_pdb', action='store_true',help="keep the pdbs") - parser.add_argument('--keep_pts',action='store_true',help="keep the coordinates of the grid points") - parser.add_argument('--rm_grid',action='store_true',help='remove the mapped feaures on the grids') + parser = argparse.ArgumentParser( + description='remove data from a hdf5 data set') + parser.add_argument( + 'hdf5', + help="hdf5 file storing the data set", + default=None) + parser.add_argument( + '--keep_feature', + action='store_true', + help="keep the features") + parser.add_argument( + '--keep_pdb', + action='store_true', + help="keep the pdbs") + parser.add_argument( + '--keep_pts', + action='store_true', + help="keep the coordinates of the grid points") + parser.add_argument( + '--rm_grid', + action='store_true', + help='remove the mapped feaures on the grids') args = parser.parse_args() clean_dataset(args.hdf5, - feature = not args.keep_feature, - pdb = not args.keep_pdb, - points = not args.keep_pts, - grid = args.rm_grid ) - + feature=not args.keep_feature, + pdb=not args.keep_pdb, + points=not args.keep_pts, + grid=args.rm_grid) diff --git a/deeprank/utils/get_h5subset.py b/deeprank/utils/get_h5subset.py index be6b4a51..9d73fb4f 100755 --- a/deeprank/utils/get_h5subset.py +++ b/deeprank/utils/get_h5subset.py @@ -1,11 +1,11 @@ #!/usr/bin/env python -""" -Extract first N groups of a hdf5 to a new hdf5 file. +"""Extract first N groups of a hdf5 to a new hdf5 file. Usage: python {0} Example: python {0} ./001_1GPW.hdf5 ./001_1GPW_sub10.hdf5 10 """ import sys + import h5py USAGE = __doc__.format(__file__) diff --git a/deeprank/utils/launch.py b/deeprank/utils/launch.py index 55596266..9a06e668 100755 --- a/deeprank/utils/launch.py +++ b/deeprank/utils/launch.py @@ -1,9 +1,10 @@ #!/usr/bin/env python -from deeprank.generate import * import os from time import time + from cleandata import * +from deeprank.generate import * ########################################################################## # @@ -19,90 +20,156 @@ BM4 = '/home/deep/projects/deeprank/data/HADDOCK/BM4/' -def generate(LIST_NAME,clean=False): +def generate(LIST_NAME, clean=False): for NAME in LIST_NAME: print(NAME) # sources to assemble the data base - pdb_source = [BM4 + 'decoys_pdbFLs/'+NAME+'/water/'] - pdb_native = [BM4 + 'BM4_dimers_bound/pdbFLs_ori'] - - #init the data assembler - database = DataGenerator(pdb_source=pdb_source,pdb_native=pdb_native,data_augmentation=None, - compute_targets = ['deeprank.tools.targets.dockQ'], - compute_features = ['deeprank.tools.features.atomic', - 'deeprank.tools.features.pssm'], - hdf5=NAME + '.hdf5', - ) + pdb_source = [BM4 + 'decoys_pdbFLs/' + NAME + '/water/'] + pdb_native = [BM4 + 'BM4_dimers_bound/pdbFLs_ori'] + + # init the data assembler + database = DataGenerator( + pdb_source=pdb_source, + pdb_native=pdb_native, + data_augmentation=None, + compute_targets=['deeprank.tools.targets.dockQ'], + compute_features=[ + 'deeprank.tools.features.atomic', + 'deeprank.tools.features.pssm'], + hdf5=NAME + '.hdf5', + ) if not os.path.isfile(database.hdf5): t0 = time() print('{:25s}'.format('Create new database') + database.hdf5) database.create_database() - print(' '*25 + '--> Done in %f s.' %(time()-t0)) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) else: print('{:25s}'.format('Use existing database') + database.hdf5) # map the features grid_info = { - 'number_of_points' : [30,30,30], - 'resolution' : [1.,1.,1.], - 'atomic_densities' : {'CA':3.5,'CB':3.5,'N':3.5,'O':3.5,'C':3.5}, - 'atomic_densities_mode' : 'diff', - 'feature_mode': 'sum' - } - - t0 =time() + 'number_of_points': [ + 30, + 30, + 30], + 'resolution': [ + 1., + 1., + 1.], + 'atomic_densities': { + 'CA': 3.5, + 'CB': 3.5, + 'N': 3.5, + 'O': 3.5, + 'C': 3.5}, + 'atomic_densities_mode': 'diff', + 'feature_mode': 'sum'} + + t0 = time() print('{:25s}'.format('Map features in database') + database.hdf5) - database.map_features(grid_info,time=False,try_sparse=True,cuda=True,gpu_block=[8,8,8]) - print(' '*25 + '--> Done in %f s.' %(time()-t0)) + database.map_features( + grid_info, + time=False, + try_sparse=True, + cuda=True, + gpu_block=[ + 8, + 8, + 8]) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) # clean the data file if clean: t0 = time() print('{:25s}'.format('Clean datafile') + database.hdf5) clean_dataset(database.hdf5) - print(' '*25 + '--> Done is %f s.' %(time()-t0)) + print(' ' * 25 + '--> Done is %f s.' % (time() - t0)) + if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='launch multiple HDF5 calculations') - parser.add_argument('-s','--status',action='store_true',help='Only list the directory') - parser.add_argument('-d','--device', help="GPU device to use",default='1',type=str) - parser.add_argument('-m','--mol',nargs='+',help='name of the molecule to process',default=None,type=str) - parser.add_argument('-i','--init',help="index of the first molecule to process",default=0,type=int) - parser.add_argument('-f','--final',help="index of the last molecule to process",default=0,type=int) - parser.add_argument('--clean',help="Clean the datafiles",action='store_true') + parser = argparse.ArgumentParser( + description='launch multiple HDF5 calculations') + parser.add_argument( + '-s', + '--status', + action='store_true', + help='Only list the directory') + parser.add_argument( + '-d', + '--device', + help="GPU device to use", + default='1', + type=str) + parser.add_argument( + '-m', + '--mol', + nargs='+', + help='name of the molecule to process', + default=None, + type=str) + parser.add_argument( + '-i', + '--init', + help="index of the first molecule to process", + default=0, + type=int) + parser.add_argument( + '-f', + '--final', + help="index of the last molecule to process", + default=0, + type=int) + parser.add_argument( + '--clean', + help="Clean the datafiles", + action='store_true') args = parser.parse_args() # get the names of the directories - names = os.listdir(BM4+'decoys_pdbFLs/') + names = os.listdir(BM4 + 'decoys_pdbFLs/') # remove some files # as stated in the README some complex don't have a water stage - remove_file = ['README','2H7V','1F6M','1ZLI','1IBR','1R8S','1Y64'] + remove_file = ['README', '2H7V', '1F6M', '1ZLI', '1IBR', '1R8S', '1Y64'] for r in remove_file: names.remove(r) # get the names of thehdf5 already there hdf5 = list(filter(lambda x: '.hdf5' in x, os.listdir())) - status = [ 'Done' if n+'.hdf5' in hdf5 else '' for n in names ] - size = [ "{:5.2f}".format(os.path.getsize(n+'.hdf5')/1E9) if n+'.hdf5' in hdf5 else '' for n in names ] + status = ['Done' if n + '.hdf5' in hdf5 else '' for n in names] + size = [ + "{:5.2f}".format( + os.path.getsize( + n + + '.hdf5') / + 1E9) if n + + '.hdf5' in hdf5 else '' for n in names] # list the dir and their status if args.status: - print('\n'+'='*50+'\n= Current status of the datase \n'+'='*50) - for i,(n,s,w) in enumerate(zip(names,status,size)): + print( + '\n' + + '=' * + 50 + + '\n= Current status of the datase \n' + + '=' * + 50) + for i, (n, s, w) in enumerate(zip(names, status, size)): if w == '': - print('% 4d: %6s %5s %s' %(i,n,s,w)) + print('% 4d: %6s %5s %s' % (i, n, s, w)) else: - print('% 4d: %6s %5s %s GB' %(i,n,s,w)) - print('-'*50) - print(': Status --> %4.3f %% done' %(status.count('Done')/len(status)*100)) - print(': Mem Tot --> %4.3f GB\n' %sum(list(map(lambda x: float(x),filter(lambda x: len(x)>0,size))))) + print('% 4d: %6s %5s %s GB' % (i, n, s, w)) + print('-' * 50) + print( + ': Status --> %4.3f %% done' % (status.count('Done') / len(status) * 100)) + print(': Mem Tot --> %4.3f GB\n' % sum(list(map(lambda x: float(x), + filter(lambda x: len(x) > 0, size))))) # compute the data else: @@ -110,13 +177,10 @@ def generate(LIST_NAME,clean=False): if args.mol is not None: MOL = args.mol else: - MOL = names[args.init:args.final+1] + MOL = names[args.init:args.final + 1] # set the cuda device #os.environ['CUDA_DEVICE'] = args.device # generate the data - generate(MOL,clean=args.clean) - - - + generate(MOL, clean=args.clean) diff --git a/deeprank/utils/plot_utils.py b/deeprank/utils/plot_utils.py index 4228ab69..84456eed 100755 --- a/deeprank/utils/plot_utils.py +++ b/deeprank/utils/plot_utils.py @@ -1,26 +1,24 @@ # 1. plot prediction scores for class 0 and 1 using two-panel box plots # 2. hit rate plot # 3. success rate plot -import numpy as np -import h5py -import sys -import torch -import torch.nn.functional as F -import pandas as pd import re +import sys +import warnings from itertools import zip_longest -from cal_hitrate_successrate import evaluate -from cal_hitrate_successrate import ave_evaluate -from cal_hitrate_successrate import add_rank +import h5py +import numpy as np +import pandas as pd -import warnings +import rpy2.robjects as ro +import torch +import torch.nn.functional as F +from cal_hitrate_successrate import add_rank, ave_evaluate, evaluate from rpy2.rinterface import RRuntimeWarning -warnings.filterwarnings("ignore", category=RRuntimeWarning) - -from rpy2.robjects.lib.ggplot2 import * from rpy2.robjects import pandas2ri -import rpy2.robjects as ro +from rpy2.robjects.lib.ggplot2 import * + +warnings.filterwarnings("ignore", category=RRuntimeWarning) def zip_equal(*iterables): @@ -32,12 +30,10 @@ def zip_equal(*iterables): yield combo -def plot_boxplot(df,figname=None,inverse = False): - - ''' - Plot a boxplot of predictions vs. targets. Useful - to visualize the performance of the training algorithm. - This is only useful in classification tasks. +def plot_boxplot(df, figname=None, inverse=False): + """Plot a boxplot of predictions vs. targets. Useful to visualize the + performance of the training algorithm. This is only useful in + classification tasks. INPUT (pd.DataFrame): @@ -45,100 +41,114 @@ def plot_boxplot(df,figname=None,inverse = False): Test 1AVX_ranair-it0_5286 0 0.503823 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 Test 1AVX_ti5-itw_354w 1 0.502845 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 Test 1AVX_ranair-it0_6223 0 0.511688 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 - ''' + """ print('\n --> Box Plot : ', figname, '\n') data = df font_size = 20 - #line = "#1F3552" + # line = "#1F3552" - text_style = element_text(size = font_size, family = "Tahoma", face = "bold") + text_style = element_text(size=font_size, family="Tahoma", face="bold") - colormap_raw =[['0','ivory3'], - ['1','steelblue']] + colormap_raw = [['0', 'ivory3'], + ['1', 'steelblue']] colormap = ro.StrVector([elt[1] for elt in colormap_raw]) colormap.names = ro.StrVector([elt[0] for elt in colormap_raw]) - p= ggplot(data) + \ - aes_string(x='target', y='DR' , fill='target' ) + \ - geom_boxplot( width = 0.2, alpha = 0.7) + \ - facet_grid(ro.Formula('.~label')) +\ - scale_fill_manual(values = colormap ) + \ - theme_bw() +\ - theme(**{'plot.title' : text_style, - 'text': text_style, - 'axis.title': text_style, - 'axis.text.x': element_text(size = font_size), - 'legend.position': 'right'} ) +\ - scale_x_discrete(name = "Target") + p = ggplot(data) + \ + aes_string(x='target', y='DR', fill='target') + \ + geom_boxplot(width=0.2, alpha=0.7) + \ + facet_grid(ro.Formula('.~label')) +\ + scale_fill_manual(values=colormap) + \ + theme_bw() +\ + theme(**{'plot.title': text_style, + 'text': text_style, + 'axis.title': text_style, + 'axis.text.x': element_text(size=font_size), + 'legend.position': 'right'}) +\ + scale_x_discrete(name="Target") # p.plot() - ggplot2.ggsave(figname, dpi = 100) + ggplot2.ggsave(figname, dpi=100) return p def read_epoch_data(DR_h5FL, epoch): - ''' - # read epoch data into a data frame + """# read epoch data into a data frame. OUTPUT (pd.DataFrame): label modelID target DR sourceFL 0 Test 1AVX_ranair-it0_5286 0 0.503823 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 1 Test 1AVX_ti5-itw_354w 1 0.502845 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 - ''' + """ - #-- 1. read deeprank output data for the specific epoch - h5 = h5py.File(DR_h5FL,'r') + # -- 1. read deeprank output data for the specific epoch + h5 = h5py.File(DR_h5FL, 'r') if epoch is None: - print (f"epoch is not provided. Use the last epoch data.") + print(f"epoch is not provided. Use the last epoch data.") keys = list(h5.keys()) - last_epoch_key = list(filter(lambda x: 'epoch_' in x,keys))[-1] + last_epoch_key = list(filter(lambda x: 'epoch_' in x, keys))[-1] else: - last_epoch_key = 'epoch_%04d' %epoch + last_epoch_key = 'epoch_%04d' % epoch if last_epoch_key not in h5: - print('Incorrect epcoh name\n Possible options are: ' + ' '.join(list(h5.keys()))) + print( + 'Incorrect epcoh name\n Possible options are: ' + + ' '.join( + list( + h5.keys()))) h5.close() return data = h5[last_epoch_key] - - #-- 2. convert into pd.DataFrame - labels = list(data) # labels = ['train', 'test', 'valid'] + # -- 2. convert into pd.DataFrame + labels = list(data) # labels = ['train', 'test', 'valid'] # write a dataframe of DR and label to_plot = pd.DataFrame() for l in labels: # l = train, valid or test - source_hdf5FLs = data[l]['mol'][:,0] - modelIDs = list(data[l]['mol'][:,1]) + source_hdf5FLs = data[l]['mol'][:, 0] + modelIDs = list(data[l]['mol'][:, 1]) DR_rawOut = data[l]['outputs'] - DR = F.softmax(torch.FloatTensor(DR_rawOut), dim = 1) - DR = np.array(DR[:,0]) # the probability of a model being negative + DR = F.softmax(torch.FloatTensor(DR_rawOut), dim=1) + DR = np.array(DR[:, 0]) # the probability of a model being negative - targets = data[l]['targets'][()] + targets = data[l]['targets'][()] targets = targets.astype(np.str) - to_plot_tmp = pd.DataFrame(list(zip_equal(source_hdf5FLs, modelIDs, targets, DR)), columns = ['sourceFL', 'modelID', 'target', 'DR']) + to_plot_tmp = pd.DataFrame( + list( + zip_equal( + source_hdf5FLs, + modelIDs, + targets, + DR)), + columns=[ + 'sourceFL', + 'modelID', + 'target', + 'DR']) to_plot_tmp['label'] = l.capitalize() to_plot = to_plot.append(to_plot_tmp) - to_plot['target'] = pd.Categorical(to_plot['target'], categories=['0', '1']) - to_plot['label'] = pd.Categorical(to_plot['label'], categories=['Train', 'Valid', 'Test']) + to_plot['target'] = pd.Categorical( + to_plot['target'], categories=['0', '1']) + to_plot['label'] = pd.Categorical( + to_plot['label'], categories=[ + 'Train', 'Valid', 'Test']) cols = ['label', 'modelID', 'target', 'DR', 'sourceFL'] to_plot = to_plot[cols] - return to_plot -def merge_HS_DR(DR_df, haddockS): - ''' - INPUT 1 (DR_df: a data frame): +def merge_HS_DR(DR_df, haddockS): + """INPUT 1 (DR_df: a data frame): label modelID target DR sourceFL 0 Test 1AVX_ranair-it0_5286 0 0.503823 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 @@ -153,22 +163,20 @@ def merge_HS_DR(DR_df, haddockS): Test 1ZHI 1ZHI_294w 0 9.758 -19.3448 Test 1ZHI 1ZHI_89w 1 17.535 -11.2127 Train 1ACB 1ACB_9w 1 14.535 -19.2127 - ''' + """ - - #-- merge HS with DR predictions, model IDs and class IDs + # -- merge HS with DR predictions, model IDs and class IDs modelIDs = DR_df['modelID'] HS, idx_keep = get_HS(modelIDs, haddockS) - data = DR_df.iloc[idx_keep,:].copy() + data = DR_df.iloc[idx_keep, :].copy() data['HS'] = HS data['caseID'] = [re.split('_', x)[0] for x in data['modelID']] - - #-- reorder columns + # -- reorder columns col_ori = data.columns col = ['label', 'caseID', 'modelID', 'target', 'sourceFL'] - col.extend( [x for x in col_ori if x not in col]) + col.extend([x for x in col_ori if x not in col]) data = data[col] return data @@ -183,80 +191,80 @@ def read_haddockScoreFL(HS_h5FL): stats['haddock-score'] = {} # stats['i-RMSD'] = {} - modelIDs = [ re.sub('.pdb','',x) for x in data['modelID'] ] # remove .pdb from model ID + modelIDs = [re.sub('.pdb', '', x) + for x in data['modelID']] # remove .pdb from model ID stats['haddock-score'] = dict(zip_equal(modelIDs, data['haddock-score'])) -# stats['i-RMSD'] = dict(zip(modelIDs, data['i-RMSD'])) # some i-RMSDs are wrong!!! Reported an issue. +# stats['i-RMSD'] = dict(zip(modelIDs, data['i-RMSD'])) # some i-RMSDs are +# wrong!!! Reported an issue. return stats + def plot_DR_iRMSD(df, figname=None): - ''' - Plot a scatter plot of DeepRank score vs. iRMSD for train, valid and test + """Plot a scatter plot of DeepRank score vs. iRMSD for train, valid and + test. INPUT (a data frame): label caseID modelID target sourceFL DR irmsd HS Test 1AVX 1AVX_ranair-it0_5286 0 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.503823 25.189108 6.980802 Test 1AVX 1AVX_ti5-itw_354w 1 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.502845 3.668682 -95.158100 - - ''' + """ print('\n --> Scatter plot of DR vs. iRMSD:', figname, '\n') # plot font_size = 16 - text_style = element_text(size = font_size, family = "Tahoma", face = "bold") - p = ggplot(df) + aes_string(y = 'irmsd', x = 'DR') +\ - facet_grid(ro.Formula('.~label')) + \ - geom_point(alpha = 0.5) + \ - theme_bw() +\ - theme(**{'plot.title' : text_style, - 'text': text_style, - 'axis.title': text_style, - 'axis.text.x': element_text(size = font_size + 2), - 'axis.text.y': element_text(size = font_size + 2)} ) + \ - scale_y_continuous(name = "i-RMSD") - - #p.plot() - ggplot2.ggsave(figname, height = 7 , width = 7 * 1.5, dpi = 100) - return p + text_style = element_text(size=font_size, family="Tahoma", face="bold") + p = ggplot(df) + aes_string(y='irmsd', x='DR') +\ + facet_grid(ro.Formula('.~label')) + \ + geom_point(alpha=0.5) + \ + theme_bw() +\ + theme(**{'plot.title': text_style, + 'text': text_style, + 'axis.title': text_style, + 'axis.text.x': element_text(size=font_size + 2), + 'axis.text.y': element_text(size=font_size + 2)}) + \ + scale_y_continuous(name="i-RMSD") + # p.plot() + ggplot2.ggsave(figname, height=7, width=7 * 1.5, dpi=100) + return p def plot_HS_iRMSD(df, figname=None): - ''' - Plot a scatter plot of HS vs. iRMSD for train, valid and test + """Plot a scatter plot of HS vs. iRMSD for train, valid and test. INPUT (a data frame): label caseID modelID target sourceFL DR irmsd HS Test 1AVX 1AVX_ranair-it0_5286 0 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.503823 25.189108 6.980802 Test 1AVX 1AVX_ti5-itw_354w 1 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.502845 3.668682 -95.158100 - - ''' + """ print('\n --> Scatter plot of HS vs. iRMSD:', figname, '\n') # plot font_size = 16 - text_style = element_text(size = font_size, family = "Tahoma", face = "bold") - p= ggplot(df) + aes_string(y = 'irmsd', x = 'HS') +\ - facet_grid(ro.Formula('.~label')) + \ - geom_point(alpha = 0.5) + \ - theme_bw() +\ - theme(**{'plot.title' : text_style, - 'text': text_style, - 'axis.title': text_style, - 'axis.text.x': element_text(size = font_size + 2), - 'axis.text.y': element_text(size = font_size + 2)} ) + \ - scale_y_continuous(name = "i-RMSD") - - #p.plot() - ggplot2.ggsave(figname, height = 7 , width = 7 * 1.5, dpi=100) + text_style = element_text(size=font_size, family="Tahoma", face="bold") + p = ggplot(df) + aes_string(y='irmsd', x='HS') +\ + facet_grid(ro.Formula('.~label')) + \ + geom_point(alpha=0.5) + \ + theme_bw() +\ + theme(**{'plot.title': text_style, + 'text': text_style, + 'axis.title': text_style, + 'axis.text.x': element_text(size=font_size + 2), + 'axis.text.y': element_text(size=font_size + 2)}) + \ + scale_y_continuous(name="i-RMSD") + + # p.plot() + ggplot2.ggsave(figname, height=7, width=7 * 1.5, dpi=100) return p -def plot_successRate_hitRate (df, figname=None,inverse = False): - '''Plot the hit rate and success_rate of the different training/valid/test sets with HS (haddock scores) +def plot_successRate_hitRate(df, figname=None, inverse=False): + """Plot the hit rate and success_rate of the different training/valid/test + sets with HS (haddock scores) The hit rate is defined as: the percentage of positive decoys that are included among the top m decoys. @@ -287,16 +295,14 @@ def plot_successRate_hitRate (df, figname=None,inverse = False): 1ACB 1 0.2 1 0.4 ... 2. Calculate success rate and hit rate over all cases. + """ - - ''' - - #-- 1. calculate success rate and hit rate + # -- 1. calculate success rate and hit rate performance_per_case = evaluate(df) performance_ave = ave_evaluate(performance_per_case) performance_ave = add_rank(performance_ave) - #-- 2. plot + # -- 2. plot plot_evaluation(performance_ave, figname) @@ -312,20 +318,18 @@ def plot_evaluation(df, figname): ''' - #---------- hit rate plot ------- + # ---------- hit rate plot ------- figname1 = figname + '.hitRate.png' print(f'\n --> Hit Rate plot:', figname1, '\n') hit_rate_plot(df) - ggplot2.ggsave(figname1, height = 7 , width = 7 * 1.2, dpi = 100) - + ggplot2.ggsave(figname1, height=7, width=7 * 1.2, dpi=100) - #---------- success rate plot ------- + # ---------- success rate plot ------- figname2 = figname + '.successRate.png' print(f'\n --> Success Rate plot:', figname2, '\n') success_rate_plot(df) - ggplot2.ggsave(figname2, height = 7 , width = 7 * 1.2, dpi=100) - + ggplot2.ggsave(figname2, height=7, width=7 * 1.2, dpi=100) def hit_rate_plot(df): @@ -340,80 +344,83 @@ def hit_rate_plot(df): ''' - #-- melt df + # -- melt df df_melt = pd.melt(df, id_vars=['label', 'rank']) idx1 = df_melt.variable.str.contains('^hitRate') - df_tmp = df_melt.loc[idx1,:].copy() + df_tmp = df_melt.loc[idx1, :].copy() df_tmp.columns = ['Sets', 'rank', 'Methods', 'hit_rate'] tmp = list(df_tmp['Methods']) - df_tmp.loc[:,'Methods']= [re.sub('hitRate_','',x) for x in tmp] # success_DR -> DR + df_tmp.loc[:, 'Methods'] = [ + re.sub('hitRate_', '', x) for x in tmp] # success_DR -> DR font_size = 20 - breaks = pd.to_numeric(np.arange(0,1.01,0.25)) - xlabels = list(map(lambda x: str('%d' % (x*100)) + ' % ', np.arange(0,1.01,0.25)) ) - text_style = element_text(size = font_size, family = "Tahoma", face = "bold") + breaks = pd.to_numeric(np.arange(0, 1.01, 0.25)) + xlabels = list(map(lambda x: str('%d' % (x * 100)) + + ' % ', np.arange(0, 1.01, 0.25))) + text_style = element_text(size=font_size, family="Tahoma", face="bold") p = ggplot(df_tmp) + \ - aes_string(x='rank', y = 'hit_rate', color='Sets', linetype= 'Methods') + \ - geom_line(size=1) + \ - labs(**{'x': 'Top models (%)', 'y': 'Hit Rate'}) + \ - theme_bw() + \ - theme(**{'legend.position': 'right', - 'plot.title': text_style, - 'text': text_style, - 'axis.text.x': element_text(size = font_size), - 'axis.text.y': element_text(size = font_size)}) +\ - scale_x_continuous(**{'breaks':breaks, 'labels': xlabels}) + aes_string(x='rank', y='hit_rate', color='Sets', linetype='Methods') + \ + geom_line(size=1) + \ + labs(**{'x': 'Top models (%)', 'y': 'Hit Rate'}) + \ + theme_bw() + \ + theme(**{'legend.position': 'right', + 'plot.title': text_style, + 'text': text_style, + 'axis.text.x': element_text(size=font_size), + 'axis.text.y': element_text(size=font_size)}) +\ + scale_x_continuous(**{'breaks': breaks, 'labels': xlabels}) return p + def success_rate_plot(df): - ''' - # INPUT: a pandas data frame - label success_HS hitRate_HS success_DR hitRate_DR - 0 valid 1.0 1.0 0.0 0.0 - 1 valid 0.0 1.0 0.0 0.0 - ''' + """# INPUT: a pandas data frame label success_HS hitRate_HS success_DR + hitRate_DR 0 valid 1.0 1.0 0.0 0.0 1 + valid 0.0 1.0 0.0 0.0.""" - #-- add the 'rank' column to df + # -- add the 'rank' column to df rank = [] for _, df_per_label in df.groupby('label'): num_mol = len(df_per_label) - rank_raw = np.array(range(num_mol )) + 1 - rank.extend(rank_raw/num_mol ) + rank_raw = np.array(range(num_mol)) + 1 + rank.extend(rank_raw / num_mol) df['rank'] = rank - #-- melt df + # -- melt df df_melt = pd.melt(df, id_vars=['label', 'rank']) idx1 = df_melt.variable.str.contains('^success_') - df_tmp = df_melt.loc[idx1,:].copy() + df_tmp = df_melt.loc[idx1, :].copy() df_tmp.columns = ['Sets', 'rank', 'Methods', 'success_rate'] tmp = list(df_tmp['Methods']) - df_tmp.loc[:,'Methods']= [re.sub('success_','',x) for x in tmp] # success_DR -> DR + df_tmp.loc[:, 'Methods'] = [ + re.sub('success_', '', x) for x in tmp] # success_DR -> DR font_size = 20 - breaks = pd.to_numeric(np.arange(0,1.01,0.25)) - xlabels = list(map(lambda x: str('%d' % (x*100)) + ' % ', np.arange(0,1.01,0.25)) ) - text_style = element_text(size = font_size, family = "Tahoma", face = "bold") + breaks = pd.to_numeric(np.arange(0, 1.01, 0.25)) + xlabels = list(map(lambda x: str('%d' % (x * 100)) + + ' % ', np.arange(0, 1.01, 0.25))) + text_style = element_text(size=font_size, family="Tahoma", face="bold") p = ggplot(df_tmp) + \ - aes_string(x='rank', y = 'success_rate', color='Sets', linetype= 'Methods') + \ - geom_line(size=1) + \ - labs(**{'x': 'Top models (%)', 'y': 'Success Rate'}) + \ - theme_bw() + \ - theme(**{'legend.position': 'right', - 'plot.title': text_style, - 'text': text_style, - 'axis.text.x': element_text(size = font_size), - 'axis.text.y': element_text(size = font_size)}) +\ - scale_x_continuous(**{'breaks':breaks, 'labels': xlabels}) + aes_string(x='rank', y='success_rate', color='Sets', linetype='Methods') + \ + geom_line(size=1) + \ + labs(**{'x': 'Top models (%)', 'y': 'Success Rate'}) + \ + theme_bw() + \ + theme(**{'legend.position': 'right', + 'plot.title': text_style, + 'text': text_style, + 'axis.text.x': element_text(size=font_size), + 'axis.text.y': element_text(size=font_size)}) +\ + scale_x_continuous(**{'breaks': breaks, 'labels': xlabels}) # p.plot() return p -def get_irmsd( source_hdf5, modelIDs): + +def get_irmsd(source_hdf5, modelIDs): irmsd = [] for h5FL, modelID in zip_equal(source_hdf5, modelIDs): @@ -424,9 +431,8 @@ def get_irmsd( source_hdf5, modelIDs): return irmsd - -def get_HS(modelIDs,haddockS): - HS=[] +def get_HS(modelIDs, haddockS): + HS = [] idx_keep = [] for idx, modelID in enumerate(modelIDs): @@ -435,8 +441,8 @@ def get_HS(modelIDs,haddockS): idx_keep.append(idx) return HS, idx_keep -def add_irmsd(df): +def add_irmsd(df): ''' INPUT (a data frame): df: @@ -460,9 +466,7 @@ def add_irmsd(df): return df - def prepare_df(deeprank_h5FL, HS_h5FL, epoch): - ''' OUTPUT: a data frame: @@ -470,7 +474,7 @@ def prepare_df(deeprank_h5FL, HS_h5FL, epoch): Test 1AVX 1AVX_ranair-it0_5286 0 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.503823 25.189108 6.980802 Test 1AVX 1AVX_ti5-itw_354w 1 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.502845 3.668682 -95.158100 ''' - #-- read deeprank_h5FL epoch data into pd.DataFrame (DR_df) + # -- read deeprank_h5FL epoch data into pd.DataFrame (DR_df) DR_df = read_epoch_data(deeprank_h5FL, epoch) ''' @@ -482,15 +486,15 @@ def prepare_df(deeprank_h5FL, HS_h5FL, epoch): 2 Test 1AVX_ranair-it0_6223 0 0.511688 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 ''' - #-- add iRMSD column to DR_df + # -- add iRMSD column to DR_df DR_df = add_irmsd(DR_df) - #-- report the number of hits for train/valid/test + # -- report the number of hits for train/valid/test hit_statistics(DR_df) - #-- add HS to DR_df (note: bound complexes do not have HS) + # -- add HS to DR_df (note: bound complexes do not have HS) stats = read_haddockScoreFL(HS_h5FL) - haddockS = stats['haddock-score']# haddockS[modelID] = score + haddockS = stats['haddock-score'] # haddockS[modelID] = score DR_HS_df = merge_HS_DR(DR_df, haddockS) ''' @@ -504,50 +508,53 @@ def prepare_df(deeprank_h5FL, HS_h5FL, epoch): return DR_HS_df -def hit_statistics(df): - ''' - Report the number of hits for Train, valid and test. +def hit_statistics(df): + """Report the number of hits for Train, valid and test. INPUT (a data frame): label modelID target DR sourceFL irmsd Test 1AVX_ranair-it0_5286 0 0.503823 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 25.189108 Test 1AVX_ti5-itw_354w 1 0.502845 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 3.668682 - ''' + """ labels = ['Train', 'Valid', 'Test'] grouped = df.groupby('label') - #-- 1. count num_hit based on i-rmsd - num_hits = grouped['irmsd'].apply(lambda x: len(x[x<=4])) + # -- 1. count num_hit based on i-rmsd + num_hits = grouped['irmsd'].apply(lambda x: len(x[x <= 4])) num_models = grouped.apply(len) for label in labels: - print(f"According to 'i-RMSD' -> num of hits for {label}: {num_hits[label]} out of {num_models[label]} models") + print( + f"According to 'i-RMSD' -> num of hits for {label}: {num_hits[label]} out of {num_models[label]} models") print("") - #-- 2. count num_hit based on the 'target' column - num_hits = grouped['target'].apply(lambda x: len(x[x=='1'])) + # -- 2. count num_hit based on the 'target' column + num_hits = grouped['target'].apply(lambda x: len(x[x == '1'])) num_models = grouped.apply(len) for label in labels: - print(f"According to 'targets' -> num of hits for {label}: {num_hits[label]} out of {num_models[label]} models") + print( + f"According to 'targets' -> num of hits for {label}: {num_hits[label]} out of {num_models[label]} models") print("") - #-- 3. report num_cases_wo_hit + # -- 3. report num_cases_wo_hit df_tmp = df.copy() df_tmp['caseID'] = df['modelID'].apply(get_caseID) grouped = df_tmp.groupby(['label', 'caseID']) - num_hits = grouped['target'].apply(lambda x: len(x[x==1])) + num_hits = grouped['target'].apply(lambda x: len(x[x == 1])) grp = num_hits.groupby('label') num_cases_total = grp.apply(lambda x: len(x)) - num_cases_wo_hit = grp.apply(lambda x: len(x==0)) + num_cases_wo_hit = grp.apply(lambda x: len(x == 0)) for label in labels: - print(f"According to 'targets' -> {num_cases_wo_hit[label]} out of {num_cases_total[label]} cases do not have any hits for {label}") + print( + f"According to 'targets' -> {num_cases_wo_hit[label]} out of {num_cases_total[label]} cases do not have any hits for {label}") print("") + def get_caseID(modelID): # modelID = 1AVX_ranair-it0_5286 # caseID = 1AVX @@ -556,12 +563,14 @@ def get_caseID(modelID): caseID = tmp[0] return caseID -def main(HS_h5FL= '/home/lixue/DBs/BM5-haddock24/stats/stats.h5'): - if len(sys.argv) !=4: + +def main(HS_h5FL='/home/lixue/DBs/BM5-haddock24/stats/stats.h5'): + if len(sys.argv) != 4: print(f"Usage: python {sys.argv[0]} epoch_data.hdf5 epoch fig_name") sys.exit() - deeprank_h5FL = sys.argv[1] #the output h5 file from deeprank: 'epoch_data.hdf5' - epoch = int(sys.argv[2]) # 9 + # the output h5 file from deeprank: 'epoch_data.hdf5' + deeprank_h5FL = sys.argv[1] + epoch = int(sys.argv[2]) # 9 figname = sys.argv[3] pandas2ri.activate() @@ -569,12 +578,24 @@ def main(HS_h5FL= '/home/lixue/DBs/BM5-haddock24/stats/stats.h5'): df = prepare_df(deeprank_h5FL, HS_h5FL, epoch) #-- plot - plot_HS_iRMSD(df, figname=figname + '.epo' + str(epoch) +'.irsmd_HS.png') + plot_HS_iRMSD(df, figname=figname + '.epo' + str(epoch) + '.irsmd_HS.png') plot_DR_iRMSD(df, figname=figname + '.epo' + str(epoch) + '.irsmd_DR.png') - plot_boxplot(df, figname=figname + '.epo' + str(epoch) + '.boxplot.png',inverse = False) - plot_successRate_hitRate(df[['label', 'caseID', 'modelID', 'target', 'DR','HS']].copy(), figname=figname + '.epo' + str(epoch) ,inverse = False) + plot_boxplot( + df, + figname=figname + + '.epo' + + str(epoch) + + '.boxplot.png', + inverse=False) + plot_successRate_hitRate(df[['label', + 'caseID', + 'modelID', + 'target', + 'DR', + 'HS']].copy(), + figname=figname + '.epo' + str(epoch), + inverse=False) + if __name__ == '__main__': main() - - diff --git a/deeprank/utils/run_slurmFLs.py b/deeprank/utils/run_slurmFLs.py index e77647b8..4fc0bbeb 100755 --- a/deeprank/utils/run_slurmFLs.py +++ b/deeprank/utils/run_slurmFLs.py @@ -2,28 +2,25 @@ # Li Xue # 20-Feb-2019 10:50 -''' -Split multiple jobs into batches and submit to cartesius. +"""Split multiple jobs into batches and submit to cartesius. INPUT: a file that contains all the jobs, for example, python /projects/0/deeprank/change_BIN_CLASS.py /projects/000_1ACB.hdf5 & python /projects/0/deeprank/change_BIN_CLASS.py /projects/000_1AK4.hdf5 & ... - -''' -import re -import os +""" import glob +import os +import re import subprocess -from shlex import quote -from shlex import split import time +from shlex import quote, split -logDIR='/projects/0/deeprank/BM5/scripts/slurm/change_BINCLASS/hdf5_withGridFeature' -slurmDIR=logDIR -num_cores = 24 # run 24 cores for each slurm job -batch_size = num_cores # number of jobs per slurm file +logDIR = '/projects/0/deeprank/BM5/scripts/slurm/change_BINCLASS/hdf5_withGridFeature' +slurmDIR = logDIR +num_cores = 24 # run 24 cores for each slurm job +batch_size = num_cores # number of jobs per slurm file def write_slurmscript(all_job_FL, batch_size, slurmDIR='tmp', logDIRi='tmp'): @@ -31,7 +28,7 @@ def write_slurmscript(all_job_FL, batch_size, slurmDIR='tmp', logDIRi='tmp'): all_job_FL = quote(all_job_FL) slurmDIR = quote(slurmDIR) - #- split all_jobs.sh into mutliple files + # - split all_jobs.sh into mutliple files command = f'cp {all_job_FL} {slurmDIR}' command = split(command) subprocess.check_call(command) @@ -42,7 +39,7 @@ def write_slurmscript(all_job_FL, batch_size, slurmDIR='tmp', logDIRi='tmp'): subprocess.check_call(command) # subprocess.check_call(['split' , '-a', '3' ,'-d', f'-l {batch_size}', '--additional-suffix=.slurm' ,f"{slurmDIR}/{all_job_FL}", f"{slurmDIR}/batch"]) - #-- add slurm header and tail to each file + # -- add slurm header and tail to each file batchID = 0 for slurmFL in glob.glob(f'{slurmDIR}/batch*'): @@ -51,17 +48,18 @@ def write_slurmscript(all_job_FL, batch_size, slurmDIR='tmp', logDIRi='tmp'): write_slurm_tail(slurmFL) print(slurmFL + ' generated ') -def submit_slurmscript(slurm_dir, batch_size = 100): + +def submit_slurmscript(slurm_dir, batch_size=100): # submit slurm scripts in batches # each batch waits for the previous batch to finish first. slu_FLs = glob.glob(slurm_dir + "/*.slurm") - jobIDs=[] - newjobIDs=[] + jobIDs = [] + newjobIDs = [] num = 0 for slu_FL in slu_FLs: - outFL=os.path.splitext(slu_FL)[0] + '.out' + outFL = os.path.splitext(slu_FL)[0] + '.out' if os.path.isfile(outFL): print(f"{outFL} exists. Skip submitting slurm file.") @@ -70,8 +68,8 @@ def submit_slurmscript(slurm_dir, batch_size = 100): num = num + 1 if num <= batch_size: -# command = ['sbatch', slu_FLs[i] ] -# print (" ".join(command)) + # command = ['sbatch', slu_FLs[i] ] + # print (" ".join(command)) slu_FL = quote(slu_FL) command = f'sbatch {slu_FL}' print(command) @@ -79,29 +77,31 @@ def submit_slurmscript(slurm_dir, batch_size = 100): jobID = subprocess.check_output(command) jobID = re.findall(r'\d+', str(jobID)) jobID = jobID[0] - print (num) - print (jobID) - newjobIDs.append(jobID) # these IDs will used for dependency=afterany + print(num) + print(jobID) + # these IDs will used for dependency=afterany + newjobIDs.append(jobID) - if num >batch_size: -# command=['sbatch', '--dependency=afterany:'+ ":".join(jobIDs), slu_FLs[i] ] -# print (" ".join(command)) + if num > batch_size: + # command=['sbatch', '--dependency=afterany:'+ ":".join(jobIDs), slu_FLs[i] ] + # print (" ".join(command)) - command = 'sbatch --dependency=afterany:' + ':'.join(jobIDs) + f'{slu_FLs[i]}' + command = 'sbatch --dependency=afterany:' + \ + ':'.join(jobIDs) + f'{slu_FLs[i]}' print(command) command = split(command) jobID = subprocess.check_output(command) jobID = re.findall(r'\d+', str(jobID)) jobID = jobID[0] - print (num) - print (jobID) + print(num) + print(jobID) newjobIDs.append(jobID) - if num%batch_size ==0: - print (newjobIDs) - jobIDs=newjobIDs - newjobIDs=[] - print ("------------- new batch --------- \n") + if num % batch_size == 0: + print(newjobIDs) + jobIDs = newjobIDs + newjobIDs = [] + print("------------- new batch --------- \n") time.sleep(1) @@ -129,6 +129,7 @@ def submit_slurmscript(slurm_dir, batch_size = 100): # time.sleep(5) # + def parse_jobID(jobID): # input: b'Submitted batch job 5442433\n ' # output: 5442433 @@ -136,10 +137,11 @@ def parse_jobID(jobID): jobID = jobID[0] return (jobID) + def write_slurm_header(slurmFL, batchID, batch_size, logFL): - #- 1. prepare the header string - header='' + # - 1. prepare the header string + header = '' header = header + "#!/usr/bin/bash\n" header = header + "#SBATCH -p normal\n" @@ -159,15 +161,14 @@ def write_slurm_header(slurmFL, batchID, batch_size, logFL): """ header = header + common_part - - #- 2. add the header to slurmFL - f = open(slurmFL,'r') + # - 2. add the header to slurmFL + f = open(slurmFL, 'r') content = f.readlines() f.close() content.insert(0, header) - f = open(slurmFL,'w') + f = open(slurmFL, 'w') f.write(''.join(content)) f.close() @@ -185,7 +186,7 @@ def write_slurm_tail(slurmFL): echo "total runtime: $runtime sec" """ - f = open(slurmFL,'a+') + f = open(slurmFL, 'a+') f.write(tail) f.close() @@ -196,4 +197,3 @@ def write_slurm_tail(slurmFL): write_slurmscript('all_jobs.sh', batch_size, slurmDIR, logDIR) #submit_slurmscript(slurmDIR, 200) - diff --git a/deeprank/utils/visualize3Ddata.py b/deeprank/utils/visualize3Ddata.py index bf4e8fb6..5adfbed8 100755 --- a/deeprank/utils/visualize3Ddata.py +++ b/deeprank/utils/visualize3Ddata.py @@ -1,18 +1,17 @@ #!/usr/bin/env python -import numpy as np -import subprocess as sp import os +import subprocess as sp + import h5py -from deeprank.tools import pdb2sql -from deeprank.tools import sparse +import numpy as np -def visualize3Ddata(hdf5=None,mol_name=None,out=None): +from deeprank.tools import pdb2sql, sparse - ''' - This function can be used to generate cube files for the visualization of the mapped - data in VMD +def visualize3Ddata(hdf5=None, mol_name=None, out=None): + """This function can be used to generate cube files for the visualization + of the mapped data in VMD. Usage python generate_cube_files.py @@ -31,8 +30,7 @@ def visualize3Ddata(hdf5=None,mol_name=None,out=None): quick vizualisation of the data by typing vmd -e .vmd - ''' - + """ outdir = out @@ -46,14 +44,14 @@ def visualize3Ddata(hdf5=None,mol_name=None,out=None): os.mkdir(outdir) try: - f5 = h5py.File(hdf5,'r') - except: - raise FileNotFoundError('HDF5 file %s could not be opened' %hdf5) + f5 = h5py.File(hdf5, 'r') + except BaseException: + raise FileNotFoundError('HDF5 file %s could not be opened' % hdf5) try: molgrp = f5[mol_name] - except: - raise LookupError('Molecule %s not found in %s' %(mol_name,hdf5)) + except BaseException: + raise LookupError('Molecule %s not found in %s' % (mol_name, hdf5)) # create the pdb file sqldb = pdb2sql(molgrp['complex'][:]) @@ -65,7 +63,7 @@ def visualize3Ddata(hdf5=None,mol_name=None,out=None): grid['x'] = molgrp['grid_points/x'][:] grid['y'] = molgrp['grid_points/y'][:] grid['z'] = molgrp['grid_points/z'][:] - shape = (len(grid['x']),len(grid['y']),len(grid['z'])) + shape = (len(grid['x']), len(grid['y']), len(grid['z'])) # deals with the features mapgrp = molgrp['mapped_features'] @@ -79,55 +77,59 @@ def visualize3Ddata(hdf5=None,mol_name=None,out=None): for ff in featgrp.keys(): subgrp = featgrp[ff] if not subgrp.attrs['sparse']: - data_dict[ff] = subgrp['value'][:] + data_dict[ff] = subgrp['value'][:] else: - spg = sparse.FLANgrid(sparse=True,index=subgrp['index'][:],value=subgrp['value'][:],shape=shape) - data_dict[ff] = spg.to_dense() + spg = sparse.FLANgrid( + sparse=True, + index=subgrp['index'][:], + value=subgrp['value'][:], + shape=shape) + data_dict[ff] = spg.to_dense() # export the cube file - export_cube_files(data_dict,data_name,grid,outdir) + export_cube_files(data_dict, data_name, grid, outdir) f5.close() -def export_cube_files(data_dict,data_name,grid,export_path): +def export_cube_files(data_dict, data_name, grid, export_path): - print('-- Export %s data to %s' %(data_name,export_path)) + print('-- Export %s data to %s' % (data_name, export_path)) bohr2ang = 0.52918 # individual axis of the grid - x,y,z = grid['x'],grid['y'],grid['z'] + x, y, z = grid['x'], grid['y'], grid['z'] # extract grid_info - npts = np.array([len(x),len(y),len(z)]) - res = np.array([x[1]-x[0],y[1]-y[0],z[1]-z[0]]) + npts = np.array([len(x), len(y), len(z)]) + res = np.array([x[1] - x[0], y[1] - y[0], z[1] - z[0]]) # the cuve file is apparently give in bohr - xmin,ymin,zmin = np.min(x)/bohr2ang,np.min(y)/bohr2ang,np.min(z)/bohr2ang - scale_res = res/bohr2ang + xmin, ymin, zmin = np.min(x) / bohr2ang, np.min(y) / \ + bohr2ang, np.min(z) / bohr2ang + scale_res = res / bohr2ang # export files for visualization - for key,values in data_dict.items(): + for key, values in data_dict.items(): - fname = export_path + data_name + '_%s' %(key) + '.cube' - f = open(fname,'w') + fname = export_path + data_name + '_%s' % (key) + '.cube' + f = open(fname, 'w') f.write('CUBE FILE\n') f.write("OUTER LOOP: X, MIDDLE LOOP: Y, INNER LOOP: Z\n") - f.write("%5i %11.6f %11.6f %11.6f\n" % (1,xmin,ymin,zmin)) - f.write("%5i %11.6f %11.6f %11.6f\n" % (npts[0],scale_res[0],0,0)) - f.write("%5i %11.6f %11.6f %11.6f\n" % (npts[1],0,scale_res[1],0)) - f.write("%5i %11.6f %11.6f %11.6f\n" % (npts[2],0,0,scale_res[2])) - + f.write("%5i %11.6f %11.6f %11.6f\n" % (1, xmin, ymin, zmin)) + f.write("%5i %11.6f %11.6f %11.6f\n" % (npts[0], scale_res[0], 0, 0)) + f.write("%5i %11.6f %11.6f %11.6f\n" % (npts[1], 0, scale_res[1], 0)) + f.write("%5i %11.6f %11.6f %11.6f\n" % (npts[2], 0, 0, scale_res[2])) # the cube file require 1 atom - f.write("%5i %11.6f %11.6f %11.6f %11.6f\n" % (0,0,0,0,0)) + f.write("%5i %11.6f %11.6f %11.6f %11.6f\n" % (0, 0, 0, 0, 0)) last_char_check = True for i in range(npts[0]): for j in range(npts[1]): for k in range(npts[2]): - f.write(" %11.5e" % values[i,j,k]) + f.write(" %11.5e" % values[i, j, k]) last_char_check = True if k % 6 == 5: f.write("\n") @@ -136,31 +138,36 @@ def export_cube_files(data_dict,data_name,grid,export_path): f.write("\n") f.close() - # export VMD script if cube format is required fname = export_path + data_name + '.vmd' - f = open(fname,'w') + f = open(fname, 'w') f.write('# can be executed with vmd -e viz_mol.vmd\n\n') # write all the cube file in one given molecule keys = list(data_dict.keys()) - write_molspec_vmd(f, data_name +'_%s.cube' %(keys[0]),'VolumeSlice','Volume') - for idata in range(1,len(keys)): - f.write('mol addfile ' + data_name +'_%s.cube\n' %(keys[idata])) + write_molspec_vmd( + f, + data_name + + '_%s.cube' % + (keys[0]), + 'VolumeSlice', + 'Volume') + for idata in range(1, len(keys)): + f.write('mol addfile ' + data_name + '_%s.cube\n' % (keys[idata])) f.write('mol rename top ' + data_name) # load the complex - write_molspec_vmd(f,'complex.pdb','Cartoon','Chain') + write_molspec_vmd(f, 'complex.pdb', 'Cartoon', 'Chain') f.close() # quick shortcut for writting the vmd file -def write_molspec_vmd(f,name,rep,color): - f.write('\nmol new %s\n' %name) - f.write('mol delrep 0 top\nmol representation %s\n' %rep) +def write_molspec_vmd(f, name, rep, color): + f.write('\nmol new %s\n' % name) + f.write('mol delrep 0 top\nmol representation %s\n' % rep) if color is not None: - f.write('mol color %s \n' %color) + f.write('mol color %s \n' % color) f.write('mol addrep top\n\n') @@ -168,10 +175,20 @@ def write_molspec_vmd(f,name,rep,color): import argparse - parser = argparse.ArgumentParser(description='export the grid data in cube format') - parser.add_argument('hdf5', help="hdf5 file storing the data set",default=None) - parser.add_argument('mol_name',help="name of the molecule in the hdf5",default=None) - parser.add_argument('-out',help="name of the directory where to output the files",default=None) + parser = argparse.ArgumentParser( + description='export the grid data in cube format') + parser.add_argument( + 'hdf5', + help="hdf5 file storing the data set", + default=None) + parser.add_argument( + 'mol_name', + help="name of the molecule in the hdf5", + default=None) + parser.add_argument( + '-out', + help="name of the directory where to output the files", + default=None) args = parser.parse_args() # shortcut @@ -180,4 +197,4 @@ def write_molspec_vmd(f,name,rep,color): out = args.out # lauch the tool - visualize3Ddata(hdf5=hdf5,mol_name=mol_name,out=out) + visualize3Ddata(hdf5=hdf5, mol_name=mol_name, out=out) diff --git a/docs/conf.py b/docs/conf.py index da55b21a..ea80390e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,11 +25,24 @@ class Mock(MagicMock): @classmethod def __getattr__(cls, name): - return MagicMock() - -MOCK_MODULES = ['numpy', 'scipy','h5py','scipy.signal','torch','torch.utils', - 'torch.utils.data', 'matplotlib','matplotlib.pyplot','torch.autograd','torch.nn', - 'torch.optim','torch.cuda','tqdm'] + return MagicMock() + + +MOCK_MODULES = [ + 'numpy', + 'scipy', + 'h5py', + 'scipy.signal', + 'torch', + 'torch.utils', + 'torch.utils.data', + 'matplotlib', + 'matplotlib.pyplot', + 'torch.autograd', + 'torch.nn', + 'torch.optim', + 'torch.cuda', + 'tqdm'] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) sys.path.insert(0, os.path.abspath('../')) @@ -196,12 +209,11 @@ def __getattr__(cls, name): ] - # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { 'python': ('https://docs.python.org/', None), 'numpy': ('http://docs.scipy.org/doc/numpy/', None), - 'pytorch' :('http://pytorch.org/docs/0.3.1/',None), + 'pytorch': ('http://pytorch.org/docs/0.3.1/', None), } @@ -214,4 +226,4 @@ def __getattr__(cls, name): # return skip # def setup(app): -# app.connect("autodoc-skip-member", skip) \ No newline at end of file +# app.connect("autodoc-skip-member", skip) diff --git a/example/generate_dataset.py b/example/generate_dataset.py index 14dc3ef7..a9949026 100644 --- a/example/generate_dataset.py +++ b/example/generate_dataset.py @@ -1,19 +1,18 @@ from deeprank.generate import * from mpi4py import MPI - comm = MPI.COMM_WORLD # name of the hdf5 to generate h5file = './hdf5/1ak4.hdf5' # for each hdf5 file where to find the pdbs -pdb_source = '../test/1AK4/decoys/' +pdb_source = '../test/1AK4/decoys/' # where to find the native conformations # pdb_native is only used to calculate i-RMSD, dockQ and so on. -#The native pdb files will not be saved in the hdf5 file -pdb_native = '../test/1AK4/native/' +# The native pdb files will not be saved in the hdf5 file +pdb_native = '../test/1AK4/native/' # where to find the pssm @@ -21,17 +20,22 @@ # initialize the database -database = DataGenerator(pdb_source=pdb_source, - pdb_native=pdb_native, - pssm_source=pssm_source, - data_augmentation = 0, - compute_targets = ['deeprank.targets.dockQ','deeprank.targets.binary_class'], - compute_features = ['deeprank.features.AtomicFeature', - 'deeprank.features.FullPSSM', - 'deeprank.features.PSSM_IC', - 'deeprank.features.BSA', - 'deeprank.features.ResidueDensity'], - hdf5=h5file,mpi_comm=comm) +database = DataGenerator( + pdb_source=pdb_source, + pdb_native=pdb_native, + pssm_source=pssm_source, + data_augmentation=0, + compute_targets=[ + 'deeprank.targets.dockQ', + 'deeprank.targets.binary_class'], + compute_features=[ + 'deeprank.features.AtomicFeature', + 'deeprank.features.FullPSSM', + 'deeprank.features.PSSM_IC', + 'deeprank.features.BSA', + 'deeprank.features.ResidueDensity'], + hdf5=h5file, + mpi_comm=comm) # create the database @@ -41,11 +45,11 @@ # define the 3D grid -#grid_info = { +# grid_info = { # 'number_of_points' : [30,30,30], # 'resolution' : [1.,1.,1.], # 'atomic_densities' : {'CA':3.5,'N':3.5,'O':3.5,'C':3.5}, -#} +# } # generate the grid #print('{:25s}'.format('Generate the grid') + database.hdf5) @@ -59,4 +63,3 @@ # print('{:25s}'.format('Normalization') + database.hdf5) # norm = NormalizeData(database.hdf5) # norm.get() - diff --git a/example/learn.py b/example/learn.py index 57d90681..5227de4d 100644 --- a/example/learn.py +++ b/example/learn.py @@ -1,15 +1,15 @@ -import os import glob +import os + import numpy as np from deeprank.learn import * +# -- for classification +from deeprank.learn.model3d import cnn_class as cnn3d -#-- for regression +# -- for regression #from deeprank.learn.model3d import cnn_reg as cnn3d -#-- for classification -from deeprank.learn.model3d import cnn_class as cnn3d - database = './hdf5/*1ak4.hdf5' out = './out' @@ -17,35 +17,43 @@ # clean the output dir out = './out_3d' if os.path.isdir(out): - for f in glob.glob(out+'/*'): - os.remove(f) - os.removedirs(out) + for f in glob.glob(out + '/*'): + os.remove(f) + os.removedirs(out) # declare the dataset instance data_set = DataSet(database, - valid_database = None, - test_database = None, - mapfly=True, - use_rotation=5, - grid_info = {'number_of_points':[10,10,10], 'resolution' : [3,3,3]}, - - select_feature={'AtomicDensities' : {'CA':1.7, 'C':1.7, 'N':1.55, 'O':1.52}, - 'Features' : ['coulomb','vdwaals','charge','PSSM_*'] }, - - #select_target='DOCKQ', # regression - select_target='BIN_CLASS', # classification - tqdm=True, - normalize_features = False, - normalize_targets=False, - clip_features=False, - pair_chain_feature=np.add, - dict_filter={'DOCKQ':'<1.'}) + valid_database=None, + test_database=None, + mapfly=True, + use_rotation=5, + grid_info={ + 'number_of_points': [ + 10, 10, 10], 'resolution': [ + 3, 3, 3]}, + + select_feature={'AtomicDensities': {'CA': 1.7, 'C': 1.7, 'N': 1.55, 'O': 1.52}, + 'Features': ['coulomb', 'vdwaals', 'charge', 'PSSM_*']}, + + # select_target='DOCKQ', # regression + select_target='BIN_CLASS', # classification + tqdm=True, + normalize_features=False, + normalize_targets=False, + clip_features=False, + pair_chain_feature=np.add, + dict_filter={'DOCKQ': '<1.'}) # create the network -model = NeuralNet(data_set,cnn3d,model_type='3d',task='class', - cuda=False,plot=True,outdir=out) +model = NeuralNet(data_set, cnn3d, model_type='3d', task='class', + cuda=False, plot=True, outdir=out) # start the training -model.train(nepoch = 3, divide_trainset = None, train_batch_size = 5, num_workers=0, save_model='all') +model.train( + nepoch=3, + divide_trainset=None, + train_batch_size=5, + num_workers=0, + save_model='all') diff --git a/example/learn_batch.py b/example/learn_batch.py index 4a442296..8b8b1c27 100644 --- a/example/learn_batch.py +++ b/example/learn_batch.py @@ -1,10 +1,12 @@ -import os import glob +import os +import re +import sys +from math import * + import numpy as np + from deeprank.learn import * -from math import * -import sys -import re from deeprank.learn.model3d import cnn_class as cnn3d from torch import optim @@ -15,24 +17,35 @@ # os.environ["CUDA_VISIBLE_DEVICES"] = "0" -def divide_data(hdf5_DIR, caseID_FL, portion=[0.8,0.1,0.1], random = True, write_to_file = True): + +def divide_data( + hdf5_DIR, + caseID_FL, + portion=[ + 0.8, + 0.1, + 0.1], + random=True, + write_to_file=True): # INPUT: the dir that stores all hdf5 data (training, validation, and test) # OUPUT: randomly divide them into train, validation, and test at the caseID-level. Return the filenames. - # write_to_file: True then write the files of trainSet.txt, valiatonSet.txt and testSet.txt + # write_to_file: True then write the files of trainSet.txt, + # valiatonSet.txt and testSet.txt if sum(portion) > 1: - sys.exit("Error: The sum of portions for train/validatoin/test is larger than 1!") + sys.exit( + "Error: The sum of portions for train/validatoin/test is larger than 1!") if len(portion) != 3: sys.exit("Error: the length of portions has to be 3.") - caseIDs = np.array(read_listFL(caseID_FL)) - train_caseIDs, valid_caseIDs, test_caseIDs = random_split(caseIDs, portion, random = random) + train_caseIDs, valid_caseIDs, test_caseIDs = random_split( + caseIDs, portion, random=random) - print (f"\nnum of training cases: {len(train_caseIDs)}") - print (f"num of validation cases: {len(valid_caseIDs)}") - print (f"num of test cases: {len(test_caseIDs)}\n") + print(f"\nnum of training cases: {len(train_caseIDs)}") + print(f"num of validation cases: {len(valid_caseIDs)}") + print(f"num of test cases: {len(test_caseIDs)}\n") train_database = get_hdf5FLs(train_caseIDs, hdf5_DIR) valid_database = get_hdf5FLs(valid_caseIDs, hdf5_DIR) @@ -40,9 +53,14 @@ def divide_data(hdf5_DIR, caseID_FL, portion=[0.8,0.1,0.1], random = True, write if write_to_file is True: outDIR = os.getcwd() - write_train_valid_testFLs (train_database, valid_database, test_database, outDIR) + write_train_valid_testFLs( + train_database, + valid_database, + test_database, + outDIR) return train_database, valid_database, test_database + def get_hdf5FLs(caseIDs, hdf5_DIR): hdf5_FLs = [] @@ -51,18 +69,19 @@ def get_hdf5FLs(caseIDs, hdf5_DIR): return hdf5_FLs + def read_listFL(listFL): - f = open(listFL,'r') + f = open(listFL, 'r') caseIDs = f.readlines() f.close() - caseIDs = [ x.strip() for x in caseIDs if not re.search('^#', x)] - print (f"{len(caseIDs)} cases read from {listFL}") + caseIDs = [x.strip() for x in caseIDs if not re.search('^#', x)] + print(f"{len(caseIDs)} cases read from {listFL}") return caseIDs -def random_split(array, portion, random = True): +def random_split(array, portion, random=True): # array: np.array. Can be a list of caseIDs or a list of hdf5 file names if random is True: @@ -77,13 +96,17 @@ def random_split(array, portion, random = True): n_test = floor(n_cases * portion[2]) train = array[:n_train] - valid = array[n_train:n_train+n_valid] - test = array[n_train + n_valid: n_train + n_valid + n_test] + valid = array[n_train:n_train + n_valid] + test = array[n_train + n_valid: n_train + n_valid + n_test] return train, valid, test -def write_train_valid_testFLs (train_database, valid_database, test_database, outDIR): +def write_train_valid_testFLs( + train_database, + valid_database, + test_database, + outDIR): trainID_FL = f"{outDIR}/trainIDs.txt" validID_FL = f"{outDIR}/validIDs.txt" testID_FL = f"{outDIR}/testIDs.txt" @@ -94,7 +117,7 @@ def write_train_valid_testFLs (train_database, valid_database, test_database, ou for outFL, database in zip(outFLs, databases): if database is not True: - np.savetxt(outFL, database, delimiter = "\n", fmt = "%s") + np.savetxt(outFL, database, delimiter="\n", fmt="%s") print(f"{outFL} generated.") @@ -103,50 +126,57 @@ def main(): out = './out' hdf5_DIR = './hdf5' caseID_FL = 'caseIDs.txt' - train_database, valid_database, test_database = \ - divide_data(hdf5_DIR = hdf5_DIR,caseID_FL = caseID_FL, portion = [0.2,0.1,0.1], random = False) + train_database, valid_database, test_database = divide_data( + hdf5_DIR=hdf5_DIR, caseID_FL=caseID_FL, portion=[0.2, 0.1, 0.1], random=False) # clean the output dir out = './out_3d' if os.path.isdir(out): - for f in glob.glob(out+'/*'): + for f in glob.glob(out + '/*'): os.remove(f) os.removedirs(out) - - # declare the dataset instance - data_set = DataSet(train_database = train_database, - valid_database = valid_database, - test_database = test_database, - mapfly=True, - use_rotation=0, - grid_info = {'number_of_points':[6,6,6], 'resolution' : [5,5,5]}, - - # select_feature={'AtomicDensities' : {'CA':1.7, 'C':1.7, 'N':1.55, 'O':1.52}, - # 'Features' : ['coulomb','vdwaals','charge','PSSM_*'] }, - # select_feature = 'all', - select_feature = {'Features':['PSSM_*']}, - select_target='BIN_CLASS', - tqdm=True, - normalize_features = False, - normalize_targets=False, - clip_features=False, - pair_chain_feature=np.add, - dict_filter={'DOCKQ':'>0.01', 'IRMSD':'<=4 or >10'}) + data_set = DataSet(train_database=train_database, + valid_database=valid_database, + test_database=test_database, + mapfly=True, + use_rotation=0, + grid_info={ + 'number_of_points': [ + 6, 6, 6], 'resolution': [ + 5, 5, 5]}, + + # select_feature={'AtomicDensities' : {'CA':1.7, 'C':1.7, 'N':1.55, 'O':1.52}, + # 'Features' : ['coulomb','vdwaals','charge','PSSM_*'] }, + # select_feature = 'all', + select_feature={'Features': ['PSSM_*']}, + select_target='BIN_CLASS', + tqdm=True, + normalize_features=False, + normalize_targets=False, + clip_features=False, + pair_chain_feature=np.add, + dict_filter={'DOCKQ': '>0.01', 'IRMSD': '<=4 or >10'}) # create the network - model = NeuralNet(data_set,cnn3d,model_type='3d',task='class', - cuda=False,plot=True,outdir=out) + model = NeuralNet(data_set, cnn3d, model_type='3d', task='class', + cuda=False, plot=True, outdir=out) #model = NeuralNet(data_set, model3d.cnn,cuda=True,ngpu=1,plot=False, task='class') # change the optimizer (optional) model.optimizer = optim.SGD(model.net.parameters(), - lr=0.0001,momentum=0.9,weight_decay=0.00001) + lr=0.0001, momentum=0.9, weight_decay=0.00001) # start the training - model.train(nepoch = 2, divide_trainset = None, train_batch_size = 50, num_workers=8, save_model='all') + model.train( + nepoch=2, + divide_trainset=None, + train_batch_size=50, + num_workers=8, + save_model='all') + if __name__ == '__main__': main() diff --git a/example/learn_batch_new.py b/example/learn_batch_new.py index f8ba609b..05907f94 100644 --- a/example/learn_batch_new.py +++ b/example/learn_batch_new.py @@ -1,13 +1,15 @@ -import os import glob +import os +import pdb +import re +import sys +from math import * + import numpy as np + from deeprank.learn import * -from math import * -import sys -import re from model3d import cnn_class as cnn3d from torch import optim -import pdb """ An example to do cross-validation 3d_cnn at the case level @@ -16,39 +18,55 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0" -def divide_data(hdf5_DIR, caseID_FL, portion=[0.8,0.1,0.1], random =True, write_to_file = True): + +def divide_data( + hdf5_DIR, + caseID_FL, + portion=[ + 0.8, + 0.1, + 0.1], + random=True, + write_to_file=True): # INPUT: the dir that stores all hdf5 data (training, validation, and test) # OUPUT: randomly divide them into train, validation, and test at the caseID-level. Return the filenames. - # write_to_file: True then write the files of trainSet.txt, valiatonSet.txt and testSet.txt + # write_to_file: True then write the files of trainSet.txt, + # valiatonSet.txt and testSet.txt if sum(portion) > 1: - sys.exit("Error: The sum of portions for train/validatoin/test is larger than 1!") + sys.exit( + "Error: The sum of portions for train/validatoin/test is larger than 1!") if len(portion) != 3: sys.exit("Error: the length of portions has to be 3.") - caseIDs = np.array(read_listFL(caseID_FL)) - train_caseIDs, valid_caseIDs, test_caseIDs = random_split(caseIDs, portion, random = random) + train_caseIDs, valid_caseIDs, test_caseIDs = random_split( + caseIDs, portion, random=random) - print (f"\nnum of training cases: {len(train_caseIDs)}") - print (f"num of validation cases: {len(valid_caseIDs)}") - print (f"num of test cases: {len(test_caseIDs)}\n") + print(f"\nnum of training cases: {len(train_caseIDs)}") + print(f"num of validation cases: {len(valid_caseIDs)}") + print(f"num of test cases: {len(test_caseIDs)}\n") train_database = get_hdf5FLs(train_caseIDs, hdf5_DIR) valid_database = get_hdf5FLs(valid_caseIDs, hdf5_DIR) test_database = get_hdf5FLs(test_caseIDs, hdf5_DIR) - print (f"\nnum of training hdf5 files: {len(train_database)}") - print (f"num of validation hdf5 files: {len(valid_database)}") - print (f"num of test hdf5 files: {len(test_database)}\n") + print(f"\nnum of training hdf5 files: {len(train_database)}") + print(f"num of validation hdf5 files: {len(valid_database)}") + print(f"num of test hdf5 files: {len(test_database)}\n") if write_to_file is True: #outDIR = hdf5_DIR outDIR = os.getcwd() - write_train_valid_testFLs (train_database, valid_database, test_database, outDIR) + write_train_valid_testFLs( + train_database, + valid_database, + test_database, + outDIR) return train_database, valid_database, test_database + def get_hdf5FLs(caseIDs, hdf5_DIR): hdf5_FLs = [] @@ -57,19 +75,23 @@ def get_hdf5FLs(caseIDs, hdf5_DIR): return hdf5_FLs + def read_listFL(listFL): - f = open(listFL,'r') + f = open(listFL, 'r') caseIDs = f.readlines() f.close() - caseIDs = [ x.strip() for x in caseIDs if not re.search('^#', x) and not re.search('^\s*$',x) ] + caseIDs = [ + x.strip() for x in caseIDs if not re.search( + '^#', x) and not re.search( + r'^\s*$', x)] - print (f"{len(caseIDs)} cases read from {listFL}") + print(f"{len(caseIDs)} cases read from {listFL}") return caseIDs -def random_split(array, portion, random = True): +def random_split(array, portion, random=True): # array: np.array. Can be a list of caseIDs or a list of hdf5 file names if random is False: @@ -86,13 +108,17 @@ def random_split(array, portion, random = True): n_test = floor(n_cases * portion[2]) train = array[:n_train] - valid = array[n_train:n_train+n_valid] - test = array[n_train + n_valid: n_train + n_valid + n_test] + valid = array[n_train:n_train + n_valid] + test = array[n_train + n_valid: n_train + n_valid + n_test] return train, valid, test -def write_train_valid_testFLs (train_database, valid_database, test_database, outDIR): +def write_train_valid_testFLs( + train_database, + valid_database, + test_database, + outDIR): trainID_FL = f"{outDIR}/trainIDs.txt" validID_FL = f"{outDIR}/validIDs.txt" testID_FL = f"{outDIR}/testIDs.txt" @@ -103,72 +129,71 @@ def write_train_valid_testFLs (train_database, valid_database, test_database, ou for outFL, database in zip(outFLs, databases): if database is not True: - np.savetxt(outFL, database, delimiter = "\n", fmt = "%s") + np.savetxt(outFL, database, delimiter="\n", fmt="%s") print(f"{outFL} generated.") def main(): - hdf5_DIR = '/projects/0/deeprank/BM5/hdf5' # stores all *.hdf5 files + hdf5_DIR = '/projects/0/deeprank/BM5/hdf5' # stores all *.hdf5 files caseID_FL = '/projects/0/deeprank/BM5/caseID_dimers.lst' # hdf5_DIR = '/projects/0/deeprank/BM5/hdf5' # caseID_FL = '/projects/0/deeprank/BM5/caseID_dimers.lst' - train_database, valid_database, test_database = \ - divide_data(hdf5_DIR = hdf5_DIR,caseID_FL = caseID_FL, portion = [0.6,0.1,0.1], random = False) + train_database, valid_database, test_database = divide_data( + hdf5_DIR=hdf5_DIR, caseID_FL=caseID_FL, portion=[0.6, 0.1, 0.1], random=False) # clean the output dir out = './out' if os.path.isdir(out): - for f in glob.glob(out+'/*'): + for f in glob.glob(out + '/*'): os.remove(f) os.removedirs(out) - - # declare the dataset instance - data_set = DataSet(train_database = train_database, - valid_database = valid_database, - test_database = test_database, - mapfly=False, - use_rotation=0, - grid_info = {'number_of_points':[6, 6, 6], 'resolution' : [5,5,5]}, - - # select_feature={'AtomicDensities' : {'CA':1.7, 'C':1.7, 'N':1.55, 'O':1.52}, - # 'Features' : ['coulomb','vdwaals','charge','PSSM_*'] }, - #select_feature = 'all', - select_feature = {'Feature_ind':['coulomb']}, - select_target='BIN_CLASS', - tqdm=True, - normalize_features = False, - normalize_targets=False, - clip_features=False, - pair_chain_feature=np.add, - dict_filter={'DOCKQ':'>0.01', 'IRMSD':'<=4 or >10'}) + data_set = DataSet(train_database=train_database, + valid_database=valid_database, + test_database=test_database, + mapfly=False, + use_rotation=0, + grid_info={ + 'number_of_points': [ + 6, 6, 6], 'resolution': [ + 5, 5, 5]}, + + # select_feature={'AtomicDensities' : {'CA':1.7, 'C':1.7, 'N':1.55, 'O':1.52}, + # 'Features' : ['coulomb','vdwaals','charge','PSSM_*'] }, + #select_feature = 'all', + select_feature={'Feature_ind': ['coulomb']}, + select_target='BIN_CLASS', + tqdm=True, + normalize_features=False, + normalize_targets=False, + clip_features=False, + pair_chain_feature=np.add, + dict_filter={'DOCKQ': '>0.01', 'IRMSD': '<=4 or >10'}) # create the networkt model = NeuralNet(data_set=data_set, - model=cnn3d, - model_type='3d', - task='class', - pretrained_model=None, - cuda=True, - ngpu=1, - plot=True, - save_hitrate=True, - save_classmetrics=True, - outdir=out) - - + model=cnn3d, + model_type='3d', + task='class', + pretrained_model=None, + cuda=True, + ngpu=1, + plot=True, + save_hitrate=True, + save_classmetrics=True, + outdir=out) # change the optimizer (optional) model.optimizer = optim.SGD(model.net.parameters(), - lr=0.0001,momentum=0.9,weight_decay=0.0001) + lr=0.0001, momentum=0.9, weight_decay=0.0001) # start the training model.train(nepoch=1, - preshuffle = True, - preshuffle_seed = 2019, + preshuffle=True, + preshuffle_seed=2019, divide_trainset=None, train_batch_size=10, num_workers=6, @@ -177,5 +202,6 @@ def main(): hdf5='xue_epoch_data.hdf5' ) + if __name__ == '__main__': main() diff --git a/example/model_250619.py b/example/model_250619.py index b51b98e3..68be188e 100644 --- a/example/model_250619.py +++ b/example/model_250619.py @@ -1,46 +1,48 @@ +import pdb + import torch -from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F -import pdb - +from torch.autograd import Variable class cnn_class(nn.Module): - def __init__(self,input_shape): + def __init__(self, input_shape): # input_shape: (C, W, H, D) - super(cnn_class,self).__init__() + super(cnn_class, self).__init__() self.bn0 = nn.BatchNorm3d(input_shape[0]) - self.conv1 = nn.Conv3d(in_channels = input_shape[0], out_channels = 6,kernel_size=5) + self.conv1 = nn.Conv3d( + in_channels=input_shape[0], + out_channels=6, + kernel_size=5) self.bn1 = nn.BatchNorm3d(6) - self.mp1 = nn.MaxPool3d((3,3,3)) - self.conv2 = nn.Conv3d(in_channels = 6, out_channels = 6,kernel_size=3) + self.mp1 = nn.MaxPool3d((3, 3, 3)) + self.conv2 = nn.Conv3d(in_channels=6, out_channels=6, kernel_size=3) self.bn2 = nn.BatchNorm3d(6) - self.mp2 = nn.MaxPool3d((3,3,3)) + self.mp2 = nn.MaxPool3d((3, 3, 3)) size = self._get_conv_outputSize(input_shape) - self.fc_1 = nn.Linear(in_features=size,out_features = 4) + self.fc_1 = nn.Linear(in_features=size, out_features=4) self.bn3 = nn.BatchNorm1d(4) - self.fc_2 = nn.Linear(4,2) - + self.fc_2 = nn.Linear(4, 2) - def _get_conv_outputSize(self,shape): + def _get_conv_outputSize(self, shape): num_data_points = 10 - inp = Variable(torch.rand(num_data_points,*shape)) + inp = Variable(torch.rand(num_data_points, *shape)) out = self._forward_features(inp) - return out.data.view(num_data_points,-1).size(1) + return out.data.view(num_data_points, -1).size(1) + + def _forward_features(self, x): + x = F.max_pool3d(F.relu(self.bn1(self.conv1(x))), 3) + x = F.max_pool3d(F.relu(self.bn2(self.conv2(x))), 3) - def _forward_features(self,x): - x = F.max_pool3d(F.relu(self.bn1(self.conv1(x))),3) - x = F.max_pool3d(F.relu(self.bn2(self.conv2(x))),3) - '''' x = F.relu(self.conv3(x)) x = self.bn3(x) - + x = F.relu(self.conv4(x)) x = self.pool4(x) x = self.bn4(x) @@ -51,9 +53,9 @@ def _forward_features(self,x): ''' return x - def forward(self,x): + def forward(self, x): x = self._forward_features(x) - x = x.view(x.size(0),-1) + x = x.view(x.size(0), -1) x = F.relu(self.fc_1(x)) x = self.bn3(x) x = self.fc_2(x) diff --git a/scripts/cleandata.py b/scripts/cleandata.py index df69bcfa..ad033051 100755 --- a/scripts/cleandata.py +++ b/scripts/cleandata.py @@ -1,57 +1,74 @@ #!/usr/bin/env python -import deeprank.generate -import h5py import os -def clean_dataset(fname,feature=True,pdb=True,points=True,grid=False): - - # name of the hdf5 file - f5 = h5py.File(fname,'a') - - # get the folder names - mol_names = f5.keys() - - for name in mol_names: - - mol_grp = f5[name] - - if feature and 'features' in mol_grp: - del mol_grp['features'] - if pdb and 'complex' in mol_grp and 'native' in mol_grp: - del mol_grp['complex'] - del mol_grp['native'] - if points and 'grid_points' in mol_grp: - del mol_grp['grid_points'] - if grid and 'mapped_features' in mol_grp: - del mol_grp['mapped_features'] - - f5.close() - - os.system('h5repack %s _tmp.h5py' %fname) - os.system('mv _tmp.h5py %s' %fname) +import h5py -if __name__ == '__main__': +import deeprank.generate + + +def clean_dataset(fname, feature=True, pdb=True, points=True, grid=False): + + # name of the hdf5 file + f5 = h5py.File(fname, 'a') - import argparse - import os + # get the folder names + mol_names = f5.keys() - parser = argparse.ArgumentParser(description='remove data from a hdf5 data set') - parser.add_argument('hdf5', help="hdf5 file storing the data set",default=None) - parser.add_argument('--keep_feature', action='store_true',help="keep the features") - parser.add_argument('--keep_pdb', action='store_true',help="keep the pdbs") - parser.add_argument('--keep_pts',action='store_true',help="keep the coordinates of the grid points") - parser.add_argument('--rm_grid',action='store_true',help='remove the mapped feaures on the grids') - args = parser.parse_args() + for name in mol_names: + + mol_grp = f5[name] + + if feature and 'features' in mol_grp: + del mol_grp['features'] + if pdb and 'complex' in mol_grp and 'native' in mol_grp: + del mol_grp['complex'] + del mol_grp['native'] + if points and 'grid_points' in mol_grp: + del mol_grp['grid_points'] + if grid and 'mapped_features' in mol_grp: + del mol_grp['mapped_features'] + + f5.close() + + os.system('h5repack %s _tmp.h5py' % fname) + os.system('mv _tmp.h5py %s' % fname) + + +if __name__ == '__main__': - clean_dataset(args.hdf5, - feature = not args.keep_feature, - pdb = not args.keep_pdb, - points = not args.keep_pts, - grid = args.rm_grid ) + import argparse + import os - #os.system('h5repack %s _tmp.h5py' %args.hdf5) - #os.system('mv _tmp.h5py %s' %args.hdf5) + parser = argparse.ArgumentParser( + description='remove data from a hdf5 data set') + parser.add_argument( + 'hdf5', + help="hdf5 file storing the data set", + default=None) + parser.add_argument( + '--keep_feature', + action='store_true', + help="keep the features") + parser.add_argument( + '--keep_pdb', + action='store_true', + help="keep the pdbs") + parser.add_argument( + '--keep_pts', + action='store_true', + help="keep the coordinates of the grid points") + parser.add_argument( + '--rm_grid', + action='store_true', + help='remove the mapped feaures on the grids') + args = parser.parse_args() + clean_dataset(args.hdf5, + feature=not args.keep_feature, + pdb=not args.keep_pdb, + points=not args.keep_pts, + grid=args.rm_grid) - + #os.system('h5repack %s _tmp.h5py' %args.hdf5) + #os.system('mv _tmp.h5py %s' %args.hdf5) diff --git a/scripts/launch.py b/scripts/launch.py index 348541d0..c20420de 100755 --- a/scripts/launch.py +++ b/scripts/launch.py @@ -1,10 +1,12 @@ #!/usr/bin/env python -from deeprank.generate import * import os from time import time + import numpy as np + from cleandata import * +from deeprank.generate import * ########################################################################## # @@ -20,34 +22,38 @@ BM4 = '' -def generate(LIST_NAME,clean=False): +def generate(LIST_NAME, clean=False): for NAME in LIST_NAME: print(NAME) # sources to assemble the data base - pdb_source = [BM4 + 'decoys_pdbFLs/'+NAME+'/water/'] - pdb_native = [BM4 + 'BM4_dimers_bound/pdbFLs_ori'] - - #init the data assembler - database = DataGenerator(pdb_source=pdb_source, - pdb_native=pdb_native, - data_augmentation=None, - compute_targets = ['deeprank.targets.dockQ','deeprank.targets.binary_class'], - compute_features = ['deeprank.features.AtomicFeature', - 'deeprank.features.NaivePSSM', - 'deeprank.features.FullPSSM', - 'deeprank.features.PSSM_IC', - 'deeprank.features.BSA', - 'deeprank.features.ResidueDensity'], - hdf5=NAME + '.hdf5', - ) + pdb_source = [BM4 + 'decoys_pdbFLs/' + NAME + '/water/'] + pdb_native = [BM4 + 'BM4_dimers_bound/pdbFLs_ori'] + + # init the data assembler + database = DataGenerator( + pdb_source=pdb_source, + pdb_native=pdb_native, + data_augmentation=None, + compute_targets=[ + 'deeprank.targets.dockQ', + 'deeprank.targets.binary_class'], + compute_features=[ + 'deeprank.features.AtomicFeature', + 'deeprank.features.NaivePSSM', + 'deeprank.features.FullPSSM', + 'deeprank.features.PSSM_IC', + 'deeprank.features.BSA', + 'deeprank.features.ResidueDensity'], + hdf5=NAME + '.hdf5', + ) if not os.path.isfile(database.hdf5): t0 = time() print('{:25s}'.format('Create new database') + database.hdf5) database.create_database() - print(' '*25 + '--> Done in %f s.' %(time()-t0)) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) else: print('{:25s}'.format('Use existing database') + database.hdf5) @@ -68,46 +74,90 @@ def generate(LIST_NAME,clean=False): t0 = time() print('{:25s}'.format('Clean datafile') + database.hdf5) clean_dataset(database.hdf5) - print(' '*25 + '--> Done is %f s.' %(time()-t0)) + print(' ' * 25 + '--> Done is %f s.' % (time() - t0)) + if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='launch multiple HDF5 calculations') - parser.add_argument('-s','--status',action='store_true',help='Only list the directory') - parser.add_argument('-d','--device', help="GPU device to use",default='1',type=str) - parser.add_argument('-m','--mol',nargs='+',help='name of the molecule to process',default=None,type=str) - parser.add_argument('-i','--init',help="index of the first molecule to process",default=0,type=int) - parser.add_argument('-f','--final',help="index of the last molecule to process",default=0,type=int) - parser.add_argument('--clean',help="Clean the datafiles",action='store_true') + parser = argparse.ArgumentParser( + description='launch multiple HDF5 calculations') + parser.add_argument( + '-s', + '--status', + action='store_true', + help='Only list the directory') + parser.add_argument( + '-d', + '--device', + help="GPU device to use", + default='1', + type=str) + parser.add_argument( + '-m', + '--mol', + nargs='+', + help='name of the molecule to process', + default=None, + type=str) + parser.add_argument( + '-i', + '--init', + help="index of the first molecule to process", + default=0, + type=int) + parser.add_argument( + '-f', + '--final', + help="index of the last molecule to process", + default=0, + type=int) + parser.add_argument( + '--clean', + help="Clean the datafiles", + action='store_true') args = parser.parse_args() # get the names of the directories - names = np.sort(os.listdir(BM4+'decoys_pdbFLs/')).tolist() + names = np.sort(os.listdir(BM4 + 'decoys_pdbFLs/')).tolist() # remove some files # as stated in the README some complex don't have a water stage - remove_file = ['README','2H7V','1F6M','1ZLI','1IBR','1R8S','1Y64'] + remove_file = ['README', '2H7V', '1F6M', '1ZLI', '1IBR', '1R8S', '1Y64'] for r in remove_file: names.remove(r) # get the names of thehdf5 already there hdf5 = list(filter(lambda x: '.hdf5' in x, os.listdir())) - status = [ 'Done' if n+'.hdf5' in hdf5 else '' for n in names ] - size = [ "{:5.2f}".format(os.path.getsize(n+'.hdf5')/1E9) if n+'.hdf5' in hdf5 else '' for n in names ] + status = ['Done' if n + '.hdf5' in hdf5 else '' for n in names] + size = [ + "{:5.2f}".format( + os.path.getsize( + n + + '.hdf5') / + 1E9) if n + + '.hdf5' in hdf5 else '' for n in names] # list the dir and their status if args.status: - print('\n'+'='*50+'\n= Current status of the datase \n'+'='*50) - for i,(n,s,w) in enumerate(zip(names,status,size)): + print( + '\n' + + '=' * + 50 + + '\n= Current status of the datase \n' + + '=' * + 50) + for i, (n, s, w) in enumerate(zip(names, status, size)): if w == '': - print('% 4d: %6s %5s %s' %(i,n,s,w)) + print('% 4d: %6s %5s %s' % (i, n, s, w)) else: - print('% 4d: %6s %5s %s GB' %(i,n,s,w)) - print('-'*50) - print(': Status --> %4.3f %% done' %(status.count('Done')/len(status)*100)) - print(': Mem Tot --> %4.3f GB\n' %sum(list(map(lambda x: float(x),filter(lambda x: len(x)>0,size))))) + print('% 4d: %6s %5s %s GB' % (i, n, s, w)) + print('-' * 50) + print( + ': Status --> %4.3f %% done' % (status.count('Done') / len(status) * 100)) + print(': Mem Tot --> %4.3f GB\n' % sum(list(map(lambda x: float(x), + filter(lambda x: len(x) > 0, size))))) # compute the data else: @@ -115,13 +165,10 @@ def generate(LIST_NAME,clean=False): if args.mol is not None: MOL = args.mol else: - MOL = names[args.init:args.final+1] + MOL = names[args.init:args.final + 1] # set the cuda device #os.environ['CUDA_DEVICE'] = args.device # generate the data - generate(MOL,clean=args.clean) - - - + generate(MOL, clean=args.clean) diff --git a/scripts/simple_generate.py b/scripts/simple_generate.py index 70f6ee58..cfc3ba77 100644 --- a/scripts/simple_generate.py +++ b/scripts/simple_generate.py @@ -1,28 +1,32 @@ -from deeprank.generate import * import os from time import time +from deeprank.generate import * """Test the data generation process.""" h5file = './1ak4.hdf5' -pdb_source = './decoys_pdbFLs/1AK4/water/' -pdb_native = './bound_pdb/' +pdb_source = './decoys_pdbFLs/1AK4/water/' +pdb_native = './bound_pdb/' -database = DataGenerator(pdb_source=pdb_src,pdb_native=pdb_native, - data_augmentation = 0, - compute_targets = ['deeprank.targets.dockQ','deeprank.targets.binary_class'], - compute_features = ['deeprank.features.AtomicFeature', - 'deeprank.features.NaivePSSM', - 'deeprank.features.PSSM_IC', - 'deeprank.features.BSA', - 'deeprank.features.FullPSSM', - 'deeprank.features.ResidueDensity'], - hdf5=h5file) +database = DataGenerator( + pdb_source=pdb_src, + pdb_native=pdb_native, + data_augmentation=0, + compute_targets=[ + 'deeprank.targets.dockQ', + 'deeprank.targets.binary_class'], + compute_features=[ + 'deeprank.features.AtomicFeature', + 'deeprank.features.NaivePSSM', + 'deeprank.features.PSSM_IC', + 'deeprank.features.BSA', + 'deeprank.features.FullPSSM', + 'deeprank.features.ResidueDensity'], + hdf5=h5file) -#create new files +# create new files print('{:25s}'.format('Create new database') + database.hdf5) database.create_database(prog_bar=True) -print(' '*25 + '--> Done in %f s.' %(time()-t0)) - +print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) diff --git a/test/2OUL/test.py b/test/2OUL/test.py index f9eeeaa4..39428b23 100644 --- a/test/2OUL/test.py +++ b/test/2OUL/test.py @@ -1,8 +1,9 @@ -from deeprank.features import AtomicFeature +import unittest + import numpy as np import pkg_resources -import unittest +from deeprank.features import AtomicFeature # in case you change the ref don't forget to: # - comment the first line (E0=1) @@ -16,18 +17,20 @@ # get the force field included in deeprank # if another FF has been used to compute the ref # change also this path to the correct one -FF = pkg_resources.resource_filename('deeprank.features','') + '/forcefield/' +FF = pkg_resources.resource_filename('deeprank.features', '') + '/forcefield/' # declare the feature calculator instance -atfeat = AtomicFeature(pdb,fix_chainID=True, - param_charge = FF + 'protein-allhdg5-4_new.top', - param_vdw = FF + 'protein-allhdg5-4_new.param', - patch_file = FF + 'patch.top') +atfeat = AtomicFeature(pdb, fix_chainID=True, + param_charge=FF + 'protein-allhdg5-4_new.top', + param_vdw=FF + 'protein-allhdg5-4_new.param', + patch_file=FF + 'patch.top') # assign parameters atfeat.assign_parameters() # only compute the pair interactions here -atfeat.evaluate_pair_interaction(save_interactions=test_name,print_interactions=True) +atfeat.evaluate_pair_interaction( + save_interactions=test_name, + print_interactions=True) # # make sure that the other properties are not crashing @@ -40,6 +43,3 @@ # # close the db # atfeat.sqldb.close() - - - diff --git a/test/test_atomic_features.py b/test/test_atomic_features.py index a79ed88e..7990b2f3 100644 --- a/test/test_atomic_features.py +++ b/test/test_atomic_features.py @@ -1,7 +1,10 @@ -from deeprank.features import AtomicFeature +import unittest + import numpy as np import pkg_resources -import unittest + +from deeprank.features import AtomicFeature + class TestAtomicFeature(unittest.TestCase): """Test StructureSimialrity.""" @@ -20,13 +23,14 @@ def test_atomic_haddock(): # get the force field included in deeprank # if another FF has been used to compute the ref # change also this path to the correct one - FF = pkg_resources.resource_filename('deeprank.features','') + '/forcefield/' + FF = pkg_resources.resource_filename( + 'deeprank.features', '') + '/forcefield/' # declare the feature calculator instance atfeat = AtomicFeature(pdb, - param_charge = FF + 'protein-allhdg5-4_new.top', - param_vdw = FF + 'protein-allhdg5-4_new.param', - patch_file = FF + 'patch.top') + param_charge=FF + 'protein-allhdg5-4_new.top', + param_vdw=FF + 'protein-allhdg5-4_new.param', + patch_file=FF + 'patch.top') # assign parameters atfeat.assign_parameters() @@ -36,55 +40,60 @@ def test_atomic_haddock(): # read the files f = open(REF) ref = f.readlines() - ref = [r for r in ref if not r.startswith('#') and not r.startswith('Total') and len(r.split())>0] + ref = [r for r in ref if not r.startswith( + '#') and not r.startswith('Total') and len(r.split()) > 0] f.close() - f=open(REF) + f = open(REF) ref_tot = f.readlines() ref_tot = [r for r in ref_tot if r.startswith('Total')] f.close() # read the test - f=open(test_name) + f = open(test_name) test = f.readlines() - test = [t for t in test if len(t.split())>0 and not t.startswith('Total')] + test = [ + t for t in test if len( + t.split()) > 0 and not t.startswith('Total')] f.close() - f=open(test_name) + f = open(test_name) test_tot = f.readlines() test_tot = [t for t in test_tot if t.startswith('Total')] f.close() # compare files nint = 0 - for ltest,lref in zip(test,ref): + for ltest, lref in zip(test, ref): ltest = ltest.split() lref = lref.split() - - at_test = ( (ltest[0],ltest[1],ltest[2],ltest[3]),(ltest[4],ltest[5],ltest[6],ltest[7]) ) - at_ref = ( (lref[1] ,lref[0] ,lref[2] ,lref[3]) ,(lref[5] ,lref[4] ,lref[6] ,lref[7]) ) + at_test = ((ltest[0], ltest[1], ltest[2], ltest[3]), + (ltest[4], ltest[5], ltest[6], ltest[7])) + at_ref = ((lref[1], lref[0], lref[2], lref[3]), + (lref[5], lref[4], lref[6], lref[7])) if not at_test == at_ref: raise AssertionError() dtest = np.array(float(ltest[8])) - dref = np.array(float(lref[8])) - if not np.allclose(dtest,dref,rtol = 1E-3,atol=1E-3): + dref = np.array(float(lref[8])) + if not np.allclose(dtest, dref, rtol=1E-3, atol=1E-3): raise AssertionError() val_test = np.array(ltest[9:11]).astype('float64') - val_ref = np.array(lref[9:11]).astype('float64') - if not np.allclose(val_ref,val_test,atol=1E-6): + val_ref = np.array(lref[9:11]).astype('float64') + if not np.allclose(val_ref, val_test, atol=1E-6): raise AssertionError() nint += 1 - Etest= np.array([float(test_tot[0].split()[3]),float(test_tot[1].split()[3])]) - Eref = np.array([float(ref_tot[0].split()[3]),float(ref_tot[1].split()[3])]) - if not np.allclose(Etest,Eref): + Etest = np.array([float(test_tot[0].split()[3]), + float(test_tot[1].split()[3])]) + Eref = np.array([float(ref_tot[0].split()[3]), + float(ref_tot[1].split()[3])]) + if not np.allclose(Etest, Eref): raise AssertionError() - # make sure that the other properties are not crashing atfeat.compute_coulomb_interchain_only(contact_only=True) atfeat.compute_coulomb_interchain_only(contact_only=False) @@ -96,8 +105,6 @@ def test_atomic_haddock(): # close the db atfeat.sqldb.close() - - @staticmethod def test_atomic_zdock(): @@ -112,20 +119,20 @@ def test_atomic_zdock(): # get the force field included in deeprank # if another FF has been used to compute the ref # change also this path to the correct one - FF = pkg_resources.resource_filename('deeprank.features','') + '/forcefield/' + FF = pkg_resources.resource_filename( + 'deeprank.features', '') + '/forcefield/' # declare the feature calculator instance atfeat = AtomicFeature(pdb, - param_charge = FF + 'protein-allhdg5-4_new.top', - param_vdw = FF + 'protein-allhdg5-4_new.param', - patch_file = FF + 'patch.top') + param_charge=FF + 'protein-allhdg5-4_new.top', + param_vdw=FF + 'protein-allhdg5-4_new.param', + patch_file=FF + 'patch.top') # assign parameters atfeat.assign_parameters() # only compute the pair interactions here atfeat.evaluate_pair_interaction(save_interactions=test_name) - # make sure that the other properties are not crashing atfeat.compute_coulomb_interchain_only(contact_only=True) atfeat.compute_coulomb_interchain_only(contact_only=False) @@ -137,6 +144,6 @@ def test_atomic_zdock(): # close the db atfeat.sqldb.close() + if __name__ == '__main__': unittest.main() - diff --git a/test/test_generate.py b/test/test_generate.py index 1d0ffd32..c0736456 100644 --- a/test/test_generate.py +++ b/test/test_generate.py @@ -1,8 +1,9 @@ -import unittest -from deeprank.generate import * import os +import unittest from time import time +from deeprank.generate import * + """ Some requirement of the naming of the files: @@ -11,115 +12,137 @@ 3. pssm file name should have this format: 2w83-AB.A.pssm (caseID.chainID.pssm or caseID.chainID.pdb.pssm) """ + class TestGenerateData(unittest.TestCase): """Test the data generation process.""" - h5file = ['./1ak4.hdf5','native.hdf5'] - pdb_source = ['./1AK4/decoys/','./1AK4/native/'] - pdb_native = ['./1AK4/native/'] # pdb_native is only used to calculate i-RMSD, dockQ and so on. The native pdb files will not be saved in the hdf5 file + h5file = ['./1ak4.hdf5', 'native.hdf5'] + pdb_source = ['./1AK4/decoys/', './1AK4/native/'] + # pdb_native is only used to calculate i-RMSD, dockQ and so on. The native + # pdb files will not be saved in the hdf5 file + pdb_native = ['./1AK4/native/'] def test_1_generate(self): """Generate the database.""" # clean old files - files = ['1ak4.hdf5','1ak4_norm.pckl','native.hdf5','native_norm.pckl'] + files = [ + '1ak4.hdf5', + '1ak4_norm.pckl', + 'native.hdf5', + 'native_norm.pckl'] for f in files: if os.path.isfile(f): os.remove(f) - #init the data assembler - for h5,src in zip(self.h5file,self.pdb_source): - - database = DataGenerator(pdb_source=src, - pdb_native=self.pdb_native, - pssm_source='./1AK4/pssm_new/', - data_augmentation = 1, - compute_targets = ['deeprank.targets.dockQ','deeprank.targets.binary_class'], - compute_features = ['deeprank.features.AtomicFeature', - 'deeprank.features.FullPSSM', - 'deeprank.features.PSSM_IC', - 'deeprank.features.BSA', - 'deeprank.features.ResidueDensity'], - hdf5=h5) - - #create new files + # init the data assembler + for h5, src in zip(self.h5file, self.pdb_source): + + database = DataGenerator( + pdb_source=src, + pdb_native=self.pdb_native, + pssm_source='./1AK4/pssm_new/', + data_augmentation=1, + compute_targets=[ + 'deeprank.targets.dockQ', + 'deeprank.targets.binary_class'], + compute_features=[ + 'deeprank.features.AtomicFeature', + 'deeprank.features.FullPSSM', + 'deeprank.features.PSSM_IC', + 'deeprank.features.BSA', + 'deeprank.features.ResidueDensity'], + hdf5=h5) + + # create new files if not os.path.isfile(database.hdf5): t0 = time() print('{:25s}'.format('Create new database') + database.hdf5) database.create_database(prog_bar=True) - print(' '*25 + '--> Done in %f s.' %(time()-t0)) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) else: print('{:25s}'.format('Use existing database') + database.hdf5) # map the features grid_info = { - 'number_of_points' : [30,30,30], - 'resolution' : [1.,1.,1.], - 'atomic_densities' : {'CA':3.5,'N':3.5,'O':3.5,'C':3.5}, + 'number_of_points': [30, 30, 30], + 'resolution': [1., 1., 1.], + 'atomic_densities': {'CA': 3.5, 'N': 3.5, 'O': 3.5, 'C': 3.5}, } t0 = time() print('{:25s}'.format('Map features in database') + database.hdf5) - database.map_features(grid_info,try_sparse=True, time=False, prog_bar=True) - print(' '*25 + '--> Done in %f s.' %(time()-t0)) + database.map_features( + grid_info, + try_sparse=True, + time=False, + prog_bar=True) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) # get the normalization t0 = time() print('{:25s}'.format('Normalization') + database.hdf5) norm = NormalizeData(h5) norm.get() - print(' '*25 + '--> Done in %f s.' %(time()-t0)) - + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) def test_2_add_target(self): """Add a target (e.g., class labels) to the database.""" for h5 in self.h5file: - #init the data assembler - database = DataGenerator(compute_targets = ['deeprank.targets.binary_class'], - hdf5=h5) + # init the data assembler + database = DataGenerator( + compute_targets=['deeprank.targets.binary_class'], hdf5=h5) t0 = time() - print('{:25s}'.format('Add new target in database') + database.hdf5) + print( + '{:25s}'.format('Add new target in database') + + database.hdf5) database.add_target(prog_bar=True) - print(' '*25 + '--> Done in %f s.' %(time()-t0)) - + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) def test_3_add_unique_target(self): """"Add a unique target to all the confs.""" for h5 in self.h5file: database = DataGenerator(hdf5=h5) - database.add_unique_target({'XX':1.0}) + database.add_unique_target({'XX': 1.0}) def test_4_add_feature(self): """Add a feature to the database.""" for h5 in self.h5file: - #init the data assembler - database = DataGenerator(pdb_source=None, pdb_native=None, data_augmentation=None, - pssm_source='./1AK4/pssm_new/', - compute_features = ['deeprank.features.FullPSSM'], hdf5=h5) + # init the data assembler + database = DataGenerator( + pdb_source=None, + pdb_native=None, + data_augmentation=None, + pssm_source='./1AK4/pssm_new/', + compute_features=['deeprank.features.FullPSSM'], + hdf5=h5) t0 = time() - print('{:25s}'.format('Add new feature in database') + database.hdf5) + print( + '{:25s}'.format('Add new feature in database') + + database.hdf5) database.add_feature(prog_bar=True) - print(' '*25 + '--> Done in %f s.' %(time()-t0)) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) t0 = time() - print('{:25s}'.format('Map new feature in database') + database.hdf5) - database.map_features(try_sparse=True,time=False,prog_bar=True) - print(' '*25 + '--> Done in %f s.' %(time()-t0)) + print( + '{:25s}'.format('Map new feature in database') + + database.hdf5) + database.map_features(try_sparse=True, time=False, prog_bar=True) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) # get the normalization t0 = time() print('{:25s}'.format('Normalization') + database.hdf5) norm = NormalizeData(h5) norm.get() - print(' '*25 + '--> Done in %f s.' %(time()-t0)) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) + if __name__ == "__main__": unittest.main() - - diff --git a/test/test_generate_cuda.py b/test/test_generate_cuda.py index 73c1dac3..438d55f3 100644 --- a/test/test_generate_cuda.py +++ b/test/test_generate_cuda.py @@ -1,72 +1,80 @@ -import unittest -import sys import os -from deeprank.generate import * +import sys +import unittest from time import time +from deeprank.generate import * + try: import pycuda skip = False -except: +except BaseException: skip = True + class TestGenerateCUDA(unittest.TestCase): tune = False test = False - gpu_block = [8,8,8] + gpu_block = [8, 8, 8] h5file = '1ak4_cuda.hdf5' # sources to assemble the data base - pdb_source = ['./1AK4/decoys/'] - pdb_native = ['./1AK4/native/'] + pdb_source = ['./1AK4/decoys/'] + pdb_native = ['./1AK4/native/'] - @unittest.skipIf(skip,"torch fails on Travis") + @unittest.skipIf(skip, "torch fails on Travis") @staticmethod def test_generate_cuda(): - #init the data assembler - database = DataGenerator(pdb_source=self.pdb_source,pdb_native=self.pdb_native, - compute_targets = ['deeprank.targets.dockQ'], - compute_features = ['deeprank.features.AtomicFeature', - 'deeprank.features.NaivePSSM', - 'deeprank.features.PSSM_IC', - 'deeprank.features.BSA'], - hdf5=self.h5file) + # init the data assembler + database = DataGenerator( + pdb_source=self.pdb_source, + pdb_native=self.pdb_native, + compute_targets=['deeprank.targets.dockQ'], + compute_features=[ + 'deeprank.features.AtomicFeature', + 'deeprank.features.NaivePSSM', + 'deeprank.features.PSSM_IC', + 'deeprank.features.BSA'], + hdf5=self.h5file) # map the features grid_info = { - 'number_of_points' : [30,30,30], - 'resolution' : [1.,1.,1.], - 'atomic_densities' : {'CA':3.5,'N':3.5,'O':3.5,'C':3.5}, + 'number_of_points': [30, 30, 30], + 'resolution': [1., 1., 1.], + 'atomic_densities': {'CA': 3.5, 'N': 3.5, 'O': 3.5, 'C': 3.5}, } # tune the kernel if self.tune: - database.tune_cuda_kernel(grid_info,func='gaussian') + database.tune_cuda_kernel(grid_info, func='gaussian') # test thekernel elif self.test: - database.test_cuda(grid_info,self.gpu_block,func='gaussian') + database.test_cuda(grid_info, self.gpu_block, func='gaussian') # compute features else: - #create new files + # create new files if not os.path.isfile(database.hdf5): t0 = time() - print('\nCreate new database : %s' %database.hdf5) + print('\nCreate new database : %s' % database.hdf5) database.create_database() - print('--> Done in %f s.' %(time()-t0)) + print('--> Done in %f s.' % (time() - t0)) else: - print('\nUse existing database : %s' %database.hdf5) + print('\nUse existing database : %s' % database.hdf5) # map these data t0 = time() - print('\nMap features in database : %s' %database.hdf5) - database.map_features(grid_info,try_sparse=True,time=False,cuda=True,gpu_block=self.gpu_block) - print(' '*25 + '--> Done in %f s.' %(time()-t0)) - + print('\nMap features in database : %s' % database.hdf5) + database.map_features( + grid_info, + try_sparse=True, + time=False, + cuda=True, + gpu_block=self.gpu_block) + print(' ' * 25 + '--> Done in %f s.' % (time() - t0)) if __name__ == "__main__": unittest.main() - diff --git a/test/test_hitrate_successrate.py b/test/test_hitrate_successrate.py index 9ce70bf8..6c36eff8 100644 --- a/test/test_hitrate_successrate.py +++ b/test/test_hitrate_successrate.py @@ -1,7 +1,9 @@ import unittest -from deeprank.utils.cal_hitrate_successrate import evaluate + import pandas as pd +from deeprank.utils.cal_hitrate_successrate import evaluate + """ Some requirement of the naming of the files: diff --git a/test/test_learn.py b/test/test_learn.py index 79981357..4f292c20 100644 --- a/test/test_learn.py +++ b/test/test_learn.py @@ -1,231 +1,282 @@ +import glob import os import unittest -import glob -import numpy as np -try: - from deeprank.learn import * - from deeprank.learn.model3d import cnn_reg as cnn3d - from deeprank.learn.model3d import cnn_class as cnn3d_class - from deeprank.learn.model2d import cnn as cnn2d - skip=False -except: - skip=True +import numpy as np +try: + from deeprank.learn import * + from deeprank.learn.model3d import cnn_reg as cnn3d + from deeprank.learn.model3d import cnn_class as cnn3d_class + from deeprank.learn.model2d import cnn as cnn2d + skip = False +except BaseException: + skip = True # all the import torch fails on TRAVIS # so we can only exectute this test locally class TestLearn(unittest.TestCase): - @unittest.skipIf(skip,"torch fails on Travis") - @staticmethod - def test_learn_3d_reg_mapfly(): - """Use a 3D CNN for regularization.""" - - #adress of the database - database = '1ak4.hdf5' - if not os.path.isfile(database): - raise FileNotFoundError('Database %s not found. Make sure to run test_generate before') - - # clean the output dir - out = './out_3d_fly' - if os.path.isdir(out): - for f in glob.glob(out+'/*'): - os.remove(f) - os.removedirs(out) - - # declare the dataset instance - data_set = DataSet(database, - test_database = None, - mapfly = True, - use_rotation=1, - grid_info={'number_of_points':(10,10,10), 'resolution':(3,3,3)}, - select_feature={'AtomicDensities' : {'CA':1.7, 'C':1.7, 'N':1.55, 'O':1.52}, - 'Features' : ['coulomb','vdwaals','charge','PSSM_*'] }, - select_target='DOCKQ', - tqdm=True, - normalize_features = False, - normalize_targets=False, - clip_features=False, - pair_chain_feature=np.add, - dict_filter={'DOCKQ': '<1'}) - #dict_filter={'IRMSD':'<4. or >10.'}) - - - # create the networkt - model = NeuralNet(data_set,cnn3d,model_type='3d',task='reg', - cuda=False,plot=True,outdir=out) - - # start the training - model.train(nepoch = 5,divide_trainset=0.8, train_batch_size = 5,num_workers=0) - - - - - @unittest.skipIf(skip,"torch fails on Travis") - @staticmethod - def test_learn_3d_reg(): - """Use a 3D CNN for regularization.""" - - #adress of the database - train_database = '1ak4.hdf5' - if not os.path.isfile(train_database): - raise FileNotFoundError('Database %s not found. Make sure to run test_generate before', train_database) - - # clean the output dir - out = './out_3d_reg' - if os.path.isdir(out): - for f in glob.glob(out+'/*'): - os.remove(f) - os.removedirs(out) - - # declare the dataset instance - data_set = DataSet(train_database = train_database, - valid_database = None, - test_database = None, - mapfly = False, - use_rotation=2, - grid_shape=(30,30,30), - select_feature={'AtomicDensities_ind' : 'all', - 'Feature_ind' : ['coulomb','vdwaals','charge','PSSM_*'] }, - select_target='DOCKQ', - tqdm=True, - normalize_features = True, - normalize_targets=True, - clip_features=False, - pair_chain_feature=np.add, - dict_filter={'DOCKQ':'<1.'}) - #dict_filter={'IRMSD':'<4. or >10.'}) - - - # create the networkt - model = NeuralNet(data_set,cnn3d,model_type='3d',task='reg', - cuda=False,plot=True,outdir=out) - - # start the training - model.train(nepoch = 5,divide_trainset=0.8, train_batch_size = 5,num_workers=0, save_model='all') - - - - @unittest.skipIf(skip,"Torch fails on Travis") - @staticmethod - def test_learn_3d_class(): - """Use a 3D CNN for regularization.""" - - #adress of the database - database = ['1ak4.hdf5','native.hdf5'] - - # clean the output dir - out = './out_3d_class' - if os.path.isdir(out): - for f in glob.glob(out+'/*'): - os.remove(f) - os.removedirs(out) - - # declare the dataset instance - data_set = DataSet(train_database = database, - valid_database = None, - test_database = None, - mapfly = False, - grid_shape=(30,30,30), - - select_feature={'AtomicDensities_ind' : 'all', - 'Feature_ind' : ['coulomb','vdwaals','charge','PSSM_*'] }, - select_target='BIN_CLASS',tqdm=True, - normalize_features = True, - normalize_targets=False, - clip_features=False, - pair_chain_feature=np.add) - - - # create the networkt - model = NeuralNet(data_set,cnn3d_class,model_type='3d',task='class', - cuda=False,plot=True,outdir=out) - - # start the training - model.train(nepoch = 5,divide_trainset=0.8, train_batch_size = 5,num_workers=0,save_epoch='all') - - - - - - - @unittest.skipIf(skip,"torch fails on Travis") - @staticmethod - def test_learn_2d_reg(): - """Use a 2D CNN for regularization.""" - - #adress of the database - database = '1ak4.hdf5' - - # clean the output dir - out = './out_2d/' - if os.path.isdir(out): - for f in glob.glob(out+'/*'): - os.remove(f) - os.removedirs(out) - - if not os.path.isfile(database): - raise FileNotFoundError('Database %s not found. Make sure to run test_generate before') - - # declare the dataset instance - data_set = DataSet(train_database = database, - valid_database = None, - test_database = None, - mapfly = False, - select_feature={'AtomicDensities_ind' : 'all', - 'Feature_ind' : ['coulomb','vdwaals','charge','PSSM_*'] }, - select_target='DOCKQ',tqdm=True, - normalize_features = True, - normalize_targets=True, - clip_features=False, - pair_chain_feature=np.add, - dict_filter={'IRMSD':'<4. or >10.'}) - - - # create the network - model = NeuralNet(data_set,cnn2d,model_type='2d',task='reg', - cuda=False,plot=True,outdir=out) - - # start the training - model.train(nepoch = 5,divide_trainset=0.8, train_batch_size = 5,num_workers=0) - - - - - - - - @unittest.skipIf(skip,"torch fails on Travis") - @staticmethod - def test_transfer(): - - #adress of the database - database = '1ak4.hdf5' - - if not os.path.isfile(database): - raise FileNotFoundError('Database %s not found. Make sure to run test_generate before') - - # clean the output dir - out = './out_test/' - if os.path.isdir(out): - for f in glob.glob(out+'/*'): - os.remove(f) - os.removedirs(out) - - # create the network - model_name = './out_3d_fly/last_model.pth.tar' - model = NeuralNet(database,cnn3d,pretrained_model=model_name,outdir=out) - model.test() - + @unittest.skipIf(skip, "torch fails on Travis") + @staticmethod + def test_learn_3d_reg_mapfly(): + """Use a 3D CNN for regularization.""" + + # adress of the database + database = '1ak4.hdf5' + if not os.path.isfile(database): + raise FileNotFoundError( + 'Database %s not found. Make sure to run test_generate before') + + # clean the output dir + out = './out_3d_fly' + if os.path.isdir(out): + for f in glob.glob(out + '/*'): + os.remove(f) + os.removedirs(out) + + # declare the dataset instance + data_set = DataSet( + database, + test_database=None, + mapfly=True, + use_rotation=1, + grid_info={ + 'number_of_points': ( + 10, + 10, + 10), + 'resolution': ( + 3, + 3, + 3)}, + select_feature={ + 'AtomicDensities': { + 'CA': 1.7, + 'C': 1.7, + 'N': 1.55, + 'O': 1.52}, + 'Features': [ + 'coulomb', + 'vdwaals', + 'charge', + 'PSSM_*']}, + select_target='DOCKQ', + tqdm=True, + normalize_features=False, + normalize_targets=False, + clip_features=False, + pair_chain_feature=np.add, + dict_filter={ + 'DOCKQ': '<1'}) + # dict_filter={'IRMSD':'<4. or >10.'}) + + # create the networkt + model = NeuralNet(data_set, cnn3d, model_type='3d', task='reg', + cuda=False, plot=True, outdir=out) + + # start the training + model.train( + nepoch=5, + divide_trainset=0.8, + train_batch_size=5, + num_workers=0) + + @unittest.skipIf(skip, "torch fails on Travis") + @staticmethod + def test_learn_3d_reg(): + """Use a 3D CNN for regularization.""" + + # adress of the database + train_database = '1ak4.hdf5' + if not os.path.isfile(train_database): + raise FileNotFoundError( + 'Database %s not found. Make sure to run test_generate before', + train_database) + + # clean the output dir + out = './out_3d_reg' + if os.path.isdir(out): + for f in glob.glob(out + '/*'): + os.remove(f) + os.removedirs(out) + + # declare the dataset instance + data_set = DataSet( + train_database=train_database, + valid_database=None, + test_database=None, + mapfly=False, + use_rotation=2, + grid_shape=( + 30, + 30, + 30), + select_feature={ + 'AtomicDensities_ind': 'all', + 'Feature_ind': [ + 'coulomb', + 'vdwaals', + 'charge', + 'PSSM_*']}, + select_target='DOCKQ', + tqdm=True, + normalize_features=True, + normalize_targets=True, + clip_features=False, + pair_chain_feature=np.add, + dict_filter={ + 'DOCKQ': '<1.'}) + # dict_filter={'IRMSD':'<4. or >10.'}) + + # create the networkt + model = NeuralNet(data_set, cnn3d, model_type='3d', task='reg', + cuda=False, plot=True, outdir=out) + + # start the training + model.train( + nepoch=5, + divide_trainset=0.8, + train_batch_size=5, + num_workers=0, + save_model='all') + + @unittest.skipIf(skip, "Torch fails on Travis") + @staticmethod + def test_learn_3d_class(): + """Use a 3D CNN for regularization.""" + + # adress of the database + database = ['1ak4.hdf5', 'native.hdf5'] + + # clean the output dir + out = './out_3d_class' + if os.path.isdir(out): + for f in glob.glob(out + '/*'): + os.remove(f) + os.removedirs(out) + + # declare the dataset instance + data_set = DataSet( + train_database=database, + valid_database=None, + test_database=None, + mapfly=False, + grid_shape=( + 30, + 30, + 30), + select_feature={ + 'AtomicDensities_ind': 'all', + 'Feature_ind': [ + 'coulomb', + 'vdwaals', + 'charge', + 'PSSM_*']}, + select_target='BIN_CLASS', + tqdm=True, + normalize_features=True, + normalize_targets=False, + clip_features=False, + pair_chain_feature=np.add) + + # create the networkt + model = NeuralNet(data_set, cnn3d_class, model_type='3d', task='class', + cuda=False, plot=True, outdir=out) + + # start the training + model.train( + nepoch=5, + divide_trainset=0.8, + train_batch_size=5, + num_workers=0, + save_epoch='all') + + @unittest.skipIf(skip, "torch fails on Travis") + @staticmethod + def test_learn_2d_reg(): + """Use a 2D CNN for regularization.""" + + # adress of the database + database = '1ak4.hdf5' + + # clean the output dir + out = './out_2d/' + if os.path.isdir(out): + for f in glob.glob(out + '/*'): + os.remove(f) + os.removedirs(out) + + if not os.path.isfile(database): + raise FileNotFoundError( + 'Database %s not found. Make sure to run test_generate before') + + # declare the dataset instance + data_set = DataSet( + train_database=database, + valid_database=None, + test_database=None, + mapfly=False, + select_feature={ + 'AtomicDensities_ind': 'all', + 'Feature_ind': [ + 'coulomb', + 'vdwaals', + 'charge', + 'PSSM_*']}, + select_target='DOCKQ', + tqdm=True, + normalize_features=True, + normalize_targets=True, + clip_features=False, + pair_chain_feature=np.add, + dict_filter={ + 'IRMSD': '<4. or >10.'}) + + # create the network + model = NeuralNet(data_set, cnn2d, model_type='2d', task='reg', + cuda=False, plot=True, outdir=out) + + # start the training + model.train( + nepoch=5, + divide_trainset=0.8, + train_batch_size=5, + num_workers=0) + + @unittest.skipIf(skip, "torch fails on Travis") + @staticmethod + def test_transfer(): + + # adress of the database + database = '1ak4.hdf5' + + if not os.path.isfile(database): + raise FileNotFoundError( + 'Database %s not found. Make sure to run test_generate before') + + # clean the output dir + out = './out_test/' + if os.path.isdir(out): + for f in glob.glob(out + '/*'): + os.remove(f) + os.removedirs(out) + + # create the network + model_name = './out_3d_fly/last_model.pth.tar' + model = NeuralNet( + database, + cnn3d, + pretrained_model=model_name, + outdir=out) + model.test() if __name__ == "__main__": - TestLearn.test_learn_3d_reg_mapfly() - TestLearn.test_learn_3d_reg() - TestLearn.test_learn_3d_class() - TestLearn.test_learn_2d_reg() - TestLearn.test_transfer() - + TestLearn.test_learn_3d_reg_mapfly() + TestLearn.test_learn_3d_reg() + TestLearn.test_learn_3d_class() + TestLearn.test_learn_2d_reg() + TestLearn.test_transfer() diff --git a/test/test_pdb2sql.py b/test/test_pdb2sql.py index 79580763..6937c0d1 100644 --- a/test/test_pdb2sql.py +++ b/test/test_pdb2sql.py @@ -1,22 +1,25 @@ import unittest + import numpy as np + from deeprank.tools import pdb2sql + class TestPDB2SQL(unittest.TestCase): """Test PDB2SQL.""" def test_read(self): """Read a pdb and create a sql db.""" - #db.prettyprint() + # db.prettyprint() self.db.get_colnames() - self.db.exportpdb('chainA.pdb', chainID = 'A') + self.db.exportpdb('chainA.pdb', chainID='A') def test_get(self): """Test get with large number of index.""" index = list(range(1200)) - self.db.get('x,y,z', rowID = index) + self.db.get('x,y,z', rowID=index) @unittest.expectedFailure def test_get_fails(self): @@ -24,7 +27,7 @@ def test_get_fails(self): index_res = list(range(100)) index_atoms = list(range(1200)) - self.db.get('x,y,z', resSeq = index_res, rowID = index_atoms) + self.db.get('x,y,z', resSeq=index_res, rowID=index_atoms) def test_add_column(self): """Add a new column to the db and change its values.""" @@ -34,7 +37,7 @@ def test_add_column(self): n = 100 q = np.random.rand(n) ind = list(range(n)) - self.db.update_column('CHARGE', q, index = ind) + self.db.update_column('CHARGE', q, index=ind) def test_update(self): """Update the database.""" @@ -42,30 +45,30 @@ def test_update(self): n = 200 index = list(range(n)) vals = np.random.rand(n, 3) - self.db.update('x,y,z',vals, rowID = index) + self.db.update('x,y,z', vals, rowID=index) self.db.prettyprint() - self.db.update_xyz(vals, index = index) + self.db.update_xyz(vals, index=index) def test_update_all(self): xyz = self.db.get('x,y,z') - self.db.update('x,y,z',xyz) + self.db.update('x,y,z', xyz) self.db.prettyprint() def test_manip(self): """Manipualte part of the protein.""" vect = np.random.rand(3) - self.db.translation(vect, chainID = 'A') + self.db.translation(vect, chainID='A') axis = np.random.rand(3) angle = np.random.rand() - self.db.rotation_around_axis(axis, angle, chainID = 'B') + self.db.rotation_around_axis(axis, angle, chainID='B') - a,b,c = np.random.rand(3) - self.db.rotation_euler(a,b,c,resName='VAL') + a, b, c = np.random.rand(3) + self.db.rotation_euler(a, b, c, resName='VAL') - mat = np.random.rand(3,3) - self.db.rotation_matrix(mat,chainID='A') + mat = np.random.rand(3, 3) + self.db.rotation_matrix(mat, chainID='A') def setUp(self): mol = './1AK4/decoys/1AK4_cm-it0_745.pdb' @@ -74,5 +77,6 @@ def setUp(self): def tearDown(self): self.db.close() + if __name__ == '__main__': unittest.main() diff --git a/test/test_rmsd.py b/test/test_rmsd.py index 5cf97fb5..ecf9434d 100644 --- a/test/test_rmsd.py +++ b/test/test_rmsd.py @@ -1,7 +1,10 @@ +import os import unittest + import numpy as np + from deeprank.tools import StructureSimilarity -import os + class TestStructureSimilarity(unittest.TestCase): """Test StructureSimialrity.""" @@ -13,21 +16,30 @@ def test_rmsd(): # specify wich data to us MOL = './1AK4/' decoys = MOL + '/decoys/' - ref = MOL + '/native/1AK4.pdb' - data = MOL + '/haddock_data/' + ref = MOL + '/native/1AK4.pdb' + data = MOL + '/haddock_data/' # get the list of decoy names - decoy_list = [decoys+'/'+n for n in list(filter(lambda x: '.pdb' in x, os.listdir(decoys)))] + decoy_list = [ + decoys + + '/' + + n for n in list( + filter( + lambda x: '.pdb' in x, + os.listdir(decoys)))] # reference data used to compare ours haddock_data = {} - haddock_files = [data+'1AK4.Fnat',data+'1AK4.lrmsd',data+'1AK4.irmsd'] + haddock_files = [ + data + '1AK4.Fnat', + data + '1AK4.lrmsd', + data + '1AK4.irmsd'] # extract the data from the haddock files - for i,fname in enumerate(haddock_files): + for i, fname in enumerate(haddock_files): # read the file - f = open(fname,'r') + f = open(fname, 'r') data = f.readlines() data = [d.split() for d in data if not d.startswith('#')] f.close() @@ -39,58 +51,56 @@ def test_rmsd(): haddock_data[mol_name] = np.zeros(3) haddock_data[mol_name][i] = float(line[1]) - # init all the data handlers nconf = len(haddock_data) - deep = np.zeros((nconf,3)) - hdk = np.zeros((nconf,3)) + deep = np.zeros((nconf, 3)) + hdk = np.zeros((nconf, 3)) # compute the data with deeprank deep_data = {} - for i,decoy in enumerate(decoy_list): + for i, decoy in enumerate(decoy_list): - sim = StructureSimilarity(decoy,ref) - lrmsd = sim.compute_lrmsd_fast(method='svd',lzone='1AK4.lzone') - irmsd = sim.compute_irmsd_fast(method='svd',izone='1AK4.izone') + sim = StructureSimilarity(decoy, ref) + lrmsd = sim.compute_lrmsd_fast(method='svd', lzone='1AK4.lzone') + irmsd = sim.compute_irmsd_fast(method='svd', izone='1AK4.izone') fnat = sim.compute_Fnat_fast(ref_pairs='1AK4.refpairs') mol_name = decoy.split('/')[-1].split('.')[0] - deep_data[mol_name] = [fnat,lrmsd,irmsd] - deep[i,:] = deep_data[mol_name] - hdk[i,:] = haddock_data[mol_name] + deep_data[mol_name] = [fnat, lrmsd, irmsd] + deep[i, :] = deep_data[mol_name] + hdk[i, :] = haddock_data[mol_name] # print the deltas - delta = np.max(np.abs(deep-hdk),0) + delta = np.max(np.abs(deep - hdk), 0) # assert the data - if not np.all(delta<[1E-3,1,1E-3]): + if not np.all(delta < [1E-3, 1, 1E-3]): raise AssertionError() - @staticmethod def test_slow(): - """Compute IRMSD/LRMSD from pdb2sql methd to make sure it doesn't crash.""" + """Compute IRMSD/LRMSD from pdb2sql methd to make sure it doesn't + crash.""" # specify wich data to us MOL = './1AK4/' decoy = MOL + '/decoys/1AK4_cm-it0_745.pdb' - ref = MOL + '/native/1AK4.pdb' + ref = MOL + '/native/1AK4.pdb' - sim = StructureSimilarity(decoy,ref) + sim = StructureSimilarity(decoy, ref) trash = sim.compute_lrmsd_pdb2sql(method='svd') trash = sim.compute_irmsd_pdb2sql(method='svd') trash = sim.compute_Fnat_pdb2sql() print(trash) - def setUp(self): """Setup the test by removing old files.""" - files = ['1AK4.lzone','1AK4.izone','1AK4.refpairs'] + files = ['1AK4.lzone', '1AK4.izone', '1AK4.refpairs'] for f in files: if os.path.isfile(f): os.remove(f) + if __name__ == '__main__': unittest.main() - diff --git a/test/test_tools.py b/test/test_tools.py index a99d5da9..17a945ea 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -1,6 +1,6 @@ import unittest -from deeprank.tools import pdb2sql -from deeprank.tools import SASA + +from deeprank.tools import SASA, pdb2sql class TestTools(unittest.TestCase): @@ -33,5 +33,6 @@ def test_sasa(): sasa.get_residue_center() sasa.neighbor_count() + if __name__ == '__main__': unittest.main()