In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
sys.path.append("../")

from src.constants import *
from src.training_utils.dataset import *
from src.training_utils.training import train_model, get_model_instance_segmentation

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image

from torchvision import transforms as T

In [None]:
train_df = pd.read_csv(f"{PATH}/data/tiles/train_cardinalidades_linux.csv")
test_df = pd.read_csv(f"{PATH}/data/tiles/test_cardinalidades_linux.csv")

In [None]:
# le_dict = get_encoder_dict(CLASSES_CSV)
# le_dict

le_dict = {'muchos_opcional': 2,
           'muchos_obligatorio': 1,
           'uno_opcional': 3,
           'uno_obligatorio': 4}

In [None]:
train_df['label_transformed'] = train_df['label'].apply(lambda x: le_dict[x])
test_df['label_transformed'] = test_df['label'].apply(lambda x: le_dict[x])

In [None]:
def get_custom_transform(train):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
        transforms.append(T.RandomVerticalFlip(0.5))
    transforms.append(T.ToTensor())
    return T.Compose(transforms)

In [None]:
IMAGES_DIR = f"{PATH}/data/tiles/image_slices"

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = len(le_dict)+1 

dataset = PennFudanDataset(csv=train_df, images_dir=IMAGES_DIR)#, transforms=get_custom_transform(train=True))
dataset_test = PennFudanDataset(csv=test_df, images_dir=IMAGES_DIR)#, transforms=get_custom_transform(train=False))

In [None]:
data_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
data_loader_test = get_dataloader(dataset_test, batch_size=1, shuffle=False)

## Training model

In [None]:
train = True
epochs = 50

In [None]:
model = get_model_instance_segmentation(num_classes=num_classes, model_type="retinanet")
model.to(device)

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
# override_path = f"{PATH}/data/models/model_best_test.pt"

In [None]:
if train:
    train_model(model=model, data_loader=data_loader, data_loader_test=data_loader_test, 
            num_epochs=epochs, device=device, params=params)

## Save model
https://pytorch.org/tutorials/beginner/saving_loading_models.html

In [None]:
# model_name = model.__class__.__name__.lower()
# PATH_TO_SAVE_MODEL = f"{PATH}/data/models/model_{model_name}_final.pt"

# save_model(path_to_save, model, epoch, loss_value)

## Testing

In [None]:
import cv2
import PIL
from IPython.display import display

def get_class_name(num_label, le_dict):
    reversed_le_dict = {v:k for k,v in le_dict.items()}
    return reversed_le_dict[num_label]

def draw_bbox(img, xmin, ymin, xmax, ymax, score, label): 
    txt = get_class_name(label, le_dict) + ' ' + str(score)
    img = cv2.putText(img, txt, (int(xmin), int(ymin)),
                      cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0,0,255), 1)

    return cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), 
                         (255,0,0), 1)

In [None]:
model.eval()

In [None]:
for i in range(len(dataset_test)):
    tensor_image = dataset_test.__getitem__(i)[0]
    to_pil = T.ToPILImage()
    pil_image = to_pil(tensor_image)
    predictions = model([tensor_image])
    image = pil_image
    for prediction in predictions:
        for box, score, label in zip(prediction['boxes'],prediction['scores'],prediction['labels']):
            score = round(score.item(), 3)
            label = label.item()
            if score < 0.5:
                break
            xmin = box[0].item()
            ymin = box[1].item()
            xmax = box[2].item()
            ymax = box[3].item()
            print(xmin, ymin, xmax, ymax)
            if isinstance(image, PIL.Image.Image):
                image = draw_bbox(np.array(image), xmin, ymin, xmax, ymax, score, label)
            else:
                image = draw_bbox(image, xmin, ymin, xmax, ymax, score, label)
        display(Image.fromarray(image))

## Load the two final models & calculate AP for them
- https://torchmetrics.readthedocs.io/en/stable/classification/average_precision.html
- https://torchmetrics.readthedocs.io/en/stable/retrieval/map.html

In [None]:
model = get_model_instance_segmentation(len(le_dict)+1, "retinanet")
model_name = model.__class__.__name__.lower()
PATH_TO_LOAD_MODEL = f"/home/nacho/TFI-Cazcarra/data/models/model_{model_name}_final.pt"

model_obj = torch.load(PATH_TO_LOAD_MODEL)
model.load_state_dict(model_obj['model_state_dict'])

In [None]:
model.eval()

In [None]:
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from pprint import pprint

In [None]:
# En batch mata el kernel
predictions = []
targets = []
for i in range(len(dataset_test)):
    prediction = model([dataset_test.__getitem__(i)[0]])
    predictions.append(prediction)
    target = dataset_test.__getitem__(i)[1]
    targets.append(dataset_test.__getitem__(i)[1])

In [None]:
predictions = [p[0] for p in predictions]

In [None]:
metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox", max_detection_thresholds=[100], class_metrics=False)
metric.update(predictions, targets)
pprint(metric.compute())