Skip to content

Commit

Permalink
LossCompute: remove epoch arg
Browse files Browse the repository at this point in the history
epoch arg in LossCompute constructor isn't used, remove it.
  • Loading branch information
JianyuZhan committed Sep 19, 2017
1 parent c49d948 commit b7e653d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions onmt/Loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,11 @@ def shards(state, shard_size, eval=False):


class LossCompute(object):
def __init__(self, generator, crit, tgt_vocab, dataset, epoch, copy_attn):
def __init__(self, generator, crit, tgt_vocab, dataset, copy_attn):
self.generator = generator
self.crit = crit
self.tgt_vocab = tgt_vocab
self.dataset = dataset
self.epoch = epoch
self.copy_attn = copy_attn

def make_loss_batch(self, outputs, batch, attns, range_):
Expand Down
4 changes: 2 additions & 2 deletions onmt/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def train(self, epoch, report_func=None):
""" Called for each epoch to train. """
closs = onmt.Loss.LossCompute(self.model.generator, self.criterion,
self.fields["tgt"].vocab,
self.train_data, epoch,
self.train_data,
self.copy_attn)

total_stats = onmt.Statistics()
Expand Down Expand Up @@ -109,7 +109,7 @@ def validate(self):

loss = onmt.Loss.LossCompute(self.model.generator, self.criterion,
self.fields["tgt"].vocab,
self.valid_data, 0,
self.valid_data,
self.copy_attn)
stats = onmt.Statistics()

Expand Down

0 comments on commit b7e653d

Please sign in to comment.