Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adding the decoding framework from the wikitables branch (#1086)
Browse files Browse the repository at this point in the history
* Adding the decoding framework from the wikitables branch

* Fix docs

* Add test for MaximumMarginalLikelihood.decode()

* Added tests for BeamSearch

* Add test for "keep final unfinished states" case

* Improve docs from Michael's comments

* Minor docstring improvement
  • Loading branch information
matt-gardner committed Apr 13, 2018
1 parent 3ea82b1 commit 1655f22
Show file tree
Hide file tree
Showing 20 changed files with 1,244 additions and 0 deletions.
35 changes: 35 additions & 0 deletions allennlp/nn/decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
This module contains code for transition-based decoding. "Transition-based decoding" is where you
start in some state, iteratively transition between states, and have some kind of supervision
signal that tells you which end states, or which transition sequences, are "good".
If you want to do decoding for a vocabulary-based model, where the allowable outputs are the same
at every timestep of decoding, this code is not what you are looking for, and it will be quite
inefficient compared to other things you could do.
The key abstractions in this code are the following:
- ``DecoderState`` represents the current state of decoding, containing a list of all of the
actions taken so far, and a current score for the state. It also has methods around
determining whether the state is "finished" and for combining states for batched computation.
- ``DecoderStep`` is a ``torch.nn.Module`` that models the transition function between states.
Its main method is ``take_step``, which generates a ranked list of next states given a
current state.
- ``DecoderTrainer`` is an algorithm for training the transition function with some kind of
supervision signal. There are many options for training algorithms and supervision signals;
this is an abstract class that is generic over the type of the supervision signal.
The module also has some classes to help represent the ``DecoderState``, including ``RnnState``,
which you can use to keep track of a decoder RNN's internal state, and ``GrammarState``, which
keeps track of what actions are allowed at each timestep of decoding, if your outputs are
production rules from a grammar.
There is also a generic ``BeamSearch`` class for finding the ``k`` highest-scoring transition
sequences given a trained ``DecoderStep`` and an initial ``DecoderState``.
"""
from allennlp.nn.decoding.beam_search import BeamSearch
from allennlp.nn.decoding.decoder_state import DecoderState
from allennlp.nn.decoding.decoder_step import DecoderStep
from allennlp.nn.decoding.decoder_trainers.decoder_trainer import DecoderTrainer
from allennlp.nn.decoding.grammar_state import GrammarState
from allennlp.nn.decoding.rnn_state import RnnState
88 changes: 88 additions & 0 deletions allennlp/nn/decoding/beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from collections import defaultdict
from typing import Dict, List

from allennlp.common import Params
from allennlp.nn.decoding.decoder_step import DecoderStep
from allennlp.nn.decoding.decoder_state import DecoderState


class BeamSearch:
"""
This class implements beam search over transition sequences given an initial ``DecoderState``
and a ``DecoderStep``, returning the highest scoring final states found by the beam (the states
will keep track of the transition sequence themselves).
The initial ``DecoderState`` is assumed to be `batched`. The value we return from the search
is a dictionary from batch indices to ranked finished states.
IMPORTANT: We assume that the ``DecoderStep`` that you are using returns possible next states
in sorted order, so we do not do an additional sort inside of ``BeamSearch.search()``. If
you're implementing your own ``DecoderStep``, you must ensure that you've sorted the states
that you return.
"""
def __init__(self, beam_size: int) -> None:
self._beam_size = beam_size

def search(self,
num_steps: int,
initial_state: DecoderState,
decoder_step: DecoderStep,
keep_final_unfinished_states: bool = True) -> Dict[int, List[DecoderState]]:
"""
Parameters
----------
num_steps : ``int``
How many steps should we take in our search? This is an upper bound, as it's possible
for the search to run out of valid actions before hitting this number, or for all
states on the beam to finish.
initial_state : ``DecoderState``
The starting state of our search. This is assumed to be `batched`, and our beam search
is batch-aware - we'll keep ``beam_size`` states around for each instance in the batch.
decoder_step : ``DecoderStep``
The ``DecoderStep`` object that defines and scores transitions from one state to the
next.
keep_final_unfinished_states : ``bool``, optional (default=True)
If we run out of steps before a state is "finished", should we return that state in our
search results?
Returns
-------
best_states : ``Dict[int, List[DecoderState]]``
This is a mapping from batch index to the top states for that instance.
"""
finished_states: Dict[int, List[DecoderState]] = defaultdict(list)
states = [initial_state]
step_num = 1
while states and step_num <= num_steps:
next_states: Dict[int, List[DecoderState]] = defaultdict(list)
grouped_state = states[0].combine_states(states)
for next_state in decoder_step.take_step(grouped_state, max_actions=self._beam_size):
# NOTE: we're doing state.batch_indices[0] here (and similar things below),
# hard-coding a group size of 1. But, our use of `next_state.is_finished()`
# already checks for that, as it crashes if the group size is not 1.
batch_index = next_state.batch_indices[0]
if next_state.is_finished():
finished_states[batch_index].append(next_state)
else:
if step_num == num_steps and keep_final_unfinished_states:
finished_states[batch_index].append(next_state)
next_states[batch_index].append(next_state)
states = []
for batch_index, batch_states in next_states.items():
# The states from the generator are already sorted, so we can just take the first
# ones here, without an additional sort.
states.extend(batch_states[:self._beam_size])
step_num += 1
best_states: Dict[int, List[DecoderState]] = {}
for batch_index, batch_states in finished_states.items():
# The time this sort takes is pretty negligible, no particular need to optimize this
# yet. Maybe with a larger beam size...
finished_to_sort = [(-state.score[0].data[0], state) for state in batch_states]
finished_to_sort.sort(key=lambda x: x[0])
best_states[batch_index] = [state[1] for state in finished_to_sort[:self._beam_size]]
return best_states

@classmethod
def from_params(cls, params: Params) -> 'BeamSearch':
beam_size = params.pop('beam_size')
return cls(beam_size=beam_size)
67 changes: 67 additions & 0 deletions allennlp/nn/decoding/decoder_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Generic, List, TypeVar

import torch

# Note that the bound here is `DecoderState` itself. This is what lets us have methods that take
# lists of a `DecoderState` subclass and output structures with the subclass. Really ugly that we
# have to do this generic typing _for our own class_, but it makes mypy happy and gives us good
# type checking in a few important methods.
T = TypeVar('T', bound='DecoderState')

class DecoderState(Generic[T]):
"""
Represents the (batched) state of a transition-based decoder.
There are two different kinds of batching we need to distinguish here. First, there's the
batch of training instances passed to ``model.forward()``. We'll use "batch" and
``batch_size`` to refer to this through the docs and code. We additionally batch together
computation for several states at the same time, where each state could be from the same
training instance in the original batch, or different instances. We use "group" and
``group_size`` in the docs and code to refer to this kind of batching, to distinguish it from
the batch of training instances.
So, using this terminology, a single ``DecoderState`` object represents a `grouped` collection
of states. Because different states in this group might finish at different timesteps, we have
methods and member variables to handle some bookkeeping around this, to split and regroup
things.
Parameters
----------
batch_indices : ``List[int]``
A ``group_size``-length list, where each element specifies which ``batch_index`` that group
element came from.
Our internal variables (like scores, action histories, hidden states, whatever) are
`grouped`, and our ``group_size`` is likely different from the original ``batch_size``.
This variable keeps track of which batch instance each group element came from (e.g., to
know what the correct action sequences are, or which encoder outputs to use).
action_history : ``List[List[int]]``
The list of actions taken so far in this state. This is also grouped, so each state in the
group has a list of actions.
score : ``List[torch.autograd.Variable]``
This state's score. It's a variable, because typically we'll be computing a loss based on
this score, and using it for backprop during training. Like the other variables here, this
is a ``group_size``-length list.
"""
def __init__(self,
batch_indices: List[int],
action_history: List[List[int]],
score: List[torch.autograd.Variable]) -> None:
self.batch_indices = batch_indices
self.action_history = action_history
self.score = score

def is_finished(self) -> bool:
"""
If this state has a ``group_size`` of 1, this returns whether the single action sequence in
this state is finished or not. If this state has a ``group_size`` other than 1, this
method raises an error.
"""
raise NotImplementedError

@classmethod
def combine_states(cls, states: List[T]) -> T:
"""
Combines a list of states, each with their own group size, into a single state.
"""
raise NotImplementedError
77 changes: 77 additions & 0 deletions allennlp/nn/decoding/decoder_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Generic, List, Set, TypeVar

import torch

from allennlp.nn.decoding.decoder_state import DecoderState

StateType = TypeVar('StateType', bound=DecoderState) # pylint: disable=invalid-name


class DecoderStep(torch.nn.Module, Generic[StateType]):
"""
A ``DecoderStep`` is a module that assigns scores to state transitions in a transition-based
decoder.
The ``DecoderStep`` takes a ``DecoderState`` and outputs a ranked list of next states, ordered
by the state's score.
The intention with this class is that a model will implement a subclass of ``DecoderStep`` that
defines how exactly you want to handle the input and what computations get done at each step of
decoding, and how states are scored. This subclass then gets passed to a ``DecoderTrainer`` to
have its parameters trained.
"""
def take_step(self,
state: StateType,
max_actions: int = None,
allowed_actions: List[Set] = None) -> List[StateType]:
"""
The main method in the ``DecoderStep`` API. This function defines the computation done at
each step of decoding and returns a ranked list of next states.
The input state is `grouped`, to allow for efficient computation, but the output states
should all have a ``group_size`` of 1, to make things easier on the decoding algorithm.
They will get regrouped later as needed.
Because of the way we handle grouping in the decoder states, constructing a new state is
actually a relatively expensive operation. If you know a priori that only some of the
states will be needed (either because you have a set of gold action sequences, or you have
a fixed beam size), passing that information into this function will keep us from
constructing more states than we need, which will greatly speed up your computation.
IMPORTANT: This method `must` returns states already sorted by their score, otherwise
``BeamSearch`` and other methods will break. For efficiency, we do not perform an
additional sort in those methods.
Parameters
----------
state : ``DecoderState``
The current state of the decoder, which we will take a step `from`. We may be grouping
together computation for several states here. Because we can have several states for
each instance in the original batch being evaluated at the same time, we use
``group_size`` for this kind of batching, and ``batch_size`` for the `original` batch
in ``model.forward.``
max_actions : ``int``, optional
If you know that you will only need a certain number of states out of this (e.g., in a
beam search), you can pass in the max number of actions that you need, and we will only
construct that many states (for each `batch` instance - `not` for each `group`
instance!). This can save a whole lot of computation if you have an action space
that's much larger than your beam size.
allowed_actions : ``List[Set]``, optional
If the ``DecoderTrainer`` has constraints on which actions need to be evaluated (e.g.,
maximum marginal likelihood only needs to evaluate action sequences in a given set),
you can pass those constraints here, to avoid constructing state objects unnecessarily.
If there are no constraints from the trainer, passing a value of ``None`` here will
allow all actions to be considered.
This is a list because it is `batched` - every instance in the batch has a set of
allowed actions. Note that the size of this list is the ``group_size`` in the
``DecoderState``, `not` the ``batch_size`` of ``model.forward``. The training
algorithm needs to convert from the `batched` allowed action sequences that it has to a
`grouped` allowed action sequence list.
Returns
-------
next_states : ``List[DecoderState]``
A list of next states, ordered by score.
"""
raise NotImplementedError
2 changes: 2 additions & 0 deletions allennlp/nn/decoding/decoder_trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from allennlp.nn.decoding.decoder_trainers.expected_risk_minimization import ExpectedRiskMinimization
from allennlp.nn.decoding.decoder_trainers.maximum_marginal_likelihood import MaximumMarginalLikelihood
52 changes: 52 additions & 0 deletions allennlp/nn/decoding/decoder_trainers/decoder_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Dict, Generic, TypeVar

import torch

from allennlp.nn.decoding.decoder_step import DecoderStep
from allennlp.nn.decoding.decoder_state import DecoderState

SupervisionType = TypeVar('SupervisionType') # pylint: disable=invalid-name

class DecoderTrainer(Generic[SupervisionType]):
"""
``DecoderTrainers`` define a training regime for transition-based decoders. A
``DecoderTrainer`` assumes an initial ``DecoderState``, a ``DecoderStep`` function that can
traverse the state space, and some supervision signal. Given these things, the
``DecoderTrainer`` trains the ``DecoderStep`` function to traverse the state space to end up at
good end states.
Concrete implementations of this abstract base class could do things like maximum marginal
likelihood, SEARN, LaSO, or other structured learning algorithms. If you're just trying to
maximize the probability of a single target sequence where the possible outputs are the same
for each timestep (as in, e.g., typical machine translation training regimes), there are way
more efficient ways to do that than using this API.
"""
def decode(self,
initial_state: DecoderState,
decode_step: DecoderStep,
supervision: SupervisionType) -> Dict[str, torch.Tensor]:
"""
Takes an initial state object, a means of transitioning from state to state, and a
supervision signal, and uses the supervision to train the transition function to pick
"good" states.
This function should typically return a ``loss`` key during training, which the ``Model``
will use as its loss.
Parameters
----------
initial_state : ``DecoderState``
This is the initial state for decoding, typically initialized after running some kind
of encoder on some inputs.
decode_step : ``DecoderStep``
This is the transition function that scores all possible actions that can be taken in a
given state, and returns a ranked list of next states at each step of decoding.
supervision : ``SupervisionType``
This is the supervision that is used to train the ``decode_step`` function to pick
"good" states. You can use whatever kind of supervision you want (e.g., a single
"gold" action sequence, a set of possible "gold" action sequences, a reward function,
etc.). We use ``typing.Generics`` to make sure that our static type checker is happy
with how you've matched the supervision that you provide in the model to the
``DecoderTrainer`` that you want to use.
"""
raise NotImplementedError
Loading

0 comments on commit 1655f22

Please sign in to comment.