Skip to content
Permalink
Browse files

remove PREDICTOR_OVERRIDES and MODEL_OVERRIDES (#999)

  • Loading branch information...
joelgrus committed Mar 19, 2018
1 parent 1d68f88 commit e28415ca6a93d76006a8544b7fc71f2179b2ea42
@@ -18,22 +18,13 @@


def main(prog: str = None,
model_overrides: Dict[str, DemoModel] = {},
predictor_overrides: Dict[str, str] = {},
subcommand_overrides: Dict[str, Subcommand] = {}) -> None:
"""
The :mod:`~allennlp.run` command only knows about the registered classes in the ``allennlp``
codebase. In particular, once you start creating your own ``Model`` s and so forth, it won't
work for them, unless you use the ``--include-package`` flag available for most commands.
work for them, unless you use the ``--include-package`` flag.
"""
# pylint: disable=dangerous-default-value

# TODO(mattg): document and/or remove the `predictor_overrides` and `model_overrides` commands.
# The `--predictor` option for the `predict` command largely removes the need for
# `predictor_overrides`, and I think the simple server largely removes the need for
# `model_overrides`, and maybe the whole `serve` command as a public API (we only need that
# path for demo.allennlp.org, and it's not likely anyone else would host that particular demo).

parser = argparse.ArgumentParser(description="Run AllenNLP", usage='%(prog)s', prog=prog)

subparsers = parser.add_subparsers(title='Commands', metavar='')
@@ -42,8 +33,8 @@ def main(prog: str = None,
# Default commands
"train": Train(),
"evaluate": Evaluate(),
"predict": Predict(predictor_overrides),
"serve": Serve(model_overrides),
"predict": Predict(),
"serve": Serve(),
"make-vocab": MakeVocab(),
"elmo": Elmo(),
"fine-tune": FineTune(),
@@ -42,7 +42,7 @@
import argparse
from contextlib import ExitStack
import sys
from typing import Optional, IO, Dict
from typing import Optional, IO

from allennlp.commands.subcommand import Subcommand
from allennlp.common.checks import ConfigurationError
@@ -61,10 +61,6 @@


class Predict(Subcommand):
def __init__(self, predictor_overrides: Dict[str, str] = {}) -> None:
# pylint: disable=dangerous-default-value
self.predictors = {**DEFAULT_PREDICTORS, **predictor_overrides}

def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
# pylint: disable=protected-access
description = '''Run the specified model against a JSON-lines input file.'''
@@ -96,11 +92,11 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar
type=str,
help='optionally specify a specific predictor to use')

subparser.set_defaults(func=_predict(self.predictors))
subparser.set_defaults(func=_predict)

return subparser

def _get_predictor(args: argparse.Namespace, predictors: Dict[str, str]) -> Predictor:
def _get_predictor(args: argparse.Namespace) -> Predictor:
archive = load_archive(args.archive_file,
weights_file=args.weights_file,
cuda_device=args.cuda_device,
@@ -112,9 +108,10 @@ def _get_predictor(args: argparse.Namespace, predictors: Dict[str, str]) -> Pred

# Otherwise, use the mapping
model_type = archive.config.get("model").get("type")
if model_type not in predictors:
raise ConfigurationError("no known predictor for model type {}".format(model_type))
return Predictor.from_archive(archive, predictors[model_type])
if model_type not in DEFAULT_PREDICTORS:
raise ConfigurationError(f"No known predictor for model type {model_type}.\n"
f"Specify one with the --predictor flag.")
return Predictor.from_archive(archive, DEFAULT_PREDICTORS[model_type])

def _run(predictor: Predictor,
input_file: IO,
@@ -156,27 +153,24 @@ def _run_predictor(batch_data):
_run_predictor(batch_json_data)


def _predict(predictors: Dict[str, str]):
def predict_inner(args: argparse.Namespace) -> None:
predictor = _get_predictor(args, predictors)
output_file = None

if args.silent and not args.output_file:
print("--silent specified without --output-file.")
print("Exiting early because no output will be created.")
sys.exit(0)
def _predict(args: argparse.Namespace) -> None:
predictor = _get_predictor(args)
output_file = None

# ExitStack allows us to conditionally context-manage `output_file`, which may or may not exist
with ExitStack() as stack:
input_file = stack.enter_context(args.input_file) # type: ignore
if args.output_file:
output_file = stack.enter_context(args.output_file) # type: ignore
if args.silent and not args.output_file:
print("--silent specified without --output-file.")
print("Exiting early because no output will be created.")
sys.exit(0)

_run(predictor,
input_file,
output_file,
args.batch_size,
not args.silent,
args.cuda_device)
# ExitStack allows us to conditionally context-manage `output_file`, which may or may not exist
with ExitStack() as stack:
input_file = stack.enter_context(args.input_file) # type: ignore
if args.output_file:
output_file = stack.enter_context(args.output_file) # type: ignore

return predict_inner
_run(predictor,
input_file,
output_file,
args.batch_size,
not args.silent,
args.cuda_device)
@@ -17,7 +17,6 @@
"""

import argparse
from typing import Dict

from allennlp.commands.subcommand import Subcommand
from allennlp.service import server_flask as server
@@ -58,10 +57,6 @@


class Serve(Subcommand):
def __init__(self, model_overrides: Dict[str, DemoModel] = {}) -> None:
# pylint: disable=dangerous-default-value
self.trained_models = {**DEFAULT_MODELS, **model_overrides}

def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
# pylint: disable=protected-access
description = '''Run the web service, which provides an HTTP API as well as a web demo.'''
@@ -70,12 +65,9 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar

subparser.add_argument('--port', type=int, default=8000)

subparser.set_defaults(func=_serve(self.trained_models))
subparser.set_defaults(func=_serve)

return subparser

def _serve(trained_models: Dict[str, DemoModel]):
def serve_inner(args: argparse.Namespace) -> None:
server.run(args.port, trained_models)

return serve_inner
def _serve(args: argparse.Namespace):
server.run(args.port, DEFAULT_MODELS)
@@ -15,7 +15,7 @@
from allennlp.common.util import JsonDict
from allennlp.common.testing import AllenNlpTestCase
from allennlp.commands import main
from allennlp.commands.predict import Predict, DEFAULT_PREDICTORS
from allennlp.commands.predict import Predict
from allennlp.service.predictors import Predictor, BidafPredictor


@@ -24,7 +24,7 @@ class TestPredict(AllenNlpTestCase):
def test_add_predict_subparser(self):
parser = argparse.ArgumentParser(description="Testing")
subparsers = parser.add_subparsers(title='Commands', metavar='')
Predict(DEFAULT_PREDICTORS).add_subparser('predict', subparsers)
Predict().add_subparser('predict', subparsers)

kebab_args = ["predict", # command
"/path/to/archive", # archive
@@ -36,7 +36,7 @@ def test_add_predict_subparser(self):

args = parser.parse_args(kebab_args)

assert args.func.__name__ == 'predict_inner'
assert args.func.__name__ == '_predict'
assert args.archive_file == "/path/to/archive"
assert args.output_file.name == "/dev/null"
assert args.batch_size == 10
@@ -122,49 +122,6 @@ def test_fails_without_required_args(self):

assert cm.exception.code == 2 # argparse code for incorrect usage

def test_can_override_predictors(self):

@Predictor.register('bidaf-override') # pylint: disable=unused-variable
class Bidaf2Predictor(BidafPredictor):
"""same as bidaf predictor but with an extra field"""
def predict_json(self, inputs: JsonDict, cuda_device: int = -1) -> JsonDict:
result = super().predict_json(inputs)
result["overridden"] = True
return result

tempdir = tempfile.mkdtemp()
infile = os.path.join(tempdir, "inputs.txt")
outfile = os.path.join(tempdir, "outputs.txt")

with open(infile, 'w') as f:
f.write("""{"passage": "the seahawks won the super bowl in 2016", """
""" "question": "when did the seahawks win the super bowl?"}\n""")
f.write("""{"passage": "the mariners won the super bowl in 2037", """
""" "question": "when did the mariners win the super bowl?"}\n""")

sys.argv = ["run.py", # executable
"predict", # command
"tests/fixtures/bidaf/serialization/model.tar.gz",
infile, # input_file
"--output-file", outfile,
"--silent"]

main(predictor_overrides={'bidaf': 'bidaf-override'})
assert os.path.exists(outfile)

with open(outfile, 'r') as f:
results = [json.loads(line) for line in f]

assert len(results) == 2
# Overridden predictor should output extra field
for result in results:
assert set(result.keys()) == {"span_start_logits", "span_end_logits",
"passage_question_attention", "question_tokens",
"passage_tokens", "span_start_probs", "span_end_probs",
"best_span", "best_span_str", "overridden"}

shutil.rmtree(tempdir)

def test_can_specify_predictor(self):

@Predictor.register('bidaf-explicit') # pylint: disable=unused-variable
@@ -298,15 +255,15 @@ def dump_line(self, outputs: JsonDict) -> str:
writer.writerow(["the mariners won the super bowl in 2037",
"when did the mariners win the super bowl?"])


sys.argv = ["run.py", # executable
"predict", # command
"tests/fixtures/bidaf/serialization/model.tar.gz",
infile, # input_file
"--output-file", outfile,
"--predictor", 'bidaf-csv',
"--silent"]

main(predictor_overrides={'bidaf': 'bidaf-csv'})
main()
assert os.path.exists(outfile)

with open(outfile, 'r') as f:
@@ -2,20 +2,20 @@
import argparse
from unittest import TestCase

from allennlp.commands.serve import Serve, DEFAULT_MODELS
from allennlp.commands.serve import Serve


class TestServe(TestCase):

def test_add_serve(self):
parser = argparse.ArgumentParser(description="Testing")
subparsers = parser.add_subparsers(title='Commands', metavar='')
Serve(DEFAULT_MODELS).add_subparser('serve', subparsers)
Serve().add_subparser('serve', subparsers)

raw_args = ["serve",
"--port", "8000"]

args = parser.parse_args(raw_args)

assert args.func.__name__ == 'serve_inner'
assert args.func.__name__ == '_serve'
assert args.port == 8000

0 comments on commit e28415c

Please sign in to comment.
You can’t perform that action at this time.