# Exercise 09: fastai_onnx_gradcam

***search for # TASK XX in the code and fill missing lines***

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JozefResetar/computer_vision/blob/main/8_fastai_onnx_gradcam.ipynb)


## Install necessary dependencies
###  on Google Colab
1. Use GPU. Runtime -> Change runtime type -> Hardware accelerator -> GPU
2. Install dependencies from the following cell
3. Runtime -> Restart runtime

In [None]:
!pip install light-the-torch >> /.tmp
!ltt install torch torchvision >> /.tmp
!pip install fastai --upgrade >> /.tmp
!pip install wandb
!pip install onnxruntime

### on local comp

In [None]:
conda install -c fastai -c pytorch fastai
pip install wandb
pip install onnxruntime
pip install ipywidgets
pip install matplotlib

## Download dataset and import necessary libraries

In [None]:
from ipywidgets import IntSlider
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
from fastai.vision.widgets import *
from fastai.vision.all import *
from fastai.callback.wandb import *
import wandb
import onnxruntime as ort
wandb.init(anonymous='allow') # use wandb without an account 
# wandb.init() # uncomment if you are a registered user

## Create Dataloader

In [None]:
# DOWNLOAD DATASET
path = untar_data(URLs.PETS)

In [None]:
files = get_image_files(path/"images")
SIZE = 224
files

### TASK 01: name of the file is a class. fix a get_y oneliner method to transform string e.g. 'Abyssinian_106.jpg' into string(class) 'abyssinian' 


In [None]:
def get_y(f): return ...
dls = ImageDataLoaders.from_name_func(path, fnames=files, label_func=get_y, item_tfms=Resize(SIZE)) # add num_workers=0 when running on windows
classes = dls.vocab
dls.show_batch()

In [None]:
assert len(dls.train_ds.vocab) == 37, "Not correct number of classes"

### (OPTIONAL TASK) change resnet18 to other arch. Is the model improving with a change of the arch ? 


In [None]:
dir(models)

## TRAIN on pretrained model


In [None]:
learn = cnn_learner(dls, resnet18, metrics=[error_rate, accuracy], cbs=[WandbCallback(), SaveModelCallback(monitor='error_rate', comp=np.less)])
# learn = cnn_learner(dls, ..., metrics=[error_rate, accuracy], cbs=[WandbCallback(), SaveModelCallback(monitor='error_rate', comp=np.less)]) 


In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(5, 3e-2, freeze_epochs=1)

In [None]:
# INFERENCE
learn.show_results()
# If you experiment
learn.export(fname='classifier.pkl')

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
len(dls.valid_ds)==len(losses)==len(idxs)
interp.plot_confusion_matrix(figsize=(7,7))

### TASK 02: based on the confusion matrix compute sensitivity and specificity of EACH class
***
Sensitivity (True Positive rate) measures the proportion of positives that are correctly identified. It is defined as follows:
$$sensitivity = \frac{true\_positive}{true\_positive + false\_negative}$$
Specificity (True Negative rate) measures the proportion of negatives that are correctly identified. It is defined as follows:
$$specificity = \frac{true\_negative}{true\_negative + false\_positive}$$
***
https://en.wikipedia.org/wiki/Sensitivity_and_specificity

HINT: https://stackoverflow.com/questions/55635406/how-to-calculate-multiclass-overall-accuracy-sensitivity-and-specificity

In [None]:
interp.confusion_matrix()

In [None]:
...

### TASK 03: compare inference times of the learn.predict and the onnx one. You can also use jupyter's %timeit, see: https://ipython.readthedocs.io/en/stable/interactive/magics.html

In [None]:
# ORIGINAL INFERENCE
learn.predict(files[0])
...

In [None]:
torch.onnx.export(
    learn.model,
    torch.randn(1, 3, SIZE, SIZE).cuda(), # The content of the tensor does not matter, but it needs to have the correct shape.
                                          # In particular, the first axis must be the batch size, even if it is 1. 
    "classifier.onnx",
    input_names=["image"],
    output_names=["diagnosis"]
)

In [None]:
# ONNX INFERENCE
session = ort.InferenceSession('classifier.onnx')
input_name = session.get_inputs()[0].name
img = PILImage.create(files[0])
session.run(None, {input_name: np.array(img.resize((SIZE,SIZE))).reshape(1,3,SIZE,SIZE).astype(np.float32)})
...

## Grad CAM

### TASK 04: Debug get_cam_map method with jupyter's %debug
* 'up', 'down' is for moving in the call stack,
* 'list' shows code
* 'help' for more options
* try to print also some variables in the current call context

In [None]:
# grad cam
def get_cam_map(model, x, cls, layer=-2):
    with Hook(model[0][layer], lambda m, i, o: o[10].detach().clone(), is_forward=False) as hookg:
        with Hook(model[0][layer], lambda m, i, o: o.detach().clone(), is_forward=True) as hook:
            output = model.eval()(x.cuda())
            act = hook.stored
        output[0, cls].backward()
        grad = hookg.stored
    w = grad[0].mean(dim=[1,2], keepdim=True)
    cam_map = (w * act[0]).sum(0)
    return cam_map.detach().cpu()

# merge image with heatmap
def merge_img_cam(image, cam_map):
    image = image.resize((SIZE, SIZE))
    # make a Figure and attach it to a canvas.
    fig = Figure(figsize=(SIZE/100, SIZE/100), dpi=100)
    canvas = FigureCanvasAgg(fig)
    ax = fig.subplots()
    # Do some plotting here
    show_image(image, ctx=ax)
    ax.axis('off')
    ax.imshow(cam_map, alpha=0.6, extent=(0,SIZE,SIZE,0), interpolation='bilinear', cmap='magma')

    # Retrieve a view on the renderer buffer
    canvas.draw()
    buf = canvas.buffer_rgba()
    # convert to a NumPy array
    X = np.asarray(buf)
    return X


## Simple site with widgets

In [None]:
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
grad_pl = widgets.Output()
lbl_pred = widgets.Label('Prediction: -; Probability: -')
grad_classes = Dropdown(options=classes, index=0)
grad_layer_slider = IntSlider(min=-5, max=-2, step=1, value=-2)

def prepare_image(img):
    return PILImage.create(np.array(PILImage.create(img).resize((SIZE,SIZE))))

def on_class_slider_change(change):
    grad_pl.clear_output()
    img = prepare_image(btn_upload.data[-1])
    x, = first(learn.dls.test_dl([img], rm_type_tfms=None, num_workers=0)) 
    with grad_pl:
        cam_map = get_cam_map(model=learn.model, x=x, cls=grad_classes.index, layer=grad_layer_slider.value)
        display(Image.fromarray(merge_img_cam(img, cam_map)))


def on_click(change):
    out_pl.clear_output()
    grad_pl.clear_output()
    img = prepare_image(btn_upload.data[-1])
    with out_pl: 
        display(img)
    pred,pred_idx,probs = learn.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {str(probs[pred_idx].numpy())}'
    x, = first(learn.dls.test_dl([img], rm_type_tfms=None, num_workers=0))
    
    with grad_pl:
        cam_map = get_cam_map(model=learn.model, x=x, cls=grad_classes.index)
        display(Image.fromarray(merge_img_cam(img, cam_map)))

btn_upload.observe(on_click, names=['data'])
grad_classes.observe(on_class_slider_change)
grad_layer_slider.observe(on_class_slider_change)

display(VBox([widgets.Label('Classify your cat or dog breed!'), 
              btn_upload, 
              out_pl, 
              lbl_pred, 
              HBox([widgets.Label('Grad Class Activation Map for class: '), grad_classes]), 
              HBox([widgets.Label('Grad Class Activation Map for model layer: '), grad_layer_slider]), 
              grad_pl]))

In [None]:
%debug