In [None]:
import torch
import torchvision
import os
import pickle
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import torch.nn as nn
from PIL import Image
from torchcam.methods import SmoothGradCAMpp
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import to_pil_image
from utils import imread, im_to_txt_path, isfile
from torchvision.transforms import transforms as T
import cv2
import random



parent_dir = 'rois2/'
obj_dir = parent_dir + 'objects/'
img_dir = parent_dir + 'images/'
label_dir = parent_dir + 'labels/'
model_dir = parent_dir + 'models/'

obj_train_dir = obj_dir + 'train/'
obj_test_dir = obj_dir + 'test/'
false_positive_label_dir = parent_dir + 'tiles/false_positives/labels/'

result_dir = parent_dir + 'results/'
result_model1_dir = result_dir + 'model1/'
result_model1_test_dir = result_model1_dir + 'test/'

result_model1_test_labels_dir = result_model1_test_dir + 'labels/'
result_model1_test_images_dir = result_model1_test_dir + 'images/'

os.makedirs(result_dir, exist_ok=True)
os.makedirs(result_model1_dir, exist_ok=True)
os.makedirs(result_model1_test_dir, exist_ok=True) 
os.makedirs(result_model1_test_labels_dir, exist_ok=True)
os.makedirs(result_model1_test_images_dir, exist_ok=True)

mean_std_dict = pickle.load(open(img_dir + 'mean_std.pkl', 'rb'))
print(mean_std_dict)


In [None]:
detection_model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(num_classes=2).to('cuda')

classification_model = torchvision.models.efficientnet_b1().to('cuda')
classification_model.classifier[1] = nn.Linear(classification_model.classifier[1].in_features, 2).to('cuda')

In [None]:
def get_model1_result(roi_fp, detect_model, save_dir = result_model1_test_labels_dir, window_size=512, stride=256, 
                      score_thr=0.5, nms_thr=0.5, intersect_thr=0.5,
                      box_width_thr = [20,500], box_area_thr = [200,150000], box_ratio_thr = 0.2, clean_up = True):
    
    if clean_up:
        for dirpath, dirnames, filenames in os.walk(result_model1_dir + 'train'):
            for filename in [f for f in filenames if f.endswith(".txt")]:
                os.remove(os.path.join(dirpath, filename))
                
        for dirpath, dirnames, filenames in os.walk(result_model1_dir + 'test'):
            for filename in [f for f in filenames if f.endswith(".txt")]:
                os.remove(os.path.join(dirpath, filename))
    
    detect_model.eval()
    roi_name = roi_fp.split('/')[-1].split('.')[0]
    # print(roi_name)
    result_file_path = save_dir + roi_name + '.txt'
    result_file = open(result_file_path, 'w')
    
    orig_roi = imread(roi_fp)
    roi_w = orig_roi.shape[1]
    roi_h = orig_roi.shape[0]
    roi = cv2.copyMakeBorder(orig_roi, 0, window_size - roi_h % window_size, 0, window_size - roi_w % window_size, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    if roi is None:
        print('     Error: No roi file found')
        return 0
    
    
    # load and record groundtruth labels
    true_boxes = []
    if isfile(im_to_txt_path(roi_fp)):
        with open(im_to_txt_path(roi_fp), 'r') as f:
            for line in f:
                temp_label = [int(x) for x in line.strip().split(' ')]
                true_boxes.append([temp_label[1], temp_label[2], temp_label[3], temp_label[4]])
    obj_map = torch.zeros((roi_h, roi_w))
    for box in true_boxes:
        obj_map[box[1]:box[3], box[0]:box[2]] = (box[3] - box[1]) * (box[2] - box[0])
    
    
    print('     roi width: ', roi_w, ' roi height: ', roi_h)
    # slide the window from top left to bottom right
    start_x = 0
    while start_x <= roi_w:
        windows = []
        offsets = []
        
        start_y = 0
        while start_y <= roi_h: 
            end_x = int(start_x + window_size)
            end_y = int(start_y + window_size)
            window = roi[start_y:end_y, start_x:end_x]
            window = torch.tensor(window).permute(2,0,1).float().to('cuda')
            # print(window.shape)
            windows.append(window)
            offsets.append([start_x, start_y])
            
            start_y += stride
        # print('     number of windows: ', len(windows))
        
        
        preds = detect_model(windows)
        for idx, pred in enumerate(preds):
            # window_timer.start()
            
            offset = offsets[idx]
            filtered_idx = torchvision.ops.nms(pred['boxes'], pred['scores'], nms_thr)
            # print("     Time to perform NMS on current window: ", nms_timer.elapsed_time())
            
            pred_boxes = pred['boxes'][filtered_idx]
            pred_scores = pred['scores'][filtered_idx]
            
            for score, box in zip(pred_scores,pred_boxes):
                # score thr
                if score < score_thr:
                    continue
                
                pred_box_x1 = int(box[0] + offset[0])
                pred_box_y1 = int(box[1] + offset[1])
                pred_box_x2 = int(box[2] + offset[0])
                pred_box_y2 = int(box[3] + offset[1])
                
                # range thr
                if pred_box_x1 < 0 or pred_box_y1 < 0 or pred_box_x2 > roi_w or pred_box_y2 > roi_h:
                    continue
                
                #  size thr
                if (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1) < box_area_thr[0] \
                    or (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1) > box_area_thr[1] \
                    or pred_box_x2 - pred_box_x1 < box_width_thr[0] or pred_box_x2 - pred_box_x1 > box_width_thr[1] \
                    or pred_box_y2 - pred_box_y1 < box_width_thr[0] or pred_box_y2 - pred_box_y1 > box_width_thr[1] \
                    or (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1) < box_area_thr[0] \
                    or (pred_box_x2 - pred_box_x1) / (pred_box_y2 - pred_box_y1) < box_ratio_thr \
                    or (pred_box_y2 - pred_box_y1) / (pred_box_x2 - pred_box_x1) < box_ratio_thr:
                    continue
                    
                result_file.write(f"{pred_box_x1} {pred_box_y1} {pred_box_x2} {pred_box_y2}\n")
        start_x += stride

In [None]:
class CustomDataset(Dataset):

    def __init__(self, labels_list, roi_img, mean, std):
        
        self.labels_list = labels_list
        self.roi_img = roi_img
        self.mean = mean
        self.std = std

    def __len__(self):
        return len(self.labels_list)

    def __getitem__(self, idx):
        # print(self.inputs_list[idx])
        box = self.labels_list[idx]
        
        obj_img = self.roi_img[box[1]:box[3],box[0]:box[2]]
        transf = v2.Compose([
            v2.ToImage(), 
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=self.mean, std=self.std),
            v2.Resize((260,260)),
        ])
        output = transf(obj_img)

        return output.to('cuda'), box

In [None]:
def generate_model1_result_dataloader(pred_labels_fp):
    roi_name = pred_labels_fp.split('/')[-1].split('.')[0]
    roi_path = img_dir + roi_name + '.png'
    
    labels_file = open(pred_labels_fp, 'r')
    lines = labels_file.readlines()
    
    labels_list = []
    
    roi_img = imread(roi_path)
    mean, std = mean_std_dict[roi_name]
    
    for line in lines:
        temp_label = [int(x) for x in line.strip().split(' ')]
        x1, y1, x2, y2 = int(temp_label[0]), int(temp_label[1]), int(temp_label[2]), int(temp_label[3])
        labels_list.append(torch.tensor([x1, y1, x2, y2],dtype = torch.int))
        
        
        
    dataset = CustomDataset(labels_list, roi_img, mean, std)
    return DataLoader(dataset, batch_size=32, shuffle=False)
    

In [None]:
def get_model2_result(roi_path, loader, class_model):
    roi = imread(roi_path)
    
    positive_preds = 0
    
    class_model.eval()
    results = []
    for i, (imgs,boxes) in enumerate(loader):
        if (i + 1) % 10 == 0 or i + 1 == len(loader): 
            print('     Processing batch[{}/{}]'.format(i+1, len(loader)))
        
        # print(boxes.shape)
        with torch.no_grad():
            outputs = class_model(imgs)
            preds = torch.argmax(outputs, 1)
            positive_preds += torch.sum(preds == 1).item()
            
            for idx, pred in enumerate(preds):
                if pred == 1:
                    # print(idx, len(boxes))
                    results.append(boxes[idx].tolist())
    
    print('     Number of positive predictions: ', positive_preds)
    return results
                 
    