In [11]:
import sys

sys.path.append("..")

In [12]:
import torch
import torchvision.datasets
import torch.utils.data
import monai.metrics

import transforms
import datasets

from brain_segmentation_pytorch.unet import UNet

## Dataset setup

### Transforms

In [13]:
tr_inf = transforms.make_transforms(transforms.PASCAL_VOC_2012_MEAN, transforms.PASCAL_VOC_2012_STD)
tr_img_inv = transforms.inv_normalize(transforms.PASCAL_VOC_2012_MEAN, transforms.PASCAL_VOC_2012_STD)

### Dataset loading

In [14]:
dataroot = '../data/'

In [15]:
ds_train = torchvision.datasets.wrap_dataset_for_transforms_v2(torchvision.datasets.VOCSegmentation(
    root=dataroot,
    year="2012",
    image_set="train",
    download=False,
    transforms=tr_inf))


In [16]:
ds_val = torchvision.datasets.wrap_dataset_for_transforms_v2(torchvision.datasets.VOCSegmentation(
    root=dataroot,
    year="2012",
    image_set="val",
    download=False,
    transforms=tr_inf))

## Load Model

In [17]:
# Use GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [18]:
checkpointfile = "../checkpoints/test.chpt.pt"

In [19]:
# Model from checkpoint file
checkpoint = torch.load(checkpointfile)
unet_features = checkpoint['model_state_dict']['encoder1.enc1conv1.weight'].size(0)
model = UNet(
    in_channels=3,
    out_channels=datasets.CLASS_MAX+1,
    init_features=unet_features,
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.eval().to(device)

In [20]:
def topclass_dict(mask: torch.Tensor, k: int=4):
    pred_classes, counts = mask.unique(return_counts=True)
    topcounts, topcounts_idx = torch.topk(counts, min(k, len(counts)))
    #topclasses = pred_classes[topcounts_idx]

    top_k_classpixels = {
        datasets.CLASSNAMES[pred_classes[i].item()] : counts[i].item()
        for i in topcounts_idx
    }

    return top_k_classpixels

## Metrics

This section calculates some metrics through different ways:
 - Manually comparix pixel values
 - Metrics classes from MonAI
 - Metrics based on ConfusionMatrixMetrix from MONAI

### Predictions

In [21]:
img = torch.stack([ds_train[i][0] for i in range(3)])
mask = torch.stack([ds_train[i][1] for i in range(3)])

In [22]:
with torch.no_grad():
    pred = model(img.to(device))
    pred = pred.to(device='cpu')
    pred_amax = torch.argmax(pred, dim=1, keepdim=True).to(dtype=torch.long)

In [23]:
pred_amax.shape

torch.Size([3, 1, 256, 256])

In [24]:
topclass_dict(pred_amax)

{'background': 127721, 'chair': 38425, 'tv/monitor': 13871, 'dog': 12967}

In [25]:
# required format for MONAI metrics:
# float, but 1.0 at whatever we define as predicting a class, 0.0 elsewhere
pred_binarized = torch.zeros_like(pred).scatter_(1, pred_amax, 1.)

In [26]:
# 12 is dog
pred_amax[2,:,128,128]

tensor([12])

In [27]:
# 9 is chair
pred_amax[2,:,10,128]

tensor([9])

In [28]:
pred_binarized.shape

torch.Size([3, 21, 256, 256])

In [29]:
# class slice through center pixel, should be dog
pred_binarized[2,:,128,128]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0.])

In [30]:
# class slice through top center pixel, should be chair
pred_binarized[2,:,9,128]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])

In [31]:
# this better be 1
pred_binarized.max()

tensor(1.)

In [32]:
mask_onehot = torch.zeros_like(pred).scatter_(1, mask.to(dtype=torch.long), 1.)

In [33]:
mask_onehot.shape

torch.Size([3, 21, 256, 256])

In [34]:
mask_onehot[2,:,9,128]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])

In [35]:
# Number of pixels in the mask per class per batch
pixelcount_mask = mask_onehot.count_nonzero(dim=(-2,-1))
pixelcount_mask

tensor([[61316,  3884,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,   336,     0,     0,     0,     0,
             0],
        [51571,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
         13965],
        [15904,     0,     0,     0,     0,     0,     0,     0,     0, 36763,
             0,     0, 12869,     0,     0,     0,     0,     0,     0,     0,
             0]])

In [36]:
# Number of pixels in the prediction per class per batch
pixelcount_pred = pred_binarized.count_nonzero(dim=(-2,-1))
pixelcount_pred

tensor([[61918,  3381,     0,     0,     0,     0,     0,     0,     0,     0,
            37,     0,     0,     0,     0,   200,     0,     0,     0,     0,
             0],
        [51665,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
         13871],
        [14138,     0,     0,     0,     0,     0,     0,     0,     1, 38425,
             0,     0, 12967,     0,     0,     5,     0,     0,     0,     0,
             0]])

In [37]:
# Number of pixels BOTH in the mask AND prediction per class per batch
pixelcount_intersection = torch.logical_and(mask_onehot, pred_binarized).count_nonzero(dim=(-2,-1))
pixelcount_intersection

tensor([[60792,  2863,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,   178,     0,     0,     0,     0,
             0],
        [51276,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
         13576],
        [13297,     0,     0,     0,     0,     0,     0,     0,     0, 36163,
             0,     0, 12518,     0,     0,     0,     0,     0,     0,     0,
             0]])

In [38]:
# Number of pixels EITHER in the mask OR prediction per class per batch
pixelcount_union = torch.logical_or(mask_onehot, pred_binarized).count_nonzero(dim=(-2,-1))
pixelcount_union

tensor([[62442,  4402,     0,     0,     0,     0,     0,     0,     0,     0,
            37,     0,     0,     0,     0,   358,     0,     0,     0,     0,
             0],
        [51960,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
         14260],
        [16745,     0,     0,     0,     0,     0,     0,     0,     1, 39025,
             0,     0, 13318,     0,     0,     5,     0,     0,     0,     0,
             0]])

### Overall Accuracy

or Old Pixel-Wise Accuracy

Simply the sum of the correctly predicted pixels in this batch divided by the total number of pixels

In [39]:
mask.shape

torch.Size([3, 1, 256, 256])

In [40]:
pred_amax.shape

torch.Size([3, 1, 256, 256])

In [41]:
acc = (mask == pred_amax)
acc = (acc.sum()/acc.numel())
acc

tensor(0.9698)

In [42]:
acc = (mask == pred_amax).count_nonzero() / mask.numel()
acc

tensor(0.9698)

### Dice Score / F1-Score

https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

Sets:

`DSC = 2*|A^B| / (|A| + |B|)`


Bool only:

`DSC = 2*FP / (2*TP + FP + FN)`

#### Sample-Wise Dice score

Only pixels in one prediction will be compared to pixels in the corresponding target mask

In [43]:
dice_samplewise = 2 * pixelcount_intersection / (pixelcount_pred + pixelcount_mask)
dice_samplewise

tensor([[0.9866, 0.7882,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan, 0.0000,    nan,    nan,    nan,    nan, 0.6642,    nan,    nan,
            nan,    nan,    nan],
        [0.9934,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9754],
        [0.8852,    nan,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
         0.9619,    nan,    nan, 0.9690,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan]])

#### mean Dice

All pixels of the dataset within one class will be added up and a per-class dice score calculated from that.
The overall mean dice is the per-class dice score averaged over the classes.

This has the effect that underrepresented classes contribute equally to the overall score.

In [44]:
# aggregate pixels over batches

dice_batchagg = 2.0 * pixelcount_intersection.sum(dim=0) / (pixelcount_pred + pixelcount_mask).sum(dim=0)
dice_batchagg

tensor([0.9775, 0.7882,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
        0.9619, 0.0000,    nan, 0.9690,    nan,    nan, 0.6580,    nan,    nan,
           nan,    nan, 0.9754])

In [45]:
# calculate avarage over classes. This is our meanDice
dice_batchagg.nanmean()

tensor(0.6663)

In [46]:
# Ignore prediction of classes that were not in the mask, corresponds to ignore_empty=True
# this is for comparing with monai.metrics.DiceMetrics
dn = (pixelcount_pred + pixelcount_mask)
dn_nonempty = torch.where(pixelcount_mask > 0, dn, 0)

In [47]:
# aggregate pixels over batches

dice_batchagg_nonempty = 2.0 * pixelcount_intersection.sum(dim=0) / dn_nonempty.sum(dim=0)
dice_batchagg_nonempty

tensor([0.9775, 0.7882,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
        0.9619,    nan,    nan, 0.9690,    nan,    nan, 0.6642,    nan,    nan,
           nan,    nan, 0.9754])

In [48]:
# then average over classes
mean_dice_nonempty = dice_batchagg_nonempty.nanmean()
mean_dice_nonempty

tensor(0.8894)

#### monai.metrics.DiceMetric()

In [49]:
metric = monai.metrics.DiceMetric()

In [50]:
m = metric(pred_binarized, mask_onehot)
m

tensor([[0.9866, 0.7882,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan, 0.6642,    nan,    nan,
            nan,    nan,    nan],
        [0.9934,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9754],
        [0.8852,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
         0.9619,    nan,    nan, 0.9690,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan]])

In [51]:
metric.aggregate(reduction="mean_batch")

tensor([0.9551, 0.7882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.9619, 0.0000, 0.0000, 0.9690, 0.0000, 0.0000, 0.6642, 0.0000, 0.0000,
        0.0000, 0.0000, 0.9754])

In [52]:
metric.aggregate()

tensor([0.9120])

In [53]:
dice_samplewise_nonempty = torch.where(pixelcount_mask != 0, dice_samplewise, torch.nan)

In [54]:
dice_samplewise_nonempty.nanmean(dim=1).nanmean(dim=0)

tensor(0.9120)

### IoU (Intersection over Union / Jaccard)

https://en.wikipedia.org/wiki/Jaccard_index

Sets:

`DSC = |A^B| / (|A| v |B|)`

#### Sample-Wise IoU

Only pixels in one prediction will be compared to pixels in the corresponding target mask

In [55]:
# IoU per one sample
iou_samplewise = pixelcount_intersection / pixelcount_union
iou_samplewise

tensor([[0.9736, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan, 0.0000,    nan,    nan,    nan,    nan, 0.4972,    nan,    nan,
            nan,    nan,    nan],
        [0.9868,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9520],
        [0.7941,    nan,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
         0.9267,    nan,    nan, 0.9399,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan]])

In [56]:
# IoU per one sample, calculated from values available in confusion matrix should be same as above
iou_samplewise = pixelcount_intersection / (pixelcount_pred + pixelcount_mask - pixelcount_intersection)
iou_samplewise

tensor([[0.9736, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan, 0.0000,    nan,    nan,    nan,    nan, 0.4972,    nan,    nan,
            nan,    nan,    nan],
        [0.9868,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9520],
        [0.7941,    nan,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
         0.9267,    nan,    nan, 0.9399,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan]])

#### mIoU (mean IoU / mean Jaccard Index)

All pixels of the dataset within one class will be added up and a per-class IoU score calculated from that.
The overall mean IoU is the per-class IoU score averaged over the classes.

This has the effect that underrepresented classes contribute equally to the overall score.

In [57]:
# aggregate pixels over batches

my_iou_batchagg = pixelcount_intersection.sum(dim=0) / (pixelcount_pred + pixelcount_mask - pixelcount_intersection).sum(dim=0)
my_iou_batchagg

tensor([0.9559, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
        0.9267, 0.0000,    nan, 0.9399,    nan,    nan, 0.4904,    nan,    nan,
           nan,    nan, 0.9520])

In [58]:
# turn nans into ones, this corresponds to ignore_empty=False
my_iou_batchagg_withempty = my_iou_batchagg.nan_to_num(1.0)

In [59]:
# then average over classes
my_mean_iou_withempty = my_iou_batchagg_withempty.mean()
my_mean_iou_withempty

tensor(0.8531)

In [60]:
# Ignore prediction of classes that were not in the mask, corresponds to ignore_empty=True
# TODO: Think hard about whether this is actually a good idea
pixelcount_union_nonempty = torch.where(pixelcount_mask > 0, pixelcount_union, 0)

In [61]:
# aggregate pixels over batches

iou_batchagg_nonempty = pixelcount_intersection.sum(dim=0) / pixelcount_union_nonempty.sum(dim=0)
iou_batchagg_nonempty

tensor([0.9559, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
        0.9267,    nan,    nan, 0.9399,    nan,    nan, 0.4972,    nan,    nan,
           nan,    nan, 0.9520])

In [62]:
# then average over classes
mean_iou_nonempty = iou_batchagg_nonempty.nanmean()
mean_iou_nonempty

tensor(0.8204)

#### monai.metrics.MeanIoU()

This calculates a "mean" IoU, but the wrong one: it first calculates the sample IoU, then averages over channels then
over batches.

In [63]:
metric = monai.metrics.MeanIoU()

In [64]:
m = metric(y_pred=pred_binarized, y=mask_onehot)
m

tensor([[0.9736, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan, 0.4972,    nan,    nan,
            nan,    nan,    nan],
        [0.9868,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9520],
        [0.7941,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
         0.9267,    nan,    nan, 0.9399,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan]])

In [65]:
metric.aggregate()

tensor([0.8545])

In [66]:
# nan out classes that were not in the target mask at all
# This is equivalent to MONAI metrics ignore_empty=True
iou_samplewise_nonempty = torch.where(pixelcount_mask != 0, iou_samplewise, torch.nan)
iou_samplewise_nonempty

tensor([[0.9736, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan, 0.4972,    nan,    nan,
            nan,    nan,    nan],
        [0.9868,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9520],
        [0.7941,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
         0.9267,    nan,    nan, 0.9399,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan]])

In [67]:
# MeanIoU from MonAI is equivalent to averaging sample-wise IoU scores
iou_samplewise_nonempty.nanmean(dim=1).nanmean(dim=0)

tensor(0.8545)

## Metrics based on Confusion Matrix

In [68]:
confm_metrics_types = [
    "precision",
    "recall",
    "accuracy",
    "f1 score", # aka Dice Score
    "threat score" # aka Intersection over Union aka Jaccard Index
]

In [69]:
confm_metric = monai.metrics.ConfusionMatrixMetric(metric_name=confm_metrics_types)

In [70]:
confm = confm_metric(y_pred=pred_binarized, y=mask_onehot)

In [71]:
confm.shape

torch.Size([3, 21, 4])

In [72]:
# true positives
tp = confm[:,:,0]
tp

tensor([[60792.,  2863.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
             0.,     0.,     0.,     0.,     0.,     0.,   178.,     0.,     0.,
             0.,     0.,     0.],
        [51276.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
             0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
             0.,     0., 13576.],
        [13297.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
         36163.,     0.,     0., 12518.,     0.,     0.,     0.,     0.,     0.,
             0.,     0.,     0.]])

In [73]:
# false positives
fp = confm[:,:,1]
fp

tensor([[1.1260e+03, 5.1800e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.7000e+01, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 2.2000e+01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.8900e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 2.9500e+02],
        [8.4100e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.0000e+00, 2.2620e+03, 0.0000e+00, 0.0000e+00,
         4.4900e+02, 0.0000e+00, 0.0000e+00, 5.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00]])

In [74]:
# true negatives
tn = confm[:,:,2]
tn

tensor([[ 3094., 61134., 65536., 65536., 65536., 65536., 65536., 65536., 65536.,
         65536., 65499., 65536., 65536., 65536., 65536., 65178., 65536., 65536.,
         65536., 65536., 65536.],
        [13576., 65536., 65536., 65536., 65536., 65536., 65536., 65536., 65536.,
         65536., 65536., 65536., 65536., 65536., 65536., 65536., 65536., 65536.,
         65536., 65536., 51276.],
        [48791., 65536., 65536., 65536., 65536., 65536., 65536., 65536., 65535.,
         26511., 65536., 65536., 52218., 65536., 65536., 65531., 65536., 65536.,
         65536., 65536., 65536.]])

In [75]:
# false negatives
fn = confm[:,:,3]
fn

tensor([[ 524., 1021.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,  158.,    0.,    0.,    0.,    0.,
            0.],
        [ 295.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
          389.],
        [2607.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,  600.,
            0.,    0.,  351.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.]])

### Precision

In [76]:
# precision (sample-wise)
tp / (tp + fp)

tensor([[0.9818, 0.8468,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan, 0.0000,    nan,    nan,    nan,    nan, 0.8900,    nan,    nan,
            nan,    nan,    nan],
        [0.9925,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9787],
        [0.9405,    nan,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
         0.9411,    nan,    nan, 0.9654,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan]])

In [77]:
# per-class precision of batch aggregate
precision_batchagg = tp.sum(dim=0) / (tp + fp).sum(dim=0)
precision_batchagg

tensor([0.9816, 0.8468,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
        0.9411, 0.0000,    nan, 0.9654,    nan,    nan, 0.8683,    nan,    nan,
           nan,    nan, 0.9787])

In [78]:
# per-class precision of batch aggregate, calculated through the confusion matrix metric
# this is what we want
confm_metric.aggregate(reduction="sum_batch")[0]

tensor([0.9816, 0.8468,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
        0.9411, 0.0000,    nan, 0.9654,    nan,    nan, 0.8683,    nan,    nan,
           nan,    nan, 0.9787])

In [79]:
# overall precision (batch aggregate precision averaged over non-nan classes)
# this is what we want
precision_batchagg.nanmean()

tensor(0.6977)

In [80]:
# reduction of confusion matrix values prior to metric calculation
# This is NOT what we want.
confm_metric.aggregate(reduction="mean")[0]

tensor([0.9698])

In [81]:
tp.sum() / (tp.sum() + fp.sum())

tensor(0.9698)

### Recall

In [82]:
# recall (sample-wise)
tp / (tp + fn)

tensor([[0.9915, 0.7371,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan, 0.5298,    nan,    nan,
            nan,    nan,    nan],
        [0.9943,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9721],
        [0.8361,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
         0.9837,    nan,    nan, 0.9727,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan]])

In [83]:
# per-class recall of batch aggregate.
# This is what we want
tp.sum(dim=0) / (tp + fn).sum(dim=0)

tensor([0.9734, 0.7371,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
        0.9837,    nan,    nan, 0.9727,    nan,    nan, 0.5298,    nan,    nan,
           nan,    nan, 0.9721])

In [84]:
# per-class recall, calculated through the confusion matrix metric
# this is what we want
confm_metric.aggregate(reduction="sum_batch")[1]

tensor([0.9734, 0.7371,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
        0.9837,    nan,    nan, 0.9727,    nan,    nan, 0.5298,    nan,    nan,
           nan,    nan, 0.9721])

In [85]:
# overall recall
torch.nanmean(tp.sum(dim=0) / (tp + fn).sum(dim=0))

tensor(0.8615)

In [86]:
confm_metric.aggregate(reduction="mean")[1]

tensor([0.9698])

In [87]:
tp.sum() / (tp.sum() + fn.sum())

tensor(0.9698)

### Overall Accuracy from confusion matrix

In [88]:
# sample-wise accuracy from confusion matrix
(tp + tn) / (tp + fp + tn + fn)

tensor([[0.9748, 0.9765, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 0.9994, 1.0000, 1.0000, 1.0000, 1.0000, 0.9973, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000],
        [0.9896, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 0.9896],
        [0.9474, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         0.9563, 1.0000, 1.0000, 0.9878, 1.0000, 1.0000, 0.9999, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000]])

In [89]:
# per-class batch aggregate accuracy
confm_metric.aggregate(reduction="none")[2]

tensor([[0.9748, 0.9765, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 0.9994, 1.0000, 1.0000, 1.0000, 1.0000, 0.9973, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000],
        [0.9896, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 0.9896],
        [0.9474, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         0.9563, 1.0000, 1.0000, 0.9878, 1.0000, 1.0000, 0.9999, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000]])

In [90]:
# ConfusionMatrixMetric aggregates hits per-class. We can recover the overall hits by summing over only the
# true positives and NOT the true negatives, because the true negatives from each class will be contained within
# the true positives of the matching class
tp.sum()

tensor(190663.)

In [91]:
(mask == pred_amax).count_nonzero()

tensor(190663)

In [92]:
# incorrectly identified pixels will count towards the "false positives"
fp.sum()

tensor(5945.)

In [93]:
(mask != pred_amax).count_nonzero()

tensor(5945)

In [94]:
# total number of pixels from confusion matrix
(tp+fp).sum()

tensor(196608.)

In [95]:
mask.numel()

196608

In [96]:
# overall accuracy calculated from per-class TP/FP
tp.sum() / (tp+fp).sum()

tensor(0.9698)

### mean Accuracy from confusion matrix

In [97]:
# per-class batch aggregate accuracy, calculated through the confusion matrix metric
(tp + tn).sum(dim=0) / (tp + tn + fp + fn).sum(dim=0)

tensor([0.9706, 0.9922, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        0.9854, 0.9998, 1.0000, 0.9959, 1.0000, 1.0000, 0.9991, 1.0000, 1.0000,
        1.0000, 1.0000, 0.9965])

In [98]:
# mean accuracy aka per-class averaged accuracy
(tp + tn).sum() / (tp + tn + fp + fn).sum()

tensor(0.9971)

In [99]:
# per-class batch aggregate accuracy
confm_metric.aggregate(reduction="sum_batch")[2]

tensor([0.9706, 0.9922, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        0.9854, 0.9998, 1.0000, 0.9959, 1.0000, 1.0000, 0.9991, 1.0000, 1.0000,
        1.0000, 1.0000, 0.9965])

In [100]:
# mean accuracy aka per-class averaged accuracy
confm_metric.aggregate(reduction="sum")[2]

tensor([0.9971])

### mean Dice from confusion matrix

In [101]:
# sample-wise dice from confusion matrix
2.0*tp / (2.0 * tp + fp + fn)

tensor([[0.9866, 0.7882,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan, 0.0000,    nan,    nan,    nan,    nan, 0.6642,    nan,    nan,
            nan,    nan,    nan],
        [0.9934,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9754],
        [0.8852,    nan,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
         0.9619,    nan,    nan, 0.9690,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan]])

In [102]:
dice_samplewise

tensor([[0.9866, 0.7882,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan, 0.0000,    nan,    nan,    nan,    nan, 0.6642,    nan,    nan,
            nan,    nan,    nan],
        [0.9934,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9754],
        [0.8852,    nan,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
         0.9619,    nan,    nan, 0.9690,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan]])

In [103]:
# per-class batch aggregate Dice, calculated through the confusion matrix metric
# this is what we want
2.0*tp.sum(dim=0) / (2.0 * tp + fn + fp).sum(dim=0)

tensor([0.9775, 0.7882,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
        0.9619, 0.0000,    nan, 0.9690,    nan,    nan, 0.6580,    nan,    nan,
           nan,    nan, 0.9754])

In [104]:
# per-class batch aggregate IoU
# this is what we want
confm_metric.aggregate(reduction="sum_batch")[3]

tensor([0.9775, 0.7882,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
        0.9619, 0.0000,    nan, 0.9690,    nan,    nan, 0.6580,    nan,    nan,
           nan,    nan, 0.9754])

### mean IoU from confusion matrix

In [105]:
# sample-wise IoU from confusion matrix
tp / (tp + fn + fp)

tensor([[0.9736, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan, 0.0000,    nan,    nan,    nan,    nan, 0.4972,    nan,    nan,
            nan,    nan,    nan],
        [0.9868,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9520],
        [0.7941,    nan,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
         0.9267,    nan,    nan, 0.9399,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan]])

In [106]:
# compare to manually calculated sample-wise IoU for comparison
iou_samplewise

tensor([[0.9736, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan, 0.0000,    nan,    nan,    nan,    nan, 0.4972,    nan,    nan,
            nan,    nan,    nan],
        [0.9868,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan, 0.9520],
        [0.7941,    nan,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
         0.9267,    nan,    nan, 0.9399,    nan,    nan, 0.0000,    nan,    nan,
            nan,    nan,    nan]])

In [107]:
# per-class batch aggregate IoU, calculated through the confusion matrix metric
# this is what we want
tp.sum(dim=0) / (tp + fn + fp).sum(dim=0)

tensor([0.9559, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
        0.9267, 0.0000,    nan, 0.9399,    nan,    nan, 0.4904,    nan,    nan,
           nan,    nan, 0.9520])

In [108]:
# per-class batch aggregate IoU
# this is what we want
confm_metric.aggregate(reduction="sum_batch")[4]

tensor([0.9559, 0.6504,    nan,    nan,    nan,    nan,    nan,    nan, 0.0000,
        0.9267, 0.0000,    nan, 0.9399,    nan,    nan, 0.4904,    nan,    nan,
           nan,    nan, 0.9520])

In [109]:
# reduction of confusion matrix values prior to metric calculation
# This is NOT what we want.
confm_metric.aggregate(reduction="sum")[4]

tensor([0.9413])

In [110]:
# reduction of confusion matrix values prior to metric calculation
# This is NOT what we want.
confm_iou_mean = tp.sum() / (tp.sum() + fn.sum() + fp.sum())
confm_iou_mean

tensor(0.9413)