-
Notifications
You must be signed in to change notification settings - Fork 4
/
losses.py
150 lines (124 loc) · 4.78 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Loss functions for training a lens flare reduction model."""
from typing import Callable, Dict, Mapping, Optional, Union
import tensorflow as tf
import vgg
def get_loss(name):
"""Returns the loss function object for the given name.
Supported configs:
- "l1":
Pixel-wise MAE.
- "l2":
Pixel-wise MSE.
- "perceptual" (or "percep"):
Perceptual loss implemented using a pre-trained VGG19 network, plus L1 loss.
The two losses have equal weights.
Args:
name: One of the three configs above. Not case-sensitive.
Returns:
A Keras `Loss` object.
"""
name = name.lower()
if name == 'l2':
return tf.keras.losses.MeanSquaredError()
elif name == 'l1':
return tf.keras.losses.MeanAbsoluteError()
elif name in ['percep', 'perceptual']:
loss_fn = CompositeLoss()
loss_fn.add_loss(PerceptualLoss(), weight=1.0)
loss_fn.add_loss('L1', weight=1.0)
return loss_fn
else:
raise ValueError(f'Unrecognized loss function name: {name}')
class PerceptualLoss(tf.keras.losses.Loss):
"""A perceptual loss function based on the VGG-19 model.
The loss function is defined as a weighted sum of the L1 loss at various
tap-out layers of the network.
"""
DEFAULT_COEFFS = {
'block1_conv2': 1 / 2.6,
'block2_conv2': 1 / 4.8,
'block3_conv2': 1 / 3.7,
'block4_conv2': 1 / 5.6,
'block5_conv2': 10 / 1.5,
}
def __init__(self,
coeffs = None,
name = 'perceptual'):
"""Initializes a perceptual loss instance.
Args:
coeffs: Key-value pairs where the keys are the tap-out layer names, and
the values are their coefficients in the weighted sum. Defaults to the
`self.DEFAULT_COEFFS`.
name: Name of this Tensorflow object.
"""
super(PerceptualLoss, self).__init__(name=name)
coeffs = coeffs or self.DEFAULT_COEFFS
layers, self._coeffs = zip(*coeffs.items())
self._model = vgg.Vgg19(tap_out_layers=layers)
def call(self, y_true, y_pred):
"""Invokes the loss function.
See base class for details.
Do not call this method directly. Use the __call__() method instead.
Args:
y_true: ground-truth image batch, with shape [B, H, W, C].
y_pred: predicted image batch, with the same shape.
Returns:
A [B, 1, 1] tensor containing the perceptual loss values. Note that
according to the base class's specs, if the inputs have D dimensions, the
output must have D-1 dimensions. Hence the [B, 1, 1] shape.
"""
true_features = self._model(y_true)
pred_features = self._model(y_pred)
total_loss = tf.constant(0.0)
for ft, fp, coeff in zip(true_features, pred_features, self._coeffs):
# MAE only reduces the last dimension, leading to a [B, H, W]-tensor.
loss = tf.keras.losses.MAE(ft, fp)
# Further reduce on the H and W dimensions.
loss = tf.reduce_mean(loss, axis=[1, 2], keepdims=True)
total_loss += loss * coeff
return total_loss
class CompositeLoss(tf.keras.losses.Loss):
"""A weighted sum of individual loss functions for images.
Attributes:
losses: Mapping from Keras loss objects to weights.
"""
def __init__(self, name = 'composite'):
"""Initializes an instance with given weights.
Args:
name: Optional name for this Tensorflow object.
"""
super(CompositeLoss, self).__init__(name=name)
self.losses: Dict[tf.keras.losses.Loss, float] = {}
def add_loss(self, loss, weight):
"""Adds a component loss to the composite with specific weight.
Args:
loss: A Keras loss object or identifier. All standard Keras loss
identifiers are supported (e.g., string like "mse", loss functions, and
`tf.keras.losses.Loss` objects). In addition, strings "l1" and "l2" are
also supported. Cannot be a loss that is already added to this
`CompositeLoss`.
weight: Weight associated with this loss. Must be > 0.
Raises:
ValueError: If the given `loss` already exists, or if `weight` is empty or
<= 0.
"""
if weight <= 0.0:
raise ValueError(f'Weight must be > 0, but is {weight}.')
if isinstance(loss, str):
loss = loss.lower()
loss = {'l1': 'mae', 'l2': 'mse'}.get(loss, loss)
loss_fn = tf.keras.losses.get(loss)
else:
loss_fn = loss
if loss_fn in self.losses:
raise ValueError('The same loss already exists.')
self.losses[loss_fn] = weight # pytype: disable=container-type-mismatch # typed-keras
def call(self, y_true, y_pred):
"""See base class."""
assert self.losses, 'At least one component loss must be added.'
loss_sum = tf.constant(0.0)
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
for loss, weight in self.losses.items():
loss_sum = loss(y_true, y_pred) * weight + loss_sum
return loss_sum