Skip to content

Commit

Permalink
Merge pull request #3769 from mmalhotra/model_server_fix
Browse files Browse the repository at this point in the history
Create EndPointConfig object from model_server in /model
  • Loading branch information
tabergma committed Jul 9, 2019
2 parents 7097a41 + 6ee051c commit a55530a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 3 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Fixed
-----
- all temporal model files are now deleted after stopping the Rasa server
- ``rasa shell nlu`` now outputs unicode characters instead of ``\uxxxx`` codes
- fixed PUT /model with model_server by deserializing the model_server to
EndpointConfig.
- ``x in AnySlotDict`` is now ``True`` for any ``x``, which fixes empty slot warnings in
interactive learning
- ``rasa train`` now also includes NLU files in other formats than the Rasa format
Expand All @@ -51,7 +53,6 @@ Changed
-------
- removed leading underscore from name of '_create_initial_project' function.


Fixed
-----
- fixed bug where facebook quick replies were not rendering
Expand Down
12 changes: 11 additions & 1 deletion rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,17 @@ async def load_model(request: Request):
model_path = request.json.get("model_file", None)
model_server = request.json.get("model_server", None)
remote_storage = request.json.get("remote_storage", None)

if model_server:
try:
model_server = EndpointConfig.from_dict(model_server)
except TypeError as e:
logger.debug(traceback.format_exc())
raise ErrorResponse(
400,
"BadRequest",
"Supplied 'model_server' is not valid. Error: {}".format(e),
{"parameter": "model_server", "in": "body"},
)
app.agent = await _load_agent(
model_path, model_server, remote_storage, endpoints
)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import tempfile
import uuid
from aioresponses import aioresponses

import pytest
from freezegun import freeze_time
Expand All @@ -12,6 +13,7 @@
from rasa.core import events, utils
from rasa.core.events import Event, UserUttered, SlotSet, BotUttered
from rasa.model import unpack_model
from rasa.utils.endpoints import EndpointConfig
from tests.nlu.utilities import ResponseTest


Expand Down Expand Up @@ -663,6 +665,42 @@ def test_load_model(rasa_app, trained_core_model):
assert old_fingerprint != response.json["fingerprint"]


def test_load_model_from_model_server(rasa_app, trained_core_model):
_, response = rasa_app.get("/status")

assert response.status == 200
assert "fingerprint" in response.json

old_fingerprint = response.json["fingerprint"]

endpoint = EndpointConfig("https://example.com/model/trained_core_model")
with open(trained_core_model, "rb") as f:
with aioresponses(passthrough=["http://127.0.0.1"]) as mocked:
headers = {}
fs = os.fstat(f.fileno())
headers["Content-Length"] = str(fs[6])
mocked.get(
"https://example.com/model/trained_core_model",
content_type="application/x-tar",
body=f.read(),
)
data = {"model_server": {"url": endpoint.url}}
_, response = rasa_app.put("/model", json=data)

assert response.status == 204

_, response = rasa_app.get("/status")

assert response.status == 200
assert "fingerprint" in response.json

assert old_fingerprint != response.json["fingerprint"]

import rasa.core.jobs

rasa.core.jobs.__scheduler = None


def test_load_model_invalid_request_body(rasa_app):
_, response = rasa_app.put("/model")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_concat_url(base, subpath, expected_result):
def test_warning_for_base_paths_with_trailing_slash(caplog):
test_path = "base/"

with caplog.at_level(logging.DEBUG):
with caplog.at_level(logging.DEBUG, logger="rasa.utils.endpoints"):
assert concat_url(test_path, None) == test_path

assert len(caplog.records) == 1

0 comments on commit a55530a

Please sign in to comment.