In [18]:
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data

import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision import transforms

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2

import seaborn as sns
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [19]:
def create_model(device):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
    num_classes = 2  # 1 class + background
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model.to(device)

transformations= transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

def show_image(pred):
    filename = pred['file']
    image = cv2.imread(filename, cv2.COLOR_BGR2RGB)
    palette = sns.color_palette(None, 2)
    color = palette[1]
    for bbox in pred['bboxes']:
        image = cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), np.array(color) * 255, 2)
    
    cv2.imshow("test", image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


In [20]:
DETECTOR_MODEL_PATH = './models/detection_model.pt'


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = create_model(device)
model.load_state_dict(torch.load(DETECTOR_MODEL_PATH, map_location=torch.device('cpu')))

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to C:\Users\flash/.cache\torch\hub\checkpoints\resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))




<All keys matched successfully>

In [124]:
def postprocess_bbxes(pred):
    bbox_matrices = []
    filtered_pred = {'file': pred['file'], 'bboxes':[]}
    for bbox in pred['bboxes']:
        matrix_set = set()
        for x in range(bbox[0], bbox[2]):
            for y in range(bbox[1], bbox[3]):
                matrix_set.add((x, y))
        bbox_matrices.append(matrix_set)
        
    for i in range(len(bbox_matrices)):
        current_matrix = bbox_matrices[i]
        is_max_contour = False
        for j in range(len(bbox_matrices)):
            compared_matrix = bbox_matrices[j]
            if (len(current_matrix) > len(compared_matrix)) and (len(current_matrix & compared_matrix)/len(compared_matrix) > 0.7):
                is_max_contour = True
                
        if not(is_max_contour):
            filtered_pred['bboxes'].append(pred['bboxes'][i])
            
    return(filtered_pred)

In [125]:
preds = []
THRESHOLD_SCORE = 0.6
model.eval()
filename = './datasets/food/1_3317.png'
img = Image.open(filename)

img_tensor = transformations(img)

with torch.no_grad():
    predictions = model([img_tensor.to(device)])
prediction = predictions[0]

pred = dict()
pred['file'] = filename
pred['bboxes'] = []

for i in range(len(prediction['boxes'])):
    x_min, y_min, x_max, y_max = map(int, prediction['boxes'][i].tolist())
    label = int(prediction['labels'][i].cpu())
    score = float(prediction['scores'][i].cpu())
    if score > THRESHOLD_SCORE:      
        pred['bboxes'].append([x_min, y_min, x_max, y_max])
        
postprocessed_pred = postprocess_bbxes(pred)    

preds.append(postprocessed_pred)        
show_image(postprocessed_pred)