View Flash Match data stored in PetaStorm PySpark Format

In [None]:
import ROOT as rt
import numpy as np
from larcv import larcv
from larflow import larflow
import chart_studio as cs
import chart_studio.plotly as py
import plotly.graph_objects as go
from ctypes import c_int
larcv.load_pyutil()

In [None]:
# library with functions for plotting data
import lardly

# load utility to draw TPC outline
from lardly import DetectorOutline 
detdata = DetectorOutline()
detlines = detdata.getlines(color=(10,10,10))

# utility to plot opdet data
from lardly.ubdl.ubplot_opdet import make_opdet_plot

# PARTICLE LABEL COLORS
# from larcv/core/DataFormat/DataFormatTypes.h
#     kROIUnknown=0, ///< LArbys
#     kROICosmic,    ///< Cosmics
#     kROIBNB,       ///< BNB
#     kROIEminus,    ///< Electron
#     kROIGamma,     ///< Gamma
#     kROIPizero,    ///< Pi0
#     kROIMuminus,   ///< Muon
#     kROIKminus,    ///< Kaon
#     kROIPiminus,   ///< Charged Pion
#     kROIProton,    ///< Proton
#     kROITypeMax    ///< enum element counter
ssnetcolor = {0:np.array((0,0,0)),     # kROIUnknown                                                                                                                                                   
              1:np.array((255,0,0)),   # kROICosmic (not used)                                                                                                                                       
              2:np.array((0,255,0)),   # kROIBNB (not used)                                                                                                                             
              3:np.array((0,0,255)),   # kROIEminus (e-/e+)                                                                                                                                              
              4:np.array((255,0,255)), # kROIGamma                                                                                                                                                 
              5:np.array((0,255,255)), # kROIPizero                                                                                                                                            
              6:np.array((255,255,0)), # kROImuminus (mu-/mu+)
              7:np.array((123,300,10)),# kROIKminus (k+/k-)
              8:np.array((204,204,255)), # kROIPiminus (pi+/pi-)
              9:np.array((255, 165, 0))} # kProton

kpcolors = {0:np.array((255,0,0)), # nu
            1:np.array((0,255,0)), # track-start
            2:np.array((0,0,255)), # track-end
            3:np.array((255,255,0)), # shower
            4:np.array((0,255,255)), # michel
            5:np.array((255,0,255))} # delta


In [None]:
# LOAD PETASTORM PYTORCH DATALOADER INTERFACE
from flashmatchdata import make_dataloader

dataset_folder = 'file:///tmp/test_flash_dataset'

loader = make_dataloader( dataset_folder, num_epochs=1, shuffle_rows=False, batch_size=1 )
loader_iter = iter(loader)


In [None]:

batch = next(loader_iter)
print(batch.keys())
print('Event %d, Match Index %d'%(batch['event'][0],batch['matchindex'][0]))

flashpe = batch['flashpe'][0]
print(flashpe.shape)

coord = batch['coord'][0]
feat  = batch['feat'][0]
featz = feat[:,2]

# convert to positions
voxel_size_cm = 5.0
pos = coord*voxel_size_cm
pos[:,1] -= 120.0

opdet_traces = make_opdet_plot( flashpe )
detlines = detdata.getlines(color=(10,10,10))
plot_traces = detlines + opdet_traces

# PLOT CHARGE

plot_q = {
    "type":"scatter3d",
    "x": pos[:,0],
    "y": pos[:,1],
    "z": pos[:,2],
    "mode":"markers",
    "name":"Charge",
    "marker":{"color":featz,"size":2,"opacity":0.5,'colorscale':'Viridis'},
}
plot_traces += [plot_q]

# LAYOUT
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='Particle Instance Labels',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 3},
        "camera": {"eye": {"x": -2, "y": 0.25, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=plot_traces, layout=layout)
fig.show()
