-
Notifications
You must be signed in to change notification settings - Fork 1
/
FocalTverskyLoss.py
44 lines (30 loc) · 1.3 KB
/
FocalTverskyLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from keras import backend as K
#More details see
#A Novel Focal Tversky loss function with improved Attention U-Net for lesion segmentation
#https://arxiv.org/abs/1810.07842
ALPHA = 0.5
BETA = 0.5
def TverskyLoss(targets, inputs, alpha=ALPHA, beta=BETA, smooth=1e-6):
#flatten label and prediction tensors
inputs = K.flatten(inputs)
targets = K.flatten(targets)
#True Positives, False Positives & False Negatives
TP = K.sum((inputs * targets))
FP = K.sum(((1-targets) * inputs))
FN = K.sum((targets * (1-inputs)))
Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)
return 1 - Tversky
ALPHA = 0.3
BETA = 1-ALPHA
GAMMA = 1
def FocalTverskyLoss(targets, inputs, alpha=ALPHA, beta=BETA, gamma=GAMMA, smooth=1e-6):
#flatten label and prediction tensors
inputs = K.flatten(inputs)
targets = K.flatten(targets)
#True Positives, False Positives & False Negatives
TP = K.sum((inputs * targets))
FP = K.sum(((1-targets) * inputs))
FN = K.sum((targets * (1-inputs)))
Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)
FocalTversky = K.pow((1 - Tversky), gamma)
return FocalTversky