In [1]:
###################################################################################################
#
# Copyright (C) 2022 Maxim Integrated Products, Inc. All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################

import os
import sys

import numpy as np
import torch

import matplotlib.patches as patches
import matplotlib.pyplot as plt

sys.path.append(os.path.join(os.getcwd(), '..'))
sys.path.append(os.path.join(os.getcwd(), '../models/'))

from collections import OrderedDict

from PIL import Image

import ai8x
from datasets import helen
from utils import parse_obj_detection_yaml

ai85net_tinierssd = __import__("ai85net-tinierssd")

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_path = '../data/'

class Args:
    def __init__(self, act_mode_8bit):
        self.act_mode_8bit = act_mode_8bit
        self.truncate_testset = False

args = Args(act_mode_8bit=False)

_, test_set = helen.HELEN_74_get_datasets((data_path, args), load_train=False, load_test=True)
# test_set = train_set

Test dataset length: 1083



In [3]:
num_classes = 2

ai8x.set_device(85, False, False)

model = ai85net_tinierssd.ai85tinierssd(num_classes=num_classes, device=device)

# Run training first, using scripts/train_svhn_tinierssd.sh
# checkpoint = torch.load('../logs/COMBINED_FACE_MIRRORED/ai85-tinierssd-helen-combined-aug-q.pth.tar')
# checkpoint = torch.load('../logs/CORRECTED_COMBINED_HELEN_MIRRORED/checkpoint.pth.tar')
checkpoint = torch.load('../logs/HIGHEST_SO_FAR/best.pth.tar')
# checkpoint = torch.load('../logs/COMBINED_TWO_EYES/best.pth.tar')
state_dict = checkpoint['state_dict']

is_multi_gpu = all([key.startswith('module') for key in state_dict.keys()])

# print(state_dict.keys())
# print(model.state_dict().keys())
# print(state_dict.keys()==model.state_dict().keys())

if is_multi_gpu:
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        new_key = key.replace('module.', '')
        new_state_dict[new_key] = value
    model.load_state_dict(new_state_dict)

else:
    model.load_state_dict(state_dict)

model = model.to(device)

Configuring device: MAX78000, simulate=False.


In [4]:
obj_detection_params_yaml_file = '../parameters/obj_detection_params_svhn.yaml'
obj_detection_params = parse_obj_detection_yaml.parse(obj_detection_params_yaml_file)

{'multi_box_loss': {'alpha': 2, 'neg_pos_ratio': 3}, 'nms': {'min_score': 0.2, 'max_overlap': 0.3, 'top_k': 20}}


In [11]:
selected_idx = np.random.randint(len(test_set))
selected_idx = 883

img, (boxes, lbls) = test_set[selected_idx]
img = img.to(device)
img_to_plot = ((128*(img.detach().cpu().numpy()+1))).astype(np.uint8)
img_to_plot = img_to_plot.transpose([1,2,0])

img_model = img.unsqueeze(0)
locs, scores = model(img_model)

all_images_boxes, all_images_labels, all_images_scores = \
    model.detect_objects(locs, scores,
                         min_score=obj_detection_params['nms']['min_score'],
                         max_overlap=obj_detection_params['nms']['max_overlap'],
                         top_k=obj_detection_params['nms']['top_k'])


img_ = Image.fromarray(img)
img_.save(f"img_{selected_idx}_true.png")

16428

In [None]:
fig, ax = plt.subplots(1)
ax.imshow(img_to_plot)

plt.tick_params(labelsize=16)
        
subplot_title=("Test set item: " + str(selected_idx))
ax.set_title(subplot_title, fontsize = 20)

# Truth boxes
boxes_resized = [[box_coord * test_set.resize_size[0] for box_coord in box] for box in boxes]
# for b in range(len(boxes)):
for b in range(2):
    bb = boxes_resized[b]
    rect = patches.Rectangle((bb[0], bb[1]), bb[2] - bb[0], bb[3] - bb[1], linewidth=3,
                            edgecolor='r', facecolor="none")
    # ax.text(bb[0],(bb[1]), 'truth: eye', verticalalignment='center', color='white', fontsize=14, weight='bold')
    ax.add_patch(rect)

# # Predicted boxes
# boxes_resized = [[box_coord * test_set.resize_size[0] for box_coord in box.detach().cpu().numpy()] for box in all_images_boxes]
# detected_labels = [val.item() if val.item() != 10 else 0 for val in all_images_labels[0]]
# for b in range(len(boxes_resized[0])):
#     if(detected_labels[b] != 0):
#         bb = boxes_resized[0][b]
#         rect = patches.Rectangle((bb[0], bb[1]), bb[2] - bb[0], bb[3] - bb[1], linewidth=3,
#                                 edgecolor='b', facecolor="none")
        
#         ax.text(bb[0],(bb[1]), detected_labels[b], verticalalignment='center', color='white', fontsize=18, weight='bold')
#         ax.add_patch(rect)    

# Predicted boxes
boxes_resized = [[box_coord * test_set.resize_size[0] for box_coord in box.detach().cpu().numpy()] for box in all_images_boxes]

detected_labels = [val.item() if val.item() != 10 else 0 for val in all_images_labels[0]]
for b in range(len(boxes_resized[0])):
    if(detected_labels[b] != 0) and all_images_scores[0][b] > 0.7:
        bb = boxes_resized[0][b]
        rect = patches.Rectangle((bb[0], bb[1]), bb[2] - bb[0], bb[3] - bb[1], linewidth=3,
                                edgecolor='b', facecolor="none")
        
        ax.text(bb[0],(bb[1]), f'eye: {all_images_scores[0][b]:.2f}', verticalalignment='center', color='white', fontsize=14, weight='bold')
        ax.add_patch(rect)    

plt.show()

# test
# 55, 929, 951, 182, 93, 1034, 565, 883, 931, 327

# with nose
# 487