Skip to content

Commit

Permalink
added dropout strategies in embedders
Browse files Browse the repository at this point in the history
Former-commit-id: fc094aa
  • Loading branch information
ZhitingHu committed May 28, 2018
1 parent 791e732 commit 89b081e
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 52 deletions.
33 changes: 26 additions & 7 deletions texar/modules/embedders/embedder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from texar.module_base import ModuleBase
from texar.modules.embedders import embedder_utils

# pylint: disable=invalid-name

__all__ = [
"EmbedderBase"
]
Expand All @@ -32,18 +34,35 @@ def _init_parameterized_embedding(self, init_value, num_embeds, hparams):
if hparams.trainable:
self._add_trainable_variable(self._embedding)

self._dropout_layer = None
if hparams.dropout_rate > 0.:
with tf.variable_scope(tf.variable_scope):
self._dropout_layer = tf.layers.Dropout(
rate=hparams.dropout_rate)

self._num_embeds = self._embedding.get_shape().as_list()[0]

self._dim = self._embedding.get_shape().as_list()[1:]
if len(self._dim) == 1:
self._dim_rank = len(self._dim)
if self._dim_rank == 1:
self._dim = self._dim[0]

def _get_dropout_layer(self, hparams, inputs):
"""Creates dropout layer according to dropout strategy.
Called in :meth:`_build()`.
"""
dropout_layer = None
if hparams.dropout_rate > 0.:
st = hparams.dropout_strategy
if st == 'element':
noise_shape = None
elif st == 'item':
index_rank = len(inputs.shape.dims) - 1
noise_shape = [None] * index_rank + [1] * self._dim_rank
elif st == 'item_type':
noise_shape = [None] + [1] * self._dim_rank
else:
raise ValueError('Unknown dropout strategy: {}'.format(st))

dropout_layer = tf.layers.Dropout(
rate=hparams.dropout_rate, noise_shape=noise_shape)

return dropout_layer

@staticmethod
def default_hparams():
Expand Down
43 changes: 29 additions & 14 deletions texar/modules/embedders/embedder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def default_embedding_hparams():
}
},
"dropout_rate": 0,
"dropout_strategy": 'element',
"trainable": True,
}
Expand All @@ -45,8 +46,9 @@ def default_embedding_hparams():
"name" : str
Name of the embedding variable.
"dim" : int
Embedding dimension.
"dim" : int or list
Embedding dimension. Can be a list of integers to yield embeddings
with dimensionality > 1.
"initializer" : dict or None
Hyperparameters of the initializer for the embedding values. An
Expand Down Expand Up @@ -124,6 +126,19 @@ class can be
The dropout rate between 0 and 1. E.g., `dropout_rate=0.1` would
drop out 10% of the embedding.
"dropout_strategy" : str
The dropout strategy. Can be one of the following
- 'element': The regular strategy that drops individual elements \
in the embedding vectors.
- 'item': Drops individual items (e.g., words) entirely. E.g., for \
the word sequence 'the simpler the better', the strategy can \
yield '_ simpler the better', where the first `the` is dropped.
- 'item_type': Drops item types (e.g., word types). E.g., for the \
above sequence, the strategy can yield '_ simpler _ better', \
where the word type 'the' is dropped. The dropout will never \
yield '_ simpler the better' as in the 'item' strategy.
"trainable" : bool
Whether the embedding is trainable.
"""
Expand All @@ -133,13 +148,15 @@ class can be
"initializer": None,
"regularizer": layers.default_regularizer_hparams(),
"dropout_rate": 0,
"trainable": True
"dropout_strategy": 'element',
"trainable": True,
"@no_typecheck": ["dim"]
}


def get_embedding(hparams=None,
init_value=None,
vocab_size=None,
num_embeds=None,
variable_scope='Embedding'):
"""Creates embedding variable if not exists.
Expand All @@ -154,24 +171,28 @@ def get_embedding(hparams=None,
init_value (Tensor or numpy array, optional): Initial values of the
embedding variable. If not given, embedding is initialized as
specified in :attr:`hparams["initializer"]`.
vocab_size (int, optional): The vocabulary size. Required if
:attr:`init_value` is not provided.
num_embeds (int, optional): The number of embedding items
(e.g., vocabulary size). Required if :attr:`init_value` is
not provided.
variable_scope (str or VariableScope, optional): Variable scope of
the embedding variable.
Returns:
Variable or Tensor: A 2D `Variable` or `Tensor` of the same shape with
:attr:`init_value` or of the shape
:attr:`[vocab_size, hparams["dim"]]`.
:attr:`[num_embeds, hparams["dim"]]`.
"""
with tf.variable_scope(variable_scope):
if hparams is None or isinstance(hparams, dict):
hparams = HParams(hparams, default_embedding_hparams())
regularizer = layers.get_regularizer(hparams["regularizer"])
if init_value is None:
initializer = layers.get_initializer(hparams["initializer"])
dim = hparams["dim"]
if not isinstance(hparams["dim"], (list, tuple)):
dim = [dim]
embedding = tf.get_variable(name=hparams["name"],
shape=[vocab_size, hparams["dim"]],
shape=[num_embeds] + dim,
initializer=initializer,
regularizer=regularizer,
trainable=hparams["trainable"])
Expand All @@ -181,10 +202,4 @@ def get_embedding(hparams=None,
regularizer=regularizer,
trainable=hparams["trainable"])

#if hparams["dropout_rate"] > 0.:
# keep_prob = utils.switch_dropout(
# hparams["dropout"]["keep_prob"], mode)
# embedding = tf.nn.dropout(embedding, keep_prob=keep_prob)
# # TODO: Return value type changed and may not be compatible with
# # previous semantic.
return embedding
20 changes: 15 additions & 5 deletions texar/modules/embedders/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"WordEmbedder"
]

#TODO(zhiting): add soft-embedder, position-embedder, embedder combiner
#TODO(zhiting): add soft-embedder, embedder combiner


class WordEmbedder(EmbedderBase):
Expand All @@ -29,7 +29,8 @@ class WordEmbedder(EmbedderBase):
Args:
init_value (optional): A `Tensor` or numpy array that contains the
initial value of embeddings. It is typically of shape
`[vocab_size, embedding dim]`
`[vocab_size] + embedding dim`. Embedding can have dimensionality
> 1.
If `None`, embedding is initialized as specified in
:attr:`hparams["initializer"]`. Otherwise, the
Expand Down Expand Up @@ -90,6 +91,7 @@ def default_hparams():
}
},
"dropout_rate": 0,
"dropout_strategy": 'element',
"trainable": True,
}
Expand Down Expand Up @@ -117,11 +119,19 @@ def _build(self, inputs, mode=None, **kwargs):
A `Tensor` of shape `shape(inputs) + embedding dimension`.
"""
embedding = self._embedding
if self._dropout_layer is not None:
dropout_layer = self._get_dropout_layer(self._hparams, inputs)
if dropout_layer:
is_training = utils.is_train_mode(mode)
embedding = self._dropout_layer.apply(
inputs=embedding, training=is_training)
if self._hparams.dropout_strategy == 'item_type':
embedding = dropout_layer.apply(
inputs=embedding, training=is_training)

outputs = tf.nn.embedding_lookup(embedding, inputs, **kwargs)

if dropout_layer and self._hparams.dropout_strategy != 'item_type':
outputs = dropout_layer.apply(
inputs=outputs, training=is_training)

return outputs

@property
Expand Down
118 changes: 101 additions & 17 deletions texar/modules/embedders/embedders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,57 @@
from texar.modules.embedders.position_embedders import PositionEmbedder
from texar.context import global_mode

class WordEmbedderTest(tf.test.TestCase):
"""Tests word embedder.
class EmbedderTest(tf.test.TestCase):
"""Tests parameterized embedder.
"""
def test_word_embedder(self):

def _test_word_embedder(self, hparams):
"""Tests :class:`texar.modules.WordEmbedder`.
"""
embedder = WordEmbedder(
vocab_size=100,
hparams={"dim": 1024, "dropout_rate": 0.3})
vocab_size=100, hparams=hparams)
inputs = tf.ones([64, 16], dtype=tf.int32)
outputs = embedder(inputs)
self.assertEqual(outputs.shape, [64, 16, 1024])
self.assertEqual(embedder.dim, 1024)

emb_dim = embedder.dim
if not isinstance(emb_dim, (list, tuple)):
emb_dim = [emb_dim]

hparams_dim = hparams["dim"]
if not isinstance(hparams["dim"], (list, tuple)):
hparams_dim = [hparams["dim"]]

self.assertEqual(outputs.shape, [64, 16] + emb_dim)
self.assertEqual(emb_dim, hparams_dim)
self.assertEqual(embedder.vocab_size, 100)
self.assertEqual(len(embedder.trainable_variables), 1)

class PositionEmbedderTest(tf.test.TestCase):
"""Tests position embedder.
"""
def test_position_embedder(self):
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
outputs_ = sess.run(
outputs,
feed_dict={global_mode(): tf.estimator.ModeKeys.TRAIN})
self.assertEqual(outputs_.shape, (64, 16) + tuple(emb_dim))

def _test_position_embedder(self, hparams):
"""Tests :class:`texar.modules.PositionEmbedder`.
"""
pos_size = 100
embedder = PositionEmbedder(
position_size=pos_size, hparams={"dim": 1024})
inputs = tf.random_uniform([64, 16], maxval=pos_size, dtype=tf.int32)
outputs = embedder(positions=inputs)
self.assertEqual(outputs.shape, [64, 16, 1024])
self.assertEqual(embedder.dim, 1024)
position_size=pos_size, hparams=hparams)
inputs = tf.ones([64, 16], dtype=tf.int32)
outputs = embedder(inputs)

emb_dim = embedder.dim
if not isinstance(emb_dim, (list, tuple)):
emb_dim = [emb_dim]

hparams_dim = hparams["dim"]
if not isinstance(hparams["dim"], (list, tuple)):
hparams_dim = [hparams["dim"]]

self.assertEqual(outputs.shape, [64, 16] + emb_dim)
self.assertEqual(emb_dim, hparams_dim)
self.assertEqual(embedder.position_size, 100)
self.assertEqual(len(embedder.trainable_variables), 1)

Expand All @@ -55,7 +77,69 @@ def test_position_embedder(self):
outputs_, max_seq_length = sess.run(
[outputs, tf.reduce_max(seq_length)],
feed_dict={global_mode(): tf.estimator.ModeKeys.TRAIN})
self.assertEqual(outputs_.shape, (64, max_seq_length, 1024))
self.assertEqual(outputs_.shape,
(64, max_seq_length) + tuple(emb_dim))


def test_embedder(self):
"""Tests various embedders.
"""
# no dropout
hparams = {"dim": 1024, "dropout_rate": 0}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

hparams = {"dim": [1024], "dropout_rate": 0}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

hparams = {"dim": [1024, 10], "dropout_rate": 0}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

# dropout with default strategy
hparams = {"dim": 1024, "dropout_rate": 0.3}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

hparams = {"dim": [1024], "dropout_rate": 0.3}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

hparams = {"dim": [1024, 10], "dropout_rate": 0.3}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

# dropout with different strategies
hparams = {"dim": 1024, "dropout_rate": 0.3,
"dropout_strategy": "item"}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

hparams = {"dim": [1024], "dropout_rate": 0.3,
"dropout_strategy": "item"}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

hparams = {"dim": [1024, 10], "dropout_rate": 0.3,
"dropout_strategy": "item"}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

hparams = {"dim": 1024, "dropout_rate": 0.3,
"dropout_strategy": "item_type"}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

hparams = {"dim": [1024], "dropout_rate": 0.3,
"dropout_strategy": "item_type"}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

hparams = {"dim": [1024, 10], "dropout_rate": 0.3,
"dropout_strategy": "item_type"}
self._test_word_embedder(hparams)
self._test_position_embedder(hparams)

if __name__ == "__main__":
tf.test.main()
25 changes: 16 additions & 9 deletions texar/modules/embedders/position_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,8 @@ def _build(self, positions=None, sequence_length=None, mode=None, **kwargs):
Returns:
A `Tensor` of shape `shape(inputs) + embedding dimension`.
"""
embedding = self._embedding
if self._dropout_layer is not None:
is_training = utils.is_train_mode(mode)
embedding = self._dropout_layer.apply(
inputs=embedding, training=is_training)

inputs = positions
if inputs is None:
if positions is None:
if sequence_length is None:
raise ValueError(
'Either `positions` or `sequence_length` is required.')
Expand All @@ -139,10 +133,23 @@ def _build(self, positions=None, sequence_length=None, mode=None, **kwargs):
inputs = tf.tile(tf.expand_dims(single_inputs, 0),
[utils.get_batch_size(sequence_length), 1])

embedding = self._embedding
dropout_layer = self._get_dropout_layer(self._hparams, inputs)
if dropout_layer:
is_training = utils.is_train_mode(mode)
if self._hparams.dropout_strategy == 'item_type':
embedding = dropout_layer.apply(
inputs=embedding, training=is_training)

outputs = tf.nn.embedding_lookup(embedding, inputs, **kwargs)

if inputs is None:
outputs = utils.mask_sequences(outputs, sequence_length, rank=3)
if dropout_layer and self._hparams.dropout_strategy != 'item_type':
outputs = dropout_layer.apply(
inputs=outputs, training=is_training)

if positions is None:
outputs = utils.mask_sequences(
outputs, sequence_length, rank=2+self._dim_rank)

return outputs

Expand Down

0 comments on commit 89b081e

Please sign in to comment.