In [None]:
import os,sys
import numpy as np
import chart_studio.plotly as py
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
from lartpcdataset import lartpcDataset
%load_ext autoreload
%autoreload 2

In [None]:
dataset = lartpcDataset(root="./data/z-view/", load_meta_data=True,verbose=False)
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=True)

In [None]:
LABEL_NAME = {0:"electron",
             1: "gamma",
             2:"muon",
             3:"proton",
             4:"pion"}

In [None]:
# Get a batch, print the first image
batch = next(iter(loader))
data = batch[0]
labels = batch[1]

# note that this is what is returned when load_meta_data=True
imgs = data[0]
meta = data[1]

# note that this is what is returned when load_meta_data=False
#imgs = data
#meta = None

print("label: ",labels," :: ",LABEL_NAME[labels.item()])
xaxis = np.linspace( 0, 256, endpoint=False, num=256 )
yaxis = np.linspace( 0, 256, endpoint=False, num=256 )
imgplot = {                                                                                                                                                                  
    "type":"heatmap",
    "z":imgs.squeeze(),
    "x":xaxis,
    "y":yaxis,
    "zmin":0.0,
    "zmax":10.0,
    "colorscale":"Jet",
    "showscale":False
}

fig0 = go.Figure(data=[imgplot])
fig0.update_layout(height=400, width=400 )

fig0.show()

In [None]:
# keep gathering batches until we assemble a set of images for each class
image_bank = {}
mom_bank = {}
for labelid in LABEL_NAME:
    image_bank[labelid] = []
    mom_bank[labelid] = []

xaxis = np.linspace( 0, 256, endpoint=False, num=256 )
yaxis = np.linspace( 0, 256, endpoint=False, num=256 )
    
NUM_IMGS_PER_CLASS = 9
NUM_IMGS_PER_ROW = 3
done = False
while not done:
    batch = next(iter(loader))
    data = batch[0]
    label = batch[1].item()
    img = data[0]
    meta = data[1]
    if len(image_bank[label])<NUM_IMGS_PER_CLASS:
        image_bank[label].append( img )
        mom_bank[label].append( meta[0][-1].item() )
    done = True
    for labelid in LABEL_NAME:
        if len(image_bank[labelid])<NUM_IMGS_PER_CLASS:
            done = False
print("Image bank made")

#Plot them
from plotly.subplots import make_subplots
class_figs = {}
NROWS = int(NUM_IMGS_PER_CLASS/NUM_IMGS_PER_ROW)
print("NROWS: ",NROWS)
for labelid in LABEL_NAME:
    #print("N IMAGES LABEL=",labelid,": ",len(image_bank[labelid]))
    
    # subplot titles
    subtitles = [ "p=%.2f MeV"%(x) for x in mom_bank[labelid] ]
    
    fig = make_subplots(rows=NROWS, cols=NUM_IMGS_PER_ROW, subplot_titles=subtitles )
    for n,img in enumerate(image_bank[labelid]):
        imgplot = {                                                                                                                                                                  
        "type":"heatmap",
        "z":img.squeeze(),
        "x":xaxis,
        "y":yaxis,
        "zmin":0.0,
        "zmax":10.0,
        "colorscale":"Jet",
        "showscale":False
        }
        #print(int(n/NUM_IMGS_PER_ROW)+1, n%NUM_IMGS_PER_ROW+1)
        fig.add_trace(imgplot, int(n/NUM_IMGS_PER_ROW)+1, n%NUM_IMGS_PER_ROW+1)
    fig.update_layout(height=900, width=900, title_text="%s examples"%(LABEL_NAME[labelid]))
    fig.show()