Skip to content

Commit

Permalink
modified connectors
Browse files Browse the repository at this point in the history
Former-commit-id: 2ee4a9d
  • Loading branch information
jxhe committed Oct 20, 2017
1 parent 4a133a7 commit e02f1a3
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 60 deletions.
166 changes: 112 additions & 54 deletions txtgen/modules/connectors/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@
from txtgen.core.utils import get_instance

# pylint: disable=too-many-locals, arguments-differ, too-many-arguments
def _assert_same_size(outputs, output_size):
"""Check if outputs match output_size
Args:
outputs: A Tensor or a (nested) tuple of tensors
output_size: Can be an Integer, a TensorShape, or a (nested) tuple of
Integers or TensorShape.
"""
nest.assert_same_structure(outputs, output_size)
flat_output_size = nest.flatten(output_size)
flat_output = nest.flatten(outputs)

for (output, size) in zip(flat_output, flat_output_size):
if output[0].shape != tf.TensorShape(size):
raise ValueError(
"The output size does not match the the required output_size")


def _mlp_transform(inputs, output_size, activation_fn=tf.identity):
"""Transforms inputs through a fully-connected layer that creates the output
with specified size.
Expand Down Expand Up @@ -74,13 +92,13 @@ class ConstantConnector(ConnectorBase):
"""Creates decoder initial state that has a constant value.
Args:
decoder_state_size: Size of state of the decoder cell. Can be an
output_size: Size of state of the decoder cell. Can be an
Integer, a Tensorshape, or a tuple of Integers or TensorShapes.
This can typically be obtained by :attr:`decoder.state_size`.
hparams (dict): Hyperparameters of the connector.
"""
def __init__(self, decoder_state_size, hparams=None):
ConnectorBase.__init__(self, decoder_state_size, hparams)
def __init__(self, output_size, hparams=None):
ConnectorBase.__init__(self, output_size, hparams)

@staticmethod
def default_hparams():
Expand Down Expand Up @@ -118,7 +136,7 @@ def _build(self, batch_size, value=None): # pylint: disable=W0221
value_ = self.hparams.value
output = nest.map_structure(
lambda x: tf.constant(value_, shape=[batch_size, x]),
self._decoder_state_size)
self._output_size)
return output


Expand All @@ -132,13 +150,13 @@ class ForwardConnector(ConnectorBase):
:meth:`~tensorflow.python.util.nest.pack_sequence_as` for more details).
Args:
decoder_state_size: Size of state of the decoder cell. Can be an
output_size: Size of state of the decoder cell. Can be an
Integer, a Tensorshape , or a tuple of Integers or TensorShapes.
This can typically be obtained by :attr:`decoder.cell.state_size`.
"""

def __init__(self, decoder_state_size):
ConnectorBase.__init__(self, decoder_state_size, None)
def __init__(self, output_size):
ConnectorBase.__init__(self, output_size, None)

@staticmethod
def default_hparams():
Expand Down Expand Up @@ -169,11 +187,11 @@ def _build(self, inputs): # pylint: disable=W0221
"""
output = inputs
try:
nest.assert_same_structure(inputs, self._decoder_state_size)
nest.assert_same_structure(inputs, self._output_size)
except (ValueError, TypeError):
flat_input = nest.flatten(inputs)
output = nest.pack_sequence_as(
self._decoder_state_size, flat_input)
self._output_size, flat_input)

self._built = True

Expand All @@ -185,14 +203,14 @@ class MLPTransformConnector(ConnectorBase):
initial state.
Args:
decoder_state_size: Size of state of the decoder cell. Can be an
output_size: Size of state of the decoder cell. Can be an
Integer, a Tensorshape , or a tuple of Integers or TensorShapes.
This can typically be obtained by :attr:`decoder.cell.state_size`.
hparams (dict): Hyperparameters of the connector.
"""

def __init__(self, decoder_state_size, hparams=None):
ConnectorBase.__init__(self, decoder_state_size, hparams)
def __init__(self, output_size, hparams=None):
ConnectorBase.__init__(self, output_size, hparams)

@staticmethod
def default_hparams():
Expand Down Expand Up @@ -233,7 +251,7 @@ def _build(self, inputs): #pylint: disable=W0221
fn_modules = ['txtgen.custom', 'tensorflow', 'tensorflow.nn']
activation_fn = get_function(self.hparams.activation_fn, fn_modules)

output = _mlp_transform(inputs, self._decoder_state_size, activation_fn)
output = _mlp_transform(inputs, self._output_size, activation_fn)

self._add_internal_trainable_variables()
self._built = True
Expand Down Expand Up @@ -270,20 +288,19 @@ def default_hparams():
"""
return {
"distribution": {
"type": "tf.contrib.distributions.MultivariateNormalDiag",
"type": "MultivariateNormalDiag",
"kwargs": {}
},
"activation_fn": "tensorflow.identity",
"name": "reparameterized_stochastic_connector"
}

# TODO(zhiting): Is the docstring of returned value correct?
def _build(self,
distribution=None,
distribution_type=None,
distribution_kwargs=None,
transform=True,
num_samples=1):
num_samples=None):
"""Samples from a distribution and optionally performs transformation.
The distribution must be reparameterizable, i.e.,
Expand All @@ -304,21 +321,22 @@ class which inherits
transform (bool): Whether to perform MLP transformation of the
samples. If `False`, the shape of a sample must match the
:attr:`output_size`.
num_samples (int or scalar int Tensor): Number of samples to
generate.
num_samples (int or scalar int Tensor, optional): Number of samples to
generate. `None` is required in training stage.
Returns:
If `num_samples`=1, returns a Tensor of shape `output_size`. If
`num_samples`>1, returns a Tensor of shape
`[num_samples x output_size]`.
If `num_samples`==None, returns a Tensor of shape `[batch_size x
output_size]`, else returns a Tensor of shape `[num_samples x
output_size]`. `num_samples` should be specified if not in
training stage.
Raises:
ValueError: If distribution cannot be reparametrized.
ValueError: The output does not match the :attr:`output_size`.
"""
if distribution:
dstr = distribution
elif distribution_type:
elif distribution_type and distribution_kwargs:
dstr = get_instance(
distribution_type, distribution_kwargs,
["txtgen.custom", "tensorflow.contrib.distributions"])
Expand All @@ -334,6 +352,8 @@ class which inherits

if num_samples:
output = dstr.sample(num_samples)
else:
output = dstr.sample()

if dstr.event_shape == []:
output = tf.reshape(output,
Expand All @@ -344,10 +364,7 @@ class which inherits
fn_modules = ['txtgen.custom', 'tensorflow', 'tensorflow.nn']
activation_fn = get_function(self.hparams.activation_fn, fn_modules)
output = _mlp_transform(output, self._output_size, activation_fn)
# TODO (zhiting): does self._output_size include batch_size as a
# dimension? If not, we also need to remove the num_samples dimension
# of `output` before this assert statement.
nest.assert_same_structure(output, self._output_size)
_assert_same_size(output, self._output_size)

self._add_internal_trainable_variables()
self._built = True
Expand All @@ -362,8 +379,8 @@ class StochasticConnector(ConnectorBase):
models.
"""

def __init__(self, decoder_state_size, hparams=None):
ConnectorBase.__init__(self, decoder_state_size, hparams)
def __init__(self, output_size, hparams=None):
ConnectorBase.__init__(self, output_size, hparams)

#TODO(zhiting): add docs
@staticmethod
Expand All @@ -377,39 +394,71 @@ def default_hparams():
```
"""
return {
"distribution": "tf.contrib.distributions.Categorical",
"distribution": {
"type": "tf.contrib.distributions.Categorical",
"kwargs": {}
},
"activation_fn": "tensorflow.identity",
"name": "stochastic_connector"
}

# pylint: disable=arguments-differ
def _build(self, distribution=None, batch_size=None, trans=False,
ds_name=None, **kwargs):
"""Samples from a distribution defined by the inputs.

def _build(self,
distribution=None,
distribution_type=None,
distribution_kwargs=None,
transform=False,
num_samples=None):

"""Samples from a distribution and optionally performs transformation.
Gradients would not propagate through the random samples.
Args:
batch_size (int or scalar int Tensor): The batch size.
distribution: Instance of tf.contrib.distributions
batch_size (int or scalar int Tensor): The batch size.
distribution (optional): An instance of
:class:`~tensorflow.contrib.distributions.Distribution`. If
`None` (default), distribution is constructed based on
:attr:`distribution_type` or
:attr:`hparams['distribution']['type']`.
distribution_type (str, optional): Name or path to the distribution
class which inherits
:class:`~tensorflow.contrib.distributions.Distribution`. Ignored
if :attr:`distribution` is specified.
distribution_kwargs (dict, optional): Keyword arguments of the
distribution class specified in :attr:`distribution_type`.
transform (bool): Whether to perform MLP transformation of the
samples. If `False`, the shape of a sample must match the
:attr:`output_size`.
num_samples (int or scalar int Tensor, optional): Number of samples
to generate. `None` is required in training stage.
Returns:
A Tensor or a (nested) tuple of Tensors of the same structure of
the decoder state.
If `num_samples`==None, returns a Tensor of shape `[batch_size x
output_size]`, else returns a Tensor of shape `[num_samples x
output_size]`. `num_samples` should be specified if not in
training stage.
Raises:
ValueError: The output does not match the :attr:`output_size`.
"""
if distribution:
dist_instance = distribution
elif ds_name and kwargs:
dist_instance = get_instance(ds_name, kwargs,
["tensorflow.contrib.distributions"])
dstr = distribution
elif distribution_type and distribution_kwargs:
dstr = get_instance(
distribution_type, distribution_kwargs,
["txtgen.custom", "tensorflow.contrib.distributions"])
else:
raise ValueError(
"Either distribution or (ds_name, kwargs) must be provided")
dstr = get_instance(
self.hparams.distribution.type,
self.hparams.distribution.kwargs,
["txtgen.custom", "tensorflow.contrib.distributions"])

if batch_size:
output = dist_instance.sample(batch_size)
if num_samples:
output = dstr.sample(num_samples)
else:
output = dist_instance.sample()
output = dstr.sample()

if dist_instance.event_shape == []:
if dstr.event_shape == []:
output = tf.reshape(output,
output.shape.concatenate(tf.TensorShape(1)))

Expand All @@ -418,8 +467,11 @@ def _build(self, distribution=None, batch_size=None, trans=False,

output = tf.cast(output, tf.float32)

if trans:
output = _mlp_transform(output, self._decoder_state_size)
if transform:
fn_modules = ['txtgen.custom', 'tensorflow', 'tensorflow.nn']
activation_fn = get_function(self.hparams.activation_fn, fn_modules)
output = _mlp_transform(output, self._output_size, activation_fn)
_assert_same_size(output, self._output_size)

self._add_internal_trainable_variables()
self._built = True
Expand All @@ -433,8 +485,8 @@ class ConcatConnector(ConnectorBase):
models.
"""

def __init__(self, decoder_state_size, hparams=None):
ConnectorBase.__init__(self, decoder_state_size, hparams)
def __init__(self, output_size, hparams=None):
ConnectorBase.__init__(self, output_size, hparams)

@staticmethod
def default_hparams():
Expand All @@ -447,10 +499,11 @@ def default_hparams():
```
"""
return {
"name": "concatconnector"
"activation_fn": "tensorflow.identity",
"name": "concat_connector"
}

def _build(self, connector_inputs): # pylint: disable=W0221
def _build(self, connector_inputs, transform=True): # pylint: disable=W0221
"""Concatenate multiple input connectors
Args:
Expand All @@ -464,7 +517,12 @@ def _build(self, connector_inputs): # pylint: disable=W0221
connector_inputs = [tf.cast(connector, tf.float32)
for connector in connector_inputs]
output = tf.concat(connector_inputs, axis=1)
output = _mlp_transform(output, self._decoder_state_size)

if transform:
fn_modules = ['txtgen.custom', 'tensorflow', 'tensorflow.nn']
activation_fn = get_function(self.hparams.activation_fn, fn_modules)
output = _mlp_transform(output, self._output_size, activation_fn)
_assert_same_size(output, self._output_size)

self._add_internal_trainable_variables()
self._built = True
Expand Down

0 comments on commit e02f1a3

Please sign in to comment.