diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 790785a1f481..488f760280cd 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -36,6 +36,7 @@ Fixed - ``x in AnySlotDict`` is now ``True`` for any ``x``, which fixes empty slot warnings in interactive learning - ``rasa train`` now also includes NLU files in other formats than the Rasa format +- ``rasa train core`` no longer crashes without a ``--domain`` arg [1.1.4] - 2019-06-18 @@ -57,7 +58,6 @@ Fixed - take FB quick reply payload rather than text as input - fixed bug where `training_data` path in `metadata.json` was an absolute path - [1.1.3] - 2019-06-14 ^^^^^^^^^^^^^^^^^^^^ diff --git a/rasa/train.py b/rasa/train.py index 018694e89755..0802536fba4e 100644 --- a/rasa/train.py +++ b/rasa/train.py @@ -71,12 +71,8 @@ async def train_async( try: domain = Domain.load(domain, skill_imports) domain.check_missing_templates() - except InvalidDomain as e: - print_error( - "Could not load domain due to error: {} \nTo specify a valid domain " - "path, use the '--domain' argument.".format(e) - ) - return None + except InvalidDomain: + domain = None story_directory, nlu_data_directory = data.get_core_nlu_directories( training_files, skill_imports @@ -87,6 +83,11 @@ async def train_async( nlu_data = stack.enter_context(TempDirectoryPath(nlu_data_directory)) story = stack.enter_context(TempDirectoryPath(story_directory)) + if domain is None: + return handle_domain_if_not_exists( + config, nlu_data_directory, output_path, fixed_model_name + ) + return await _train_async_internal( domain, config, @@ -99,6 +100,27 @@ async def train_async( kwargs, ) + if domain is None: + return handle_domain_if_not_exists( + config, nlu_data_directory, output_path, fixed_model_name + ) + + +def handle_domain_if_not_exists( + config, nlu_data_directory, output_path, fixed_model_name +): + nlu_model_only = _train_nlu_with_validated_data( + config=config, + nlu_data_directory=nlu_data_directory, + output=output_path, + fixed_model_name=fixed_model_name, + ) + print_warning( + "Core training was skipped because no valid domain file was found. Only an nlu-model was created." + "Please specify a valid domain using '--domain' argument or check if the provided domain file exists." + ) + return nlu_model_only + async def _train_async_internal( domain: Union[Domain, Text], @@ -293,16 +315,15 @@ async def train_core_async( skill_imports = SkillSelector.load(config, stories) - if isinstance(domain, str): - try: - domain = Domain.load(domain, skill_imports) - domain.check_missing_templates() - except InvalidDomain as e: - print_error( - "Could not load domain due to: '{}'. To specify a valid domain path " - "use the '--domain' argument.".format(e) - ) - return None + try: + domain = Domain.load(domain, skill_imports) + domain.check_missing_templates() + except InvalidDomain: + print_error( + "Core training was skipped because no valid domain file was found. " + "Please specify a valid domain using '--domain' argument or check if the provided domain file exists." + ) + return None train_context = TempDirectoryPath(data.get_core_directory(stories, skill_imports)) diff --git a/tests/cli/test_rasa_train.py b/tests/cli/test_rasa_train.py index 9062567a4f05..7d3c879030e7 100644 --- a/tests/cli/test_rasa_train.py +++ b/tests/cli/test_rasa_train.py @@ -4,6 +4,8 @@ import pytest +from rasa import model + from rasa.cli.train import _get_valid_config from rasa.constants import ( CONFIG_MANDATORY_KEYS_CORE, @@ -37,6 +39,32 @@ def test_train(run_in_default_project): assert os.path.basename(files[0]) == "test-model.tar.gz" +def test_train_no_domain_exists(run_in_default_project): + + os.remove("domain.yml") + run_in_default_project( + "train", + "-c", + "config.yml", + "--data", + "data", + "--out", + "train_models_no_domain", + "--fixed-model-name", + "nlu-model-only", + ) + + assert os.path.exists("train_models_no_domain") + files = list_files("train_models_no_domain") + assert len(files) == 1 + + trained_model_path = "train_models_no_domain/nlu-model-only.tar.gz" + unpacked = model.unpack_model(trained_model_path) + + metadata_path = os.path.join(unpacked, "nlu", "metadata.json") + assert os.path.exists(metadata_path) + + def test_train_skip_on_model_not_changed(run_in_default_project): temp_dir = os.getcwd() @@ -118,6 +146,28 @@ def test_train_core(run_in_default_project): assert os.path.isfile("train_rasa_models/rasa-model.tar.gz") +def test_train_core_no_domain_exists(run_in_default_project): + + os.remove("domain.yml") + run_in_default_project( + "train", + "core", + "--config", + "config.yml", + "--domain", + "domain1.yml", + "--stories", + "data", + "--out", + "train_rasa_models_no_domain", + "--fixed-model-name", + "rasa-model", + ) + + assert not os.path.exists("train_rasa_models_no_domain/rasa-model.tar.gz") + assert not os.path.isfile("train_rasa_models_no_domain/rasa-model.tar.gz") + + def count_rasa_temp_files(): count = 0 for entry in os.scandir(tempfile.gettempdir()):