Skip to content

Commit

Permalink
fixed module trainable variable collector
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhitingHu committed Aug 22, 2017
1 parent ff81731 commit 406f63d
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 7 deletions.
3 changes: 3 additions & 0 deletions txtgen/modules/connectors/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _build(self, encoder_state): # pylint: disable=W0221
"""
nest.assert_same_structure(encoder_state,
self._decoder_state_size)

return encoder_state

@staticmethod
Expand Down Expand Up @@ -118,6 +119,8 @@ def _build(self, encoder_result): #pylint: disable=W0221
decoder_state = _mlp_transform(
encoder_result, self._decoder_state_size, activation_fn)

self._add_internal_trainable_variables()

return decoder_state

@staticmethod
Expand Down
7 changes: 6 additions & 1 deletion txtgen/modules/decoders/rnn_decoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class RNNDecoderBase(ModuleBase, TFDecoder):
"""Base class inherited by all RNN decoder classes.
"""

def __init__(self, cell=None, hparams=None, name="decoder"):
def __init__(self, cell=None, hparams=None, name="rnn_decoder"):
"""Initializes the decoder.
Args:
Expand Down Expand Up @@ -61,6 +61,11 @@ def _build(self, helper, initial_state): # pylint: disable=W0221
outputs, final_state, sequence_lengths = dynamic_decode(
decoder=self, maximum_iterations=max_decoding_length)

self._add_internal_trainable_variables()
# Add trainable variables of `self._cell` which may be constructed
# externally
self._add_trainable_variable(self._cell.trainable_variables())

return outputs, final_state, sequence_lengths

@staticmethod
Expand Down
11 changes: 9 additions & 2 deletions txtgen/modules/encoders/rnn_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,24 @@ def _build(self, inputs, **kwargs):
Outputs and final state of the encoder.
"""
if ('dtype' not in kwargs) and ('initial_state' not in kwargs):
return tf.nn.dynamic_rnn(
results = tf.nn.dynamic_rnn(
cell=self._cell,
inputs=inputs,
dtype=tf.float32,
**kwargs)
else:
return tf.nn.dynamic_rnn(
results = tf.nn.dynamic_rnn(
cell=self._cell,
inputs=inputs,
**kwargs)

self._add_internal_trainable_variables()
# Add trainable variables of `self._cell` which may be constructed
# externally
self._add_trainable_variable(self._cell.trainable_variables())

return results

@staticmethod
def default_hparams():
"""Returns a dictionary of hyperparameters with default values.
Expand Down
42 changes: 38 additions & 4 deletions txtgen/modules/module_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from __future__ import division
from __future__ import print_function

import re

import tensorflow as tf

from txtgen.hyperparams import HParams
Expand All @@ -29,8 +31,8 @@ def __init__(self, name, hparams=None):
create_scope_now_=True)
self._hparams = HParams(hparams, self.default_hparams())
self._unique_name = self._template.variable_scope.name.split("/")[-1]
self._variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=self.variable_scope.name)
self._trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.variable_scope.name)


def _build(self, *args, **kwargs):
Expand All @@ -57,6 +59,38 @@ def __call__(self, *args, **kwargs):
"""
return self._template(*args, **kwargs)

def _add_internal_trainable_variables(self): # pylint: disable=invalid-name
"""Collects trainable variables constructured internally in this module.
This is typically called at the end of `_build()` where all necessary
trainable variables have been constructed.
"""
scope_name = self.variable_scope.name
# Escape to handle possible "." characters in the name.
# Append a slash to the end to avoid searching scopes that have this
# scope name as a prefix.
scope_name = re.escape(scope_name) + "/"
internal_trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope_name)
self._add_trainable_variable(internal_trainable_variables)

def _add_trainable_variable(self, variable):
"""Adds a trainable variable to the trainable variable list of the
module.
Args:
variable: a (list of) trainable variable(s) constructed either
internally in the module or constructured outside but used
inside the module.
"""
if isinstance(variable, list):
for var in variable:
if var not in self.trainable_variables:
self._trainable_variables.append(var)
else:
if variable not in self.trainable_variables:
self._trainable_variables.append(variable)

@staticmethod
def default_hparams():
"""Returns a dictionary of default hyperparameters of the module.
Expand All @@ -78,10 +112,10 @@ def module_name(self):
return self._unique_name

@property
def variables(self):
def trainable_variables(self):
"""Returns the list of trainable variables of the module.
"""
return self._variables
return self._trainable_variables

@property
def hparams(self):
Expand Down

0 comments on commit 406f63d

Please sign in to comment.