Skip to content

Commit

Permalink
polish beam_search
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhitingHu committed Apr 5, 2019
1 parent 6a6202c commit d7a7385
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions texar/utils/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modifications copyright (C) 2018 Texar
# Modifications copyright (C) 2019 Texar
# ==============================================================================
"""
Implemetation of beam seach with penalties.
Expand Down Expand Up @@ -103,6 +103,7 @@ def compute_batch_indices(batch_size, beam_size):
Args:
batch_size: Batch size
beam_size: Size of the beam.
Returns:
batch_pos: [batch_size, beam_size] tensor of ids
"""
Expand Down Expand Up @@ -142,11 +143,12 @@ def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags,
grow_finished, we will need to return the length penalized
scors.
flags: Tensor of bools for sequences that say whether a sequence
has reached EOS or not
has reached EOS or not
beam_size: int
batch_size: int
prefix: string that will prefix unique names for the ops run.
states_to_gather: dict (possibly nested) of decoding states.
Returns:
Tuple of
(topk_seq [batch_size, beam_size, decode_length],
Expand Down Expand Up @@ -184,14 +186,14 @@ def gather(tensor, name):


def beam_search(symbols_to_logits_fn,
initial_ids,
beam_size,
decode_length,
vocab_size,
alpha,
eos_id,
states=None,
stop_early=True):
initial_ids,
beam_size,
decode_length,
vocab_size,
alpha,
eos_id,
states=None,
stop_early=True):
"""Beam search with length penalties.
Requires a function that can take the currently decoded sybmols and
Expand Down Expand Up @@ -222,20 +224,21 @@ def beam_search(symbols_to_logits_fn,
Args:
symbols_to_logits_fn: Interface to the model, to provide logits.
Shoud take [batch_size, decoded_ids] and return
[batch_size, vocab_size]
Should take [batch_size, decoded_ids] and return
[batch_size, vocab_size]
initial_ids: Ids to start off the decoding, this will be the first
thing handed to symbols_to_logits_fn (after expanding to beam size)
thing handed to symbols_to_logits_fn (after expanding to beam size)
[batch_size]
beam_size: Size of the beam.
decode_length: Number of steps to decode for.
vocab_size: Size of the vocab, must equal the size of the logits
returned by symbols_to_logits_fn
returned by symbols_to_logits_fn
alpha: alpha for length penalty.
states: dict (possibly nested) of decoding states.
eos_id: ID for end of sentence.
stop_early: a boolean - stop once best sequence is provably
determined.
determined.
Returns:
Tuple of
(decoded beams [batch_size, beam_size, decode_length]
Expand Down Expand Up @@ -282,12 +285,13 @@ def grow_finished(finished_seq, finished_scores, finished_flags,
finished_flags: finished bools for each of these sequences.
[batch_size, beam_size]
curr_seq: current topk sequence that has been grown by one
position.
position.
[batch_size, beam_size, current_decoded_length]
curr_scores: scores for each of these sequences. [batch_size,
beam_size]
beam_size]
curr_finished: Finished flags for each of these sequences.
[batch_size, beam_size]
Returns:
Tuple of
(Topk sequences based on scores,
Expand Down Expand Up @@ -321,7 +325,7 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
Args:
curr_seq: current topk sequence that has been grown by one
position.
position.
[batch_size, beam_size, i+1]
curr_scores: scores for each of these sequences. [batch_size,
beam_size]
Expand All @@ -330,6 +334,7 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
curr_finished: Finished flags for each of these sequences.
[batch_size, beam_size]
states: dict (possibly nested) of decoding states.
Returns:
Tuple of
(Topk sequences based on scores,
Expand Down Expand Up @@ -363,6 +368,7 @@ def grow_topk(i, alive_seq, alive_log_probs, states):
alive_log_probs: probabilities of these sequences.
[batch_size, beam_size]
states: dict (possibly nested) of decoding states.
Returns:
Tuple of
(Topk sequences extended by the next word,
Expand Down Expand Up @@ -521,7 +527,7 @@ def _is_finished(i, unused_alive_seq, alive_log_probs,
finished_scores: scores for each of these sequences.
[batch_size, beam_size]
finished_in_finished: finished bools for each of these
sequences. [batch_size, beam_size]
sequences. [batch_size, beam_size]
Returns:
Bool.
Expand Down

0 comments on commit d7a7385

Please sign in to comment.