In [1]:
import os
import torch
import torch.nn.functional as F
import numpy as np

from PIL import Image
from tqdm import tqdm

In [2]:
class mIoUMask():

    def __init__(self, num_classes=3, ignore_class=None, eps=1e-7):
        super(mIoUMask, self).__init__()
        self.eps = eps
        self.num_class = num_classes + (1 if ignore_class is not None else 0)
        self.ignore_class = ignore_class
        self.confusion_matrix = np.zeros((self.num_class, 3)) # tp, fp, fn for each class
        self.cnt = np.zeros(self.num_class)

    def _generate_matrix(self, pre_image, gt_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class) 
        if self.ignore_class is not None:
            mask = mask & (gt_image != self.ignore_class)
        pre_image = pre_image[mask].astype('int')
        gt_image = gt_image[mask].astype('int')

        confusion_matrix = np.zeros((self.num_class, 3))
        cnt = np.zeros(self.num_class)
        for i in range(self.num_class):
            gt = (gt_image == i)
            pred = (pre_image == i)
            tp = np.sum(gt & pred)
            fp = np.sum((gt == False) & (pred == True))
            fn = np.sum((gt == True) & (pred == False))
            cnt[i] += np.sum(gt)
            confusion_matrix[i, 0] += tp
            confusion_matrix[i, 1] += fp
            confusion_matrix[i, 2] += fn

        return confusion_matrix, cnt

    def reset(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def add_batch(self, gt_image, pre_image):
        assert gt_image.shape == pre_image.shape
        confusion_matrix, cnt = self._generate_matrix(gt_image, pre_image)
        self.confusion_matrix += confusion_matrix
        self.cnt += cnt

    def Tissue_Intersection_over_Union(self):
        tiou = np.zeros((self.num_class))
        for i in range(self.num_class):
            tp = self.confusion_matrix[i, 0]
            fp = self.confusion_matrix[i, 1]
            fn = self.confusion_matrix[i, 2]
            tiou[i] = tp / (tp + fp + fn + self.eps)
        return tiou

    def Mean_Intersection_over_Union(self):
        tiou = self.Tissue_Intersection_over_Union()
        MIoU = np.mean(tiou)
        return MIoU

    def Frequency_Weighted_Intersection_over_Union(self):
        freq = self.cnt / np.sum(self.cnt)
        tiou = self.Tissue_Intersection_over_Union()
        FWIoU = (freq[freq > 0] * tiou[freq > 0]).sum()
        return FWIoU

    def __call__(self, pred, mask):
        self.add_batch(pred, mask)
        return self.Mean_Intersection_over_Union(), self.Frequency_Weighted_Intersection_over_Union()

Patch level

In [3]:
pred_dir = f"./logs/wsss4luad3k6_3k6_run1_2/test/mask"
gt_dir = "./data/WSSS4LUAD/3.testing/mask"

In [4]:
metrics = mIoUMask(num_classes=3)
for name in tqdm(os.listdir(gt_dir)):
    gt_mask = np.array(Image.open(os.path.join(gt_dir, name)))
    pred_mask = np.array(Image.open(os.path.join(pred_dir, name)))
    if gt_mask.shape != pred_mask.shape:
        print(name, gt_mask.shape, pred_mask.shape)
    miou, fwiou = metrics(pred_mask, gt_mask)

miou = metrics.Mean_Intersection_over_Union()
fwiou = metrics.Frequency_Weighted_Intersection_over_Union()
tissue_iou = metrics.Tissue_Intersection_over_Union()

print(f"mIoU: {miou:.4f}, FWIoU: {fwiou:.4f}")
print(f"tissue_iou: {tissue_iou}")

100%|██████████| 80/80 [00:07<00:00, 10.34it/s]

mIoU: 0.6541, FWIoU: 0.7065
tissue_iou: [0.76666176 0.63436127 0.56127134]



