In [7]:
import sys, os, tqdm, glob, os.path as osp
import h5py, numpy as np, pandas as pd
import torch, torch_geometric as tg, multiprocessing as mp
from uuid import uuid4
if '/scratch' not in sys.path: sys.path.append('/scratch')

### Load graph from file

In [87]:
def get_graph(f, idx):
    
    # The event table contains one entry per event
    evt = f['event_table']
    if idx >= evt['event'].shape[0]: raise Exception(f'Graph {idx} larger than file size {evt.shape[0]}')
    graph = f['graph_table']

    # Get event information
    run = evt['run'][idx].squeeze()
    subrun = evt['subrun'][idx].squeeze()
    event = evt['event'][idx].squeeze()

    # Pull out all graph nodes associated with that event
    cut = (graph['run'][:,0] == run) & (graph['subrun'][:,0] == subrun) & (graph['event'][:,0] == event)

    return pd.DataFrame(np.array([graph[key][cut,0] for key in graph.keys()]).T, columns=list(graph.keys()))

### Define edges

In [88]:
def create_edges(df):
    edges = []
    for idx, node in df.iterrows():
        start = idx
        cut_wire = (df.wire - node.wire > 0) & (df.wire - node.wire <= 5)
        cut_time = (abs(df.time-node.time) < 50)
        end = df[cut_wire & cut_time].index[:]
        for e in end: edges.append((start, e))
    return np.array(edges).T

### Get truth

In [89]:
def get_truth(graph, edges):
    truth = [ (graph.true_id[e_in] == graph.true_id[e_out]) for e_in, e_out in edges.T ]
    return np.array(truth).T

This is the function which loops over a single ROOT file and processes it into input PyTorch files.

In [92]:
def process(file):
    with h5py.File(file, 'r') as f:
        for idx in range(f['event_table']['event'].shape[0]):
            graph = get_graph(f, idx)
            for plane in range(3):
                graph_plane = graph[(graph.plane==plane)].reset_index(drop=True)
                edges = create_edges(graph_plane)
                truth = get_truth(graph_plane, edges)

                x = graph_plane.loc[:, ['plane', 'wire', 'time',
                                         'tpc', 'rawplane', 'rawwire', 'rawtime', 
                                         'integral', 'rms' ]].values

                true_id = graph_plane['true_id'].values

                data = {'x': x, 'edge_index': edges, 'y': truth, 'true_id': true_id}
                data = tg.data.Data(**data)

                torch.save(data, f'/data/hit2d/processed/{uuid4()}.pt')

In [None]:
nonswap = glob.glob('/data/hit2d/nonswap/*.h5')
fluxswap = glob.glob('/data/hit2d/fluxswap/*.h5')
files = nonswap + fluxswap

with mp.Pool(processes=50) as pool: pool.map(process, files)

### Remove low-hit graphs

In [8]:
def check_file(filename):
    data = torch.load(filename)
    if data.x.shape[0] < 5: os.remove(filename)

processed_files = glob.glob('/data/hit2d/processed/*.pt')
with mp.Pool(processes=50) as p: p.map(check_file, processed_files)