Skip to content

Commit

Permalink
passing pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhitingHu committed Aug 30, 2018
1 parent 59c38e1 commit 71a7cee
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 49 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ docs/_build
### IntelliJ ###
*.iml

### pytest ###
/.pytest_cache/

### Project ###
/data/
Expand Down
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@
'tf_r0.12': (
'https://www.tensorflow.org/versions/r0.12/api_docs/python/%s',
None),
'tf_hmpg': (
'https://www.tensorflow.org/%s',
None),
'gym': (
'https://gym.openai.com/docs/%s',
None),
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
tensorflow >= 1.6.0
tensorflow-gpu >= 1.6.0
tensorflow-probability >= 0.3.0
tensorflow-probability-gpu >= 0.3.0
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

setuptools.setup(
name="texar",
version="0.0.1",
url="https://github.com/ZhitingHu/txtgen",
version="0.1",
url="https://github.com/asyml/texar",

description="An open and flexible framework for text generation.",
description="Toolkit for Text Generation and Beyond",

packages=setuptools.find_packages(),
platforms='any',
Expand Down
22 changes: 10 additions & 12 deletions texar/core/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,9 @@ def __init__(self,
else:
self._layers.append(get_layer(hparams=layer))

# Keep tracks of whether trainable variables have been created
self._vars_built = False

def compute_output_shape(self, input_shape):
if self._layers is None:
_shapes = input_shape
Expand Down Expand Up @@ -870,8 +873,9 @@ def call(self, inputs):
else:
raise ValueError("Unknown merge mode: '%s'" % self._mode)

if not self.built:
if not self.built or not self._vars_built:
self._collect_weights()
self._vars_built = True

return outputs

Expand All @@ -881,11 +885,6 @@ def layers(self):
"""
return self._layers

def build(self, _):
"""Dumb method.
"""
# Does not set :attr:`self.built` as this point.
pass

class SequentialLayer(tf.layers.Layer):
"""A subclass of :tf_main:`tf.layers.Layer <layers/Layer>`.
Expand Down Expand Up @@ -915,6 +914,9 @@ def __init__(self,
else:
self._layers.append(get_layer(hparams=layer))

# Keep tracks of whether trainable variables have been created
self._vars_built = False

def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape)
for layer in self._layers:
Expand Down Expand Up @@ -947,8 +949,9 @@ def call(self, inputs, mode=None): # pylint: disable=arguments-differ
outputs = layer(inputs)
inputs = outputs

if not self.built:
if not self.built or not self._vars_built:
self._collect_weights()
self._vars_built = True

return outputs

Expand All @@ -958,11 +961,6 @@ def layers(self):
"""
return self._layers

def build(self, _):
"""Dumb method.
"""
# Does not set :attr:`self.built` as this point.
pass

def _common_default_conv_dense_kwargs():
"""Returns the default keyword argument values that are common to
Expand Down
6 changes: 2 additions & 4 deletions texar/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from __future__ import print_function
from __future__ import division

import inspect

import tensorflow as tf

from texar.hyperparams import HParams
Expand Down Expand Up @@ -216,7 +214,7 @@ def get_optimizer_fn(hparams=None):

def _get_opt(learning_rate=None):
opt_kwargs = hparams["kwargs"].todict()
fn_args = set(inspect.getargspec(opt_class.__init__).args)
fn_args = set(utils.get_args(opt_class.__init__))
if 'learning_rate' in fn_args and learning_rate is not None:
opt_kwargs["learning_rate"] = learning_rate
return opt_class(**opt_kwargs)
Expand Down Expand Up @@ -319,7 +317,7 @@ def get_gradient_clip_fn(hparams=None):

fn_modules = ["tensorflow", "texar.custom"]
clip_fn = utils.get_function(fn_type, fn_modules)
clip_fn_args = inspect.getargspec(clip_fn).args
clip_fn_args = utils.get_args(clip_fn)
fn_kwargs = hparams["kwargs"]
if isinstance(fn_kwargs, HParams):
fn_kwargs = fn_kwargs.todict()
Expand Down
12 changes: 8 additions & 4 deletions texar/losses/rewards_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def test_discount_reward(self):

r, r_n = sess.run([discounted_reward_, discounted_reward_n_])

np.testing.assert_array_equal(discounted_reward, r)
np.testing.assert_array_equal(discounted_reward_n, r_n)
np.testing.assert_array_almost_equal(
discounted_reward, r, decimal=6)
np.testing.assert_array_almost_equal(
discounted_reward_n, r_n, decimal=6)

# 2D
reward = np.ones([2, 10], dtype=np.float64)
Expand All @@ -70,8 +72,10 @@ def test_discount_reward(self):

r, r_n = sess.run([discounted_reward_, discounted_reward_n_])

np.testing.assert_array_equal(discounted_reward, r)
np.testing.assert_array_equal(discounted_reward_n, r_n)
np.testing.assert_array_almost_equal(
discounted_reward, r, decimal=6)
np.testing.assert_array_almost_equal(
discounted_reward_n, r_n, decimal=6)

def test_discount_reward_py_1d(self):
"""Tests :func:`texar.losses.rewards._discount_reward_py_1d`
Expand Down
36 changes: 25 additions & 11 deletions texar/modules/connectors/connectors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various connectors.
"""
Expand All @@ -10,7 +22,7 @@
import numpy as np

import tensorflow as tf
import tensorflow.contrib.distributions as tf_dstr
from tensorflow import distributions as tf_dstr
from tensorflow.python.util import nest # pylint: disable=E0611

from texar.modules.connectors.connector_base import ConnectorBase
Expand Down Expand Up @@ -484,10 +496,10 @@ def _build(self,
`distribution.reparameterization_type = FULLY_REPARAMETERIZED`.
Args:
distribution: A
:tf_main:`TF Distribution <contrib/distributions/Distribution>`.
Can be a class, its name or module path, or an instance of
a subclass.
distribution: A instance of subclass of
:tf_main:`TF Distribution <distributions/Distribution>`,
or :tf_hmpg:`tensorflow_probability Distribution <probability>`,
Can be a class, its name or module path, or a class instance.
distribution_kwargs (dict, optional): Keyword arguments for the
distribution constructor. Ignored if `distribution` is a
class instance.
Expand Down Expand Up @@ -515,7 +527,8 @@ class instance.
"""
dstr = check_or_get_instance(
distribution, distribution_kwargs,
["tensorflow.contrib.distributions", "texar.custom"])
["tensorflow.distributions", "tensorflow_probability.distributions",
"texar.custom"])

if dstr.reparameterization_type == tf_dstr.NOT_REPARAMETERIZED:
raise ValueError(
Expand Down Expand Up @@ -614,10 +627,10 @@ def _build(self,
cannot be back-propagate through the samples.
Args:
distribution: A
:tf_main:`TF Distribution <contrib/distributions/Distribution>`.
Can be a class, its name or module path, or an instance of
a subclass.
distribution: A instance of subclass of
:tf_main:`TF Distribution <distributions/Distribution>`,
or :tf_hmpg:`tensorflow_probability Distribution <probability>`.
Can be a class, its name or module path, or a class instance.
distribution_kwargs (dict, optional): Keyword arguments for the
distribution constructor. Ignored if `distribution` is a
class instance.
Expand All @@ -644,7 +657,8 @@ class instance.
"""
dstr = check_or_get_instance(
distribution, distribution_kwargs,
["tensorflow.contrib.distributions", "texar.custom"])
["tensorflow.distributions", "tensorflow_probability.distributions",
"tensorflow.contrib.distributions", "texar.custom"])

if num_samples:
output = dstr.sample(num_samples)
Expand Down
6 changes: 3 additions & 3 deletions texar/modules/connectors/connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import unicode_literals

import tensorflow as tf
import tensorflow.contrib.distributions as tfds
from tensorflow_probability import distributions as tfpd
from tensorflow.python.util import nest # pylint: disable=E0611

from texar.core import layers
Expand Down Expand Up @@ -85,8 +85,8 @@ def test_reparameterized_stochastic_connector(self):
var = tf.ones([self._batch_size, variable_size])
mu_vec = tf.zeros([variable_size])
var_vec = tf.ones([variable_size])
gauss_ds = tfds.MultivariateNormalDiag(loc=mu, scale_diag=var)
gauss_ds_vec = tfds.MultivariateNormalDiag(loc=mu_vec,
gauss_ds = tfpd.MultivariateNormalDiag(loc=mu, scale_diag=var)
gauss_ds_vec = tfpd.MultivariateNormalDiag(loc=mu_vec,
scale_diag=var_vec)
gauss_connector = ReparameterizedStochasticConnector(state_size)
gauss_connector_ts = ReparameterizedStochasticConnector(state_size_ts)
Expand Down
12 changes: 12 additions & 0 deletions texar/modules/embedders/embedder_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils of embedder.
"""

Expand Down
13 changes: 9 additions & 4 deletions texar/modules/encoders/hierarchical_encoders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ def test_order(self):
minval=-1,
dtype=tf.float32)

outputs, state = encoder(inputs, order='btu', time_major=False)
outputs, state = encoder(inputs, order='utb', time_major=True)
outputs, state = encoder(inputs, order='tbu', time_major_major=True)
outputs, state = encoder(inputs, order='ubt', time_major_minor=True)
outputs_1, state_1 = encoder(inputs, order='btu')
outputs_2, state_2 = encoder(inputs, order='utb')
outputs_3, state_3 = encoder(inputs, order='tbu')
outputs_4, state_4 = encoder(inputs, order='ubt')

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run([outputs_1, state_1, outputs_2, state_2,
outputs_3, state_3, outputs_4, state_4])

def test_depack(self):
hparams = {
Expand Down
28 changes: 20 additions & 8 deletions texar/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from texar.utils.dtypes import is_str, is_callable, compat_as_text, \
_maybe_list_to_array

# pylint: disable=anomalous-backslash-in-string

MAX_SEQ_LENGTH = np.iinfo(np.int32).max

## Some modules cannot be imported directly,
Expand All @@ -44,6 +46,7 @@
#}

__all__ = [
"_inspect_getargspec",
"check_or_get_class",
"get_class",
"check_or_get_instance",
Expand Down Expand Up @@ -81,6 +84,15 @@ def _expand_name(name):
"""
return name

def _inspect_getargspec(fn):
"""Returns `inspect.getargspec(fn)` for Py2 and `inspect.getfullargspec(fn)`
for Py3
"""
try:
return inspect.getfullargspec(fn)
except AttributeError:
return inspect.getargspec(fn)

def check_or_get_class(class_or_name, module_path=None, superclass=None):
"""Returns the class and checks if the class inherits :attr:`superclass`.
Expand Down Expand Up @@ -214,14 +226,14 @@ def get_instance(class_or_name, kwargs, module_paths=None):
if is_str(class_):
class_ = get_class(class_, module_paths)
# Check validity of arguments
class_args = set(inspect.getargspec(class_.__init__).args)
class_args = set(_inspect_getargspec(class_.__init__).args)
if kwargs is None:
kwargs = {}
for key in kwargs.keys():
if key not in class_args:
raise ValueError(
"Invalid argument for class %s.%s: %s, valid args:%s" %
(class_.__module__, class_.__name__, key, class_args))
"Invalid argument for class %s.%s: %s, valid args: %s" %
(class_.__module__, class_.__name__, key, list(class_args)))

return class_(**kwargs)

Expand Down Expand Up @@ -293,7 +305,7 @@ class construction method are used.

# Select valid arguments
selected_kwargs = {}
class_args = set(inspect.getargspec(class_.__init__).args)
class_args = set(_inspect_getargspec(class_.__init__).args)
if kwargs is None:
kwargs = {}
for key, value in kwargs.items():
Expand Down Expand Up @@ -354,9 +366,9 @@ def call_function_with_redundant_kwargs(fn, kwargs):
The returned results by calling :attr:`fn`.
"""
try:
fn_args = set(inspect.getargspec(fn).args)
fn_args = set(_inspect_getargspec(fn).args)
except TypeError:
fn_args = set(inspect.getargspec(fn.__call__).args)
fn_args = set(_inspect_getargspec(fn.__call__).args)

if kwargs is None:
kwargs = {}
Expand All @@ -378,7 +390,7 @@ def get_args(fn):
Returns:
list: A list of argument names (str) of the function.
"""
argspec = inspect.getargspec(fn)
argspec = _inspect_getargspec(fn)
return argspec.args

def get_default_arg_values(fn):
Expand All @@ -393,7 +405,7 @@ def get_default_arg_values(fn):
dict: A dictionary that maps argument names (str) to their default
values. The dictionary is empty if no arguments have default values.
"""
argspec = inspect.getargspec(fn)
argspec = _inspect_getargspec(fn)
if argspec.defaults is None:
return {}
num_defaults = len(argspec.defaults)
Expand Down

0 comments on commit 71a7cee

Please sign in to comment.