In [1]:
import cv2
import time
import os
import tensorflow as tf
import numpy as np
from tensorflow.python.keras.utils.data_utils import get_file
import warnings

In [2]:
np.random.seed(333)

In [3]:
class Detector:
    def __init__(self):
        pass
    
    def readClasses(self, classes_file_path):
        with open(classes_file_path, 'r') as f:
            self.classes_list = f.read().splitlines()
            
        self.color_list = np.random.uniform(low = 0, high = 255, size = (len(self.classes_list), 3))
        print(f'Class list len: {len(self.classes_list)}')
        print(f'Color list len: {len(self.color_list)}')
    
    def downloadModel(self, model_url):
        file_name = os.path.basename(model_url)
        self.model_name = file_name[:file_name.index('.')]
        
        self.cache_dir = "./pretrained_models"
        
        os.makedirs(self.cache_dir, exist_ok=True)
        
        get_file(fname=file_name, 
                 origin=model_url, 
                 cache_dir=self.cache_dir, 
                 cache_subdir="checkpoints", 
                 extract=True)
    
    def loadModel(self):
        print(f'Loading model: {self.model_name}')
        tf.keras.backend.clear_session()
        self.model = tf.saved_model.load(os.path.join(self.cache_dir, "checkpoints", self.model_name, "saved_model"))
        
        print(f'Model {self.model_name} loaded successfully...')
        
    def predictImage(self, image_path, threshold = 0.5, save = 0):
        image = cv2.imread(image_path)
        (h, w) = image.shape[:2]
        if h > w:
            image = image_resize(image, height = 800)
        else:
            image = image_resize(image, width = 800)
            
        if save:
            bbox_image, detected_objects = self.createBoundingBox(image, threshold, save = 1)
            return bbox_image, detected_objects
        else:
            bbox_image = self.createBoundingBox(image, threshold)
            
        cv2.imshow("Result", bbox_image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        
    def predictFolder(self, folder_path, threshold = 0.5, save = 0):
        if save:
            new_folder_path = os.path.join(folder_path, "output_folder")
            os.makedirs(new_folder_path, exist_ok=True)
            
            text_file_path = os.path.join(new_folder_path, 'output.txt')
            open(text_file_path, 'w').close()
            count = 0
        print(folder_path)
        for image_path in os.listdir(folder_path):
            image = os.path.join(folder_dir, image_path)
            print(image)
            if save:
                count += 1
                try:
                    bbox_image, detected_objects = detector.predictImage(image, 0.5, save = 1)
                except:
                    continue
                
                image_name = f'image {count}.jpg'
                image_file_path = os.path.join(new_folder_path, image_name)
                cv2.imwrite(image_file_path, bbox_image)
                
                with open(text_file_path, 'a') as file:
                    file.write(f'{image_name}, Output: {detected_objects}\n')
            else:
                detector.predictImage(image, 0.5)
        
        
    def createBoundingBox(self, image, threshold = 0.5, save = 0):
        input_tensor = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2RGB)
        input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.uint8)
        input_tensor = input_tensor[tf.newaxis, ...]
        
        detections = self.model(input_tensor)
        
        bboxes = detections["detection_boxes"][0].numpy()
        class_indices = detections["detection_classes"][0].numpy().astype(np.int32)
        class_scores = detections["detection_scores"][0].numpy()
        
        img_h, img_w, img_c = image.shape
        
        bbox_idx = tf.image.non_max_suppression(bboxes, class_scores, max_output_size = 50,
                                               iou_threshold=threshold, score_threshold=threshold)
        detected_objects = {}
        if len(bboxes) != 0:
            for i in bbox_idx:
                bbox = tuple(bboxes[i].tolist())
                class_confidence = round(100*class_scores[i])
                class_index = class_indices[i]

                class_label_text = self.classes_list[class_index]
                class_color = self.color_list[class_index]

                display_text = f'{class_label_text}: {class_confidence}'
                
                if class_label_text in detected_objects:
                    if isinstance(detected_objects[class_label_text], list):
                        detected_objects[class_label_text].append(class_confidence)
                    else:
                        detected_objects[class_label_text] = [detected_objects[class_label_text], class_confidence]
                else:
                    detected_objects[class_label_text] = class_confidence

                ymin, xmin, ymax, xmax = bbox

                xmin, xmax, ymin, ymax = (xmin * img_h, xmax * img_w, ymin * img_h, ymax * img_h)
                xmin, xmax, ymin, ymax = int(xmin), int(xmax), int(ymin), int(ymax)

                cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color = class_color, thickness=1)
                if (ymin - 10) < 10:
                    cv2.rectangle(image, (xmin, ymin), (xmax, ymin + 20), class_color, -1)
                    cv2.putText(image, display_text, (xmin, ymin + 12), cv2.FONT_HERSHEY_PLAIN, 1, 
                            (255, 255, 255), 2)
                else:
                    cv2.rectangle(image, (xmin, ymin), (xmax, ymin - 20), class_color, -1)
                    cv2.putText(image, display_text, (xmin, ymin - 7), cv2.FONT_HERSHEY_PLAIN, 1, 
                            (255, 255, 255), 2)
            
        if save:
            return image, detected_objects
        return image
            
            
        

In [None]:
def image_resize(image, width = None, height = None):
    inter = cv2.INTER_AREA
    
    (h, w) = image.shape[:2]
    
    if width == None:
        r = height/float(h)
        dim = (int(w * r), height)
        
    else:
        r = width/float(w)
        dim = (width, int(h * r))
        
    resized = cv2.resize(image, dim, interpolation = inter)
    
    return resized