/
nce_loss.py
248 lines (198 loc) · 9.56 KB
/
nce_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""A generic NCE wrapper which speedup the training and inferencing"""
import math
import torch
import torch.nn as nn
from .alias_multinomial import AliasMultinomial
# A backoff probability to stabilize log operation
BACKOFF_PROB = 1e-10
class NCELoss(nn.Module):
"""Noise Contrastive Estimation
NCE is to eliminate the computational cost of softmax
normalization.
There are 3 loss modes in this NCELoss module:
- nce: enable the NCE approximation
- sampled: enabled sampled softmax approximation
- full: use the original cross entropy as default loss
They can be switched by directly setting `nce.loss_type = 'nce'`.
Ref:
X.Chen etal Recurrent neural network language
model training with noise contrastive estimation
for speech recognition
https://core.ac.uk/download/pdf/42338485.pdf
Attributes:
noise: the distribution of noise
noise_ratio: $\frac{#noises}{#real data samples}$ (k in paper)
norm_term: the normalization term (lnZ in paper), can be heuristically
determined by the number of classes, plz refer to the code.
reduction: reduce methods, same with pytorch's loss framework, 'none',
'elementwise_mean' and 'sum' are supported.
loss_type: loss type of this module, currently 'full', 'sampled', 'nce'
are supported
Shape:
- noise: :math:`(V)` where `V = vocabulary size`
- target: :math:`(B, N)`
- loss: a scalar loss by default, :math:`(B, N)` if `reduction='none'`
Input:
target: the supervised training label.
args&kwargs: extra arguments passed to underlying index module
Return:
loss: if `reduction='sum' or 'elementwise_mean'` the scalar NCELoss ready for backward,
else the loss matrix for every individual targets.
"""
def __init__(self,
noise,
noise_ratio=100,
norm_term='auto',
reduction='elementwise_mean',
per_word=False,
loss_type='nce',
):
super(NCELoss, self).__init__()
# Re-norm the given noise frequency list and compensate words with
# extremely low prob for numeric stability
probs = noise / noise.sum()
probs = probs.clamp(min=BACKOFF_PROB)
renormed_probs = probs / probs.sum()
self.register_buffer('logprob_noise', renormed_probs.log())
self.alias = AliasMultinomial(renormed_probs)
self.noise_ratio = noise_ratio
if norm_term == 'auto':
self.norm_term = math.log(noise.numel())
else:
self.norm_term = norm_term
self.reduction = reduction
self.per_word = per_word
self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none')
self.ce = nn.CrossEntropyLoss(reduction='none')
self.loss_type = loss_type
def forward(self, target, *args, **kwargs):
"""compute the loss with output and the desired target
The `forward` is the same among all NCELoss submodules, it
takes care of generating noises and calculating the loss
given target and noise scores.
"""
batch = target.size(0)
max_len = target.size(1)
if self.loss_type != 'full':
noise_samples = self.get_noise(batch, max_len)
# B,N,Nr
logit_noise_in_noise = self.logprob_noise[noise_samples.data.view(-1)].view_as(noise_samples)
logit_target_in_noise = self.logprob_noise[target.data.view(-1)].view_as(target)
# (B,N), (B,N,Nr)
logit_target_in_model, logit_noise_in_model = self._get_logit(target, noise_samples, *args, **kwargs)
if self.loss_type == 'nce':
if self.training:
loss = self.nce_loss(
logit_target_in_model, logit_noise_in_model,
logit_noise_in_noise, logit_target_in_noise,
)
else:
# directly output the approximated posterior
loss = - logit_target_in_model
elif self.loss_type == 'sampled':
loss = self.sampled_softmax_loss(
logit_target_in_model, logit_noise_in_model,
logit_noise_in_noise, logit_target_in_noise,
)
# NOTE: The mix mode is still under investigation
elif self.loss_type == 'mix' and self.training:
loss = 0.5 * self.nce_loss(
logit_target_in_model, logit_noise_in_model,
logit_noise_in_noise, logit_target_in_noise,
)
loss += 0.5 * self.sampled_softmax_loss(
logit_target_in_model, logit_noise_in_model,
logit_noise_in_noise, logit_target_in_noise,
)
else:
current_stage = 'training' if self.training else 'inference'
raise NotImplementedError(
'loss type {} not implemented at {}'.format(
self.loss_type, current_stage
)
)
else:
# Fallback into conventional cross entropy
loss = self.ce_loss(target, *args, **kwargs)
if self.reduction == 'elementwise_mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
def get_noise(self, batch_size, max_len):
"""Generate noise samples from noise distribution"""
noise_size = (batch_size, max_len, self.noise_ratio)
if self.per_word:
noise_samples = self.alias.draw(*noise_size)
else:
noise_samples = self.alias.draw(1, 1, self.noise_ratio).expand(*noise_size)
noise_samples = noise_samples.contiguous()
return noise_samples
def _get_logit(self, target_idx, noise_idx, *args, **kwargs):
"""Get the logits of NCE estimated probability for target and noise
Both NCE and sampled softmax Loss are unchanged when the probabilities are scaled
evenly, here we subtract the maximum value as in softmax, for numeric stability.
Shape:
- Target_idx: :math:`(N)`
- Noise_idx: :math:`(N, N_r)` where `N_r = noise ratio`
"""
target_logit, noise_logit = self.get_score(target_idx, noise_idx, *args, **kwargs)
target_logit = target_logit.sub(self.norm_term)
noise_logit = noise_logit.sub(self.norm_term)
return target_logit, noise_logit
def get_score(self, target_idx, noise_idx, *args, **kwargs):
"""Get the target and noise score
Usually logits are used as score.
This method should be override by inherit classes
Returns:
- target_score: real valued score for each target index
- noise_score: real valued score for each noise index
"""
raise NotImplementedError()
def ce_loss(self, target_idx, *args, **kwargs):
"""Get the conventional CrossEntropyLoss
The returned loss should be of the same size of `target`
Args:
- target_idx: batched target index
- args, kwargs: any arbitrary input if needed by sub-class
Returns:
- loss: the estimated loss for each target
"""
raise NotImplementedError()
def nce_loss(self, logit_target_in_model, logit_noise_in_model, logit_noise_in_noise, logit_target_in_noise):
"""Compute the classification loss given all four probabilities
Args:
- logit_target_in_model: logit of target words given by the model (RNN)
- logit_noise_in_model: logit of noise words given by the model
- logit_noise_in_noise: logit of noise words given by the noise distribution
- logit_target_in_noise: logit of target words given by the noise distribution
Returns:
- loss: a mis-classification loss for every single case
"""
# NOTE: prob <= 1 is not guaranteed
logit_model = torch.cat([logit_target_in_model.unsqueeze(2), logit_noise_in_model], dim=2)
logit_noise = torch.cat([logit_target_in_noise.unsqueeze(2), logit_noise_in_noise], dim=2)
# predicted probability of the word comes from true data distribution
# The posterior can be computed as following
# p_true = logit_model.exp() / (logit_model.exp() + self.noise_ratio * logit_noise.exp())
# For numeric stability we compute the logits of true label and
# directly use bce_with_logits.
# Ref https://pytorch.org/docs/stable/nn.html?highlight=bce#torch.nn.BCEWithLogitsLoss
logit_true = logit_model - logit_noise - math.log(self.noise_ratio)
label = torch.zeros_like(logit_model)
label[:, :, 0] = 1
loss = self.bce_with_logits(logit_true, label).sum(dim=2)
return loss
def sampled_softmax_loss(self, logit_target_in_model, logit_noise_in_model, logit_noise_in_noise, logit_target_in_noise):
"""Compute the sampled softmax loss based on the tensorflow's impl"""
logits = torch.cat([logit_target_in_model.unsqueeze(2), logit_noise_in_model], dim=2)
q_logits = torch.cat([logit_target_in_noise.unsqueeze(2), logit_noise_in_noise], dim=2)
# subtract Q for correction of biased sampling
logits = logits - q_logits
labels = torch.zeros_like(logits.narrow(2, 0, 1)).squeeze(2).long()
loss = self.ce(
logits.view(-1, logits.size(-1)),
labels.view(-1),
).view_as(labels)
return loss