# Dataset Index Generation
Generates indices for train, validation and test sets

In [None]:
import h5py
# from progressbar import *
import re
import numpy as np

## Options

In [None]:
n_test_files = 400
n_val_files = 100
labels = (0, 1)

--- Logging error ---
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 461, in dispatch_queue
    await self.process_one()
  File "/opt/conda/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 450, in process_one
    await dispatch(*args)
TypeError: object NoneType can't be used in 'await' expression

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/logging/__init__.py", line 1089, in emit
    self.flush()
  File "/opt/conda/lib/python3.8/logging/__init__.py", line 1069, in flush
    self.stream.flush()
BrokenPipeError: [Errno 32] Broken pipe
Call stack:
  File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.8/site-packages/ipykernel_laun

## Load dataset

In [2]:
data_path = "/scratch/jgao/data/HKHybrid/HKHybrid_e-gamma_E0to1000MeV_unif-pos-R3240-y3287cm_4pi-dir_6Mevts_w_mPMT.hdf5"
f = h5py.File(data_path, "r")

In [3]:
event_labels = np.array(f['labels'])
root_files = np.array(f['root_files']).astype(str)
print(len(event_labels))

NameError: name 'np' is not defined

## Find the files of each label and indices of each file

In [5]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

In [6]:
files_in_labels = {l: sorted(set(root_files[event_labels==l]), key=natural_keys) for l in labels}
idxs_in_files = {f: range(i, i+c) for f,i,c in zip(*np.unique(root_files, return_index=True, return_counts=True))}

In [7]:
for l, f in files_in_labels.items():
    print("label", l,"has", len(f),"files and ", sum([len(idxs_in_files[i]) for i in f]), "indices")

label 0 has 1000 files and  3000000 indices
label 1 has 1000 files and  3000000 indices


## Create the splits

In [8]:
split_files = {"test_idxs":  [f for l in labels for f in files_in_labels[l][:n_test_files]],
               "val_idxs":   [f for l in labels for f in files_in_labels[l][n_test_files:n_test_files+n_val_files]],
               "train_idxs": [f for l in labels for f in files_in_labels[l][n_test_files+n_val_files:]]}
split_idxs = {k: [i for f in v for i in idxs_in_files[f]] for k, v in split_files.items()}

In [9]:
for s in split_files.keys():
    print(s,"has", len(split_files[s]),"files and", len(split_idxs[s]),"indices")

test_idxs has 800 files and 2400000 indices
val_idxs has 200 files and 600000 indices
train_idxs has 1000 files and 3000000 indices


In [10]:
# Verify that all events are uniquely accounted for
all_indices = np.concatenate(list(split_idxs.values()))
print(len(event_labels))
print(len(all_indices))
print(len(set(all_indices)))

6000000
6000000
6000000


## Save file

In [11]:
np.savez('/scratch/jgao/data/HKHybrid/e-_n_gamma_idxs.npz', **split_idxs)