Skip to content

Commit

Permalink
Fix label_weights in bert models (facebookresearch#1100)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#1100

following D18176357 I realized we also aren't passing label weights in bert models

Reviewed By: anchit, nihit

Differential Revision: D18283670

fbshipit-source-id: 67a2bcda75f03c556648893ad4db329b244527a8
  • Loading branch information
rutyrinott authored and facebook-github-bot committed Nov 4, 2019
1 parent 3bd7b34 commit 9d64d10
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pytext/models/bert_classification_models.py
Expand Up @@ -31,6 +31,7 @@
from pytext.models.representations.transformer_sentence_encoder_base import (
TransformerSentenceEncoderBase,
)
from pytext.utils.label import get_label_weights


class NewBertModel(BaseModel):
Expand Down Expand Up @@ -89,7 +90,13 @@ def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
out_dim=len(labels),
)

loss = create_loss(config.output_layer.loss)
label_weights = (
get_label_weights(labels.idx, config.output_layer.label_weights)
if config.output_layer.label_weights
else None
)

loss = create_loss(config.output_layer.loss, weight=label_weights)

if isinstance(loss, BinaryCrossEntropyLoss):
output_layer_cls = BinaryClassificationOutputLayer
Expand Down

0 comments on commit 9d64d10

Please sign in to comment.