-
Notifications
You must be signed in to change notification settings - Fork 1
/
losses.py
40 lines (27 loc) · 964 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
def get_gan_losses(gan_type):
if gan_type == 'gan':
return gan_g_loss, gan_d_loss
else:
raise ValueError('Improper GAN type "%s"' % gan_type)
def bce_loss(input, target):
neg_abs = -input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
def _make_targets(x, y):
return torch.full_like(x, y)
def gan_g_loss(scores_fake):
if scores_fake.dim() > 1:
scores_fake = scores_fake.view(-1)
y_fake = _make_targets(scores_fake, 1)
return bce_loss(scores_fake, y_fake)
def gan_d_loss(scores_real, scores_fake):
assert scores_real.size() == scores_fake.size()
if scores_real.dim() > 1:
scores_real = scores_real.view(-1)
scores_fake = scores_fake.view(-1)
y_real = _make_targets(scores_real, 1)
y_fake = _make_targets(scores_fake, 0)
loss_real = bce_loss(scores_real, y_real)
loss_fake = bce_loss(scores_fake, y_fake)
return loss_real + loss_fake