Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
WTQ dataset reader and models moved from iterative-search-semparse (#…
Browse files Browse the repository at this point in the history
…2764)

* added dataset reader

* moved dataset reader

* moved models

* removed unnecessary except

* fixed the name of the model in docs

* change trainer_test

* minor fixes

* removed unnecessary modules

* removed unnecessary tests

* removed table knowledge graph

* removed table knowledge graph

* updated knowledge graph field test with new context

* fixed docs after removing table knowledge graph

* got the predictor tests passing

* fixed a bug in checking if data is tagged

* minor fixes

* addressed PR comments and added a test for reading untagged tables

* fixed a silly doc issue

* removed jdk dependency

* more language changes

* undid changes made for backwards compatibility

* pylint and mypy fixes

* better error raising
  • Loading branch information
pdasigi committed Jun 20, 2019
1 parent 7e08298 commit 0fbd1ca
Show file tree
Hide file tree
Showing 51 changed files with 1,044 additions and 3,112 deletions.
3 changes: 1 addition & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ RUN apt-get update --fix-missing && apt-get install -y \
libxrender1 \
wget \
libevent-dev \
build-essential \
openjdk-8-jdk && \
build-essential && \
rm -rf /var/lib/apt/lists/*

# Copy select files needed for installing requirements.
Expand Down
363 changes: 113 additions & 250 deletions allennlp/data/dataset_readers/semantic_parsing/wikitables/wikitables.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from allennlp.models.model import Model
from allennlp.models.semantic_parsing.wikitables.wikitables_semantic_parser import WikiTablesSemanticParser
from allennlp.modules import Attention, FeedForward, Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder
from allennlp.semparse.type_declarations import wikitables_lambda_dcs as types
from allennlp.semparse.worlds import WikiTablesWorld
from allennlp.state_machines.states import CoverageState, ChecklistStatelet
from allennlp.state_machines.trainers import ExpectedRiskMinimization
from allennlp.state_machines.transition_functions import LinkingCoverageTransitionFunction
from allennlp.training.metrics import Average
from allennlp.semparse.domain_languages import WikiTablesLanguage


logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -78,10 +78,6 @@ class WikiTablesErmSemanticParser(WikiTablesSemanticParser):
The vocabulary namespace to use for production rules. The default corresponds to the
default used in the dataset reader, so you likely don't need to modify this. Passed to super
class.
tables_directory : ``str``, optional (default=/wikitables/)
The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to
evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells
SEMPRE where to find the tables. Passed to super class.
mml_model_file : ``str``, optional (default=None)
If you want to initialize this model using weights from another model trained using MML,
pass the path to the ``model.tar.gz`` file of that model here.
Expand All @@ -104,7 +100,6 @@ def __init__(self,
dropout: float = 0.0,
num_linking_features: int = 10,
rule_namespace: str = 'rule_labels',
tables_directory: str = '/wikitables/',
mml_model_file: str = None) -> None:
use_similarity = use_neighbor_similarity_for_linking
super().__init__(vocab=vocab,
Expand All @@ -117,22 +112,13 @@ def __init__(self,
use_neighbor_similarity_for_linking=use_similarity,
dropout=dropout,
num_linking_features=num_linking_features,
rule_namespace=rule_namespace,
tables_directory=tables_directory)
rule_namespace=rule_namespace)
# Not sure why mypy needs a type annotation for this!
self._decoder_trainer: ExpectedRiskMinimization = \
ExpectedRiskMinimization(beam_size=decoder_beam_size,
normalize_by_length=normalize_beam_score_by_length,
max_decoding_steps=self._max_decoding_steps,
max_num_finished_states=decoder_num_finished_states)
unlinked_terminals_global_indices = []
global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace)
for production, index in global_vocab.items():
right_side = production.split(" -> ")[1]
if right_side in types.COMMON_NAME_MAPPING:
# This is a terminal production.
unlinked_terminals_global_indices.append(index)
self._num_unlinked_terminals = len(unlinked_terminals_global_indices)
self._decoder_step = LinkingCoverageTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
action_embedding_dim=action_embedding_dim,
input_attention=attention,
Expand Down Expand Up @@ -204,11 +190,10 @@ def _get_vocab_index_mapping(self, archived_vocab: Vocabulary) -> List[Tuple[int
def forward(self, # type: ignore
question: Dict[str, torch.LongTensor],
table: Dict[str, torch.LongTensor],
world: List[WikiTablesWorld],
world: List[WikiTablesLanguage],
actions: List[List[ProductionRule]],
agenda: torch.LongTensor,
example_lisp_string: List[str],
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
target_values: List[List[str]] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
Parameters
Expand All @@ -221,22 +206,20 @@ def forward(self, # type: ignore
``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each
entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
get embeddings for each entity.
world : ``List[WikiTablesWorld]``
We use a ``MetadataField`` to get the ``World`` for each input instance. Because of
how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
world : ``List[WikiTablesLanguage]``
We use a ``MetadataField`` to get the ``WikiTablesLanguage`` object for each input instance.
Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesLanguage]``,
actions : ``List[List[ProductionRule]]``
A list of all possible actions for each ``World`` in the batch, indexed into a
A list of all possible actions for each ``world`` in the batch, indexed into a
``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these
and use the embeddings to determine which action to take at each timestep in the
decoder.
agenda : ``torch.LongTensor``
Agenda of one instance of size ``(agenda_size, 1)``.
example_lisp_string : ``List[str]``
The example (lisp-formatted) string corresponding to the given input. This comes
directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE
when evaluating denotation accuracy; it is otherwise unused.
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
Metadata containing the original tokenized question within a 'question_tokens' key.
Agenda vectors that the checklist vectors will be compared against to compute the checklist
cost.
target_values : ``List[List[str]]``, optional (default = None)
For each instance, a list of target values taken from the example lisp string. We pass
this list to the evaluator along with logical forms to compute denotation accuracy.
"""
batch_size = list(question.values())[0].size(0)
# Each instance's agenda is of size (agenda_size, 1)
Expand Down Expand Up @@ -275,7 +258,7 @@ def forward(self, # type: ignore
grammar_state=grammar_state,
checklist_state=checklist_states,
possible_actions=actions,
extras=example_lisp_string,
extras=target_values,
debug_info=None)

if not self.training:
Expand Down Expand Up @@ -316,10 +299,11 @@ def forward(self, # type: ignore
in_agenda_ratio = sum(actions_in_agenda) / len(actions_in_agenda)
self._agenda_coverage(in_agenda_ratio)

metadata = None
self._compute_validation_outputs(actions,
best_final_states,
world,
example_lisp_string,
target_values,
metadata,
outputs)
return outputs
Expand Down Expand Up @@ -350,7 +334,7 @@ def _get_checklist_info(agenda: torch.LongTensor,
"""
terminal_indices = []
target_checklist_list = []
agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
agenda_indices_set = {int(x) for x in agenda.squeeze(0).detach().cpu().numpy()}
# We want to return checklist target and terminal actions that are column vectors to make
# computing softmax over the difference between checklist and target easier.
for index, action in enumerate(all_actions):
Expand All @@ -372,7 +356,7 @@ def _get_checklist_info(agenda: torch.LongTensor,
checklist_mask = (target_checklist != 0).float()
return target_checklist, terminal_actions, checklist_mask

def _get_state_cost(self, worlds: List[WikiTablesWorld], state: CoverageState) -> torch.Tensor:
def _get_state_cost(self, worlds: List[WikiTablesLanguage], state: CoverageState) -> torch.Tensor:
if not state.is_finished():
raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
world = worlds[state.batch_indices[0]]
Expand All @@ -390,9 +374,13 @@ def _get_state_cost(self, worlds: List[WikiTablesWorld], state: CoverageState) -
action_history = state.action_history[0]
batch_index = state.batch_indices[0]
action_strings = [state.possible_actions[batch_index][i][0] for i in action_history]
logical_form = world.get_logical_form(action_strings)
lisp_string = state.extras[batch_index]
if self._executor.evaluate_logical_form(logical_form, lisp_string):
target_values = state.extras[batch_index]
evaluation = False
executor_logger = \
logging.getLogger('allennlp.semparse.domain_languages.wikitables_language')
executor_logger.setLevel(logging.ERROR)
evaluation = world.evaluate_action_sequence(action_strings, target_values)
if evaluation:
cost = checklist_cost
else:
cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import torch

from allennlp.data import Vocabulary
from allennlp.data.fields.production_rule_field import ProductionRule
from allennlp.data.fields.production_rule_field import ProductionRuleArray
from allennlp.models.model import Model
from allennlp.models.semantic_parsing.wikitables.wikitables_semantic_parser import WikiTablesSemanticParser
from allennlp.modules import Attention, FeedForward, Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder
from allennlp.semparse.worlds import WikiTablesWorld
from allennlp.state_machines import BeamSearch
from allennlp.state_machines.states import GrammarBasedState
from allennlp.state_machines.trainers import MaximumMarginalLikelihood
from allennlp.state_machines.transition_functions import LinkingTransitionFunction

from allennlp.semparse.domain_languages import WikiTablesLanguage
from allennlp.models.semantic_parsing.wikitables.wikitables_semantic_parser \
import WikiTablesSemanticParser

@Model.register("wikitables_mml_parser")
class WikiTablesMmlSemanticParser(WikiTablesSemanticParser):
Expand All @@ -23,10 +23,9 @@ class WikiTablesMmlSemanticParser(WikiTablesSemanticParser):
denotation. This is a re-implementation of the model used for the paper `Neural Semantic Parsing with Type
Constraints for Semi-Structured Tables
<https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a>`_,
by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017).
WORK STILL IN PROGRESS. We'll iteratively improve it until we've reproduced the performance of
the original parser.
by Jayant Krishnamurthy, Pradeep Dasigi, and Matt Gardner (EMNLP 2017). The language used by
this model is different from LambdaDCS, the one in the paper above though. This model uses the
variable free language from ``allennlp.semparse.domain_languages.wikitables_language``.
Parameters
----------
Expand Down Expand Up @@ -77,10 +76,6 @@ class WikiTablesMmlSemanticParser(WikiTablesSemanticParser):
The vocabulary namespace to use for production rules. The default corresponds to the
default used in the dataset reader, so you likely don't need to modify this. Passed to super
class.
tables_directory : ``str``, optional (default=/wikitables/)
The directory to find tables when evaluating logical forms. We rely on a call to SEMPRE to
evaluate logical forms, and SEMPRE needs to read the table from disk itself. This tells
SEMPRE where to find the tables. Passed to super class.
"""
def __init__(self,
vocab: Vocabulary,
Expand All @@ -97,8 +92,7 @@ def __init__(self,
use_neighbor_similarity_for_linking: bool = False,
dropout: float = 0.0,
num_linking_features: int = 10,
rule_namespace: str = 'rule_labels',
tables_directory: str = '/wikitables/') -> None:
rule_namespace: str = 'rule_labels') -> None:
use_similarity = use_neighbor_similarity_for_linking
super().__init__(vocab=vocab,
question_embedder=question_embedder,
Expand All @@ -110,8 +104,7 @@ def __init__(self,
use_neighbor_similarity_for_linking=use_similarity,
dropout=dropout,
num_linking_features=num_linking_features,
rule_namespace=rule_namespace,
tables_directory=tables_directory)
rule_namespace=rule_namespace)
self._beam_search = decoder_beam_search
self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
Expand All @@ -127,11 +120,10 @@ def __init__(self,
def forward(self, # type: ignore
question: Dict[str, torch.LongTensor],
table: Dict[str, torch.LongTensor],
world: List[WikiTablesWorld],
actions: List[List[ProductionRule]],
example_lisp_string: List[str] = None,
target_action_sequences: torch.LongTensor = None,
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
world: List[WikiTablesLanguage],
actions: List[List[ProductionRuleArray]],
target_values: List[List[str]] = None,
target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
In this method we encode the table entities, link them to words in the question, then
Expand All @@ -149,24 +141,21 @@ def forward(self, # type: ignore
``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each
entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
get embeddings for each entity.
world : ``List[WikiTablesWorld]``
We use a ``MetadataField`` to get the ``World`` for each input instance. Because of
how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
actions : ``List[List[ProductionRule]]``
A list of all possible actions for each ``World`` in the batch, indexed into a
``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these
world : ``List[WikiTablesLanguage]``
We use a ``MetadataField`` to get the ``WikiTablesLanguage`` object for each input instance.
Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesLanguage]``,
actions : ``List[List[ProductionRuleArray]]``
A list of all possible actions for each ``world`` in the batch, indexed into a
``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these
and use the embeddings to determine which action to take at each timestep in the
decoder.
example_lisp_string : ``List[str]``, optional (default = None)
The example (lisp-formatted) string corresponding to the given input. This comes
directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE
when evaluating denotation accuracy; it is otherwise unused.
target_values : ``List[List[str]]``, optional (default = None)
For each instance, a list of target values taken from the example lisp string. We pass
this list to the evaluator along with logical forms to compute denotation accuracy.
target_action_sequences : torch.Tensor, optional (default = None)
A list of possibly valid action sequences, where each action is an index into the list
of possible actions. This tensor has shape ``(batch_size, num_action_sequences,
sequence_length)``.
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
Metadata containing the original tokenized question within a 'question_tokens' key.
"""
outputs: Dict[str, Any] = {}
rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
Expand All @@ -183,7 +172,7 @@ def forward(self, # type: ignore
rnn_state=rnn_state,
grammar_state=grammar_state,
possible_actions=actions,
extras=example_lisp_string,
extras=target_values,
debug_info=None)

if target_action_sequences is not None:
Expand Down Expand Up @@ -223,11 +212,11 @@ def forward(self, # type: ignore
sequence_in_targets = self._action_history_match(best_action_indices, targets)
self._action_sequence_accuracy(sequence_in_targets)

metadata = None
self._compute_validation_outputs(actions,
best_final_states,
world,
example_lisp_string,
target_values,
metadata,
outputs)

return outputs

0 comments on commit 0fbd1ca

Please sign in to comment.