<a href="https://colab.research.google.com/github/TedDeVriesLentsch/ComputerVision_Group20/blob/main/detection_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Project - Seminar Computer Vision by Deep Learning (CS4245) 2020/2021

Group Number: 20

Student 1: Stan Zwinkels

Student 2: Ted de Vries Lentsch

Date: June 14, 2021

## Check available GPU

In [None]:
!nvidia-smi -L

## Settings

In [None]:
DO_TRAIN                = False
DO_TRAIN_NOISE          = False
DO_TRAIN_COLORJITTER    = False
DO_TEST                 = False
DO_OPTIMIZE_THRESHOLD   = False

## Import necessary libraries

In [None]:
# standard libraries
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import time
from matplotlib.lines import Line2D

# data processing
from skimage.transform import resize

# widgets
import ipywidgets

# Pytorch
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F

## Download and import standard TorchVision files

In [None]:
%%shell
# install pycocotools
pip install -q -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'

# clone TorchVision repository to use some files from references/detection
git clone --quiet https://github.com/pytorch/vision.git
cd vision
git checkout --quiet v0.3.0

# copy the files to drive
cp references/detection/coco_eval.py ../
cp references/detection/coco_utils.py ../
cp references/detection/engine.py ../
cp references/detection/transforms.py ../
cp references/detection/utils.py ../

In [None]:
from engine import train_one_epoch

## Determine device

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('The device is: {}'.format(device))

## Load dataset from Kaggle

In [None]:
!pip install -q kaggle

Go to your Kaggle account, scroll to the `API` section and click `Expire API Token` to remove previous tokens. Then click on `Create New API Token` and the file `kaggle.json` is downloaded. Run the cell below and select the `kaggle.json` that has been downloaded from the Kaggle account settings page.

In [None]:
from google.colab import files
files.upload()

Place the `kaggle.json` file in the directory `~/.kaggle/kaggle.json` according to the API of Kaggle.

In [None]:
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

Download the public dataset from Kaggle and unzip the dataset.

In [None]:
!kaggle datasets download -d teddevrieslentsch/morado-5may
!unzip -qq morado-5may.zip -d morado_5may
!rm morado-5may.zip

## Connect to Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Create folder on your Drive to save model parameters after training.

In [None]:
if not os.path.isdir('drive/MyDrive/model_parameters'):
    os.makedirs('drive/MyDrive/model_parameters')

## 1. Explore dataset

### 1.1. Load data
Below, the names of the images and annotations are obtained.

In [None]:
# paths
dataset_name  = 'morado_5may'
path_img      = '{}/images'.format(dataset_name)
path_annot    = '{}/annotations'.format(dataset_name)

# get image and annotation names
img_names     = list(sorted(os.listdir(path_img)))
annot_names   = list(sorted(os.listdir(path_annot)))

print('The dataset has {} images and {} annotations.'.format(len(img_names), len(annot_names)))

### 1.2. Images
Below, the images are displayed with an interactive widget. The size of every image is (H, W, C) = (4032, 3024, 3).

In [None]:
def plot_img(path_img, img_name):
    I1 = plt.imread('{}/{}'.format(path_img, img_name))
    I2 = np.rot90(I1, -1)

    plt.figure(figsize=(10, 10))
    plt.imshow(I2, zorder=-10)
    plt.xlabel('Width of image (pixels)', fontsize=20)
    plt.ylabel('Height of image (pixels)', fontsize=20)
    plt.xticks(np.linspace(0, 3024, 9), fontsize=16)
    plt.yticks(np.linspace(0, 4032, 9), fontsize=16)
    plt.xlim(0, 3024)
    plt.ylim(4032, 0)
    plt.show()

In [None]:
ipywidgets.interact(lambda idx: plot_img(path_img, img_names[idx]), idx=range(len(img_names)))

### 1.3. Annotations
Below, the images and annotations from the dataset are displayed with an interactive widget. The size of every image is (H, W, C) = (4032, 3024, 3) and an array with annotations constains rows with [x_min, y_min, x_max, y_max, label]. The 'raw' flowers are indicated with a blue rectangle and the 'ripe' flowers with a red rectangle.

In [None]:
def plot_img_and_annots(path_img, img_name, path_annot, annot_name):
    I1      = plt.imread('{}/{}'.format(path_img, img_name))
    I2      = np.rot90(I1, -1)
    annot1  = pd.read_csv('{}/{}'.format(path_annot, annot_name), sep=',', header=None)
    annot2  = annot1.replace({'raw': 1, 'ripe': 2}).to_numpy()

    plt.figure(figsize=(10, 10))
    plt.imshow(I2, zorder=-10)
    for coor in annot2:
        rectangle_points = np.array([[coor[0], coor[2], coor[2], coor[0], coor[0]],
                                     [coor[1], coor[1], coor[3], coor[3], coor[1]]])
        if coor[4]==1:
            plt.plot(rectangle_points[0,:], rectangle_points[1,:], color='blue', linewidth=3, zorder=10)
        elif coor[4]==2:
            plt.plot(rectangle_points[0,:], rectangle_points[1,:], color='red', linewidth=3, zorder=10)
    plt.xlabel('Width of image (pixels)', fontsize=20)
    plt.ylabel('Height of image (pixels)', fontsize=20)
    plt.xticks(np.linspace(0, 3024, 9), fontsize=16)
    plt.yticks(np.linspace(0, 4032, 9), fontsize=16)
    plt.xlim(0, 3024)
    plt.ylim(4032, 0)
    line_raw  = Line2D([0], [0], color='blue', linewidth=3, label='raw')
    line_ripe = Line2D([0], [0], color='red', linewidth=3, label='ripe')
    plt.legend(handles=[line_raw, line_ripe], loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20) 
    plt.show()

In [None]:
ipywidgets.interact(lambda idx: plot_img_and_annots(path_img, img_names[idx], path_annot, annot_names[idx]), idx=range(len(img_names)))

### 1.4. Resized image
Below, the images and annotations from the dataset are displayed with an interactive widget. The image needs to be resized because the height and width of an input image for Faster R-CNN must be in the range [800, 1333]. The original images have a size of (H, W, C) = (4032, 3024, 3). We have chosen to resize the image so that the height to width ratio remains the same. The size of each resized image is (H, W, C) = (1200, 900, 3) and an array with annotations constains rows with [x_min, y_min, x_max, y_max, label]. The 'raw' flowers are indicated with a blue rectangle and the 'ripe' flowers with a red rectangle.

In [None]:
def plot_img_and_annots_resized(path_img, img_name, path_annot, annot_name):
    I1      = plt.imread('{}/{}'.format(path_img, img_name))
    I2      = np.rot90(I1, -1)
    I3      = resize(I2, (1200, 900))
    annot1  = pd.read_csv('{}/{}'.format(path_annot, annot_name), sep=',', header=None)
    annot2  = annot1.replace({'raw': 1, 'ripe': 2}).to_numpy()
    annot2[:,0:4]  = (1200/4032)*annot2[:,0:4]

    plt.figure(figsize=(10, 10))
    plt.imshow(I3, zorder=-10)
    for coor in annot2:
        rectangle_points = np.array([[coor[0], coor[2], coor[2], coor[0], coor[0]],
                                     [coor[1], coor[1], coor[3], coor[3], coor[1]]])
        if coor[4]==1:
            plt.plot(rectangle_points[0,:], rectangle_points[1,:], color='blue', linewidth=3, zorder=10)
        elif coor[4]==2:
            plt.plot(rectangle_points[0,:], rectangle_points[1,:], color='red', linewidth=3, zorder=10)
    plt.xlabel('Width of image (pixels)', fontsize=20)
    plt.ylabel('Height of image (pixels)', fontsize=20)
    plt.xticks(np.linspace(0, 900, 7), fontsize=16)
    plt.yticks(np.linspace(0, 1200, 9), fontsize=16)
    plt.xlim(0, 900)
    plt.ylim(1200, 0)
    line_raw  = Line2D([0], [0], color='blue', linewidth=3, label='raw')
    line_ripe = Line2D([0], [0], color='red', linewidth=3, label='ripe')
    plt.legend(handles=[line_raw, line_ripe], loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)    
    plt.show()

In [None]:
ipywidgets.interact(lambda idx: plot_img_and_annots_resized(path_img, img_names[idx], path_annot, annot_names[idx]), idx=range(len(img_names)))

## 2. Dataset class

In [None]:
class MoradoDataset(torch.utils.data.Dataset):
    def __init__(self, root, path_numbers=[], transforms=None):
        self.root       = root                                                          # directory to dataset
        self.transforms = transforms                                                    # transform input data
        self.imgs       = self.get_names('{}/images'.format(root), path_numbers)        # load images
        self.annots     = self.get_names('{}/annotations'.format(root), path_numbers)   # load annotations
        self.classes    = ['background', 'raw', 'ripe']                                 # classes
        self.height     = 1200                                                          # (in pixels)
        self.width      = 900                                                           # (in pixels)
        self.sc_factor  = 1200/4032                                                     # scale factor

    def get_names(self, directory, path_numbers):
        names         = list(sorted(os.listdir(directory)))
        target_names  = []

        for name in names:
            path_number = int(name.split('morado_5may_')[1].split('_')[0])
            if path_number in path_numbers:
                target_names.append(name)
        
        return target_names#[:4] #REMOVE

    def __getitem__(self, idx):
        img_path   = '{}/images/{}'.format(self.root, self.imgs[idx])                   # image path
        annot_path = '{}/annotations/{}'.format(self.root, self.annots[idx])            # annotation path

        img   = np.rot90(plt.imread(img_path), -1)                                                          # image (rotated)
        annot = pd.read_csv(annot_path, sep=',', header=None).replace({'raw': 1, 'ripe': 2}).to_numpy()     # annotation

        img          = resize(img, (self.height, self.width))                           # resized image
        annot[:,0:4] = self.sc_factor*annot[:,0:4]                                      # resized annotation

        boxes    = torch.as_tensor(annot[:,0:4], dtype=torch.float)                     # boxes
        labels   = torch.as_tensor(annot[:,4], dtype=torch.int64)                       # labels
        image_id = torch.tensor([idx])                                                  # image id
        area     = (boxes[:,3]-boxes[:,1])*(boxes[:,2]-boxes[:,0])                      # area
        iscrowd  = torch.zeros((len(annot)), dtype=torch.int64)                         # is crowd (set to False)

        target             = {}                                                         # target
        target['boxes']    = boxes
        target['labels']   = labels
        target['image_id'] = image_id
        target['area']     = area
        target['iscrowd']  = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

## 3. Transform

In [None]:
class ToTensor(object):
    def __call__(self, img, target):
        img = F.to_tensor(img.copy()).type(torch.float)
        return img, target

In [None]:
class RandomHorizontalFlip(object):
    def __init__(self, prob_threshold):
        self.prob_threshold = prob_threshold

    def __call__(self, img, target):
        if random.random()<self.prob_threshold:
            width           = img.shape[2]                  # get width of image
            img             = img.flip(-1)                  # mirror image horizontally
            bbox            = target['boxes']               # get bounding boxes of original image
            bbox[:,[0,2]]   = width-bbox[:,[2,0]]           # adapt x-coordinates of bounding boxes
            target['boxes'] = bbox                          # set new bounding box as target
        return img, target

In [None]:
class RandomVerticalFlip(object):
    def __init__(self, prob_threshold):
        self.prob_threshold = prob_threshold

    def __call__(self, img, target):
        if random.random()<self.prob_threshold:
            height          = img.shape[1]                  # get height of image
            img             = img.flip([1])                 # mirror image vertically
            bbox            = target['boxes']               # get bounding boxes of original image
            bbox[:,[1,3]]   = height-bbox[:,[3,1]]          # adapt y-coordinates of bounding boxes
            target['boxes'] = bbox                          # set new bounding box as target
        return img, target

In [None]:
class AddGaussianNoise(object):
    def __init__(self, mean=0, std=0.3):
        self.mean = mean
        self.std  = std

    def __call__(self, img, target):
        # torch.randn gives tensor filled with random numbers from the standard normal distribution
        img = torch.clamp(img + self.mean + torch.randn(img.size())*self.std, min=0, max=1)
        return img, target

In [None]:
class ColorJitter(object):
    def __call__(self, img, target):
        color_jitter = torchvision.transforms.ColorJitter(brightness=0, contrast=0.3, saturation=0.3, hue=0.2)
        img = color_jitter(img)
        return img, target

In [None]:
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, target):
        for transform in self.transforms:
            img, target = transform(img, target)
        return img, target

In [None]:
# Note: mean/std normalization is done by model
def get_transform(train=False, train_noise=False, train_colorjitter=False):
    transforms = []
    
    # numpy array to tensor and 0-255 to 0-1
    transforms.append(ToTensor())
    
    if train:
        # during training, randomly flip the training images and bounding boxes for data augmentation
        transforms.append(RandomHorizontalFlip(0.5))
        transforms.append(RandomVerticalFlip(0.5))

        if train_noise:
            # During training add GaussianNoise on the training images
            transforms.append(AddGaussianNoise())

        if train_colorjitter:
            # During training add ColorJitter on the training images (brightness, constrast, saturation, hue)
            transforms.append(ColorJitter())
    return Compose(transforms)

## 4. Model

In [None]:
def get_model(num_classes):
    # load Faster R-CNN model (pre-trained on COCO)
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    # there are 3 different classes (background + 2 classes)
    num_classes = 3

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features

    # replace the classifier with one that has num_classes classes
    # head: part of the network that uses ROI feature vector to predict cls score and bounding box (2 sibling linear layers)
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 

    return model

## 5. Create dataloader

In [None]:
# set seed for repetitiveness
torch.manual_seed(1)

# path numbers for the datasets (there are 10 paths)
path_numbers_train  = [1, 2, 3, 4, 5, 6]
path_numbers_val    = [7, 8]
path_numbers_test   = [9, 10]

# create train, validation, and test dataset
dataset_train   = MoradoDataset('morado_5may', path_numbers_train, get_transform(DO_TRAIN, DO_TRAIN_NOISE, DO_TRAIN_COLORJITTER))
dataset_val     = MoradoDataset('morado_5may', path_numbers_val, get_transform())
dataset_test    = MoradoDataset('morado_5may', path_numbers_test, get_transform())

# utility function for data loader
# convert batch from [(img1, dict1), (img2, dict2)] into ((img1, img2), (dict1, dict2))
def collate_fn(batch):
    return tuple(zip(*batch))

# create train data loaders and test 
data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=2, shuffle=True, num_workers=1, collate_fn=collate_fn)
data_loader_val   = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=1, collate_fn=collate_fn)
data_loader_test  = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=1, collate_fn=collate_fn)

## 6. Create model

In [None]:
# there are 3 different classes (background + 2 classes)
num_classes = 3

# get the model
model = get_model(num_classes).to(device)

## 7. Train and save model

In [None]:
def evaluate_model(model, dataset_val, device):
    # variable to store the loss
    total_loss = 0

    # keep model in train mode to obtain loss dictionary
    model.train()

    # disable gradient calculation
    with torch.no_grad():
        for img, target in dataset_val:
            img     = img.to(device)
            target  = {k: v.to(device) for k, v in target.items()}

            # predict
            loss_dict = model([img], [target])
            loss      = sum(loss for loss in loss_dict.values())

            # add loss
            total_loss += loss.item()

    # average loss
    avg_loss = total_loss/len(dataset_val)

    return avg_loss

In [None]:
if DO_TRAIN:
    # start time training is used for name of saved model  
    saving_time = int(time.time())

    # training variables
    params          = [p for p in model.parameters() if p.requires_grad]

    NUM_EPOCHS      = 15
    LEARNING_RATE   = 0.005
    MOMENTUM        = 0.9
    WEIGHT_DECAY    = 0.0005
    STEP_SIZE       = 3
    GAMMA           = 0.1

    optimizer       = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    lr_scheduler    = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

    # early stopping implementation
    best_val_loss = float('inf')
    patience      = 3
    patience_cnt  = 0

    for epoch in range(NUM_EPOCHS):
        # train for one epoch
        train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=30)
        
        # update the learning rate
        lr_scheduler.step()
        
        # evaluate on the validation dataset
        val_loss = evaluate_model(model, dataset_val, device)
        print('The validation loss is {:.4f}'.format(val_loss))
        print()

        # apply early stopping
        if val_loss<best_val_loss:
            patience_cnt = 0
            best_val_loss = val_loss
            
            # save model
            torch.save(model.state_dict(), 'drive/MyDrive/model_parameters/model_{}.pt'.format(saving_time))

            # save train info
            columns           = ['LOSS','EPOCH','NUM_EPOCHS','LEARNING_RATE','MOMENTUM','WEIGHT_DECAY','STEP_SIZE','GAMMA','TRANSFORMS']
            train_variables   = np.array([[val_loss, epoch+1, NUM_EPOCHS, LEARNING_RATE, MOMENTUM, WEIGHT_DECAY, STEP_SIZE, GAMMA, 0]])
            transform_names   = [obj.__class__.__name__ for obj in get_transform(train=True).transforms]
            model_info        = pd.DataFrame(train_variables, columns=columns).astype(object)
            model_info.at[0,'TRANSFORMS'] = transform_names
            model_info.to_csv('drive/MyDrive/model_parameters/model_info_{}.csv'.format(saving_time), index=0, header=1)
        else:
            patience_cnt += 1

            if patience_cnt==patience:              
                # stop training
                break

    print('The model has been saved as {}'.format('drive/MyDrive/model_parameters/model_{}.pt'.format(saving_time)))

## 8. Load model

In [None]:
if DO_TEST:
    if not DO_TRAIN:
        saving_time = None
        model.load_state_dict(torch.load('drive/MyDrive/model_parameters/model_{}.pt'.format(saving_time)))
    else:
        model.load_state_dict(torch.load('drive/MyDrive/model_parameters/model_{}.pt'.format(saving_time)))

## 9. Test model and show predictions

In [None]:
def do_nms(boxes, labels, scores, IoU_threshold=0.3):
    # get indices to keep
    keep        = torchvision.ops.nms(boxes, scores, IoU_threshold)

    # keep selection
    boxes_nms   = boxes[keep]
    labels_nms  = labels[keep]
    scores_nms  = scores[keep]

    return boxes_nms, labels_nms, scores_nms

In [None]:
def test_model(model, device, dataset):
    # disable gradients
    for param in model.parameters():
        param.requires_grad = False

    # switch model to evaluation mode
    model.eval()

    # model to device
    model = model.to(device)

    # empty lists for the ground truth and predicted arrays
    Is           = []
    boxes_gts    = []
    labels_gts   = []
    boxes_preds  = []
    labels_preds = []
    scores_preds = []

    # evaluate test dataset
    with torch.no_grad():
        for img, target in dataset:
            # predict
            prediction = model([img.to(device)])[0]

            # append image
            Is.append(img.permute(1, 2, 0).numpy())
            
            # extract and append the ground truth boxes and labels
            boxes_gt    = target['boxes'].cpu().numpy()
            labels_gt   = target['labels'].cpu().numpy()
            boxes_gts.append(boxes_gt)
            labels_gts.append(labels_gt)

            # extract the predicted boxes, labels, and scores
            boxes_pred  = prediction['boxes']
            labels_pred = prediction['labels']
            scores_pred = prediction['scores']

            # apply non maximum suppression (IoU_threshold=0.2 corresponds with 33 procent overlap for two equally sized squares)
            boxes_nms, labels_nms, scores_nms = do_nms(boxes_pred, labels_pred, scores_pred, IoU_threshold=0.2)

            # append the predicted boxes, labels, and scores that follow from the non maximum suppression
            boxes_preds.append(boxes_nms.cpu().numpy())
            labels_preds.append(labels_nms.cpu().numpy())
            scores_preds.append(scores_nms.cpu().numpy())

    return Is, boxes_gts, labels_gts, boxes_preds, labels_preds, scores_preds

In [None]:
def plot_ground_truth_and_prediction(I, boxes_gt, labels_gt, boxes_pred, labels_pred, scores_pred, decision_threshold):
    boxes_pred  = boxes_pred[scores_pred>decision_threshold]
    labels_pred = labels_pred[scores_pred>decision_threshold]

    figure, ax = plt.subplots(1, 2, figsize=(8, 8))
    figure.tight_layout()
    figure.set_figwidth(13)

    ax[0].imshow(I, zorder=-10)
    for i in range(len(boxes_gt)):
        bbox = boxes_gt[i]
        rectangle_points = np.array([[bbox[0], bbox[2], bbox[2], bbox[0], bbox[0]],
                                      [bbox[1], bbox[1], bbox[3], bbox[3], bbox[1]]])
        label = labels_gt[i]
        if label==1:
            ax[0].plot(rectangle_points[0,:], rectangle_points[1,:], color='blue', linewidth=3, zorder=10)
        elif label==2:
            ax[0].plot(rectangle_points[0,:], rectangle_points[1,:], color='red', linewidth=3, zorder=20) 
    ax[0].set_title('Ground truth', fontsize=24)  
    ax[0].set_xlabel('Width of image (pixels)', fontsize=20)
    ax[0].set_ylabel('Height of image (pixels)', fontsize=20)
    ax[0].set_xticks(np.linspace(0, 900, 7))
    ax[0].set_yticks(np.linspace(0, 1200, 9))
    ax[0].tick_params(axis='both', which='major', labelsize=16)
    ax[0].set_xlim(0, 900)
    ax[0].set_ylim(1200, 0)

    ax[1].imshow(I, zorder=-10)
    for i in range(len(boxes_pred)):
        bbox = boxes_pred[i]
        rectangle_points = np.array([[bbox[0], bbox[2], bbox[2], bbox[0], bbox[0]],
                                      [bbox[1], bbox[1], bbox[3], bbox[3], bbox[1]]])
        label = labels_pred[i]
        if label==1:
            ax[1].plot(rectangle_points[0,:], rectangle_points[1,:], color='blue', linewidth=3, zorder=10)
        elif label==2:
            ax[1].plot(rectangle_points[0,:], rectangle_points[1,:], color='red', linewidth=3, zorder=20)
    ax[1].set_title('Predicted', fontsize=24)
    ax[1].set_xlabel('Width of image (pixels)', fontsize=20)
    ax[1].set_ylabel('Height of image (pixels)', fontsize=20)
    ax[1].set_xticks(np.linspace(0, 900, 7))
    ax[1].set_yticks(np.linspace(0, 1200, 9))
    ax[1].tick_params(axis='both', which='major', labelsize=16)
    ax[1].set_xlim(0, 900)
    ax[1].set_ylim(1200, 0)

    line_raw  = Line2D([0], [0], color='blue', linewidth=3, label='raw')
    line_ripe = Line2D([0], [0], color='red', linewidth=3, label='ripe')
    plt.legend(handles=[line_raw, line_ripe], loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)    

    for axis in ax.flat:
        axis.label_outer()

    plt.show()

In [None]:
if DO_TEST:
    Is, boxes_gts, labels_gts, boxes_preds, labels_preds, scores_preds = test_model(model, device, dataset_test)

In [None]:
if DO_TEST:
    ipywidgets.interact(lambda idx, threshold: plot_ground_truth_and_prediction(Is[idx],
                                                                     boxes_gts[idx], 
                                                                     labels_gts[idx], 
                                                                     boxes_preds[idx], 
                                                                     labels_preds[idx], 
                                                                     scores_preds[idx],
                                                                     decision_threshold=threshold), 
                        idx=range(len(Is)),
                        threshold=ipywidgets.FloatSlider(min=0, max=1, step=0.05))

## 10. Optimize decision threshold

In [None]:
def classify_boxes(boxes_ripe_gt, boxes_ripe_pred, IoU_threshold=0.5):
    dataframe_columns = ['x_min', 'y_min', 'x_max', 'y_max', 'idx_gt', 'idx_pred', 'type']
    boxes_result = pd.DataFrame(np.empty([0,7]), columns=dataframe_columns)

    if len(boxes_ripe_pred)>0 and len(boxes_ripe_gt)>0:
        # array with the IoU values, row is [IoU_box_gt1, IoU_box_gt2, ...]
        iou_values = np.empty([len(boxes_ripe_pred), len(boxes_ripe_gt)])

        for i, coor_pred in enumerate(boxes_ripe_pred):
            x_min, y_min, x_max, y_max = coor_pred
            for j, coor_gt in enumerate(boxes_ripe_gt):
                x_min_gt, y_min_gt, x_max_gt, y_max_gt = coor_gt

                x1_intersect = max(x_min, x_min_gt)
                x2_intersect = min(x_max, x_max_gt)
                y1_intersect = max(y_min, y_min_gt)
                y2_intersect = min(y_max, y_max_gt)

                if x1_intersect<x2_intersect and y1_intersect<y2_intersect:
                    area_pred       = (x_max-x_min)*(y_max-y_min)
                    area_gt         = (x_max_gt-x_min_gt)*(y_max_gt-y_min_gt)
                    area_intersect  = (x2_intersect-x1_intersect)*(y2_intersect-y1_intersect) 
                    iou_values[i,j] = area_intersect/(area_pred+area_gt-area_intersect) 
                else:
                    iou_values[i,j] = 0

        iou_values_argmax0 = np.argmax(iou_values, axis=0)
        iou_values_argmax1 = np.argmax(iou_values, axis=1)

        for idx in range(len(boxes_ripe_gt)):
            # get TP
            if iou_values_argmax1[iou_values_argmax0[idx]]==idx and iou_values[iou_values_argmax0[idx],idx]>IoU_threshold:
                new_box = pd.DataFrame(np.empty([1,7]), columns=dataframe_columns).astype(object)
                new_box.loc[0,['x_min','y_min','x_max','y_max']] = boxes_ripe_pred[iou_values_argmax0[idx],:]
                new_box.at[0,'idx_gt']    = idx
                new_box.at[0,'idx_pred']  = iou_values_argmax0[idx]
                new_box.at[0,'type']      = 'TP'
                boxes_result = boxes_result.append(new_box)
            # get FN
            else:
                new_box = pd.DataFrame(np.empty([1,7]), columns=dataframe_columns).astype(object)
                new_box.loc[0,['x_min','y_min','x_max','y_max']] = boxes_ripe_gt[idx,:]
                new_box.at[0,'idx_gt']    = idx
                new_box.at[0,'idx_pred']  = None
                new_box.at[0,'type']      = 'FN'
                boxes_result = boxes_result.append(new_box)

        TP_ids = [idx for idx in boxes_result['idx_pred'].tolist() if idx is not None]
        FP_ids = [idx for idx in range(len(boxes_ripe_pred)) if idx not in TP_ids]

        # get FP
        for idx in FP_ids:
            new_box = pd.DataFrame(np.empty([1,7]), columns=dataframe_columns).astype(object)
            new_box.loc[0,['x_min','y_min','x_max','y_max']] = boxes_ripe_pred[idx,:]
            new_box.at[0,'idx_gt']    = None
            new_box.at[0,'idx_pred']  = idx
            new_box.at[0,'type']      = 'FP'
            boxes_result = boxes_result.append(new_box)

        boxes_result = boxes_result.sort_values(['type'])
        boxes_result.reset_index(drop=True, inplace=True)

    elif len(boxes_ripe_pred)>0 and len(boxes_ripe_gt)==0:
        # get FP
        for idx in range(len(boxes_ripe_pred)):
            new_box = pd.DataFrame(np.empty([1,7]), columns=dataframe_columns).astype(object)
            new_box.loc[0,['x_min','y_min','x_max','y_max']] = boxes_ripe_pred[idx,:]
            new_box.at[0,'idx_gt']    = None
            new_box.at[0,'idx_pred']  = idx
            new_box.at[0,'type']      = 'FP'
            boxes_result = boxes_result.append(new_box)

        boxes_result = boxes_result.sort_values(['type'])
        boxes_result.reset_index(drop=True, inplace=True)

    elif len(boxes_ripe_pred)==0 and len(boxes_ripe_gt)>0:
        # get FN
        for idx in range(len(boxes_ripe_gt)):
            new_box = pd.DataFrame(np.empty([1,7]), columns=dataframe_columns).astype(object)
            new_box.loc[0,['x_min','y_min','x_max','y_max']] = boxes_ripe_gt[idx,:]
            new_box.at[0,'idx_gt']    = idx
            new_box.at[0,'idx_pred']  = None
            new_box.at[0,'type']      = 'FN'
            boxes_result = boxes_result.append(new_box)

        boxes_result = boxes_result.sort_values(['type'])
        boxes_result.reset_index(drop=True, inplace=True)

    return boxes_result

In [None]:
def calculate_boxes_results(Is, boxes_gts, labels_gts, boxes_preds, labels_preds, scores_preds, IoU_threshold, decision_threshold):
    boxes_results = []
    for i in range(len(Is)):
        boxes_gt    = boxes_gts[i]
        labels_gt   = labels_gts[i]
        scores_pred = scores_preds[i]
        boxes_pred  = boxes_preds[i][scores_pred>decision_threshold]
        labels_pred = labels_preds[i][scores_pred>decision_threshold]

        boxes_ripe_gt   = boxes_gt[labels_gt==2]
        boxes_ripe_pred = boxes_pred[labels_pred==2]

        boxes_result = classify_boxes(boxes_ripe_gt, boxes_ripe_pred, IoU_threshold)
        boxes_results.append(boxes_result)
    return boxes_results

In [None]:
def calculate_positives_negatives(boxes_result):
    num_TP, num_FP, num_FN = 0, 0, 0

    for box_type in boxes_result['type'].tolist():
        if box_type=='TP':
            num_TP += 1
        elif box_type=='FP':
            num_FP += 1
        elif box_type=='FN':
            num_FN += 1

    return num_TP, num_FP, num_FN

In [None]:
def calculate_f1_score(num_TP, num_FP, num_FN):
    if num_TP+num_FP>0 and num_TP+num_FN>0 and num_TP>0:
        precision   = num_TP/(num_TP+num_FP)
        recall      = num_TP/(num_TP+num_FN)
        f1_score    = 2*precision*recall/(precision+recall)
    else:
        precision   = np.nan
        recall      = np.nan
        f1_score    = np.nan
    return f1_score, precision, recall

In [None]:
def calculate_optimal_decision_threshold(model, device, dataset):
    decision_thresholds   = np.linspace(0, 1, 101).tolist()
    f1_scores             = []

    # predictions for dataset
    Is, boxes_gts, labels_gts, boxes_preds, labels_preds, scores_preds = test_model(model, device, dataset)

    # calculate f1 score for every decision threshold
    for threshold in decision_thresholds:
        # boxes_results
        boxes_results = calculate_boxes_results(Is, boxes_gts, labels_gts, boxes_preds, labels_preds, scores_preds, IoU_threshold=0.5, decision_threshold=threshold)

        # TP, FP, and FN
        total_num_TP, total_num_FP, total_num_FN = 0, 0, 0
        for boxes_result in boxes_results:
            num_TP, num_FP, num_FN = calculate_positives_negatives(boxes_result)
            total_num_TP += num_TP
            total_num_FP += num_FP
            total_num_FN += num_FN

        # results
        f1_score, precision, recall = calculate_f1_score(total_num_TP, total_num_FP, total_num_FN)
        f1_scores.append(f1_score)

    # calculate optimal value
    decision_thresholds   = np.array(decision_thresholds)[~np.isnan(f1_scores)]
    f1_scores             = np.array(f1_scores)[~np.isnan(f1_scores)]
    decision_threshold    = decision_thresholds[np.argmax(f1_scores)]
    f1_score              = np.max(f1_scores)

    return decision_thresholds, f1_scores, decision_threshold, f1_score

In [None]:
if DO_TEST:
    if DO_OPTIMIZE_THRESHOLD:
        decision_thresholds, f1_scores, decision_threshold, f1_score = calculate_optimal_decision_threshold(model, device, dataset_val)
        print('The optimal decision threshold is {}'.format(decision_threshold))
    else:
        decision_threshold = 0.5

In [None]:
if DO_TEST and DO_OPTIMIZE_THRESHOLD:
    plt.figure(figsize=(12, 7))
    plt.plot(decision_thresholds, f1_scores, color='blue', linewidth=3, zorder=10)
    plt.scatter(decision_threshold, f1_score, s=100, color='red', zorder=20, label='Optimal decision threshold')
    plt.xlabel('Decision threshold', fontsize=20)
    plt.ylabel('F1 score', fontsize=20)
    plt.xticks(np.linspace(0, 1, 11), fontsize=16)
    plt.yticks(np.linspace(0, min(int(10*(max(f1_scores)+0.1))/10, 1), int(10*min(int(10*(max(f1_scores)+0.1))/10, 1)+1)), fontsize=16)
    plt.xlim(0, 1)
    plt.ylim(0, min(int(10*(max(f1_scores)+0.1))/10, 1))
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20) 
    plt.grid(zorder=-10)
    plt.show()

## 11. Calculate results and show results

In [None]:
def plot_img_and_results(I, boxes_result):
    num_TP, num_FP, num_FN = calculate_positives_negatives(boxes_result)

    plt.figure(figsize=(10, 10))
    plt.imshow(I, zorder=-10)
    for i in range(len(boxes_result)):
        coor = boxes_result.loc[i,['x_min', 'y_min', 'x_max', 'y_max']].to_numpy()
        rectangle_points = np.array([[coor[0], coor[2], coor[2], coor[0], coor[0]],
                                     [coor[1], coor[1], coor[3], coor[3], coor[1]]])
        if boxes_result.loc[i,'type']=='TP':
            plt.plot(rectangle_points[0,:], rectangle_points[1,:], color='gold', linewidth=3, zorder=10)
        elif boxes_result.loc[i,'type']=='FP':
            plt.plot(rectangle_points[0,:], rectangle_points[1,:], color='red', linewidth=3, zorder=10)
        elif boxes_result.loc[i,'type']=='FN':
            plt.plot(rectangle_points[0,:], rectangle_points[1,:], color='blue', linewidth=3, zorder=10)
    plt.title('Result: {} TP, {} FP, {} FN'.format(num_TP, num_FP, num_FN), fontsize=24)
    plt.xlabel('Width of image (pixels)', fontsize=20)
    plt.ylabel('Height of image (pixels)', fontsize=20)
    plt.xticks(np.linspace(0, 900, 7), fontsize=16)
    plt.yticks(np.linspace(0, 1200, 9), fontsize=16)
    plt.xlim(0, 900)
    plt.ylim(1200, 0)
    line_TP = Line2D([0], [0], color='gold', linewidth=3, label='TP')
    line_FP = Line2D([0], [0], color='red', linewidth=3, label='FP')
    line_FN = Line2D([0], [0], color='blue', linewidth=3, label='FN')
    plt.legend(handles=[line_TP, line_FP, line_FN], loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)
    plt.show()
    plt.savefig('plot1.png')

In [None]:
if DO_TEST:
    # IoU_threshold=0.5 corresponds with 67 procent overlap for two equally sized squares
    IoU_threshold   = 0.5
    
    # boxes_results
    Is, boxes_gts, labels_gts, boxes_preds, labels_preds, scores_preds = test_model(model, device, dataset_test)
    boxes_results = calculate_boxes_results(Is, boxes_gts, labels_gts, boxes_preds, labels_preds, scores_preds, IoU_threshold, decision_threshold)

    # TP, FP, and FN
    total_num_TP, total_num_FP, total_num_FN = 0, 0, 0
    for boxes_result in boxes_results:
        num_TP, num_FP, num_FN = calculate_positives_negatives(boxes_result)
        total_num_TP += num_TP
        total_num_FP += num_FP
        total_num_FN += num_FN

    # results
    f1_score, precision, recall = calculate_f1_score(total_num_TP, total_num_FP, total_num_FN)

    print('The model has a f1_score of {:.3f} on the test dataset.'.format(f1_score))
    print('The model has a precision of {:.3f} on the test dataset.'.format(precision))
    print('The model has a recall of {:.3f} on the test dataset.'.format(recall))

In [None]:
if DO_TEST:
    ipywidgets.interact(lambda idx: plot_img_and_results(Is[idx], boxes_results[idx]), idx=range(len(Is)))