In [None]:
import h5py
import numpy as np
import os
import simplejson as json
import sys
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn as nn
import torch.nn.init as init
import torch.utils.data as data

from torch.autograd import Variable

In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.ticker as tick

from matplotlib import cm
from matplotlib.colors import SymLogNorm
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

with open('./data/palette.json') as json_file:
    color_palette = json.load(json_file)
    
plt.style.use('./plots/ssdjet.mplstyle')

In [None]:
module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

from ssd.generator import CalorimeterJetDataset
from ssd.layers.modules import MultiBoxLoss
from ssd.net import build_ssd

In [None]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

net = build_ssd(phase='test', num_classes=2, qtype='full')
net.load_weights('./models/ssd-jet-one-class-full-2.pth')
net.eval()

In [None]:
DATA_SOURCE = '/mnt/home/apol/ceph/fast-three'
CLASSES = ['b', 'h', 'W', 't', 'q']

In [None]:
def plot_calo_image(pixels, gt=[], predictions=[], save_name=None):

    fig = plt.figure(figsize=(16, 6))

    height, width, offset = pixels.shape[1], pixels.shape[2], 15

    for index, calorimeter in enumerate(pixels):
        ax = plt.subplot(1, 2, index+1)
        ax = plt.gca()
        ax.set_xlim([-offset, width+offset])
        ax.set_xlabel('$\eta$', horizontalalignment='right', x=1.0)
        ax.set_ylim([-offset, height+offset])
        ax.set_ylabel('$\phi [\degree$]', horizontalalignment='right', y=1.0)
        ax.spines['left'].set_smart_bounds(True)
        ax.spines['bottom'].set_smart_bounds(True)
        
        # Show energy deposits
        im = ax.imshow(calorimeter,
                       norm=SymLogNorm(linthresh=0.03,
                                       vmin=0,
                                       vmax=pixels.max()),
                       interpolation='nearest')

        # Add scale
        cbar = fig.colorbar(im, extend='max')
        cbar.set_label('$E_T$ [GeV]',
                       rotation=90,
                       horizontalalignment='right',
                       y=1.0)
    
        # Add ground truth
        for box in gt:
            xmin, ymin, xmax, ymax = box[0]*width, box[1]*height, box[2]*width, box[3]*height
            jet = patches.Rectangle((xmin, ymin),
                                    xmax-xmin,
                                    ymax-ymin,
                                    linewidth=1.2,
                                    alpha=0.5,
                                    edgecolor=color_palette['red']['shade_500'],
                                    facecolor='none')
            ax.add_patch(jet)
            ax.legend(['Ground Truth'], loc='lower left', bbox_to_anchor=(0., -0.14))
            
            ax.text(xmin, ymin,
                    '{0}'.format(CLASSES[int(box[4])]),
                    weight='bold',
                    color=color_palette['grey']['shade_100'],
                    bbox={'facecolor': color_palette['red']['shade_900'],
                          'alpha': 1.0})

        # Draw the predicted boxes
        for box in predictions:
            xmin, ymin, xmax, ymax = box[2]*width, box[3]*height, box[4]*width, box[5]*height
            jet = plt.Rectangle((xmin, ymin),
                                xmax-xmin,
                                ymax-ymin,
                                linewidth=1.2,
                                alpha=0.5,
                                edgecolor=color_palette['grey']['shade_500'],
                                facecolor='none')
            ax.add_patch(jet)
            ax.legend(['Prediction'], loc='lower left', bbox_to_anchor=(0., -0.14))
            ax.text(xmin, ymin,
                    '{0}: {1:.2f}'.format(CLASSES[int(box[0])], box[1]),
                    weight='bold',
                    color=color_palette['grey']['shade_100'],
                    bbox={'facecolor': color_palette['grey']['shade_900'],
                          'alpha': 1.0})

        # Add CMS tag
        ax.text(0, 1, 'CMS',
                weight='bold',
                transform=ax.transAxes,
                color=color_palette['grey']['shade_900'],
                fontsize=14)

        # Add title
        if index:
            ax.text(0.85, 1, 'HCAL',
                    transform=ax.transAxes,
                    color=color_palette['grey']['shade_900'],
                    fontsize=14)
        else:
            ax.text(0.85, 1, 'ECAL',
                    transform=ax.transAxes,
                    color=color_palette['grey']['shade_900'],
                    fontsize=14)

        logo = OffsetImage(plt.imread('plots/hls4mllogo.jpg', format='jpg'), zoom=0.08)
        ab = AnnotationBbox(logo, [0, 0], xybox=(65, 381), frameon=False)
        ax.add_artist(ab)

    if save_name:
        fig.savefig(save_name)

    # Show plot
    plt.show();

In [None]:
train_dataset_path = '%s/RSGraviton_NARROW_1.h5' % DATA_SOURCE
h5_train = h5py.File(train_dataset_path, 'r')
train_dataset = CalorimeterJetDataset(hdf5_dataset=h5_train)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=1,
                                           shuffle=False,
                                           num_workers=1)

In [None]:
batch_iterator = iter(train_loader)
images, targets = next(batch_iterator)
plot_calo_image(images.cpu().numpy()[0], targets.cpu().numpy()[0], save_name='./plots/ssd-inference-example-gt')

In [None]:
if torch.cuda.is_available():
    y = net(images.clone().detach().cuda())
else:
    y = net(images)

detections = y.data
detections = detections.cpu().numpy()

In [None]:
coords = []
for i in range(detections.shape[1]):
    j = 0
    while detections[0,i,j,0] >= 0.1:
        coords.append(np.append(i, detections[0,i,j,:]))
        j+=1

plot_calo_image(images.numpy()[0], predictions=coords, save_name='./plots/ssd-inference-example-pred')