Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
add preshuffle_seed
Browse files Browse the repository at this point in the history
  • Loading branch information
CunliangGeng committed Feb 26, 2019
1 parent 095ba31 commit 4ebdf67
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(self,data_set,model,
sys.exit()

def train(self,nepoch=50, divide_trainset=None, hdf5='epoch_data.hdf5',train_batch_size = 10,
preshuffle = True,export_intermediate=True,num_workers=1,save_model='best',save_epoch='intermediate'):
preshuffle=True, preshuffle_seed=None, export_intermediate=True,num_workers=1,save_model='best',save_epoch='intermediate'):

"""Perform a simple training of the model. The data set is divided in training/validation sets.
Expand All @@ -251,6 +251,8 @@ def train(self,nepoch=50, divide_trainset=None, hdf5='epoch_data.hdf5',train_bat
preshuffle (bool, optional): preshuffle the dataset before dividing it
preshuffle_seed (int, optional): set random seed for preshuffle
export_intermediate (bool, optional): export data at interediate epoch
num_workers (int, optional): number of workers to be used to prep the batch data
Expand Down Expand Up @@ -295,7 +297,7 @@ def train(self,nepoch=50, divide_trainset=None, hdf5='epoch_data.hdf5',train_bat

# divide the set in train+ valid and test
divide_trainset = divide_trainset or [0.8,0.2]
index_train,index_valid,index_test = self._divide_dataset(divide_trainset,preshuffle)
index_train,index_valid,index_test = self._divide_dataset(divide_trainset,preshuffle, preshuffle_seed)

print(': %d confs. for training' %len(index_train))
print(': %d confs. for validation' %len(index_valid))
Expand Down Expand Up @@ -416,13 +418,14 @@ def load_data_params(self,filename):
self.data_set.transform = state['transform']
self.data_set.proj2D = state['proj2D']

def _divide_dataset(self,divide_set, preshuffle):
def _divide_dataset(self,divide_set, preshuffle, preshuffle_seed):

'''Divide the data set in a training validation and test according to the percentage in divide_set.
Args:
divide_set (list(float)): percentage used for training/validation/test
preshuffle (bool): shuffle the dataset before dividing it
preshuffle_seed (int, optional): set random seed for preshuffle
Returns:
list(int),list(int),list(int): Indices of the training/validation/test set
Expand All @@ -442,6 +445,9 @@ def _divide_dataset(self,divide_set, preshuffle):

# preshuffle
if preshuffle:
if preshuffle_seed is not None and not isinstance(preshuffle_seed, int):
preshuffle_seed = int(preshuffle_seed)
np.random.seed(preshuffle_seed)
np.random.shuffle(self.data_set.index_train)

# size of the subset for training
Expand Down

0 comments on commit 4ebdf67

Please sign in to comment.