In [1]:
import numpy as np
import h5py
import random
import sys
import os
sys.path.append("/data")
sys.path.append("..")

In [2]:
mg_file = h5py.File("/data/22Mg/point_clouds/simulated/output_digi_HDF_Mg22_Ne20pp_8MeV.h5")
o_file = h5py.File("/data/16O/point_clouds/simulated/output_digi_HDF_2Body_2T.h5")
mg_keys = list(mg_file.keys())
o_keys = list(o_file.keys())
print(f"{len(mg_keys)} 22Mg events")
print(f"{len(o_keys)} 16O events")

10000 22Mg events
10000 16O events


In [3]:
mg_lens = np.ndarray((len(mg_keys)), dtype=int)
o_lens = np.ndarray((len(o_keys)), dtype=int)

for i, k in enumerate(mg_keys):
    mg_lens[i] = len(mg_file[k])
for i, k in enumerate(o_keys):
    o_lens[i] = len(o_file[k])

In [4]:
print(os.getcwd())

/home/DAVIDSON/bewagner/data/22Mg_16O_combo/simulated


In [5]:
MIN_LEN = 50
MAX_LEN = 750
RNG = np.random.default_rng()

if not os.path.exists("/home/DAVIDSON/bewagner/data/variable_length/22Mg"):
    os.mkdir("/home/DAVIDSON/bewagner/data/variable_length/22Mg")
for i, k in enumerate(mg_file):
    if mg_lens[i] < MIN_LEN or mg_lens[i] > MAX_LEN:
        continue
    event = np.ndarray((mg_lens[i], 4))
    for idx, p in enumerate(mg_file[k]):
        event[idx, 0] = p[0]
        event[idx, 1] = p[1]
        event[idx, 2] = p[2]
        event[idx, 3] = p[4]
    name = f"/home/DAVIDSON/bewagner/data/variable_length/22Mg/{random.getrandbits(128):032x}.npy"
    np.save(name, event)

if not os.path.exists("/home/DAVIDSON/bewagner/data/variable_length/16O"):
    os.mkdir("/home/DAVIDSON/bewagner/data/variable_length/16O")
for i, k in enumerate(o_file):
    if o_lens[i] < MIN_LEN or o_lens[i] > MAX_LEN:
        continue
    event = np.ndarray((o_lens[i], 4))
    for idx, p in enumerate(o_file[k]):
        event[idx, 0] = p[0]
        event[idx, 1] = p[1]
        event[idx, 2] = p[2]
        event[idx, 3] = p[4]
    name = f"/home/DAVIDSON/bewagner/data/variable_length/16O/{random.getrandbits(128):032x}.npy"
    np.save(name, event)


In [6]:
mg_names = os.listdir("/home/DAVIDSON/bewagner/data/variable_length/22Mg")
o_names = os.listdir("/home/DAVIDSON/bewagner/data/variable_length/16O")

In [7]:
name_arr = np.ndarray((len(mg_names) + len(o_names)), dtype=[('hash', 'object'), ('exp', 'object')])

i = 0
for f in mg_names:
    name_arr[i] = (f.split('.')[0], "22Mg")
    i += 1

for f in o_names:
    name_arr[i] = (f.split('.')[0], "16O")
    i += 1

print(name_arr[:5])

[('a0000976afc6b723c88d004aa54d4fd9', '22Mg')
 ('ef5e3d9663f7100a1c290b889e7d3a6e', '22Mg')
 ('bbcab25a488512a699f6f674948e50f7', '22Mg')
 ('3c0a82aaf497d4046e4547eabb04c179', '22Mg')
 ('2649434426a06def6cb129cc7eabedea', '22Mg')]


In [8]:
RNG.shuffle(name_arr)
print(name_arr[:5])

[('c780bd55fc310243264c0c476b90120a', '22Mg')
 ('90307d622ecd8e83fa43457b3e7c9558', '22Mg')
 ('e5f14c34ae354946e7c84deae645087a', '22Mg')
 ('9953160d36d34b149820e911d28dfd3a', '22Mg')
 ('93d33fc02be90918a8c95cefa5428df4', '22Mg')]


In [9]:
num_train = int(0.6 * len(name_arr))
num_val = int(0.2 * len(name_arr))

train = name_arr[:num_train]
val = name_arr[num_train:num_train+num_val]
test = name_arr[num_train+num_val:]

print(len(train), len(val), len(test))
print((len(train) + len(val) + len(test)) == len(name_arr))

4845 1615 1615
True


In [None]:
with open("../../variable_length/category_file.json", 'w') as json:

    json.write("[\n")

    json.write("\t{\n")
    
    json.write("\t\t\"experiment\": \"22Mg\",\n")
    json.write("\t\t\"train\": [\n")

    for event in train[:-1]:
        if event['exp'] != "22Mg":
            continue
        json.write(f"\t\t\t\"{event['hash']}\",\n")
    if train[-1]['exp'] == "22Mg":
        json.write(f"\t\t\t\"{train[-1]['hash']}\"\n")
    json.write("\t\t],\n")

    json.write("\t\t\"val\": [\n")
    for event in val[:-1]:
        if event['exp'] != "22Mg":
            continue
        json.write(f"\t\t\t\"{event['hash']}\",\n")
    if val[-1]['exp'] == "22Mg":
        json.write(f"\t\t\t\"{val[-1]['hash']}\"\n")
    json.write("\t\t],\n")

    json.write("\t\t\"test\": [\n")
    for event in test[:-1]:
        if event['exp'] != "22Mg":
            continue
        json.write(f"\t\t\t\"{event['hash']}\",\n")
    if test[-1]['exp'] == "22Mg":
        json.write(f"\t\t\t\"{test[-1]['hash']}\"\n")
    json.write("\t\t]\n")

    json.write("\t},\n")

    json.write("\t{\n")
    json.write("\t\t\"experiment\": \"16O\",\n")
    json.write("\t\t\"train\": [\n")

    for event in train[:-1]:
        if event['exp'] != "16O":
            continue
        json.write(f"\t\t\t\"{event['hash']}\",\n")
    if train[-1]['exp'] == '16O':
        json.write(f"\t\t\t\"{train[-1]['hash']}\"\n")
    json.write("\t\t],\n")

    json.write("\t\t\"val\": [\n")
    for event in val[:-1]:
        if event['exp'] != "16O":
            continue
        json.write(f"\t\t\t\"{event['hash']}\",\n")
    if val[-1]['exp'] == '16O':
        json.write(f"\t\t\t\"{val[-1]['hash']}\"\n")
    json.write("\t\t],\n")

    json.write("\t\t\"test\": [\n")
    for event in test[:-1]:
        if event['exp'] != "16O":
            continue
        json.write(f"\t\t\t\"{event['hash']}\",\n")
    if test[-1]['exp'] == '16O':
        json.write(f"\t\t\t\"{test[-1]['hash']}\"\n")
    json.write("\t\t]\n")

    json.write("\t}\n")

    json.write("]")

In [18]:
import json
with open("./variable_length/category_file.json", 'r') as j:
    jason = json.load(j)

print(jason)

[{'experiment': '22Mg', 'train': ['c780bd55fc310243264c0c476b90120a', '90307d622ecd8e83fa43457b3e7c9558', 'e5f14c34ae354946e7c84deae645087a', '9953160d36d34b149820e911d28dfd3a', '93d33fc02be90918a8c95cefa5428df4', '0907df1c3ada20aee2f56ea060c68e1b', 'd8b8f17208a7da715770965b976540df', 'eb19a634009b915651ca96179b4b14d7', '3cdc01685996d32d29974a48303c0c1c', 'efae20333815566e393f737d057e3f7f', '2f4956433a69cb8f40e68675ee4847ee', '007bbe3a4a6d1ba7469b2409aac39e9f', '5a318c756ad991254f6baf8ede70db52', '830d6ffc8ba1118eb525aed1b7a2616a', 'c70db2b3d1771ae3cd9040b5f4648a61', '65da9b1ddcaff3ac8b36482d91da5c65', 'e4ece5616ccf454f89d05351b38c66f4', '5b33d853022dc5f5cecfd12dffe55e74', '873d38724739b2d6ba315f4d12324a09', '3f7403d217a8e757a7d4916f4a72e762', '98adbaa787793b0536945a59987d4709', 'f47854dc18a523d311d9e1d9bbdf3fc7', '69437a52623e5908e0cdbd8ace209666', '205cb0bec3e30a67db71a8c05db97f11', 'd42677cec3cb90d3b4962dc9cdff66d5', '0469bbb1897147760a28c451207af380', '982de87f5d205367c9531e566bf03