Skip to content

Commit

Permalink
polish helper docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhitingHu committed Apr 9, 2019
1 parent 204d886 commit b820d1c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion texar/modules/decoders/rnn_decoder_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __init__(self, embedding, start_tokens, end_token, top_k=10,
self._seed = seed

def sample(self, time, outputs, state, name=None):
"""sample for SampleEmbeddingHelper."""
"""Gets a sample for one step."""
del time, state # unused by sample_fn
# Outputs are logits, we sample from the top_k candidates
if not isinstance(outputs, tf.Tensor):
Expand Down
15 changes: 11 additions & 4 deletions texar/modules/decoders/tf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,14 @@ def initialize(self, name=None):
return (finished, next_inputs)

def sample(self, time, outputs, name=None, **unused_kwargs):
"""Gets a sample for one step."""
with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
sample_ids = math_ops.cast(
math_ops.argmax(outputs, axis=-1), dtypes.int32)
return sample_ids

def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for TrainingHelper."""
"""Gets the inputs for next step."""
with ops.name_scope(name, "TrainingHelperNextInputs",
[time, outputs, state]):
next_time = time + 1
Expand Down Expand Up @@ -335,6 +336,7 @@ def initialize(self, name=None):
name=name)

def sample(self, time, outputs, state, name=None):
"""Gets a sample for one step."""
with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
[time, outputs, state]):
# Return -1s where we did not sample, and sample_ids elsewhere
Expand All @@ -349,6 +351,7 @@ def sample(self, time, outputs, state, name=None):
gen_array_ops.fill([self.batch_size], -1))

def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""Gets the outputs for next step."""
with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperNextInputs",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
Expand Down Expand Up @@ -462,12 +465,14 @@ def initialize(self, name=None):
return super(ScheduledOutputTrainingHelper, self).initialize(name=name)

def sample(self, time, outputs, state, name=None):
"""Gets a sample for one step."""
with ops.name_scope(name, "ScheduledOutputTrainingHelperSample",
[time, outputs, state]):
sampler = bernoulli.Bernoulli(probs=self._sampling_probability)
return sampler.sample(sample_shape=self.batch_size, seed=self._seed)

def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""Gets the next inputs for next step."""
with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
Expand Down Expand Up @@ -602,7 +607,7 @@ def initialize(self, name=None):
return finished, self._start_inputs

def sample(self, time, outputs, state, name=None):
"""sample for GreedyEmbeddingHelper."""
"""Gets a sample for one step."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, ops.Tensor):
Expand All @@ -612,7 +617,7 @@ def sample(self, time, outputs, state, name=None):
return sample_ids

def next_inputs(self, time, outputs, state, sample_ids, name=None, reach_max_time=None):
"""next_inputs_fn for GreedyEmbeddingHelper."""
"""Gets the inputs for next step."""
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)
if reach_max_time is not None:
Expand Down Expand Up @@ -683,7 +688,7 @@ def __init__(self, embedding, start_tokens, end_token,
self._seed = seed

def sample(self, time, outputs, state, name=None):
"""sample for SampleEmbeddingHelper."""
"""Gets a sample for one step."""
del time, state # unused by sample_fn
# Outputs are logits, we sample instead of argmax (greedy).
if not isinstance(outputs, ops.Tensor):
Expand Down Expand Up @@ -745,10 +750,12 @@ def initialize(self, name=None):
return (finished, self._start_inputs)

def sample(self, time, outputs, state, name=None):
"""Gets a sample for one step."""
del time, state # unused by sample
return self._sample_fn(outputs)

def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""Gets the outputs for next step."""
del time, outputs # unused by next_inputs
if self._next_inputs_fn is None:
next_inputs = sample_ids
Expand Down

0 comments on commit b820d1c

Please sign in to comment.