This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding WikiTables and NLVR predictors (#1118)
* Adding WikiTables and NLVR predictors * Fixed test paths and fixtures
- Loading branch information
1 parent
7627a09
commit 960f913
Showing
11 changed files
with
276 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from typing import Tuple | ||
from overrides import overrides | ||
|
||
from allennlp.common.util import JsonDict | ||
from allennlp.data import Instance | ||
from allennlp.service.predictors.predictor import Predictor | ||
|
||
|
||
@Predictor.register('nlvr-parser') | ||
class NlvrParserPredictor(Predictor): | ||
@overrides | ||
def _json_to_instance(self, json_dict: JsonDict) -> Tuple[Instance, JsonDict]: | ||
sentence = json_dict['sentence'] | ||
worlds = json_dict['worlds'] | ||
instance = self._dataset_reader.text_to_instance(sentence, worlds) | ||
return instance, {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import os | ||
from subprocess import run | ||
from typing import Tuple | ||
|
||
from overrides import overrides | ||
|
||
from allennlp.common.file_utils import cached_path | ||
from allennlp.common.util import JsonDict, sanitize | ||
from allennlp.data import DatasetReader, Instance | ||
from allennlp.models import Model | ||
from allennlp.service.predictors.predictor import Predictor | ||
|
||
# TODO(mattg): We should merge how this works with how the `WikiTablesAccuracy` metric works, maybe | ||
# just removing the need for adding this stuff at all, because the parser already runs the java | ||
# process. This requires modifying the scala `wikitables-executor` code to also return the | ||
# denotation when running it as a server, and updating the model to parse the output correctly, but | ||
# that shouldn't be too hard. | ||
DEFAULT_EXECUTOR_JAR = "https://s3-us-west-2.amazonaws.com/allennlp/misc/wikitables-executor-0.1.0.jar" | ||
ABBREVIATIONS_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/misc/wikitables-abbreviations.tsv" | ||
GROW_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/misc/wikitables-grow.grammar" | ||
SEMPRE_DIR = 'data/' | ||
|
||
@Predictor.register('wikitables-parser') | ||
class WikiTablesParserPredictor(Predictor): | ||
""" | ||
Wrapper for the | ||
:class:`~allennlp.models.encoder_decoders.wikitables_semantic_parser.WikiTablesSemanticParser` | ||
model. | ||
""" | ||
|
||
def __init__(self, model: Model, dataset_reader: DatasetReader) -> None: | ||
super().__init__(model, dataset_reader) | ||
# Load auxiliary sempre files during startup for faster logical form execution. | ||
os.makedirs(SEMPRE_DIR, exist_ok=True) | ||
abbreviations_path = os.path.join(SEMPRE_DIR, 'abbreviations.tsv') | ||
if not os.path.exists(abbreviations_path): | ||
run(f'wget {ABBREVIATIONS_FILE}', shell=True) | ||
run(f'mv wikitables-abbreviations.tsv {abbreviations_path}', shell=True) | ||
|
||
grammar_path = os.path.join(SEMPRE_DIR, 'grow.grammar') | ||
if not os.path.exists(grammar_path): | ||
run(f'wget {GROW_FILE}', shell=True) | ||
run(f'mv wikitables-grow.grammar {grammar_path}', shell=True) | ||
|
||
@overrides | ||
def _json_to_instance(self, json_dict: JsonDict) -> Tuple[Instance, JsonDict]: | ||
""" | ||
Expects JSON that looks like ``{"question": "...", "table": "..."}``. | ||
""" | ||
question_text = json_dict["question"] | ||
table_text = json_dict["table"] | ||
cells = [] | ||
for row_index, line in enumerate(table_text.split('\n')): | ||
line = line.rstrip('\n') | ||
if row_index == 0: | ||
columns = line.split('\t') | ||
else: | ||
cells.append(line.split('\t')) | ||
# pylint: disable=protected-access | ||
tokenized_question = self._dataset_reader._tokenizer.tokenize(question_text.lower()) # type: ignore | ||
# pylint: enable=protected-access | ||
table_json = {"question": tokenized_question, "columns": columns, "cells": cells} | ||
instance = self._dataset_reader.text_to_instance(question_text, # type: ignore | ||
table_json, | ||
tokenized_question=tokenized_question) | ||
extra_info = {'question_tokens': tokenized_question} | ||
return instance, extra_info | ||
|
||
@overrides | ||
def predict_json(self, inputs: JsonDict, cuda_device: int = -1) -> JsonDict: | ||
instance, return_dict = self._json_to_instance(inputs) | ||
outputs = self._model.forward_on_instance(instance, cuda_device) | ||
outputs['answer'] = self._execute_logical_form_on_table(outputs['logical_form'], | ||
inputs['table']) | ||
|
||
return_dict.update(outputs) | ||
return sanitize(return_dict) | ||
|
||
@staticmethod | ||
def _execute_logical_form_on_table(logical_form, table): | ||
""" | ||
The parameters are written out to files which the jar file reads and then executes the | ||
logical form. | ||
""" | ||
logical_form_filename = os.path.join(SEMPRE_DIR, 'logical_forms.txt') | ||
with open(logical_form_filename, 'w') as temp_file: | ||
temp_file.write(logical_form + '\n') | ||
|
||
table_dir = os.path.join(SEMPRE_DIR, 'tsv/') | ||
os.makedirs(table_dir, exist_ok=True) | ||
# The .tsv file extension is important here since the table string parameter is in tsv format. | ||
# If this file was named with suffix .csv then Sempre would interpret it as comma separated | ||
# and return the wrong denotation. | ||
table_filename = 'context.tsv' | ||
with open(os.path.join(table_dir, table_filename), 'w', encoding='utf-8') as temp_file: | ||
temp_file.write(table) | ||
|
||
# The id, target, and utterance are ignored, we just need to get the | ||
# table filename into sempre's lisp format. | ||
test_record = ('(example (id nt-0) (utterance none) (context (graph tables.TableKnowledgeGraph %s))' | ||
'(targetValue (list (description "6"))))' % (table_filename)) | ||
test_data_filename = os.path.join(SEMPRE_DIR, 'data.examples') | ||
with open(test_data_filename, 'w') as temp_file: | ||
temp_file.write(test_record) | ||
|
||
# TODO(matt): The jar that we have isn't optimal for this use case - we're using a | ||
# script designed for computing accuracy, and just pulling out a piece of it. Writing | ||
# a new entry point to the jar that's tailored for this use would be cleaner. | ||
command = ' '.join(['java', | ||
'-jar', | ||
cached_path(DEFAULT_EXECUTOR_JAR), | ||
test_data_filename, | ||
logical_form_filename, | ||
table_dir]) | ||
run(command, shell=True) | ||
|
||
denotations_file = os.path.join(SEMPRE_DIR, 'logical_forms_denotations.tsv') | ||
with open(denotations_file) as temp_file: | ||
line = temp_file.readline().split('\t') | ||
return line[1] if len(line) > 1 else line[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file modified
BIN
+0 Bytes
(100%)
tests/fixtures/semantic_parsing/wikitables/serialization/best.th
Binary file not shown.
Binary file modified
BIN
-4 Bytes
(100%)
tests/fixtures/semantic_parsing/wikitables/serialization/model.tar.gz
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
|
||
from allennlp.common.testing import AllenNlpTestCase | ||
from allennlp.models.archival import load_archive | ||
from allennlp.service.predictors import Predictor | ||
|
||
|
||
class TestNlvrParserPredictor(AllenNlpTestCase): | ||
def setUp(self): | ||
super().setUp() | ||
self.inputs = {'worlds': [[[{'y_loc': 80, 'type': 'triangle', 'color': '#0099ff', 'x_loc': 80, | ||
'size': 20}], | ||
[{'y_loc': 80, 'type': 'square', 'color': 'Yellow', 'x_loc': 13, | ||
'size': 20}], | ||
[{'y_loc': 67, 'type': 'triangle', 'color': 'Yellow', 'x_loc': 35, | ||
'size': 10}]], | ||
[[{'y_loc': 8, 'type': 'square', 'color': 'Yellow', 'x_loc': 57, | ||
'size': 30}], | ||
[{'y_loc': 43, 'type': 'square', 'color': '#0099ff', 'x_loc': 70, | ||
'size': 30}], | ||
[{'y_loc': 59, 'type': 'square', 'color': 'Yellow', 'x_loc': 47, | ||
'size': 10}]]], | ||
'identifier': 'fake_id', | ||
'sentence': 'Each grey box contains atleast one yellow object touching the edge'} | ||
|
||
def test_predictor_with_coverage_parser(self): | ||
archive_dir = 'tests/fixtures/semantic_parsing/nlvr_coverage_semantic_parser/serialization' | ||
archive = load_archive(os.path.join(archive_dir, 'model.tar.gz')) | ||
predictor = Predictor.from_archive(archive, 'nlvr-parser') | ||
|
||
result = predictor.predict_json(self.inputs) | ||
assert 'logical_form' in result | ||
assert 'denotations' in result | ||
assert len(result['denotations']) == 2 # Because there are two worlds in the input. | ||
|
||
def test_predictor_with_direct_parser(self): | ||
archive_dir = 'tests/fixtures/semantic_parsing/nlvr_direct_semantic_parser/serialization' | ||
archive = load_archive(os.path.join(archive_dir, 'model.tar.gz')) | ||
predictor = Predictor.from_archive(archive, 'nlvr-parser') | ||
|
||
result = predictor.predict_json(self.inputs) | ||
assert 'logical_form' in result | ||
assert 'denotations' in result | ||
assert len(result['denotations']) == 2 # Because there are two worlds in the input. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# pylint: disable=no-self-use,invalid-name | ||
import os | ||
import shutil | ||
from unittest import TestCase | ||
|
||
from allennlp.models.archival import load_archive | ||
from allennlp.service.predictors import Predictor | ||
|
||
|
||
class TestWikiTablesParserPredictor(TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
self.should_remove_data_dir = not os.path.exists('data') | ||
|
||
def tearDown(self): | ||
super().tearDown() | ||
if self.should_remove_data_dir and os.path.exists('data'): | ||
shutil.rmtree('data') | ||
|
||
def test_uses_named_inputs(self): | ||
inputs = { | ||
"question": "names", | ||
"table": "name\tdate\nmatt\t2017\npradeep\t2018" | ||
} | ||
|
||
archive_dir = 'tests/fixtures/semantic_parsing/wikitables/serialization/' | ||
archive = load_archive(os.path.join(archive_dir, 'model.tar.gz')) | ||
predictor = Predictor.from_archive(archive, 'wikitables-parser') | ||
|
||
result = predictor.predict_json(inputs) | ||
|
||
action_sequence = result.get("best_action_sequence") | ||
if action_sequence: | ||
# We don't currently disallow endless loops in the decoder, and an untrained seq2seq | ||
# model will easily get itself into a loop. An endless loop isn't a finished logical | ||
# form, so decoding doesn't return any finished states, which means no actions. So, | ||
# sadly, we don't have a great test here. This is just testing that the predictor | ||
# runs, basically. | ||
assert len(action_sequence) > 1 | ||
assert all([isinstance(action, str) for action in action_sequence]) | ||
|
||
logical_form = result.get("logical_form") | ||
assert logical_form is not None | ||
|
||
def test_answer_present(self): | ||
inputs = { | ||
"question": "Who is 18 years old?", | ||
"table": "Name\tAge\nShallan\t16\nKaladin\t18" | ||
} | ||
|
||
archive_dir = 'tests/fixtures/semantic_parsing/wikitables/serialization/' | ||
archive = load_archive(os.path.join(archive_dir, 'model.tar.gz')) | ||
predictor = Predictor.from_archive(archive, 'wikitables-parser') | ||
|
||
result = predictor.predict_json(inputs) | ||
answer = result.get("answer") | ||
assert answer is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters