In [75]:
import pickle, numpy as np, glob

In [76]:
def make_arrays(num_samples, num_features):
  if num_samples:
    X = np.ndarray((num_samples, num_features), dtype=np.float32)
    y = np.ndarray(num_samples, dtype=np.float32)
  else:
    X, y = None, None
  return X, y

def merge_datasets(pickle_files):
    num_parts = len(pickle_files)

    X = None
    y = None
    for label, pickle_file in enumerate(pickle_files):       
        try:
            with open(pickle_file, 'rb') as f:
                sequence_data = pickle.load(f)
        except Exception as e:
          print('Unable to process data from', pickle_file, ':', e)
          raise

        if X is not None:
            X = np.vstack((X, sequence_data['X']))
            y = np.hstack((y, sequence_data['y']))
            
        else:
            X = sequence_data['X']
            y = sequence_data['y']
    
    return X,y

In [84]:
def shuffle_in_unison(X, y):
    assert len(X) == len(y)
    p = np.random.permutation(len(X))
    return X[p,:], y[p]

def split_data(X, y, train_size, valid_size, test_size):
    train_X, train_y = X[:train_size,:], y[:train_size]
    k = train_size + valid_size
    valid_X, valid_y = X[train_size:k,:], y[train_size:k]
    test_X, test_y = X[k:(k+test_size),:], y[k:(k+test_size)]
    
    return train_X, train_y, valid_X, valid_y, test_X, test_y

for label in ['5bins_edges', '200bins']:
    files = glob.glob('data/*_%s.pickle' % label)
    X, y = merge_datasets(files)
    X, y = shuffle_in_unison(X,y)
    print('X', X.shape, 'y', y.shape)
    pickle.dump({'X': X, 'y': y}, open('data/dataset_%s.pickle' % label, 'wb'))

    size = len(X)
    train_size = int(.7*size)
    valid_size = int(.15*size)
    test_size = size-train_size-valid_size

    train_X, train_y, valid_X, valid_y, test_X, test_y = split_data(X, y, train_size, valid_size, test_size)
    print(len(train_X), len(test_X), len(test_y))
    pickle.dump({'X': train_X, 'y': train_y}, open('data/train_%s.pickle' % label, 'wb'))
    pickle.dump({'X': valid_X, 'y': valid_y}, open('data/valid_%s.pickle' % label, 'wb'))
    pickle.dump({'X': test_X, 'y': test_y}, open('data/test_%s.pickle' % label, 'wb'))
    

X (69537, 88) y (69537,)
48675 10432 10432
X (69537, 1592) y (69537,)
48675 10432 10432
