Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Domain API endpoint #679

Merged
merged 11 commits into from Jul 5, 2018
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 31 additions & 16 deletions rasa_core/domain.py
Expand Up @@ -28,7 +28,7 @@
ensure_action_name_uniqueness)
from rasa_core.slots import Slot
from rasa_core.trackers import DialogueStateTracker, SlotSet
from rasa_core.utils import read_yaml_file
from rasa_core.utils import read_file, read_yaml_string

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -390,33 +390,37 @@ def templates(self):


class TemplateDomain(Domain):

@classmethod
def load(cls, filename, action_factory=None):
if not os.path.isfile(filename):
raise Exception(
"Failed to load domain specification from '{}'. "
"File not found!".format(os.path.abspath(filename)))
return cls.load_from_yaml(read_file(filename), action_factory=action_factory)

cls.validate_domain_yaml(filename)
data = read_yaml_file(filename)
@classmethod
def load_from_yaml(cls, yaml, action_factory=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can load use load_from_yaml once it read the yaml from the file as well? want to avoid the duplication

cls.validate_domain_yaml(yaml, string_input=True)
data = read_yaml_string(yaml)
utter_templates = cls.collect_templates(data.get("templates", {}))
if not action_factory:
action_factory = data.get("action_factory", None)
slots = cls.collect_slots(data.get("slots", {}))
additional_arguments = data.get("config", {})
return TemplateDomain(
data.get("intents", []),
data.get("entities", []),
slots,
utter_templates,
data.get("actions", []),
data.get("action_names", []),
action_factory,
**additional_arguments
return cls(
data.get("intents", []),
data.get("entities", []),
slots,
utter_templates,
data.get("actions", []),
data.get("action_names", []),
action_factory,
**additional_arguments
)

@classmethod
def validate_domain_yaml(cls, filename):
def validate_domain_yaml(cls, input, string_input=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input needs to be renamed (shadows build in function)

do we need string_input? I think it is always True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it wasnt in tests but i changed tests :)

"""Validate domain yaml."""
from pykwalify.core import Core

Expand All @@ -425,7 +429,11 @@ def validate_domain_yaml(cls, filename):

schema_file = pkg_resources.resource_filename(__name__,
"schemas/domain.yml")
c = Core(source_data=utils.read_yaml_file(filename),
if not string_input:
source_data = utils.read_yaml_file(input)
else:
source_data = utils.read_yaml_string(input)
c = Core(source_data=source_data,
schema_files=[schema_file])
try:
c.validate(raise_exception=True)
Expand All @@ -434,7 +442,7 @@ def validate_domain_yaml(cls, filename):
"Make sure the file is correct, to do so"
"take a look at the errors logged during "
"validation previous to this exception. "
"".format(os.path.abspath(filename)))
"".format(os.path.abspath(input)))

@staticmethod
def collect_slots(slot_dict):
Expand Down Expand Up @@ -496,7 +504,7 @@ def instantiate_actions(factory_name, action_classes, action_names,
def _slot_definitions(self):
return {slot.name: slot.persistence_info() for slot in self.slots}

def persist(self, filename):
def _make_domain_data(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def as_dict(self) (for consistency across classes)

additional_config = {
"store_entities_as_slots": self.store_entities_as_slots}
action_names = self.action_names[len(Domain.DEFAULT_ACTIONS):]
Expand All @@ -511,9 +519,16 @@ def persist(self, filename):
"action_names": action_names, # names in stories
"action_factory": self._factory_name
}
return domain_data

def persist(self, filename):
domain_data = self._make_domain_data()
utils.dump_obj_as_yaml_to_file(filename, domain_data)

def to_yaml(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as_yaml for consistency

domain_data = self._make_domain_data()
return utils.dump_obj_as_yaml_to_string(domain_data)

@utils.lazyproperty
def templates(self):
return self._templates
Expand Down
10 changes: 10 additions & 0 deletions rasa_core/server.py
Expand Up @@ -305,6 +305,16 @@ def update_tracker(sender_id):
agent().tracker_store.save(tracker)
return jsonify(tracker.current_state(should_include_events=True))

@app.route("/domain",
methods=['GET'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@ensure_loaded_agent(agent)
def get_domain_yaml():
"""Get current domain in yaml format."""
domain_yaml = agent().domain.to_yaml()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we sending the domain as yaml instead of json?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean its the way that the domain is stored in the model folder so why not?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a REST api, so usually the responses are in json format. I mean I don't have a strong opinion on this, but if yaml is used, content negotiation should be used: e.g.

accepts = request.headers.get("Accept", default="application/json")
if accepts.endswith("json"):
  domain = agent().domain.as_dict()
  return jsonify(domain)
elif accepts.endswith("yml"):
  domain = agent().domain.as_yaml()
  return Response(domain_yaml, status=200, content_type="application/x-yml")
else:
  return Response("""Invalid accept header. Domain can be provided as json ("Accept: application/json") or yml ("Accept: application/x-yml"). Make sure you've set the appropriate Accept header.""", status=406)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok cool!

return jsonify(domain_yaml)

@app.route("/conversations/<sender_id>/parse",
methods=['GET', 'POST', 'OPTIONS'])
@cross_origin(origins=cors_origins)
Expand Down
33 changes: 24 additions & 9 deletions rasa_core/utils.py
Expand Up @@ -18,6 +18,11 @@
from numpy import all, array
from typing import Text, Any, List, Optional, Tuple, Dict, Set

if six.PY2:
from StringIO import StringIO
else:
from io import StringIO

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -304,30 +309,30 @@ def construct_yaml_str(self, node):

def read_yaml_file(filename):
"""Read contents of `filename` interpreting them as yaml."""
return read_yaml_string(read_file(filename))


def read_yaml_string(string):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we then rewrite read_yaml_file to just be read_yaml_string(read_file(filename)) ?

if six.PY2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here an example of using six

import yaml

fix_yaml_loader()
return yaml.load(read_file(filename, "utf-8"))
return yaml.load(string)
else:
import ruamel.yaml

yaml_parser = ruamel.yaml.YAML(typ="safe")
yaml_parser.allow_unicode = True
yaml_parser.unicode_supplementary = True

return yaml_parser.load(read_file(filename))

return yaml_parser.load(string)

def dump_obj_as_yaml_to_file(filename, obj):
"""Writes data (python dict) to the filename in yaml repr."""

def _dump_yaml(obj, output):
if six.PY2:
import yaml

with io.open(filename, 'w', encoding="utf-8") as yaml_file:
yaml.safe_dump(obj, yaml_file,
yaml.safe_dump(obj, output,
default_flow_style=False,
allow_unicode=True)
else:
Expand All @@ -338,9 +343,19 @@ def dump_obj_as_yaml_to_file(filename, obj):
yaml_writer.default_flow_style = False
yaml_writer.allow_unicode = True

with io.open(filename, 'w', encoding="utf-8") as yaml_file:
yaml_writer.dump(obj, yaml_file)
yaml_writer.dump(obj, output)

def dump_obj_as_yaml_to_file(filename, obj):
"""Writes data (python dict) to the filename in yaml repr."""
with io.open(filename, 'w', encoding="utf-8") as output:
_dump_yaml(obj, output)


def dump_obj_as_yaml_to_string(obj):
"""Writes data (python dict) to a yaml string."""
str_io = StringIO()
_dump_yaml(obj, str_io)
return str_io.getvalue()

def read_file(filename, encoding="utf-8"):
"""Read text from a file."""
Expand Down
19 changes: 19 additions & 0 deletions tests/test_domain.py
Expand Up @@ -166,3 +166,22 @@ def test_domain_fails_on_unknown_custom_slot_type(tmpdir):
- utter_greet""")
with pytest.raises(ValueError):
TemplateDomain.load(domain_path)


def test_domain_to_yaml():
test_yaml = """action_factory: null
action_names:
- utter_greet
actions:
- utter_greet
config:
store_entities_as_slots: true
entities: []
intents: []
slots: {}
templates:
utter_greet:
- text: hey there!"""
domain = TemplateDomain.load_from_yaml(test_yaml)
assert test_yaml.strip() == domain.to_yaml().strip()
domain = TemplateDomain.load_from_yaml(domain.to_yaml())