In [None]:
# Before starting import everything necessary
import numpy as np
import os
import simplejson as json
import sys
import torch
import yaml

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from ssd.net import build_ssd
from utils import get_data_loader

In [None]:
# Set presentation settings
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches

with open('../plots/palette.json') as json_file:
    color_palette = json.load(json_file)
plt.style.use('../plots/ssdjet.mplstyle')

In [None]:
THRESHOLD = 0.2

# Load configuration
CONFIG_FILE = '../ssd-config.yml'
NET_CONFIG_FILE = '../net-config-last.yml'
MODEL = '../models/PF-Jet-SSD-tw.pth'

config = yaml.safe_load(open(CONFIG_FILE))
net_config = yaml.safe_load(open(NET_CONFIG_FILE))

ssd_settings = config['ssd_settings']
ssd_settings['n_classes'] += 1
net_channels = net_config['network_channels']

In [None]:
# Initiate SSD and load weights
torch.set_default_tensor_type('torch.cuda.FloatTensor')
net = build_ssd(0, config['ssd_settings'], net_channels, inference=True)
if net.load_weights(MODEL):
    net.eval();

In [None]:
# Plotting calorimeter energy deposit image in color
def calorimeter_image(labels, baselines, data2d, titles=[]):
    offset = 15
    fig = plt.figure(figsize=(6.0, 3.0))
    for i, (label, title, c) in enumerate(zip([labels, baselines], titles, [color_palette['grey']['shade_50'], color_palette['red']['shade_600']])):
        ax = plt.subplot(1, 2, i+1)
        ax.set_title(title, loc='right')
        ax.set_xlim([-offset, 340+offset])
        ax.set_xlabel(r'$\eta$', horizontalalignment='right', x=1.0)
        ax.set_ylim([-offset, 360+offset])
        ax.set_ylabel(r'$\phi$', horizontalalignment='right', y=1.0)
        ax.imshow(data2d)

        for l in label:
            jet = patches.Rectangle((l[0]*340, l[1]*360),
                                    (l[2]-l[0])*340,
                                    (l[3]-l[1])*360,
                                    linewidth=0.5,
                                    edgecolor=c,
                                    facecolor='none')
            ax.add_patch(jet)
            ax.text(l[0]*340, l[3]*360,
                    config['evaluation_pref']['names_classes'][int(l[-2])-1],
                weight='bold',
                color='#000000',
                bbox={'facecolor': c,
                      'alpha': 1.0,
                      'linewidth': 0.5})
    plt.savefig('../plots/Inference-Example')
    plt.show();

In [None]:
loader = get_data_loader(config['dataset']['test'][0],
                         1,
                         0,
                         ssd_settings['input_dimensions'],
                         ssd_settings['object_size'],
                         return_pt=True,
                         shuffle=True)

In [None]:
for image, target in loader:
    with torch.no_grad():
        detections = net(image).data.cpu().numpy()
        predictions = np.empty((0, 6))

        for cl, det in enumerate(detections[0]):
            columns = det.shape[1]
            mask = det[:, 0] >= THRESHOLD
            det = det.flatten()[np.repeat(mask, columns)]
            if len(det):
                det = det.reshape(-1, columns)
                cls = np.transpose(np.expand_dims(np.repeat(cl-1, det.shape[0]), 0))
                det = np.hstack((det[:, 1:5], cls+1, det[:, [0]]))
                predictions = np.vstack([predictions, det])
    
    calorimeter_image(
        np.array(target[0].cpu().numpy()),
        predictions,
        np.transpose(image[0].cpu().numpy(), (2,1,0)),
        ['Truth', 'Prediction']
    )
    
    break