In [1]:
import torch
import torchvision
import os
import pickle
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import torch.nn as nn
from PIL import Image
from torchcam.methods import SmoothGradCAMpp
from utils import imread, isfile, im_to_txt_path, imwrite
from timer import *
from torchvision.transforms import transforms as T
import cv2
from dataset_operations import clip_value

parent_dir = 'rois2/'
obj_dir = parent_dir + 'objects/'
img_dir = parent_dir + 'images/'
label_dir = parent_dir + 'labels/'
model_dir = parent_dir + 'models/'

result_dir = parent_dir + 'results/'
os.makedirs(result_dir, exist_ok=True)
model1_result_dir = result_dir + 'model1/'
os.makedirs(model1_result_dir, exist_ok=True)

os.makedirs(model1_result_dir + 'train', exist_ok=True)
os.makedirs(model1_result_dir + 'val', exist_ok=True)

obj_train_dir = obj_dir + 'train/'
obj_val_dir = obj_dir + 'val/'

os.makedirs(obj_train_dir, exist_ok=True)
os.makedirs(obj_val_dir, exist_ok=True)

In [2]:
def gather_detections(roi_fp, model, save_dir = model1_result_dir, window_size=512, stride=256, 
                      score_thr=0.5, nms_thr=0.5, intersect_thr=0.5,
                      box_width_thr = [20,500], box_area_thr = [200,150000], box_ratio_thr = 0.2, clean_up = True):
    
    if clean_up:
        for dirpath, dirnames, filenames in os.walk(model1_result_dir + 'train'):
            for filename in [f for f in filenames if f.endswith(".txt")]:
                os.remove(os.path.join(dirpath, filename))
                
        for dirpath, dirnames, filenames in os.walk(model1_result_dir + 'test'):
            for filename in [f for f in filenames if f.endswith(".txt")]:
                os.remove(os.path.join(dirpath, filename))
    
    
    model.eval()
    roi_name = roi_fp.split('/')[-1].split('.')[0]
    # print(roi_name)
    result_file_path = save_dir + roi_name + '.txt'
    result_file = open(result_file_path, 'w')
    
    orig_roi = imread(roi_fp)
    roi_w = orig_roi.shape[1]
    roi_h = orig_roi.shape[0]
    roi = cv2.copyMakeBorder(orig_roi, 0, window_size - roi_h % window_size, 0, window_size - roi_w % window_size, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    if roi is None:
        print('     Error: No roi file found')
        return 0
    
    
    # load and record groundtruth labels
    true_boxes = []
    if isfile(im_to_txt_path(roi_fp)):
        with open(im_to_txt_path(roi_fp), 'r') as f:
            for line in f:
                temp_label = [int(x) for x in line.strip().split(' ')]
                true_boxes.append([temp_label[1], temp_label[2], temp_label[3], temp_label[4]])
    obj_map = torch.zeros((roi_h, roi_w))
    for box in true_boxes:
        obj_map[box[1]:box[3], box[0]:box[2]] = (box[3] - box[1]) * (box[2] - box[0])
    
    
    print('     roi width: ', roi_w, ' roi height: ', roi_h)
    # slide the window from top left to bottom right
    start_x = 0
    while start_x <= roi_w:
        windows = []
        offsets = []
        
        start_y = 0
        while start_y <= roi_h: 
            end_x = int(start_x + window_size)
            end_y = int(start_y + window_size)
            window = roi[start_y:end_y, start_x:end_x]
            window = torch.tensor(window).permute(2,0,1).float().to('cuda')
            # print(window.shape)
            windows.append(window)
            offsets.append([start_x, start_y])
            
            start_y += stride
        # print('     number of windows: ', len(windows))
        
        
        preds = model(windows)
        for idx, pred in enumerate(preds):
            # window_timer.start()
            
            offset = offsets[idx]
            filtered_idx = torchvision.ops.nms(pred['boxes'], pred['scores'], nms_thr)
            # print("     Time to perform NMS on current window: ", nms_timer.elapsed_time())
            
            pred_boxes = pred['boxes'][filtered_idx]
            pred_scores = pred['scores'][filtered_idx]
            
            for score, box in zip(pred_scores,pred_boxes):
                # score thr
                if score < score_thr:
                    continue
                
                pred_box_x1 = int(box[0] + offset[0])
                pred_box_y1 = int(box[1] + offset[1])
                pred_box_x2 = int(box[2] + offset[0])
                pred_box_y2 = int(box[3] + offset[1])
                
                #  size thr
                if (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1) < box_area_thr[0] \
                    or (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1) > box_area_thr[1] \
                    or pred_box_x2 - pred_box_x1 < box_width_thr[0] or pred_box_x2 - pred_box_x1 > box_width_thr[1] \
                    or pred_box_y2 - pred_box_y1 < box_width_thr[0] or pred_box_y2 - pred_box_y1 > box_width_thr[1] \
                    or (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1) < box_area_thr[0] \
                    or (pred_box_x2 - pred_box_x1) / (pred_box_y2 - pred_box_y1) < box_ratio_thr \
                    or (pred_box_y2 - pred_box_y1) / (pred_box_x2 - pred_box_x1) < box_ratio_thr:
                    continue
                
                pred_obj_type = 0
                # iou thr
                intersect_area = (obj_map[pred_box_y1:pred_box_y2, pred_box_x1:pred_box_x2] > 0).sum().item()
                pred_area = (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1)
                
                if intersect_area > 0:
                    obj_area = torch.max(obj_map[pred_box_y1:pred_box_y2, pred_box_x1:pred_box_x2]).item()
                    if pred_area >= (pred_area + obj_area - intersect_area) * intersect_thr:
                        pred_obj_type = 1
                    
                result_file.write(f"{pred_box_x1} {pred_box_y1} {pred_box_x2} {pred_box_y2} {pred_obj_type}\n")
        start_x += stride

In [3]:
detect_model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(num_classes=2).to('cuda')
detect_model.load_state_dict(torch.load('rois2/models/round2_2/round2_model_epoch6.pth'))

<All keys matched successfully>

In [6]:
train_roi_paths = []
f = open(img_dir + 'train_rois.txt', 'r')
lines = f.readlines()
for line in lines:
    train_roi_paths.append(line.strip())
f.close()

timer = Timer()
for i, roi_path in enumerate(train_roi_paths):
    file_name = model1_result_dir + 'train/' +  roi_path.split('/')[-1].split('.')[0] + '.txt'
    
    if isfile(file_name) and os.path.getsize(file_name) > 0:
        print('ROI [', i+1, '/', len(train_roi_paths),'] Already tested, skipping...')
        continue
    
    
    print('ROI [', i+1, '/', len(train_roi_paths),'] Processing ', roi_path)
    
    timer.start()
    gather_detections(roi_path, detect_model, save_dir = model1_result_dir + 'train/', 
                      nms_thr=0.1, score_thr=0.5, intersec_thr = 1)
    timer.stop()
    print('     Time to process ', roi_path, ': ', timer.elapsed_time())

ROI [ 1 / 334 ] Processing  rois2/images/expert1_id-6381471f7f8a5e686a52765f_left-36214_top-41262_right-39265_bottom-46232.png
     roi width:  3051  roi height:  4970
     Time to process  rois2/images/expert1_id-6381471f7f8a5e686a52765f_left-36214_top-41262_right-39265_bottom-46232.png :  0 hours, 0 minutes, 11.55 seconds
ROI [ 2 / 334 ] Processing  rois2/images/expert1_id-6381471f7f8a5e686a52765f_left-36627_top-47431_right-39464_bottom-52356.png
     roi width:  2837  roi height:  4925
     Time to process  rois2/images/expert1_id-6381471f7f8a5e686a52765f_left-36627_top-47431_right-39464_bottom-52356.png :  0 hours, 0 minutes, 11.56 seconds
ROI [ 3 / 334 ] Processing  rois2/images/expert1_id-6381471f7f8a5e686a527706_left-75035_top-24698_right-80405_bottom-28610.png
     roi width:  5370  roi height:  3912
     Time to process  rois2/images/expert1_id-6381471f7f8a5e686a527706_left-75035_top-24698_right-80405_bottom-28610.png :  0 hours, 0 minutes, 17.39 seconds
ROI [ 4 / 334 ] Proces

In [7]:
test_roi_paths = []
f = open(img_dir + 'test_rois.txt', 'r')
lines = f.readlines()
for line in lines:
    test_roi_paths.append(line.strip())
f.close()

timer = Timer()
for i, roi_path in enumerate(test_roi_paths):
    file_name = model1_result_dir + 'test' + roi_path.split('/')[-1].split('.')[0] + '.txt'
    
    if isfile(file_name) and os.path.getsize(file_name) > 0:
        print('ROI [', i+1, '/', len(test_roi_paths),'] Already tested, skipping...')
        continue
    
    
    print('ROI [', i+1, '/', len(test_roi_paths),'] Processing ', roi_path)
    
    timer.start()
    gather_detections(roi_path, detect_model, save_dir = model1_result_dir + 'test/', 
                      nms_thr=0.5, score_thr=0.1)
    timer.stop()
    print('     Time to process: ', timer.elapsed_time())

ROI [ 1 / 71 ] Processing  rois2/images/expert4_id-6381475a7f8a5e686a52cd62_left-62410_top-42230_right-67996_bottom-47544.png
     roi width:  5586  roi height:  5314
     Time to process:  0 hours, 0 minutes, 23.38 seconds
ROI [ 2 / 71 ] Processing  rois2/images/expert3_id-638147157f8a5e686a5266b6_left-89851_top-19583_right-95249_bottom-23467.png
     roi width:  5398  roi height:  3884
     Time to process:  0 hours, 0 minutes, 18.62 seconds
ROI [ 3 / 71 ] Processing  rois2/images/novice4_id-640b6f43a7dbca00a13b1d79_left-68935_top-27767_right-74578_bottom-32500.png
     roi width:  5643  roi height:  4733
     Time to process:  0 hours, 0 minutes, 23.28 seconds
ROI [ 4 / 71 ] Processing  rois2/images/expert3_id-638147197f8a5e686a526b96_left-64449_top-32580_right-69312_bottom-38239.png
     roi width:  4863  roi height:  5659
     Time to process:  0 hours, 0 minutes, 22.84 seconds
ROI [ 5 / 71 ] Processing  rois2/images/expert2_id-638147437f8a5e686a52abc1_left-46904_top-41335_right-5

In [2]:
# clean up the directory
for dir, subdirs, files in os.walk(obj_train_dir):
    for file in files:
        os.remove(os.path.join(dir, file))

for dir, subdirs, files in os.walk(obj_val_dir):
    for file in files:
        os.remove(os.path.join(dir, file))            

if os.path.exists(obj_train_dir):       
    for dir in os.listdir(obj_train_dir):
        os.removedirs(obj_train_dir + dir)
        
if os.path.exists(obj_val_dir):
    for dir in os.listdir(obj_val_dir):
        os.removedirs(obj_val_dir + dir)


In [3]:
def crop_detected_objects(read_dir, save_dir):    
    for dirpath, dirnames, filenames in os.walk(read_dir):
        for i, filename in enumerate(filenames):
            roi_path = img_dir + filename.split('.')[0] + '.png'
            roi_img = imread(roi_path).copy()
            w,h = roi_img.shape[1], roi_img.shape[0]
            # print(w,h)
            with open(os.path.join(dirpath, filename), 'r') as f:
                lines = f.readlines()
                for line in lines:
                    _, x1,y1,x2,y2,obj_type = line.strip().split(' ')
                    x1,y1,x2,y2,obj_type = int(x1), int(y1), int(x2), int(y2), str(obj_type)
                    x1,y1,x2,y2 = clip_value(x1, 0, w), clip_value(y1, 0, h), clip_value(x2, 0, w), clip_value(y2, 0, h)
                    
                    if x1 >= x2 or y1 >= y2:
                        continue
                    
                    if x2 - x1 < 20 or y2 - y1 < 20 or (x2 - x1) * (y2 - y1) < 200 or (x2 - x1) * (y2 - y1) > 150000:
                        continue
                    
                    temp_img = roi_img[y1:y2, x1:x2]
                    os.makedirs(save_dir + obj_type + '/', exist_ok=True)
                    imwrite(save_dir + obj_type + '/' + filename.split('.')[0] + '__' + str(x1) + '_' + str(y1) + '_' + str(x2) + '_' + str(y2) + '.png', temp_img)
            
            print('ROI processed [{}/{}]'.format(i + 1, len(filenames)) )
                    

In [4]:
crop_detected_objects(model1_result_dir + 'train/', obj_train_dir)

ROI processed [1/334]
ROI processed [2/334]
ROI processed [3/334]
ROI processed [4/334]
ROI processed [5/334]
ROI processed [6/334]
ROI processed [7/334]
ROI processed [8/334]
ROI processed [9/334]
ROI processed [10/334]
ROI processed [11/334]
ROI processed [12/334]
ROI processed [13/334]
ROI processed [14/334]
ROI processed [15/334]
ROI processed [16/334]
ROI processed [17/334]
ROI processed [18/334]
ROI processed [19/334]
ROI processed [20/334]
ROI processed [21/334]
ROI processed [22/334]
ROI processed [23/334]
ROI processed [24/334]
ROI processed [25/334]
ROI processed [26/334]
ROI processed [27/334]
ROI processed [28/334]
ROI processed [29/334]
ROI processed [30/334]
ROI processed [31/334]
ROI processed [32/334]
ROI processed [33/334]
ROI processed [34/334]
ROI processed [35/334]
ROI processed [36/334]
ROI processed [37/334]
ROI processed [38/334]
ROI processed [39/334]
ROI processed [40/334]
ROI processed [41/334]
ROI processed [42/334]
ROI processed [43/334]
ROI processed [44/33

In [5]:
crop_detected_objects(model1_result_dir + 'val/', obj_val_dir)

ROI processed [1/71]
ROI processed [2/71]
ROI processed [3/71]
ROI processed [4/71]
ROI processed [5/71]
ROI processed [6/71]
ROI processed [7/71]
ROI processed [8/71]
ROI processed [9/71]
ROI processed [10/71]
ROI processed [11/71]
ROI processed [12/71]
ROI processed [13/71]
ROI processed [14/71]
ROI processed [15/71]
ROI processed [16/71]
ROI processed [17/71]
ROI processed [18/71]
ROI processed [19/71]
ROI processed [20/71]
ROI processed [21/71]
ROI processed [22/71]
ROI processed [23/71]
ROI processed [24/71]
ROI processed [25/71]
ROI processed [26/71]
ROI processed [27/71]
ROI processed [28/71]
ROI processed [29/71]
ROI processed [30/71]
ROI processed [31/71]
ROI processed [32/71]
ROI processed [33/71]
ROI processed [34/71]
ROI processed [35/71]
ROI processed [36/71]
ROI processed [37/71]
ROI processed [38/71]
ROI processed [39/71]
ROI processed [40/71]
ROI processed [41/71]
ROI processed [42/71]
ROI processed [43/71]
ROI processed [44/71]
ROI processed [45/71]
ROI processed [46/7