-
Notifications
You must be signed in to change notification settings - Fork 1
/
loss_functions.py
37 lines (29 loc) · 1003 Bytes
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# ## Define Loss Function
import torch
import torch.nn.functional as F
def dice_loss(pred, target, smooth = 1e-5):
intersection = (pred * target).sum(dim=(2,3))
union= pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
dice= 2.0 * (intersection + smooth) / (union+ smooth)
loss = 1.0 - dice
return loss.sum(), dice.sum()
def loss_func(pred, target):
bce = F.binary_cross_entropy_with_logits(pred, target, reduction='sum')
pred= torch.sigmoid(pred)
dlv, _ = dice_loss(pred, target)
loss = bce + dlv
return loss
def metrics_batch(pred, target):
pred= torch.sigmoid(pred)
_, metric=dice_loss(pred, target)
return metric
def loss_batch(loss_func, output, target, opt=None):
loss = loss_func(output, target)
with torch.no_grad():
pred= torch.sigmoid(output)
_, metric_b=dice_loss(pred, target)
if opt is not None:
opt.zero_grad()
loss.backward()
opt.step()
return loss.item(), metric_b