In [1]:
import torch
from rpn.build_rpn import RPN_Model
from sam.build_sam import SAM_Model
from data_builder.build_dataset import PlanetscopeDataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader

In [2]:
ROOT_PATH = 'C://Users/anind/Dropbox (ASU)/ASU/Kerner-Lab/SAT-SAM(Dataset)/ps_france/all_dataset/'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
rpn_model = RPN_Model('rpn/checkpoint/rpn_model_1.14_Fr_Rw.pth', 2, device)
sam_model = SAM_Model('sam/checkpoint/sam_vit_l_0b3195.pth', 'large', device)

In [4]:
dataset = PlanetscopeDataset(ROOT_PATH, train=False)

In [5]:
def filter_boxes(predictions, ensemble, overlap):
    ensemble = ensemble.astype(np.uint8)
    filtered_boxes = []

    for i in range (len(predictions['boxes'])):
        xmin = predictions['boxes'][i][0]
        ymin = predictions['boxes'][i][1]
        xmax = predictions['boxes'][i][2]
        ymax = predictions['boxes'][i][3]
        
        box_mask = np.zeros((ensemble.shape[0], ensemble.shape[1]))
        box_mask[int(ymin):int(ymax), int(xmin):int(xmax)] = 1
        num_ones_box = np.count_nonzero(box_mask)
        res = ensemble * box_mask
        num_ones_intersection = np.count_nonzero(res)

        _overlap = num_ones_intersection / num_ones_box

        if _overlap > overlap:
            filtered_boxes.append(predictions['boxes'][i])

    return filtered_boxes

def iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2)
    union = np.logical_or(mask1, mask2)
    iou_score = np.sum(intersection) / np.sum(union)
    return iou_score

def calculate_iou_matrix(target_masks, predicted_masks):
    num_target_masks = len(target_masks)
    num_predicted_masks = len(predicted_masks)

    print("num_target_masks: ", num_target_masks)
    print("num_predicted_masks: ", num_predicted_masks)

    iou_matrix = np.zeros((num_target_masks, num_predicted_masks))
    for i in range(num_target_masks):
        for j in range(num_predicted_masks):
            iou_matrix[i, j] = iou(target_masks[i], predicted_masks[j])

    return iou_matrix

def calculate_iou(target_masks, predicted_masks):
    iou_matrix = calculate_iou_matrix(target_masks, predicted_masks)
    # Use the Hungarian algorithm to find the best assignment
    row_ind, col_ind = linear_sum_assignment(-iou_matrix)

    print("row_ind: ", row_ind)
    print("col_ind: ", col_ind)

    total_iou = 0.0
    for i, j in zip(row_ind, col_ind):
        total_iou += iou_matrix[i, j]

    average_iou = total_iou / len(row_ind)
    return average_iou

def calculate_precision_recall(matched_ious, num_pred_instances, num_gt_instances, threshold=0.5):
    num_true_positives = len(matched_ious[matched_ious >= threshold])
    num_false_positives = num_pred_instances - num_true_positives
    num_false_negatives = num_gt_instances - num_true_positives
    precision = num_true_positives / (num_true_positives + num_false_positives)
    recall = num_true_positives / (num_true_positives + num_false_negatives)
    
    return precision, recall

In [6]:
iou_scores = []

for i, (sam_image, rpn_image, target, ensemble)  in enumerate(dataset): 
    try:
        rpn_image = rpn_image.squeeze(0).to(device)  
        predictions = rpn_model.predict(rpn_image)
        predictions = rpn_model.postprocess(predictions, nms_threshold=0.9, score_threshold=0.6)

        filtered_predictions = filter_boxes(predictions, ensemble, 0.5)

        low_res_masks, iou_predictions = sam_model.predict(sam_image, filtered_predictions)
        high_res_masks = sam_model.postprocess(low_res_masks, tuple(sam_image.size))
        high_res_masks = high_res_masks.squeeze().cpu().numpy()

        iou_score = calculate_iou(target_masks=np.array(target['masks']), predicted_masks=np.array(high_res_masks))
        print("Image Id: ", i, " Average IoU Score: ", iou_score)

        iou_scores.append(iou_score)
    except:
        print("Error in image: ", i)
        continue

Error in image:  0
Error in image:  1
num_target_masks:  9
num_predicted_masks:  10
row_ind:  [0 1 2 3 4 5 6 7 8]
col_ind:  [5 0 7 1 2 3 8 6 4]
Image Id:  2  Average IoU Score:  0.025668721933851044
num_target_masks:  1
num_predicted_masks:  2
row_ind:  [0]
col_ind:  [0]
Image Id:  3  Average IoU Score:  0.0
num_target_masks:  14
num_predicted_masks:  4
row_ind:  [ 0  2  8 10]
col_ind:  [0 2 1 3]
Image Id:  4  Average IoU Score:  0.04135203244053413
num_target_masks:  1
num_predicted_masks:  8
row_ind:  [0]
col_ind:  [0]
Image Id:  5  Average IoU Score:  0.0028370543974343162
num_target_masks:  33
num_predicted_masks:  19
row_ind:  [ 0  1  2  3  4  5  6  7  8  9 10 11 14 15 17 20 24 26 29]
col_ind:  [ 1  3  6 11  8 12 13 14 16 18 15 10  7  5  4  0 17  2  9]
Image Id:  6  Average IoU Score:  0.3771524458426688
num_target_masks:  9
num_predicted_masks:  6
row_ind:  [0 1 2 3 7 8]
col_ind:  [5 2 4 0 3 1]
Image Id:  7  Average IoU Score:  0.2800423076456437
num_target_masks:  10
num_predict

In [7]:
np.mean(iou_scores)

0.27028505351723836