In [None]:
from __future__ import print_function
import os,sys
import chart_studio as cs
import chart_studio.plotly as py
import plotly.graph_objects as go
import numpy as np
from readparquet import UBParquetReader
from detectoroutline import DetectorOutline

In [None]:
import lardly
detdata = lardly.DetectorOutline()

In [None]:
# load utility to draw TPC outline
detdata = DetectorOutline()

# define some colors for particle types
ssnet_color_array = np.array( ((0,0,0),      # bg
                               (255,0,0),    # electron
                               (0,255,0),    # gamma
                               (0,0,255),    # muon
                               (255,0,255),  # pion
                               (0,255,255),  # proton
                               (0,0,0)),     # other
                               dtype=np.float ) 

larcvpid2class = {0:0,#unknown -> bg
                  1:0,#cosmic -> bg
                  2:0,#bnb -> bg
                  3:1,#eminus->electron
                  4:2,#gamma->gamma
                  5:2,#pizero->gamma
                  6:3,#muon->muon
                  7:4,#kaon->other
                  8:4,#pion->pion
                  9:5,#proton->proton
                 }

In [None]:
# DATA FILES

# specify location where parquet file(s) live
datafolder="./data/"

reader = UBParquetReader(datafolder)
NENTRIES = reader.nentries

In [None]:
# Get entry data
ENTRY = 0
data = reader.get_entry(ENTRY)

In [None]:
# Print Entry Info
for col in ['run','subrun','event','truepos','nu_energy','nu_pid','nu_ccnc','nu_interaction','nu_geniemode']:
    print(col,": ",data[col])

# make instance 2 id map
instance2id = {}
for k,i in zip(data['voxinstancelist'],data['voxidlist']):
    instance2id[k.as_py()] = i.as_py()
for primdata in zip(data['primary_trackid'].as_py(),data['primary_pid'].as_py(),data['primary_start'].as_py(),data['primary_mom'].as_py()):
    trackid=primdata[0]
    if trackid in instance2id:
        print('[%d]'%(instance2id[primdata[0]]),": ",primdata)
    
#print(instance2id)

In [None]:
# GENERATE PLOTLY OBJECTS FOR VISUALIZATION
plotdata = []
nvoxels = data["voxcoord"].shape[0]
# We need to retrieved the 3d positions
pos3d = data["voxcoord"].astype(np.float)*1.0
pos3d[:,1] -= 117.0 


In [None]:
# Plot Particle ID (semenatic segmentation network, "SSNet") labels
no_ghost_points = True

ssnetlabels = data["voxssnet"]
unique_labels = np.unique(ssnetlabels)
print('unique labels: ',unique_labels)
color = np.zeros( (pos3d.shape[0],3), dtype=np.float )
for i in range(7):
   color[ ssnetlabels==i ] = ssnet_color_array[ i,:]

print("ssnetlabels: ",ssnetlabels.shape)
ssnetcol = np.zeros((nvoxels,3))
for i in range(6):
   ssnetcol[ ssnetlabels==i ] = ssnet_color_array[i,:]

if no_ghost_points:
    pos = pos3d[data['voxlabel']==1]
    ssnetcol = ssnetcol[data['voxlabel']==1]
    print("removed ghost voxels: ",pos3d.shape)
else:
    pos = pos3d


ssnetplot = {
    "type":"scatter3d",
    "x":pos[:,0],
    "y":pos[:,1],
    "z":pos[:,2],
        "mode":"markers",
        "name":"ssnet",
        "marker":{"color":ssnetcol,"size":1}
    }

detlines = detdata.getlines(color=(10,10,10))

# DATA
ssnet_plot_data = [ssnetplot] + detdata.getlines(color=(10,10,10))

# LAYOUT
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='SSNET LABELS',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 3},
        "camera": {"eye": {"x": 4.0, "y": 0.5, "z": -0.5},
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=ssnet_plot_data, layout=layout)
fig.show()

In [None]:
# Voxel visualization with KEYPOINT SCORE LABELS
KPTYPE = 0
KPNAMES = {0:"neutrino",
    1:"track starts",
    2:"track ends",
    3:"shower starts",
    4:"michel starts",
    5:"delta starts"}

no_ghost_points = True

labels = data["voxkplabel"]
print("ssnetlabels: ",labels.shape)
kpcol = np.zeros(nvoxels)
for i in range(6):
   kpcol[:] = labels[KPTYPE,:]

if no_ghost_points:
    pos = pos3d[data['voxlabel']==1]
    kpcol = kpcol[data['voxlabel']==1]
    print("removed ghost voxels: ",pos3d.shape)
else:
    pos = pos3d


ssnetplot = {
    "type":"scatter3d",
    "x":pos[:,0],
    "y":pos[:,1],
    "z":pos[:,2],
        "mode":"markers",
        "name":"ssnet",
        "marker":{"color":kpcol,"size":1,'colorscale':'Viridis'}
    }

detlines = detdata.getlines(color=(10,10,10))

# DATA
ssnet_plot_data = [ssnetplot] + detdata.getlines(color=(10,10,10))

# LAYOUT
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='KEYPOINT (%s) LABELS'%(KPNAMES[KPTYPE]),
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 3},
        "camera": {"eye": {"x": 2.0, "y": 0.5, "z": 0},
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=ssnet_plot_data, layout=layout)
fig.show()

In [None]:
# Plot Instance labels
no_ghost_points = True

labels = data["voxinstance"]
unique_labels = np.unique(labels)
print('unique labels: ',unique_labels)
color = np.zeros( (pos3d.shape[0],3), dtype=np.float )
for i in unique_labels.tolist():
    if i==0:
        continue # keep bg instance black
    color[ labels==i ] = np.random.rand(3)*255

print("labels: ",labels.shape)

if no_ghost_points:
    pos = pos3d[data['voxlabel']==1]
    color = color[data['voxlabel']==1]
    print("removed ghost voxels: ",pos3d.shape)
else:
    pos = pos3d


plot = {
    "type":"scatter3d",
    "x":pos[:,0],
    "y":pos[:,1],
    "z":pos[:,2],
        "mode":"markers",
        "name":"instance",
        "marker":{"color":color,"size":1}
    }

detlines = detdata.getlines(color=(10,10,10))

# DATA
plot_data = [plot] + detdata.getlines(color=(10,10,10))

# LAYOUT
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='INSTANCE LABELS',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 3},
        "camera": {"eye": {"x": 1, "y": 1, "z": 1},
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=plot_data, layout=layout)
fig.show()

In [None]:
# Plot Origin labels
no_ghost_points = True

labels = data["voxorigin"]
unique_labels = np.unique(labels)
print('unique labels: ',unique_labels)
unique_weights = np.unique(data['voxoriginweight'])
print('unique weights: ',unique_weights)

if no_ghost_points:
    pos = pos3d[data['voxlabel']==1]
    color = labels[data['voxlabel']==1]
    print("removed ghost voxels: ",pos3d.shape)
else:
    pos = pos3d
    color = labels


plot = {
    "type":"scatter3d",
    "x":pos[:,0],
    "y":pos[:,1],
    "z":pos[:,2],
        "mode":"markers",
        "name":"origin",
        "marker":{"color":color,"size":1,"colorscale":'Viridis'}
    }

detlines = detdata.getlines(color=(10,10,10))

# DATA
plot_data = [plot] + detdata.getlines(color=(10,10,10))

# LAYOUT
axis_template = {
    "showbackground": True,
    "backgroundcolor": "rgba(100, 100, 100,0.5)",
    "gridcolor": "rgb(50, 50, 50)",
    "zerolinecolor": "rgb(0, 0, 0)",
}


layout = go.Layout(
    title='ORIGIN LABELS',
    autosize=True,
    hovermode='closest',
    showlegend=False,
    scene= {
        "xaxis": axis_template,
        "yaxis": axis_template,
        "zaxis": axis_template,
        "aspectratio": {"x": 1, "y": 1, "z": 3},
        "camera": {"eye": {"x": 1, "y": 1, "z": 1},
                   "up":dict(x=0, y=1, z=0)},
        "annotations": [],
    }
)

fig = go.Figure(data=plot_data, layout=layout)
fig.show()

In [None]:
# plot Y-PLANE wire image

img = data['wireimageplane2'] # stored as sparse matrix
img[:,2] = np.clip(img[:,2],0,100)
print(img.shape)

xaxis = np.linspace( 0, 3456, endpoint=False, num=3456 )
yaxis = np.linspace( 2400, 8448, endpoint=False, num=1008 )
print(yaxis.shape)

denseimg = np.zeros( (1008,3456) )
tmpindex = img.astype(np.int)
denseimg[ tmpindex[:,0], tmpindex[:,1] ] = img[:,2]

layout_yplane = go.Layout(
    title='Y-PLANE WIRE IMAGE',
    autosize=True,
    hovermode='closest',
    showlegend=False)
    
heatmap = {
    #"type":"heatmapgl",                                                                                                                                                                  
    "type":"heatmap",
    "z":denseimg,
    "x":xaxis,
    "y":yaxis,
    "colorscale":"Jet",
}

fig = go.Figure(data=[heatmap],layout=layout_yplane)
fig.show()


In [None]:
# plot U-PLANE wire image

img = data['wireimageplane0'] # stored as sparse matrix
img[:,2] = np.clip(img[:,2],0,100)
print(img.shape)

xaxis = np.linspace( 0, 3456, endpoint=False, num=3456 )
yaxis = np.linspace( 2400, 8448, endpoint=False, num=1008 )
print(yaxis.shape)

denseimg = np.zeros( (1008,3456) )
tmpindex = img.astype(np.int)
denseimg[ tmpindex[:,0], tmpindex[:,1] ] = img[:,2]

layout_yplane = go.Layout(
    title='U-PLANE WIRE IMAGE',
    autosize=True,
    hovermode='closest',
    showlegend=False)
    
heatmap = {
    #"type":"heatmapgl",                                                                                                                                                                  
    "type":"heatmap",
    "z":denseimg,
    "x":xaxis,
    "y":yaxis,
    "colorscale":"Jet",
}

fig = go.Figure(data=[heatmap],layout=layout_yplane)
fig.show()


In [None]:
# plot V-PLANE wire image

img = data['wireimageplane1'] # stored as sparse matrix
img[:,2] = np.clip(img[:,2],0,100)
print(img.shape)

xaxis = np.linspace( 0, 3456, endpoint=False, num=3456 )
yaxis = np.linspace( 2400, 8448, endpoint=False, num=1008 )
print(yaxis.shape)

denseimg = np.zeros( (1008,3456) )
tmpindex = img.astype(np.int)
denseimg[ tmpindex[:,0], tmpindex[:,1] ] = img[:,2]

layout_yplane = go.Layout(
    title='V-PLANE WIRE IMAGE',
    autosize=True,
    hovermode='closest',
    showlegend=False)
    
heatmap = {
    #"type":"heatmapgl",                                                                                                                                                                  
    "type":"heatmap",
    "z":denseimg,
    "x":xaxis,
    "y":yaxis,
    "colorscale":"Jet",
}

fig = go.Figure(data=[heatmap],layout=layout_yplane)
fig.show()


In [None]:
# sign into plotly chart studio
username = '' # your username
api_key = '' # your api key - go to profile > settings > regenerate key
cs.tools.set_credentials_file(username=username, api_key=api_key)

In [None]:
# push plot to chart studio: can then embed on websites
py.plot(fig, filename = 'test_ssnet_3d', auto_open=True)