Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fixes to ERM decoding script (#3041)
Browse files Browse the repository at this point in the history
* fixes to erm decoding script

* changed variable name
  • Loading branch information
pdasigi committed Jul 9, 2019
1 parent 64d16ac commit 0663e0b
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions scripts/wikitables/generate_data_from_erm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))))

from allennlp.data.dataset_readers import WikiTablesDatasetReader
from allennlp.data.dataset_readers.semantic_parsing.wikitables import util
from allennlp.models.archival import load_archive


Expand All @@ -18,27 +19,26 @@ def make_data(input_examples_file: str,
output_dir: str,
num_logical_forms: int) -> None:
reader = WikiTablesDatasetReader(tables_directory=tables_directory,
keep_if_no_dpd=True,
keep_if_no_logical_forms=True,
output_agendas=True)
dataset = reader.read(input_examples_file)
input_lines = []
with open(input_examples_file) as input_file:
input_lines = input_file.readlines()
# Note: Double { for escaping {.
new_tables_config = f"{{model: {{tables_directory: {tables_directory}}}}}"
archive = load_archive(archived_model_file,
overrides=new_tables_config)
archive = load_archive(archived_model_file)
model = archive.model
model.training = False
model._decoder_trainer._max_num_decoded_sequences = 100
for instance, example_line in zip(dataset, input_lines):
outputs = model.forward_on_instance(instance)
parsed_info = reader._parse_example_line(example_line)
world = instance.fields['world'].metadata
parsed_info = util.parse_example_line(example_line)
example_id = parsed_info["id"]
target_list = parsed_info["target_values"]
logical_forms = outputs["logical_form"]
correct_logical_forms = []
for logical_form in logical_forms:
if model._denotation_accuracy.evaluate_logical_form(logical_form, example_line):
if world.evaluate_logical_form(logical_form, target_list):
correct_logical_forms.append(logical_form)
if len(correct_logical_forms) >= num_logical_forms:
break
Expand Down

0 comments on commit 0663e0b

Please sign in to comment.