# Deep Fruit Vision Evaluation

Our detection, ripeness, and defect models are all trained and tested on different datasets. This notebook evaluates the overall performance of Deep Fruit Vision on a unique, hand-labelled test dataset.

In [1]:
import os
import shutil

from tqdm.notebook import tqdm
import torch
from torch.utils.data import Subset
from deepfruitvision import DeepFruitVision
from modules.datasets import EnsembleDataset, save_dataset
from torchvision.ops import box_iou

import matplotlib
from matplotlib import pyplot as plt

import numpy as np

%matplotlib inline

In [2]:
best_classification_weights = os.path.join('weights', 'detection', 'best_classification.pt')
best_detection_weights = os.path.join('weights', 'detection', 'best_detection.pt')
ripeness_weights = os.path.join('weights', 'ripeness', 'ripeness_model_fine_tuned')
defect_weights = os.path.join('weights', 'defect', 'defect_model_fine_tuned')

We can use the EnsembleDataset class to easily load images and labels from the test dataset. We reserve a bit of the ensemble dataset for fine-tuning our models. We use the seed to make sure that we get the same images across runs.

In [3]:
eval_dir = os.path.join('dataset', 'deepfruitvision_eval')
detection_eval_img_dir = os.path.join(eval_dir, 'images')
detection_eval_label_dir = os.path.join(eval_dir, 'labels')

os.makedirs(detection_eval_img_dir, exist_ok=True)
os.makedirs(detection_eval_label_dir, exist_ok=True)

seed = 123
np.random.seed(seed)
num_fine_tune_samples = 50

In [4]:
ensemble_dataset = EnsembleDataset('dataset', for_yolov5=True)

random_indices = np.random.permutation(len(ensemble_dataset))
fine_tune_indices = random_indices[:num_fine_tune_samples]
test_indices = random_indices[num_fine_tune_samples:]

In [5]:
yolov5_test_dataset = Subset(ensemble_dataset, test_indices)
save_dataset(yolov5_test_dataset, detection_eval_img_dir, detection_eval_label_dir) # save the test split of the ensemble dataset with Yolo-v5 labels to the disk

ensemble_dataset.for_yolov5 = False
deepfruitvision_test_dataset = Subset(ensemble_dataset, test_indices) # then just get the test split of the ensemble dataset

Saving dataset to dataset\deepfruitvision_eval\images and dataset\deepfruitvision_eval\labels: 100%|██████████| 105/105 [00:01<00:00, 87.31it/s]


## Detection mAP

The detection/Yolo-v5 model is responsible for detecting and classifying fruits, so we can just use Yolo-v5's built-in evaluation script. We've provided two Yolo-v5 models. One is better at detecting fruits (high bounding box accuracy, but has difficulty classifying fruits) and the other is better at classification (better at classifying fruits, but has less accurate bounding boxes). We'll evaluate both models and compare their performances.

The following two tests are each Yolo-v5 model on the detection and classification task.

In [None]:
%run yolov5/val.py --data fine_tune_apple_papaya_mango.yaml --weights {best_detection_weights} --img 416 --task test

In [None]:
%run yolov5/val.py --data fine_tune_apple_papaya_mango.yaml --weights {best_classification_weights} --img 416 --task test

The `best classification` appears to do much better than the 'best detection' weights, but that's because it often mixed up the classes. If we ignore the classes and just judge the models by detection performance, we can see how good the `best detection` model is at fruit detection.

In [None]:
%run yolov5/val.py --data fine_tune_apple_papaya_mango.yaml --weights {best_detection_weights} --img 416 --task test --single-cls

In [None]:
%run yolov5/val.py --data fine_tune_apple_papaya_mango.yaml --weights {best_classification_weights} --img 416 --task test --single-cls

And now we can clean the Yolov5 eval dir because we don't need it any more.

In [None]:
# also, clean up the yolov5 eval dataset
shutil.rmtree(eval_dir)

## Ensemble Classification Accuracy

Now we want to evaluate the accuracy of the ensemble classification model. Since we only care about the accuracy of the harvestability, we ignore any bounding box that is too small to be classified.

In [6]:
fruit_vision = DeepFruitVision(best_classification_weights, ripeness_weights, defect_weights)

YOLOv5  2022-11-21 Python-3.10.5 torch-1.12.0 CUDA:0 (NVIDIA GeForce GTX 1070, 8192MiB)

Fusing layers... 
Model summary: 157 layers, 7018216 parameters, 0 gradients, 15.8 GFLOPs
Adding AutoShape... 


In [7]:
# use Matplotlib to display the true and predicted bounding boxes
# true bounding boxes are a list of dicts with x, y, w, and h keys (all normalized, not pixel values)
# predicted bounding boxes are a list of dicts with xmin, ymin, xmax, ymax (all normalized, not pixel values)
def display_bboxes(img, true_bboxes, pred_bboxes):
    img_h, img_w, _ = img.shape
    fig, ax = plt.subplots(1)
    ax.imshow(img)
    for bbox in true_bboxes:
        x = bbox['x'] * img_w
        y = bbox['y'] * img_h
        w = bbox['w'] * img_w
        h = bbox['h'] * img_h
        rect = matplotlib.patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='g', facecolor='none')
        ax.add_patch(rect)
    for bbox in pred_bboxes:
        x = bbox['xmin'] * img_w
        y = bbox['ymin'] * img_h
        w = (bbox['xmax'] - bbox['xmin']) * img_w
        h = (bbox['ymax'] - bbox['ymin']) * img_h
        rect = matplotlib.patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
    plt.show()

for i in range(1):
    img, true_boxes = deepfruitvision_test_dataset[i]
    pred_boxes = fruit_vision.get_harvestability(img)
    display_bboxes(img, true_boxes, pred_boxes)

In [8]:
import cv2
import tensorflow as tf

def get_ripeness(img, true_box):
    # get the fruit from the image
    img_h, img_w, _ = img.shape
    x, y, w, h = int(true_box['x'] * img_w), int(true_box['y'] * img_h), int(true_box['w'] * img_w), int(true_box['h'] * img_h)
    cropped_img = img[y:y+h, x:x+w]

    # resize the image to 224x224
    cropped_img = cv2.resize(cropped_img, (224, 224))
    # convert it to a tf tensor and normalize it
    cropped_img = tf.convert_to_tensor(cropped_img, dtype=tf.float32)

    # add a batch dimension
    cropped_img = tf.expand_dims(cropped_img, axis=0)
    # feed it to deepfruitvision's ripeness model
    ripeness = fruit_vision.ripeness_module.get_ripeness_predictions(cropped_img)
    return ripeness[0]


In [9]:
min_box_size = fruit_vision.min_bounding_box_size

total_boxes = 0
total_correct = 0

correct_bbox = 0
correct_ripeness = 0
correct_ripeness_true_bbox = 0
correct_defect = 0

for frame, true_boxes in tqdm(deepfruitvision_test_dataset, desc='Evaluating DeepFruitVision', total=len(deepfruitvision_test_dataset)):
    predicted_boxes = fruit_vision.get_harvestability(frame)

    # ignore any boxes in the predicted and true boxes that are too small to be classified
    predicted_boxes = [box for box in predicted_boxes if box['xmax'] - box['xmin'] > min_box_size and box['ymax'] - box['ymin'] > min_box_size]
    # the true boxes are numpy arrays with the format [class, x, y, w, h, ...]
    true_boxes = [box for box in true_boxes if box['w'] > min_box_size and box['h'] > min_box_size]

    if len(true_boxes) == 0: # if there are no boxes large enough to be classified, then we can't evaluate this frame
        continue

    total_boxes += len(true_boxes)

    if len(predicted_boxes) == 0: # if there are no predicted boxes, then we automatically get 0 correct
        continue
    
    # now we have to convert the predicted and true boxes to tensors so we can use Yolo-v5's built-in IoU function
    predicted_boxes_tensor = torch.tensor([[box['xmin'], box['ymin'], box['xmax'], box['ymax']] for box in predicted_boxes])
    # each box has to have the format [x1, y1, x2, y2] where x1 < x2 and y1 < y2
    true_boxes_tensor = torch.tensor([[box['x'], box['y'], box['x'] + box['w'], box['y'] + box['h']] for box in true_boxes])

    # this returns a tensor [num_true_boxes, num_predicted_boxes] that we can use to determine which of the true boxes have a corresponding predicted box
    iou = box_iou(true_boxes_tensor, predicted_boxes_tensor)

    # the max of each row will be the IoU of the predicted box with the true box
    max_ious, max_iou_indices = torch.max(iou, dim=1)

    for i, (max_iou, max_iou_index) in enumerate(zip(max_ious, max_iou_indices)):

        ripeness_true_bbox  = get_ripeness(frame, true_boxes[i])

        if ripeness_true_bbox[0] == true_boxes[i]['ripeness']:
            correct_ripeness_true_bbox += 1

        if max_iou > 0.5: # if the IoU is greater than 0.5, then we consider it a correct prediction
            correct_bbox += 1

            if ripeness_true_bbox == true_boxes[i]['ripeness']:
                correct_ripeness_true_bbox += 1

            if true_boxes[i]['ripeness'] == predicted_boxes[max_iou_index]['ripeness'][0]:
                correct_ripeness += 1

            if true_boxes[i]['defect'] == predicted_boxes[max_iou_index]['defect'][0]:
                correct_defect += 1

            true_ensemble_label = true_boxes[i]['ensemble']
            predicted_harvestability_label = predicted_boxes[max_iou_index]['harvestability']

            if true_ensemble_label == predicted_harvestability_label:
                total_correct += 1

print(f'Got {total_correct} out of {total_boxes} correct ({total_correct / total_boxes * 100:.2f}%)')
print(f'Got {correct_bbox} out of {total_boxes} correct bounding boxes ({correct_bbox / total_boxes * 100:.2f}%)')
print(f'Got {correct_ripeness} out of {total_boxes} correct ripeness labels ({correct_ripeness / total_boxes * 100:.2f}%)')
print(f'Got {correct_ripeness_true_bbox} out of {total_boxes} correct ripeness labels when the bounding box was correct ({correct_ripeness_true_bbox / total_boxes * 100:.2f}%)')
print(f'Got {correct_defect} out of {total_boxes} correct defect labels ({correct_defect / total_boxes * 100:.2f}%)')

Evaluating DeepFruitVision:   0%|          | 0/105 [00:00<?, ?it/s]

Got 264 out of 338 correct (78.11%)
Got 320 out of 338 correct bounding boxes (94.67%)
Got 274 out of 338 correct ripeness labels (81.07%)
Got 278 out of 338 correct ripeness labels when the bounding box was correct (82.25%)
Got 292 out of 338 correct defect labels (86.39%)
