# Inference Model

### Import required libraries

In [None]:
import torch
import torchvision
import torch.nn as nn
import cv2
import numpy as np
import pandas as pd
import traitlets
from IPython.display import display
import ipywidgets.widgets as widgets
from jetbot import Camera, bgr8_to_jpeg
from jetbot import Robot
import torch.nn.functional as F
import time
import pickle
import sklearn
from sklearn.svm import OneClassSVM
import random
import collections


#from jetbot.zmq_camera import ZmqCamera

In [None]:
print(torch.__version__)
print(torchvision.__version__)
print('The scikit-learn version is {}.'.format(sklearn.__version__))

### Load models

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 12, 5)
        self.bn1 = nn.BatchNorm2d(12)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(12, 24, 5)
        self.bn2 = nn.BatchNorm2d(24)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(24*53*53, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)
    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1,24*53*53)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
device = torch.device('cuda')

In [None]:
net1 = Net()
state_dict1 = torch.load('best_model1_pruned.pth')
net1.load_state_dict(state_dict1)
net1 = net1.to(device)

In [None]:
net2 = Net()
state_dict2 = torch.load('best_model2_pruned.pth')
net2.load_state_dict(state_dict2)
net2 = net2.to(device)

In [None]:
net3 = Net()
state_dict3 = torch.load('best_model3_pruned.pth')
net3.load_state_dict(state_dict3)
net3 = net3.to(device)

### Create the preprocessing function

In [None]:
# Retrieve normalisation parameters 

norm_param_df = pd.read_csv('TRG_DATASET_NORM_PARAM.csv')

meanR = norm_param_df.loc[norm_param_df["Dataset"] == "AVG", "meanR"].item()
meanG = norm_param_df.loc[norm_param_df["Dataset"] == "AVG", "meanG"].item()
meanB = norm_param_df.loc[norm_param_df["Dataset"] == "AVG", "meanB"].item()

stdR = norm_param_df.loc[norm_param_df["Dataset"] == "AVG", "stdR"].item()
stdG = norm_param_df.loc[norm_param_df["Dataset"] == "AVG", "stdG"].item()
stdB = norm_param_df.loc[norm_param_df["Dataset"] == "AVG", "stdB"].item()


In [None]:
# Define function to apply normalization parameters to camera feed
normalize = torchvision.transforms.Normalize(255.0 * np.array([meanR, meanG, meanB]), 255.0 * np.array([stdR, stdG, stdB]))

# Define function to perform pre-processing of the camera feed 
def preprocess(camera_value):
    global device, normalize
    x = camera_value
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x).float()
    x = normalize(x)
    x = x.to(device)
    x = x[None, ...]
    return x

### Prepare SVM Model

In [None]:
SVM_model = pickle.load(open('SVM_model.sav', 'rb'))

### Prepare interface

In [None]:
# Define general layout information
button_layout = widgets.Layout(width='128px', height='30px')
text_layout = widgets.Layout(width='500px', height='64px')
number_layout = widgets.Layout(width='150px', height='64px')

# Define Warning interface
holdingpoint_warning_display = widgets.Text(description='Holding Point', layout=text_layout)
clear_button = widgets.Button(description='CLR',button_style='info', layout=button_layout)

# Define Test interface
ID_OOD_display = widgets.Text(description='ID/OOD', layout=number_layout)
model1_prediction_display = widgets.IntText(description='Pred. 1', layout=number_layout)
model2_prediction_display = widgets.IntText(description='Pred. 2', layout=number_layout)
model3_prediction_display = widgets.IntText(description='Pred. 3', layout=number_layout)
model1_proba_display = widgets.FloatText(description='Proba. 1', layout=number_layout)
model2_proba_display = widgets.FloatText(description='Proba. 2', layout=number_layout)
model3_proba_display = widgets.FloatText(description='Proba. 3', layout=number_layout)
model1_proba_TS_display = widgets.FloatText(description='Proba. 1 TS', layout=number_layout, color = 'red')
model2_proba_TS_display = widgets.FloatText(description='Proba. 2 TS', layout=number_layout)
model3_proba_TS_display = widgets.FloatText(description='Proba. 3 TS', layout=number_layout)
count_holdingpoint_display = widgets.IntText(description='Count', layout=number_layout)
final_prediction_display = widgets.IntText(description='Final Pred.', layout=number_layout)
final_prediction_TS_display = widgets.IntText(description='Final Pred. TS', layout=number_layout)
count_excluded_display = widgets.IntText(description='Excl. %', layout=number_layout)
count_excluded_TS_display = widgets.IntText(description='Excl. TS %', layout=number_layout)

### Create Robot, Camera and Controller instances

In [None]:
# Create camera instance
camera = Camera.instance(width=224, height=224, fps=1)

In [None]:
image = widgets.Image(format='jpeg', width=224, height=224)
camera_link = traitlets.dlink((camera, 'value'), (image, 'value'), transform=bgr8_to_jpeg)
#display(image, width = 224, height = 224)

In [None]:
# Connect controller to Jetbot
controller = widgets.Controller(index=0) 
display(controller)

In [None]:
# Create robot instance
robot = Robot()
# Adjust motor setting to make sure the robot goes straight 
robot = Robot(left_motor_alpha = 1.1, right_motor_alpha = 1.0)

In [None]:
# Define controller functions to move the robot
fwd_left_link = traitlets.dlink((controller.buttons[12], 'value'), (robot.left_motor, 'value'), transform=lambda x: x/8)
fwd_right_link = traitlets.dlink((controller.buttons[12], 'value'), (robot.right_motor, 'value'), transform=lambda x: x/8)

bck_left_link = traitlets.dlink((controller.buttons[13], 'value'), (robot.left_motor, 'value'), transform=lambda x: -x/8)
bck_right_link = traitlets.dlink((controller.buttons[13], 'value'), (robot.right_motor, 'value'), transform=lambda x: -x/8)

turn_left_left_link = traitlets.dlink((controller.buttons[14], 'value'), (robot.left_motor, 'value'), transform=lambda x: -x/8)
turn_left_right_link = traitlets.dlink((controller.buttons[14], 'value'), (robot.right_motor, 'value'), transform=lambda x: x/8)

turn_right_left_link = traitlets.dlink((controller.buttons[15], 'value'), (robot.left_motor, 'value'), transform=lambda x: x/8)
turn_right_right_link = traitlets.dlink((controller.buttons[15], 'value'), (robot.right_motor, 'value'), transform=lambda x: -x/8)

### Define update function

In [None]:
# Define function to compute the mean of a given channel
def get_mean_channels(batched_outputs):
    channel_means = []
    for single_output in batched_outputs:
        channel_means.append([channel.mean() for channel in single_output])
    return torch.tensor(channel_means)

In [None]:
# Set counters to 0
count_holdingpoint = 0
count_total_cycles = 0
count_excluded = 0
count_excluded_TS = 0

# Set confidence threshold and temperature scaling coefficient for each model (computed previously)
threshold = 0.97
threshold_TS = 0.73
temp_factor_net1 = 3.661
temp_factor_net2 = 4.289
temp_factor_net3 = 3.913

In [None]:
# Define function to retrieve activations on the Conv2 Layer of model 1
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.cpu().detach()
    return hook
net1.conv2.register_forward_hook(get_activation('conv2'))

# Define function that will process the images
def update(change):
    global count_holdingpoint, robot, count_total_cycles, count_excluded,count_excluded_TS
    # Retrieve new live feed image
    x = change['new'] 
    # Perform pre-processing of new live feed image
    x = preprocess(x)
    
    # Run image through all three models in paralell and compute prediction
    with torch.no_grad():  
        net1.eval()
        out1 = net1(x)        
        net2.eval()
        out2 = net2(x)        
        net3.eval()
        out3 = net3(x)
        
        # Run model 1 conv2 layer activations through SVM model to determin if image is in or out of distribution
        OOD_prediction = SVM_model.predict(get_mean_channels(activation['conv2']))
        if OOD_prediction == 1:
            ID_OOD_display.value = "ID"
        else:
            ID_OOD_display.value = "OOD"
        
        # Compute and display "raw" prediction
        _, prediction1 = torch.max(out1.data, 1)
        _, prediction2 = torch.max(out2.data, 1)
        _, prediction3 = torch.max(out3.data, 1)

        model1_prediction_display.value = prediction1
        model2_prediction_display.value = prediction2
        model3_prediction_display.value = prediction3
        
        # Compute and display predictions using softmax
        predicted_soft1 = F.softmax(out1, dim=1)
        predicted_soft2 = F.softmax(out2, dim=1)
        predicted_soft3 = F.softmax(out3, dim=1)
        
        model1_proba_display.value = np.amax(predicted_soft1.cpu().numpy())
        model2_proba_display.value = np.amax(predicted_soft2.cpu().numpy())
        model3_proba_display.value = np.amax(predicted_soft3.cpu().numpy())
        
        # Compute and display predictions using temperature scaling calibration
        predicted_soft1_TS = F.softmax(out1/temp_factor_net1, dim=1)
        predicted_soft2_TS = F.softmax(out2/temp_factor_net2, dim=1)
        predicted_soft3_TS = F.softmax(out3/temp_factor_net3, dim=1)
        
        model1_proba_TS_display.value = np.amax(predicted_soft1_TS.cpu().numpy())
        model2_proba_TS_display.value = np.amax(predicted_soft2_TS.cpu().numpy())
        model3_proba_TS_display.value = np.amax(predicted_soft3_TS.cpu().numpy())
    
    # Perform confidence filtering (using softmax prediction)
    if np.amax(predicted_soft1.cpu().numpy()) > threshold:
        sure1 = 1
    else:
        sure1 = 0
    if np.amax(predicted_soft2.cpu().numpy()) > threshold:
        sure2 = 1
    else:
        sure2 = 0
    if np.amax(predicted_soft3.cpu().numpy()) > threshold:
        sure3 = 1
    else:
        sure3 = 0
    
    # Perform voting and compute final prediction (using softmax prediction)
    if sure1 + sure2 + sure3 == 3:
        if prediction1 + prediction2 + prediction3 >=2:
            final_prediction = 1
        else:
            final_prediction = 0
    elif sure1 + sure2 + sure3 == 2:
        if sure1 == 0:
            if prediction2 + prediction3 == 2:
                final_prediction = 1
            elif prediction2 + prediction3 == 0:
                final_prediction = 0
            else:
                final_prediction = -1
        if sure2 == 0:
            if prediction1 + prediction3 ==2:
                final_prediction = 1
            elif prediction1 + prediction3 ==0:
                final_prediction = 0
            else:
                final_prediction = -1
        if sure3 == 0:
            if prediction2 + prediction1 ==2:
                final_prediction = 1
            elif prediction2 + prediction1 ==0:
                final_prediction = 0
            else:
                final_prediction = -1
    else:
        final_prediction = -1
    
    # Display final prediction (using softmax prediction)    
    final_prediction_display.value = final_prediction
    
    # Perform confidence filtering (using temperature scaling prediction)
    if np.amax(predicted_soft1_TS.cpu().numpy()) > threshold_TS:
        sure1_TS = 1
    else:
        sure1_TS = 0
    if np.amax(predicted_soft2_TS.cpu().numpy()) > threshold_TS:
        sure2_TS = 1
    else:
        sure2_TS = 0
    if np.amax(predicted_soft3_TS.cpu().numpy()) > threshold_TS:
        sure3_TS = 1
    else:
        sure3_TS = 0
    
    # Perform voting and compute final prediction (using temperature scaling prediction)
    if sure1_TS + sure2_TS + sure3_TS == 3:
        if prediction1 + prediction2 + prediction3 >=2:
            final_prediction_TS = 1
        else:
            final_prediction_TS = 0
    elif sure1_TS + sure2_TS + sure3_TS == 2:
        if sure1_TS == 0:
            if prediction2 + prediction3 == 2:
                final_prediction_TS = 1
            elif prediction2 + prediction3 == 0:
                final_prediction_TS = 0
            else:
                final_prediction_TS = -1
        if sure2_TS == 0:
            if prediction1 + prediction3 ==2:
                final_prediction_TS = 1
            elif prediction1 + prediction3 ==0:
                final_prediction_TS = 0
            else:
                final_prediction_TS = -1
        if sure3_TS == 0:
            if prediction2 + prediction1 ==2:
                final_prediction_TS = 1
            elif prediction2 + prediction1 ==0:
                final_prediction_TS = 0
            else:
                final_prediction_TS = -1
    else:
        final_prediction_TS = -1
    
    # Display final prediction (using temperature scaling prediction)
    final_prediction_TS_display.value = final_prediction_TS
    
    # Display Alerts/Warning depending on final prediction (softmax) and counter value
    if final_prediction == 0 and OOD_prediction == 1:
        count_holdingpoint = count_holdingpoint + 1
        count_holdingpoint_display.value = count_holdingpoint
        if count_holdingpoint >=5:
            holdingpoint_warning_display.value = "REACHING HOLDING POINT - STOP"
            robot.stop()
        else:
            holdingpoint_warning_display.value = "APPROACHING HOLDING POINT"
    else:
        holdingpoint_warning_display.value = " "
        count_holdingpoint = 0
    
    # Increment counter value if image was in-distribution and count number of excluded images
    if OOD_prediction == 1:
        count_total_cycles = count_total_cycles + 1
        if final_prediction_TS == -1:
            count_excluded_TS = count_excluded_TS + 1
        if final_prediction == -1:
            count_excluded = count_excluded + 1
    # Compute proportion of excluded images
    count_excluded_display.value = 100 * count_excluded / count_total_cycles
    count_excluded_TS_display.value = 100 * count_excluded_TS / count_total_cycles
    
    time.sleep(0.01)
   
update({'new': camera.value})  # we call the function once to intialize

### Start Demo

In [None]:
# Set count to 0
count_holdingpoint = 0

# Start processing images
camera.observe(update, names='value') 

In [None]:
# Display Alert/Warning field and metrics
display(image, width = 224, height = 224)
display(widgets.HBox([holdingpoint_warning_display, clear_button]))
display(widgets.HBox([ID_OOD_display]))
display(widgets.HBox([model1_prediction_display, model1_proba_display, model2_proba_TS_display]))
display(widgets.HBox([model2_prediction_display, model2_proba_display, model1_proba_TS_display, final_prediction_display, final_prediction_TS_display, count_holdingpoint_display]))
display(widgets.HBox([model3_prediction_display, model3_proba_display, model3_proba_TS_display, count_excluded_display, count_excluded_TS_display]))
clear_button.on_click(lambda x: camera.unobserve(update, names='value'))

### Stop Demo

In [None]:
camera.unobserve(update, names='value')

time.sleep(0.1)  # add a small sleep to make sure frames have finished processing

robot.stop()

In [None]:
camera_link.unlink()  # don't stream to browser (will still run camera)