Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
haoransh committed Mar 30, 2019
1 parent 0513239 commit 2e2201f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
3 changes: 2 additions & 1 deletion texar/modules/decoders/rnn_decoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from texar.module_base import ModuleBase
from texar.modules.decoders import rnn_decoder_helpers
from texar.utils.dtypes import is_callable
from texar.utils.shapes import shape_list

__all__ = [
"RNNDecoderBase"
Expand Down Expand Up @@ -74,11 +75,11 @@ def __init__(self,

# Make the output layer
self._vocab_size = vocab_size
self._output_layer = output_layer

if is_callable(output_layer):
self._output_layer = output_layer
elif tf.contrib.framework.is_tensor(output_layer):
self._vocab_size = shape_list(output_layer)[1]
self._output_layer = self._make_output_layer_from_tensor(
output_layer, self._vocab_size)
elif output_layer is None:
Expand Down
11 changes: 5 additions & 6 deletions texar/modules/decoders/rnn_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ class BasicRNNDecoder(RNNDecoderBase):
Ignored if :attr:`cell` is given.
vocab_size (int, optional): Vocabulary size. Required if
:attr:`output_layer` is `None`.
output_layer (optional): An instance of
:tf_main:`tf.layers.Layer <layers/Layer>`, or
:tf_main:`tf.identity <identity>`. Apply to the RNN cell
output to get logits. If `None`, a dense layer
is used with output dimension set to :attr:`vocab_size`.
output_layer (optional): An instance of callable layer to transform
output to logits. Or a tensor which is used as the kernel weights
to transform hidden states into logits. If None, use `vocab_size`
and `hparams.output_layer_bias` to create the output layer.
Set `output_layer=tf.identity` if you do not want to have an
output layer after the RNN cell outputs.
output layer after the cell outputs.
hparams (dict, optional): Hyperparameters. Missing
hyperparamerter will be set to default values. See
:meth:`default_hparams` for the hyperparameter sturcture and
Expand Down
31 changes: 18 additions & 13 deletions texar/modules/decoders/transformer_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,23 @@ class TransformerDecoder(ModuleBase, TFDecoder):
:class:`~texar.modules.FeedForwardNetwork`, and residual connections.
Args:
vocab_size:
Specify the size of the output vocabulary.
Ignored if output_layer is provided.
output_layer:
a callable function to transform the hidden states to logits.
Or a tensor which is used as the kernel weights to transform hidden
states into logits.
If it's not provided, use `vocab_size` and
`hparams.output_layer_bias` to create the output layer.
vocab_size (int, optional): Vocabulary size. Required if
:attr:`output_layer` is `None`.
output_layer (optional): An instance of callable layer to transform
output to logits. Or a tensor which is used as the kernel weights
to transform hidden states into logits. If None, use `vocab_size`
and `hparams.output_layer_bias` to create the output layer.
Set `output_layer=tf.identity` if you do not want to have an
output layer after the outputs.
.. document private functions
.. automethod:: _build
"""

def __init__(self, vocab_size, output_layer, hparams=None):
def __init__(self,
vocab_size=None,
output_layer=None,
hparams=None):
ModuleBase.__init__(self, hparams)

self._vocab_size = vocab_size
Expand All @@ -93,11 +95,13 @@ def __init__(self, vocab_size, output_layer, hparams=None):
tf.get_variable_scope().set_initializer(
layers.get_initializer(self._hparams.initializer))

# Make the output layer
if is_callable(output_layer):
self._output_layer = output_layer
elif tf.contrib.framework.is_tensor(output_layer):
self._vocab_size = shape_list(output_layer)[1]
self._output_layer = self._make_output_layer_from_tensor(
output_layer, self._vocab_size)
output_layer)
elif output_layer is None:
if self._vocab_size is None:
raise ValueError(
Expand Down Expand Up @@ -157,11 +161,12 @@ def __init__(self, vocab_size, output_layer, hparams=None):
self._helper = None
self._cache = None

def _make_output_layer_from_tensor(self, output_layer_tensor, vocab_size):
def _make_output_layer_from_tensor(self, output_layer_tensor):
"""Creates an output layer from a Tensor. Used to tie word embedding
with the output layer weight.
"""
affine_bias = None
vocab_size = self._vocab_size
if self._hparams.output_layer_bias:
with tf.variable_scope(self.variable_scope):
affine_bias = tf.get_variable('affine_bias', [vocab_size])
Expand All @@ -173,7 +178,7 @@ def _outputs_to_logits(outputs):
logits = tf.matmul(outputs, output_layer_tensor)
if affine_bias is not None:
logits += affine_bias
logits = tf.reshape(logits, shape[:-1] + [self._vocab_size])
logits = tf.reshape(logits, shape[:-1] + [vocab_size])
return logits

return _outputs_to_logits
Expand Down

0 comments on commit 2e2201f

Please sign in to comment.