Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge 7742b3e into 1667307
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyFaulkner committed Jul 2, 2018
2 parents 1667307 + 7742b3e commit 18e2b5a
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 27 deletions.
44 changes: 28 additions & 16 deletions rasa_core/domain.py
Expand Up @@ -28,7 +28,7 @@
ensure_action_name_uniqueness)
from rasa_core.slots import Slot
from rasa_core.trackers import DialogueStateTracker, SlotSet
from rasa_core.utils import read_yaml_file
from rasa_core.utils import read_file, read_yaml_string

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -390,33 +390,37 @@ def templates(self):


class TemplateDomain(Domain):

@classmethod
def load(cls, filename, action_factory=None):
if not os.path.isfile(filename):
raise Exception(
"Failed to load domain specification from '{}'. "
"File not found!".format(os.path.abspath(filename)))
return cls.load_from_yaml(read_file(filename), action_factory=action_factory)

cls.validate_domain_yaml(filename)
data = read_yaml_file(filename)
@classmethod
def load_from_yaml(cls, yaml, action_factory=None):
cls.validate_domain_yaml(yaml)
data = read_yaml_string(yaml)
utter_templates = cls.collect_templates(data.get("templates", {}))
if not action_factory:
action_factory = data.get("action_factory", None)
slots = cls.collect_slots(data.get("slots", {}))
additional_arguments = data.get("config", {})
return TemplateDomain(
data.get("intents", []),
data.get("entities", []),
slots,
utter_templates,
data.get("actions", []),
data.get("action_names", []),
action_factory,
**additional_arguments
return cls(
data.get("intents", []),
data.get("entities", []),
slots,
utter_templates,
data.get("actions", []),
data.get("action_names", []),
action_factory,
**additional_arguments
)

@classmethod
def validate_domain_yaml(cls, filename):
def validate_domain_yaml(cls, input):
"""Validate domain yaml."""
from pykwalify.core import Core

Expand All @@ -425,7 +429,8 @@ def validate_domain_yaml(cls, filename):

schema_file = pkg_resources.resource_filename(__name__,
"schemas/domain.yml")
c = Core(source_data=utils.read_yaml_file(filename),
source_data = utils.read_yaml_string(input)
c = Core(source_data=source_data,
schema_files=[schema_file])
try:
c.validate(raise_exception=True)
Expand All @@ -434,7 +439,7 @@ def validate_domain_yaml(cls, filename):
"Make sure the file is correct, to do so"
"take a look at the errors logged during "
"validation previous to this exception. "
"".format(os.path.abspath(filename)))
"".format(os.path.abspath(input)))

@staticmethod
def collect_slots(slot_dict):
Expand Down Expand Up @@ -496,7 +501,7 @@ def instantiate_actions(factory_name, action_classes, action_names,
def _slot_definitions(self):
return {slot.name: slot.persistence_info() for slot in self.slots}

def persist(self, filename):
def as_dict(self):
additional_config = {
"store_entities_as_slots": self.store_entities_as_slots}
action_names = self.action_names[len(Domain.DEFAULT_ACTIONS):]
Expand All @@ -511,9 +516,16 @@ def persist(self, filename):
"action_names": action_names, # names in stories
"action_factory": self._factory_name
}
return domain_data

def persist(self, filename):
domain_data = self.as_dict()
utils.dump_obj_as_yaml_to_file(filename, domain_data)

def as_yaml(self):
domain_data = self.as_dict()
return utils.dump_obj_as_yaml_to_string(domain_data)

@utils.lazyproperty
def templates(self):
return self._templates
Expand Down
20 changes: 20 additions & 0 deletions rasa_core/server.py
Expand Up @@ -305,6 +305,26 @@ def update_tracker(sender_id):
agent().tracker_store.save(tracker)
return jsonify(tracker.current_state(should_include_events=True))

@app.route("/domain",
methods=['GET'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@ensure_loaded_agent(agent)
def get_domain():
"""Get current domain in yaml format."""
accepts = request.headers.get("Accept", default="application/json")
if accepts.endswith("json"):
domain = agent().domain.as_dict()
return jsonify(domain)
elif accepts.endswith("yml"):
domain_yaml = agent().domain.as_yaml()
return Response(domain_yaml, status=200, content_type="application/x-yml")
else:
return Response(
"""Invalid accept header. Domain can be provided as json ("Accept: application/json") or yml ("Accept: application/x-yml"). Make sure you've set the appropriate Accept header.""",
status=406)


@app.route("/conversations/<sender_id>/parse",
methods=['GET', 'POST', 'OPTIONS'])
@cross_origin(origins=cors_origins)
Expand Down
33 changes: 24 additions & 9 deletions rasa_core/utils.py
Expand Up @@ -18,6 +18,11 @@
from numpy import all, array
from typing import Text, Any, List, Optional, Tuple, Dict, Set

if six.PY2:
from StringIO import StringIO
else:
from io import StringIO

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -304,30 +309,30 @@ def construct_yaml_str(self, node):

def read_yaml_file(filename):
"""Read contents of `filename` interpreting them as yaml."""
return read_yaml_string(read_file(filename))


def read_yaml_string(string):
if six.PY2:
import yaml

fix_yaml_loader()
return yaml.load(read_file(filename, "utf-8"))
return yaml.load(string)
else:
import ruamel.yaml

yaml_parser = ruamel.yaml.YAML(typ="safe")
yaml_parser.allow_unicode = True
yaml_parser.unicode_supplementary = True

return yaml_parser.load(read_file(filename))

return yaml_parser.load(string)

def dump_obj_as_yaml_to_file(filename, obj):
"""Writes data (python dict) to the filename in yaml repr."""

def _dump_yaml(obj, output):
if six.PY2:
import yaml

with io.open(filename, 'w', encoding="utf-8") as yaml_file:
yaml.safe_dump(obj, yaml_file,
yaml.safe_dump(obj, output,
default_flow_style=False,
allow_unicode=True)
else:
Expand All @@ -338,9 +343,19 @@ def dump_obj_as_yaml_to_file(filename, obj):
yaml_writer.default_flow_style = False
yaml_writer.allow_unicode = True

with io.open(filename, 'w', encoding="utf-8") as yaml_file:
yaml_writer.dump(obj, yaml_file)
yaml_writer.dump(obj, output)

def dump_obj_as_yaml_to_file(filename, obj):
"""Writes data (python dict) to the filename in yaml repr."""
with io.open(filename, 'w', encoding="utf-8") as output:
_dump_yaml(obj, output)


def dump_obj_as_yaml_to_string(obj):
"""Writes data (python dict) to a yaml string."""
str_io = StringIO()
_dump_yaml(obj, str_io)
return str_io.getvalue()

def read_file(filename, encoding="utf-8"):
"""Read text from a file."""
Expand Down
24 changes: 22 additions & 2 deletions tests/test_domain.py
Expand Up @@ -9,6 +9,7 @@
from rasa_core import training
from rasa_core.domain import TemplateDomain
from rasa_core.featurizers import MaxHistoryTrackerFeaturizer
from rasa_core.utils import read_file
from tests import utilities
from tests.conftest import DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE

Expand Down Expand Up @@ -133,8 +134,8 @@ def test_utter_templates():

def test_restaurant_domain_is_valid():
# should raise no exception
TemplateDomain.validate_domain_yaml(
'examples/restaurantbot/restaurant_domain.yml')
TemplateDomain.validate_domain_yaml(read_file(
'examples/restaurantbot/restaurant_domain.yml'))


def test_custom_slot_type(tmpdir):
Expand Down Expand Up @@ -166,3 +167,22 @@ def test_domain_fails_on_unknown_custom_slot_type(tmpdir):
- utter_greet""")
with pytest.raises(ValueError):
TemplateDomain.load(domain_path)


def test_domain_to_yaml():
test_yaml = """action_factory: null
action_names:
- utter_greet
actions:
- utter_greet
config:
store_entities_as_slots: true
entities: []
intents: []
slots: {}
templates:
utter_greet:
- text: hey there!"""
domain = TemplateDomain.load_from_yaml(test_yaml)
assert test_yaml.strip() == domain.as_yaml().strip()
domain = TemplateDomain.load_from_yaml(domain.as_yaml())

0 comments on commit 18e2b5a

Please sign in to comment.