Skip to content

Commit

Permalink
Merge pull request #4627 from ronancummins/add_persist_flag_cli
Browse files Browse the repository at this point in the history
Add flag to CLI to persist NLU data #4599
  • Loading branch information
tmbo committed Oct 18, 2019
2 parents 6432b95 + 7ec5edc commit 1979b0a
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 6 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This project adheres to `Semantic Versioning`_ starting with version 1.0.

Added
-----
- add flag to CLI to persist NLU training data if needed
- log a warning if the ``Interpreter`` picks up an intent or an entity that does not
exist in the domain file.
- added ``DynamoTrackerStore`` to support persistence of agents running on AWS
Expand All @@ -20,7 +21,7 @@ Added
- `CRFEntityExtractor` updated to accept arbitrary token-level features like word
vectors (issues/4214)
- `SpacyFeaturizer` updated to add `ner_features` for `CRFEntityExtractor`
- Sanitizing incoming messages from slack to remove slack formatting like <mailto:xyz@rasa.com|xyz@rasa.com>
- Sanitizing incoming messages from slack to remove slack formatting like <mailto:xyz@rasa.com|xyz@rasa.com>
or <http://url.com|url.com> and substitute it with original content
- Added the ability to configure the number of Sanic worker processes in the HTTP
server (``rasa.server``) and input channel server
Expand Down
12 changes: 12 additions & 0 deletions rasa/cli/arguments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def set_train_arguments(parser: argparse.ArgumentParser):
add_dump_stories_param(parser)

add_model_name_param(parser)
add_persist_nlu_data_param(parser)
add_force_param(parser)


Expand Down Expand Up @@ -50,6 +51,7 @@ def set_train_nlu_arguments(parser: argparse.ArgumentParser):
add_nlu_data_param(parser, help_text="File or folder containing your NLU data.")

add_model_name_param(parser)
add_persist_nlu_data_param(parser)


def add_force_param(parser: Union[argparse.ArgumentParser, argparse._ActionsContainer]):
Expand Down Expand Up @@ -138,3 +140,13 @@ def add_model_name_param(parser: argparse.ArgumentParser):
help="If set, the name of the model file/directory will be set to the given "
"name.",
)


def add_persist_nlu_data_param(
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer]
):
parser.add_argument(
"--persist-nlu-data",
action="store_true",
help="Persist the nlu training data in the saved model.",
)
2 changes: 2 additions & 0 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def train(args: argparse.Namespace) -> Optional[Text]:
output=args.out,
force_training=args.force,
fixed_model_name=args.fixed_model_name,
persist_nlu_training_data=args.persist_nlu_data,
kwargs=extract_additional_arguments(args),
)

Expand Down Expand Up @@ -134,6 +135,7 @@ def train_nlu(
output=output,
train_path=train_path,
fixed_model_name=args.fixed_model_name,
persist_nlu_training_data=args.persist_nlu_data,
)


Expand Down
18 changes: 16 additions & 2 deletions rasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def train(
output: Text = DEFAULT_MODELS_PATH,
force_training: bool = False,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
kwargs: Optional[Dict] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Optional[Text]:
Expand All @@ -40,6 +41,7 @@ def train(
output_path=output,
force_training=force_training,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
kwargs=kwargs,
)
)
Expand Down Expand Up @@ -124,6 +126,8 @@ async def _train_async_internal(
train_path: Directory in which to train the model.
output_path: Output path.
force_training: If `True` retrain model even if data has not changed.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
fixed_model_name: Name of model to be stored.
kwargs: Additional training parameters.
Expand Down Expand Up @@ -363,6 +367,7 @@ def train_nlu(
output: Text,
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
) -> Optional[Text]:
"""Trains an NLU model.
Expand All @@ -373,7 +378,9 @@ def train_nlu(
train_path: If `None` the model will be trained in a temporary
directory, otherwise in the provided directory.
fixed_model_name: Name of the model to be stored.
uncompress: If `True` the model will not be compressed.
persist_nlu_training_data: `True` if the NLU training data should be persisted
with the model.
Returns:
If `train_path` is given it returns the path to the model archive,
Expand All @@ -383,7 +390,14 @@ def train_nlu(

loop = asyncio.get_event_loop()
return loop.run_until_complete(
_train_nlu_async(config, nlu_data, output, train_path, fixed_model_name)
_train_nlu_async(
config,
nlu_data,
output,
train_path,
fixed_model_name,
persist_nlu_training_data,
)
)


Expand Down
79 changes: 76 additions & 3 deletions tests/cli/test_rasa_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

from rasa import model

from rasa.nlu.model import Metadata
from rasa.nlu.training_data import training_data
from rasa.cli.train import _get_valid_config
from rasa.constants import (
CONFIG_MANDATORY_KEYS_CORE,
Expand Down Expand Up @@ -36,6 +37,44 @@ def test_train(run_in_default_project):
files = io_utils.list_files(os.path.join(temp_dir, "train_models"))
assert len(files) == 1
assert os.path.basename(files[0]) == "test-model.tar.gz"
model_dir = model.get_model("train_models")
assert model_dir is not None
metadata = Metadata.load(os.path.join(model_dir, "nlu"))
assert metadata.get("training_data") is None
assert not os.path.exists(
os.path.join(model_dir, "nlu", training_data.DEFAULT_TRAINING_DATA_OUTPUT_PATH)
)


def test_train_persist_nlu_data(run_in_default_project):
temp_dir = os.getcwd()

run_in_default_project(
"train",
"-c",
"config.yml",
"-d",
"domain.yml",
"--data",
"data",
"--out",
"train_models",
"--fixed-model-name",
"test-model",
"--persist-nlu-data",
)

assert os.path.exists(os.path.join(temp_dir, "train_models"))
files = io_utils.list_files(os.path.join(temp_dir, "train_models"))
assert len(files) == 1
assert os.path.basename(files[0]) == "test-model.tar.gz"
model_dir = model.get_model("train_models")
assert model_dir is not None
metadata = Metadata.load(os.path.join(model_dir, "nlu"))
assert metadata.get("training_data") is not None
assert os.path.exists(
os.path.join(model_dir, "nlu", training_data.DEFAULT_TRAINING_DATA_OUTPUT_PATH)
)


def test_train_core_compare(run_in_default_project):
Expand Down Expand Up @@ -257,6 +296,39 @@ def test_train_nlu(run_in_default_project):
files = io_utils.list_files("train_models")
assert len(files) == 1
assert os.path.basename(files[0]).startswith("nlu-")
model_dir = model.get_model("train_models")
assert model_dir is not None
metadata = Metadata.load(os.path.join(model_dir, "nlu"))
assert metadata.get("training_data") is None
assert not os.path.exists(
os.path.join(model_dir, "nlu", training_data.DEFAULT_TRAINING_DATA_OUTPUT_PATH)
)


def test_train_nlu_persist_nlu_data(run_in_default_project):
run_in_default_project(
"train",
"nlu",
"-c",
"config.yml",
"--nlu",
"data/nlu.md",
"--out",
"train_models",
"--persist-nlu-data",
)

assert os.path.exists("train_models")
files = io_utils.list_files("train_models")
assert len(files) == 1
assert os.path.basename(files[0]).startswith("nlu-")
model_dir = model.get_model("train_models")
assert model_dir is not None
metadata = Metadata.load(os.path.join(model_dir, "nlu"))
assert metadata.get("training_data") is not None
assert os.path.exists(
os.path.join(model_dir, "nlu", training_data.DEFAULT_TRAINING_DATA_OUTPUT_PATH)
)


def test_train_nlu_temp_files(run_in_default_project):
Expand All @@ -272,7 +344,7 @@ def test_train_help(run):
[-c CONFIG] [-d DOMAIN] [--out OUT]
[--augmentation AUGMENTATION] [--debug-plots]
[--dump-stories] [--fixed-model-name FIXED_MODEL_NAME]
[--force]
[--persist-nlu-data] [--force]
{core,nlu} ..."""

lines = help_text.split("\n")
Expand All @@ -285,7 +357,8 @@ def test_train_nlu_help(run):
output = run("train", "nlu", "--help")

help_text = """usage: rasa train nlu [-h] [-v] [-vv] [--quiet] [-c CONFIG] [--out OUT]
[-u NLU] [--fixed-model-name FIXED_MODEL_NAME]"""
[-u NLU] [--fixed-model-name FIXED_MODEL_NAME]
[--persist-nlu-data]"""

lines = help_text.split("\n")

Expand Down

0 comments on commit 1979b0a

Please sign in to comment.