Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Prune finished states in ERM (#1187)
Browse files Browse the repository at this point in the history
* prune finished states in erm

* addressed pr comments
  • Loading branch information
pdasigi committed May 12, 2018
1 parent 77b33fb commit b5af3e4
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 20 deletions.
Expand Up @@ -51,6 +51,8 @@ class NlvrCoverageSemanticParser(NlvrSemanticParser):
attention.
beam_size : ``int``
Beam size for the beam search used during training.
max_num_finished_states : ``int``
Maximum number of finished states the trainer should compute costs for.
normalize_beam_score_by_length : ``bool``, optional (default=False)
Should the log probabilities be normalized by length before renormalizing them? Edunov et
al. do this in their work, but we found that not doing it works better. It's possible they
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(self,
encoder: Seq2SeqEncoder,
attention_function: SimilarityFunction,
beam_size: int,
max_num_finished_states: int,
max_decoding_steps: int,
normalize_beam_score_by_length: bool = False,
checklist_cost_weight: float = 0.6,
Expand All @@ -93,7 +96,10 @@ def __init__(self,
encoder=encoder)
self._agenda_coverage = Average()
self._decoder_trainer: DecoderTrainer[Callable[[NlvrDecoderState], torch.Tensor]] = \
ExpectedRiskMinimization(beam_size, normalize_beam_score_by_length, max_decoding_steps)
ExpectedRiskMinimization(beam_size,
normalize_beam_score_by_length,
max_decoding_steps,
max_num_finished_states)

# Instantiating an empty NlvrWorld just to get the number of terminals.
self._terminal_productions = set(NlvrWorld([]).terminal_productions.values())
Expand Down Expand Up @@ -423,6 +429,7 @@ def from_params(cls, vocab, params: Params) -> 'NlvrCoverageSemanticParser':
else:
attention_function = None
beam_size = params.pop_int('beam_size')
max_num_finished_states = params.pop_int('max_num_finished_states', None)
normalize_beam_score_by_length = params.pop_bool('normalize_beam_score_by_length', False)
max_decoding_steps = params.pop_int("max_decoding_steps")
checklist_cost_weight = params.pop_float("checklist_cost_weight", 0.6)
Expand All @@ -436,6 +443,7 @@ def from_params(cls, vocab, params: Params) -> 'NlvrCoverageSemanticParser':
encoder=encoder,
attention_function=attention_function,
beam_size=beam_size,
max_num_finished_states=max_num_finished_states,
max_decoding_steps=max_decoding_steps,
normalize_beam_score_by_length=normalize_beam_score_by_length,
checklist_cost_weight=checklist_cost_weight,
Expand Down
Expand Up @@ -50,6 +50,9 @@ class WikiTablesErmSemanticParser(WikiTablesSemanticParser):
attention. Passed to super class.
decoder_beam_size : ``int``
Beam size to be used by the ExpectedRiskMinimization algorithm.
decoder_num_finished_states : ``int``
Number of finished states for which costs will be computed by the ExpectedRiskMinimization
algorithm.
max_decoding_steps : ``int``
Maximum number of steps the decoder should take before giving up. Used both during training
and evaluation. Passed to super class.
Expand Down Expand Up @@ -95,6 +98,7 @@ def __init__(self,
mixture_feedforward: FeedForward,
attention_function: SimilarityFunction,
decoder_beam_size: int,
decoder_num_finished_states: int,
max_decoding_steps: int,
normalize_beam_score_by_length: bool = False,
checklist_cost_weight: float = 0.6,
Expand All @@ -120,7 +124,8 @@ def __init__(self,
self._decoder_trainer: ExpectedRiskMinimization = \
ExpectedRiskMinimization(beam_size=decoder_beam_size,
normalize_by_length=normalize_beam_score_by_length,
max_decoding_steps=self._max_decoding_steps)
max_decoding_steps=self._max_decoding_steps,
max_num_finished_states=decoder_num_finished_states)
unlinked_terminals_global_indices = []
global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace)
for production, index in global_vocab.items():
Expand Down Expand Up @@ -432,6 +437,7 @@ def from_params(cls, vocab, params: Params) -> 'WikiTablesErmSemanticParser':
else:
attention_function = None
decoder_beam_size = params.pop_int("decoder_beam_size")
decoder_num_finished_states = params.pop_int("decoder_num_finished_states", None)
max_decoding_steps = params.pop_int("max_decoding_steps")
normalize_beam_score_by_length = params.pop("normalize_beam_score_by_length", False)
use_neighbor_similarity_for_linking = params.pop_bool("use_neighbor_similarity_for_linking", False)
Expand All @@ -450,6 +456,7 @@ def from_params(cls, vocab, params: Params) -> 'WikiTablesErmSemanticParser':
mixture_feedforward=mixture_feedforward,
attention_function=attention_function,
decoder_beam_size=decoder_beam_size,
decoder_num_finished_states=decoder_num_finished_states,
max_decoding_steps=max_decoding_steps,
normalize_beam_score_by_length=normalize_beam_score_by_length,
checklist_cost_weight=checklist_cost_weight,
Expand Down
60 changes: 42 additions & 18 deletions allennlp/nn/decoding/decoder_trainers/expected_risk_minimization.py
Expand Up @@ -32,16 +32,22 @@ class ExpectedRiskMinimization(DecoderTrainer[Callable[[StateType], torch.Tensor
The maximum number of steps we should take during decoding.
max_num_decoded_sequences : ``int``, optional (default=1)
Maximum number of sorted decoded sequences to return. Defaults to 1.
max_num_finished_states : ``int``, optional (default = None)
Maximum number of finished states to keep after search. This is to finished states as
``beam_size`` is to unfinished ones. Costs are computed for only these number of states per
instance. If not set, we will keep all the finished states.
"""
def __init__(self,
beam_size: int,
normalize_by_length: bool,
max_decoding_steps: int,
max_num_decoded_sequences: int = 1) -> None:
max_num_decoded_sequences: int = 1,
max_num_finished_states: int = None) -> None:
self._beam_size = beam_size
self._normalize_by_length = normalize_by_length
self._max_decoding_steps = max_decoding_steps
self._max_num_decoded_sequences = max_num_decoded_sequences
self._max_num_finished_states = max_num_finished_states

def decode(self,
initial_state: DecoderState,
Expand Down Expand Up @@ -82,10 +88,44 @@ def _get_finished_states(self,
else:
next_states.append(next_state)

states = self._prune_beam(next_states)
states = self._prune_beam(states=next_states,
beam_size=self._beam_size,
sort_states=False)
num_steps += 1
if self._max_num_finished_states is not None:
finished_states = self._prune_beam(states=finished_states,
beam_size=self._max_num_finished_states,
sort_states=True)
return finished_states

# TODO(pradeep): Move this method to nn.decoding.util
@staticmethod
def _prune_beam(states: List[DecoderState],
beam_size: int,
sort_states: bool = False) -> List[DecoderState]:
"""
This method can be used to prune the set of unfinished states on a beam or finished states
at the end of search. In the former case, the states need not be sorted because the all come
from the same decoding step, which does the sorting. However, if the states are finished and
this method is called at the end of the search, they need to be sorted because they come
from different decoding steps.
"""
states_by_batch_index: Dict[int, List[DecoderState]] = defaultdict(list)
for state in states:
assert len(state.batch_indices) == 1
batch_index = state.batch_indices[0]
states_by_batch_index[batch_index].append(state)
pruned_states = []
for _, instance_states in states_by_batch_index.items():
if sort_states:
scores = torch.cat([state.score[0] for state in instance_states])
_, sorted_indices = scores.sort(-1, descending=True)
sorted_states = [instance_states[i] for i in sorted_indices.data.cpu().numpy()]
instance_states = sorted_states
for state in instance_states[:beam_size]:
pruned_states.append(state)
return pruned_states

def _get_model_scores_by_batch(self, states: List[StateType]) -> Dict[int, List[Variable]]:
batch_scores: Dict[int, List[Variable]] = defaultdict(list)
for state in states:
Expand Down Expand Up @@ -132,19 +172,3 @@ def _get_best_action_sequences(self,
for i in best_action_indices]
best_action_sequences[batch_index] = instance_best_sequences
return best_action_sequences

def _prune_beam(self, states: List[DecoderState]) -> List[DecoderState]:
"""
Prunes a beam, and keeps at most ``self._beam_size`` states per instance. We
assume that the ``states`` are grouped, with a group size of 1, and that they're already
sorted.
"""
num_states_per_instance: Dict[int, int] = defaultdict(int)
pruned_states = []
for state in states:
assert len(state.batch_indices) == 1
batch_index = state.batch_indices[0]
if num_states_per_instance[batch_index] < self._beam_size:
pruned_states.append(state)
num_states_per_instance[batch_index] += 1
return pruned_states
Expand Up @@ -25,6 +25,7 @@
"num_layers": 1
},
"beam_size": 20,
"max_num_finished_states": 20,
"max_decoding_steps": 20,
"attention_function": {"type": "dot_product"},
"checklist_cost_weight": 0.8,
Expand Down
Expand Up @@ -31,6 +31,7 @@
"checklist_cost_weight": 0.6,
"max_decoding_steps": 7,
"decoder_beam_size": 10,
"decoder_num_finished_states": 200,
"attention_function": {"type": "dot_product"},
"mml_model_file": "tests/fixtures/semantic_parsing/wikitables/serialization/model.tar.gz"
},
Expand Down

0 comments on commit b5af3e4

Please sign in to comment.