Skip to content

Commit

Permalink
Merge pull request #3721 from RasaHQ/config-argument
Browse files Browse the repository at this point in the history
Update validation of config for rasa train
  • Loading branch information
tabergma committed Jun 12, 2019
2 parents f799189 + 2fa8734 commit a068663
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 89 deletions.
5 changes: 3 additions & 2 deletions rasa/cli/arguments/default_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ def add_domain_param(


def add_config_param(
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer]
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer],
default: Optional[Text] = DEFAULT_CONFIG_PATH,
):
parser.add_argument(
"-c",
"--config",
type=str,
default=DEFAULT_CONFIG_PATH,
default=default,
help="The policy and NLU pipeline configuration of your bot.",
)

Expand Down
2 changes: 1 addition & 1 deletion rasa/cli/arguments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
add_out_param,
add_domain_param,
)
from rasa.constants import DEFAULT_CONFIG_PATH, DEFAULT_DATA_PATH
from rasa.constants import DEFAULT_DATA_PATH, DEFAULT_CONFIG_PATH


def set_train_arguments(parser: argparse.ArgumentParser):
Expand Down
52 changes: 47 additions & 5 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
import argparse
import os
import tempfile
from typing import List, Optional, Text, Dict
import rasa.cli.arguments as arguments

from rasa.cli.utils import get_validated_path
from rasa.constants import DEFAULT_CONFIG_PATH, DEFAULT_DATA_PATH, DEFAULT_DOMAIN_PATH
from rasa.cli.utils import (
get_validated_path,
missing_config_keys,
print_error,
print_warning,
)
from rasa.constants import (
DEFAULT_CONFIG_PATH,
DEFAULT_DATA_PATH,
DEFAULT_DOMAIN_PATH,
FALLBACK_CONFIG_PATH,
CONFIG_MANDATORY_KEYS_NLU,
CONFIG_MANDATORY_KEYS_CORE,
CONFIG_MANDATORY_KEYS,
)


# noinspection PyProtectedMember
Expand Down Expand Up @@ -52,7 +66,8 @@ def train(args: argparse.Namespace) -> Optional[Text]:
domain = get_validated_path(
args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True
)
config = args.config or DEFAULT_CONFIG_PATH

config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS)

training_files = [
get_validated_path(f, "data", DEFAULT_DATA_PATH, none_is_valid=True)
Expand Down Expand Up @@ -94,7 +109,7 @@ def train_core(
if isinstance(args.config, list):
args.config = args.config[0]

config = args.config or DEFAULT_CONFIG_PATH
config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE)

return train_core(
domain=args.domain,
Expand All @@ -119,7 +134,7 @@ def train_nlu(

output = train_path or args.out

config = args.config or DEFAULT_CONFIG_PATH
config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_NLU)
nlu_data = get_validated_path(
args.nlu, "nlu", DEFAULT_DATA_PATH, none_is_valid=True
)
Expand All @@ -144,3 +159,30 @@ def extract_additional_arguments(args: argparse.Namespace) -> Dict:
arguments["debug_plots"] = args.debug_plots

return arguments


def _get_valid_config(
config: Optional[Text],
mandatory_keys: List[Text],
default_config: Text = DEFAULT_CONFIG_PATH,
) -> Text:
config = get_validated_path(config, "config", default_config)

if not os.path.exists(config):
print_error(
"The config file '{}' does not exist. Use '--config' to specify a "
"valid config file."
"".format(config)
)
exit(1)

missing_keys = missing_config_keys(config, mandatory_keys)
if missing_keys:
print_error(
"The config file '{}' is missing mandatory parameters: "
"'{}'. Add missing parameters to config file and try again."
"".format(config, "', '".join(missing_keys))
)
exit(1)

return config
43 changes: 1 addition & 42 deletions rasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,12 @@
from rasa.cli.utils import (
create_output_path,
print_success,
missing_config_keys,
print_warning,
get_validated_path,
print_error,
bcolors,
print_color,
)
from rasa.constants import (
DEFAULT_MODELS_PATH,
CONFIG_MANDATORY_KEYS,
CONFIG_MANDATORY_KEYS_CORE,
CONFIG_MANDATORY_KEYS_NLU,
FALLBACK_CONFIG_PATH,
)
from rasa.constants import DEFAULT_MODELS_PATH


def train(
Expand Down Expand Up @@ -73,7 +65,6 @@ async def train_async(
Returns:
Path of the trained model archive.
"""
config = _get_valid_config(config, CONFIG_MANDATORY_KEYS)
train_path = tempfile.mkdtemp()

skill_imports = SkillSelector.load(config, training_files)
Expand Down Expand Up @@ -254,7 +245,6 @@ async def train_core_async(
"""

config = _get_valid_config(config, CONFIG_MANDATORY_KEYS_CORE)
skill_imports = SkillSelector.load(config, stories)

if isinstance(domain, str):
Expand Down Expand Up @@ -353,7 +343,6 @@ def train_nlu(
otherwise the path to the directory with the trained model files.
"""
config = _get_valid_config(config, CONFIG_MANDATORY_KEYS_NLU)

# training NLU only hence the training files still have to be selected
skill_imports = SkillSelector.load(config, nlu_data)
Expand Down Expand Up @@ -409,36 +398,6 @@ def _train_nlu_with_validated_data(
return _train_path


def _enrich_config(
config_path: Text, missing_keys: List[Text], FALLBACK_CONFIG_PATH: Text
):
import rasa.utils.io

config_data = rasa.utils.io.read_yaml_file(config_path)
fallback_config_data = rasa.utils.io.read_yaml_file(FALLBACK_CONFIG_PATH)

for k in missing_keys:
config_data[k] = fallback_config_data[k]

rasa.utils.io.write_yaml_file(config_data, config_path)


def _get_valid_config(config: Text, mandatory_keys: List[Text]) -> Text:
config_path = get_validated_path(config, "config", FALLBACK_CONFIG_PATH)

missing_keys = missing_config_keys(config_path, mandatory_keys)

if missing_keys:
print_warning(
"Configuration file '{}' is missing mandatory parameters: "
"{}. Filling missing parameters from fallback configuration file: '{}'."
"".format(config, ", ".join(missing_keys), FALLBACK_CONFIG_PATH)
)
_enrich_config(config_path, missing_keys, FALLBACK_CONFIG_PATH)

return config_path


def _package_model(
new_fingerprint: Fingerprint,
output_path: Text,
Expand Down
112 changes: 112 additions & 0 deletions tests/cli/test_rasa_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import os
import shutil
import tempfile

import pytest

from rasa.cli.train import _get_valid_config
from rasa.constants import (
CONFIG_MANDATORY_KEYS_CORE,
CONFIG_MANDATORY_KEYS,
CONFIG_MANDATORY_KEYS_NLU,
DEFAULT_CONFIG_PATH,
)
from rasa.nlu.utils import list_files


Expand Down Expand Up @@ -169,3 +179,105 @@ def test_train_core_help(run):

for i, line in enumerate(lines):
assert output.outlines[i] == line


@pytest.mark.parametrize(
"parameters",
[
{
"config_data": {"language": "en", "pipeline": "supervised"},
"default_config": {
"language": "en",
"pipeline": "supervised",
"policies": ["KerasPolicy", "FallbackPolicy"],
},
"mandatory_keys": CONFIG_MANDATORY_KEYS_CORE,
"error": True,
},
{
"config_data": {},
"default_config": {
"language": "en",
"pipeline": "supervised",
"policies": ["KerasPolicy", "FallbackPolicy"],
},
"mandatory_keys": CONFIG_MANDATORY_KEYS,
"error": True,
},
{
"config_data": {
"policies": ["KerasPolicy", "FallbackPolicy"],
"imports": "other-folder",
},
"default_config": {
"language": "en",
"pipeline": "supervised",
"policies": ["KerasPolicy", "FallbackPolicy"],
},
"mandatory_keys": CONFIG_MANDATORY_KEYS_NLU,
"error": True,
},
{
"config_data": None,
"default_config": {
"pipeline": "supervised",
"policies": ["KerasPolicy", "FallbackPolicy"],
},
"mandatory_keys": CONFIG_MANDATORY_KEYS_NLU,
"error": True,
},
{
"config_data": None,
"default_config": {
"language": "en",
"pipeline": "supervised",
"policies": ["KerasPolicy", "FallbackPolicy"],
},
"mandatory_keys": CONFIG_MANDATORY_KEYS,
"error": False,
},
{
"config_data": None,
"default_config": {"language": "en", "pipeline": "supervised"},
"mandatory_keys": CONFIG_MANDATORY_KEYS_CORE,
"error": True,
},
{
"config_data": None,
"default_config": None,
"mandatory_keys": CONFIG_MANDATORY_KEYS,
"error": True,
},
],
)
def test_get_valid_config(parameters):
import rasa.utils.io

config_path = None
if parameters["config_data"] is not None:
config_path = os.path.join(tempfile.mkdtemp(), "config.yml")
rasa.utils.io.write_yaml_file(parameters["config_data"], config_path)

default_config_path = None
if parameters["default_config"] is not None:
default_config_path = os.path.join(tempfile.mkdtemp(), "default-config.yml")
rasa.utils.io.write_yaml_file(parameters["default_config"], default_config_path)

if parameters["error"]:
with pytest.raises(SystemExit):
_get_valid_config(config_path, parameters["mandatory_keys"])

else:
config_path = _get_valid_config(
config_path, parameters["mandatory_keys"], default_config_path
)

config_data = rasa.utils.io.read_yaml_file(config_path)

for k in parameters["mandatory_keys"]:
assert k in config_data


def test_get_valid_config_with_non_existing_file():
with pytest.raises(SystemExit):
_get_valid_config("non-existing-file.yml", CONFIG_MANDATORY_KEYS)
40 changes: 1 addition & 39 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
import os

import pytest
from rasa.constants import (
CONFIG_MANDATORY_KEYS_CORE,
CONFIG_MANDATORY_KEYS,
CONFIG_MANDATORY_KEYS_NLU,
)

from rasa.model import unpack_model

from rasa.train import _package_model, _get_valid_config
from rasa.train import _package_model
from tests.core.test_model import _fingerprint


Expand Down Expand Up @@ -45,36 +40,3 @@ def test_package_model(trained_rasa_model, parameters):
assert parameters["prefix"] in file_name

assert file_name.endswith(".tar.gz")


@pytest.mark.parametrize(
"parameters",
[
{
"config_data": {"language": "en", "pipeline": "supervised"},
"mandatory_keys": CONFIG_MANDATORY_KEYS_CORE,
},
{"config_data": {}, "mandatory_keys": CONFIG_MANDATORY_KEYS},
{
"config_data": {
"policy": ["KerasPolicy", "FallbackPolicy"],
"imports": "other-folder",
},
"mandatory_keys": CONFIG_MANDATORY_KEYS_NLU,
},
],
)
def test_get_valid_config(parameters):
import rasa.utils.io

config_path = os.path.join(tempfile.mkdtemp(), "config.yml")
rasa.utils.io.write_yaml_file(parameters["config_data"], config_path)

config_path = _get_valid_config(config_path, parameters["mandatory_keys"])
config_data = rasa.utils.io.read_yaml_file(config_path)

for k in parameters["mandatory_keys"]:
assert k in config_data

for k, v in parameters["config_data"].items():
assert config_data[k] == v

0 comments on commit a068663

Please sign in to comment.