This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding the decoding framework from the wikitables branch (#1086)
* Adding the decoding framework from the wikitables branch * Fix docs * Add test for MaximumMarginalLikelihood.decode() * Added tests for BeamSearch * Add test for "keep final unfinished states" case * Improve docs from Michael's comments * Minor docstring improvement
- Loading branch information
1 parent
3ea82b1
commit 1655f22
Showing
20 changed files
with
1,244 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""" | ||
This module contains code for transition-based decoding. "Transition-based decoding" is where you | ||
start in some state, iteratively transition between states, and have some kind of supervision | ||
signal that tells you which end states, or which transition sequences, are "good". | ||
If you want to do decoding for a vocabulary-based model, where the allowable outputs are the same | ||
at every timestep of decoding, this code is not what you are looking for, and it will be quite | ||
inefficient compared to other things you could do. | ||
The key abstractions in this code are the following: | ||
- ``DecoderState`` represents the current state of decoding, containing a list of all of the | ||
actions taken so far, and a current score for the state. It also has methods around | ||
determining whether the state is "finished" and for combining states for batched computation. | ||
- ``DecoderStep`` is a ``torch.nn.Module`` that models the transition function between states. | ||
Its main method is ``take_step``, which generates a ranked list of next states given a | ||
current state. | ||
- ``DecoderTrainer`` is an algorithm for training the transition function with some kind of | ||
supervision signal. There are many options for training algorithms and supervision signals; | ||
this is an abstract class that is generic over the type of the supervision signal. | ||
The module also has some classes to help represent the ``DecoderState``, including ``RnnState``, | ||
which you can use to keep track of a decoder RNN's internal state, and ``GrammarState``, which | ||
keeps track of what actions are allowed at each timestep of decoding, if your outputs are | ||
production rules from a grammar. | ||
There is also a generic ``BeamSearch`` class for finding the ``k`` highest-scoring transition | ||
sequences given a trained ``DecoderStep`` and an initial ``DecoderState``. | ||
""" | ||
from allennlp.nn.decoding.beam_search import BeamSearch | ||
from allennlp.nn.decoding.decoder_state import DecoderState | ||
from allennlp.nn.decoding.decoder_step import DecoderStep | ||
from allennlp.nn.decoding.decoder_trainers.decoder_trainer import DecoderTrainer | ||
from allennlp.nn.decoding.grammar_state import GrammarState | ||
from allennlp.nn.decoding.rnn_state import RnnState |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from collections import defaultdict | ||
from typing import Dict, List | ||
|
||
from allennlp.common import Params | ||
from allennlp.nn.decoding.decoder_step import DecoderStep | ||
from allennlp.nn.decoding.decoder_state import DecoderState | ||
|
||
|
||
class BeamSearch: | ||
""" | ||
This class implements beam search over transition sequences given an initial ``DecoderState`` | ||
and a ``DecoderStep``, returning the highest scoring final states found by the beam (the states | ||
will keep track of the transition sequence themselves). | ||
The initial ``DecoderState`` is assumed to be `batched`. The value we return from the search | ||
is a dictionary from batch indices to ranked finished states. | ||
IMPORTANT: We assume that the ``DecoderStep`` that you are using returns possible next states | ||
in sorted order, so we do not do an additional sort inside of ``BeamSearch.search()``. If | ||
you're implementing your own ``DecoderStep``, you must ensure that you've sorted the states | ||
that you return. | ||
""" | ||
def __init__(self, beam_size: int) -> None: | ||
self._beam_size = beam_size | ||
|
||
def search(self, | ||
num_steps: int, | ||
initial_state: DecoderState, | ||
decoder_step: DecoderStep, | ||
keep_final_unfinished_states: bool = True) -> Dict[int, List[DecoderState]]: | ||
""" | ||
Parameters | ||
---------- | ||
num_steps : ``int`` | ||
How many steps should we take in our search? This is an upper bound, as it's possible | ||
for the search to run out of valid actions before hitting this number, or for all | ||
states on the beam to finish. | ||
initial_state : ``DecoderState`` | ||
The starting state of our search. This is assumed to be `batched`, and our beam search | ||
is batch-aware - we'll keep ``beam_size`` states around for each instance in the batch. | ||
decoder_step : ``DecoderStep`` | ||
The ``DecoderStep`` object that defines and scores transitions from one state to the | ||
next. | ||
keep_final_unfinished_states : ``bool``, optional (default=True) | ||
If we run out of steps before a state is "finished", should we return that state in our | ||
search results? | ||
Returns | ||
------- | ||
best_states : ``Dict[int, List[DecoderState]]`` | ||
This is a mapping from batch index to the top states for that instance. | ||
""" | ||
finished_states: Dict[int, List[DecoderState]] = defaultdict(list) | ||
states = [initial_state] | ||
step_num = 1 | ||
while states and step_num <= num_steps: | ||
next_states: Dict[int, List[DecoderState]] = defaultdict(list) | ||
grouped_state = states[0].combine_states(states) | ||
for next_state in decoder_step.take_step(grouped_state, max_actions=self._beam_size): | ||
# NOTE: we're doing state.batch_indices[0] here (and similar things below), | ||
# hard-coding a group size of 1. But, our use of `next_state.is_finished()` | ||
# already checks for that, as it crashes if the group size is not 1. | ||
batch_index = next_state.batch_indices[0] | ||
if next_state.is_finished(): | ||
finished_states[batch_index].append(next_state) | ||
else: | ||
if step_num == num_steps and keep_final_unfinished_states: | ||
finished_states[batch_index].append(next_state) | ||
next_states[batch_index].append(next_state) | ||
states = [] | ||
for batch_index, batch_states in next_states.items(): | ||
# The states from the generator are already sorted, so we can just take the first | ||
# ones here, without an additional sort. | ||
states.extend(batch_states[:self._beam_size]) | ||
step_num += 1 | ||
best_states: Dict[int, List[DecoderState]] = {} | ||
for batch_index, batch_states in finished_states.items(): | ||
# The time this sort takes is pretty negligible, no particular need to optimize this | ||
# yet. Maybe with a larger beam size... | ||
finished_to_sort = [(-state.score[0].data[0], state) for state in batch_states] | ||
finished_to_sort.sort(key=lambda x: x[0]) | ||
best_states[batch_index] = [state[1] for state in finished_to_sort[:self._beam_size]] | ||
return best_states | ||
|
||
@classmethod | ||
def from_params(cls, params: Params) -> 'BeamSearch': | ||
beam_size = params.pop('beam_size') | ||
return cls(beam_size=beam_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Generic, List, TypeVar | ||
|
||
import torch | ||
|
||
# Note that the bound here is `DecoderState` itself. This is what lets us have methods that take | ||
# lists of a `DecoderState` subclass and output structures with the subclass. Really ugly that we | ||
# have to do this generic typing _for our own class_, but it makes mypy happy and gives us good | ||
# type checking in a few important methods. | ||
T = TypeVar('T', bound='DecoderState') | ||
|
||
class DecoderState(Generic[T]): | ||
""" | ||
Represents the (batched) state of a transition-based decoder. | ||
There are two different kinds of batching we need to distinguish here. First, there's the | ||
batch of training instances passed to ``model.forward()``. We'll use "batch" and | ||
``batch_size`` to refer to this through the docs and code. We additionally batch together | ||
computation for several states at the same time, where each state could be from the same | ||
training instance in the original batch, or different instances. We use "group" and | ||
``group_size`` in the docs and code to refer to this kind of batching, to distinguish it from | ||
the batch of training instances. | ||
So, using this terminology, a single ``DecoderState`` object represents a `grouped` collection | ||
of states. Because different states in this group might finish at different timesteps, we have | ||
methods and member variables to handle some bookkeeping around this, to split and regroup | ||
things. | ||
Parameters | ||
---------- | ||
batch_indices : ``List[int]`` | ||
A ``group_size``-length list, where each element specifies which ``batch_index`` that group | ||
element came from. | ||
Our internal variables (like scores, action histories, hidden states, whatever) are | ||
`grouped`, and our ``group_size`` is likely different from the original ``batch_size``. | ||
This variable keeps track of which batch instance each group element came from (e.g., to | ||
know what the correct action sequences are, or which encoder outputs to use). | ||
action_history : ``List[List[int]]`` | ||
The list of actions taken so far in this state. This is also grouped, so each state in the | ||
group has a list of actions. | ||
score : ``List[torch.autograd.Variable]`` | ||
This state's score. It's a variable, because typically we'll be computing a loss based on | ||
this score, and using it for backprop during training. Like the other variables here, this | ||
is a ``group_size``-length list. | ||
""" | ||
def __init__(self, | ||
batch_indices: List[int], | ||
action_history: List[List[int]], | ||
score: List[torch.autograd.Variable]) -> None: | ||
self.batch_indices = batch_indices | ||
self.action_history = action_history | ||
self.score = score | ||
|
||
def is_finished(self) -> bool: | ||
""" | ||
If this state has a ``group_size`` of 1, this returns whether the single action sequence in | ||
this state is finished or not. If this state has a ``group_size`` other than 1, this | ||
method raises an error. | ||
""" | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
def combine_states(cls, states: List[T]) -> T: | ||
""" | ||
Combines a list of states, each with their own group size, into a single state. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from typing import Generic, List, Set, TypeVar | ||
|
||
import torch | ||
|
||
from allennlp.nn.decoding.decoder_state import DecoderState | ||
|
||
StateType = TypeVar('StateType', bound=DecoderState) # pylint: disable=invalid-name | ||
|
||
|
||
class DecoderStep(torch.nn.Module, Generic[StateType]): | ||
""" | ||
A ``DecoderStep`` is a module that assigns scores to state transitions in a transition-based | ||
decoder. | ||
The ``DecoderStep`` takes a ``DecoderState`` and outputs a ranked list of next states, ordered | ||
by the state's score. | ||
The intention with this class is that a model will implement a subclass of ``DecoderStep`` that | ||
defines how exactly you want to handle the input and what computations get done at each step of | ||
decoding, and how states are scored. This subclass then gets passed to a ``DecoderTrainer`` to | ||
have its parameters trained. | ||
""" | ||
def take_step(self, | ||
state: StateType, | ||
max_actions: int = None, | ||
allowed_actions: List[Set] = None) -> List[StateType]: | ||
""" | ||
The main method in the ``DecoderStep`` API. This function defines the computation done at | ||
each step of decoding and returns a ranked list of next states. | ||
The input state is `grouped`, to allow for efficient computation, but the output states | ||
should all have a ``group_size`` of 1, to make things easier on the decoding algorithm. | ||
They will get regrouped later as needed. | ||
Because of the way we handle grouping in the decoder states, constructing a new state is | ||
actually a relatively expensive operation. If you know a priori that only some of the | ||
states will be needed (either because you have a set of gold action sequences, or you have | ||
a fixed beam size), passing that information into this function will keep us from | ||
constructing more states than we need, which will greatly speed up your computation. | ||
IMPORTANT: This method `must` returns states already sorted by their score, otherwise | ||
``BeamSearch`` and other methods will break. For efficiency, we do not perform an | ||
additional sort in those methods. | ||
Parameters | ||
---------- | ||
state : ``DecoderState`` | ||
The current state of the decoder, which we will take a step `from`. We may be grouping | ||
together computation for several states here. Because we can have several states for | ||
each instance in the original batch being evaluated at the same time, we use | ||
``group_size`` for this kind of batching, and ``batch_size`` for the `original` batch | ||
in ``model.forward.`` | ||
max_actions : ``int``, optional | ||
If you know that you will only need a certain number of states out of this (e.g., in a | ||
beam search), you can pass in the max number of actions that you need, and we will only | ||
construct that many states (for each `batch` instance - `not` for each `group` | ||
instance!). This can save a whole lot of computation if you have an action space | ||
that's much larger than your beam size. | ||
allowed_actions : ``List[Set]``, optional | ||
If the ``DecoderTrainer`` has constraints on which actions need to be evaluated (e.g., | ||
maximum marginal likelihood only needs to evaluate action sequences in a given set), | ||
you can pass those constraints here, to avoid constructing state objects unnecessarily. | ||
If there are no constraints from the trainer, passing a value of ``None`` here will | ||
allow all actions to be considered. | ||
This is a list because it is `batched` - every instance in the batch has a set of | ||
allowed actions. Note that the size of this list is the ``group_size`` in the | ||
``DecoderState``, `not` the ``batch_size`` of ``model.forward``. The training | ||
algorithm needs to convert from the `batched` allowed action sequences that it has to a | ||
`grouped` allowed action sequence list. | ||
Returns | ||
------- | ||
next_states : ``List[DecoderState]`` | ||
A list of next states, ordered by score. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from allennlp.nn.decoding.decoder_trainers.expected_risk_minimization import ExpectedRiskMinimization | ||
from allennlp.nn.decoding.decoder_trainers.maximum_marginal_likelihood import MaximumMarginalLikelihood |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import Dict, Generic, TypeVar | ||
|
||
import torch | ||
|
||
from allennlp.nn.decoding.decoder_step import DecoderStep | ||
from allennlp.nn.decoding.decoder_state import DecoderState | ||
|
||
SupervisionType = TypeVar('SupervisionType') # pylint: disable=invalid-name | ||
|
||
class DecoderTrainer(Generic[SupervisionType]): | ||
""" | ||
``DecoderTrainers`` define a training regime for transition-based decoders. A | ||
``DecoderTrainer`` assumes an initial ``DecoderState``, a ``DecoderStep`` function that can | ||
traverse the state space, and some supervision signal. Given these things, the | ||
``DecoderTrainer`` trains the ``DecoderStep`` function to traverse the state space to end up at | ||
good end states. | ||
Concrete implementations of this abstract base class could do things like maximum marginal | ||
likelihood, SEARN, LaSO, or other structured learning algorithms. If you're just trying to | ||
maximize the probability of a single target sequence where the possible outputs are the same | ||
for each timestep (as in, e.g., typical machine translation training regimes), there are way | ||
more efficient ways to do that than using this API. | ||
""" | ||
def decode(self, | ||
initial_state: DecoderState, | ||
decode_step: DecoderStep, | ||
supervision: SupervisionType) -> Dict[str, torch.Tensor]: | ||
""" | ||
Takes an initial state object, a means of transitioning from state to state, and a | ||
supervision signal, and uses the supervision to train the transition function to pick | ||
"good" states. | ||
This function should typically return a ``loss`` key during training, which the ``Model`` | ||
will use as its loss. | ||
Parameters | ||
---------- | ||
initial_state : ``DecoderState`` | ||
This is the initial state for decoding, typically initialized after running some kind | ||
of encoder on some inputs. | ||
decode_step : ``DecoderStep`` | ||
This is the transition function that scores all possible actions that can be taken in a | ||
given state, and returns a ranked list of next states at each step of decoding. | ||
supervision : ``SupervisionType`` | ||
This is the supervision that is used to train the ``decode_step`` function to pick | ||
"good" states. You can use whatever kind of supervision you want (e.g., a single | ||
"gold" action sequence, a set of possible "gold" action sequences, a reward function, | ||
etc.). We use ``typing.Generics`` to make sure that our static type checker is happy | ||
with how you've matched the supervision that you provide in the model to the | ||
``DecoderTrainer`` that you want to use. | ||
""" | ||
raise NotImplementedError |
Oops, something went wrong.