In [8]:
import sys, os, tqdm, glob, os.path as osp
import h5py, numpy as np, pandas as pd
import torch, torch_geometric as tg, multiprocessing as mp
from uuid import uuid4
if '/scratch' not in sys.path: sys.path.append('/scratch')

### Load graph from file

In [1]:
def get_graph(f, idx):
    
    # The event table contains one entry per event
    evt = f['event_table']
    if idx >= evt['event'].shape[0]: raise Exception(f'Graph {idx} larger than file size {evt.shape[0]}')
    graph = f['graph_table']

    # Get event information
    run = evt['run'][idx].squeeze()
    subrun = evt['subrun'][idx].squeeze()
    event = evt['event'][idx].squeeze()

    # Pull out all graph nodes associated with that event
    cut = (graph['run'][:,0] == run) & (graph['subrun'][:,0] == subrun) & (graph['event'][:,0] == event)

    return pd.DataFrame(np.array([graph[key][cut,0] for key in graph.keys()]).T, columns=list(graph.keys()))

In [2]:
def get_particle_tree(f, idx):
    
    # The event table contains one entry per event
    evt = f['event_table']
    if idx >= evt['event'].shape[0]: raise Exception(f'Graph {idx} larger than file size {evt.shape[0]}')
    part = f['particle_table']

    # Get event information
    run = evt['run'][idx].squeeze()
    subrun = evt['subrun'][idx].squeeze()
    event = evt['event'][idx].squeeze()

    # Pull out all graph nodes associated with that event
    cut = (part['run'][:,0] == run) & (part['subrun'][:,0] == subrun) & (part['event'][:,0] == event)

    return pd.DataFrame(np.array([part[key][cut,0] for key in part.keys()]).T, columns=list(part.keys()))

### Define edges

In [3]:
def create_edges(df):
    edges = []
    for idx, node in df.iterrows():
        start = idx
        cut_wire = (df.wire - node.wire > 0) & (df.wire - node.wire <= 5)
        cut_time = (abs(df.time-node.time) < 50)
        end = df[cut_wire & cut_time].index[:]
        for e in end: edges.append((start, e))
    return np.array(edges).T

### Get truth

In [4]:
def get_truth(graph, part, edges):
    
    truth = []
    
    for e_in, e_out in edges.T:
        true_in = graph.true_id[e_in]
        true_out = graph.true_id[e_out]
        if true_in == true_out: truth.append(part[(part.id == true_in)].squeeze().truth)
        else: truth.append(3)

    return np.array(truth).T

### Testing

So we want to get a mapping from primary to type, and then apply that to the primary for each hit

In [5]:
def get_primaries(df):
    
    # Get the primary particle ID for each particle in the hierarchy
    primaries = []
    for _, row in df.iterrows():
        tmp_id = row.id
        while True:
            # Get this particle from the table
            part = df[(df.id == tmp_id)].squeeze()
            # If we've found a primary then quit
            if part.parent_id == 0: break
            # Otherwise walk back a step along the particle hierarchy
            tmp_id = part.parent_id
        primaries.append(tmp_id)
    df['primary'] = primaries
    
    # Go from primary ID to primary type    
    primtypes = { row.id: row.type for _, row in df[(df.parent_id == 0)].iterrows() }
    types = [ primtypes[row.primary] for _, row in df.iterrows() ]
    
    for i in range(len(types)):
        if abs(types[i]) == 11: types[i] = 0   # EM shower
        elif abs(types[i]) == 13: types[i] = 1 # muon track
        else: types[i] = 2 # Everything else
            
    df['truth'] = types
        
    return df

### Processing a single file
This function loops over a single HDF5 file and processes it into input PyTorch files.

In [6]:
def process(file):
    with h5py.File(file, 'r') as f:
        for idx in range(f['event_table']['event'].shape[0]):
            graph = get_graph(f, idx)
            part = get_particle_tree(f, idx)

            part = get_primaries(part)
            
            for plane in range(3):
                graph_plane = graph[(graph.plane==plane)].reset_index(drop=True)
                edges = create_edges(graph_plane)
                truth = get_truth(graph_plane, part, edges)

                x = graph_plane.loc[:, ['plane', 'wire', 'time',
                                         'tpc', 'rawplane', 'rawwire', 'rawtime', 
                                         'integral', 'rms' ]].values
                
                if x.shape[0] < 50: continue

                data = {'x': x, 'edge_index': edges, 'y': truth}
                data = tg.data.Data(**data)

                torch.save(data, f'/data/hit2d/processed-flav/{uuid4()}.pt')

### Processing the dataset
We get a list of all the H5 files, and then map them to a pool of processes to be processed

In [9]:
nonswap = glob.glob('/data/hit2d/nonswap/*.h5')
fluxswap = glob.glob('/data/hit2d/fluxswap/*.h5')
files = nonswap + fluxswap

# with mp.Pool(processes=50) as pool: pool.map(process, files)
process(fluxswap[0])

KeyboardInterrupt: 

### Remove low-hit graphs

In [8]:
def check_file(filename):
    data = torch.load(filename)
    if data.x.shape[0] < 5: os.remove(filename)

processed_files = glob.glob('/data/hit2d/processed/*.pt')
with mp.Pool(processes=50) as p: p.map(check_file, processed_files)

### Workspace
What's in the HDF5 file? Is there a true particle table, and if so, how is it structured?

In [11]:
nonswap = glob.glob('/data/hit2d/nonswap/*.h5')
name = nonswap[0]
with h5py.File(name, 'r') as file:
    print(file['particle_table'].keys())

<KeysViewHDF5 ['end_process', 'end_x', 'end_y', 'end_z', 'event', 'id', 'momentum', 'parent_id', 'run', 'start_process', 'start_x', 'start_y', 'start_z', 'subrun', 'type']>
