This repository has been archived by the owner on Aug 22, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
test_nlg.py
86 lines (60 loc) · 2.36 KB
/
test_nlg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import uuid
import jsonschema
import pytest
from flask import Flask, request, jsonify
from pytest_localserver.http import WSGIServer
from rasa_core import utils
from rasa_core.nlg.callback import (
nlg_request_format_spec,
CallbackNaturalLanguageGenerator)
from rasa_core.utils import EndpointConfig
from rasa_core.agent import Agent
from tests.conftest import DEFAULT_ENDPOINTS_FILE
def nlg_app(base_url="/"):
app = Flask(__name__)
@app.route(base_url, methods=['POST'])
def generate():
"""Simple HTTP NLG generator, checks that the incoming request
is format according to the spec."""
nlg_call = request.json
jsonschema.validate(nlg_call, nlg_request_format_spec())
if nlg_call.get("template") == "utter_greet":
response = {"text": "Hey there!"}
else:
response = {"text": "Sorry, didn't get that."}
return jsonify(response)
return app
@pytest.fixture(scope="module")
def http_nlg(request):
http_server = WSGIServer(application=nlg_app())
http_server.start()
request.addfinalizer(http_server.stop)
return http_server.url
def test_nlg(http_nlg, default_agent_path):
sender = str(uuid.uuid1())
nlg_endpoint = EndpointConfig.from_dict({
"url": http_nlg
})
agent = Agent.load(default_agent_path, None,
generator=nlg_endpoint)
response = agent.handle_message("/greet", sender_id=sender)
assert len(response) == 1
assert response[0] == {"text": "Hey there!", "recipient_id": sender}
def test_nlg_endpoint_config_loading():
cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "nlg")
assert cfg == EndpointConfig.from_dict({
"url": "http://localhost:5055/nlg"
})
def test_nlg_schema_validation():
content = {"text": "Hey there!"}
assert CallbackNaturalLanguageGenerator.validate_response(content)
def test_nlg_schema_validation_empty_buttons():
content = {"text": "Hey there!", "buttons": []}
assert CallbackNaturalLanguageGenerator.validate_response(content)
def test_nlg_schema_validation_empty_image():
content = {"text": "Hey there!", "image": None}
assert CallbackNaturalLanguageGenerator.validate_response(content)