In [1]:
import sys

sys.path.append("..")

In [2]:
import torch
import torchvision.datasets
import torch.utils.data
import monai.metrics

import transforms
import datasets

from brain_segmentation_pytorch.unet import UNet



## Dataset setup

### Transforms

In [3]:
tr_inf = transforms.make_transforms(datasets.DATASET_STATS["2012"]["rgb_mean"], datasets.DATASET_STATS["2012"]["rgb_std"])
tr_img_inv = transforms.inv_normalize(datasets.DATASET_STATS["2012"]["rgb_mean"], datasets.DATASET_STATS["2012"]["rgb_std"])

### Dataset loading

In [4]:
dataroot = '../data/'

In [5]:
ds_train = torchvision.datasets.wrap_dataset_for_transforms_v2(torchvision.datasets.VOCSegmentation(
    root=dataroot,
    year="2012",
    image_set="train",
    download=False,
    transforms=tr_inf))


In [6]:
ds_val = torchvision.datasets.wrap_dataset_for_transforms_v2(torchvision.datasets.VOCSegmentation(
    root=dataroot,
    year="2012",
    image_set="val",
    download=False,
    transforms=tr_inf))

## Load Model

In [7]:
# Use GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
checkpointfile = "../checkpoints/test.chpt.pt"

In [9]:
# Model from checkpoint file
checkpoint = torch.load(checkpointfile)
unet_features = checkpoint['model_state_dict']['encoder1.enc1conv1.weight'].size(0)
model = UNet(
    in_channels=3,
    out_channels=datasets.CLASS_MAX+1,
    init_features=unet_features,
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.eval().to(device)

In [10]:
def topclass_dict(mask: torch.Tensor, k: int=4):
    pred_classes, counts = mask.unique(return_counts=True)
    topcounts, topcounts_idx = torch.topk(counts, min(k, len(counts)))
    #topclasses = pred_classes[topcounts_idx]

    top_k_classpixels = {
        datasets.CLASSNAMES[pred_classes[i].item()] : counts[i].item()
        for i in topcounts_idx
    }

    return top_k_classpixels

## Metrics

This section calculates some metrics through different ways:
 - Manually comparix pixel values
 - Metrics classes from MonAI
 - Metrics based on ConfusionMatrixMetrix from MONAI

### Predictions

In [11]:
img = torch.stack([ds_train[i][0] for i in range(3)])
mask = torch.stack([ds_train[i][1] for i in range(3)])

In [12]:
with torch.no_grad():
    pred = model(img.to(device))
    pred = pred.to(device='cpu')
    pred_amax = torch.argmax(pred, dim=1, keepdim=True).to(dtype=torch.long)

In [13]:
pred_amax.shape

torch.Size([3, 1, 256, 256])

In [14]:
topclass_dict(pred_amax)

{'background': 147526, 'cat': 29973, 'bus': 15064, 'sofa': 985}

In [15]:
# required format for MONAI metrics:
# float, but 1.0 at whatever we define as predicting a class, 0.0 elsewhere
pred_binarized = torch.zeros_like(pred).scatter_(1, pred_amax, 1.)

In [16]:
# 12 is dog
pred_amax[2,:,128,128]

tensor([8])

In [17]:
# 9 is chair
pred_amax[2,:,10,128]

tensor([8])

In [18]:
pred_binarized.shape

torch.Size([3, 21, 256, 256])

In [19]:
# class slice through center pixel, should be dog
pred_binarized[2,:,128,128]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])

In [20]:
# class slice through top center pixel, should be chair
pred_binarized[2,:,9,128]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])

In [21]:
# this better be 1
pred_binarized.max()

tensor(1.)

In [22]:
mask_onehot = torch.zeros_like(pred).scatter_(1, mask.to(dtype=torch.long), 1.)

In [23]:
mask_onehot.shape

torch.Size([3, 21, 256, 256])

In [24]:
mask_onehot[2,:,9,128]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])

In [25]:
# Number of pixels in the mask per class per batch
pixelcount_mask = mask_onehot.count_nonzero(dim=(-2,-1))
pixelcount_mask

tensor([[61316,  3884,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,   336,     0,     0,     0,     0,
             0],
        [51571,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
         13965],
        [15904,     0,     0,     0,     0,     0,     0,     0,     0, 36763,
             0,     0, 12869,     0,     0,     0,     0,     0,     0,     0,
             0]])

In [26]:
# Number of pixels in the prediction per class per batch
pixelcount_pred = pred_binarized.count_nonzero(dim=(-2,-1))
pixelcount_pred

tensor([[51413,     0,     0,     0,     0,     0, 14121,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     2,
             0],
        [61957,     0,     0,     0,     0,     0,   935,    11,  2123,     0,
             0,     0,     0,     0,   281,     0,     0,    39,     0,    99,
            91],
        [34156,     0,     0,     0,     0,     0,     8,     0, 27850,     0,
             1,   492,     4,   786,     0,   472,     0,   776,   985,     6,
             0]])

In [27]:
# Number of pixels BOTH in the mask AND prediction per class per batch
pixelcount_intersection = torch.logical_and(mask_onehot, pred_binarized).count_nonzero(dim=(-2,-1))
pixelcount_intersection

tensor([[51201,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [48749,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [12777,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     2,     0,     0,     0,     0,     0,     0,     0,
             0]])

In [28]:
# Number of pixels EITHER in the mask OR prediction per class per batch
pixelcount_union = torch.logical_or(mask_onehot, pred_binarized).count_nonzero(dim=(-2,-1))
pixelcount_union

tensor([[61528,  3884,     0,     0,     0,     0, 14121,     0,     0,     0,
             0,     0,     0,     0,     0,   336,     0,     0,     0,     2,
             0],
        [64779,     0,     0,     0,     0,     0,   935,    11,  2123,     0,
             0,     0,     0,     0,   281,     0,     0,    39,     0,    99,
         14056],
        [37283,     0,     0,     0,     0,     0,     8,     0, 27850, 36763,
             1,   492, 12871,   786,     0,   472,     0,   776,   985,     6,
             0]])

### Overall Accuracy

or Old Pixel-Wise Accuracy

Simply the sum of the correctly predicted pixels in this batch divided by the total number of pixels

In [29]:
mask.shape

torch.Size([3, 1, 256, 256])

In [30]:
pred_amax.shape

torch.Size([3, 1, 256, 256])

In [31]:
acc = (mask == pred_amax)
acc = (acc.sum()/acc.numel())
acc

tensor(0.5734)

In [32]:
acc = (mask == pred_amax).count_nonzero() / mask.numel()
acc

tensor(0.5734)

In [33]:
mask.shape

torch.Size([3, 1, 256, 256])

In [34]:
mask.size()[-2:]

torch.Size([256, 256])

### Dice Score / F1-Score

https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

Sets:

`DSC = 2*|A^B| / (|A| + |B|)`


Bool only:

`DSC = 2*FP / (2*TP + FP + FN)`

#### Sample-Wise Dice score

Only pixels in one prediction will be compared to pixels in the corresponding target mask

In [35]:
dice_samplewise = 2 * pixelcount_intersection / (pixelcount_pred + pixelcount_mask)
dice_samplewise

tensor([[9.0839e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan, 0.0000e+00,        nan],
        [8.5880e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00, 0.0000e+00, 0.0000e+00,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00,        nan,        nan, 0.0000e+00,
                nan, 0.0000e+00, 0.0000e+00],
        [5.1047e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         3.1073e-04, 0.0000e+00,        nan, 0.0000e+00,        nan, 0.0000e+00,
         0.0000e+00, 0.0000e+00,        nan]])

#### mean Dice

All pixels of the dataset within one class will be added up and a per-class dice score calculated from that.
The overall mean dice is the per-class dice score averaged over the classes.

This has the effect that underrepresented classes contribute equally to the overall score.

In [36]:
# aggregate pixels over batches

dice_batchagg = 2.0 * pixelcount_intersection.sum(dim=0) / (pixelcount_pred + pixelcount_mask).sum(dim=0)
dice_batchagg

tensor([8.1593e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        3.1073e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,        nan, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])

In [37]:
# calculate avarage over classes. This is our meanDice
dice_batchagg.nanmean()

tensor(0.0510)

In [38]:
# Ignore prediction of classes that were not in the mask, corresponds to ignore_empty=True
# this is for comparing with monai.metrics.DiceMetrics
dn = (pixelcount_pred + pixelcount_mask)
dn_nonempty = torch.where(pixelcount_mask > 0, dn, 0)

In [39]:
# aggregate pixels over batches

dice_batchagg_nonempty = 2.0 * pixelcount_intersection.sum(dim=0) / dn_nonempty.sum(dim=0)
dice_batchagg_nonempty

tensor([8.1593e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
               nan,        nan,        nan, 0.0000e+00,        nan,        nan,
        3.1073e-04,        nan,        nan, 0.0000e+00,        nan,        nan,
               nan,        nan, 0.0000e+00])

In [40]:
# then average over classes
mean_dice_nonempty = dice_batchagg_nonempty.nanmean()
mean_dice_nonempty

tensor(0.1360)

#### monai.metrics.DiceMetric()

In [41]:
metric = monai.metrics.DiceMetric()

In [42]:
m = metric(pred_binarized, mask_onehot)
m

tensor([[9.0839e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan,        nan,        nan],
        [8.5880e-01,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00],
        [5.1047e-01,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
         3.1073e-04,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan]])

In [43]:
metric.aggregate(reduction="mean_batch")

tensor([7.5922e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        3.1073e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])

In [44]:
metric.aggregate()

tensor([0.3008])

In [45]:
dice_samplewise_nonempty = torch.where(pixelcount_mask != 0, dice_samplewise, torch.nan)

In [46]:
dice_samplewise_nonempty.nanmean(dim=1).nanmean(dim=0)

tensor(0.3008)

### IoU (Intersection over Union / Jaccard)

https://en.wikipedia.org/wiki/Jaccard_index

Sets:

`DSC = |A^B| / (|A| v |B|)`

#### Sample-Wise IoU

Only pixels in one prediction will be compared to pixels in the corresponding target mask

In [47]:
# IoU per one sample
iou_samplewise = pixelcount_intersection / pixelcount_union
iou_samplewise

tensor([[8.3216e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan, 0.0000e+00,        nan],
        [7.5254e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00, 0.0000e+00, 0.0000e+00,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00,        nan,        nan, 0.0000e+00,
                nan, 0.0000e+00, 0.0000e+00],
        [3.4270e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         1.5539e-04, 0.0000e+00,        nan, 0.0000e+00,        nan, 0.0000e+00,
         0.0000e+00, 0.0000e+00,        nan]])

In [48]:
# IoU per one sample, calculated from values available in confusion matrix should be same as above
iou_samplewise = pixelcount_intersection / (pixelcount_pred + pixelcount_mask - pixelcount_intersection)
iou_samplewise

tensor([[8.3216e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan, 0.0000e+00,        nan],
        [7.5254e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00, 0.0000e+00, 0.0000e+00,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00,        nan,        nan, 0.0000e+00,
                nan, 0.0000e+00, 0.0000e+00],
        [3.4270e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         1.5539e-04, 0.0000e+00,        nan, 0.0000e+00,        nan, 0.0000e+00,
         0.0000e+00, 0.0000e+00,        nan]])

#### mIoU (mean IoU / mean Jaccard Index)

All pixels of the dataset within one class will be added up and a per-class IoU score calculated from that.
The overall mean IoU is the per-class IoU score averaged over the classes.

This has the effect that underrepresented classes contribute equally to the overall score.

In [49]:
# aggregate pixels over batches

my_iou_batchagg = pixelcount_intersection.sum(dim=0) / (pixelcount_pred + pixelcount_mask - pixelcount_intersection).sum(dim=0)
my_iou_batchagg

tensor([6.8908e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.5539e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,        nan, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])

In [50]:
# turn nans into ones, this corresponds to ignore_empty=False
my_iou_batchagg_withempty = my_iou_batchagg.nan_to_num(1.0)

In [51]:
# then average over classes
my_mean_iou_withempty = my_iou_batchagg_withempty.mean()
my_mean_iou_withempty

tensor(0.2709)

In [52]:
# Ignore prediction of classes that were not in the mask, corresponds to ignore_empty=True
# TODO: Think hard about whether this is actually a good idea
pixelcount_union_nonempty = torch.where(pixelcount_mask > 0, pixelcount_union, 0)

In [53]:
# aggregate pixels over batches

iou_batchagg_nonempty = pixelcount_intersection.sum(dim=0) / pixelcount_union_nonempty.sum(dim=0)
iou_batchagg_nonempty

tensor([6.8908e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
               nan,        nan,        nan, 0.0000e+00,        nan,        nan,
        1.5539e-04,        nan,        nan, 0.0000e+00,        nan,        nan,
               nan,        nan, 0.0000e+00])

In [54]:
# then average over classes
mean_iou_nonempty = iou_batchagg_nonempty.nanmean()
mean_iou_nonempty

tensor(0.1149)

#### monai.metrics.MeanIoU()

This calculates a "mean" IoU, but the wrong one: it first calculates the sample IoU, then averages over channels then
over batches.

In [55]:
metric = monai.metrics.MeanIoU()

In [56]:
m = metric(y_pred=pred_binarized, y=mask_onehot)
m

tensor([[8.3216e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan,        nan,        nan],
        [7.5254e-01,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00],
        [3.4270e-01,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
         1.5539e-04,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan]])

In [57]:
metric.aggregate()

tensor([0.2560])

In [58]:
# nan out classes that were not in the target mask at all
# This is equivalent to MONAI metrics ignore_empty=True
iou_samplewise_nonempty = torch.where(pixelcount_mask != 0, iou_samplewise, torch.nan)
iou_samplewise_nonempty

tensor([[8.3216e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan,        nan,        nan],
        [7.5254e-01,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00],
        [3.4270e-01,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
         1.5539e-04,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan]])

In [59]:
# MeanIoU from MonAI is equivalent to averaging sample-wise IoU scores
iou_samplewise_nonempty.nanmean(dim=1).nanmean(dim=0)

tensor(0.2560)

## Metrics based on Confusion Matrix

In [60]:
confm_metrics_types = [
    "precision",
    "recall",
    "accuracy",
    "f1 score", # aka Dice Score
    "threat score" # aka Intersection over Union aka Jaccard Index
]

In [61]:
confm_metric = monai.metrics.ConfusionMatrixMetric(metric_name=confm_metrics_types)

In [62]:
confm = confm_metric(y_pred=pred_binarized, y=mask_onehot)

In [63]:
confm.shape

torch.Size([3, 21, 4])

In [64]:
# true positives
tp = confm[:,:,0]
tp

tensor([[5.1201e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.8749e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.2777e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         2.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00]])

In [65]:
# false positives
fp = confm[:,:,1]
fp

tensor([[2.1200e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         1.4121e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 2.0000e+00, 0.0000e+00],
        [1.3208e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         9.3500e+02, 1.1000e+01, 2.1230e+03, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 2.8100e+02, 0.0000e+00, 0.0000e+00, 3.9000e+01,
         0.0000e+00, 9.9000e+01, 9.1000e+01],
        [2.1379e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         8.0000e+00, 0.0000e+00, 2.7850e+04, 0.0000e+00, 1.0000e+00, 4.9200e+02,
         2.0000e+00, 7.8600e+02, 0.0000e+00, 4.7200e+02, 0.0000e+00, 7.7600e+02,
         9.8500e+02, 6.0000e+00, 0.0000e+00]])

In [66]:
# true negatives
tn = confm[:,:,2]
tn

tensor([[ 4008., 61652., 65536., 65536., 65536., 65536., 51415., 65536., 65536.,
         65536., 65536., 65536., 65536., 65536., 65536., 65200., 65536., 65536.,
         65536., 65534., 65536.],
        [  757., 65536., 65536., 65536., 65536., 65536., 64601., 65525., 63413.,
         65536., 65536., 65536., 65536., 65536., 65255., 65536., 65536., 65497.,
         65536., 65437., 51480.],
        [28253., 65536., 65536., 65536., 65536., 65536., 65528., 65536., 37686.,
         28773., 65535., 65044., 52665., 64750., 65536., 65064., 65536., 64760.,
         64551., 65530., 65536.]])

In [67]:
# false negatives
fn = confm[:,:,3]
fn

tensor([[10115.,  3884.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
             0.,     0.,     0.,     0.,     0.,     0.,   336.,     0.,     0.,
             0.,     0.,     0.],
        [ 2822.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
             0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
             0.,     0., 13965.],
        [ 3127.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
         36763.,     0.,     0., 12867.,     0.,     0.,     0.,     0.,     0.,
             0.,     0.,     0.]])

### Precision

In [68]:
# precision (sample-wise)
tp / (tp + fp)

tensor([[0.9959,    nan,    nan,    nan,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan, 0.0000,    nan],
        [0.7868,    nan,    nan,    nan,    nan,    nan, 0.0000, 0.0000, 0.0000,
            nan,    nan,    nan,    nan,    nan, 0.0000,    nan,    nan, 0.0000,
            nan, 0.0000, 0.0000],
        [0.3741,    nan,    nan,    nan,    nan,    nan, 0.0000,    nan, 0.0000,
            nan, 0.0000, 0.0000, 0.5000, 0.0000,    nan, 0.0000,    nan, 0.0000,
         0.0000, 0.0000,    nan]])

In [69]:
# per-class precision of batch aggregate
precision_batchagg = tp.sum(dim=0) / (tp + fp).sum(dim=0)
precision_batchagg

tensor([0.7641,    nan,    nan,    nan,    nan,    nan, 0.0000, 0.0000, 0.0000,
           nan, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000,    nan, 0.0000,
        0.0000, 0.0000, 0.0000])

In [70]:
# per-class precision of batch aggregate, calculated through the confusion matrix metric
# this is what we want
confm_metric.aggregate(reduction="sum_batch")[0]

tensor([0.7641,    nan,    nan,    nan,    nan,    nan, 0.0000, 0.0000, 0.0000,
           nan, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000,    nan, 0.0000,
        0.0000, 0.0000, 0.0000])

In [71]:
# overall precision (batch aggregate precision averaged over non-nan classes)
# this is what we want
precision_batchagg.nanmean()

tensor(0.0903)

In [72]:
# reduction of confusion matrix values prior to metric calculation
# This is NOT what we want.
confm_metric.aggregate(reduction="mean")[0]

tensor([0.5734])

In [73]:
tp.sum() / (tp.sum() + fp.sum())

tensor(0.5734)

### Recall

In [74]:
# recall (sample-wise)
tp / (tp + fn)

tensor([[8.3503e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan,        nan,        nan],
        [9.4528e-01,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00],
        [8.0338e-01,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
         1.5541e-04,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan]])

In [75]:
# per-class recall of batch aggregate.
# This is what we want
tp.sum(dim=0) / (tp + fn).sum(dim=0)

tensor([8.7527e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
               nan,        nan,        nan, 0.0000e+00,        nan,        nan,
        1.5541e-04,        nan,        nan, 0.0000e+00,        nan,        nan,
               nan,        nan, 0.0000e+00])

In [76]:
# per-class recall, calculated through the confusion matrix metric
# this is what we want
confm_metric.aggregate(reduction="sum_batch")[1]

tensor([8.7527e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
               nan,        nan,        nan, 0.0000e+00,        nan,        nan,
        1.5541e-04,        nan,        nan, 0.0000e+00,        nan,        nan,
               nan,        nan, 0.0000e+00])

In [77]:
# overall recall
torch.nanmean(tp.sum(dim=0) / (tp + fn).sum(dim=0))

tensor(0.1459)

In [78]:
confm_metric.aggregate(reduction="mean")[1]

tensor([0.5734])

In [79]:
tp.sum() / (tp.sum() + fn.sum())

tensor(0.5734)

### Overall Accuracy from confusion matrix

In [80]:
# sample-wise accuracy from confusion matrix
(tp + tn) / (tp + fp + tn + fn)

tensor([[0.8424, 0.9407, 1.0000, 1.0000, 1.0000, 1.0000, 0.7845, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9949, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000],
        [0.7554, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9857, 0.9998, 0.9676,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9957, 1.0000, 1.0000, 0.9994,
         1.0000, 0.9985, 0.7855],
        [0.6261, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9999, 1.0000, 0.5750,
         0.4390, 1.0000, 0.9925, 0.8036, 0.9880, 1.0000, 0.9928, 1.0000, 0.9882,
         0.9850, 0.9999, 1.0000]])

In [81]:
# per-class batch aggregate accuracy
confm_metric.aggregate(reduction="none")[2]

tensor([[0.8424, 0.9407, 1.0000, 1.0000, 1.0000, 1.0000, 0.7845, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9949, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000],
        [0.7554, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9857, 0.9998, 0.9676,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9957, 1.0000, 1.0000, 0.9994,
         1.0000, 0.9985, 0.7855],
        [0.6261, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9999, 1.0000, 0.5750,
         0.4390, 1.0000, 0.9925, 0.8036, 0.9880, 1.0000, 0.9928, 1.0000, 0.9882,
         0.9850, 0.9999, 1.0000]])

In [82]:
# ConfusionMatrixMetric aggregates hits per-class. We can recover the overall hits by summing over only the
# true positives and NOT the true negatives, because the true negatives from each class will be contained within
# the true positives of the matching class
tp.sum()

tensor(112729.)

In [83]:
(mask == pred_amax).count_nonzero()

tensor(112729)

In [84]:
# incorrectly identified pixels will count towards the "false positives"
fp.sum()

tensor(83879.)

In [85]:
(mask != pred_amax).count_nonzero()

tensor(83879)

In [86]:
# total number of pixels from confusion matrix
(tp+fp).sum()

tensor(196608.)

In [87]:
mask.numel()

196608

In [88]:
# overall accuracy calculated from per-class TP/FP
tp.sum() / (tp+fp).sum()

tensor(0.5734)

### mean Accuracy from confusion matrix

In [89]:
# per-class batch aggregate accuracy, calculated through the confusion matrix metric
(tp + tn).sum(dim=0) / (tp + tn + fp + fn).sum(dim=0)

tensor([0.7413, 0.9802, 1.0000, 1.0000, 1.0000, 1.0000, 0.9234, 0.9999, 0.8475,
        0.8130, 1.0000, 0.9975, 0.9345, 0.9960, 0.9986, 0.9959, 1.0000, 0.9959,
        0.9950, 0.9995, 0.9285])

In [90]:
# mean accuracy aka per-class averaged accuracy
(tp + tn).sum() / (tp + tn + fp + fn).sum()

tensor(0.9594)

In [91]:
# per-class batch aggregate accuracy
confm_metric.aggregate(reduction="sum_batch")[2]

tensor([0.7413, 0.9802, 1.0000, 1.0000, 1.0000, 1.0000, 0.9234, 0.9999, 0.8475,
        0.8130, 1.0000, 0.9975, 0.9345, 0.9960, 0.9986, 0.9959, 1.0000, 0.9959,
        0.9950, 0.9995, 0.9285])

In [92]:
# mean accuracy aka per-class averaged accuracy
confm_metric.aggregate(reduction="sum")[2]

tensor([0.9594])

### mean Dice from confusion matrix

In [93]:
# sample-wise dice from confusion matrix
2.0*tp / (2.0 * tp + fp + fn)

tensor([[9.0839e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan, 0.0000e+00,        nan],
        [8.5880e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00, 0.0000e+00, 0.0000e+00,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00,        nan,        nan, 0.0000e+00,
                nan, 0.0000e+00, 0.0000e+00],
        [5.1047e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         3.1073e-04, 0.0000e+00,        nan, 0.0000e+00,        nan, 0.0000e+00,
         0.0000e+00, 0.0000e+00,        nan]])

In [94]:
dice_samplewise

tensor([[9.0839e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan, 0.0000e+00,        nan],
        [8.5880e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00, 0.0000e+00, 0.0000e+00,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00,        nan,        nan, 0.0000e+00,
                nan, 0.0000e+00, 0.0000e+00],
        [5.1047e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         3.1073e-04, 0.0000e+00,        nan, 0.0000e+00,        nan, 0.0000e+00,
         0.0000e+00, 0.0000e+00,        nan]])

In [95]:
# per-class batch aggregate Dice, calculated through the confusion matrix metric
# this is what we want
2.0*tp.sum(dim=0) / (2.0 * tp + fn + fp).sum(dim=0)

tensor([8.1593e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        3.1073e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,        nan, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])

In [96]:
# per-class batch aggregate IoU
# this is what we want
confm_metric.aggregate(reduction="sum_batch")[3]

tensor([8.1593e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        3.1073e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,        nan, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])

In [97]:
(2.0*tp.sum(dim=0) / (2.0 * tp + fn + fp).sum(dim=0)).nanmean()

tensor(0.0510)

### mean IoU from confusion matrix

In [98]:
# sample-wise IoU from confusion matrix
tp / (tp + fn + fp)

tensor([[8.3216e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan, 0.0000e+00,        nan],
        [7.5254e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00, 0.0000e+00, 0.0000e+00,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00,        nan,        nan, 0.0000e+00,
                nan, 0.0000e+00, 0.0000e+00],
        [3.4270e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         1.5539e-04, 0.0000e+00,        nan, 0.0000e+00,        nan, 0.0000e+00,
         0.0000e+00, 0.0000e+00,        nan]])

In [99]:
# compare to manually calculated sample-wise IoU for comparison
iou_samplewise

tensor([[8.3216e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan,        nan,        nan,        nan,        nan,
                nan,        nan,        nan, 0.0000e+00,        nan,        nan,
                nan, 0.0000e+00,        nan],
        [7.5254e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00, 0.0000e+00, 0.0000e+00,        nan,        nan,        nan,
                nan,        nan, 0.0000e+00,        nan,        nan, 0.0000e+00,
                nan, 0.0000e+00, 0.0000e+00],
        [3.4270e-01,        nan,        nan,        nan,        nan,        nan,
         0.0000e+00,        nan, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         1.5539e-04, 0.0000e+00,        nan, 0.0000e+00,        nan, 0.0000e+00,
         0.0000e+00, 0.0000e+00,        nan]])

In [100]:
# per-class batch aggregate IoU, calculated through the confusion matrix metric
# this is what we want
tp.sum(dim=0) / (tp + fn + fp).sum(dim=0)

tensor([6.8908e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.5539e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,        nan, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])

In [101]:
# per-class batch aggregate IoU
# this is what we want
confm_metric.aggregate(reduction="sum_batch")[4]

tensor([6.8908e-01, 0.0000e+00,        nan,        nan,        nan,        nan,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.5539e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,        nan, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])

In [102]:
(tp.sum(dim=0) / (tp + fn + fp).sum(dim=0)).nanmean()

tensor(0.0431)

In [103]:
# reduction of confusion matrix values prior to metric calculation
# This is NOT what we want.
confm_metric.aggregate(reduction="sum")[4]

tensor([0.4019])

In [104]:
# reduction of confusion matrix values prior to metric calculation
# This is NOT what we want.
confm_iou_mean = tp.sum() / (tp.sum() + fn.sum() + fp.sum())
confm_iou_mean

tensor(0.4019)

## metrics.MultiMetrics class

My MultiMetrics class uses all the metrics calculated above

In [105]:
import metrics

In [106]:
multimetrics = metrics.MultiMetrics()

In [107]:
multimetrics.update(pred, mask);

In [108]:
multimetrics.calculate()

{'OverallAccuracy': 0.5733693440755209,
 'MeanPrecision': 0.0902940109372139,
 'MeanRecall': 0.14590436220169067,
 'MeanAccuracy': 0.9593685865402222,
 'MeanDice': 0.051014743745326996,
 'MeanIoU': 0.043077364563941956}