Skip to content

Commit

Permalink
Merge pull request #3914 from RanaMostafaAbdElMohsen/rasa_train_comma…
Browse files Browse the repository at this point in the history
…nd_fix

Fixed crash when --domain not provided for rasa train command
  • Loading branch information
erohmensing committed Jul 8, 2019
2 parents 76a9118 + bc61f20 commit 7097a41
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 17 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Fixed
- ``x in AnySlotDict`` is now ``True`` for any ``x``, which fixes empty slot warnings in
interactive learning
- ``rasa train`` now also includes NLU files in other formats than the Rasa format
- ``rasa train core`` no longer crashes without a ``--domain`` arg


[1.1.4] - 2019-06-18
Expand All @@ -57,7 +58,6 @@ Fixed
- take FB quick reply payload rather than text as input
- fixed bug where `training_data` path in `metadata.json` was an absolute path


[1.1.3] - 2019-06-14
^^^^^^^^^^^^^^^^^^^^

Expand Down
53 changes: 37 additions & 16 deletions rasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,8 @@ async def train_async(
try:
domain = Domain.load(domain, skill_imports)
domain.check_missing_templates()
except InvalidDomain as e:
print_error(
"Could not load domain due to error: {} \nTo specify a valid domain "
"path, use the '--domain' argument.".format(e)
)
return None
except InvalidDomain:
domain = None

story_directory, nlu_data_directory = data.get_core_nlu_directories(
training_files, skill_imports
Expand All @@ -87,6 +83,11 @@ async def train_async(
nlu_data = stack.enter_context(TempDirectoryPath(nlu_data_directory))
story = stack.enter_context(TempDirectoryPath(story_directory))

if domain is None:
return handle_domain_if_not_exists(
config, nlu_data_directory, output_path, fixed_model_name
)

return await _train_async_internal(
domain,
config,
Expand All @@ -99,6 +100,27 @@ async def train_async(
kwargs,
)

if domain is None:
return handle_domain_if_not_exists(
config, nlu_data_directory, output_path, fixed_model_name
)


def handle_domain_if_not_exists(
config, nlu_data_directory, output_path, fixed_model_name
):
nlu_model_only = _train_nlu_with_validated_data(
config=config,
nlu_data_directory=nlu_data_directory,
output=output_path,
fixed_model_name=fixed_model_name,
)
print_warning(
"Core training was skipped because no valid domain file was found. Only an nlu-model was created."
"Please specify a valid domain using '--domain' argument or check if the provided domain file exists."
)
return nlu_model_only


async def _train_async_internal(
domain: Union[Domain, Text],
Expand Down Expand Up @@ -293,16 +315,15 @@ async def train_core_async(

skill_imports = SkillSelector.load(config, stories)

if isinstance(domain, str):
try:
domain = Domain.load(domain, skill_imports)
domain.check_missing_templates()
except InvalidDomain as e:
print_error(
"Could not load domain due to: '{}'. To specify a valid domain path "
"use the '--domain' argument.".format(e)
)
return None
try:
domain = Domain.load(domain, skill_imports)
domain.check_missing_templates()
except InvalidDomain:
print_error(
"Core training was skipped because no valid domain file was found. "
"Please specify a valid domain using '--domain' argument or check if the provided domain file exists."
)
return None

train_context = TempDirectoryPath(data.get_core_directory(stories, skill_imports))

Expand Down
50 changes: 50 additions & 0 deletions tests/cli/test_rasa_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import pytest

from rasa import model

from rasa.cli.train import _get_valid_config
from rasa.constants import (
CONFIG_MANDATORY_KEYS_CORE,
Expand Down Expand Up @@ -37,6 +39,32 @@ def test_train(run_in_default_project):
assert os.path.basename(files[0]) == "test-model.tar.gz"


def test_train_no_domain_exists(run_in_default_project):

os.remove("domain.yml")
run_in_default_project(
"train",
"-c",
"config.yml",
"--data",
"data",
"--out",
"train_models_no_domain",
"--fixed-model-name",
"nlu-model-only",
)

assert os.path.exists("train_models_no_domain")
files = list_files("train_models_no_domain")
assert len(files) == 1

trained_model_path = "train_models_no_domain/nlu-model-only.tar.gz"
unpacked = model.unpack_model(trained_model_path)

metadata_path = os.path.join(unpacked, "nlu", "metadata.json")
assert os.path.exists(metadata_path)


def test_train_skip_on_model_not_changed(run_in_default_project):
temp_dir = os.getcwd()

Expand Down Expand Up @@ -118,6 +146,28 @@ def test_train_core(run_in_default_project):
assert os.path.isfile("train_rasa_models/rasa-model.tar.gz")


def test_train_core_no_domain_exists(run_in_default_project):

os.remove("domain.yml")
run_in_default_project(
"train",
"core",
"--config",
"config.yml",
"--domain",
"domain1.yml",
"--stories",
"data",
"--out",
"train_rasa_models_no_domain",
"--fixed-model-name",
"rasa-model",
)

assert not os.path.exists("train_rasa_models_no_domain/rasa-model.tar.gz")
assert not os.path.isfile("train_rasa_models_no_domain/rasa-model.tar.gz")


def count_rasa_temp_files():
count = 0
for entry in os.scandir(tempfile.gettempdir()):
Expand Down

0 comments on commit 7097a41

Please sign in to comment.