Skip to content

Commit

Permalink
updated module interfaces, put in rather than as an argument to __init__
Browse files Browse the repository at this point in the history
Former-commit-id: e08ffbc
  • Loading branch information
ZhitingHu committed Sep 29, 2017
1 parent 005bf91 commit fe3c56d
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 102 deletions.
40 changes: 24 additions & 16 deletions txtgen/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from __future__ import print_function
from __future__ import unicode_literals

import copy

import tensorflow as tf
import tensorflow.contrib.slim as tf_slim

Expand All @@ -21,13 +19,12 @@
from txtgen.data.embedding import Embedding


def default_text_database_hparams():
def default_text_dataset_hparams():
"""Returns a dictionary of hyperparameters of a text dataset with default
values.
"""
# TODO(zhiting): add more docs
return {
"name": "",
"files": [],
"vocab.file": "",
"vocab.share_with": "",
Expand All @@ -46,19 +43,25 @@ def default_text_database_hparams():
}
}

def default_paired_text_dataset_hparams(): # pylint: disable=invalid-name
"""Returns
"""
return {
}

class DataBaseBase(object):
"""Base class of all data classes.
"""

def __init__(self, hparams, name="database"):
self.name = name
def __init__(self, hparams):
self._hparams = HParams(hparams, self.default_hparams())

@staticmethod
def default_hparams():
"""Returns a dicitionary of default hyperparameters.
"""
return {
"name": "database",
"num_epochs": 1,
"batch_size": 64,
"allow_smaller_final_batch": False,
Expand Down Expand Up @@ -101,6 +104,12 @@ def hparams(self):
"""
return self._hparams

@property
def name(self):
"""The name of the data base.
"""
return self.hparams.name


class MonoTextDataBase(DataBaseBase):
"""Text data base that reads single set of text files.
Expand All @@ -110,27 +119,27 @@ class MonoTextDataBase(DataBaseBase):
:class:`~txtgen.data.database.PairedTextDataBase`.
Args:
hparams (dict): Hyperparameters. See
:meth:`~txgen.data.database.default_text_database_hparams` for
the defaults.
hparams (dict): Hyperparameters. See :meth:`default_hparams` for the
defaults.
name (str): Name of the database.
"""

def __init__(self, hparams, name="mono_text_database"):
DataBaseBase.__init__(self, hparams, name)
def __init__(self, hparams):
DataBaseBase.__init__(self, hparams)

# pylint: disable=not-context-manager
with tf.name_scope(name, "mono_text_database"):
with tf.name_scope(self.name, self.default_hparams["name"]):
self._dataset = self.make_dataset(self._hparams.dataset)
self._data_provider = self._make_data_provider(self._dataset)

@staticmethod
def default_hparams():
"""Returns a dicitionary of default hyperparameters.
"""
hparams = copy.deepcopy(DataBaseBase.default_hparams())
hparams = DataBaseBase.default_hparams()
hparams["name"] = "mono_text_database"
hparams.update({
"dataset": default_text_database_hparams()
"dataset": default_text_dataset_hparams()
})
return hparams

Expand Down Expand Up @@ -172,8 +181,7 @@ def make_dataset(dataset_hparams):
num_samples=None,
items_to_descriptions=None,
vocab=vocab,
embedding=embedding,
name=dataset_hparams["name"])
embedding=embedding)

return dataset

Expand Down
13 changes: 10 additions & 3 deletions txtgen/modules/connectors/connector_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@ class ConnectorBase(ModuleBase):
Integer, a Tensorshape , or a tuple of Integers or TensorShapes.
This can typically be obtained by `decoder.cell.state_size`.
hparams (dict): Hyperparameters of connector.
name (str): Name of connector.
"""

def __init__(self, decoder_state_size, hparams=None, name="connector"):
ModuleBase.__init__(self, name, hparams)
def __init__(self, decoder_state_size, hparams=None):
ModuleBase.__init__(self, hparams)
self._decoder_state_size = decoder_state_size

@staticmethod
def default_hparams():
"""Returns a dictionary of hyperparameters with default values.
"""
return {
"name": "connector"
}

def _build(self, *args, **kwargs):
"""Transforms inputs to the decoder initial states.
"""
Expand Down
60 changes: 32 additions & 28 deletions txtgen/modules/connectors/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from txtgen.modules.connectors.connector_base import ConnectorBase
from txtgen.core.utils import get_function
from txtgen.core import distributions


def _mlp_transform(inputs, output_size, activation_fn=tf.identity):
Expand Down Expand Up @@ -56,6 +55,7 @@ def _mlp_transform(inputs, output_size, activation_fn=tf.identity):

return output


class ConstantConnector(ConnectorBase):
"""Creates decoder initial state that has a constant value.
Expand All @@ -64,11 +64,9 @@ class ConstantConnector(ConnectorBase):
Integer, a Tensorshape, or a tuple of Integers or TensorShapes.
This can typically be obtained by :attr:`decoder.state_size`.
hparams (dict): Hyperparameters of the connector.
name (str): Name of connector.
"""
def __init__(self, decoder_state_size, hparams=None,
name="constant_connector"):
ConnectorBase.__init__(self, decoder_state_size, hparams, name)
def __init__(self, decoder_state_size, hparams=None):
ConnectorBase.__init__(self, decoder_state_size, hparams)

@staticmethod
def default_hparams():
Expand All @@ -78,11 +76,14 @@ def default_hparams():
{
# The constant value that the decoder initial state has.
"value": 0.
"value": 0.,
# The name of the connector.
"name": "constant_connector"
}
"""
return {
"value": 0.
"value": 0.,
"name": "constant_connector"
}

def _build(self, batch_size, value=None): # pylint: disable=W0221
Expand Down Expand Up @@ -120,20 +121,24 @@ class ForwardConnector(ConnectorBase):
decoder_state_size: Size of state of the decoder cell. Can be an
Integer, a Tensorshape , or a tuple of Integers or TensorShapes.
This can typically be obtained by :attr:`decoder.cell.state_size`.
name (str): Name of connector.
"""

def __init__(self, decoder_state_size, name="forward_connector"):
ConnectorBase.__init__(self, decoder_state_size, None, name)
def __init__(self, decoder_state_size):
ConnectorBase.__init__(self, decoder_state_size, None)

@staticmethod
def default_hparams():
"""Returns a dictionary of default hyperparameters.
The dictionary is empty since the connector does not have any
configurable hyperparameters.
.. code-block:: python
{
# The name of the connector.
"name": "forward_connector"
"""
return {}
return {
"name": "forward_connector"
}

def _build(self, inputs): # pylint: disable=W0221
"""Passes inputs to the initial states of decoder.
Expand Down Expand Up @@ -173,8 +178,8 @@ class MLPTransformConnector(ConnectorBase):
name (str): Name of connector.
"""

def __init__(self, decoder_state_size, hparams=None, name="mlp_connector"):
ConnectorBase.__init__(self, decoder_state_size, hparams, name)
def __init__(self, decoder_state_size, hparams=None):
ConnectorBase.__init__(self, decoder_state_size, hparams)

@staticmethod
def default_hparams():
Expand All @@ -188,11 +193,15 @@ def default_hparams():
# functions defined in module `tensorflow` or `tensorflow.nn`,
# or user-defined functions defined in `user.custom`, or a
# full path like "my_module.my_activation_fn".
"activation_fn": "tensorflow.identity"
"activation_fn": "tensorflow.identity",
# Name of the connector.
"name": "mlp_connector"
}
"""
return {
"activation_fn": "tensorflow.identity"
"activation_fn": "tensorflow.identity",
"name": "mlp_connector"
}

def _build(self, inputs): #pylint: disable=W0221
Expand All @@ -218,6 +227,7 @@ def _build(self, inputs): #pylint: disable=W0221

return output


#TODO(junxian): Customize reparameterize type
class StochasticConnector(ConnectorBase):
"""Samples decoder initial state from a distribution defined by the inputs.
Expand All @@ -226,29 +236,23 @@ class StochasticConnector(ConnectorBase):
models.
"""

def __init__(self, decoder_state_size, hparams=None,
name="stochastic_connector"):
ConnectorBase.__init__(self, decoder_state_size, hparams, name)
def __init__(self, decoder_state_size, hparams=None):
ConnectorBase.__init__(self, decoder_state_size, hparams)

#TODO(zhiting): add docs
@staticmethod
def default_hparams():
"""Returns a dictionary of hyperparameters with default values.
Returns:
```python
{
# The name or full path of the activation function applied to
# the outputs of the MLP layer. E.g., the name of built-in
# functions defined in module `tensorflow` or `tensorflow.nn`,
# or user-defined functions defined in `user.custom`, or a
# full path like "my_module.my_activation_fn".
"activation_fn": "tensorflow.identity"
}
```
"""
return {
"distribution": "tf.contrib.distributions.MultivariateNormalDiag"
"distribution": "tf.contrib.distributions.MultivariateNormalDiag",
"name": "stochastic_connector"
}

def _build(self, inputs): # pylint: disable=W0221
Expand Down
14 changes: 8 additions & 6 deletions txtgen/modules/decoders/rnn_decoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,15 @@ class RNNDecoderBase(ModuleBase, TFDecoder):
hparams (dict, optional): Hyperparameters. If not specified, the default
hyperparameter setting is used. See :attr:`default_hparams` for the
structure and default values.
name (str): Name of the decoder.
"""

def __init__(self, # pylint: disable=too-many-arguments
cell=None,
embedding=None,
embedding_trainable=True,
vocab_size=None,
hparams=None,
name="rnn_decoder"):
ModuleBase.__init__(self, name, hparams)
hparams=None):
ModuleBase.__init__(self, hparams)

self._helper = None
self._initial_state = None
Expand Down Expand Up @@ -117,7 +115,10 @@ def default_hparams():
# (optional) An integer. Maximum allowed number of decoding
# steps at inference time. If `None` (default), decoding is
# performed until fully done, e.g., encountering the EOS token.
"max_decoding_length_infer": None
"max_decoding_length_infer": None,
# Name of the decoder.
"name": "rnn_decoder"
}
"""
return {
Expand All @@ -126,7 +127,8 @@ def default_hparams():
"helper_train": rnn_decoder_helpers.default_helper_train_hparams(),
"helper_infer": rnn_decoder_helpers.default_helper_infer_hparams(),
"max_decoding_length_train": None,
"max_decoding_length_infer": None
"max_decoding_length_infer": None,
"name": "rnn_decoder"
}


Expand Down
17 changes: 14 additions & 3 deletions txtgen/modules/decoders/rnn_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,21 @@ def __init__(self, # pylint: disable=too-many-arguments
embedding=None,
embedding_trainable=True,
vocab_size=None,
hparams=None,
name="basic_rnn_decoder"):
hparams=None):
RNNDecoderBase.__init__(self, cell, embedding, embedding_trainable,
vocab_size, hparams, name)
vocab_size, hparams)

@staticmethod
def default_hparams():
"""Returns a dictionary of hyperparameters with default values.
The hyperparameters have the same structure as in
:meth:`txtgen.modules.RNNDecoderBase.default_hparams`, except that
the default "name" is "basic_rnn_decoder".
"""
hparams = RNNDecoderBase.default_hparams()
hparams["name"] = "basic_rnn_decoder"
return hparams

def initialize(self, name=None):
return self._helper.initialize() + (self._initial_state,)
Expand Down
12 changes: 10 additions & 2 deletions txtgen/modules/encoders/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,16 @@ class EncoderBase(ModuleBase):
"""Base class inherited by all encoder classes.
"""

def __init__(self, hparams=None, name="encoder"):
ModuleBase.__init__(self, name, hparams)
def __init__(self, hparams=None):
ModuleBase.__init__(self, hparams)

@staticmethod
def default_hparams():
"""Returns a dictionary of hyperparameters with default values.
"""
return {
"name": "encoder"
}

def _build(self, inputs, *args, **kwargs):
"""Encodes the inputs.
Expand Down

0 comments on commit fe3c56d

Please sign in to comment.