Skip to content

Commit

Permalink
Merge 91b1e4d into 316b3c4
Browse files Browse the repository at this point in the history
  • Loading branch information
ricwo committed Sep 4, 2019
2 parents 316b3c4 + 91b1e4d commit cfa6526
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 50 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -85,6 +85,7 @@ Fixed
- ``rasa test nlu`` with a folder of configuration files
- ``MappingPolicy`` standard featurizer is set to ``None``
- Removed ``text`` parameter from send_attachment function in slack.py to avoid duplication of text output to slackbot
- server ``/status`` endpoint reports status when an NLU-only model is loaded

Removed
-------
Expand Down
34 changes: 16 additions & 18 deletions rasa/core/agent.py
Expand Up @@ -396,21 +396,17 @@ def load(
remote_storage=remote_storage,
)

def is_ready(self, allow_nlu_only: bool = False):
def is_core_ready(self):
"""Check if all necessary components and policies are ready to use the agent.
"""
return self.is_ready() and self.policy_ensemble

def is_ready(self):
"""Check if all necessary components are instantiated to use agent.
Args:
allow_nlu_only: If `True`, consider the agent ready event if no policy
ensemble is present.
Policies might not be available, if this is an NLU only agent."""

"""
return all(
[
self.tracker_store,
self.interpreter,
self.policy_ensemble or allow_nlu_only,
]
)
return self.tracker_store and self.interpreter

async def parse_message_using_nlu_interpreter(
self, message_data: Text, tracker: DialogueStateTracker = None
Expand Down Expand Up @@ -466,7 +462,7 @@ def noop(_):
logger.info("Ignoring message as there is no agent to handle it.")
return None

if not self.is_ready(allow_nlu_only=True):
if not self.is_ready():
return noop(message)

processor = self.create_processor(message_preprocessor)
Expand Down Expand Up @@ -567,7 +563,7 @@ def continue_training(
self, trackers: List[DialogueStateTracker], **kwargs: Any
) -> None:

if not self.is_ready():
if not self.is_core_ready():
raise AgentNotReady("Can't continue training without a policy ensemble.")

self.policy_ensemble.continue_training(trackers, self.domain, **kwargs)
Expand Down Expand Up @@ -651,7 +647,7 @@ def train(
**kwargs: additional arguments passed to the underlying ML
trainer (e.g. keras parameters)
"""
if not self.is_ready():
if not self.is_core_ready():
raise AgentNotReady("Can't train without a policy ensemble.")

# deprecation tests
Expand Down Expand Up @@ -760,7 +756,7 @@ def _clear_model_directory(model_path: Text) -> None:
def persist(self, model_path: Text, dump_flattened_stories: bool = False) -> None:
"""Persists this agent into a directory for later loading and usage."""

if not self.is_ready():
if not self.is_core_ready():
raise AgentNotReady("Can't persist without a policy ensemble.")

if not model_path.endswith("core"):
Expand Down Expand Up @@ -810,7 +806,7 @@ def create_processor(
"""Instantiates a processor based on the set state of the agent."""
# Checks that the interpreter and tracker store are set and
# creates a processor
if not self.is_ready(allow_nlu_only=True):
if not self.is_ready():
raise AgentNotReady(
"Agent needs to be prepared before usage. You need to set an "
"interpreter and a tracker store."
Expand All @@ -835,7 +831,9 @@ def _create_domain(domain: Union[Domain, Text]) -> Domain:
return domain
elif isinstance(domain, Domain):
return domain
elif domain is not None:
elif domain is None:
return Domain.empty()
else:
raise ValueError(
"Invalid param `domain`. Expected a path to a domain "
"specification or a domain instance. But got "
Expand Down
7 changes: 7 additions & 0 deletions rasa/core/processor.py
Expand Up @@ -111,6 +111,13 @@ def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]:
)
return None

if not self.policy_ensemble or not self.domain:
# save tracker state to continue conversation from this state
logger.warning(
"No policy ensemble or domain set. Skipping action prediction "
)
return None

probabilities, policy = self._get_next_action_probabilities(tracker)
# save tracker state to continue conversation from this state
self._save_tracker(tracker)
Expand Down
38 changes: 16 additions & 22 deletions rasa/server.py
Expand Up @@ -66,13 +66,22 @@ def _docs(sub_url: Text) -> Text:
return DOCS_BASE_URL + sub_url


def ensure_loaded_agent(app: Sanic):
"""Wraps a request handler ensuring there is a loaded and usable agent."""
def ensure_loaded_agent(app: Sanic, require_core_is_ready=False):
"""Wraps a request handler ensuring there is a loaded and usable agent.
Require the agent to have a loaded Core model if `require_core_is_ready` is
`True`.
"""

def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
if not app.agent or not app.agent.is_ready():
# noinspection PyUnresolvedReferences
if not app.agent or not (
app.agent.is_core_ready()
if require_core_is_ready
else app.agent.is_ready()
):
raise ErrorResponse(
409,
"Conflict",
Expand Down Expand Up @@ -385,14 +394,6 @@ async def status(request: Request):
@ensure_loaded_agent(app)
async def retrieve_tracker(request: Request, conversation_id: Text):
"""Get a dump of a conversation's tracker including its events."""
if not app.agent.tracker_store:
raise ErrorResponse(
409,
"Conflict",
"No tracker store available. Make sure to "
"configure a tracker store when starting "
"the server.",
)

verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART)
until_time = rasa.utils.endpoints.float_arg(request, "until")
Expand Down Expand Up @@ -497,14 +498,6 @@ async def replace_events(request: Request, conversation_id: Text):
@ensure_loaded_agent(app)
async def retrieve_story(request: Request, conversation_id: Text):
"""Get an end-to-end story corresponding to this conversation."""
if not app.agent.tracker_store:
raise ErrorResponse(
409,
"Conflict",
"No tracker store available. Make sure to "
"configure a tracker store when starting "
"the server.",
)

# retrieve tracker and set to requested state
tracker = get_tracker(app.agent, conversation_id)
Expand Down Expand Up @@ -729,7 +722,7 @@ def validate_request(rjs):

@app.post("/model/test/stories")
@requires_auth(app, auth_token)
@ensure_loaded_agent(app)
@ensure_loaded_agent(app, require_core_is_ready=True)
async def evaluate_stories(request: Request):
"""Evaluate stories against the currently loaded model."""
validate_request_body(
Expand Down Expand Up @@ -795,7 +788,7 @@ async def evaluate_intents(request: Request):

@app.post("/model/predict")
@requires_auth(app, auth_token)
@ensure_loaded_agent(app)
@ensure_loaded_agent(app, require_core_is_ready=True)
async def tracker_predict(request: Request):
""" Given a list of events, predicts the next action"""
validate_request_body(
Expand Down Expand Up @@ -848,6 +841,7 @@ async def tracker_predict(request: Request):

@app.post("/model/parse")
@requires_auth(app, auth_token)
@ensure_loaded_agent(app)
async def parse(request: Request):
validate_request_body(
request,
Expand Down Expand Up @@ -915,7 +909,7 @@ async def unload_model(request: Request):

app.agent = Agent(lock_store=app.agent.lock_store)

logger.debug("Successfully unload model '{}'.".format(model_file))
logger.debug("Successfully unloaded model '{}'.".format(model_file))
return response.json(None, status=204)

@app.get("/domain")
Expand Down
20 changes: 10 additions & 10 deletions tests/test_server.py
Expand Up @@ -105,13 +105,21 @@ def test_status(rasa_app: SanicTestClient):
assert "model_file" in response.json


def test_status_nlu_only(rasa_app_nlu: SanicTestClient):
_, response = rasa_app_nlu.get("/status")
assert response.status == 200
assert "fingerprint" in response.json
assert "model_file" in response.json


def test_status_secured(rasa_secured_app: SanicTestClient):
_, response = rasa_secured_app.get("/status")
assert response.status == 401


def test_status_not_ready_agent(rasa_app_nlu: SanicTestClient):
_, response = rasa_app_nlu.get("/status")
def test_status_not_ready_agent(rasa_app: SanicTestClient):
rasa_app.app.agent = None
_, response = rasa_app.get("/status")
assert response.status == 409


Expand Down Expand Up @@ -471,11 +479,6 @@ def test_predict(rasa_app: SanicTestClient):
assert "policy" in content


def test_retrieve_tracker_not_ready_agent(rasa_app_nlu: SanicTestClient):
_, response = rasa_app_nlu.get("/conversations/test/tracker")
assert response.status == 409


@freeze_time("2018-01-01")
def test_requesting_non_existent_tracker(rasa_app: SanicTestClient):
_, response = rasa_app.get("/conversations/madeupid/tracker")
Expand Down Expand Up @@ -660,9 +663,6 @@ def test_unload_model_error(rasa_app: SanicTestClient):
_, response = rasa_app.delete("/model")
assert response.status == 204

_, response = rasa_app.get("/status")
assert response.status == 409


def test_get_domain(rasa_app: SanicTestClient):
_, response = rasa_app.get("/domain", headers={"accept": "application/json"})
Expand Down

0 comments on commit cfa6526

Please sign in to comment.