In [None]:
# Evan Racah
import h5py
from sklearn.preprocessing import StandardScaler
import numpy as np
import pickle
import glob
import os
from util.helper_fxns import center, scale
from operator import mul

In [None]:
def load_ibd_pairs(path, train_frac=0.5, valid_frac=0.25, tot_num_pairs=-1):
    '''Load up the hdf5 file given into a set of numpy arrays suitable for
    convnets.

    The output is a tuple of (train, valid, test). Each set has shape
    (n_pairs, nchannels, xsize, ysize) where
        (nchannels, xsize, ysize) = (4, 8, 24).

    The relative size of each set can be specified in the arguments.'''
    h5file = h5py.File(path, 'r')
    h5set = h5file['ibd_pair_data']
    
    if tot_num_pairs == -1:
        npairs = h5set.shape[0]
    else:
        npairs = tot_num_pairs
    ntrain = int(train_frac * npairs)
    nvalid = int(valid_frac * npairs)
    ntest = npairs - ntrain - nvalid

    train = np.asarray(h5set[:ntrain])
    valid = np.asarray(h5set[ntrain:(ntrain + nvalid)])
    test = np.asarray(h5set[(ntrain + nvalid):])

    imageshape = (4, 8, 24)
    nfeatures = reduce(mul, imageshape)
    # Don't use all of the array since it contains the metadata as well as the
    # pixels
    train = train[:, :nfeatures].reshape(ntrain, *imageshape)
    valid = valid[:, :nfeatures].reshape(nvalid, *imageshape)
    test = test[:, :nfeatures].reshape(ntest, *imageshape)

    return (train, valid, test)


def get_ibd_data(path_prefix="/global/homes/s/skohn/ml/dayabay-data-conversion/extract_ibd", mode='standardize',
                tot_num_pairs=-1):
    
    h5files = []

    name = os.path.join(path_prefix,"ibd_yasu_%d_%d.h5")
    h5file = name % (i*10000, (i+1)*10000-1)
    train, test, val = load_ibd_pairs(h5file, tot_num_pairs=tot_num_pairs)

    center(train)
    center(val)
    center(test)
    scale(train, 1, mode=mode)
    scale(val, 1, mode=mode)
    scale(test, 1, mode=mode)
    
    return train, val, test