diff --git a/allennlp/commands/predict.py b/allennlp/commands/predict.py index f998aa6f2dd..d6409f743ed 100644 --- a/allennlp/commands/predict.py +++ b/allennlp/commands/predict.py @@ -11,9 +11,10 @@ def add_subparser(parser: argparse._SubParsersAction) -> argparse.ArgumentParser subparser = parser.add_parser( 'predict', description=description, help='Use a trained model to make predictions.') subparser.add_argument('archive_file', type=str, help='the archived model to make predictions with') - subparser.add_argument('input_file', metavar='input-file', type=str, help='path to input file') - subparser.add_argument('--output-file', type=str, help='path to output file') - subparser.add_argument('--print', action='store_true', help='print results to string') + subparser.add_argument('input_file', metavar='input-file', type=argparse.FileType('r'), + help='path to input file') + subparser.add_argument('--output-file', type=argparse.FileType('w'), help='path to output file') + subparser.add_argument('--print', action='store_true', help='print results to stdout') subparser.set_defaults(func=predict) @@ -37,13 +38,12 @@ def run(predictor: Predictor, input_file: IO, output_file: Optional[IO], print_t def predict(args: argparse.Namespace) -> None: predictor = get_predictor(args) + output_file = None # ExitStack allows us to conditionally context-manage `output_file`, which may or may not exist with ExitStack() as stack: - input_file = stack.enter_context(open(args.input_file, 'r')) # type: ignore + input_file = stack.enter_context(args.input_file) # type: ignore if args.output_file: - output_file = stack.enter_context(open(args.output_file, 'w')) # type: ignore - else: - output_file = None + output_file = stack.enter_context(args.output_file) # type: ignore run(predictor, input_file, output_file, args.print) diff --git a/tests/commands/predict_test.py b/tests/commands/predict_test.py index 4083a99b359..366fbeba79a 100644 --- a/tests/commands/predict_test.py +++ b/tests/commands/predict_test.py @@ -18,16 +18,14 @@ def test_add_predict_subparser(self): raw_args = ["predict", # command "/path/to/archive", # archive - "/path/to/data", # input_file - "--output-file", "outfile", + "/dev/null", # input_file + "--output-file", "/dev/null", "--print"] args = parser.parse_args(raw_args) assert args.func == predict assert args.archive_file == "/path/to/archive" - assert args.input_file == "/path/to/data" - assert args.output_file == "outfile" assert args.print def test_works_with_known_model(self):