Detect through the whole ROIs, apply sliding window detection, to find objects/boxes

In [2]:
from utils import im_to_txt_path
import os
from os.path import isfile
import matplotlib.pyplot as plt
import torch
import cv2
from timer import *

In [3]:
parent_dir = "test/"
roi_dir = parent_dir + "images/"

roi_paths = []

for root, dirs, files in os.walk(roi_dir):
        for file in files:
            if file.lower().endswith(('.jpg','.png')):
                img_path = os.path.join(root, file)
                roi_paths.append(img_path)   
                
print(len(roi_paths))         

48


In [4]:
# # test on cropping the image / sliding window approach
# window_size = 512
# fill = [114, 114, 114]

# roi_fp = roi_paths[0]
# roi = cv2.imread(roi_fp)

# roi_w = roi.shape[1]
# roi_h = roi.shape[0]

# roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
# print('no border: ', roi.shape)
# # Add the border to the ROI image
# roi_img = cv2.copyMakeBorder(roi, 0, window_size - roi_h % window_size, 0, window_size - roi_w % window_size, cv2.BORDER_CONSTANT, value=fill)

# print('with border: ', roi_img.shape)

# plt.imshow(roi_img)
# plt.show()


# start_x = roi_w // window_size * window_size
# start_y = roi_h // window_size * window_size
# end_x = start_x + window_size
# end_y = start_y + window_size
# print(start_x, end_x)
# print(start_y, end_y)
# # crop = roi_img[0:512, 0:512]
# crop = roi_img[start_y:end_y,start_x:end_x]
# plt.imshow(crop)
# plt.show()

In [5]:
def calculate_iou(box1, box2):
    """
    Calculate the Intersection over Union (IoU) of two bounding boxes.
    
    Parameters:
    box1, box2: Arrays or lists in the format [x_min, y_min, x_max, y_max]
    
    Returns:
    float: IoU value
    """
    x1_min, y1_min, x1_max, y1_max = box1
    x2_min, y2_min, x2_max, y2_max = box2
    
    # Calculate the intersection coordinates
    x_min = max(x1_min, x2_min)
    y_min = max(y1_min, y2_min)
    x_max = min(x1_max, x2_max)
    y_max = min(y1_max, y2_max)
    
    # Calculate the intersection area
    intersection_area = max(0, x_max - x_min) * max(0, y_max - y_min)
    
    # Calculate the area of both boxes
    box1_area = (x1_max - x1_min) * (y1_max - y1_min)
    box2_area = (x2_max - x2_min) * (y2_max - y2_min)
    
    # Calculate the union area
    union_area = box1_area + box2_area - intersection_area
    
    # Calculate the IoU
    iou = intersection_area / union_area
    
    return iou

In [6]:
import torchvision
import logging

import traceback

# Configure logging
# logging.basicConfig(level=logging.DEBUG)

def detect_on_roi(roi_fp, window_size = 512, stride = 256, fill = [114, 114, 114], iou_thr = 0.5):
    try:
        model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(num_classes=2)
        model = model.to('cuda')
        model.load_state_dict(torch.load('rois2/models/model_epoch9.pth'))
        model.eval()
        
        logging.debug(f"Processing {roi_fp}")
        # load groundtruth labels
        true_labels = []
        if isfile(im_to_txt_path(roi_fp)):
            with open(im_to_txt_path(roi_fp), 'r') as f:
                for line in f:
                    temp_label = [int(x) for x in line.strip().split(' ')]
                    true_labels.append(temp_label)
        true_labels_detected = [0] * len(true_labels)
        roi = cv2.imread(roi_fp)
        
        if roi is None:
            logging.error(f"Failed to read {roi_fp}")
            return (0, 0)
        
        roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)/255
        roi_w = roi.shape[1]
        roi_h = roi.shape[0]
        roi = cv2.copyMakeBorder(roi, 0, window_size - roi_h % window_size, 0, window_size - roi_w % window_size, cv2.BORDER_CONSTANT, value=fill)
        
        print('     roi width: ', roi_w, ' roi height: ', roi_h)
        start_x = 0
        
        while start_x <= roi_w:
            windows = []
            offsets = []
            
            # print('     ', start_x)
            start_y = 0
            while start_y <= roi_h: # slide the window from top left to bottom right
                end_x = int(start_x + window_size)
                end_y = int(start_y + window_size)
                window = roi[start_y:end_y, start_x:end_x]
                window = torch.tensor(window).permute(2, 0, 1).float().to('cuda')
                
                windows.append(window)
                offsets.append([start_x, start_y])
                
                start_y += stride
            # print('     number of windows: ', len(windows))
                
            preds = model(windows)
        
            for idx, pred in enumerate(preds):
                offset = offsets[idx]
                filtered_idx = torchvision.ops.nms(pred['boxes'], pred['scores'], iou_thr)
                pred_boxes = pred['boxes'][filtered_idx]
                for box in (pred_boxes):
                    pred_box = [box[0] + offset[0], box[1] + offset[1], box[2] + offset[0], box[3] + offset[1]]
                    for i, true_label in enumerate(true_labels):
                        true_box = [true_label[1], true_label[2], true_label[3], true_label[4]]
                        if calculate_iou(pred_box, true_box) > iou_thr:
                            true_labels_detected[i] = 1
            
            start_x += stride
    except Exception as e:
        logging.error(f"Error processing {roi_fp}: {e}")
        traceback.print_exc()
        return (0, 0)

    # logging.debug(f"Completed {roi_fp}")
    return sum(true_labels_detected), len(true_labels_detected)
    


In [7]:
tot_cnt = 0
correct_cnt = 0
timer = Timer()
timer.start()
for i, roi_fp in enumerate(roi_paths):
    print('roi: ', i, '/', len(roi_paths))
    roi_tot_cnt = 0
    roi_correct_cnt = 0
    
    true_positives, total = detect_on_roi(roi_fp, window_size=512, stride=256, iou_thr=0.5)
    
    timer.stop()
    roi_tot_cnt += total
    roi_correct_cnt += true_positives
    if total > 0:
        print('     correct / total = ', roi_correct_cnt, '/', roi_tot_cnt, ' = ', true_positives/total, 'time: ', timer.elapsed_time())
    else:
        print('     no ground truth labels', 'time: ', timer.elapsed_time())
    
    tot_cnt += roi_tot_cnt
    correct_cnt += roi_correct_cnt

print('recall = ', correct_cnt/tot_cnt)

roi:  0 / 48
     roi width:  5284  roi height:  3311
     correct / total =  79 / 81  =  0.9753086419753086 time:  Elapsed time: 0 hours, 44 minutes, 29.21 seconds
roi:  1 / 48
     roi width:  4944  roi height:  2684
     correct / total =  9 / 9  =  1.0 time:  Elapsed time: 0 hours, 48 minutes, 30.54 seconds
roi:  2 / 48
     roi width:  2702  roi height:  4946
     no ground truth labels time:  Elapsed time: 0 hours, 48 minutes, 33.83 seconds
roi:  3 / 48
     roi width:  4364  roi height:  5594
     correct / total =  54 / 56  =  0.9642857142857143 time:  Elapsed time: 1 hours, 30 minutes, 36.12 seconds
roi:  4 / 48
     roi width:  5621  roi height:  4390
     correct / total =  13 / 14  =  0.9285714285714286 time:  Elapsed time: 1 hours, 44 minutes, 10.59 seconds
roi:  5 / 48
     roi width:  5065  roi height:  5487
     correct / total =  6 / 6  =  1.0 time:  Elapsed time: 1 hours, 51 minutes, 15.42 seconds
roi:  6 / 48
     roi width:  4972  roi height:  2644
     correct / to

KeyboardInterrupt: 

In [None]:
# from multiprocessing import Pool
# from tqdm import tqdm

# def multi_detect_on_roi(roi_fps, window_size=512, stride=256, fill=[114, 114, 114], iou_thr=0.5, num_proc=10):
#     tot_cnt = 0
#     correct_cnt = 0

#     with Pool(num_proc) as pool:
#         jobs = [
#             pool.apply_async(
#                 func=detect_on_roi,
#                 args=((roi_fp, window_size, stride, fill, iou_thr))
#             )
#             for roi_fp in roi_fps
#         ]
#         for job in tqdm(jobs):
#             try:
#                 true_positives, total = job.get(timeout=300)  # Add a timeout to detect stuck processes
#                 correct_cnt += true_positives
#                 tot_cnt += total
#             except Exception as e:
#                 logging.error(f"Error in job: {e}")
#     return correct_cnt, tot_cnt

# tot_cnt = 0
# correct_cnt = 0
# correct_cnt, tot_cnt = multi_detect_on_roi(roi_paths, window_size = 1024, stride = 512, fill = [114, 114, 114], iou_thr = 0.5, num_proc = 10)
# print('recall = ', correct_cnt, '/', tot_cnt, ' = ', correct_cnt/tot_cnt)