Skip to content

Commit

Permalink
added RNNEncoderBase
Browse files Browse the repository at this point in the history
Former-commit-id: 12dea84
  • Loading branch information
ZhitingHu committed Sep 23, 2017
1 parent 97e1181 commit e988614
Showing 1 changed file with 56 additions and 24 deletions.
80 changes: 56 additions & 24 deletions txtgen/modules/encoders/rnn_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from txtgen.core import layers


class ForwardRNNEncoder(EncoderBase):
"""One directional forward RNN encoder.
class RNNEncoderBase(EncoderBase):
"""Base class for all RNN encoder classes.
Args:
cell: (RNNCell, optional) If it is not specified,
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, # pylint: disable=too-many-arguments
embedding_trainable=True,
vocab_size=None,
hparams=None,
name="forward_rnn_encoder"):
name="rnn_encoder"):
EncoderBase.__init__(self, hparams, name)

# Make rnn cell
Expand Down Expand Up @@ -106,6 +106,59 @@ def default_hparams():
"embedding": layers.default_embedding_hparams()
}

def _build(self, inputs, *args, **kwargs):
"""Encodes the inputs.
Args:
inputs: Inputs to the encoder.
*args: Other arguments.
**kwargs: Keyword arguments.
Returns:
Encoding results.
"""
raise NotImplementedError

@property
def embedding(self):
"""The embedding variable.
"""
return self._embedding

@property
def cell(self):
"""The RNN cell.
"""
return self._cell

@property
def state_size(self):
"""The state size of encoder cell.
Same as :attr:`encoder.cell.state_size`.
"""
return self.cell.state_size


class ForwardRNNEncoder(RNNEncoderBase):
"""One directional forward RNN encoder.
See :class:`~txtgen.modules.encoders.rnn_encoders.RNNEncoderBase` for the
arguments, and
:class:`~txtgen.modules.encoders.rnn_encoders.RNNEncoderBase.`
`default_hparams` for the default hyperparameters.
"""

def __init__(self, # pylint: disable=too-many-arguments
cell=None,
embedding=None,
embedding_trainable=True,
vocab_size=None,
hparams=None,
name="forward_rnn_encoder"):
RNNEncoderBase.__init__(
self, cell, embedding, embedding_trainable,
vocab_size, hparams, name)

def _build(self, inputs, **kwargs):
"""Encodes the inputs.
Expand Down Expand Up @@ -140,24 +193,3 @@ def _build(self, inputs, **kwargs):
self._built = True

return results


@property
def embedding(self):
"""The embedding variable.
"""
return self._embedding

@property
def cell(self):
"""The RNN cell.
"""
return self._cell

@property
def state_size(self):
"""The state size of encoder cell.
Same as :attr:`encoder.cell.state_size`.
"""
return self.cell.state_size

0 comments on commit e988614

Please sign in to comment.