Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Dec 22, 2017
1 parent 4f2e179 commit f287003
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 111 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ matrix:
script: flake8
- python: "3.5"
install:
- python setup.py install
- pip install doctr
script:
- pip install -r docs/requirements.txt
Expand Down
22 changes: 18 additions & 4 deletions docs/source/onmt.modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,30 @@ Core Modules
:undoc-members:


.. autoclass:: onmt.modules.StackedGRU
.. autoclass:: onmt.modules.GlobalAttention
:members:
:undoc-members:

.. autoclass:: onmt.modules.StackedLSTM

Encoders
---------

.. autoclass:: onmt.modules.EncoderBase
:members:
:undoc-members:


.. autoclass:: onmt.modules.GlobalAttention
.. autoclass:: onmt.modules.MeanEncoder
:members:

.. autoclass:: onmt.modules.RNNEncoder
:members:


Decoders
---------


.. autoclass:: onmt.modules.RNNDecoderBase
:members:
:undoc-members:

Expand Down
13 changes: 3 additions & 10 deletions docs/source/onmt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,13 @@ OpenNMT Framework
=================

Model
------

.. autoclass:: onmt.NMTModel
:members:
:undoc-members:
-----

.. autoclass:: onmt.EncoderBase
.. autoclass:: onmt.Models.NMTModel
:members:
:undoc-members:

.. autoclass:: onmt.DecoderBase
.. autoclass:: onmt.Models.DecoderState
:members:
:undoc-members:


Trainer
-------
Expand Down
105 changes: 74 additions & 31 deletions onmt/Loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,33 @@

class LossComputeBase(nn.Module):
"""
This is the loss criterion base class. Users can implement their own
loss computation strategy by making subclass of this one.
Users need to implement the compute_loss() and make_shard_state() methods.
We inherits from nn.Module to leverage the cuda behavior.
Class for managing efficient loss computation. Handles
sharding next step predictions and accumulating mutiple
loss computations
Users can implement their own loss computation strategy by making
subclass of this one. Users need to implement the _compute_loss()
and make_shard_state() methods.
Args:
generator (:obj:`nn.Module`) :
module that maps the output of the decoder to a
distribution over the target vocabulary.
tgt_vocab (:obj:`Vocab`) :
torchtext vocab object representing the target output
"""
def __init__(self, generator, tgt_vocab):
super(LossComputeBase, self).__init__()
self.generator = generator
self.tgt_vocab = tgt_vocab
self.padding_idx = tgt_vocab.stoi[onmt.io.PAD_WORD]

def make_shard_state(self, batch, output, range_, attns=None):
def _make_shard_state(self, batch, output, range_, attns=None):
"""
Make shard state dictionary for shards() to return iterable
shards for efficient loss computation. Subclass must define
this method to match its own compute_loss() interface.
this method to match its own _compute_loss() interface.
Args:
batch: the current batch.
output: the predict output from the model.
Expand All @@ -40,10 +51,12 @@ def make_shard_state(self, batch, output, range_, attns=None):
"""
return NotImplementedError

def compute_loss(self, batch, output, target, **kwargs):
def _compute_loss(self, batch, output, target, **kwargs):
"""
Compute the loss. Subclass must define this method.
Args:
batch: the current batch.
output: the predict output from the model.
target: the validate target to compare output with.
Expand All @@ -53,37 +66,72 @@ def compute_loss(self, batch, output, target, **kwargs):

def monolithic_compute_loss(self, batch, output, attns):
"""
Compute the loss monolithically, not dividing into shards.
Compute the forward loss for the batch.
Args:
batch (batch): batch of labeled examples
output (:obj:`FloatTensor`):
output of decoder model `[tgt_len x batch x hidden]`
attns (dict of :obj:`FloatTensor`) :
dictionary of attention distributions
`[tgt_len x batch x src_len]`
Returns:
:obj:`onmt.Statistics`: loss statistics
"""
range_ = (0, batch.tgt.size(0))
shard_state = self.make_shard_state(batch, output, range_, attns)
_, batch_stats = self.compute_loss(batch, **shard_state)
shard_state = self._make_shard_state(batch, output, range_, attns)
_, batch_stats = self._compute_loss(batch, **shard_state)

return batch_stats

def sharded_compute_loss(self, batch, output, attns,
cur_trunc, trunc_size, shard_size):
"""
Compute the loss in shards for efficiency.
"""Compute the forward loss and backpropagate. Computation is done
with shards and optionally truncation for memory efficiency.
Also supports truncated BPTT for long sequences by taking a
range in the decoder output sequence to back propagate in.
Range is from `(cur_trunc, cur_trunc + trunc_size)`.
Note harding is an exact efficiency trick to relieve memory
required for the generation buffers. Truncation is an
approximate efficiency trick to relieve the memory required
in the RNN buffers.
Args:
batch (batch) : batch of labeled examples
output (:obj:`FloatTensor`) :
output of decoder model `[tgt_len x batch x hidden]`
attns (dict) : dictionary of attention distributions
`[tgt_len x batch x src_len]`
cur_trunc (int) : starting position of truncation window
trunc_size (int) : length of truncation window
shard_size (int) : maximum number of examples in a shard
Returns:
:obj:`onmt.Statistics`: validation loss statistics
"""
batch_stats = onmt.Statistics()
range_ = (cur_trunc, cur_trunc + trunc_size)
shard_state = self.make_shard_state(batch, output, range_, attns)
shard_state = self._make_shard_state(batch, output, range_, attns)

for shard in shards(shard_state, shard_size):
loss, stats = self.compute_loss(batch, **shard)
loss, stats = self._compute_loss(batch, **shard)
loss.div(batch.batch_size).backward()
batch_stats.update(stats)

return batch_stats

def stats(self, loss, scores, target):
def _stats(self, loss, scores, target):
"""
Compute and return a Statistics object.
Args:
loss(Tensor): the loss computed by the loss criterion.
scores(Tensor): a sequence of predict output with scores.
loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
scores (:obj:`FloatTensor`): a score for each possible output
target (:obj:`FloatTensor`): true targets
Returns:
:obj:`Statistics` : statistics for this batch.
"""
pred = scores.max(1)[1]
non_padding = target.ne(self.padding_idx)
Expand All @@ -92,10 +140,10 @@ def stats(self, loss, scores, target):
.sum()
return onmt.Statistics(loss[0], non_padding.sum(), num_correct)

def bottle(self, v):
def _bottle(self, v):
return v.view(-1, v.size(2))

def unbottle(self, v, batch_size):
def _unbottle(self, v, batch_size):
return v.view(-1, batch_size, v.size(1))


Expand All @@ -105,10 +153,7 @@ class NMTLossCompute(LossComputeBase):
"""
def __init__(self, generator, tgt_vocab, label_smoothing=0.0):
super(NMTLossCompute, self).__init__(generator, tgt_vocab)

# CHECK
assert (label_smoothing >= 0.0 and label_smoothing <= 1.0)
# END CHECK

if label_smoothing > 0:
# When label smoothing is turned on,
Expand All @@ -128,16 +173,14 @@ def __init__(self, generator, tgt_vocab, label_smoothing=0.0):
self.criterion = nn.NLLLoss(weight, size_average=False)
self.confidence = 1.0 - label_smoothing

def make_shard_state(self, batch, output, range_, attns=None):
""" See base class for args description. """
def _make_shard_state(self, batch, output, range_, attns=None):
return {
"output": output,
"target": batch.tgt[range_[0] + 1: range_[1]],
}

def compute_loss(self, batch, output, target):
""" See base class for args description. """
scores = self.generator(self.bottle(output))
def _compute_loss(self, batch, output, target):
scores = self.generator(self._bottle(output))

gtruth = target.view(-1)
if self.confidence < 1:
Expand All @@ -157,7 +200,7 @@ def compute_loss(self, batch, output, target):
else:
loss_data = loss.data.clone()

stats = self.stats(loss_data, scores.data, target.view(-1).data)
stats = self._stats(loss_data, scores.data, target.view(-1).data)

return loss, stats

Expand All @@ -174,7 +217,7 @@ def shards(state, shard_size, eval=False):
"""
Args:
state: A dictionary which corresponds to the output of
*LossCompute.make_shard_state(). The values for
*LossCompute._make_shard_state(). The values for
those keys are Tensor-like or None.
shard_size: The maximum size of the shards yielded by the model.
eval: If True, only yield the state, nothing else.
Expand Down
66 changes: 35 additions & 31 deletions onmt/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def _check_args(self, input, lengths=None, hidden=None):
def forward(self, input, lengths=None, hidden=None):
"""
Args:
input (LongTensor): len x batch x nfeat.
lengths (LongTensor): batch
hidden: Initial hidden state.
input (:obj:`LongTensor`): len x batch x nfeat.
lengths (:obj:`LongTensor`): batch
hidden (class specific): Initial hidden state.
Returns:
hidden_t (Variable): Pair of layers x batch x rnn_size - final
:obj:`Variable`: Pair of layers x batch x rnn_size - final
encoder state
outputs (FloatTensor): len x batch x rnn_size - Memory bank
:obj:`FloatTensor`: outputs len x batch x rnn_size - Memory bank
"""
raise NotImplementedError

Expand Down Expand Up @@ -391,34 +391,40 @@ def _input_size(self):

class NMTModel(nn.Module):
"""
The encoder + decoder Neural Machine Translation Model.
Core trainable object in OpenNMT. Implements a trainable interface
for a simple, generic encoder + decoder model.
Args:
encoder (:obj:`EncoderBase`): an encoder object
decoder (:obj:`RNNDecoderBase`): a decoder object
multigpu (bool): setup for multigpu support
"""
def __init__(self, encoder, decoder, multigpu=False):
"""
Args:
encoder(*Encoder): the various encoder.
decoder(*Decoder): the various decoder.
multigpu(bool): run parellel on multi-GPU?
"""
self.multigpu = multigpu
super(NMTModel, self).__init__()
self.encoder = encoder
self.decoder = decoder

def forward(self, src, tgt, lengths, dec_state=None):
"""
"""Forward propagate a `src` and `tgt` pair for training.
Possible initialized with a beginning decoder state.
Args:
src(FloatTensor): a sequence of source tensors with
optional feature tensors of size (len x batch).
tgt(FloatTensor): a sequence of target tensors with
optional feature tensors of size (len x batch).
lengths([int]): an array of the src length.
dec_state: A decoder state object
src (:obj:`Tensor`):
a source sequence passed to encoder.
typically for inputs this will be a padded :obj:`LongTensor`
of size `[len x batch x features]`. however, may be an
image or other generic input depending on encoder.
tgt (:obj:`LongTensor`):
a target sequence of size `[tgt_len x batch]`.
lengths(:obj:`LongTensor`): the src lengths, pre-padding `[batch]`.
dec_state (:obj:`DecoderState`, optional): the initial decoder state
Returns:
outputs (FloatTensor): (len x batch x hidden_size): decoder outputs
attns (FloatTensor): Dictionary of (src_len x batch)
dec_hidden (FloatTensor): tuple (1 x batch x hidden_size)
Init hidden state
(:obj:`FloatTensor`, `dict` of :obj:`FloatTensor`, :obj:`DecoderState`) :
* decoder output `[tgt_len x batch x hidden]`
* dictionary attention dists of `[tgt_len x batch x src_len]`
* final decoder state
"""
tgt = tgt[:-1] # exclude last target from inputs
enc_hidden, context = self.encoder(src, lengths)
Expand All @@ -434,21 +440,19 @@ def forward(self, src, tgt, lengths, dec_state=None):


class DecoderState(object):
"""
DecoderState is a base class for models, used during translation
for storing translation states.
"""Interface for grouping together the current state of a recurrent
decoder. In the simplest case just represents the hidden state of
the model. But can also be used for implementing various forms of
input_feeding and non-recurrent models.
Modules need to implement this to utilize beam search decoding.
"""
def detach(self):
"""
Detaches all Variables from the graph
that created it, making it a leaf.
"""
for h in self._all:
if h is not None:
h.detach_()

def beam_update(self, idx, positions, beam_size):
""" Update when beam advances. """
for e in self._all:
a, br, d = e.size()
sent_states = e.view(a, beam_size, br // beam_size, d)[:, :, idx]
Expand Down

0 comments on commit f287003

Please sign in to comment.