In [None]:
from __future__ import print_function
import os,sys
import torch
import ROOT as rt
from larlite import larlite
from ROOT import larutil
import chart_studio.plotly as py
import plotly.graph_objects as go
import numpy as np
import lardly
sys.path.append("../")
import larmatch_dataset
from larmatch_dataset import larmatchDataset
larmatch_dataset.__file__

In [None]:
# LOAD THE FILE/DATASET LOADER
BATCH_SIZE=1
test = larmatchDataset( filelist=["../test.root"], load_truth=True)
nentries = len(test)
print("NENTRIES: ",nentries)
loader = torch.utils.data.DataLoader(test,batch_size=BATCH_SIZE,
                                     collate_fn=larmatchDataset.collate_fn)

In [None]:
# SETUP THE GEOMETRY

# Set the detector in larlite
#detid = larlite.geo.kICARUS
detid = larlite.geo.kMicroBooNE
#detid = larlite.geo.kSBND
larutil.LArUtilConfig.SetDetector(detid)

# Get geometry class
geo = larlite.larutil.Geometry.GetME()
detp = larutil.DetectorProperties.GetME()

# Get detector outlines from lardly
from lardly.detectors.getdetectoroutlines import DetectorOutlineFromLarlite
detoutline = DetectorOutlineFromLarlite(detid)
detlines = detoutline.getlines()

In [None]:
# DEFINE PARTICLE ID COLORS
particle_id_color = {0:(0,0,0),      # no label
                     1:(255,125,50), # Cosmic
                     2:(0,0,0),      # BNB
                     3:(255,0,0),    # electron
                     4:(0,255,0),    # gamma
                     5:(0,125,125),  # pi0
                     6:(155,0,155),  # Muon
                     7:(255,255,0),  # Kaon
                     8:(255,165,0),  # pion                     
                     9:(0,0,255)}    # proton

particle_id_name = {0:"nolabel",  # no label
                    1:"delta",    # no label
                    2:"nolabel",  # no label
                    3:"electron", # electron
                    4:"gamma",    # gamma
                    5:"pi0",  # pi0
                    6:"muon", # Muon
                    7:"kaon", # Kaon
                    8:"pion", # proton
                    9:"proton"}   # pion

# define some colors for particle types
kp_color_array = np.array( ((255,0,0),    # nu
                            (0,255,0),    # track-start
                            (0,0,255),    # track-end
                            (255,0,255),  # shower
                            (0,255,255),  # shower-michel
                            (255,255,0)),     # shower-delta
                            dtype=np.float64 ) 

In [None]:
# Get some data
batch = next(iter(loader))
for ib,data in enumerate(batch):
    print("BATCH[%d]"%(ib))
    print(" keys: ",data.keys())
    for name,d in data.items():
        if type(d) is np.ndarray:
            print("  ",name,"-[array]: ",d.shape)
        else:
            print("  ",name,"-[non-array]: ",type(d))
    print(data['coord_0'].shape)
    print(np.unique(data['coord_0'],axis=0).shape)

In [None]:
#Get 3D positions
pos = np.zeros( (batch[0]['matchtriplet_v'].shape[0],3) )
vec = rt.TVector3(0,0,0)
for i in range( pos.shape[0] ):
    trip = batch[0]['matchtriplet_v'][i]
    w1 = batch[0]['coord_0'][trip[0],1]
    w2 = batch[0]['coord_1'][trip[1],1]
    w3 = batch[0]['coord_2'][trip[2],1]
    tick = batch[0]['coord_0'][trip[0],0]*6+2400
    #print(w1,",",w2,",",w3)
    ch1 = geo.PlaneWireToChannel(int(w1),0,0,0 )
    ch2 = geo.PlaneWireToChannel(int(w2),1,0,0 )
    ch3 = geo.PlaneWireToChannel(int(w3),2,0,0 )
    geo.ChannelsIntersect(ch1,ch2,vec)
    for v in range(3):
        pos[i,v] = vec[v]
    pos[i,0] = detp.ConvertTicksToX(tick,0,0,0)
print("made position array: ",pos.shape)

In [None]:
# PLOT SPACEPOINT TRUTH
TRUE_ONLY=False

sp_plots = []
truth =  batch[0]['larmatch_truth']
if TRUE_ONLY:
    xpos = pos[truth==1,:]
    truth =  batch[0]['larmatch_truth'][truth==1]
else:
    xpos = pos
    
    
plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2], 
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,"color":truth,"colorscale":"Viridis"} )  
sp_plots.append(plot)
    
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT SPACEPOINT WEIGHTS
TRUE_ONLY=True

sp_plots = []
lmtruth =  batch[0]['larmatch_truth']
lmweight =  batch[0]['larmatch_weight']
if TRUE_ONLY:
    xpos = pos[truth==1,:]
    lmweight = lmweight[truth==1]
    lmtruth =  batch[0]['larmatch_truth'][truth==1]
else:
    xpos = pos
    
print("minweight: ",np.min(lmweight))
print("maxweight: ",np.max(lmweight))

plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2], 
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,"color":lmweight,"colorscale":"Viridis"} )  
sp_plots.append(plot)
    
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT Keypoint LABELS
KPTYPE=0
TRUE_ONLY=True
lmtruth = batch[0]['larmatch_truth']
kplabel = batch[0]['keypoint_truth'][KPTYPE,:]
print("kplabel: ",kplabel.shape)

sp_plots = []
if TRUE_ONLY:
    xpos = pos[lmtruth==1,:]
    kplabel = kplabel[lmtruth==1]
else:
    xpos = pos

plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2], 
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,"color":kplabel,"colorscale":"Viridis"} )  
sp_plots.append(plot)
    
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT Keypoint WEIGHTS
KPTYPE=0
TRUE_ONLY=True
lmtruth = batch[0]['larmatch_truth']
kplabel = batch[0]['keypoint_weight'][KPTYPE,:]
print("kplabel: ",kplabel.shape)
print("min weight: ",np.min(kplabel))
print("max weight: ",np.max(kplabel))

sp_plots = []
if TRUE_ONLY:
    xpos = pos[lmtruth==1,:]
    kplabel = kplabel[lmtruth==1]
else:
    xpos = pos

plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2], 
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,"color":kplabel,"colorscale":"Viridis"} )  
sp_plots.append(plot)
    
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT SSNET WEIGHT 
TRUE_ONLY=True
PLOT = 'ssnet_label_t'
#PLOT = 'ssnet_topweight_t'
#PLOT = 'ssnet_classweight_t'
#PLOT = 'ssnet_totalweight'

lmtruth = batch[0]['larmatch_truth']
if True:
    xpos = pos[lmtruth==1]
    sslabel = batch[0]['ssnet_truth'][lmtruth==1]
else:
    xpos = pos
    sslabel = batch[0]['ssnet_truth']

uniqueclasses = np.unique(sslabel)
sp_plots = []
for c in uniqueclasses:
    print(c)
    zcolor = "rgb(%d,%d,%d)"%(particle_id_color[c][0],particle_id_color[c][1],particle_id_color[c][2])
    plot = go.Scatter3d( x=xpos[sslabel==c,0], y=xpos[sslabel==c,1], z=xpos[sslabel==c,2], 
                        mode="markers", name="%d"%(c), 
                        marker={"size":1.0,"opacity":0.5,"color":zcolor} )  
    sp_plots.append(plot)
#     sp_plots.append(plot)
#     zdata = None
#     if PLOT in ["ssnet_label_t","ssnet_topweight_t","ssnet_classweight_t"]:
#         zdata = ssnetdict[PLOT]
#     elif PLOT=='ssnet_totalweight':
#         zdata = ssnetdict['ssnet_topweight_t']*ssnetdict['ssnet_classweight_t']
    
#     if TRUE_ONLY:
#         pos = pos[truth==1,:]
#         zdata = zdata[truth==1]
    



axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT SSNET WEIGHT 
TRUE_ONLY=True
SSCLASS = 1

sp_plots = []

lmtruth = batch[0]['larmatch_truth']
if TRUE_ONLY:
    xpos = pos[lmtruth==1]
    ssweight = batch[0]['ssnet_weight'][lmtruth==1]
else:
    xpos = pos
    ssweight = batch[0]['ssnet_weight']
plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2], 
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,"color":ssweight,"colorscale":"Viridis"} )  
sp_plots.append(plot)

print("min weight: ",np.min(ssweight))
print("max weight: ",np.max(ssweight))

axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()