Skip to content
Permalink
Browse files

Number predictions. (#2709)

  • Loading branch information...
schmmd committed May 7, 2019
1 parent aeafe1c commit 48f46ea9b7390658c200a9f74e6dcaca93feb4ab
Showing with 7 additions and 3 deletions.
  1. +7 −3 allennlp/commands/predict.py
@@ -141,11 +141,12 @@ def _predict_instances(self, batch_data: List[Instance]) -> Iterator[str]:
yield self._predictor.dump_line(output)

def _maybe_print_to_console_and_file(self,
index: int,
prediction: str,
model_input: str = None) -> None:
if self._print_to_console:
if model_input is not None:
print("input: ", model_input)
print(f"input {index}: ", model_input)
print("prediction: ", prediction)
if self._output_file is not None:
self._output_file.write(prediction)
@@ -171,14 +172,17 @@ def _get_instance_data(self) -> Iterator[Instance]:

def run(self) -> None:
has_reader = self._dataset_reader is not None
index = 0
if has_reader:
for batch in lazy_groups_of(self._get_instance_data(), self._batch_size):
for model_input_instance, result in zip(batch, self._predict_instances(batch)):
self._maybe_print_to_console_and_file(result, str(model_input_instance))
self._maybe_print_to_console_and_file(index, result, str(model_input_instance))
index = index + 1
else:
for batch_json in lazy_groups_of(self._get_json_data(), self._batch_size):
for model_input_json, result in zip(batch_json, self._predict_json(batch_json)):
self._maybe_print_to_console_and_file(result, json.dumps(model_input_json))
self._maybe_print_to_console_and_file(index, result, json.dumps(model_input_json))
index = index + 1

if self._output_file is not None:
self._output_file.close()

0 comments on commit 48f46ea

Please sign in to comment.
You can’t perform that action at this time.