In [1]:
!pip install torcheval
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torcheval
  Downloading torcheval-0.0.6-py3-none-any.whl (158 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m158.4/158.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchtnt>=0.0.5 (from torcheval)
  Downloading torchtnt-0.1.0-py3-none-any.whl (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.9/87.9 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
Collecting pyre-extensions (from torchtnt>=0.0.5->torcheval)
  Downloading pyre_extensions-0.0.30-py3-none-any.whl (12 kB)
Collecting typing-inspect (from pyre-extensions->torchtnt>=0.0.5->torcheval)
  Downloading typing_inspect-0.8.0-py3-none-any.whl (8.7 kB)
Collecting mypy-extensions>=0.3.0 (from typing-inspect->pyre-extensions->torchtnt>=0.0.5->torcheval)
  Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)
Installing collected packages: mypy-extensions, typing-

In [2]:
from torcheval.metrics.functional import binary_auprc
from torchmetrics import PrecisionRecallCurve
import matplotlib.pyplot as plt
import numpy as np
import torch

# Copy and paste these functions and it will spit out the results for each metric, input is two 2d numpy arrays

# Compare AUPRC vs. AUC Judd metrics
# Input: Two numpy arrays representing an image where pixels are a value from 0 to 1
#   gt:   Ground truth
#   pred: Prediction

def compare_metrics(gt, pred):
    fig, ax = plt.subplots()
    ax.imshow(gt)

    fig2, ax2 = plt.subplots()
    ax2.imshow(pred)

    print("AUC Judd:", auc_judd(gt, pred))
    print("AUPRC:", auprc(gt, pred))

# AUC Judd Implementation and visualizes the ROC Curve as well
# Source: https://github.com/tarunsharma1/saliency_metrics/blob/master/salience_metrics.py

def auc_judd(gt, pred):

    gt_binary = np.where(gt >= 0.5, 1, gt)
    gt_binary = np.where(gt_binary < 0.5, 0.0, gt_binary)
    
    thresholds = []

    for i in range(0, gt_binary.shape[0]):
        for j in range(0, gt_binary.shape[1]):
            if gt_binary[i][j] > 0.5:
                thresholds.append(pred[i][j])
    
    num_fixations = len(thresholds)
    area = []
    area.append((0.0, 0.0))

    for thresh in thresholds:
        temp = np.zeros(pred.shape)
        temp[pred >= thresh] = 1

        num_overlap = np.where(np.add(temp, gt_binary) == 2)[0].shape[0]
        tp = num_overlap/num_fixations

        fp = (np.sum(temp) - num_overlap)/((np.shape(gt_binary)[0] * np.shape(gt_binary)[1]) - num_fixations)

        area.append((round(fp, 4), round(tp, 4)))

    area.append((1.0, 1.0))

    plt.scatter(*zip(*area))
    plt.show()

    area.sort(key = lambda x:x[0])
    tp_list = [x[1] for x in area]
    fp_list = [x[0] for x in area]

    return np.trapz(np.array(tp_list), np.array(fp_list))

def auprc(gt, pred):
    gt_tensor = torch.tensor(gt.flatten())
    pred_tensor = torch.tensor(pred.flatten())

    auprc = binary_auprc(gt_tensor, pred_tensor)

    pr_curve = PrecisionRecallCurve(task="binary")
    precision, recall, thresholds = pr_curve(gt_tensor, pred_tensor)

    plt.show()

    plt.plot(recall, precision)
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.show()
    return auprc

# Run this code: 

# compare_metrics(gt_array, pred_array)


In [None]:
import random
import imageio
import matplotlib.patches as patches
# Helper functions to generate prediction bboxes

# Generate a random RECTANGULAR bounding box to compute intersecting areas 
#   img: Ground truth image
#   ax: Matplotlib plot to add bbox visualization to
def generate_random_bbox(img, ax):
    height = np.array(img).shape[0]
    width = np.array(img).shape[1]

    x = int(random.random() * width)
    x2 = int(random.random() * width)
    y = int(random.random() * height)
    y2 = int(random.random() * height)

    x = min(x, x2)
    w = abs(x - x2)

    y= min(y, y2)
    h = abs(y - y2)

    # ---- Or, change this code to manually set a bbox for testing purposes ----------------------------------------
    y = 40
    x = 150
    w = 50
    h = 50

    rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')

    # Add the patch to the Axes
    ax.add_patch(rect)

    data = np.zeros((np.array(img).shape[0], np.array(img).shape[1]), dtype=int)
    for i in range(y, y + h):
        for j in range(x, x + w):
            data[i][j] = 1
    return data

# Instead of generating rectangular bbox, generate from a jpeg drawing of a mask.
def generate_bbox_from_image(img_path):
    img = imageio.imread(img_path)
    img = np.array(img)
    return (img[:,:,0] < 50).astype(int)