In [None]:
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
from keras_retinanet.utils.visualization import draw_box, draw_caption
from keras_retinanet.utils.colors import label_color
from keras_retinanet import models

import matplotlib.pyplot as plt
import cv2
import os
import numpy as np
from tqdm import tqdm
import configparser


In [None]:
class ObjectDetector:
    
    def __init__(self,retinanet_weight_path,class_names,viz_path,text_path,backbone,palette):
        
        self.class_names = class_names
        self.retinanet_weights = retinanet_weight_path
        self.viz_path = viz_path
        self.text_path = text_path
        self.palette = palette
        self.backbone = backbone
        
        
    def load_retinanet(self):
        model = models.load_model(self.retinanet_weights, backbone_name=self.backbone)
        self.model = models.convert_model(model)
        
    
    def filter_bbox(self,boxes,scores,classes,score_threshold=0.3):
        res = np.where(scores[0] ==(list(filter(lambda i: i < score_threshold, scores[0]))[0]))[0][0]
#scores returned by retinanet model are already retured as sorted from highest to lowest in terms of confidence score, so by filtering like this, we get indices of predictions that are less than confidence threshold, so we retain outputs from start till last prediction that is greater than equal to score threshold
        return boxes[0][0:res],scores[0][0:res],classes[0][0:res]
    
    
    def convert_detections_to_dict(self,boxes,scores,classes):
        scores_map = {k:[] for k in self.class_names}
        for box,score,class_item in zip(boxes,scores,classes):
            scores_map[self.class_names[class_item]].append([box[0],box[1],box[2],box[3],score])
            
        return scores_map
    
    def write_to_file(self,filename,scores_map):
        
        file = open(os.path.join(self.text_path,os.path.splitext(filename)[0]+".txt"),"w")
        for key in scores_map:
            values = scores_map[key]
            if values == []:
                continue
            for x1,y1,x2,y2,score in values:
                file.write(key+" "+str(score)+" "+str(x1)+" "+str(y1)+" "+str(x2)+" "+str(y2)+"\n")
        file.close()
                    
    def perform_retinanet_inference(self,image_path,write_to_text=True):
            
            image = read_image_bgr(image_path)
            image_temp = image.copy()
            image = preprocess_image(image)
            image, scale = resize_image(image)
            boxes, scores, labels = self.model.predict_on_batch(np.expand_dims(image, axis=0))
            boxes /= scale

            boxes, scores, labels = self.filter_bbox(boxes,scores,labels)
            
            scores_map = self.convert_detections_to_dict(boxes,scores,labels)
            self.res = scores_map.copy()
            self.res1 = scores_map
            for key in scores_map:
                bboxes = scores_map[key]
                if bboxes==[]:
                    continue

                for bbox in bboxes:
                    cv2.rectangle(image_temp, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])),self.palette[key], 2)

                    cv2.putText(image_temp, key, (int(bbox[0]), int(bbox[1]) - 15),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2)
            cv2.imwrite(os.path.join(self.viz_path,file),image_temp)
            
            
            if write_to_text:
                self.write_to_file(os.path.splitext(image_path.split("/")[1])[0]+".jpg",scores_map)


In [None]:
config = configparser.ConfigParser()
config.read('config.ini')

PALETTE = {"specularity":(255,0,0),"saturation":(0,255,0),"artifact":(0,0,255),"blur":(0,255,255),"contrast":(238,130,238),"bubbles":(211,0,148),"instrument":(0,0,0),"blood":(127,127,127)}


RETINANET_WEIGHTS = config['model_params']['retinanet_weights']
TEXT_PATH = config['model_params']['text_path']
VIS_PATH = config['model_params']['vis_path']
CLASS_NAMES = config['model_params']['class_names'].split(",")
IMAGE_PATH = config['model_params']['image_path']
BACKBONE = config['model_params']['backbone']



obj_det = ObjectDetector(RETINANET_WEIGHTS,CLASS_NAMES,VIS_PATH,TEXT_PATH,BACKBONE,PALETTE)

In [None]:
obj_det.load_retinanet()

In [None]:
files = os.listdir(IMAGE_PATH)

for file in tqdm(files):
    obj_det.perform_retinanet_inference(os.path.join(IMAGE_PATH,file))
    
