Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
Merge f3ca2d8 into bff0070
Browse files Browse the repository at this point in the history
  • Loading branch information
LilySnow committed Mar 4, 2019
2 parents bff0070 + f3ca2d8 commit 2371511
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 43 deletions.
92 changes: 67 additions & 25 deletions deeprank/learn/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class DataSet():

def __init__(self,database, test_database = None,
def __init__(self,train_database, valid_database = None, test_database = None,
use_rotation = None,
select_feature = 'all', select_target = 'DOCKQ',
normalize_features = True, normalize_targets = True,
Expand All @@ -35,8 +35,8 @@ def __init__(self,database, test_database = None,
part of DeepRank. To create an instance you must provide quite a few arguments.
Example:
>>> from deeprank.learn import *
>>> database = '1ak4.hdf5'
>>> data_set = DataSet(database,
>>> train_database = '1ak4.hdf5'
>>> data_set = DataSet(train_database, valid_database = None,
>>> test_database = None,
>>> grid_shape=(30,30,30),
>>> select_feature = {
Expand All @@ -50,8 +50,10 @@ def __init__(self,database, test_database = None,
>>> dict_filter={'IRMSD':'<4. or >10.'},
>>> process = True)
Args:
database (list(str)): names of the hdf5 files used for the training/validation
train_database (list(str)): names of the hdf5 files used for the training/validation
Example : ['1AK4.hdf5','1B7W.hdf5',...]
valid_database (list(str)): names of the hdf5 files used for the validation
Example : ['1ACB.hdf5','4JHF.hdf5',...]
test_database (list(str)): names of the hdf5 files used for the test
Example : ['7CEI.hdf5']
use_rotation (int): number of rotations to use.
Expand Down Expand Up @@ -91,9 +93,13 @@ def __init__(self,database, test_database = None,
'''

# allow for multiple database
self.database = database
if not isinstance(database,list):
self.database = [database]
self.train_database = train_database
if not isinstance(train_database,list):
self.train_database = [train_database]

self.valid_database = valid_database
if not isinstance(valid_database,list):
self.valid_database = [valid_database]

# allow for multiple database
self.test_database = test_database
Expand Down Expand Up @@ -152,7 +158,7 @@ def process_dataset(self):
print('=\t DeepRank Data Set')
print('=')
print('=\t Training data' )
for f in self.database:
for f in self.train_database:
print('=\t ->',f)
print('=')
if self.test_database is not None:
Expand All @@ -165,7 +171,13 @@ def process_dataset(self):


# check if the files are ok
self.check_hdf5_files()
self.train_database = self.check_hdf5_files(self.train_database)

if self.valid_database is not None:
self.valid_database = self.check_hdf5_files(self.valid_database)

if self.test_database is not None:
self.test_database = self.check_hdf5_files(self.test_database)

# create the indexing system
# alows to associate each mol to an index
Expand Down Expand Up @@ -195,7 +207,8 @@ def process_dataset(self):
print('\n')
print(" Data Set Info")
print(' Training set : %d conformations' %self.ntrain)
print(' Test set : %d conformations' %(self.ntot-self.ntrain))
print(' Validation set : %d conformations' %self.nvalid)
print(' Test set : %d conformations' %(self.ntest))
print(' Number of channels : %d' %self.input_shape[0])
print(' Grid Size : %d x %d x %d' %(self.data_shape[1],self.data_shape[2],self.data_shape[3]))
sys.stdout.flush()
Expand Down Expand Up @@ -241,12 +254,13 @@ def __getitem__(self,index):
return {'mol':[fname,mol],'feature':feature,'target':target}


def check_hdf5_files(self):
@staticmethod
def check_hdf5_files(database):
"""Check if the data contained in the hdf5 file is ok."""

print(" Checking dataset Integrity")
remove_file = []
for fname in self.database:
for fname in database:
try:
f = h5py.File(fname,'r')
mol_names = list(f.keys())
Expand All @@ -259,7 +273,8 @@ def check_hdf5_files(self):
remove_file.append(fname)

for name in remove_file:
self.database.remove(name)
database.remove(name)
return database


def create_index_molecules(self):
Expand All @@ -273,10 +288,10 @@ def create_index_molecules(self):

desc = '{:25s}'.format(' Train dataset')
if self.tqdm:
data_tqdm = tqdm(self.database,desc=desc,file=sys.stdout)
data_tqdm = tqdm(self.train_database,desc=desc,file=sys.stdout)
else:
print(' Train dataset')
data_tqdm = self.database
data_tqdm = self.train_database
sys.stdout.flush()

for fdata in data_tqdm:
Expand All @@ -287,8 +302,7 @@ def create_index_molecules(self):
mol_names = list(fh5.keys())
mol_names = self._select_pdb(mol_names)
for k in mol_names:
if self.filter(fh5[k]):
self.index_complexes += [(fdata,k)]
self.index_complexes += [(fdata,k)]
fh5.close()
except Exception as inst:
print('\t\t-->Ignore File : ' + fdata)
Expand All @@ -297,6 +311,33 @@ def create_index_molecules(self):
self.ntrain = len(self.index_complexes)
self.index_train = list(range(self.ntrain))

if self.valid_database is not None:

desc = '{:25s}'.format(' Validation dataset')
if self.tqdm:
data_tqdm = tqdm(self.valid_database,desc=desc,file=sys.stdout)
else:
data_tqdm = self.valid_database
print(' Validation dataset')
sys.stdout.flush()

for fdata in data_tqdm:
if self.tqdm:
data_tqdm.set_postfix(mol=os.path.basename(fdata))
try:
fh5 = h5py.File(fdata,'r')
mol_names = list(fh5.keys())
mol_names = self._select_pdb(mol_names)
self.index_complexes += [(fdata,k) for k in mol_names]
fh5.close()
except:
print('\t\t-->Ignore File : '+fdata)

self.ntot = len(self.index_complexes)
self.index_valid = list(range(self.ntrain,self.ntot))
self.nvalid = self.ntot - self.ntrain


if self.test_database is not None:

desc = '{:25s}'.format(' Test dataset')
Expand All @@ -313,14 +354,15 @@ def create_index_molecules(self):
try:
fh5 = h5py.File(fdata,'r')
mol_names = list(fh5.keys())
mol_names = selef._select_pdb(mol_names)
mol_names = self._select_pdb(mol_names)
self.index_complexes += [(fdata,k) for k in mol_names]
fh5.close()
except:
print('\t\t-->Ignore File : '+fdata)

self.ntot = len(self.index_complexes)
self.index_test = list(range(self.ntrain,self.ntot))
self.index_test = list(range(self.ntrain + self.nvalid ,self.ntot))
self.ntest = self.ntot - self.ntrain - self.nvalid


def _select_pdb(self, mol_names):
Expand Down Expand Up @@ -391,7 +433,7 @@ def get_feature_name(self):
'''

# open a h5 file in case we need it
f5 = h5py.File(self.database[0],'r')
f5 = h5py.File(self.train_database[0],'r')
mol_name = list(f5.keys())[0]
mapped_data = f5.get(mol_name + '/mapped_features/')
chain_tags = ['_chainA','_chainB']
Expand Down Expand Up @@ -465,7 +507,7 @@ def get_feature_name(self):
def print_possible_features(self):
"""Print the possible features in the group."""

f5 = h5py.File(self.database[0],'r')
f5 = h5py.File(self.train_database[0],'r')
mol_name = list(f5.keys())[0]
mapgrp = f5.get(mol_name + '/mapped_features/')

Expand Down Expand Up @@ -515,7 +557,7 @@ def get_input_shape(self):
self.input_shape : input size of the CNN (potentially after 2d transformation)
"""

fname = self.database[0]
fname = self.train_database[0]
feature,_ = self.load_one_molecule(fname)
self.data_shape = feature.shape

Expand All @@ -535,7 +577,7 @@ def get_grid_shape(self):
ValueError: If no grid shape is provided or is present in the HDF5 file
'''

fname = self.database[0]
fname = self.train_database[0]
fh5 = h5py.File(fname,'r')
mol = list(fh5.keys())[0]

Expand Down Expand Up @@ -588,7 +630,7 @@ def _read_norm(self):
"""Read or create the normalization file for the complex.
"""
# loop through all the filename
for f5 in self.database:
for f5 in self.train_database:

# get the precalculated data
fdata = os.path.splitext(f5)[0]+'_norm.pckl'
Expand Down Expand Up @@ -618,7 +660,7 @@ def _read_norm(self):
self.param_norm['targets'][self.select_target].update(maxv)

# process the std
nfile = len(self.database)
nfile = len(self.train_database)
for feat_types,feat_dict in self.param_norm['features'].items():
for feat in feat_dict:
self.param_norm['features'][feat_types][feat].process(nfile)
Expand Down
14 changes: 11 additions & 3 deletions deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import time
import h5py
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
Expand Down Expand Up @@ -234,7 +236,7 @@ def __init__(self,data_set,model,
print(' --> Aborting the experiment \n\n')
sys.exit()

def train(self,nepoch=50, divide_trainset=None, hdf5='epoch_data.hdf5',train_batch_size = 10,
def train(self,nepoch=50, divide_trainset= None, hdf5='epoch_data.hdf5',train_batch_size = 10,
preshuffle=True, preshuffle_seed=None, export_intermediate=True,num_workers=1,save_model='best',save_epoch='intermediate'):

"""Perform a simple training of the model. The data set is divided in training/validation sets.
Expand Down Expand Up @@ -296,8 +298,14 @@ def train(self,nepoch=50, divide_trainset=None, hdf5='epoch_data.hdf5',train_bat
self.f5 = h5py.File(fname,'w')

# divide the set in train+ valid and test
divide_trainset = divide_trainset or [0.8,0.2]
index_train,index_valid,index_test = self._divide_dataset(divide_trainset,preshuffle, preshuffle_seed)
if divide_trainset is not None:
# if divide_trainset is not None
index_train,index_valid,index_test = self._divide_dataset(divide_trainset,preshuffle, preshuffle_seed)
else:
index_train = self.data_set.index_train
index_valid = self.data_set.index_valid
index_test = self.data_set.index_test


print(': %d confs. for training' %len(index_train))
print(': %d confs. for validation' %len(index_valid))
Expand Down
22 changes: 14 additions & 8 deletions example/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,30 @@
from deeprank.learn import *
from deeprank.learn.model3d import cnn as cnn3d

database = '1ak4.hdf5'
out = './out'
#adress of the database
train_database = './data/1AVX.hdf5'
valid_database = './data/1BVN.hdf5'
test_database = './data/1DFJ.hdf5'


# make sure the databse is there
database = '1ak4.hdf5'
if not os.path.isfile(database):
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before')

for database in [train_database, valid_database, test_database]:
if not os.path.isfile(database):
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before')

# clean the output dir
out = './out_3d'
if os.path.isdir(out):
for f in glob.glob(out+'/*'):
os.remove(f)
os.removedirs(out)


# declare the dataset instance
data_set = DataSet(database,
test_database = None,
data_set = DataSet(train_database = train_database,
valid_database=valid_database,
test_database = test_database,
grid_shape=(30,30,30),
select_feature={'AtomicDensities_ind' : 'all',
'Feature_ind' : ['coulomb','vdwaals','charge','PSSM_*'] },
Expand All @@ -39,4 +45,4 @@
cuda=False,plot=True,outdir=out)

# start the training
model.train(nepoch = 5,divide_trainset=0.8, train_batch_size = 5, num_workers=0, save_model='all')
model.train(nepoch = 5, divide_trainset = None, train_batch_size = 5, num_workers=0, save_model='all')
17 changes: 10 additions & 7 deletions test/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def test_learn_3d_reg():
"""Use a 3D CNN for regularization."""

#adress of the database
database = '1ak4.hdf5'
if not os.path.isfile(database):
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before')
train_database = '1ak4.hdf5'
if not os.path.isfile(train_database):
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before', train_database)

# clean the output dir
out = './out_3d'
Expand All @@ -33,7 +33,8 @@ def test_learn_3d_reg():
os.removedirs(out)

# declare the dataset instance
data_set = DataSet(database,
data_set = DataSet(train_database = train_database,
valid_database = None,
test_database = None,
grid_shape=(30,30,30),
select_feature={'AtomicDensities_ind' : 'all',
Expand Down Expand Up @@ -72,7 +73,8 @@ def test_learn_2d_reg():
raise FileNotFoundError('Database %s not found. Make sure to run test_generate before')

# declare the dataset instance
data_set = DataSet(database,
data_set = DataSet(train_database = database,
valid_database = database,
test_database = database,
select_feature={'AtomicDensities_ind' : 'all',
'Feature_ind' : ['coulomb','vdwaals','charge','PSSM_*'] },
Expand Down Expand Up @@ -129,7 +131,8 @@ def test_learn_3d_class():
os.removedirs(out)

# declare the dataset instance
data_set = DataSet(database,
data_set = DataSet(train_database = database,
valid_database = None,
test_database = None,
grid_shape=(30,30,30),
select_feature={'AtomicDensities_ind' : 'all',
Expand All @@ -153,4 +156,4 @@ def test_learn_3d_class():
#TestLearn.test_learn_3d_reg()
#TestLearn.test_learn_3d_class()
TestLearn.test_learn_2d_reg()
#TestLearn.test_transfer()
#TestLearn.test_transfer()

0 comments on commit 2371511

Please sign in to comment.