# Scopo del notebook: realizzare una demo in tempo reale che processi un video dalla videocamera e mostri le predizioni

#### Costanti

In [1]:
# dove si trova il modello da usare per le predizioni
NN_PATH = 'trained/resnet50_out8_best.pth'
# dove trovo le classi da predire
LABELS_PATH = 'data/images_scraped/'
# dimensione del mirino
ssz=200

### Imports

In [None]:
import torch, torchvision
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Resize, ToPILImage, CenterCrop, Normalize, Compose
from torchvision.transforms.functional import to_grayscale, to_tensor, rotate, hflip, crop
import matplotlib.pyplot as plt

import random
import pandas as pd
from torchvision.io import read_image

from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from sklearn.metrics import classification_report, confusion_matrix

from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
import seaborn as sn
import numpy as np

import io
from PIL import Image
from PIL.features import pilinfo

from copy import deepcopy


import traceback
import warnings
warnings.filterwarnings("error")

import cv2
from numpy import asarray
import copy
import argparse

import os
import matplotlib.patches as patches

### Caricamento del modello

In [None]:
loaded = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=False)
loaded.fc = torch.nn.Linear(2048,8)

In [None]:
loaded.load_state_dict(torch.load(NN_PATH))
loaded.eval()

#### Caricamento classi da predire

In [None]:
only_dirs = [ name for name in os.listdir(LABELS_PATH) if 
                 os.path.isdir(os.path.join(LABELS_PATH, name)) ]

diz = {} #diz [key=LABEL_INDEX, value=LABEL_NAME]
diz2 = {} #diz2 [key=LABEL_INDEX, value=PROB_PREDICTION]
nocl=0 #num of classes
for d in only_dirs:
    diz[nocl] = d
    diz2[nocl] = 0
    nocl+=1
    
print(diz.values())

#### Funzione che prende l'immagine, ritaglia il mirino e la prepara alla rete neurale

In [None]:
def preprocess(im, x0, y0):
    cropped = crop(im, y0, x0, ssz, ssz)
    actions = Compose([
                Resize(300),
                CenterCrop(300),
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
    return actions(cropped)

### Funzione che effettua predizioni e crea grafico a barre

In [None]:
def analyze_and_drawbar(frame):
    # predico e produco confidenze della scelta
    target_frame = frame[None, :] # aggiungo una dimensione per simulare un batch
    out = loaded(target_frame)
    _, best_pred = torch.max(out.data, 1)
    perc = torch.nn.functional.softmax(out, dim=1)[0] * 100  # recupero attivazioni e trasformo in percentuali  
    _, indices = torch.sort(out, descending=True) # prelevo le classi
    
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(16, 8))
    #creazione grafico a barre orizzontali
    plt.xlim(0, 100)
    for idx in indices[0][:8]: # assegno ad ogni la classe la sua percentuale
        diz2[idx.item()] = perc[idx].item()

    keys = list(diz.keys())
    plt.barh(keys, diz2.values())
    ax.xaxis.set_ticks(np.linspace(0,100,11))
    ax.set_yticks(keys, labels=list(diz.values()))
    ax.grid(axis='x')
    ax.tick_params(axis='both', labelsize=25)
    ax.set_xlabel('confidence', fontsize=20, fontweight='bold')
    ax.set_ylabel('labels', fontsize=20, fontweight='bold')
    ax.set_title('predictions', fontsize=25, fontweight='bold')

    # salvataggio del grafico in un'immagine
    buf = io.BytesIO()
    fig.savefig(buf, format='jpeg', bbox_inches="tight", dpi=120)
    buf.seek(0)
    frame = Image.open(buf)
    plt.close()
    
    return frame

### Launcher della demo live

In [None]:
parser = argparse.ArgumentParser(description='Car logo realtime demo.')
parser.add_argument('--fullscreen', action='store_true', help='run in fullscreen')
args, unknown = parser.parse_known_args()

# Test della videocamera (evita crash nel loop)
cap = cv2.VideoCapture(0)
cap.release()

# Setup finestra (assegno nome e la metto a schermo intero)
windowname="Car logo Live Demo"
cv2.namedWindow(windowname, cv2.WND_PROP_FULLSCREEN)
cv2.moveWindow(windowname, 0, 0)
cv2.setWindowProperty(windowname, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)

try:
    cap = cv2.VideoCapture(0)
    while(True):
        ret, img = cap.read()
        im = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        im = cv2.flip(im, 1)

        # calcolo posizione del mirino
        h, w = im.shape[0:2]
        x0 = w // 2 - ssz // 2
        y0 = h // 2 - ssz // 2

        im_tensor = to_tensor(np.array(im)) # catturo immagine da camera
        rect = patches.Rectangle((x0, y0), ssz, ssz, linewidth=3, 
                             edgecolor='r', facecolor="none") # creo mirino

        # creazione plot
        fig = plt.figure(figsize=(16, 16), constrained_layout=True)
        spec = fig.add_gridspec(nrows=2, ncols=2)

        # personalizzo assi
        ax_camera = fig.add_subplot(spec[0, 0])
        ax_target = fig.add_subplot(spec[0, 1])
        ax_barh = fig.add_subplot(spec[1, :])
        plt.axis("off")

        # disegno immagine da camera
        ax_camera.imshow(im_tensor.permute(1,2,0))
        ax_camera.add_patch(rect)
        ax_camera.set_title('camera view', fontsize=25, fontweight='bold')
        ax_camera.axis('off')

        # costruisco e disegno il mirino
        frame = preprocess(im_tensor, x0, y0)
        frame_plt = frame * torch.tensor([0.229, 0.224, 0.225]).reshape(3,1,1) + torch.tensor([0.485, 0.456, 0.406]).reshape(3,1,1)
        ax_target.imshow(frame_plt.permute(1,2,0))
        ax_target.set_title('target view', fontsize=25, fontweight='bold')
        ax_target.axis('off')

        # predico, costruisco e disegno grafico
        img_barplt = analyze_and_drawbar(frame)
        ax_barh.imshow(img_barplt)

        # catturo plot
        buf = io.BytesIO()
        fig.savefig(buf, format='jpeg', bbox_inches='tight')
        buf.seek(0)
        full_screen_img = Image.open(buf)
        plt.close()

        # preparo il plot alla scrittura su video live 
        full_screen_img = (to_tensor(full_screen_img)).permute(1,2,0).numpy()
        full_screen_img = (full_screen_img*255).astype(np.uint8)

        # scrivo plot sul video live    
        cv2.imshow(windowname,cv2.cvtColor(full_screen_img, cv2.COLOR_RGB2BGR))

        # tasto che fa terminare la demo (Q)
        key = cv2.waitKey(20)
        if key & 0xFF == ord('q'):
            break
    
except Exception:
    print('error occured during thw while loop!')
    cap.release()
    cv2.startWindowThread()
    cv2.destroyAllWindows()
    
cap.release()
cv2.startWindowThread()
cv2.destroyAllWindows()