Skip to content

Commit

Permalink
Warn when the ever-present PyTorch assertion error occurs on backward()
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnVinyard committed Mar 3, 2018
1 parent ffd7c25 commit ade66d8
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions zounds/learn/supervised.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from trainer import Trainer
import numpy as np
import warnings


class SupervisedTrainer(Trainer):
Expand Down Expand Up @@ -96,8 +97,12 @@ def batch(d, l, test=False):
minibatch_data = data[minibatch_slice]
minibatch_labels = labels[minibatch_slice]

inp, output, e = batch(
minibatch_data, minibatch_labels, test=False)
try:
inp, output, e = batch(
minibatch_data, minibatch_labels, test=False)
except RuntimeError as e:
warnings.warn(e.message)
continue

# test batch
if test_size:
Expand Down

0 comments on commit ade66d8

Please sign in to comment.