Skip to content

Commit

Permalink
Merge pull request #7 from AdeelH/AdeelH-patch-1
Browse files Browse the repository at this point in the history
Ensure forward() always returns a tensor
  • Loading branch information
AdeelH committed Jun 18, 2022
2 parents f15f440 + 00c99ad commit 6359c1e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
unignored_mask = y != self.ignore_index
y = y[unignored_mask]
if len(y) == 0:
return 0.
return torch.tensor(0.)
x = x[unignored_mask]

# compute weighted cross entropy term: -alpha * log(pt)
Expand Down

0 comments on commit 6359c1e

Please sign in to comment.