In [49]:
import numpy as np
from scipy.ndimage import distance_transform_edt
import torch

def calculate_score(mask_gt, mask_pred, r):
    # Calculate the dimensions of the masks
    height, width = mask_gt.shape

    # Initialize the score
    total_score = 0

    # Create a distance field for the ground truth
    mask_distance = distance_transform_edt(mask_gt == 0)
    print(mask_distance)

    # Iterate over each pixel in the ground truth mask
    for y in range(height):
        for x in range(width):
            gt_class = mask_gt[y, x]
            pred_class = mask_pred[y, x]

            # Check if the prediction is in the exact same location
            if pred_class == gt_class:
                total_score += 100  # Exact match: same class and location
            else:
                # Check within a distance r
                if mask_distance[y, x] <= r:
                    if abs(pred_class - gt_class) <= 1:
                        total_score += 25  # Different class but within tolerance and distance r
                    if pred_class == gt_class:
                        total_score += 50  # Same class but within distance r

                # Check for one level class mismatch at the exact same location
                if abs(pred_class - gt_class) <= 1:
                    total_score += 50

    # Normalize the score by the maximum possible score
    max_score = height * width * 100
    normalized_score = total_score / max_score

    return normalized_score

def calculate_new_score(mask_gt, mask_pred, r):
    height, width = mask_gt.shape
    total_score = 0

    # Initialize score and distance arrays
    scores = np.zeros((height, width))

    # Iterate over each unique class in the ground truth
    for class_value in np.unique(mask_gt):
        # Create a mask for the current class
        class_mask = (mask_gt == class_value)
        print("class_mask")
        print(class_mask)
        
        # Compute the distance transform for the current class
        class_distances = distance_transform_edt(~class_mask)
        print("class_distances")
        print(class_distances)

        # Evaluate predictions within radius r for this class
        near_mask = (class_distances <= r)
        print("near_mask")
        print(near_mask)
        exact_mask = (mask_pred == class_value)
        print("exact_mask")
        print(exact_mask)

        # Scoring: exact location and same class
        scores += 100 * (class_mask & exact_mask)
        
        # Scoring: within r distance and same class
        scores += 50 * near_mask & exact_mask
        
        # Scoring: exact location but class mismatch within one level tolerance
        tolerance_mask = (np.abs(mask_pred - class_value) == 1)
        scores += 50 * (class_mask & tolerance_mask)

        # Scoring: within r distance and class mismatch within one level tolerance
        scores += 25 * (near_mask & tolerance_mask)

    # Normalize the total score by the maximum possible score
    max_score = height * width * 100
    normalized_score = np.sum(scores) / max_score

    return normalized_score

def mIOU(label, pred, num_classes=9):
#     pred = F.softmax(pred, dim=1)              
#     pred = torch.argmax(pred, dim=1).squeeze(1)
    iou_list = list()
    present_iou_list = list()

#     pred = pred.view(-1)
#     label = label.view(-1)
    # Note: Following for loop goes from 0 to (num_classes-1)
    # and ignore_index is num_classes, thus ignore_index is
    # not considered in computation of IoU.
    for sem_class in range(num_classes):
        pred_inds = (pred == sem_class)
        target_inds = (label == sem_class)
        if target_inds.long().sum().item() == 0:
            iou_now = float('nan')
        else: 
            intersection_now = (pred_inds[target_inds]).long().sum().item()
            union_now = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection_now
            iou_now = float(intersection_now) / float(union_now)
            present_iou_list.append(iou_now)
            # print(iou_now)
        iou_list.append(iou_now)
    return iou_list, present_iou_list, np.mean(present_iou_list)

# # Example usage
# mask_gt = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
# mask_pred = np.array([[2, 2, 3], [1, 2, 3], [1, 2, 3]])
# r = 1

# print(f"Normalized Score: {calculate_score(mask_gt, mask_pred, r):.2f}")

In [50]:
gt =np.array(([0,0,1,0,0],

              [0,0,0,1,0],

              [0,0,0,0,1],

              [0,0,0,1,0],

              [0,0,1,0,0]))

b = np.array(([0,1,0,0,0],

              [0,0,1,0,0],

              [0,0,0,1,0],

              [0,0,1,0,0],

              [0,1,0,0,0]))

c = np.array(([0,0,2,0,0],

              [0,0,0,2,0],

              [0,0,0,0,2],

              [0,0,0,2,0],

              [0,0,2,0,0]))

d = np.array(([0,0,3,0,0],

              [0,0,0,3,0],

              [0,0,0,0,3],

              [0,0,0,3,0],

              [0,0,3,0,0]))

r = 1

pred = d

iou_list, present_iou_list, iou_score = mIOU(torch.tensor(gt), torch.tensor(pred))

# print(f"Normalized Score: {calculate_score(gt, pred, r):.2f}")
print(f"Normalized New Score: {calculate_new_score(gt, pred, r):.2f}")
# print(f"mIoU Score: {iou_score}")

class_mask
[[ True  True False  True  True]
 [ True  True  True False  True]
 [ True  True  True  True False]
 [ True  True  True False  True]
 [ True  True False  True  True]]
class_distances
[[0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 1. 0.]
 [0. 0. 1. 0. 0.]]
near_mask
[[ True  True  True  True  True]
 [ True  True  True  True  True]
 [ True  True  True  True  True]
 [ True  True  True  True  True]
 [ True  True  True  True  True]]
exact_mask
[[ True  True False  True  True]
 [ True  True  True False  True]
 [ True  True  True  True False]
 [ True  True  True False  True]
 [ True  True False  True  True]]
class_mask
[[False False  True False False]
 [False False False  True False]
 [False False False False  True]
 [False False False  True False]
 [False False  True False False]]
class_distances
[[2.         1.         0.         1.         1.41421356]
 [2.23606798 1.41421356 1.         0.         1.        ]
 [2.82842712 2.23606798 1.41421356 1.         0.      