# Reusing a Torchjit Model in Python



The model we provide is a [torchjit](https://pytorch.org/docs/stable/jit.html), which means that you can also directly access the model using the Torch library.

In this tutorial, we will learn how to use our pretrained model with Torch and how to extract the features generated from the model.


In [None]:
%matplotlib inline
import os
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [12, 8]
import torch

from metavision_ml.data import CDProcessorIterator

In [None]:
jit_path = os.path.join(os.getcwd(), "red_histogram_05_2020/model.ptjit")

model = torch.jit.load(jit_path)

The loaded model is a [torchjit](https://pytorch.org/docs/stable/jit.html) model:

In [None]:
print(model)

In [None]:
input_path = "driving_sample.raw"
# if the file doesn't exist, it will be downloaded from Prophesee's public sample server 
from metavision_core.utils import get_sample

get_sample(input_path, folder=".")

In [None]:
delta_t = 50000
# The processor iterator combines the events iterator with the preprocessing functions
proc_iterator = CDProcessorIterator(input_path, "histo", delta_t=delta_t, num_tbins=1, preprocess_kwargs={"max_incr_per_pixel": 5},
                                    device=torch.device('cpu'), height=None, width=None)

input_tensor =  next(iter(proc_iterator))

We can now extract the feature map:

In [None]:
feature_maps = model.feature_extractor(input_tensor[None, ...])
for feature_map in feature_maps:
    print(feature_map.shape)
feature_maps = [feature_map.detach().numpy() for feature_map in feature_maps]   

Our detection network produces features maps at different resolution. Each different feature map corresponds to one channel in the final convolutional layer of the _feature extractor_ network.
These feature maps are the features that our network "considers" the best for the detection task. To make a comparison, a human looking for cars might search for headlights or wheels, our network uses these feature maps.

We can now visualize some of these feature maps: negative values (features that suggest that the object is not a car) are in blue, positives values (features that suggest that the object is a car) in red.

In [None]:
def remove_outliers(array):
    """remove outlier values for better visualization"""
    filtered_array = array.copy()
    absolute_value = np.abs(filtered_array)
    mean = absolute_value.mean()
    std = absolute_value.std()
    filtered_array[absolute_value > mean + 3 * std] = 0
    return filtered_array

plt.rcParams['figure.figsize'] = [6, 4]
# as a reminder, we first visualize the input of the neural network 
plt.imshow(proc_iterator.show(time_bin=0))
plt.title("Neural network input histogram")
plt.show()

for index, feature_map in enumerate(feature_maps[0][0, :14]):
    feature_map = remove_outliers(feature_map)
    plt.imshow(feature_map, cmap="coolwarm")
    plt.title("feature map number {}".format(index))
    plt.show()

From the first 14 features, feature map number 0 looks interesting, as it seems to have a positive correlation with the car in this example. Let's take a closer look:

In [None]:
from itertools import islice
plt.rcParams['figure.figsize'] = [12, 8]

for input_tensor in islice(proc_iterator,4):
    
    # we extract this particular feature map
    single_feature_map = model.feature_extractor(input_tensor[None, ...])[0][0, 0].detach()
    
    # Let's display the features alongside the input events.
    _, (ax1, ax2) = plt.subplots(1, 2)
    feature_map = remove_outliers(single_feature_map.numpy())
    ax1.imshow(feature_map, cmap='coolwarm')
    ax2.imshow(proc_iterator.show(time_bin=0))
    plt.show()