load the models and weights

In [1]:
import torch
import torchvision
import os
import pickle
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F 
from torchvision import datasets
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import torch.nn as nn
from PIL import Image
from torchcam.methods import SmoothGradCAMpp
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import to_pil_image
from utils import imread, im_to_txt_path, isfile
from torchvision.transforms import transforms as T
import cv2
import random



parent_dir = 'rois2/'
obj_dir = parent_dir + 'objects/'
img_dir = parent_dir + 'images/'
label_dir = parent_dir + 'labels/'
model_dir = parent_dir + 'models/'

obj_train_dir = obj_dir + 'train/'
obj_test_dir = obj_dir + 'test/'
false_positive_label_dir = parent_dir + 'tiles/false_positives/labels/'

result_dir = parent_dir + 'results/'
result_model1_dir = result_dir + 'model1/'
result_model1_test_dir = result_model1_dir + 'test/'

result_model1_test_labels_dir = result_model1_test_dir + 'labels/'
result_model1_test_images_dir = result_model1_test_dir + 'images/'

os.makedirs(result_dir, exist_ok=True)
os.makedirs(result_model1_dir, exist_ok=True)
os.makedirs(result_model1_test_dir, exist_ok=True) 
os.makedirs(result_model1_test_labels_dir, exist_ok=True)
os.makedirs(result_model1_test_images_dir, exist_ok=True)

mean_std_dict = pickle.load(open(img_dir + 'mean_std.pkl', 'rb'))
print(mean_std_dict)


{'expert1_id-6381471f7f8a5e686a52765f_left-36214_top-41262_right-39265_bottom-46232': (array([0.88391932, 0.86455804, 0.84128348]), array([0.08340014, 0.09904173, 0.12370593])), 'expert1_id-6381471f7f8a5e686a52765f_left-36627_top-47431_right-39464_bottom-52356': (array([0.8841721 , 0.86190506, 0.83281352]), array([0.08451419, 0.10180634, 0.1258441 ])), 'expert1_id-6381471f7f8a5e686a527706_left-75035_top-24698_right-80405_bottom-28610': (array([0.86614397, 0.84570188, 0.82592797]), array([0.10124425, 0.11657547, 0.1331032 ])), 'expert1_id-6381471f7f8a5e686a5277c2_left-68060_top-39334_right-73669_bottom-44184': (array([0.89545478, 0.87718601, 0.86096679]), array([0.06975292, 0.08177992, 0.09553303])), 'expert1_id-6381471f7f8a5e686a52781f_left-70416_top-29678_right-75686_bottom-33272': (array([0.81283723, 0.75033681, 0.66633584]), array([0.12331021, 0.1447023 , 0.17060996])), 'expert1_id-638147207f8a5e686a527923_left-84307_top-28200_right-89535_bottom-31674': (array([0.71946232, 0.6813055

In [None]:
len(mean_std_dict)

In [3]:
def get_model1_result(roi_fp, detect_model, save_dir = result_model1_test_labels_dir, window_size=512, stride=256, 
                      score_thr=0.5, nms_thr=0.5, softmax_thr = 0.5,
                      box_width_thr = [20,500], box_area_thr = [200,150000], box_ratio_thr = 0.2, 
                      clean_up = True, output_false_neg_num = False):
    if clean_up:
        for dirpath, dirnames, filenames in os.walk(result_model1_dir + 'train'):
            for filename in [f for f in filenames if f.endswith(".txt")]:
                os.remove(os.path.join(dirpath, filename))
                
        for dirpath, dirnames, filenames in os.walk(result_model1_dir + 'test'):
            for filename in [f for f in filenames if f.endswith(".txt")]:
                os.remove(os.path.join(dirpath, filename))
    
    detect_model.eval()
    roi_name = roi_fp.split('/')[-1].split('.')[0]
    # print(roi_name)
    result_file_path = save_dir + roi_name + '.txt'
    result_file = open(result_file_path, 'w')
    
    orig_roi = imread(roi_fp)
    roi_w = orig_roi.shape[1]
    roi_h = orig_roi.shape[0]
    roi = cv2.copyMakeBorder(orig_roi, 0, window_size - roi_h % window_size, 0, window_size - roi_w % window_size, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    if roi is None:
        print('     Error: No roi file found')
        return 0
    
    # load and record groundtruth labels
    true_boxes = []
    if isfile(im_to_txt_path(roi_fp)):
        with open(im_to_txt_path(roi_fp), 'r') as f:
            for line in f:
                temp_label = [int(x) for x in line.strip().split(' ')]
                true_boxes.append([temp_label[1], temp_label[2], temp_label[3], temp_label[4]])
    true_obj_found = [0] * len(true_boxes)
    obj_map = torch.zeros((roi_h, roi_w))
    obj_index_map = torch.zeros((roi_h, roi_w))
    for i, box in enumerate(true_boxes):
        obj_map[box[1]:box[3], box[0]:box[2]] = (box[3] - box[1]) * (box[2] - box[0])
        obj_index_map[box[1]:box[3], box[0]:box[2]] = i + 1
    
    
    print('     roi width: ', roi_w, ' roi height: ', roi_h)
    # slide the window from top left to bottom right
    start_x = 0
    while start_x <= roi_w:
        windows = []
        offsets = []
        
        start_y = 0
        while start_y <= roi_h: 
            end_x = int(start_x + window_size)
            end_y = int(start_y + window_size)
            window = roi[start_y:end_y, start_x:end_x]
            window = torch.tensor(window).permute(2,0,1).float().to('cuda')
            # print(window.shape)
            windows.append(window)
            offsets.append([start_x, start_y])
            
            start_y += stride
        # print('     number of windows: ', len(windows))
        
        
        preds = detect_model(windows)


        for idx, pred in enumerate(preds):
            # window_timer.start()
            
            offset = offsets[idx]
            filtered_idx = torchvision.ops.nms(pred['boxes'], pred['scores'], nms_thr)
            # print("     Time to perform NMS on current window: ", nms_timer.elapsed_time())
            
            pred_boxes = pred['boxes'][filtered_idx]
            pred_scores = pred['scores'][filtered_idx]
            
            for score, box in zip(pred_scores,pred_boxes):
                # score thr
                if score < score_thr:
                    continue
                
                pred_box_x1 = int(box[0] + offset[0])
                pred_box_y1 = int(box[1] + offset[1])
                pred_box_x2 = int(box[2] + offset[0])
                pred_box_y2 = int(box[3] + offset[1])
                
                # range thr
                if pred_box_x1 < 0 or pred_box_y1 < 0 or pred_box_x2 > roi_w or pred_box_y2 > roi_h:
                    continue
                
                #  size thr
                if (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1) < box_area_thr[0] \
                    or (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1) > box_area_thr[1] \
                    or pred_box_x2 - pred_box_x1 < box_width_thr[0] or pred_box_x2 - pred_box_x1 > box_width_thr[1] \
                    or pred_box_y2 - pred_box_y1 < box_width_thr[0] or pred_box_y2 - pred_box_y1 > box_width_thr[1] \
                    or (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1) < box_area_thr[0] \
                    or (pred_box_x2 - pred_box_x1) / (pred_box_y2 - pred_box_y1) < box_ratio_thr \
                    or (pred_box_y2 - pred_box_y1) / (pred_box_x2 - pred_box_x1) < box_ratio_thr:
                    continue
                
                result_file.write(f"{pred_box_x1} {pred_box_y1} {pred_box_x2} {pred_box_y2}\n")
                
                if not output_false_neg_num:
                    continue
                
                # intersect thr
                intersect_area = (obj_map[pred_box_y1:pred_box_y2, pred_box_x1:pred_box_x2]>0).sum().item()
                pred_area = (pred_box_x2 - pred_box_x1) * (pred_box_y2 - pred_box_y1)
                if intersect_area > 0:
                    obj_area = torch.max(obj_map[pred_box_y1:pred_box_y2, pred_box_x1:pred_box_x2]).item()
                    if intersect_area >= (obj_area + pred_area - intersect_area) * nms_thr:
                        output_index = int(torch.max(obj_index_map[pred_box_y1:pred_box_y2, pred_box_x1:pred_box_x2]).item()) - 1
                        true_obj_found[output_index] = 1
             
        start_x += stride
        
    if output_false_neg_num:
        false_neg_num = sum([1 - x for x in true_obj_found])
        return false_neg_num
    else:
        return -1    

In [4]:
class CustomDataset(Dataset):

    def __init__(self, labels_list, roi_img, mean, std):
        
        self.labels_list = labels_list
        self.roi_img = roi_img
        self.mean = mean
        self.std = std

    def __len__(self):
        return len(self.labels_list)

    def __getitem__(self, idx):
        # print(self.inputs_list[idx])
        box = self.labels_list[idx]
        
        obj_img = self.roi_img[box[1]:box[3],box[0]:box[2]]
        transf = v2.Compose([
            v2.ToImage(), 
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=self.mean, std=self.std),
            v2.Resize((260,260)),
        ])
        output = transf(obj_img)

        return output.to('cuda'), box

In [5]:
def generate_model1_result_dataloader(pred_labels_fp):
    roi_name = pred_labels_fp.split('/')[-1].split('.')[0]
    roi_path = img_dir + roi_name + '.png'
    
    labels_file = open(pred_labels_fp, 'r')
    lines = labels_file.readlines()
    
    labels_list = []
    
    roi_img = imread(roi_path)
    mean, std = mean_std_dict[roi_name]
    
    for line in lines:
        temp_label = [int(x) for x in line.strip().split(' ')]
        x1, y1, x2, y2 = int(temp_label[0]), int(temp_label[1]), int(temp_label[2]), int(temp_label[3])
        labels_list.append(torch.tensor([x1, y1, x2, y2],dtype = torch.int))
        
        
        
    dataset = CustomDataset(labels_list, roi_img, mean, std)
    return DataLoader(dataset, batch_size=32, shuffle=False)
    

In [6]:
def check_pred_boxes(pred_boxes, true_boxes, img_tensor, iou_thr):
    pred_obj_correct = [0] * len(pred_boxes)
    
    temp_height, temp_width = img_tensor.shape[1], img_tensor.shape[2]
    true_boxes_map = torch.zeros((temp_height, temp_width))
    for true_box in (true_boxes):
        x1, y1, x2, y2 = true_box
        x1,y1,x2,y2 = int(x1), int(y1), int(x2), int(y2)
        true_boxes_map[y1:y2, x1:x2] = (y2-y1)*(x2-x1)
    for i, pred_box in enumerate(pred_boxes):
        x1, y1, x2, y2 = pred_box
        x1,y1,x2,y2 = int(x1), int(y1), int(x2), int(y2)
        intersect_area = (true_boxes_map[y1:y2, x1:x2]>0).sum().item()
        pred_area = (y2 - y1) * (x2 - x1)
        if intersect_area > 0:
            obj_area = torch.max(true_boxes_map[y1:y2, x1:x2]).item()
            if intersect_area >= (obj_area + pred_area - intersect_area) * iou_thr:
                pred_obj_correct[i] = 1
    return pred_obj_correct

In [7]:
def get_model2_result(roi_path, roi_model1_result_loader, class_model, output_softmax_score = False):
    true_objects = []
    if output_softmax_score:
        img_tensor = torch.tensor(imread(roi_path))
        roi_label_path = im_to_txt_path(roi_path)
        if os.path.exists(roi_label_path):
            roi_label_file = open(roi_label_path, 'r')
            lines = roi_label_file.readlines()
            for line in lines:
                temp = [int(x) for x in line.strip().split(' ')]
                true_objects.append(temp[1:])
                
    positive_preds = 0
    
    criterion = nn.CrossEntropyLoss()
    class_model.eval()
    results = []
    
    output_list = []
    target_list = []
    
    
    for i, (imgs,boxes) in enumerate(roi_model1_result_loader):
        if (i + 1) % 10 == 0 or i + 1 == len(roi_model1_result_loader): 
            # print('     Processing batch[{}/{}]'.format(i+1, len(roi_model1_result_loader)))
            None
        
        # print(boxes.shape)
        with torch.no_grad():
            outputs = class_model(imgs)
            preds = torch.argmax(outputs, 1)
            if output_softmax_score:
                output_list.extend(outputs)
                target_list.extend(torch.tensor(check_pred_boxes(boxes, true_objects, img_tensor, 0.5)))
                
            positive_preds += torch.sum(preds == 1).item()
            
            # print(outputs)
            # print(probs)
            # print('\n')
            
            for idx, pred in enumerate(preds):
                if pred == 1:
                    # print(idx, len(boxes))
                    results.append(boxes[idx].tolist())
                    
    loss = criterion(torch.stack(output_list), torch.stack(target_list).to('cuda'))
    # print('     Number of positive predictions: ', positive_preds)
    
    if output_softmax_score:
        return results, loss.item()
    else:
        return results
                 
    

In [8]:
def plot_results(roi_path, model2_results):
    roi_img = imread(roi_path)
    
    true_boxes = []
    groundtruth_path = im_to_txt_path(roi_path)
    if isfile(groundtruth_path):
        with open(groundtruth_path, 'r') as f:
            for line in f:
                temp_label = [int(x) for x in line.strip().split(' ')]
                true_boxes.append([temp_label[1], temp_label[2], temp_label[3], temp_label[4]])
    
    print('Number of true objects:', len(true_boxes))
        
    for box in true_boxes:
        x1, y1, x2, y2 = box
        cv2.rectangle(roi_img, (x1, y1), (x2, y2), (0, 255 , 0), 2)
    
    plt.figure(dpi=300)
    plt.imshow(roi_img)
    plt.show()
    
    
    print('Number of predicted objects:', len(model2_results))
    roi_img = imread(roi_path)
    for box in model2_results:
        x1, y1, x2, y2 = box
        cv2.rectangle(roi_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
    plt.figure(dpi=300)
    plt.imshow(roi_img)
    plt.show()

In [9]:
test_rois = []
rois_list_file = open(img_dir + 'test_rois.txt','r')
lines = rois_list_file.readlines()
for line in lines:
    test_rois.append(line.strip())

# randomly pick a roi to test
input_roi_path = test_rois[random.randint(0, len(test_rois) - 1)]

In [10]:
detection_model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(num_classes=2).to('cuda')
# detection_model.load_state_dict(torch.load('rois2/models/round2/2024_09_17 penalty-1/round2_model_epoch3.pth'))

classification_model = torchvision.models.efficientnet_b1().to('cuda')
classification_model.classifier[1] = nn.Linear(classification_model.classifier[1].in_features, 2).to('cuda')
# classification_model.load_state_dict(torch.load('rois2/results/model2/round3_model_epoch1.pth'))

In [None]:
nms_list = [0.1, 0.2, 0.3, 0.4, 0.5]
epoch_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

detection_model.load_state_dict(torch.load('rois2/models/round2/2024_09_17 penalty-1/round2_model_epoch10.pth'))
for nms in nms_list:
    fn_list = [None] * len(epoch_list)
    class_loss_list = [None] * len(epoch_list)
    
    for epoch in epoch_list:
        classification_model.load_state_dict(torch.load('rois2/results/model2/round3_model_epoch' + str(epoch) + '.pth'))
        fn_score = 0
        loss_score = 0
        for input_roi_path in test_rois:
            roi_name = input_roi_path.split('/')[-1].split('.')[0]
            
            detect_fn = get_model1_result(input_roi_path, detection_model, result_model1_test_labels_dir, nms_thr=nms, score_thr=0, output_false_neg_num=True)
            
            loader = generate_model1_result_dataloader(result_model1_test_labels_dir + roi_name + '.txt')
            model2_results, class_loss = get_model2_result(input_roi_path, loader, classification_model, output_softmax_score=True)
            fn_score += detect_fn
            loss_score += class_loss
        
        fn_score /= len(test_rois)
        loss_score /= len(test_rois)
        
        fn_list[epoch - 1] = fn_score
        class_loss_list[epoch - 1] = loss_score
            
        plt.figure()
        plt.title('NMS threshold: ' + str(nms))
        plt.plot(epoch_list, fn_list, color='red', label='False Negative Number')
        plt.plot(epoch_list, class_loss_list, color='green', label='Classification Loss')
        plt.xlabel('Epoch')
        plt.xticks(epoch_list)
        plt.show()
        plt.close()
            
            




In [13]:
# get_model1_result(input_roi_path, detection_model, nms_thr=1.0, score_thr=0, output_false_neg_num=True)
# loader = generate_model1_result_dataloader(result_model1_test_labels_dir + input_roi_path.split('/')[-1].split('.')[0] + '.txt')
# results, loss = get_model2_result(input_roi_path, loader, classification_model, output_softmax_score=True)
# print('Loss:', loss)
# plot_results(input_roi_path, results)