Skip to content

Commit

Permalink
Use MSELoss in (M)BartForSequenceClassification (huggingface#11178)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored and Iwontbecreative committed Jul 15, 2021
1 parent 73980c8 commit 692b959
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
11 changes: 8 additions & 3 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -1437,8 +1437,13 @@ def forward(

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if self.config.num_labels == 1:
# regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -1442,8 +1442,13 @@ def forward(

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if self.config.num_labels == 1:
# regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down

0 comments on commit 692b959

Please sign in to comment.