Domain API endpoint #679
Domain API endpoint #679
Changes from 8 commits
2f650be
73ca75c
df88903
3cc7047
114af62
4d3ae82
fea258e
8bb98b4
949356a
7742b3e
66265cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
ensure_action_name_uniqueness) | ||
from rasa_core.slots import Slot | ||
from rasa_core.trackers import DialogueStateTracker, SlotSet | ||
from rasa_core.utils import read_yaml_file | ||
from rasa_core.utils import read_file, read_yaml_string | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -390,33 +390,37 @@ def templates(self): | |
|
||
|
||
class TemplateDomain(Domain): | ||
|
||
@classmethod | ||
def load(cls, filename, action_factory=None): | ||
if not os.path.isfile(filename): | ||
raise Exception( | ||
"Failed to load domain specification from '{}'. " | ||
"File not found!".format(os.path.abspath(filename))) | ||
return cls.load_from_yaml(read_file(filename), action_factory=action_factory) | ||
|
||
cls.validate_domain_yaml(filename) | ||
data = read_yaml_file(filename) | ||
@classmethod | ||
def load_from_yaml(cls, yaml, action_factory=None): | ||
cls.validate_domain_yaml(yaml, string_input=True) | ||
data = read_yaml_string(yaml) | ||
utter_templates = cls.collect_templates(data.get("templates", {})) | ||
if not action_factory: | ||
action_factory = data.get("action_factory", None) | ||
slots = cls.collect_slots(data.get("slots", {})) | ||
additional_arguments = data.get("config", {}) | ||
return TemplateDomain( | ||
data.get("intents", []), | ||
data.get("entities", []), | ||
slots, | ||
utter_templates, | ||
data.get("actions", []), | ||
data.get("action_names", []), | ||
action_factory, | ||
**additional_arguments | ||
return cls( | ||
data.get("intents", []), | ||
data.get("entities", []), | ||
slots, | ||
utter_templates, | ||
data.get("actions", []), | ||
data.get("action_names", []), | ||
action_factory, | ||
**additional_arguments | ||
) | ||
|
||
@classmethod | ||
def validate_domain_yaml(cls, filename): | ||
def validate_domain_yaml(cls, input, string_input=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it wasnt in tests but i changed tests :) |
||
"""Validate domain yaml.""" | ||
from pykwalify.core import Core | ||
|
||
|
@@ -425,7 +429,11 @@ def validate_domain_yaml(cls, filename): | |
|
||
schema_file = pkg_resources.resource_filename(__name__, | ||
"schemas/domain.yml") | ||
c = Core(source_data=utils.read_yaml_file(filename), | ||
if not string_input: | ||
source_data = utils.read_yaml_file(input) | ||
else: | ||
source_data = utils.read_yaml_string(input) | ||
c = Core(source_data=source_data, | ||
schema_files=[schema_file]) | ||
try: | ||
c.validate(raise_exception=True) | ||
|
@@ -434,7 +442,7 @@ def validate_domain_yaml(cls, filename): | |
"Make sure the file is correct, to do so" | ||
"take a look at the errors logged during " | ||
"validation previous to this exception. " | ||
"".format(os.path.abspath(filename))) | ||
"".format(os.path.abspath(input))) | ||
|
||
@staticmethod | ||
def collect_slots(slot_dict): | ||
|
@@ -496,7 +504,7 @@ def instantiate_actions(factory_name, action_classes, action_names, | |
def _slot_definitions(self): | ||
return {slot.name: slot.persistence_info() for slot in self.slots} | ||
|
||
def persist(self, filename): | ||
def _make_domain_data(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
additional_config = { | ||
"store_entities_as_slots": self.store_entities_as_slots} | ||
action_names = self.action_names[len(Domain.DEFAULT_ACTIONS):] | ||
|
@@ -511,9 +519,16 @@ def persist(self, filename): | |
"action_names": action_names, # names in stories | ||
"action_factory": self._factory_name | ||
} | ||
return domain_data | ||
|
||
def persist(self, filename): | ||
domain_data = self._make_domain_data() | ||
utils.dump_obj_as_yaml_to_file(filename, domain_data) | ||
|
||
def to_yaml(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
domain_data = self._make_domain_data() | ||
return utils.dump_obj_as_yaml_to_string(domain_data) | ||
|
||
@utils.lazyproperty | ||
def templates(self): | ||
return self._templates | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -305,6 +305,16 @@ def update_tracker(sender_id): | |
agent().tracker_store.save(tracker) | ||
return jsonify(tracker.current_state(should_include_events=True)) | ||
|
||
@app.route("/domain", | ||
methods=['GET']) | ||
@cross_origin(origins=cors_origins) | ||
@requires_auth(auth_token) | ||
@ensure_loaded_agent(agent) | ||
def get_domain_yaml(): | ||
"""Get current domain in yaml format.""" | ||
domain_yaml = agent().domain.to_yaml() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we sending the domain as yaml instead of json? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean its the way that the domain is stored in the model folder so why not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a REST api, so usually the responses are in json format. I mean I don't have a strong opinion on this, but if yaml is used, content negotiation should be used: e.g. accepts = request.headers.get("Accept", default="application/json")
if accepts.endswith("json"):
domain = agent().domain.as_dict()
return jsonify(domain)
elif accepts.endswith("yml"):
domain = agent().domain.as_yaml()
return Response(domain_yaml, status=200, content_type="application/x-yml")
else:
return Response("""Invalid accept header. Domain can be provided as json ("Accept: application/json") or yml ("Accept: application/x-yml"). Make sure you've set the appropriate Accept header.""", status=406) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok cool! |
||
return jsonify(domain_yaml) | ||
|
||
@app.route("/conversations/<sender_id>/parse", | ||
methods=['GET', 'POST', 'OPTIONS']) | ||
@cross_origin(origins=cors_origins) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,11 @@ | |
from numpy import all, array | ||
from typing import Text, Any, List, Optional, Tuple, Dict, Set | ||
|
||
if six.PY2: | ||
from StringIO import StringIO | ||
else: | ||
from io import StringIO | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
@@ -304,30 +309,30 @@ def construct_yaml_str(self, node): | |
|
||
def read_yaml_file(filename): | ||
"""Read contents of `filename` interpreting them as yaml.""" | ||
return read_yaml_string(read_file(filename)) | ||
|
||
|
||
def read_yaml_string(string): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we then rewrite |
||
if six.PY2: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here an example of using |
||
import yaml | ||
|
||
fix_yaml_loader() | ||
return yaml.load(read_file(filename, "utf-8")) | ||
return yaml.load(string) | ||
else: | ||
import ruamel.yaml | ||
|
||
yaml_parser = ruamel.yaml.YAML(typ="safe") | ||
yaml_parser.allow_unicode = True | ||
yaml_parser.unicode_supplementary = True | ||
|
||
return yaml_parser.load(read_file(filename)) | ||
|
||
return yaml_parser.load(string) | ||
|
||
def dump_obj_as_yaml_to_file(filename, obj): | ||
"""Writes data (python dict) to the filename in yaml repr.""" | ||
|
||
def _dump_yaml(obj, output): | ||
if six.PY2: | ||
import yaml | ||
|
||
with io.open(filename, 'w', encoding="utf-8") as yaml_file: | ||
yaml.safe_dump(obj, yaml_file, | ||
yaml.safe_dump(obj, output, | ||
default_flow_style=False, | ||
allow_unicode=True) | ||
else: | ||
|
@@ -338,9 +343,19 @@ def dump_obj_as_yaml_to_file(filename, obj): | |
yaml_writer.default_flow_style = False | ||
yaml_writer.allow_unicode = True | ||
|
||
with io.open(filename, 'w', encoding="utf-8") as yaml_file: | ||
yaml_writer.dump(obj, yaml_file) | ||
yaml_writer.dump(obj, output) | ||
|
||
def dump_obj_as_yaml_to_file(filename, obj): | ||
"""Writes data (python dict) to the filename in yaml repr.""" | ||
with io.open(filename, 'w', encoding="utf-8") as output: | ||
_dump_yaml(obj, output) | ||
|
||
|
||
def dump_obj_as_yaml_to_string(obj): | ||
"""Writes data (python dict) to a yaml string.""" | ||
str_io = StringIO() | ||
_dump_yaml(obj, str_io) | ||
return str_io.getvalue() | ||
|
||
def read_file(filename, encoding="utf-8"): | ||
"""Read text from a file.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can
load
useload_from_yaml
once it read the yaml from the file as well? want to avoid the duplication