In [1]:
#default_exp segmentation.metrics

In [2]:
%autosave 60 
import fastai; print(fastai.__version__)

Autosaving every 60 seconds
1.0.58.dev0


In [3]:
#export
from fastai.vision import *
from fastai.metrics import foreground_acc, dice

In [4]:
#export
_all_ = ["foreground_acc", "dice"]

### test

In [5]:
from local.test import *

In [227]:
_input_softmax = tensor([[
    [
    [0,0,1],
    [0,0,1],
    [1,1,1]
],
    [
    [1,1,0],
    [1,1,0],
    [0,0,0]
]
]]).float()

_input_sigmoid = tensor([[
    [
    [2,2,-2],
    [2,-2,1],
    [-2,1,1]
],
    [
    [1,1,-1],
    [1,-1,2],
    [-1,2,2]
]
]]).float()


_target_softmax = tensor([[[
    [0,0,0],
    [0,0,1],
    [0,1,1]
]]])

_target_sigmoid = tensor([[[
    [1,1,0],
    [1,0,2],
    [0,2,2]
]]])


_input_sigmoid = torch.cat([_input_sigmoid for i in range(5)])
_input_softmax = torch.cat([_input_softmax  for i in range(5)])
_target_sigmoid = torch.cat([_target_sigmoid  for i in range(5)])
_target_softmax = torch.cat([_target_softmax  for i in range(5)])

_input_softmax.size(), _input_sigmoid.size(), _target_sigmoid.size(), _target_softmax.size()

(torch.Size([5, 2, 3, 3]),
 torch.Size([5, 2, 3, 3]),
 torch.Size([5, 1, 3, 3]),
 torch.Size([5, 1, 3, 3]))

In [228]:
test_eq(dice(_input_softmax,_target, eps=1e-8), (2*1)/(8))

### `iou`

In [229]:
#export
def iou(input: torch.Tensor, targs: torch.Tensor, **kwargs)->Rank0Tensor:
    "Binary IOU"
    return dice(input, targs, iou=True, **kwargs)

In [230]:
test_eq(iou(_input_softmax,_target, eps=1e-8), 1/7)

### `multilabel_dice`

In [231]:
#export
def multilabel_dice(input:Tensor, targs:Tensor, c:int, iou:bool=False, 
                    mean=True, eps:float=1e-8, sigmoid:bool=False, threshold:float=0.5)->Rank0Tensor:
    "Batch/Dataset Mean Dice"
    if sigmoid:
        sigmoid_input     = input.sigmoid()
        thresholded_input = sigmoid_input > threshold
        _, indices        = torch.max(sigmoid_input, dim=1);
        indices          += 1
        values, _         = torch.max(thresholded_input, dim=1)
        input             = (values.float()*indices.float()).view(-1)
    else:
        input             = input.argmax(dim=1, keepdim=True).view(-1)
        
    targs = targs.view(-1)
    res = []
    for ci in range(c):
        # float() fail for fp16 - nan
        _input, _targs = input == ci, targs == ci
        intersect = (_input * _targs).sum().float()
        union = (_input+_targs).sum().float()
        if not iou: res.append((2. * intersect / union if union > 0 else union.new([1.]).squeeze()))
        else: res.append(intersect / (union-intersect+eps))
    res = torch.tensor(res).to(input.device)
    if not mean: return res
    else: return res.mean()

In [232]:
test_eq(multilabel_dice(_input_softmax,_target,2, eps=1e-8), ((2*1)/(8) + (2*2)/(10))/2)

### `multilabel_iou`

In [233]:
#export
def multilabel_iou(input: torch.Tensor, targs: torch.Tensor, c)->Rank0Tensor:
    "Batch/Dataset Mean IOU"
    return multilabel_dice(input, targs, c=c, iou=True)

In [234]:
test_eq(multilabel_iou(_input_softmax,_target,2), (1/7+2/8)/2)

### `sigmoid_multilabel_dice`

In [235]:
#export
def sigmoid_multilabel_dice(input, target, c, threshold): 
    return multilabel_dice(input, target, c=c, threshold=threshold, sigmoid=True)

### `sigmoid_multilabel_iou`

In [236]:
#export
def sigmoid_multilabel_iou(input, target, c, threshold):
    return multilabel_iou(input, target, c=c, threshold=threshold, sigmoid=True, iou=True)

### `_dice`

In [237]:
#export
def _dice(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8, 
          reduce:bool=True)->Rank0Tensor:
    "Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems."
    n = targs.shape[0]
    input = input.view(n,-1).float()
    targs = targs.view(n,-1).float()
    intersect = (input * targs).sum(dim=1).float()
    union = (input+targs).sum(dim=1).float()
    if not iou: l = 2. * intersect / union
    else: l = intersect / (union-intersect+eps)
    l[union == 0.] = 1.
    if reduce: return l.mean()
    else: return l

In [238]:
assert _dice(tensor([[0]]), tensor([[0]])).item()

In [239]:
assert _dice(tensor([[1]]), tensor([[1]])).item()

In [240]:
assert _dice(tensor([[1]]), tensor([[0]])).item() == 0

In [241]:
assert _dice(tensor([[0]]), tensor([[1]])).item() == 0

### `sigmoid_dice_novoid`

In [311]:
#export
def _to_sigmoid_input(logits, threshold):
    "convert logits to preds with sigmoid and thresh (void=0)"
    sigmoid_input = logits.sigmoid()
    thresholded_input = sigmoid_input > threshold
   
    _, indices = torch.max(sigmoid_input, dim=1)
    indices += 1
    values, _ = torch.max(thresholded_input, dim=1)
    preds = (values.float()*indices.float())
    return preds

In [312]:
assert _to_sigmoid_input(_input_sigmoid, threshold=1).view(-1).sum(0) == 0

In [313]:
test_eq(_to_sigmoid_input(_input_sigmoid, threshold=0)[0],
        tensor([[1., 1., 2.], [1., 2., 2.],[2., 2., 2.]]))

In [314]:
test_eq(_to_sigmoid_input(_input_sigmoid, threshold=0.5)[0],
        tensor([[1., 1., 0.], [1., 0., 2.], [0., 2., 2.]]))

In [366]:
#export 
def sigmoid_dice_novoid(input:Tensor, target:Tensor, threshold:float=0.5,
                        macro:bool=True)->Rank0Tensor:
    "macro: mean of per class dice, micro: mean of per image mean dice"
    c = input.size(1)
    preds = _to_sigmoid_input(input, threshold)
    if macro:
        res = [_dice(preds==ci, target==ci) for ci in range(1, c+1)]
        return torch.mean(tensor(res))
    else:
        res = [_dice(preds==ci, target==ci, reduce=False) for ci in range(1, c+1)]
        return torch.stack(res).mean(0).mean()    

In [372]:
test_close(sigmoid_dice_novoid(_input_sigmoid, _target_sigmoid, 
                            macro=True, threshold=0).item(), 0.833, eps=1e-3)

In [374]:
assert sigmoid_dice_novoid(_input_sigmoid, _target_sigmoid, macro=True)

In [375]:
assert sigmoid_dice_novoid(_input_sigmoid, _target_sigmoid, macro=False)

### export

In [376]:
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_script.ipynb.
Converted 02_scheduler.ipynb.
Converted 03_callbacks.ipynb.
Converted 10_segmentation_dataset.ipynb.
Converted 11_segmentation_losses_mulitlabel.ipynb.
Converted 11b_segmentation_losses_binary.ipynb.
Converted 12_segmentation_metrics.ipynb.
Converted 13_segmentation_models.ipynb.
Converted segmentation_training.ipynb.


### fin