In [1]:
import numpy as np
import torch
from sklearn.metrics import confusion_matrix

In [9]:
num_classes = 2
ignore_index = 250

def cm(preds, targets, mask=None):
    """
    Calculate confusion matrix
    preds: B*H*W
    targets: B*H*W
    """
    preds, targets = preds.numpy(), targets.numpy()
    targets = targets.astype(int)
    if mask is None:
        mask = np.ones_like(preds) == 1
    else:
        mask = np.squeeze(mask)
    k = (preds >= 0) & (preds < num_classes) & (preds != ignore_index)
    k &= (targets >= 0) & (targets < num_classes) & (targets != ignore_index)
    k &= (mask.astype(np.bool))
    return confusion_matrix(preds[k].flatten(), targets[k].flatten(), labels=list(range(num_classes)))

def get_segmt_scores(cm):
    """
    Calculate scores for segmentation task from the confusion matrix
    """
    if cm.sum() == 0:
        return 0, 0, 0
    pixel = np.diag(cm).sum() / cm.sum()
    perclass = np.diag(cm) / np.sum(cm, axis=0)
    iou = np.diag(cm) / (cm.sum(axis=1) + cm.sum(axis=0) - np.diag(cm))
    # with np.errstate(divide='ignore', invalid='ignore'):
    #     pixel = np.diag(self.cm).sum() / np.float(self.cm.sum())
    #     perclass = np.diag(self.cm) / self.cm.sum(1).astype(np.float)
    #     IU = np.diag(self.cm) / (self.cm.sum(1) + self.cm.sum(0) - np.diag(self.cm)).astype(np.float)
    return pixel, np.nanmean(perclass), np.nanmean(iou)


In [72]:
def compute_miou(x_pred, x_output):
    x_pred_label = x_pred
    x_output_label = x_output
    batch_size = x_pred.size(0)
    class_nb = 2
    device = x_pred.device
    for i in range(batch_size):
        true_class = 0
        first_switch = True
        invalid_mask = (x_output[i] != 250).float()
        for j in range(class_nb):
            pred_mask = torch.eq(x_pred_label[i], j * torch.ones(x_pred_label[i].shape).long().to(device))
            true_mask = torch.eq(x_output_label[i], j * torch.ones(x_output_label[i].shape).long().to(device))
            mask_comb = pred_mask.float() + true_mask.float()
            union = torch.sum((mask_comb > 0).float() * invalid_mask)  # remove non-defined pixel predictions
            intsec = torch.sum((mask_comb > 1).float())
            if union == 0:
                continue
            if first_switch:
                class_prob = intsec / union
                first_switch = False
            else:
                class_prob = intsec / union + class_prob
            true_class += 1
        if i == 0:
            batch_avg = class_prob / true_class
        else:
            batch_avg = class_prob / true_class + batch_avg
    return batch_avg / batch_size


def compute_iou(x_pred, x_output):
    x_pred_label = x_pred
    x_output_label = x_output
    batch_size = x_pred.size(0)
    for i in range(batch_size):
        if i == 0:
            pixel_acc = torch.div(
                torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
                torch.sum((x_output_label[i] != 250).float()))
        else:
            pixel_acc = pixel_acc + torch.div(
                torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
                torch.sum((x_output_label[i] != 250).float()))
    return pixel_acc / batch_size


def depth_error(x_pred, x_output):
    device = x_pred.device
    binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device)
    x_pred_true = x_pred.masked_select(binary_mask)
    x_output_true = x_output.masked_select(binary_mask)
    abs_err = torch.abs(x_pred_true - x_output_true)
    rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true
    return (torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)).item(), \
           (torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)).item()

In [73]:
pred = torch.tensor([
   [[0, 0, 1, 1],
    [0, 0, 1, 1],
    [0, 0, 1, 0],
    [0, 0, 0, 0]],
    
   [[1, 1, 1, 1],
    [1, 1, 1, 1],
    [0, 0, 1, 0],
    [0, 0, 0, 0]]
])

target = torch.tensor([
   [[0, 0, 1, 1],
    [0, 250, 1, 1],
    [0, 0, 0, 0],
    [0, 0, 0, 0]],
    
   [[0, 1, 1, 1],
    [0, 0, 1, 1],
    [250, 0, 1, 0],
    [250, 0, 0, 0]]
])

In [74]:
# cm_ = cm(pred, target)
# print(cm_)
# get_segmt_scores(cm_)

In [75]:
pix_acc = 0.5 * (14/15 + 11/14)
mean = 0.5 * (0.5 * (10/11 + 4/4) + 0.5 * (8/8 + 4/6))
miou = 0.5 * (0.5 * (10/11 + 4/5) + 0.5 * (5/8 + 6/9))
pix_acc, mean, miou

(0.8595238095238096, 0.8939393939393939, 0.7501893939393939)

In [76]:
compute_miou(pred, target)

tensor(0.7502)

In [77]:
compute_iou(pred, target)

tensor(0.8595)