diff --git a/rasa_core/domain.py b/rasa_core/domain.py index 6a39be4d9bb..2d3d6b28d88 100644 --- a/rasa_core/domain.py +++ b/rasa_core/domain.py @@ -28,7 +28,7 @@ ensure_action_name_uniqueness) from rasa_core.slots import Slot from rasa_core.trackers import DialogueStateTracker, SlotSet -from rasa_core.utils import read_yaml_file +from rasa_core.utils import read_file, read_yaml_string logger = logging.getLogger(__name__) @@ -390,33 +390,37 @@ def templates(self): class TemplateDomain(Domain): + @classmethod def load(cls, filename, action_factory=None): if not os.path.isfile(filename): raise Exception( "Failed to load domain specification from '{}'. " "File not found!".format(os.path.abspath(filename))) + return cls.load_from_yaml(read_file(filename), action_factory=action_factory) - cls.validate_domain_yaml(filename) - data = read_yaml_file(filename) + @classmethod + def load_from_yaml(cls, yaml, action_factory=None): + cls.validate_domain_yaml(yaml) + data = read_yaml_string(yaml) utter_templates = cls.collect_templates(data.get("templates", {})) if not action_factory: action_factory = data.get("action_factory", None) slots = cls.collect_slots(data.get("slots", {})) additional_arguments = data.get("config", {}) - return TemplateDomain( - data.get("intents", []), - data.get("entities", []), - slots, - utter_templates, - data.get("actions", []), - data.get("action_names", []), - action_factory, - **additional_arguments + return cls( + data.get("intents", []), + data.get("entities", []), + slots, + utter_templates, + data.get("actions", []), + data.get("action_names", []), + action_factory, + **additional_arguments ) @classmethod - def validate_domain_yaml(cls, filename): + def validate_domain_yaml(cls, input): """Validate domain yaml.""" from pykwalify.core import Core @@ -425,7 +429,8 @@ def validate_domain_yaml(cls, filename): schema_file = pkg_resources.resource_filename(__name__, "schemas/domain.yml") - c = Core(source_data=utils.read_yaml_file(filename), + source_data = utils.read_yaml_string(input) + c = Core(source_data=source_data, schema_files=[schema_file]) try: c.validate(raise_exception=True) @@ -434,7 +439,7 @@ def validate_domain_yaml(cls, filename): "Make sure the file is correct, to do so" "take a look at the errors logged during " "validation previous to this exception. " - "".format(os.path.abspath(filename))) + "".format(os.path.abspath(input))) @staticmethod def collect_slots(slot_dict): @@ -496,7 +501,7 @@ def instantiate_actions(factory_name, action_classes, action_names, def _slot_definitions(self): return {slot.name: slot.persistence_info() for slot in self.slots} - def persist(self, filename): + def as_dict(self): additional_config = { "store_entities_as_slots": self.store_entities_as_slots} action_names = self.action_names[len(Domain.DEFAULT_ACTIONS):] @@ -511,9 +516,16 @@ def persist(self, filename): "action_names": action_names, # names in stories "action_factory": self._factory_name } + return domain_data + def persist(self, filename): + domain_data = self.as_dict() utils.dump_obj_as_yaml_to_file(filename, domain_data) + def as_yaml(self): + domain_data = self.as_dict() + return utils.dump_obj_as_yaml_to_string(domain_data) + @utils.lazyproperty def templates(self): return self._templates diff --git a/rasa_core/server.py b/rasa_core/server.py index 940ebe0ef87..829d29c4174 100644 --- a/rasa_core/server.py +++ b/rasa_core/server.py @@ -305,6 +305,26 @@ def update_tracker(sender_id): agent().tracker_store.save(tracker) return jsonify(tracker.current_state(should_include_events=True)) + @app.route("/domain", + methods=['GET']) + @cross_origin(origins=cors_origins) + @requires_auth(auth_token) + @ensure_loaded_agent(agent) + def get_domain(): + """Get current domain in yaml format.""" + accepts = request.headers.get("Accept", default="application/json") + if accepts.endswith("json"): + domain = agent().domain.as_dict() + return jsonify(domain) + elif accepts.endswith("yml"): + domain_yaml = agent().domain.as_yaml() + return Response(domain_yaml, status=200, content_type="application/x-yml") + else: + return Response( + """Invalid accept header. Domain can be provided as json ("Accept: application/json") or yml ("Accept: application/x-yml"). Make sure you've set the appropriate Accept header.""", + status=406) + + @app.route("/conversations//parse", methods=['GET', 'POST', 'OPTIONS']) @cross_origin(origins=cors_origins) diff --git a/rasa_core/utils.py b/rasa_core/utils.py index 47f7f754ffe..cfa602c4105 100644 --- a/rasa_core/utils.py +++ b/rasa_core/utils.py @@ -18,6 +18,11 @@ from numpy import all, array from typing import Text, Any, List, Optional, Tuple, Dict, Set +if six.PY2: + from StringIO import StringIO +else: + from io import StringIO + logger = logging.getLogger(__name__) @@ -304,12 +309,15 @@ def construct_yaml_str(self, node): def read_yaml_file(filename): """Read contents of `filename` interpreting them as yaml.""" + return read_yaml_string(read_file(filename)) + +def read_yaml_string(string): if six.PY2: import yaml fix_yaml_loader() - return yaml.load(read_file(filename, "utf-8")) + return yaml.load(string) else: import ruamel.yaml @@ -317,17 +325,14 @@ def read_yaml_file(filename): yaml_parser.allow_unicode = True yaml_parser.unicode_supplementary = True - return yaml_parser.load(read_file(filename)) - + return yaml_parser.load(string) -def dump_obj_as_yaml_to_file(filename, obj): - """Writes data (python dict) to the filename in yaml repr.""" +def _dump_yaml(obj, output): if six.PY2: import yaml - with io.open(filename, 'w', encoding="utf-8") as yaml_file: - yaml.safe_dump(obj, yaml_file, + yaml.safe_dump(obj, output, default_flow_style=False, allow_unicode=True) else: @@ -338,9 +343,19 @@ def dump_obj_as_yaml_to_file(filename, obj): yaml_writer.default_flow_style = False yaml_writer.allow_unicode = True - with io.open(filename, 'w', encoding="utf-8") as yaml_file: - yaml_writer.dump(obj, yaml_file) + yaml_writer.dump(obj, output) + +def dump_obj_as_yaml_to_file(filename, obj): + """Writes data (python dict) to the filename in yaml repr.""" + with io.open(filename, 'w', encoding="utf-8") as output: + _dump_yaml(obj, output) + +def dump_obj_as_yaml_to_string(obj): + """Writes data (python dict) to a yaml string.""" + str_io = StringIO() + _dump_yaml(obj, str_io) + return str_io.getvalue() def read_file(filename, encoding="utf-8"): """Read text from a file.""" diff --git a/tests/test_domain.py b/tests/test_domain.py index 449885e1d73..282d915e651 100644 --- a/tests/test_domain.py +++ b/tests/test_domain.py @@ -9,6 +9,7 @@ from rasa_core import training from rasa_core.domain import TemplateDomain from rasa_core.featurizers import MaxHistoryTrackerFeaturizer +from rasa_core.utils import read_file from tests import utilities from tests.conftest import DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE @@ -133,8 +134,8 @@ def test_utter_templates(): def test_restaurant_domain_is_valid(): # should raise no exception - TemplateDomain.validate_domain_yaml( - 'examples/restaurantbot/restaurant_domain.yml') + TemplateDomain.validate_domain_yaml(read_file( + 'examples/restaurantbot/restaurant_domain.yml')) def test_custom_slot_type(tmpdir): @@ -166,3 +167,22 @@ def test_domain_fails_on_unknown_custom_slot_type(tmpdir): - utter_greet""") with pytest.raises(ValueError): TemplateDomain.load(domain_path) + + +def test_domain_to_yaml(): + test_yaml = """action_factory: null +action_names: +- utter_greet +actions: +- utter_greet +config: + store_entities_as_slots: true +entities: [] +intents: [] +slots: {} +templates: + utter_greet: + - text: hey there!""" + domain = TemplateDomain.load_from_yaml(test_yaml) + assert test_yaml.strip() == domain.as_yaml().strip() + domain = TemplateDomain.load_from_yaml(domain.as_yaml())