diff --git a/deeprank/learn/DataSet.py b/deeprank/learn/DataSet.py index 237e3fbd..35c7d6a2 100644 --- a/deeprank/learn/DataSet.py +++ b/deeprank/learn/DataSet.py @@ -19,7 +19,7 @@ class DataSet(): - def __init__(self,database, test_database = None, + def __init__(self,train_database, valid_database = None, test_database = None, use_rotation = None, select_feature = 'all', select_target = 'DOCKQ', normalize_features = True, normalize_targets = True, @@ -35,8 +35,8 @@ def __init__(self,database, test_database = None, part of DeepRank. To create an instance you must provide quite a few arguments. Example: >>> from deeprank.learn import * - >>> database = '1ak4.hdf5' - >>> data_set = DataSet(database, + >>> train_database = '1ak4.hdf5' + >>> data_set = DataSet(train_database, valid_database = None, >>> test_database = None, >>> grid_shape=(30,30,30), >>> select_feature = { @@ -50,8 +50,10 @@ def __init__(self,database, test_database = None, >>> dict_filter={'IRMSD':'<4. or >10.'}, >>> process = True) Args: - database (list(str)): names of the hdf5 files used for the training/validation + train_database (list(str)): names of the hdf5 files used for the training/validation Example : ['1AK4.hdf5','1B7W.hdf5',...] + valid_database (list(str)): names of the hdf5 files used for the validation + Example : ['1ACB.hdf5','4JHF.hdf5',...] test_database (list(str)): names of the hdf5 files used for the test Example : ['7CEI.hdf5'] use_rotation (int): number of rotations to use. @@ -91,9 +93,13 @@ def __init__(self,database, test_database = None, ''' # allow for multiple database - self.database = database - if not isinstance(database,list): - self.database = [database] + self.train_database = train_database + if not isinstance(train_database,list): + self.train_database = [train_database] + + self.valid_database = valid_database + if not isinstance(valid_database,list): + self.valid_database = [valid_database] # allow for multiple database self.test_database = test_database @@ -152,7 +158,7 @@ def process_dataset(self): print('=\t DeepRank Data Set') print('=') print('=\t Training data' ) - for f in self.database: + for f in self.train_database: print('=\t ->',f) print('=') if self.test_database is not None: @@ -165,7 +171,13 @@ def process_dataset(self): # check if the files are ok - self.check_hdf5_files() + self.train_database = self.check_hdf5_files(self.train_database) + + if self.valid_database is not None: + self.valid_database = self.check_hdf5_files(self.valid_database) + + if self.test_database is not None: + self.test_database = self.check_hdf5_files(self.test_database) # create the indexing system # alows to associate each mol to an index @@ -195,7 +207,8 @@ def process_dataset(self): print('\n') print(" Data Set Info") print(' Training set : %d conformations' %self.ntrain) - print(' Test set : %d conformations' %(self.ntot-self.ntrain)) + print(' Validation set : %d conformations' %self.nvalid) + print(' Test set : %d conformations' %(self.ntest)) print(' Number of channels : %d' %self.input_shape[0]) print(' Grid Size : %d x %d x %d' %(self.data_shape[1],self.data_shape[2],self.data_shape[3])) sys.stdout.flush() @@ -241,12 +254,13 @@ def __getitem__(self,index): return {'mol':[fname,mol],'feature':feature,'target':target} - def check_hdf5_files(self): + @staticmethod + def check_hdf5_files(database): """Check if the data contained in the hdf5 file is ok.""" print(" Checking dataset Integrity") remove_file = [] - for fname in self.database: + for fname in database: try: f = h5py.File(fname,'r') mol_names = list(f.keys()) @@ -259,7 +273,8 @@ def check_hdf5_files(self): remove_file.append(fname) for name in remove_file: - self.database.remove(name) + database.remove(name) + return database def create_index_molecules(self): @@ -273,10 +288,10 @@ def create_index_molecules(self): desc = '{:25s}'.format(' Train dataset') if self.tqdm: - data_tqdm = tqdm(self.database,desc=desc,file=sys.stdout) + data_tqdm = tqdm(self.train_database,desc=desc,file=sys.stdout) else: print(' Train dataset') - data_tqdm = self.database + data_tqdm = self.train_database sys.stdout.flush() for fdata in data_tqdm: @@ -287,8 +302,7 @@ def create_index_molecules(self): mol_names = list(fh5.keys()) mol_names = self._select_pdb(mol_names) for k in mol_names: - if self.filter(fh5[k]): - self.index_complexes += [(fdata,k)] + self.index_complexes += [(fdata,k)] fh5.close() except Exception as inst: print('\t\t-->Ignore File : ' + fdata) @@ -297,6 +311,33 @@ def create_index_molecules(self): self.ntrain = len(self.index_complexes) self.index_train = list(range(self.ntrain)) + if self.valid_database is not None: + + desc = '{:25s}'.format(' Validation dataset') + if self.tqdm: + data_tqdm = tqdm(self.valid_database,desc=desc,file=sys.stdout) + else: + data_tqdm = self.valid_database + print(' Validation dataset') + sys.stdout.flush() + + for fdata in data_tqdm: + if self.tqdm: + data_tqdm.set_postfix(mol=os.path.basename(fdata)) + try: + fh5 = h5py.File(fdata,'r') + mol_names = list(fh5.keys()) + mol_names = self._select_pdb(mol_names) + self.index_complexes += [(fdata,k) for k in mol_names] + fh5.close() + except: + print('\t\t-->Ignore File : '+fdata) + + self.ntot = len(self.index_complexes) + self.index_valid = list(range(self.ntrain,self.ntot)) + self.nvalid = self.ntot - self.ntrain + + if self.test_database is not None: desc = '{:25s}'.format(' Test dataset') @@ -313,14 +354,15 @@ def create_index_molecules(self): try: fh5 = h5py.File(fdata,'r') mol_names = list(fh5.keys()) - mol_names = selef._select_pdb(mol_names) + mol_names = self._select_pdb(mol_names) self.index_complexes += [(fdata,k) for k in mol_names] fh5.close() except: print('\t\t-->Ignore File : '+fdata) self.ntot = len(self.index_complexes) - self.index_test = list(range(self.ntrain,self.ntot)) + self.index_test = list(range(self.ntrain + self.nvalid ,self.ntot)) + self.ntest = self.ntot - self.ntrain - self.nvalid def _select_pdb(self, mol_names): @@ -391,7 +433,7 @@ def get_feature_name(self): ''' # open a h5 file in case we need it - f5 = h5py.File(self.database[0],'r') + f5 = h5py.File(self.train_database[0],'r') mol_name = list(f5.keys())[0] mapped_data = f5.get(mol_name + '/mapped_features/') chain_tags = ['_chainA','_chainB'] @@ -465,7 +507,7 @@ def get_feature_name(self): def print_possible_features(self): """Print the possible features in the group.""" - f5 = h5py.File(self.database[0],'r') + f5 = h5py.File(self.train_database[0],'r') mol_name = list(f5.keys())[0] mapgrp = f5.get(mol_name + '/mapped_features/') @@ -515,7 +557,7 @@ def get_input_shape(self): self.input_shape : input size of the CNN (potentially after 2d transformation) """ - fname = self.database[0] + fname = self.train_database[0] feature,_ = self.load_one_molecule(fname) self.data_shape = feature.shape @@ -535,7 +577,7 @@ def get_grid_shape(self): ValueError: If no grid shape is provided or is present in the HDF5 file ''' - fname = self.database[0] + fname = self.train_database[0] fh5 = h5py.File(fname,'r') mol = list(fh5.keys())[0] @@ -588,7 +630,7 @@ def _read_norm(self): """Read or create the normalization file for the complex. """ # loop through all the filename - for f5 in self.database: + for f5 in self.train_database: # get the precalculated data fdata = os.path.splitext(f5)[0]+'_norm.pckl' @@ -618,7 +660,7 @@ def _read_norm(self): self.param_norm['targets'][self.select_target].update(maxv) # process the std - nfile = len(self.database) + nfile = len(self.train_database) for feat_types,feat_dict in self.param_norm['features'].items(): for feat in feat_dict: self.param_norm['features'][feat_types][feat].process(nfile) diff --git a/deeprank/learn/NeuralNet.py b/deeprank/learn/NeuralNet.py index 29366fb6..190e2d26 100644 --- a/deeprank/learn/NeuralNet.py +++ b/deeprank/learn/NeuralNet.py @@ -3,6 +3,8 @@ import os import time import h5py +import matplotlib +matplotlib.use('agg') import matplotlib.pyplot as plt import matplotlib.ticker as mtick import numpy as np @@ -234,7 +236,7 @@ def __init__(self,data_set,model, print(' --> Aborting the experiment \n\n') sys.exit() - def train(self,nepoch=50, divide_trainset=None, hdf5='epoch_data.hdf5',train_batch_size = 10, + def train(self,nepoch=50, divide_trainset= None, hdf5='epoch_data.hdf5',train_batch_size = 10, preshuffle=True, preshuffle_seed=None, export_intermediate=True,num_workers=1,save_model='best',save_epoch='intermediate'): """Perform a simple training of the model. The data set is divided in training/validation sets. @@ -296,8 +298,14 @@ def train(self,nepoch=50, divide_trainset=None, hdf5='epoch_data.hdf5',train_bat self.f5 = h5py.File(fname,'w') # divide the set in train+ valid and test - divide_trainset = divide_trainset or [0.8,0.2] - index_train,index_valid,index_test = self._divide_dataset(divide_trainset,preshuffle, preshuffle_seed) + if divide_trainset is not None: + # if divide_trainset is not None + index_train,index_valid,index_test = self._divide_dataset(divide_trainset,preshuffle, preshuffle_seed) + else: + index_train = self.data_set.index_train + index_valid = self.data_set.index_valid + index_test = self.data_set.index_test + print(': %d confs. for training' %len(index_train)) print(': %d confs. for validation' %len(index_valid)) diff --git a/example/learn.py b/example/learn.py index c18b07ad..06dc2dd0 100644 --- a/example/learn.py +++ b/example/learn.py @@ -5,15 +5,20 @@ from deeprank.learn import * from deeprank.learn.model3d import cnn as cnn3d -database = '1ak4.hdf5' -out = './out' +#adress of the database +train_database = './data/1AVX.hdf5' +valid_database = './data/1BVN.hdf5' +test_database = './data/1DFJ.hdf5' + # make sure the databse is there -database = '1ak4.hdf5' -if not os.path.isfile(database): - raise FileNotFoundError('Database %s not found. Make sure to run test_generate before') + +for database in [train_database, valid_database, test_database]: + if not os.path.isfile(database): + raise FileNotFoundError('Database %s not found. Make sure to run test_generate before') # clean the output dir +out = './out_3d' if os.path.isdir(out): for f in glob.glob(out+'/*'): os.remove(f) @@ -21,8 +26,9 @@ # declare the dataset instance -data_set = DataSet(database, - test_database = None, +data_set = DataSet(train_database = train_database, + valid_database=valid_database, + test_database = test_database, grid_shape=(30,30,30), select_feature={'AtomicDensities_ind' : 'all', 'Feature_ind' : ['coulomb','vdwaals','charge','PSSM_*'] }, @@ -39,4 +45,4 @@ cuda=False,plot=True,outdir=out) # start the training -model.train(nepoch = 5,divide_trainset=0.8, train_batch_size = 5, num_workers=0, save_model='all') \ No newline at end of file +model.train(nepoch = 5, divide_trainset = None, train_batch_size = 5, num_workers=0, save_model='all') diff --git a/test/test_learn.py b/test/test_learn.py index 40093699..7fd1a572 100644 --- a/test/test_learn.py +++ b/test/test_learn.py @@ -21,9 +21,9 @@ def test_learn_3d_reg(): """Use a 3D CNN for regularization.""" #adress of the database - database = '1ak4.hdf5' - if not os.path.isfile(database): - raise FileNotFoundError('Database %s not found. Make sure to run test_generate before') + train_database = '1ak4.hdf5' + if not os.path.isfile(train_database): + raise FileNotFoundError('Database %s not found. Make sure to run test_generate before', train_database) # clean the output dir out = './out_3d' @@ -33,7 +33,8 @@ def test_learn_3d_reg(): os.removedirs(out) # declare the dataset instance - data_set = DataSet(database, + data_set = DataSet(train_database = train_database, + valid_database = None, test_database = None, grid_shape=(30,30,30), select_feature={'AtomicDensities_ind' : 'all', @@ -72,7 +73,8 @@ def test_learn_2d_reg(): raise FileNotFoundError('Database %s not found. Make sure to run test_generate before') # declare the dataset instance - data_set = DataSet(database, + data_set = DataSet(train_database = database, + valid_database = database, test_database = database, select_feature={'AtomicDensities_ind' : 'all', 'Feature_ind' : ['coulomb','vdwaals','charge','PSSM_*'] }, @@ -129,7 +131,8 @@ def test_learn_3d_class(): os.removedirs(out) # declare the dataset instance - data_set = DataSet(database, + data_set = DataSet(train_database = database, + valid_database = None, test_database = None, grid_shape=(30,30,30), select_feature={'AtomicDensities_ind' : 'all', @@ -153,4 +156,4 @@ def test_learn_3d_class(): #TestLearn.test_learn_3d_reg() #TestLearn.test_learn_3d_class() TestLearn.test_learn_2d_reg() - #TestLearn.test_transfer() \ No newline at end of file + #TestLearn.test_transfer()