Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adding WikiTables and NLVR predictors (#1118)
Browse files Browse the repository at this point in the history
* Adding WikiTables and NLVR predictors

* Fixed test paths and fixtures
  • Loading branch information
matt-gardner committed Apr 24, 2018
1 parent 7627a09 commit 960f913
Show file tree
Hide file tree
Showing 11 changed files with 276 additions and 11 deletions.
6 changes: 4 additions & 2 deletions allennlp/service/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
"""
from .predictor import Predictor, DemoModel
from .bidaf import BidafPredictor
from .constituency_parser import ConstituencyParserPredictor
from .coref import CorefPredictor
from .decomposable_attention import DecomposableAttentionPredictor
from .semantic_role_labeler import SemanticRoleLabelerPredictor
from .coref import CorefPredictor
from .sentence_tagger import SentenceTaggerPredictor
from .constituency_parser import ConstituencyParserPredictor
from .simple_seq2seq import SimpleSeq2SeqPredictor
from .wikitables_parser import WikiTablesParserPredictor
from .nlvr_parser import NlvrParserPredictor
16 changes: 16 additions & 0 deletions allennlp/service/predictors/nlvr_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Tuple
from overrides import overrides

from allennlp.common.util import JsonDict
from allennlp.data import Instance
from allennlp.service.predictors.predictor import Predictor


@Predictor.register('nlvr-parser')
class NlvrParserPredictor(Predictor):
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Tuple[Instance, JsonDict]:
sentence = json_dict['sentence']
worlds = json_dict['worlds']
instance = self._dataset_reader.text_to_instance(sentence, worlds)
return instance, {}
120 changes: 120 additions & 0 deletions allennlp/service/predictors/wikitables_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
from subprocess import run
from typing import Tuple

from overrides import overrides

from allennlp.common.file_utils import cached_path
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import DatasetReader, Instance
from allennlp.models import Model
from allennlp.service.predictors.predictor import Predictor

# TODO(mattg): We should merge how this works with how the `WikiTablesAccuracy` metric works, maybe
# just removing the need for adding this stuff at all, because the parser already runs the java
# process. This requires modifying the scala `wikitables-executor` code to also return the
# denotation when running it as a server, and updating the model to parse the output correctly, but
# that shouldn't be too hard.
DEFAULT_EXECUTOR_JAR = "https://s3-us-west-2.amazonaws.com/allennlp/misc/wikitables-executor-0.1.0.jar"
ABBREVIATIONS_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/misc/wikitables-abbreviations.tsv"
GROW_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/misc/wikitables-grow.grammar"
SEMPRE_DIR = 'data/'

@Predictor.register('wikitables-parser')
class WikiTablesParserPredictor(Predictor):
"""
Wrapper for the
:class:`~allennlp.models.encoder_decoders.wikitables_semantic_parser.WikiTablesSemanticParser`
model.
"""

def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
super().__init__(model, dataset_reader)
# Load auxiliary sempre files during startup for faster logical form execution.
os.makedirs(SEMPRE_DIR, exist_ok=True)
abbreviations_path = os.path.join(SEMPRE_DIR, 'abbreviations.tsv')
if not os.path.exists(abbreviations_path):
run(f'wget {ABBREVIATIONS_FILE}', shell=True)
run(f'mv wikitables-abbreviations.tsv {abbreviations_path}', shell=True)

grammar_path = os.path.join(SEMPRE_DIR, 'grow.grammar')
if not os.path.exists(grammar_path):
run(f'wget {GROW_FILE}', shell=True)
run(f'mv wikitables-grow.grammar {grammar_path}', shell=True)

@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Tuple[Instance, JsonDict]:
"""
Expects JSON that looks like ``{"question": "...", "table": "..."}``.
"""
question_text = json_dict["question"]
table_text = json_dict["table"]
cells = []
for row_index, line in enumerate(table_text.split('\n')):
line = line.rstrip('\n')
if row_index == 0:
columns = line.split('\t')
else:
cells.append(line.split('\t'))
# pylint: disable=protected-access
tokenized_question = self._dataset_reader._tokenizer.tokenize(question_text.lower()) # type: ignore
# pylint: enable=protected-access
table_json = {"question": tokenized_question, "columns": columns, "cells": cells}
instance = self._dataset_reader.text_to_instance(question_text, # type: ignore
table_json,
tokenized_question=tokenized_question)
extra_info = {'question_tokens': tokenized_question}
return instance, extra_info

@overrides
def predict_json(self, inputs: JsonDict, cuda_device: int = -1) -> JsonDict:
instance, return_dict = self._json_to_instance(inputs)
outputs = self._model.forward_on_instance(instance, cuda_device)
outputs['answer'] = self._execute_logical_form_on_table(outputs['logical_form'],
inputs['table'])

return_dict.update(outputs)
return sanitize(return_dict)

@staticmethod
def _execute_logical_form_on_table(logical_form, table):
"""
The parameters are written out to files which the jar file reads and then executes the
logical form.
"""
logical_form_filename = os.path.join(SEMPRE_DIR, 'logical_forms.txt')
with open(logical_form_filename, 'w') as temp_file:
temp_file.write(logical_form + '\n')

table_dir = os.path.join(SEMPRE_DIR, 'tsv/')
os.makedirs(table_dir, exist_ok=True)
# The .tsv file extension is important here since the table string parameter is in tsv format.
# If this file was named with suffix .csv then Sempre would interpret it as comma separated
# and return the wrong denotation.
table_filename = 'context.tsv'
with open(os.path.join(table_dir, table_filename), 'w', encoding='utf-8') as temp_file:
temp_file.write(table)

# The id, target, and utterance are ignored, we just need to get the
# table filename into sempre's lisp format.
test_record = ('(example (id nt-0) (utterance none) (context (graph tables.TableKnowledgeGraph %s))'
'(targetValue (list (description "6"))))' % (table_filename))
test_data_filename = os.path.join(SEMPRE_DIR, 'data.examples')
with open(test_data_filename, 'w') as temp_file:
temp_file.write(test_record)

# TODO(matt): The jar that we have isn't optimal for this use case - we're using a
# script designed for computing accuracy, and just pulling out a piece of it. Writing
# a new entry point to the jar that's tailored for this use would be cleaner.
command = ' '.join(['java',
'-jar',
cached_path(DEFAULT_EXECUTOR_JAR),
test_data_filename,
logical_form_filename,
table_dir])
run(command, shell=True)

denotations_file = os.path.join(SEMPRE_DIR, 'logical_forms_denotations.tsv')
with open(denotations_file) as temp_file:
line = temp_file.readline().split('\t')
return line[1] if len(line) > 1 else line[0]
14 changes: 14 additions & 0 deletions doc/api/allennlp.service.predictors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ allennlp.service.predictors
* :ref:`CorefPredictor<coreference-resolution>`
* :ref:`ConstituencyParserPredictor<constituency-parser>`
* :ref:`SimpleSeq2SeqPredictor<simple-seq2seq>`
* :ref:`WikiTablesParserPredictor<wikitables-parser>`
* :ref:`NlvrParserPredictor<nlvr-parser>`

.. _predictor:
.. automodule:: allennlp.service.predictors.predictor
Expand Down Expand Up @@ -62,3 +64,15 @@ allennlp.service.predictors
:members:
:undoc-members:
:show-inheritance:

.. _wikitables-parser:
.. automodule:: allennlp.service.predictors.wikitables_parser
:members:
:undoc-members:
:show-inheritance:

.. _nlvr-parser:
.. automodule:: allennlp.service.predictors.nlvr_parser
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion scripts/train_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def train_fixture_gpu(config_prefix: str) -> None:
'encoder_decoder/simple_seq2seq',
'semantic_parsing/nlvr_coverage_semantic_parser',
'semantic_parsing/nlvr_direct_semantic_parser',
'semantic_parsing/wikitables_semantic_parser',
'semantic_parsing/wikitables',
'srl',
]
for model in models:
Expand Down
Binary file modified tests/fixtures/semantic_parsing/wikitables/serialization/best.th
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@
<r,r> -> [<<#1,#2>,<#2,#1>>, <r,r>]
<r,r> -> fb:row.row.next
<r,r> -> fb:type.object.type
@START@ -> c
@START@ -> d
@START@ -> n
@START@ -> p
@START@ -> r
@start@ -> c
@start@ -> d
@start@ -> n
@start@ -> p
@start@ -> r
c -> [<#1,#1>, c]
c -> [<#1,<#1,#1>>, c, c]
c -> [<d,c>, d]
Expand Down
44 changes: 44 additions & 0 deletions tests/service/predictors/nlvr_parser_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os

from allennlp.common.testing import AllenNlpTestCase
from allennlp.models.archival import load_archive
from allennlp.service.predictors import Predictor


class TestNlvrParserPredictor(AllenNlpTestCase):
def setUp(self):
super().setUp()
self.inputs = {'worlds': [[[{'y_loc': 80, 'type': 'triangle', 'color': '#0099ff', 'x_loc': 80,
'size': 20}],
[{'y_loc': 80, 'type': 'square', 'color': 'Yellow', 'x_loc': 13,
'size': 20}],
[{'y_loc': 67, 'type': 'triangle', 'color': 'Yellow', 'x_loc': 35,
'size': 10}]],
[[{'y_loc': 8, 'type': 'square', 'color': 'Yellow', 'x_loc': 57,
'size': 30}],
[{'y_loc': 43, 'type': 'square', 'color': '#0099ff', 'x_loc': 70,
'size': 30}],
[{'y_loc': 59, 'type': 'square', 'color': 'Yellow', 'x_loc': 47,
'size': 10}]]],
'identifier': 'fake_id',
'sentence': 'Each grey box contains atleast one yellow object touching the edge'}

def test_predictor_with_coverage_parser(self):
archive_dir = 'tests/fixtures/semantic_parsing/nlvr_coverage_semantic_parser/serialization'
archive = load_archive(os.path.join(archive_dir, 'model.tar.gz'))
predictor = Predictor.from_archive(archive, 'nlvr-parser')

result = predictor.predict_json(self.inputs)
assert 'logical_form' in result
assert 'denotations' in result
assert len(result['denotations']) == 2 # Because there are two worlds in the input.

def test_predictor_with_direct_parser(self):
archive_dir = 'tests/fixtures/semantic_parsing/nlvr_direct_semantic_parser/serialization'
archive = load_archive(os.path.join(archive_dir, 'model.tar.gz'))
predictor = Predictor.from_archive(archive, 'nlvr-parser')

result = predictor.predict_json(self.inputs)
assert 'logical_form' in result
assert 'denotations' in result
assert len(result['denotations']) == 2 # Because there are two worlds in the input.
57 changes: 57 additions & 0 deletions tests/service/predictors/wikitables_parser_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# pylint: disable=no-self-use,invalid-name
import os
import shutil
from unittest import TestCase

from allennlp.models.archival import load_archive
from allennlp.service.predictors import Predictor


class TestWikiTablesParserPredictor(TestCase):
def setUp(self):
super().setUp()
self.should_remove_data_dir = not os.path.exists('data')

def tearDown(self):
super().tearDown()
if self.should_remove_data_dir and os.path.exists('data'):
shutil.rmtree('data')

def test_uses_named_inputs(self):
inputs = {
"question": "names",
"table": "name\tdate\nmatt\t2017\npradeep\t2018"
}

archive_dir = 'tests/fixtures/semantic_parsing/wikitables/serialization/'
archive = load_archive(os.path.join(archive_dir, 'model.tar.gz'))
predictor = Predictor.from_archive(archive, 'wikitables-parser')

result = predictor.predict_json(inputs)

action_sequence = result.get("best_action_sequence")
if action_sequence:
# We don't currently disallow endless loops in the decoder, and an untrained seq2seq
# model will easily get itself into a loop. An endless loop isn't a finished logical
# form, so decoding doesn't return any finished states, which means no actions. So,
# sadly, we don't have a great test here. This is just testing that the predictor
# runs, basically.
assert len(action_sequence) > 1
assert all([isinstance(action, str) for action in action_sequence])

logical_form = result.get("logical_form")
assert logical_form is not None

def test_answer_present(self):
inputs = {
"question": "Who is 18 years old?",
"table": "Name\tAge\nShallan\t16\nKaladin\t18"
}

archive_dir = 'tests/fixtures/semantic_parsing/wikitables/serialization/'
archive = load_archive(os.path.join(archive_dir, 'model.tar.gz'))
predictor = Predictor.from_archive(archive, 'wikitables-parser')

result = predictor.predict_json(inputs)
answer = result.get("answer")
assert answer is not None
18 changes: 15 additions & 3 deletions training_config/wikitables_parser.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
{
"dataset_reader": {
"type": "wikitables-preprocessed",
"type": "wikitables",
"lazy": false,
"tables_directory": "/wikitables/",
"dpd_output_directory": "/wikitables/dpd_output/",
"question_token_indexers": {
"tokens": {"type": "single_id"}
}
},
"validation_dataset_reader": {
"type": "wikitables",
"lazy": false,
"tables_directory": "/wikitables/",
"dpd_output_directory": "/wikitables/dpd_output/",
"question_token_indexers": {
"tokens": {"type": "single_id"}
},
"keep_if_no_dpd": true
},
"vocabulary": {
"min_count": {"tokens": 3}
},
Expand Down Expand Up @@ -34,9 +46,9 @@
"averaged": true
},
"decoder_beam_search": {
"beam_size": 3
"beam_size": 10
},
"max_decoding_steps": 20,
"max_decoding_steps": 40,
"attention_function": {
"type": "bilinear",
"tensor_1_dim": 200,
Expand Down

0 comments on commit 960f913

Please sign in to comment.