In [2]:
import sys, os, tqdm, glob, os.path as osp
import h5py, numpy as np, pandas as pd, numba
import torch, torch_geometric as tg, multiprocessing as mp
from uuid import uuid4
if '/scratch' not in sys.path: sys.path.append('/scratch')

### Load graph from file

In [43]:
# @timer
def get_graph(f, idx):
    
    # The event table contains one entry per event
    evt = f['event_table']
    if idx >= evt['event'].shape[0]: raise Exception(f'Graph {idx} larger than file size {evt.shape[0]}')
    graph = f['graph_table']

    # Get event information
    run = evt['run'][idx].squeeze()
    subrun = evt['subrun'][idx].squeeze()
    event = evt['event'][idx].squeeze()

    # Pull out all graph nodes associated with that event
    cut = (graph['run'][:,0] == run) & (graph['subrun'][:,0] == subrun) & (graph['event'][:,0] == event)

    return pd.DataFrame(np.array([graph[key][cut,0] for key in graph.keys()]).T, columns=list(graph.keys()))

In [44]:
# @timer
def get_particle_tree(f, idx):
    
    # The event table contains one entry per event
    evt = f['event_table']
    if idx >= evt['event'].shape[0]: raise Exception(f'Graph {idx} larger than file size {evt.shape[0]}')
    part = f['particle_table']

    # Get event information
    run = evt['run'][idx].squeeze()
    subrun = evt['subrun'][idx].squeeze()
    event = evt['event'][idx].squeeze()

    # Pull out all graph nodes associated with that event
    cut = (part['run'][:,0] == run) & (part['subrun'][:,0] == subrun) & (part['event'][:,0] == event)

    return pd.DataFrame(np.array([part[key][cut,0] for key in part.keys()]).T, columns=list(part.keys()))

### Define edges

In [59]:
@numba.jit
def _create_edges(wire, time):
    edges = []
    for start, (ws, ts) in enumerate(zip(wire, time)):
        for end, (we, te) in enumerate(zip(wire, time)):
            if start == end: continue # no self loops
            elif ws >= we or we-ws > 5: continue # within 5 wires
            elif abs(te-ts) > 50: continue # within 50 time ticks
            edges.append((start, end))
    return np.array(edges).T

def create_edges(df):
    wire = df.wire.to_list()
    time = df.time.to_list()
    return _create_edges(wire, time)

### Manipulate dataframes and get truth

In [60]:
# @timer
def get_primaries(df):
    
    # Get the primary particle ID for each particle in the hierarchy
    primaries = []
    df_dict = df.to_dict(orient='list')
    parent_dict = { key: val for key, val in zip(df_dict['id'], df_dict['parent_id'])}
    
    for id, parent in parent_dict.items():
        tmp_id = id
        while True:
            # If we've found a primary then quit
            if parent_dict[tmp_id] == 0: break
            # Otherwise walk back a step along the particle hierarchy
            tmp_id = parent_dict[tmp_id]
        primaries.append(tmp_id)
    df['primary'] = primaries
    
    # Go from primary ID to primary type
    primtypes = { row.id: row.type for _, row in df[(df.parent_id == 0)].iterrows() }
    types = [ primtypes[primary] for primary in primaries ]
    
    for i in range(len(types)):
        if abs(types[i]) == 11: types[i] = 1
#             if df['parent_id'][i] != 0: types[i] = 1 # EM shower
#             else: types[i] = 2   # EM shower root
        elif abs(types[i]) == 13: types[i] = 2 # muon track
        else: types[i] = 3 # Everything else
            
    df['truth'] = types
        
    return df

In [77]:
# @numba.jit
def _add_to_graph(true_id, truth, primary, graph_id):
    '''Take truth information from particle-wise dataframe, and add to hit-wise dataframe'''
    truth_dict = { key: val for key, val in zip(true_id, truth) }
    primary_dict = { key: val for key, val in zip(true_id, primary) }
    graph_truth = [ truth_dict[id] for id in graph_id ]
    graph_primary = [ primary_dict[id] for id in graph_id ]
    return graph_truth, graph_primary

# @timer
def add_to_graph(df_graph, df_part):
    '''Take truth information from particle-wise dataframe, and add to hit-wise dataframe'''
    true_id = df_part.id.to_list()
    truth = df_part.truth.to_list()
    primary = df_part.primary.to_list()
    graph_id = df_graph.true_id.to_list()
    graph_truth, graph_primary = _add_to_graph(true_id, truth, primary, graph_id)
    df_graph['truth'] = graph_truth
    df_graph['primary'] = graph_primary
    return df_graph

In [70]:
# @timer
@numba.jit
def get_truth(graph, part, edges):
    
    hit_truth = graph.truth.to_list()
    hit_true_id = graph.true_id.to_list()
    hit_primary = graph.primary.to_list()
    
    truth = []
    
    for e_in, e_out in edges.T:
        
        if hit_truth[e_in] == 1: # if em shower
            if hit_primary[e_in] == hit_primary[e_out]: truth.append(hit_truth[e_in])
            else: truth.append(0)
        else:
            if hit_true_id[e_in] == hit_true_id[e_out]: truth.append(hit_truth[e_in])
            else: truth.append(0)

    return np.array(truth).T

### Processing a single file
This function loops over a single HDF5 file and processes it into input PyTorch files.

In [49]:
def timer(func):
    from timeit import default_timer
    def wrapper_timer(*args, **kwargs):
        start = default_timer()
        value = func(*args, **kwargs)
        print('executing', func.__name__, 'took', default_timer()-start, 'seconds')
        return value
    return wrapper_timer

In [81]:
# @timer
def process(file):
    with h5py.File(file, 'r') as f:
        for idx in range(f['event_table']['event'].shape[0]):
            graph = get_graph(f, idx)
            part = get_particle_tree(f, idx)
            part = get_primaries(part)
            graph = add_to_graph(graph, part)
            
            for plane in range(3):
                graph_plane = graph[(graph.plane==plane)].reset_index(drop=True)
                if graph_plane.shape[0] < 50: continue
#                 tmp = tg.data.Data(pos=torch.FloatTensor(graph_plane.loc[:, ['wire', 'time']].values))
#                 tmp = tg.transforms.Delaunay()(tmp)
#                 edges = tg.transforms.FaceToEdge()(tmp).edge_index
                edges = create_edges(graph_plane)
                truth = get_truth(graph_plane, part, edges)

                x = graph_plane.loc[:, ['plane', 'wire', 'time',
                                         'tpc', 'rawplane', 'rawwire', 'rawtime', 
                                         'integral', 'rms' ]].values

                data = {'x': x, 'edge_index': edges, 'y': truth}
                data = tg.data.Data(**data)

                torch.save(data, f'/data/hit2d/processed-flav-shower/{uuid4()}.pt')

### Processing the dataset

We get a list of all the H5 files, and then map them to a pool of processes to be processed

In [83]:
nonswap = glob.glob('/data/hit2d/nonswap/*.h5')
fluxswap = glob.glob('/data/hit2d/fluxswap/*.h5')
files = nonswap + fluxswap

with mp.Pool(processes=50) as pool: pool.map(process, files)
# process(fluxswap[0])

KeyError: 0.0

In [80]:
!ls /data/hit2d/processed-flav-shower/*.pt | wc -l

ls: cannot access '/data/hit2d/processed-flav-shower/*.pt': No such file or directory
0


### Workspace
What's in the HDF5 file? Is there a true particle table, and if so, how is it structured?

In [14]:
nonswap = glob.glob('/data/hit2d/nonswap/*.h5')
name = nonswap[0]
with h5py.File(name, 'r') as file:
    print(file['particle_table']['type'][()])
#     print(file['particle_table']['trie_id']

[[  13]
 [2112]
 [2212]
 ...
 [  11]
 [  11]
 [  11]]


In [None]:
import matplotlib.pyplot as plt
import matplotlib.collections as mc

for j, file in enumerate(glob.glob('/data/hit2d/processed-flav-shower-delaunay/*.pt')):
    
    
    data = torch.load(file)
    wire = data.x[:,1]
    time = data.x[:,2]
    tpc = data.x[:,3]
    
    lines = [ [ [ wire[edge[0]], time[edge[0]] ], [ wire[edge[1]], time[edge[1]] ] ] for edge in data.edge_index.T ]

    # Edge plot
    lines_class = [ [], [], [], [], [] ]
    colours = ['gainsboro', 'red', 'green', 'blue', 'yellow' ]
    for l, y in zip(lines, data.y):
        lines_class[y].append(l)
    lcs = []
    for i in range(5): lcs.append(mc.LineCollection(lines_class[i], colors=colours[i], linewidths=2, zorder=1))
    fig, ax = plt.subplots(figsize=[16,9])
    for lc in lcs: ax.add_collection(lc)
    ax.autoscale()
    plt.tight_layout()
    plt.savefig(f'plots/flav-shower-delaunay/evt{j+1}_edges.png')
    plt.close()

In [None]:
!rm /data/hit2d/processed-flav/*.pt
!rm plots/flav/*

In [None]:
!ls /data/hit2d/processed-flav | wc -l

In [None]:
!mkdir /data/hit2d/processed-flav-debug

In [None]:
for name in glob.glob('/data/hit2d/processed-flav-showers/*.pt'):
    data = torch.load(name)
    if len(data.y.shape) > 1:
        print('shape of y is', data.y.shape)
        os.remove(name)