In [1]:
#!/usr/bin/evn/ python

import h5py
import os 
import argparse
import numpy as np



In [2]:
def parse_args():
    parser = argparse.ArgumentParser(
        description="Separate Train and Validation from Test data")
    parser.add_argument("h5_file",
                        type=str,
                        help="Path to h5_file,\
                        must contain 'event_data'")
    parser.add_argument('output_folder', type=str,
                        help="Path to output folder.")
    parser.add_argument('indices_folder', type=str, help="Path to indices folder")
    args = parser.parse_args()
    return args

def load_indices(indices_file):
    with open(indices_file, 'r') as f:
        lines = f.readlines()
    # indicies = [int(l.strip()) for l in lines if not l.isspace()]
    indices = [int(l.strip()) for l in lines]
    return indices



In [3]:
class EasyDict(dict):
    def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
    def __getattr__(self, name): return self[name]
    def __setattr__(self, name, value): self[name] = value
    def __delattr__(self, name): del self[name]
        
config = EasyDict()
config.h5_file = "/app/test_data/IWCDmPMT_4pi_fulltank_test_graphnet.h5"
config.indices_folder = "/app/test_data/IWCDmPMT_4pi_fulltank_test_splits/"
config.output_folder = "/app/test_data/split_h5_3"

In [4]:
test_indices = load_indices(os.path.join(config.indices_folder, "test.txt"))
train_indices = load_indices(os.path.join(config.indices_folder, "train.txt"))
val_indices = load_indices(os.path.join(config.indices_folder, "val.txt"))


In [5]:
basename, extension = os.path.splitext(os.path.basename(config.h5_file))
test_filename = basename + "_test" + extension
train_filename = basename + "_trainval" + extension

print(test_filename, train_filename)

IWCDmPMT_4pi_fulltank_test_graphnet_test.h5 IWCDmPMT_4pi_fulltank_test_graphnet_trainval.h5


In [6]:
os.makedirs(config.output_folder, exist_ok=True)

test_filepath = os.path.join(config.output_folder, test_filename)
train_filepath = os.path.join(config.output_folder, train_filename)

print(test_filepath, train_filepath)

/app/test_data/split_h5_3/IWCDmPMT_4pi_fulltank_test_graphnet_test.h5 /app/test_data/split_h5_3/IWCDmPMT_4pi_fulltank_test_graphnet_trainval.h5


# One by One copy

In [7]:
with h5py.File(config.h5_file, 'r') as infile:
    keys = list(infile.keys())
    
    print("Writing testing data to {}".format(test_filepath))
    with h5py.File(test_filepath, 'w') as outfile:
        length = len(test_indices)
        for key in keys:
            original_shape = infile[key].shape
            original_dtype = infile[key].dtype
            new_shape = (length, ) + original_shape[1:]

            dataset = outfile.create_dataset(key, shape=new_shape, dtype=original_dtype)

            for i,j in enumerate(test_indices):
                dataset[i] = infile[key][j]

    # Write the trainval file
    print("Writing training and validating data to {}".format(train_filepath))
    with h5py.File(train_filepath, 'w') as outfile:
        length = len(train_indices) + len(val_indices)
        for key in keys:
            original_shape = infile[key].shape
            original_dtype = infile[key].dtype
            new_shape = (length, ) + original_shape[1:]

            dataset = outfile.create_dataset(key, shape=new_shape, dtype=original_dtype)

            for i,j in enumerate(train_indices):
                dataset[i] = infile[key][j]

            for i, j in enumerate(val_indices):
                dataset[i+len(train_indices)] = infile[key][j]

Writing testing data to /app/test_data/split_h5_3/IWCDmPMT_4pi_fulltank_test_graphnet_test.h5
Writing training and validating data to /app/test_data/split_h5_3/IWCDmPMT_4pi_fulltank_test_graphnet_trainval.h5


# Sorted copy

In [8]:
# with h5py.File(config.h5_file, 'r') as infile:
#     keys = list(infile.keys())

#     with h5py.File(test_filepath, 'w') as outfile:
#         length = len(test_indices)
#         for key in keys:
#             original_shape = infile[key].shape
#             original_dtype = infile[key].dtype
#             new_shape = (length, ) + original_shape[1:]
            
#             sorted_indices = sorted(test_indices)
#             dataset = outfile.create_dataset(key, shape=new_shape, dtype=original_dtype,
#                                             data=infile[key][sorted_indices])
            
    
#     with h5py.File(train_filepath, 'w') as outfile:
#         length = len(train_indices) + len(val_indices)
#         for key in keys:
#             original_shape = infile[key].shape
#             original_dtype = infile[key].dtype
#             new_shape = (length, ) + original_shape[1:]
            
#             dataset = outfile.create_dataset(key, shape=new_shape, dtype=original_dtype)
            
#             sorted_indices = sorted(train_indices)
#             dataset[:len(train_indices)] = infile[key][sorted_indices]
#             sorted_indices = sorted(val_indices)
#             dataset[len(train_indices):] = infile[key][sorted_indices]

# Sorted batch copy

In [9]:
# with h5py.File(config.h5_file, 'r') as infile:
#     keys = list(infile.keys())

#     with h5py.File(test_filepath, 'w') as outfile:
#         length = len(test_indices)
#         for key in keys:
#             original_shape = infile[key].shape
#             original_dtype = infile[key].dtype
#             new_shape = (length, ) + original_shape[1:]
            
#             dataset = outfile.create_dataset(key, shape=new_shape, dtype=original_dtype)
            
#             sorted_indices = sorted(test_indices)
#             i = 0
#             while i < len(sorted_indices):
#                 end = max(len(sorted_indices), i+128)
#                 dataset[i:end] = infile[key][sorted_indices[i:end]]
#                 i = end
    
#     with h5py.File(train_filepath, 'w') as outfile:
#         length = len(train_indices) + len(val_indices)
#         for key in keys:
#             original_shape = infile[key].shape
#             original_dtype = infile[key].dtype
#             new_shape = (length, ) + original_shape[1:]
            
#             dataset = outfile.create_dataset(key, shape=new_shape, dtype=original_dtype)
            
#             sorted_indices = sorted(train_indices)
#             i = 0
#             while i < len(sorted_indices):
#                 end = max(len(sorted_indices), i+128)
#                 dataset[i:end] = infile[key][sorted_indices[i:end]]
#                 i = end
            
            
#             sorted_indices = sorted(val_indices)
#             i = 0
#             while i < len(sorted_indices):
#                 end = max(len(sorted_indices), i+128)
#                 dataset[i+len(train_indices):end+len(train_indices)] = infile[key][sorted_indices[i:end]]
#                 i = end           


# Splits

In [10]:
# splits_dir = os.path.join(config.output_folder, basename + "_splits")
# os.makedirs(splits_dir, exist_ok=True)

# with open(os.path.join(splits_dir, 'train.txt'), 'w') as f:
#     f.writelines(["{}\n".format(i) for i in range(len(train_indices))])
    
# with open(os.path.join(splits_dir, 'val.txt'), 'w') as f:
#     f.writelines(["{}\n".format(i) for i in range(len(train_indices), len(train_indices) + len(val_indices))])