In [1]:
from application.src.models.faster_rcnn import load_model
from application.src.config import TEST_DIR
from torchvision.transforms import transforms
from application.src.models.custom_dataset import test_dataset
from PIL import Image
import os
import heapq
import copy

In [2]:
model = load_model()
transform = transforms.Compose([transforms.ToTensor()])

In [3]:
image_index = 0
all_correct_labels = []
all_predicted_labels = []
all_scores = []
all_boxes = []

for subdir, dirs, files in os.walk(TEST_DIR):
    for i, file in enumerate(files):
        if file[-4:] == '.jpg':
            img_path = os.path.join(subdir, file)
            PIL_image = Image.open(img_path)
            test_image = transform(PIL_image)
            test_image = test_image.view(1, 3, test_image.shape[1], test_image.shape[2])
            _, target = test_dataset[image_index]
            correct_labels, predicted_labels, scores, boxes = model.verify(test_image, target)
            all_correct_labels.append(correct_labels)
            all_predicted_labels.append(predicted_labels)
            all_scores.append(scores)
            all_boxes.append(boxes)
            image_index += 1
        # if i > 10:
        #     print('Loop ends')
        #     break
print(f'No of images in path: {image_index}')
print(f'Lengths: {len(all_correct_labels)}, {len(all_predicted_labels)}, {len(all_scores)}, {len(all_boxes)}')

No of images in path: 215
Lengths: 215, 215, 215, 215


### Accuracy

In [4]:
def check_guesses(list_predicted_labels, list_correct_labels):
    # Checking for guesses and false positives + creating baseline
    guesses = 0  # A predicted label is found among the correct labels
    false_pos = 0  # A predicted label is not found among the correct labels
    baseline = 0  # Baseline is the total number of correct labels to be found
    for i, labels in enumerate(list_predicted_labels):
        correct_labels = list_correct_labels[i]
        baseline += len(correct_labels)
        for label in labels:
            if label in correct_labels:
                guesses += 1
                correct_labels.remove(label)
            else:
                false_pos += 1
    return guesses, false_pos, baseline

In [5]:
def check_hits(list_predicted_labels, list_correct_labels, list_scores):
    # Checking for true hits
    hits = 0  # A top predicted label equals a correct label
    for i, labels in enumerate(list_predicted_labels):
        scores = list_scores[i]
        correct_labels = list_correct_labels[i]
        top_predictions = heapq.nlargest(len(correct_labels), zip(scores, labels))
        for prediction in top_predictions:
            label = prediction[1]
            if label in correct_labels:
                hits += 1
                correct_labels.remove(label)
    return hits

In [6]:
# These indicators do not take into account if the bounding boxes match
list_correct_labels = copy.deepcopy(all_correct_labels)
list_predicted_labels = copy.deepcopy(all_predicted_labels)
list_scores = copy.deepcopy(all_scores)
guesses, false_pos, baseline = check_guesses(list_predicted_labels, list_correct_labels)

list_correct_labels = copy.deepcopy(all_correct_labels)
hits = check_hits(list_predicted_labels, list_correct_labels, list_scores)

print(f'Correct predictions: {hits}, Percent of baseline: {round((hits / baseline) * 100, 0)}')
print(f'Uncertain predictions: {guesses}, Percent of baseline: {round((guesses / baseline) * 100, 0)}')
print(f'False predictions: {false_pos}, Percent of baseline: {round((false_pos / baseline) * 100, 0)}')
print(f'Baseline (number of annotated signs to find): {baseline}')

Correct predictions: 444, Percent of baseline: 86.0
Uncertain predictions: 483, Percent of baseline: 93.0
False predictions: 508, Percent of baseline: 98.0
Baseline (number of annotated signs to find): 517


In [7]:
counter = 0
for i, list in enumerate(all_predicted_labels):
    for prediction in list:
        counter += 1
print(counter)

counter = 0
for i, list in enumerate(all_correct_labels):
    for label in list:
        counter += 1
print(counter)

991
517
