In [None]:

# CNN 
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn


import numpy as np
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory
import time
import random


import torchvision
import torchvision.transforms as transforms


from pycocotools.coco import COCO
import cv2


cudnn.benchmark = True
plt.ion()   

In [None]:
cocoRoot = "./Data/Coco/"
dataType = "val2017"

annFile = os.path.join(cocoRoot, f'annotations/instances_{dataType}.json')
print(f'Annotation file: {annFile}')

coco=COCO(annFile)
coco 

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


def plot_image_with_annotations(coco, cocoRoot, dataType, imgId, ax=None):
    # Get image information
    imgInfo = coco.loadImgs(imgId)[0]
    # Get image location for visualization
    imPath = os.path.join(cocoRoot, dataType, imgInfo['file_name'])    
    # Read the image
    im = cv2.imread(imPath)
    # Convert color space: OpenCV defaults to BGR, but matplotlib to RGB, so conversion is needed
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

    # Find all annotations (bounding boxes) for the image
    annIds = coco.getAnnIds(imgIds=imgInfo['id'])
    # Load all annotation information: bounding box coordinates, labels, accuracies
    anns = coco.loadAnns(annIds)
    all_labels = set()

    # Extract bounding box coordinates, labels, accuracies
    for ann in anns:
        # Specifically select information related to the bounding box: returns (x, y) of the lower-left corner, width, height
        x, y, w, h = ann['bbox']

        # Get label text information: load category name by category ID
        label = coco.loadCats(ann['category_id'])[0]["name"]
        all_labels.add(label)

        # Draw bounding boxes using provided coordinates
        rect = Rectangle((x, y), w, h, linewidth=2, edgecolor='r', facecolor='none')
    
        # Draw the image: if sorting of images is needed, ax parameter specifies the position
        if ax is None:
            plt.gca().add_patch(rect) 
            plt.text(x, y, f'{label}', fontsize=10, color='w', backgroundcolor='r')
        else:
            ax.add_patch(rect)
            ax.text(x, y, f'{label}', fontsize=10, color='w', backgroundcolor='r')

    # Display the image with a title
    if ax is None:
        plt.imshow(im)
        plt.axis('off')
        plt.title(f'Annotations: {all_labels}', color='r')
        plt.show()
    else:
        ax.axis('off')
        ax.set_title(f'Annotations: {all_labels}', color='r', loc='center', pad=20)
        ax.imshow(im)


# Get the tenth image
imgIds = coco.getImgIds()
imgId = imgIds[10]

plot_image_with_annotations(coco, cocoRoot, dataType, imgId)

In [None]:
def random_select(coco, cocoRoot, dataType, num_images=10):
    
    imgIds = coco.getImgIds()
    
    selected_imgIds = random.sample(imgIds, num_images)
    
    for imgId in selected_imgIds:
        
        plot_image_with_annotations(coco, cocoRoot, dataType, imgId)
    
    
    return selected_imgIds
    
valid_ids = random_select(coco, cocoRoot, dataType, num_images=10)
valid_ids

In [None]:

model_res = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="RCNN_ResNet50_FPN_Weights.DEFAULT")
model_res.eval()

model_mobile = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=torchvision.models.detection.RCNN_MobileNet_V3_Large_FPN_Weights)
model_mobile.eval()

In [None]:
from PIL import Image

def load_image(imgIdx):
    
    imgInfo = coco.loadImgs(imgIdx)[0]
    
    imPath = os.path.join(cocoRoot, dataType, imgInfo['file_name'])    
    
    print(imPath)
    try:
         
        return Image.open(imPath)
    except:
        raise Exception()


def pil2tensor(pil_image):
    
    return torchvision.transforms.PILToTensor()(pil_image).unsqueeze(0) / 255.0

In [None]:
 
predictions_res = []
predictions_mobile = []

for i in valid_ids:
    print(i)
    
    img_as_tensor = pil2tensor(load_image(i))
   
    prediction = model_res(img_as_tensor)
    
    predictions_res.append(prediction)

    prediction = model_mobile(img_as_tensor)
    
    predictions_mobile.append(prediction)

In [None]:
def filter_valid_boxes(predictions, threshold=0.8):
    
    valid_boxes_list = []
    
    for prediction in predictions:
        valid_boxes_for_this_prediction = []
        
        for box, label, score in zip(prediction[0]["boxes"], prediction[0]["labels"], prediction[0]["scores"]):
            
            if score >= threshold: 
                
                valid_boxes_for_this_prediction.append((box, label, score))
        
        valid_boxes_list.append(valid_boxes_for_this_prediction)
    
    return valid_boxes_list


valid_boxes_res = filter_valid_boxes(predictions_res, threshold=0.8)
valid_boxes_mobile = filter_valid_boxes(predictions_mobile, threshold=0.8)

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os


 
def display_annotated_results(imgId, valid_boxes, model_name, color='g', ax=None):
    
    imgInfo = coco.loadImgs(imgId)[0]
    image_path = os.path.join(cocoRoot, dataType, imgInfo['file_name'])
    image = Image.open(image_path)

    
    annIds = coco.getAnnIds(imgIds=imgInfo['id'])
    anns = coco.loadAnns(annIds)
    bbox_tlist_anns = torch.tensor([ann["bbox"] for ann in anns]) # tensor.shape[2,4]
   
    bbox_tlist_anns[:, 2] = bbox_tlist_anns[:, 0] + bbox_tlist_anns[:, 2]
    bbox_tlist_anns[:, 3] = bbox_tlist_anns[:, 1] + bbox_tlist_anns[:, 3]
    
    
    bbox_tlist_model = torch.stack([box for box, _, _ in valid_boxes]) 
    
    iou = torchvision.ops.box_iou(bbox_tlist_anns, bbox_tlist_model) 
    
    avg_iou = np.mean([t.cpu().detach().numpy().max() for t in iou]) 
    
    all_labels = set()

    
    for boxes in valid_boxes:
        
        box, label, score = boxes

        label = coco.loadCats(label.item())[0]["name"]
        
        all_labels.add(label)
        
        x, y, x2, y2 = box.detach().numpy() 
        rect = Rectangle((x, y), x2 - x, y2 - y, linewidth=2, edgecolor=color, facecolor='none')

        
        if ax is None:
            
            plt.gca().add_patch(rect) 
             
            plt.text(x, y, f'{label}', fontsize=10, color='w', backgroundcolor=color)
        else:
            
            ax.add_patch(rect)
            
            ax.text(x, y, f'{label}', fontsize=10, color='w', backgroundcolor=color)
    
    
    if ax is None:
        plt.axis('off')
        plt.title(f'{model_name}: {all_labels} \n IoU: {avg_iou:.4f}', color=color)
        plt.imshow(image)
        plt.show()

    else:
        ax.axis('off')
        ax.set_title(f'{model_name}: {all_labels} \n I0U: {avg_iou:.4f}', color=color)
        ax.imshow(image)
    
    return avg_iou


res_iou = []
mobile_iou = []


for i in range(len(valid_ids)):
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    
    plot_image_with_annotations(coco, cocoRoot, dataType, valid_ids[i], ax=axs[1])

    
    i_mobil_iou = display_annotated_results(valid_ids[i], valid_boxes_mobile[i], "mobile", color='g', ax=axs[0])
    i_res_iou = display_annotated_results(valid_ids[i], valid_boxes_res[i], "ResNet", color='b', ax=axs[2])

    
    mobile_iou.append(i_mobil_iou)
    res_iou.append(i_res_iou)

    
    plt.tight_layout()


print("ResNet: Avg.", np.mean(res_iou), "; each IoU:", res_iou)
print("MobileNet: Avg.", np.mean(mobile_iou), "; each IoU:", mobile_iou)

In [None]:
torchvision.ops.box_iou(bbox_tlist_anns, bbox_tlist_model) 

In [None]:

avg_iou = np.mean([t.cpu().detach().numpy().max() for t in iou]) 