Skip to content

Commit

Permalink
Merge pull request #4094 from RasaHQ/rasa-core-comparision
Browse files Browse the repository at this point in the history
fix 'rasa test core'
  • Loading branch information
tabergma authored Jul 26, 2019
2 parents b01582e + d0b5d48 commit ad41208
Show file tree
Hide file tree
Showing 20 changed files with 312 additions and 204 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This project adheres to `Semantic Versioning`_ starting with version 1.0.

Added
-----
- add ``--evaluate-model-directory`` to ``rasa test core`` to evaluate models from ``rasa train core -c <config-1> <config-2>``

Changed
-------
Expand All @@ -21,6 +22,7 @@ Removed

Fixed
-----
- ``rasa test core`` can handle compressed model files


[1.1.8] - 2019-07-25
Expand All @@ -47,6 +49,7 @@ Fixed
- ``rasa train core`` in comparison mode stores the model files compressed (``tar.gz`` files)
- slot setting in interactive learning with the TwoStageFallbackPolicy


[1.1.7] - 2019-07-18
^^^^^^^^^^^^^^^^^^^^

Expand Down
6 changes: 3 additions & 3 deletions docs/user-guide/evaluating-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ mode to evaluate the models you just trained:

.. code-block:: bash
$ rasa test core -m comparison_models/<model-1>.tar.gz comparison_models/<model-2>.tar.gz \
--stories stories_folder --out comparison_results
$ rasa test core -m comparison_models --stories stories_folder
--out comparison_results --evaluate-model-directory
This will evaluate each of the models on the training set and plot some graphs
to show you which policy performs best. By evaluating on the full set of stories, you
to show you which policy performs best. By evaluating on the full set of stories, you
can measure how well Rasa Core is predicting the held-out stories.

If you're not sure which policies to compare, we'd recommend trying out the
Expand Down
1 change: 0 additions & 1 deletion rasa/cli/arguments/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
add_nlu_data_param,
add_out_param,
add_data_param,
add_stories_param,
add_domain_param,
)

Expand Down
19 changes: 14 additions & 5 deletions rasa/cli/arguments/test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from typing import Union

from rasa.constants import DEFAULT_MODELS_PATH, DEFAULT_CONFIG_PATH
from rasa.constants import DEFAULT_MODELS_PATH, DEFAULT_RESULTS_PATH

from rasa.cli.arguments.default_arguments import (
add_stories_param,
Expand Down Expand Up @@ -42,7 +42,7 @@ def add_test_core_argument_group(
)
add_out_param(
parser,
default="results",
default=DEFAULT_RESULTS_PATH,
help_text="Output path for any files created during the evaluation.",
)
parser.add_argument(
Expand Down Expand Up @@ -70,6 +70,15 @@ def add_test_core_argument_group(
"trains on it. Fetches the data by sending a GET request "
"to the supplied URL.",
)
parser.add_argument(
"--evaluate-model-directory",
default=False,
action="store_true",
help="Should be set to evaluate models trained via "
"'rasa train core --config <config-1> <config-2>'. "
"All models in the provided directory are evaluated "
"and compared against each other.",
)


def add_test_nlu_argument_group(
Expand Down Expand Up @@ -150,7 +159,7 @@ def add_test_nlu_argument_group(
required=False,
nargs="+",
type=int,
default=[0, 25, 50, 75, 90],
default=[0, 25, 50, 75],
help="Percentages of training data to exclude during comparison.",
)

Expand All @@ -164,6 +173,6 @@ def add_test_core_model_param(parser: argparse.ArgumentParser):
default=[default_path],
help="Path to a pre-trained model. If it is a 'tar.gz' file that model file "
"will be used. If it is a directory, the latest model in that directory "
"will be used. If multiple 'tar.gz' files are provided, all those models "
"will be compared.",
"will be used (exception: '--evaluate-model-directory' flag is set). If multiple "
"'tar.gz' files are provided, all those models will be compared.",
)
2 changes: 1 addition & 1 deletion rasa/cli/arguments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def add_compare_params(
"--percentages",
nargs="*",
type=int,
default=[0, 5, 25, 50, 70, 90, 95],
default=[0, 25, 50, 75],
help="Range of exclusion percentages.",
)
parser.add_argument(
Expand Down
57 changes: 33 additions & 24 deletions rasa/cli/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import os
from typing import List

from rasa import data
from rasa.cli.arguments import test as arguments
from rasa.cli.utils import get_validated_path
from rasa.constants import (
DEFAULT_CONFIG_PATH,
DEFAULT_DATA_PATH,
Expand All @@ -15,8 +13,8 @@
DEFAULT_NLU_RESULTS_PATH,
CONFIG_SCHEMA_FILE,
)
from rasa.test import test_compare_core, compare_nlu_models
from rasa.utils.validation import validate_yaml_schema, InvalidYamlFileError
import rasa.utils.validation as validation_utils
import rasa.cli.utils as cli_utils

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,12 +57,13 @@ def add_subparser(


def test_core(args: argparse.Namespace) -> None:
from rasa.test import test_core
from rasa import data
from rasa.test import test_core_models_in_directory, test_core, test_core_models

endpoints = get_validated_path(
endpoints = cli_utils.get_validated_path(
args.endpoints, "endpoints", DEFAULT_ENDPOINTS_PATH, True
)
stories = get_validated_path(args.stories, "stories", DEFAULT_DATA_PATH)
stories = cli_utils.get_validated_path(args.stories, "stories", DEFAULT_DATA_PATH)
stories = data.get_core_directory(stories)
output = args.out or DEFAULT_RESULTS_PATH

Expand All @@ -75,25 +74,31 @@ def test_core(args: argparse.Namespace) -> None:
args.model = args.model[0]

if isinstance(args.model, str):
model_path = get_validated_path(args.model, "model", DEFAULT_MODELS_PATH)

test_core(
model=model_path,
stories=stories,
endpoints=endpoints,
output=output,
kwargs=vars(args),
model_path = cli_utils.get_validated_path(
args.model, "model", DEFAULT_MODELS_PATH
)

if args.evaluate_model_directory:
test_core_models_in_directory(args.model, stories, output)
else:
test_core(
model=model_path,
stories=stories,
endpoints=endpoints,
output=output,
kwargs=vars(args),
)

else:
test_compare_core(args.model, stories, output)
test_core_models(args.model, stories, output)


def test_nlu(args: argparse.Namespace) -> None:
from rasa.test import test_nlu, perform_nlu_cross_validation
import rasa.utils.io
from rasa import data
import rasa.utils.io as io_utils
from rasa.test import compare_nlu_models, perform_nlu_cross_validation, test_nlu

nlu_data = get_validated_path(args.nlu, "nlu", DEFAULT_DATA_PATH)
nlu_data = cli_utils.get_validated_path(args.nlu, "nlu", DEFAULT_DATA_PATH)
nlu_data = data.get_nlu_directory(nlu_data)

if args.config is not None and len(args.config) == 1:
Expand All @@ -114,13 +119,13 @@ def test_nlu(args: argparse.Namespace) -> None:
config_files = []
for file in args.config:
try:
validate_yaml_schema(
rasa.utils.io.read_file(file),
validation_utils.validate_yaml_schema(
io_utils.read_file(file),
CONFIG_SCHEMA_FILE,
show_validation_errors=False,
)
config_files.append(file)
except InvalidYamlFileError:
except validation_utils.InvalidYamlFileError:
logger.debug(
"Ignoring file '{}' as it is not a valid config file.".format(file)
)
Expand All @@ -136,10 +141,14 @@ def test_nlu(args: argparse.Namespace) -> None:
)
elif args.cross_validation:
logger.info("Test model using cross validation.")
config = get_validated_path(args.config, "config", DEFAULT_CONFIG_PATH)
config = cli_utils.get_validated_path(
args.config, "config", DEFAULT_CONFIG_PATH
)
perform_nlu_cross_validation(config, nlu_data, vars(args))
else:
model_path = get_validated_path(args.model, "model", DEFAULT_MODELS_PATH)
model_path = cli_utils.get_validated_path(
args.model, "model", DEFAULT_MODELS_PATH
)
test_nlu(model_path, nlu_data, vars(args))


Expand Down
6 changes: 3 additions & 3 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def train_core(
args.domain = get_validated_path(
args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True
)
stories = get_validated_path(
story_file = get_validated_path(
args.stories, "stories", DEFAULT_DATA_PATH, none_is_valid=True
)

Expand All @@ -105,7 +105,7 @@ def train_core(
return train_core(
domain=args.domain,
config=config,
stories=stories,
stories=story_file,
output=output,
train_path=train_path,
fixed_model_name=args.fixed_model_name,
Expand All @@ -114,7 +114,7 @@ def train_core(
else:
from rasa.core.train import do_compare_training

loop.run_until_complete(do_compare_training(args, stories))
loop.run_until_complete(do_compare_training(args, story_file))


def train_nlu(
Expand Down
73 changes: 48 additions & 25 deletions rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,9 @@ def collect_story_predictions(
story_eval_store = EvaluationStore()
failed = []
correct_dialogues = []
num_stories = len(completed_trackers)
number_of_stories = len(completed_trackers)

logger.info("Evaluating {} stories\nProgress:".format(num_stories))
logger.info("Evaluating {} stories\nProgress:".format(number_of_stories))

action_list = []

Expand Down Expand Up @@ -451,7 +451,7 @@ def collect_story_predictions(
action_list=action_list,
in_training_data_fraction=in_training_data_fraction,
),
num_stories,
number_of_stories,
)


Expand Down Expand Up @@ -587,38 +587,61 @@ def plot_story_evaluation(
fig.savefig(os.path.join(out_directory, "story_confmat.pdf"), bbox_inches="tight")


async def compare(models: Text, stories_file: Text, output: Text) -> None:
"""Evaluates multiple trained models on a test set."""
from rasa.core.agent import Agent
import rasa.nlu.utils as nlu_utils
async def compare_models_in_dir(
model_dir: Text, stories_file: Text, output: Text
) -> None:
"""Evaluates multiple trained models in a directory on a test set."""
from rasa.core import utils
import rasa.utils.io as io_utils

num_correct = defaultdict(list)

for run in nlu_utils.list_subdirectories(models):
num_correct_run = defaultdict(list)
number_correct = defaultdict(list)

for model in sorted(nlu_utils.list_subdirectories(run)):
logger.info("Evaluating model {}".format(model))
for run in io_utils.list_subdirectories(model_dir):
number_correct_in_run = defaultdict(list)

agent = Agent.load(model)
for model in sorted(io_utils.list_files(run)):
if not model.endswith("tar.gz"):
continue

completed_trackers = await _generate_trackers(stories_file, agent)

story_eval_store, no_of_stories = collect_story_predictions(
completed_trackers, agent
)

failed_stories = story_eval_store.failed_stories
# The model files are named like <policy-name><number>.tar.gz
# Remove the number from the name to get the policy name
policy_name = "".join(
[i for i in os.path.basename(model) if not i.isdigit()]
)
num_correct_run[policy_name].append(no_of_stories - len(failed_stories))
number_of_correct_stories = await _evaluate_core_model(model, stories_file)
number_correct_in_run[policy_name].append(number_of_correct_stories)

for k, v in number_correct_in_run.items():
number_correct[k].append(v)

utils.dump_obj_as_json_to_file(os.path.join(output, RESULTS_FILE), number_correct)


async def compare_models(models: List[Text], stories_file: Text, output: Text) -> None:
"""Evaluates provided trained models on a test set."""
from rasa.core import utils

number_correct = defaultdict(list)

for model in models:
number_of_correct_stories = await _evaluate_core_model(model, stories_file)
number_correct[os.path.basename(model)].append(number_of_correct_stories)

utils.dump_obj_as_json_to_file(os.path.join(output, RESULTS_FILE), number_correct)

for k, v in num_correct_run.items():
num_correct[k].append(v)

utils.dump_obj_as_json_to_file(os.path.join(output, "results.json"), num_correct)
async def _evaluate_core_model(model: Text, stories_file: Text) -> int:
from rasa.core.agent import Agent

logger.info("Evaluating model '{}'".format(model))

agent = Agent.load(model)
completed_trackers = await _generate_trackers(stories_file, agent)
story_eval_store, number_of_stories = collect_story_predictions(
completed_trackers, agent
)
failed_stories = story_eval_store.failed_stories
return number_of_stories - len(failed_stories)


def plot_nlu_results(output: Text, number_of_examples: List[int]) -> None:
Expand Down
5 changes: 2 additions & 3 deletions rasa/core/training/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from typing import Optional, List, Text, Any, Dict, TYPE_CHECKING, Iterable

import rasa.utils.io as io_utils
from rasa.constants import DOCS_BASE_URL
from rasa.core import utils
from rasa.core.constants import INTENT_MESSAGE_PREFIX
Expand Down Expand Up @@ -175,16 +176,14 @@ async def read_from_folder(
exclusion_percentage: Optional[int] = None,
) -> List[StoryStep]:
"""Given a path reads all contained story files."""
import rasa.nlu.utils as nlu_utils

if not os.path.exists(resource_name):
raise ValueError(
"Story file or folder could not be found. Make "
"sure '{}' exists and points to a story folder "
"or file.".format(os.path.abspath(resource_name))
)

files = nlu_utils.list_files(resource_name)
files = io_utils.list_files(resource_name)

return await StoryFileReader.read_from_files(
files,
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rasa.nlu.config import RasaNLUModelConfig, component_config_from_pipeline
from rasa.nlu.persistor import Persistor
from rasa.nlu.training_data import TrainingData, Message
from rasa.nlu.utils import create_dir, write_json_to_file
from rasa.nlu.utils import write_json_to_file
import rasa.utils.io

MODEL_NAME_PREFIX = "nlu_"
Expand Down Expand Up @@ -221,7 +221,7 @@ def persist(
path = os.path.abspath(path)
dir_name = os.path.join(path, model_name)

create_dir(dir_name)
rasa.utils.io.create_directory(dir_name)

if self.training_data:
metadata.update(self.training_data.persist(dir_name))
Expand Down
Loading

0 comments on commit ad41208

Please sign in to comment.