In [None]:
from __future__ import print_function
import os,sys
import torch
import ROOT as rt
from larlite import larlite
from ROOT import larutil
import chart_studio.plotly as py
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import lardly
sys.path.append("../") 
from larmatch.larmatch_dataset import larmatchDataset
from larmatch.larmatch_mp_dataloader import larmatchMultiProcessDataloader
from larmatch.utils.common import prepare_me_sparsetensor

In [None]:
# LOAD THE FILE/DATASET LOADER
#help(larmatchDataset)

import yaml
stream = open("../config/test_loader.yaml", 'r')
toplevel_config = yaml.load(stream, Loader=yaml.FullLoader)
config = toplevel_config["TRAIN_DATALOADER_CONFIG"]
config["INPUT_FILE"] = "../test.list"

BATCH_SIZE=4
loader = larmatchMultiProcessDataloader(config,
                                        BATCH_SIZE,
                                        num_workers=2,
                                        prefetch_batches=1,
                                        collate_fn=larmatchDataset.collate_fn)

nentries = loader.nentries
print("NENTRIES: ",nentries)

In [None]:
# SETUP THE GEOMETRY

# Set the detector in larlite
#detid = larlite.geo.kICARUS
detid = larlite.geo.kMicroBooNE
#detid = larlite.geo.kSBND
larutil.LArUtilConfig.SetDetector(detid)

# Get geometry class
geo = larlite.larutil.Geometry.GetME()
detp = larutil.DetectorProperties.GetME()

# Get detector outlines from lardly
from lardly.detectors.getdetectoroutlines import DetectorOutlineFromLarlite
detoutline = DetectorOutlineFromLarlite(detid)
detlines = detoutline.getlines()

In [None]:
# DEFINE PARTICLE ID COLORS
particle_id_color = {0:(0,0,0),      # no label
                     1:(255,125,50), # Cosmic
                     2:(0,0,0),      # BNB
                     3:(255,0,0),    # electron
                     4:(0,255,0),    # gamma
                     5:(0,125,125),  # pi0
                     6:(155,0,155),  # Muon
                     7:(255,255,0),  # Kaon
                     8:(255,165,0),  # pion                     
                     9:(0,0,255)}    # proton

ssnet_id_color = {0:(0,0,0),      # no label
                  1:(255,125,50), # electron
                  2:(0,255,0),    # gamma
                  3:(0,0,255),    # muon
                  4:(0,125,125),  # pion
                  5:(155,0,155),  # proton
                  6:(255,255,0)}  # other

particle_id_name = {0:"nolabel",  # no label
                    1:"delta",    # no label
                    2:"nolabel",  # no label
                    3:"electron", # electron
                    4:"gamma",    # gamma
                    5:"pi0",  # pi0
                    6:"muon", # Muon
                    7:"kaon", # Kaon
                    8:"pion", # proton
                    9:"proton"}   # pion

# define some colors for particle types
kp_color_array = np.array( ((255,0,0),    # nu
                            (0,255,0),    # track-start
                            (0,0,255),    # track-end
                            (255,0,255),  # shower
                            (0,255,255),  # shower-michel
                            (255,255,0)),     # shower-delta
                            dtype=np.float64 ) 

In [None]:
# Get some data, set BATCH INDEX
DEVICE=torch.device("cpu")
BATCHINDEX=0
batch = next(iter(loader))
print(batch.keys())
wireplane_sparsetensors, matchtriplet_v, batch_truth, batch_weight \
            = prepare_me_sparsetensor( batch, DEVICE, verbose=True )
print("ENTRIES: ",batch["tree_entry"])
print(matchtriplet_v.shape)
for k in batch_truth:
    print(k,": ",batch_truth[k].shape)
for k in batch_weight:
    print(k,": ",batch_weight[k].shape)

In [None]:
#Get 3D positions
pos = np.zeros( (matchtriplet_v[BATCHINDEX].shape[0],3) )

coord_0 = batch["coord_0"][BATCHINDEX]
coord_1 = batch["coord_1"][BATCHINDEX]
coord_2 = batch["coord_2"][BATCHINDEX]

vec = rt.TVector3(0,0,0)
for i in range( pos.shape[0] ):
    trip = matchtriplet_v[BATCHINDEX,i]
    w1 = coord_0[trip[0],1]
    w2 = coord_1[trip[1],1]
    w3 = coord_2[trip[2],1]
    tick = coord_0[trip[0],0]*6+2400
    #print(w1,",",w2,",",w3)
    ch1 = geo.PlaneWireToChannel(int(w1),0,0,0 )
    ch2 = geo.PlaneWireToChannel(int(w2),1,0,0 )
    ch3 = geo.PlaneWireToChannel(int(w3),2,0,0 )
    geo.ChannelsIntersect(ch1,ch2,vec)
    for v in range(3):
        pos[i,v] = vec[v]
    pos[i,0] = detp.ConvertTicksToX(tick,0,0,0)
print("made position array: ",pos.shape)

In [None]:
# PLOT SPACEPOINT TRUTH
TRUE_ONLY=False

sp_plots = []
truth =  batch_truth['lm'][BATCHINDEX]
print("truth tensor: ",truth.shape)
if TRUE_ONLY:
    xpos = pos[truth.get(0.5),:]
    truth =  batch_truth['lm'][BATCHINDEX][truth.get(0.5)]
else:
    xpos = pos
ntruth = (batch_truth['lm'][BATCHINDEX]==1).sum()
nfalse = (batch_truth['lm'][BATCHINDEX]==0).sum()
print("ntruth=",ntruth," nfalse=",nfalse," ratio: ",float(ntruth)/float(nfalse))
    
plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2],
                    text=np.arange(0,xpos.shape[0]),
                    #hoverinfo="text",
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,"color":truth.squeeze(),"colorscale":"Viridis"} )  
sp_plots.append(plot)
    
axis_template = {
    "showbackground": True,
    #"backgroundcolor": "rgba(100, 100, 100,0.5)",
    "backgroundcolor": "rgba(0, 0, 0,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT LARMATCH LABELS TRUTH
TRUE_ONLY=False
LABEL_CUT=False

sp_plots = []
lmlabel = batch_truth['lm'][BATCHINDEX].squeeze()
truth   = batch_truth['lm-hardlabel'][BATCHINDEX].squeeze()
print("lmlabel: ",lmlabel.shape)
print("truth (hardlabel): ",truth.shape)

if TRUE_ONLY:
    xpos = pos[truth==1,:]
    truth =  batch[0]['larmatch_truth'][truth==1]
else:
    xpos = pos
    
    
if LABEL_CUT:
    xpos    = xpos[ lmlabel>0.5, :]
    lmlabel = lmlabel[ lmlabel>0.5 ]
#ntruth = (batch[0]['larmatch_truth']==1).sum()
#nfalse = (batch[0]['larmatch_truth']==0).sum()
#print("ntruth=",ntruth," nfalse=",nfalse," ratio: ",float(ntruth)/float(nfalse))
print("min: ",torch.min(lmlabel))
print("max: ",torch.max(lmlabel))
    
plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2],
                    text=np.arange(0,xpos.shape[0]),
                    #hoverinfo="text",
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,
                            "color":lmlabel,
                            "cmin":0.0,
                            "cmax":1.0,
                            "colorscale":"Viridis"} )  
sp_plots.append(plot)
    
axis_template = {
    "showbackground": True,
    #"backgroundcolor": "rgba(100, 100, 100,0.5)",
    "backgroundcolor": "rgba(0, 0, 0,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT SPACEPOINT WEIGHTS
TRUE_ONLY=False

sp_plots = []
lmtruth  = batch_truth["lm"][BATCHINDEX].squeeze()
lmweight = batch_weight["lm"][BATCHINDEX].squeeze()
#sstopweight = batch[0]['ssnet_top_weight']
ssclsweight = batch_weight["ssnet"][BATCHINDEX].squeeze()
if TRUE_ONLY:
    xpos = pos[lmtruth.gt(0.5),:]
    lmweight = lmweight[lmtruth.gt(0.5)]
    sstopweight = sstopweight[lmtruth.gt(0.5)]
    ssclsweight = ssclsweight[lmtruth.gt(0.5)]
else:
    xpos = pos
    
num_lm_pos = lmtruth.gt(0.5).sum().cpu().item()
num_lm_neg = lmtruth.lt(0.5).sum().cpu().item()
print("num positive examples: ",num_lm_pos)
print("num negative examples: ",num_lm_neg)
    
print("lm weight min: ",torch.min(lmweight[lmweight>0]))
print("lm weight max: ",torch.max(lmweight))    
print("ssnet class min: ",torch.min(ssclsweight))
print("ssnet class max: ",torch.max(ssclsweight))

plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2], 
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,
                            "color":lmweight,
                            "colorscale":"Viridis"} )  
sp_plots.append(plot)
    
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT Keypoint LABELS
KPTYPE=0
TRUE_ONLY=False
lmtruth = batch[0]['larmatch_truth']
kplabel = batch[0]['keypoint_truth'][KPTYPE,:]
print("kplabel: ",kplabel.shape)
print("NPOS: ",(kplabel>0.05).sum())

# use keypoint sampler
if False:
    xpos = pos[kpallidx[:],:]
    kplabel = kplabel[kpallidx[:]]
    lmtruth = lmtruth[kpallidx[:]]
    print("post-keypoint sampler")
    print("kplabel: ",kplabel.shape)
    print("xpos: ",xpos.shape)
    print("lmtruth: ",lmtruth.shape)
    print("NPOS: ",(kplabel>0.2).sum())

sp_plots = []
if TRUE_ONLY:
    xpos = xpos[lmtruth==1,:]
    kplabel = kplabel[lmtruth==1]
else:
    xpos = xpos

    
plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2], 
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":1.0,"color":kplabel,"colorscale":"Viridis"} )  
sp_plots.append(plot)
    
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(10, 10, 10,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT Keypoint WEIGHTS
KPTYPE=0
TRUE_ONLY=True
lmtruth = batch[0]['larmatch_truth']
kplabel = batch[0]['keypoint_weight'][KPTYPE,:]
print("kplabel: ",kplabel.shape)
print("min weight: ",np.min(kplabel))
print("max weight: ",np.max(kplabel))
print("weight sum: ",np.sum(kplabel))

sp_plots = []
if TRUE_ONLY:
    xpos = pos[lmtruth==1,:]
    kplabel = kplabel[lmtruth==1]
else:
    xpos = pos

plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2], 
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,"color":kplabel,"colorscale":"Viridis"} )  
sp_plots.append(plot)
    
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT SSNET LABEL 
TRUE_ONLY=False

lmtruth = batch[0]['larmatch_truth']
if TRUE_ONLY:
    xpos = pos[lmtruth==1]
    sslabel = batch[0]['ssnet_truth'][lmtruth==1]
else:
    xpos = pos
    sslabel = batch[0]['ssnet_truth']

uniqueclasses = np.unique(sslabel)
sp_plots = []
for c in uniqueclasses:
    print("num in class[",c,"]: ",(sslabel==c).sum())
    zcolor = "rgb(%d,%d,%d)"%(ssnet_id_color[c][0],ssnet_id_color[c][1],ssnet_id_color[c][2])
    plot = go.Scatter3d( x=xpos[sslabel==c,0], y=xpos[sslabel==c,1], z=xpos[sslabel==c,2], 
                        mode="markers", name="%d"%(c), 
                        marker={"size":1.0,"opacity":0.5,"color":zcolor} )  
    sp_plots.append(plot)
#     sp_plots.append(plot)
#     zdata = None
#     if PLOT in ["ssnet_label_t","ssnet_topweight_t","ssnet_classweight_t"]:
#         zdata = ssnetdict[PLOT]
#     elif PLOT=='ssnet_totalweight':
#         zdata = ssnetdict['ssnet_topweight_t']*ssnetdict['ssnet_classweight_t']
    
#     if TRUE_ONLY:
#         pos = pos[truth==1,:]
#         zdata = zdata[truth==1]
    



axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT SSNET WEIGHT 
TRUE_ONLY=False

sp_plots = []

lmtruth = batch[0]['larmatch_truth']
if TRUE_ONLY:
    xpos = pos[lmtruth==1]
    ssweight = batch[0]['ssnet_class_weight'][lmtruth==1]
else:
    xpos = pos
    ssweight = batch[0]['ssnet_class_weight']
plot = go.Scatter3d( x=xpos[:,0], y=xpos[:,1], z=xpos[:,2], 
                    mode="markers", name="", 
                    marker={"size":1.0,"opacity":0.5,"color":ssweight,"colorscale":"Viridis"} )  
sp_plots.append(plot)

print("min weight: ",np.min(ssweight))
print("max weight: ",np.max(ssweight))

axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='DETECTOR TPC',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 2},
        "camera": {"eye": {"x": -3, "y": 0.1, "z": 0.0},
                   "center":dict(x=0, y=0, z=0),
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=sp_plots+detlines, layout=layout)
fig.show()

In [None]:
# PLOT Sparse Images
#PLANE=2
INPUT_SHAPE=(1024,3584) # what's given to the model
img_v = []
plane_v = []
for PLANE in range(3):
    img = np.ones( INPUT_SHAPE )*-1.0
    coord = batch[0]['coord_%d'%(PLANE)]
    feat  = batch[0]['feat_%d'%(PLANE)]
    print("coord: ",coord.shape)
    print("feat: ",feat.shape)
    img[ coord[:,0], coord[:,1] ] = feat[:,0]


    xaxis = np.linspace( 0, INPUT_SHAPE[0], endpoint=False, num=INPUT_SHAPE[0] )
    yaxis = np.linspace( 0, INPUT_SHAPE[1], endpoint=False, num=INPUT_SHAPE[1] )

    heatmap = {
        #"type":"heatmapgl",                                                                                                                                                                  
        "type":"heatmap",
        "z":img,
        "x":xaxis,
        "y":yaxis,
        "zmin":-1.0,
        "zmax":5.0,
        "colorscale":"Jet",
        }
    img_v.append(img)
    plane_v.append(heatmap)

fig0 = go.Figure(data=[plane_v[0]])
fig0.show()
fig1 = go.Figure(data=[plane_v[1]])
fig1.show()
fig2 = go.Figure(data=[plane_v[2]])
fig2.show()

In [None]:
# CHECK TRIPLETS BY CROPPING AROUND IMAGE
CROP_SIZE=(16,16) # what's given to the model

TRIPLET_TRUTH = None
ntry = 0
while TRIPLET_TRUTH!=1 and ntry<20:
    TRIPLET_IDX = np.random.randint(0,batch[0]["matchtriplet_v"].shape[0])
    TRIPLET_TRUTH = batch[0]["larmatch_truth"][TRIPLET_IDX]
    ntry += 1

#TRIPLET_IDX = 22483
#TRIPLET_TRUTH = batch[0]["larmatch_truth"][TRIPLET_IDX]

print("IDX: ",TRIPLET_IDX)
trip = batch[0]["matchtriplet_v"][TRIPLET_IDX,:]
print(trip," ",trip.shape)
print("TRIPLET TRUTH: ",TRIPLET_TRUTH)
crop_v = []
plotcrop_v = []
marker_v = []
for PLANE in range(3):
    print("PLANE ",PLANE)
    coord = batch[0]['coord_%d'%(PLANE)][trip[PLANE],:]
    print("coord: ",coord)
    print("feat: ",batch[0]['feat_%d'%(PLANE)][trip[PLANE]])
    xmin = np.maximum(int(coord[0]-CROP_SIZE[0]-1),0)
    xmax = np.minimum(int(coord[0]+CROP_SIZE[0]),INPUT_SHAPE[0])
    print("X: ",xmin," ",xmax," ",xmax-xmin)
    ymin = np.maximum(int(coord[1]-CROP_SIZE[1]-1),0)
    ymax = np.minimum(int(coord[1]+CROP_SIZE[1]),INPUT_SHAPE[1])
    print("Y: ",ymin," ",ymax," ",ymax-ymin)

    crop = img_v[PLANE][xmin:xmax,ymin:ymax]
    #print("crop plane=",PLANE,": ",crop.shape)
    xaxis = np.linspace( ymin, ymax, endpoint=False, num=ymax-ymin+1 )
    yaxis = np.linspace( xmin, xmax, endpoint=False, num=xmax-xmin+1 )
    zaxis = np.linspace( 0, 10, endpoint=False, num=50 )
    color = "rgb(255,255,255)"
    if TRIPLET_TRUTH==0:
        color = "rgb(0,0,0)"
    heatmap = {
        #"type":"heatmapgl",                                                                                                                                                                  
        "type":"heatmap",
        "name":"plane%d"%(PLANE),
        "z":crop,
        "x":xaxis,
        "y":yaxis,
        "zmin":0,
        "zmax":4.0,
        "colorscale":"Jet",
        }
    marker = {
        "type":"scatter",
        "name":None,
        "x":[coord[1]],
        "y":[coord[0]],
        "marker_symbol":[300],
        "mode":"markers",
        "marker":{"size":20,"color":color}
    }
    plotcrop_v.append(heatmap)
    marker_v.append(marker)
    crop_v.append(crop)


fig = make_subplots(rows=1, cols=3, horizontal_spacing=0.01, shared_yaxes=True)
fig.add_trace(plotcrop_v[0],row=1,col=1 )
fig.add_trace(marker_v[0],row=1,col=1 )
fig.add_trace(plotcrop_v[1],row=1,col=2 )
fig.add_trace(marker_v[1],row=1,col=2 )
fig.add_trace(plotcrop_v[2],row=1,col=3 )
fig.add_trace(marker_v[2],row=1,col=3 )