In [3]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
import os

In [4]:
file="/home/pablo/github/DeepIceLearning/h5_final3/File_375.h5"
data = h5py.File(file,'r')
data.keys()

<KeysViewHDF5 ['IC_charge', 'IC_charge_100ns', 'IC_charge_10ns', 'IC_charge_500ns', 'IC_charge_50ns', 'IC_charge_last', 'IC_diff', 'IC_first_charge', 'IC_mean', 'IC_mult', 'IC_num_pulses', 'IC_pulse_0_01_pct_charge_quantile', 'IC_pulse_0_03_pct_charge_quantile', 'IC_pulse_0_05_pct_charge_quantile', 'IC_pulse_0_08_pct_charge_quantile', 'IC_pulse_0_11_pct_charge_quantile', 'IC_pulse_0_15_pct_charge_quantile', 'IC_pulse_0_2_pct_charge_quantile', 'IC_pulse_0_5_pct_charge_quantile', 'IC_pulse_0_8_pct_charge_quantile', 'IC_time_first', 'IC_time_last', 'IC_time_spread', 'IC_time_std', 'IC_time_weighted_median', 'IC_var', 'reco_vals']>

In [2]:
def generator_v2(batch_size, file_handlers, inds, inp_shape_dict,
                 inp_transformations, out_shape_dict, out_transformations,
                 weighting_function=None, use_data=False, equal_len=False,
                 mask_func=None, valid=False):

    """ This function generates the training batches for the neural network.
    It load all input and output data and applies the transformations
    as defined in the network definition file.

    Arguments:
    batch size : the batch size per gpu
    file_handlers: list of files used for the training
                   i.e. ['/path/to/file/A', 'path/to/file/B']
    inds: the index range used for the dataset
          i.e. [(0,1000), (0,2000)]
    inp_shape_dict: A dictionary with the input shape for each branch
    inp_transformations: Dictionary with input variable name and function
    out_shape_dict: A dictionary with the output shape for each branch
    out_transformations: Dictionary with out variable name and function

    Returns:
    batch_input : a batch of input data
    batch_out: a batch of output data

    """

    print('Run with inds {}'.format(inds))
    file_handlers = np.array(file_handlers)
    inds = np.array(inds)
    in_branches = [(branch, inp_shape_dict[branch]['general'])
                   for branch in inp_shape_dict]
    out_branches = [(branch, out_shape_dict[branch]['general'])
                    for branch in out_shape_dict]
    inp_variables = [[(i, inp_transformations[branch[0]][i])
                      for i in inp_transformations[branch[0]]]
                     for branch in in_branches]
    out_variables = [[(i, out_transformations[branch[0]][i])
                      for i in out_transformations[branch[0]]]
                     for branch in out_branches]
#    print inp_transformations
#    print out_transformations
#    print inp_shape_dict
#    print out_shape_dict
    cur_file = 0
    ind_lo = inds[0][0]
    ind_hi = inds[0][0] + batch_size
    in_data = h5py.File(file_handlers[0], 'r')
    f_reco_vals = in_data['reco_vals']
    t0 = time.time()
    num_batches = 1
 
    while True:
        inp_data = []
        out_data = []
        weights = []
        arr_size = np.min([batch_size, ind_hi - ind_lo])
        reco_vals = f_reco_vals[ind_lo:ind_hi]

        #print('Generate Input Data')
        for k, b in enumerate(out_branches):
            for j, f in enumerate(out_variables[k]):
                if weighting_function != None:
                    tweights=weighting_function(reco_vals)
                else:
                    tweights=np.ones(arr_size)
                if mask_func != None:
                    mask = mask_func(reco_vals)
                    tweights[mask] = 0
            weights.append(tweights)
            
        for k, b in enumerate(in_branches):
            batch_input = np.zeros((arr_size,)+in_branches[k][1])
            for j, f in enumerate(inp_variables[k]):
                if f[0] in in_data.keys():
                    pre_data = np.array(np.squeeze(in_data[f[0]][ind_lo:ind_hi]), ndmin=4)
                    batch_input[:,:,:,:,j] = np.atleast_1d(f[1](pre_data))
                else:
                    pre_data = np.squeeze(reco_vals[f[0]])
                    batch_input[:,j]=f[1](pre_data)
            inp_data.append(batch_input)
        
        # Generate Output Data
        if not use_data:
            for k, b in enumerate(out_branches):
                shape = (arr_size,)+out_branches[k][1]
                batch_output = np.zeros(shape)
                for j, f in enumerate(out_variables[k]):
                    pre_data = np.squeeze(reco_vals[f[0]])
                    if len(out_variables[k]) == 1:
                        batch_output[:]=np.reshape(f[1](pre_data), shape)
                    else:
                        batch_output[:,j] = f[1](pre_data)
                out_data.append(batch_output)

        #Prepare Next Loop
        ind_lo += batch_size
        ind_hi += batch_size
        if (ind_lo >= inds[cur_file][1]) | (equal_len & (ind_hi > inds[cur_file][1])):
            cur_file += 1
            if (cur_file == len(file_handlers)):
                cur_file=0
                new_inds = np.random.permutation(len(file_handlers))
                print('Shuffle filelist...')
                file_handlers = file_handlers[new_inds]
                inds = inds[new_inds] 
            t1 = time.time()
            print('\n Open File: {} \n'.format(file_handlers[cur_file]))
            print('\n Average Time per Batch: {}s \n'.format((t1-t0)/num_batches))
            t0 = time.time()
            num_batches = 1
            in_data.close()
            in_data = h5py.File(file_handlers[cur_file], 'r')
            f_reco_vals = in_data['reco_vals']
            ind_lo = inds[cur_file][0]
            ind_hi = ind_lo + batch_size
        elif ind_hi > inds[cur_file][1]:
            ind_hi = inds[cur_file][1]
       
        # Yield Result
        num_batches += 1
        if use_data:
            yield inp_data
        else:
            yield (inp_data, out_data, weights)
