Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve issue mentioned in #242 #246

Merged
merged 3 commits into from
Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions texar/tf/data/tokenizers/xlnet_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class XLNetTokenizerTest(tf.test.TestCase):
def setUp(self):
self.tmp_dir = tempfile.TemporaryDirectory()
self.SAMPLE_VOCAB = maybe_download(
'https://github.com/gpengzhi/pytorch-transformers/blob/master/'
'pytorch_transformers/tests/fixtures/test_sentencepiece.model'
'?raw=true', self.tmp_dir.name)
'https://github.com/huggingface/transformers/blob/master/'
gpengzhi marked this conversation as resolved.
Show resolved Hide resolved
'transformers/tests/fixtures/test_sentencepiece.model?raw=true',
self.tmp_dir.name)

self.tokenizer = XLNetTokenizer.load(
self.SAMPLE_VOCAB[0], configs={'keep_accents': True})
Expand Down
295 changes: 151 additions & 144 deletions texar/tf/modules/decoders/dynamic_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def dynamic_decode(decoder,
Args:
decoder: A `Decoder` instance.
output_time_major: Python boolean. Default: `False` (batch major). If
`True`, outputs are returned as time major tensors (this mode is faster).
Otherwise, outputs are returned as batch major tensors (this adds extra
time to the computation).
`True`, outputs are returned as time major tensors (this mode is
faster). Otherwise, outputs are returned as batch major tensors
(this adds extra time to the computation).
impute_finished: Python boolean. If `True`, then states for batch
entries which are marked as finished get copied through and the
corresponding outputs get zeroed out. This causes some slowdown at
Expand All @@ -186,153 +186,160 @@ def dynamic_decode(decoder,
type(decoder))

with tf.variable_scope(scope, "decoder") as varscope:
# Properly cache variable values inside the while_loop
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)

if maximum_iterations is not None:
maximum_iterations = tf.convert_to_tensor(
maximum_iterations, dtype=tf.int32, name="maximum_iterations")
if maximum_iterations.get_shape().ndims != 0:
raise ValueError("maximum_iterations must be a scalar")

initial_finished, initial_inputs, initial_state = decoder.initialize()

zero_outputs = _create_zero_outputs(decoder.output_size,
decoder.output_dtype,
decoder.batch_size)

if maximum_iterations is not None:
initial_finished = tf.logical_or(
initial_finished, 0 >= maximum_iterations)
initial_sequence_lengths = tf.zeros_like(
initial_finished, dtype=tf.int32)
initial_time = tf.constant(0, dtype=tf.int32)

def _shape(batch_size, from_shape):
if (not isinstance(from_shape, tensor_shape.TensorShape) or
from_shape.ndims == 0):
return None
else:
batch_size = tf.get_static_value(
tf.convert_to_tensor(
batch_size, name="batch_size"))
return tensor_shape.TensorShape([batch_size]).\
concatenate(from_shape)

dynamic_size = True

def _create_ta(s, d):
return tf.TensorArray(
dtype=d,
size=0 if dynamic_size else maximum_iterations,
dynamic_size=dynamic_size,
element_shape=_shape(decoder.batch_size, s))

initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
decoder.output_dtype)

def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
finished, unused_sequence_lengths):
cond = tf.logical_not(tf.reduce_all(finished))
cond_time = (maximum_iterations is None or
unused_time < maximum_iterations)
ret = tf.logical_and(cond, tf.convert_to_tensor(cond_time))
return ret

def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
r"""Internal while_loop body.

Args:
time: scalar int32 tensor.
outputs_ta: structure of TensorArray.
state: (structure of) state tensors and TensorArrays.
inputs: (structure of) input tensors.
finished: bool tensor (keeping track of what's finished).
sequence_lengths: int32 tensor (keeping track of time of finish).

Returns:
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)`.
"""
(next_outputs, state) = decoder.step(time, inputs, state)

# Check if the maximum iteration is met. If it is met, do not compute
# the next inputs.
reach_max = tf.equal(time + 1, maximum_iterations)
(decoder_finished, next_inputs, decoder_state) = tf.cond(
reach_max,
lambda: (tf.cast(tf.ones_like(finished), tf.bool),
inputs, state),
lambda: decoder.next_inputs(time, next_outputs, state)
)
if decoder.tracks_own_finished:
next_finished = decoder_finished
else:
next_finished = tf.logical_or(decoder_finished, finished)
next_sequence_lengths = tf.where(
tf.logical_not(finished),
tf.fill(tf.shape(sequence_lengths), time + 1),
sequence_lengths)

nest.assert_same_structure(state, decoder_state)
nest.assert_same_structure(outputs_ta, next_outputs)
nest.assert_same_structure(inputs, next_inputs)

# Zero out output values past finish
if impute_finished:
emit = nest.map_structure(
lambda out, zero: tf.where(finished, zero, out),
next_outputs,
zero_outputs)
else:
emit = next_outputs

# Copy through states past finish
def _maybe_copy_state(new, cur):
# TensorArrays and scalar states get passed through.
if isinstance(cur, tf.TensorArray):
pass_through = True
initial_finished, initial_inputs, initial_state = decoder.initialize()

zero_outputs = _create_zero_outputs(decoder.output_size,
decoder.output_dtype,
decoder.batch_size)

if maximum_iterations is not None:
initial_finished = tf.logical_or(
initial_finished, 0 >= maximum_iterations)
initial_sequence_lengths = tf.zeros_like(
initial_finished, dtype=tf.int32)
initial_time = tf.constant(0, dtype=tf.int32)

def _shape(batch_size, from_shape):
if (not isinstance(from_shape, tensor_shape.TensorShape) or
from_shape.ndims == 0):
return None
else:
batch_size = tf.get_static_value(
tf.convert_to_tensor(
batch_size, name="batch_size"))
return tensor_shape.TensorShape([batch_size]).\
concatenate(from_shape)

dynamic_size = True

def _create_ta(s, d):
return tf.TensorArray(
dtype=d,
size=0 if dynamic_size else maximum_iterations,
dynamic_size=dynamic_size,
element_shape=_shape(decoder.batch_size, s))

initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
decoder.output_dtype)

def condition(unused_time, unused_outputs_ta, unused_state,
unused_inputs, finished, unused_sequence_lengths):
cond = tf.logical_not(tf.reduce_all(finished))
cond_time = (maximum_iterations is None or
unused_time < maximum_iterations)
ret = tf.logical_and(cond, tf.convert_to_tensor(cond_time))
return ret

def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
r"""Internal while_loop body.

Args:
time: scalar int32 tensor.
outputs_ta: structure of TensorArray.
state: (structure of) state tensors and TensorArrays.
inputs: (structure of) input tensors.
finished: bool tensor (keeping track of what's finished).
sequence_lengths: int32 tensor (keeping track of time of
finish).

Returns:
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)`.
"""
(next_outputs, state) = decoder.step(time, inputs, state)

# Check if the maximum iteration is met. If it is met, do not
# compute the next inputs.
reach_max = tf.equal(time + 1, maximum_iterations)
(decoder_finished, next_inputs, decoder_state) = tf.cond(
reach_max,
lambda: (tf.cast(tf.ones_like(finished), tf.bool),
inputs, state),
lambda: decoder.next_inputs(time, next_outputs, state)
)
if decoder.tracks_own_finished:
next_finished = decoder_finished
else:
next_finished = tf.logical_or(decoder_finished, finished)
next_sequence_lengths = tf.where(
tf.logical_not(finished),
tf.fill(tf.shape(sequence_lengths), time + 1),
sequence_lengths)

nest.assert_same_structure(state, decoder_state)
nest.assert_same_structure(outputs_ta, next_outputs)
nest.assert_same_structure(inputs, next_inputs)

# Zero out output values past finish
if impute_finished:
emit = nest.map_structure(
lambda out, zero: tf.where(finished, zero, out),
next_outputs,
zero_outputs)
else:
emit = next_outputs

# Copy through states past finish
def _maybe_copy_state(new, cur):
# TensorArrays and scalar states get passed through.
if isinstance(cur, tf.TensorArray):
pass_through = True
else:
new.set_shape(cur.shape)
pass_through = (new.shape.ndims == 0)
return new if pass_through else tf.where(finished, cur, new)

if impute_finished:
next_state = nest.map_structure(
_maybe_copy_state, decoder_state, state)
else:
new.set_shape(cur.shape)
pass_through = (new.shape.ndims == 0)
return new if pass_through else tf.where(finished, cur, new)

if impute_finished:
next_state = nest.map_structure(
_maybe_copy_state, decoder_state, state)
else:
next_state = decoder_state

outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
outputs_ta, emit)
return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)

res = tf.while_loop(
condition,
body,
loop_vars=(
initial_time,
initial_outputs_ta,
initial_state,
initial_inputs,
initial_finished,
initial_sequence_lengths,
),
parallel_iterations=parallel_iterations,
maximum_iterations=maximum_iterations,
swap_memory=swap_memory)

final_outputs_ta = res[1]
final_state = res[2]
final_sequence_lengths = res[5]

final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)

try:
final_outputs, final_state = decoder.finalize(
final_outputs, final_state, final_sequence_lengths)
except NotImplementedError:
pass

if not output_time_major:
final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
next_state = decoder_state

outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
outputs_ta, emit)
return (time + 1, outputs_ta, next_state, next_inputs,
next_finished, next_sequence_lengths)

res = tf.while_loop(
condition,
body,
loop_vars=(
initial_time,
initial_outputs_ta,
initial_state,
initial_inputs,
initial_finished,
initial_sequence_lengths,
),
parallel_iterations=parallel_iterations,
maximum_iterations=maximum_iterations,
swap_memory=swap_memory)

final_outputs_ta = res[1]
final_state = res[2]
final_sequence_lengths = res[5]

final_outputs = nest.map_structure(lambda ta: ta.stack(),
final_outputs_ta)

try:
final_outputs, final_state = decoder.finalize(
final_outputs, final_state, final_sequence_lengths)
except NotImplementedError:
pass

if not output_time_major:
final_outputs = nest.map_structure(_transpose_batch_time,
final_outputs)

return final_outputs, final_state, final_sequence_lengths