Skip to content

Commit

Permalink
Adds debug message and exception to invalid loss
Browse files Browse the repository at this point in the history
  • Loading branch information
JSybrandt committed May 21, 2020
1 parent 0f41048 commit 4f700fe
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions agatha/ml/hypothesis_predictor/hypothesis_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,17 +317,27 @@ def _step(
.collate_predicate_embeddings(neg)
.to(self.get_device())
)
partial_losses.append(
self.loss_fn(
positive_predictions,
negative_predictions,
positive_predictions.new_ones(len(positive_predictions))
)
part_loss = self.loss_fn(
positive_predictions,
negative_predictions,
positive_predictions.new_ones(len(positive_predictions))
)
# If something has gone terrible
if torch.isnan(part_loss) or torch.isinf(part_loss):
print("ERROR: Loss is:\n", part_loss)
print("Positive Predicates:\n", positive_predicates)
print("Positive Scores:\n", positive_predictions)
print("Negative Scores:\n", negative_predictions)
print("Positive Sample:\n", pos)
print("Negative Sample:\n", neg)
raise Exception("Invalid loss")
else:
partial_losses.append(part_loss)
# End of batch
loss=torch.mean(torch.stack(partial_losses))
return (
loss,
dict(
dict( # pbar metrics
)
)

Expand Down

0 comments on commit 4f700fe

Please sign in to comment.