Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Added validation set option #52

Merged
merged 3 commits into from
Mar 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 67 additions & 25 deletions deeprank/learn/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class DataSet():

def __init__(self,database, test_database = None,
def __init__(self,train_database, valid_database = None, test_database = None,
use_rotation = None,
select_feature = 'all', select_target = 'DOCKQ',
normalize_features = True, normalize_targets = True,
Expand All @@ -35,8 +35,8 @@ def __init__(self,database, test_database = None,
part of DeepRank. To create an instance you must provide quite a few arguments.
Example:
>>> from deeprank.learn import *
>>> database = '1ak4.hdf5'
>>> data_set = DataSet(database,
>>> train_database = '1ak4.hdf5'
>>> data_set = DataSet(train_database, valid_database = None,
>>> test_database = None,
>>> grid_shape=(30,30,30),
>>> select_feature = {
Expand All @@ -50,8 +50,10 @@ def __init__(self,database, test_database = None,
>>> dict_filter={'IRMSD':'<4. or >10.'},
>>> process = True)
Args:
database (list(str)): names of the hdf5 files used for the training/validation
train_database (list(str)): names of the hdf5 files used for the training/validation
Example : ['1AK4.hdf5','1B7W.hdf5',...]
valid_database (list(str)): names of the hdf5 files used for the validation
Example : ['1ACB.hdf5','4JHF.hdf5',...]
test_database (list(str)): names of the hdf5 files used for the test
Example : ['7CEI.hdf5']
use_rotation (int): number of rotations to use.
Expand Down Expand Up @@ -91,9 +93,13 @@ def __init__(self,database, test_database = None,
'''

# allow for multiple database
self.database = database
if not isinstance(database,list):
self.database = [database]
self.train_database = train_database
if not isinstance(train_database,list):
self.train_database = [train_database]

self.valid_database = valid_database
if not isinstance(valid_database,list):
self.valid_database = [valid_database]

# allow for multiple database
self.test_database = test_database
Expand Down Expand Up @@ -152,7 +158,7 @@ def process_dataset(self):
print('=\t DeepRank Data Set')
print('=')
print('=\t Training data' )
for f in self.database:
for f in self.train_database:
print('=\t ->',f)
print('=')
if self.test_database is not None:
Expand All @@ -165,7 +171,13 @@ def process_dataset(self):


# check if the files are ok
self.check_hdf5_files()
self.train_database = self.check_hdf5_files(self.train_database)

if self.valid_database is not None:
self.valid_database = self.check_hdf5_files(self.valid_database)

if self.test_database is not None:
self.test_database = self.check_hdf5_files(self.test_database)

# create the indexing system
# alows to associate each mol to an index
Expand Down Expand Up @@ -195,7 +207,8 @@ def process_dataset(self):
print('\n')
print(" Data Set Info")
print(' Training set : %d conformations' %self.ntrain)
print(' Test set : %d conformations' %(self.ntot-self.ntrain))
print(' Validation set : %d conformations' %self.nvalid)
print(' Test set : %d conformations' %(self.ntest))
print(' Number of channels : %d' %self.input_shape[0])
print(' Grid Size : %d x %d x %d' %(self.data_shape[1],self.data_shape[2],self.data_shape[3]))
sys.stdout.flush()
Expand Down Expand Up @@ -241,12 +254,13 @@ def __getitem__(self,index):
return {'mol':[fname,mol],'feature':feature,'target':target}


def check_hdf5_files(self):
@staticmethod
def check_hdf5_files(database):
"""Check if the data contained in the hdf5 file is ok."""

print(" Checking dataset Integrity")
remove_file = []
for fname in self.database:
for fname in database:
try:
f = h5py.File(fname,'r')
mol_names = list(f.keys())
Expand All @@ -259,7 +273,8 @@ def check_hdf5_files(self):
remove_file.append(fname)

for name in remove_file:
self.database.remove(name)
database.remove(name)
return database


def create_index_molecules(self):
Expand All @@ -273,10 +288,10 @@ def create_index_molecules(self):

desc = '{:25s}'.format(' Train dataset')
if self.tqdm:
data_tqdm = tqdm(self.database,desc=desc,file=sys.stdout)
data_tqdm = tqdm(self.train_database,desc=desc,file=sys.stdout)
else:
print(' Train dataset')
data_tqdm = self.database
data_tqdm = self.train_database
sys.stdout.flush()

for fdata in data_tqdm:
Expand All @@ -287,8 +302,7 @@ def create_index_molecules(self):
mol_names = list(fh5.keys())
mol_names = self._select_pdb(mol_names)
for k in mol_names:
if self.filter(fh5[k]):
self.index_complexes += [(fdata,k)]
self.index_complexes += [(fdata,k)]
fh5.close()
except Exception as inst:
print('\t\t-->Ignore File : ' + fdata)
Expand All @@ -297,6 +311,33 @@ def create_index_molecules(self):
self.ntrain = len(self.index_complexes)
self.index_train = list(range(self.ntrain))

if self.valid_database is not None:

desc = '{:25s}'.format(' Validation dataset')
if self.tqdm:
data_tqdm = tqdm(self.valid_database,desc=desc,file=sys.stdout)
else:
data_tqdm = self.valid_database
print(' Validation dataset')
sys.stdout.flush()

for fdata in data_tqdm:
if self.tqdm:
data_tqdm.set_postfix(mol=os.path.basename(fdata))
try:
fh5 = h5py.File(fdata,'r')
mol_names = list(fh5.keys())
mol_names = self._select_pdb(mol_names)
self.index_complexes += [(fdata,k) for k in mol_names]
fh5.close()
except:
print('\t\t-->Ignore File : '+fdata)

self.ntot = len(self.index_complexes)
self.index_valid = list(range(self.ntrain,self.ntot))
self.nvalid = self.ntot - self.ntrain


if self.test_database is not None:

desc = '{:25s}'.format(' Test dataset')
Expand All @@ -313,14 +354,15 @@ def create_index_molecules(self):
try:
fh5 = h5py.File(fdata,'r')
mol_names = list(fh5.keys())
mol_names = selef._select_pdb(mol_names)
mol_names = self._select_pdb(mol_names)
self.index_complexes += [(fdata,k) for k in mol_names]
fh5.close()
except:
print('\t\t-->Ignore File : '+fdata)

self.ntot = len(self.index_complexes)
self.index_test = list(range(self.ntrain,self.ntot))
self.index_test = list(range(self.ntrain + self.nvalid ,self.ntot))
self.ntest = self.ntot - self.ntrain - self.nvalid


def _select_pdb(self, mol_names):
Expand Down Expand Up @@ -391,7 +433,7 @@ def get_feature_name(self):
'''

# open a h5 file in case we need it
f5 = h5py.File(self.database[0],'r')
f5 = h5py.File(self.train_database[0],'r')
mol_name = list(f5.keys())[0]
mapped_data = f5.get(mol_name + '/mapped_features/')
chain_tags = ['_chainA','_chainB']
Expand Down Expand Up @@ -465,7 +507,7 @@ def get_feature_name(self):
def print_possible_features(self):
"""Print the possible features in the group."""

f5 = h5py.File(self.database[0],'r')
f5 = h5py.File(self.train_database[0],'r')
mol_name = list(f5.keys())[0]
mapgrp = f5.get(mol_name + '/mapped_features/')

Expand Down Expand Up @@ -515,7 +557,7 @@ def get_input_shape(self):
self.input_shape : input size of the CNN (potentially after 2d transformation)
"""

fname = self.database[0]
fname = self.train_database[0]
feature,_ = self.load_one_molecule(fname)
self.data_shape = feature.shape

Expand All @@ -535,7 +577,7 @@ def get_grid_shape(self):
ValueError: If no grid shape is provided or is present in the HDF5 file
'''

fname = self.database[0]
fname = self.train_database[0]
fh5 = h5py.File(fname,'r')
mol = list(fh5.keys())[0]

Expand Down Expand Up @@ -588,7 +630,7 @@ def _read_norm(self):
"""Read or create the normalization file for the complex.
"""
# loop through all the filename
for f5 in self.database:
for f5 in self.train_database:

# get the precalculated data
fdata = os.path.splitext(f5)[0]+'_norm.pckl'
Expand Down Expand Up @@ -618,7 +660,7 @@ def _read_norm(self):
self.param_norm['targets'][self.select_target].update(maxv)

# process the std
nfile = len(self.database)
nfile = len(self.train_database)
for feat_types,feat_dict in self.param_norm['features'].items():
for feat in feat_dict:
self.param_norm['features'][feat_types][feat].process(nfile)
Expand Down
14 changes: 11 additions & 3 deletions deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import time
import h5py
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
Expand Down Expand Up @@ -234,7 +236,7 @@ def __init__(self,data_set,model,
print(' --> Aborting the experiment \n\n')
sys.exit()

def train(self,nepoch=50, divide_trainset=None, hdf5='epoch_data.hdf5',train_batch_size = 10,
def train(self,nepoch=50, divide_trainset= None, hdf5='epoch_data.hdf5',train_batch_size = 10,
preshuffle=True, preshuffle_seed=None, export_intermediate=True,num_workers=1,save_model='best',save_epoch='intermediate'):

"""Perform a simple training of the model. The data set is divided in training/validation sets.
Expand Down Expand Up @@ -296,8 +298,14 @@ def train(self,nepoch=50, divide_trainset=None, hdf5='epoch_data.hdf5',train_bat
self.f5 = h5py.File(fname,'w')

# divide the set in train+ valid and test
divide_trainset = divide_trainset or [0.8,0.2]
index_train,index_valid,index_test = self._divide_dataset(divide_trainset,preshuffle, preshuffle_seed)
if divide_trainset is not None:
# if divide_trainset is not None
index_train,index_valid,index_test = self._divide_dataset(divide_trainset,preshuffle, preshuffle_seed)
else:
index_train = self.data_set.index_train
index_valid = self.data_set.index_valid
index_test = self.data_set.index_test


print(': %d confs. for training' %len(index_train))
print(': %d confs. for validation' %len(index_valid))
Expand Down
22 changes: 14 additions & 8 deletions example/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,30 @@
from deeprank.learn import *
from deeprank.learn.model3d import cnn as cnn3d

database = '1ak4.hdf5'
out = './out'
#adress of the database
train_database = './data/1AVX.hdf5'
valid_database = './data/1BVN.hdf5'
test_database = './data/1DFJ.hdf5'


# make sure the databse is there
database = '1ak4.hdf5'
if not os.path.isfile(database):
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before')

for database in [train_database, valid_database, test_database]:
if not os.path.isfile(database):
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before')

# clean the output dir
out = './out_3d'
if os.path.isdir(out):
for f in glob.glob(out+'/*'):
os.remove(f)
os.removedirs(out)


# declare the dataset instance
data_set = DataSet(database,
test_database = None,
data_set = DataSet(train_database = train_database,
valid_database=valid_database,
test_database = test_database,
grid_shape=(30,30,30),
select_feature={'AtomicDensities_ind' : 'all',
'Feature_ind' : ['coulomb','vdwaals','charge','PSSM_*'] },
Expand All @@ -39,4 +45,4 @@
cuda=False,plot=True,outdir=out)

# start the training
model.train(nepoch = 5,divide_trainset=0.8, train_batch_size = 5, num_workers=0, save_model='all')
model.train(nepoch = 5, divide_trainset = None, train_batch_size = 5, num_workers=0, save_model='all')
17 changes: 10 additions & 7 deletions test/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def test_learn_3d_reg():
"""Use a 3D CNN for regularization."""

#adress of the database
database = '1ak4.hdf5'
if not os.path.isfile(database):
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before')
train_database = '1ak4.hdf5'
if not os.path.isfile(train_database):
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before', train_database)

# clean the output dir
out = './out_3d'
Expand All @@ -33,7 +33,8 @@ def test_learn_3d_reg():
os.removedirs(out)

# declare the dataset instance
data_set = DataSet(database,
data_set = DataSet(train_database = train_database,
valid_database = None,
test_database = None,
grid_shape=(30,30,30),
select_feature={'AtomicDensities_ind' : 'all',
Expand Down Expand Up @@ -72,7 +73,8 @@ def test_learn_2d_reg():
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before')

# declare the dataset instance
data_set = DataSet(database,
data_set = DataSet(train_database = database,
valid_database = database,
test_database = database,
select_feature={'AtomicDensities_ind' : 'all',
'Feature_ind' : ['coulomb','vdwaals','charge','PSSM_*'] },
Expand Down Expand Up @@ -129,7 +131,8 @@ def test_learn_3d_class():
os.removedirs(out)

# declare the dataset instance
data_set = DataSet(database,
data_set = DataSet(train_database = database,
valid_database = None,
test_database = None,
grid_shape=(30,30,30),
select_feature={'AtomicDensities_ind' : 'all',
Expand All @@ -153,4 +156,4 @@ def test_learn_3d_class():
#TestLearn.test_learn_3d_reg()
#TestLearn.test_learn_3d_class()
TestLearn.test_learn_2d_reg()
#TestLearn.test_transfer()
#TestLearn.test_transfer()