Skip to content

Commit

Permalink
Extractive: Use '-9e3' instead of '-9e9' to replace padding values in…
Browse files Browse the repository at this point in the history
… classifiers before loss calculation
  • Loading branch information
HHousen committed Oct 27, 2020
1 parent ad3a31e commit e1f6022
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward(self, x, mask):
x = self.linear2(x)
# x = self.sigmoid(x)
sent_scores = x.squeeze(-1) * mask.float()
sent_scores[sent_scores == 0] = -9e9
sent_scores[sent_scores == 0] = -9e3
return sent_scores


Expand All @@ -88,7 +88,7 @@ def forward(self, x, mask):
x = self.linear(x).squeeze(-1)
# x = self.sigmoid(x)
sent_scores = x * mask.float()
sent_scores[sent_scores == 0] = -9e9
sent_scores[sent_scores == 0] = -9e3
return sent_scores


Expand Down Expand Up @@ -183,5 +183,5 @@ def forward(self, x, mask):
# x is shape (batch size, source sequence length, 1)
# mask is shape (batch size, source sequence length)
sent_scores = x.squeeze(-1) * mask.float()
sent_scores[sent_scores == 0] = -9e9
sent_scores[sent_scores == 0] = -9e3
return sent_scores
4 changes: 2 additions & 2 deletions src/extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def __init__(self, hparams, embedding_model_config=None, classifier_obj=None):
# trick for numerical stability. Padding values are 0 and if 0 is the input
# to the sigmoid function the output will be 0.5. This will cause issues when
# inputs with more padding will have higher loss values. To solve this, all
# padding values are set to -9e9 as the last step of each encoder. The sigmoid
# function transforms -9e9 to nearly 0, thus preserving the proper loss
# padding values are set to -9e3 as the last step of each encoder. The sigmoid
# function transforms -9e3 to nearly 0, thus preserving the proper loss
# calculation. See `compute_loss()` for more info.
self.loss_func = nn.BCEWithLogitsLoss(reduction="none")

Expand Down

0 comments on commit e1f6022

Please sign in to comment.