Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge pull request #1169 from twollnik/policy_cofiguration_via_yaml
Browse files Browse the repository at this point in the history
Policy cofiguration via yaml file
  • Loading branch information
tmbo committed Oct 25, 2018
2 parents fc6340d + 17225df commit d83693c
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 42 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -19,6 +19,8 @@ Added
- Command line interface for interactive learning now displays policy
confidence alongside the action name
- added action prediction confidence & policy to ``ActionExecuted`` event
- the Core policy configuration can now be set in a config.yaml file.
This makes training custom policies possible.
- both the date and the time at which a model was trained are now
included in the policy's metadata when it is persisted
- show visualization of conversation while doing interactive learning
Expand Down
5 changes: 5 additions & 0 deletions data/test_config/example_config.yaml
@@ -0,0 +1,5 @@
policies:
- name: MemoizationPolicy
max_history: 5
- name: tests.conftest.ExamplePolicy
example_arg: 10
41 changes: 39 additions & 2 deletions docs/policies.rst
Expand Up @@ -115,8 +115,45 @@ at every step in the conversation.
There are different policies to choose from, and you can include
multiple policies in a single :class:`rasa_core.agent.Agent`. At
every turn, the policy which predicts the next action with the
highest confidence will be used. You can pass a list of policies
when you create an agent:
highest confidence will be used.

.. _policy_file:

Configuring polices using a configuration file
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You can set the policies you would like the Core model to use in a YAML file.

For example:

.. code-block:: yaml
policies:
- name: "KerasPolicy"
max_history: 5
- name: "MemoizationPolicy"
max_history: 5
- name: "FallbackPolicy"
nlu_threshold: 0.4
core_threshold: 0.3
fallback_action_name: "my_fallback_action"
- name: "path.to.your.policy.class"
arg1: "..."
Pass the YAML file's name to the train script using the ``--config``
argument (or just ``-c``). If no config.yaml is given, the policies
default to ``[KerasPolicy(), MemoizationPolicy(), FallbackPolicy()]``.

.. note::

Policies specified higher in the ``config.yaml`` will take
precedence over a policy specified lower if the confidences
are equal.

Configuring polices in code
^^^^^^^^^^^^^^^^^^^^^^^^^^^

You can pass a list of policies when you create an agent:

.. code-block:: python
Expand Down
64 changes: 64 additions & 0 deletions rasa_core/config.py
@@ -0,0 +1,64 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from rasa_core.constants import (
DEFAULT_NLU_FALLBACK_THRESHOLD,
DEFAULT_CORE_FALLBACK_THRESHOLD, DEFAULT_FALLBACK_ACTION)
from rasa_core import utils
from rasa_core.policies import PolicyEnsemble


def load(config_file, fallback_args, max_history):
# type: Dict[Text, Any] -> List[Policy]
"""Load policy data stored in the specified file. fallback_args and
max_history are typically command line arguments. They take precedence
over the arguments specified in the config yaml.
"""

if config_file is None:
return PolicyEnsemble.default_policies(fallback_args, max_history)

config_data = utils.read_yaml_file(config_file)
config_data = handle_precedence_and_defaults(
config_data, fallback_args, max_history)

return PolicyEnsemble.from_dict(config_data)


def handle_precedence_and_defaults(config_data, fallback_args, max_history):
# type: Dict[Text, Any] -> Dict[Text, Any]

for policy in config_data.get('policies'):

if policy.get('name') == 'FallbackPolicy' and fallback_args is not None:
set_fallback_args(policy, fallback_args)

elif policy.get('name') in {'KerasPolicy', 'MemoizationPolicy'}:
set_arg(policy, "max_history", max_history, 3)

return config_data


def set_arg(data_dict, argument, value, default):

if value is not None:
data_dict[argument] = value
elif data_dict.get(argument) is None:
data_dict[argument] = default

return data_dict


def set_fallback_args(policy, fallback_args):

set_arg(policy, "nlu_threshold",
fallback_args.get("nlu_threshold"),
DEFAULT_NLU_FALLBACK_THRESHOLD)
set_arg(policy, "core_threshold",
fallback_args.get("core_threshold"),
DEFAULT_CORE_FALLBACK_THRESHOLD)
set_arg(policy, "fallback_action_name",
fallback_args.get("fallback_action_name"),
DEFAULT_FALLBACK_ACTION)
50 changes: 49 additions & 1 deletion rasa_core/policies/ensemble.py
Expand Up @@ -3,6 +3,8 @@
from __future__ import print_function
from __future__ import unicode_literals

import importlib
import io
import json
import logging
import os
Expand All @@ -17,9 +19,14 @@

import rasa_core
from rasa_core import utils, training, constants
from rasa_core.constants import (
DEFAULT_NLU_FALLBACK_THRESHOLD,
DEFAULT_CORE_FALLBACK_THRESHOLD, DEFAULT_FALLBACK_ACTION)
from rasa_core.events import SlotSet, ActionExecuted
from rasa_core.exceptions import UnsupportedDialogueModelError
from rasa_core.featurizers import MaxHistoryTrackerFeaturizer
from rasa_core.featurizers import (MaxHistoryTrackerFeaturizer,
BinarySingleStateFeaturizer)
from rasa_core.policies.keras_policy import KerasPolicy
from rasa_core.policies.fallback import FallbackPolicy
from rasa_core.policies.memoization import (MemoizationPolicy,
AugmentedMemoizationPolicy)
Expand Down Expand Up @@ -193,6 +200,47 @@ def load(cls, path):
ensemble = ensemble_cls(policies, fingerprints)
return ensemble

@classmethod
def from_dict(cls, dictionary):
# type: Dict[Text, Any] -> List[Policy]

policies = []

for policy in dictionary.get('policies', []):

policy_name = policy.pop('name')

if policy_name == 'KerasPolicy':
policy_object = KerasPolicy(MaxHistoryTrackerFeaturizer(
BinarySingleStateFeaturizer(),
max_history=policy.get('max_history', 3)))
constr_func = utils.class_from_module_path(policy_name)

policy_object = constr_func(**policy)
policies.append(policy_object)

return policies

@classmethod
def default_policies(cls, fallback_args, max_history):
# type: None -> List[Policy]
"""Load the default policy setup consisting of
FallbackPolicy, MemoizationPolicy and KerasPolicy."""

return [
FallbackPolicy(
fallback_args.get("nlu_threshold",
DEFAULT_NLU_FALLBACK_THRESHOLD),
fallback_args.get("core_threshold",
DEFAULT_CORE_FALLBACK_THRESHOLD),
fallback_args.get("fallback_action_name",
DEFAULT_FALLBACK_ACTION)),
MemoizationPolicy(
max_history=max_history),
KerasPolicy(
MaxHistoryTrackerFeaturizer(BinarySingleStateFeaturizer(),
max_history=max_history))]

def continue_training(self, trackers, domain, **kwargs):
# type: (List[DialogueStateTracker], Domain, Any) -> None

Expand Down
84 changes: 47 additions & 37 deletions rasa_core/train.py
Expand Up @@ -8,15 +8,14 @@
import argparse
import logging

from rasa_core import config
from rasa_core import utils
from rasa_core.agent import Agent
from rasa_core.constants import (
DEFAULT_NLU_FALLBACK_THRESHOLD,
DEFAULT_CORE_FALLBACK_THRESHOLD, DEFAULT_FALLBACK_ACTION)
from rasa_core.featurizers import (
MaxHistoryTrackerFeaturizer, BinarySingleStateFeaturizer)
from rasa_core.interpreter import NaturalLanguageInterpreter
from rasa_core.policies import FallbackPolicy
from rasa_core.policies.ensemble import PolicyEnsemble
from rasa_core.policies.keras_policy import KerasPolicy
from rasa_core.policies.memoization import MemoizationPolicy
from rasa_core.run import AvailableEndpoints
Expand All @@ -34,21 +33,15 @@ def create_argument_parser():
# either the user can pass in a story file, or the data will get
# downloaded from a url
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'-s', '--stories',
type=str,
help="file or folder containing the training stories")
group.add_argument(
'--url',
type=str,
help="If supplied, downloads a story file from a URL and "
"trains on it. Fetches the data by sending a GET request "
"to the supplied URL.")
group.add_argument(
'--core',
default=None,
help="path to load a pre-trained model instead of training (for "
"interactive mode only)")

group = add_args_to_group(group)
parser = add_args_to_parser(parser)

utils.add_logging_option_arguments(parser)
return parser


def add_args_to_parser(parser):

parser.add_argument(
'-o', '--out',
Expand All @@ -68,7 +61,7 @@ def create_argument_parser():
parser.add_argument(
'--history',
type=int,
default=3,
default=None,
help="max history to use of a story")
parser.add_argument(
'--epochs',
Expand Down Expand Up @@ -126,32 +119,60 @@ def create_argument_parser():
parser.add_argument(
'--nlu_threshold',
type=float,
default=DEFAULT_NLU_FALLBACK_THRESHOLD,
default=None,
required=False,
help="If NLU prediction confidence is below threshold, fallback "
"will get triggered.")
parser.add_argument(
'--core_threshold',
type=float,
default=DEFAULT_CORE_FALLBACK_THRESHOLD,
default=None,
required=False,
help="If Core action prediction confidence is below the threshold "
"a fallback action will get triggered")
parser.add_argument(
'--fallback_action_name',
type=str,
default=DEFAULT_FALLBACK_ACTION,
default=None,
required=False,
help="When a fallback is triggered (e.g. because the ML prediction "
"is of low confidence) this is the name of tje action that "
"will get triggered instead.")

utils.add_logging_option_arguments(parser)
parser.add_argument(
'-c', '--config',
type=str,
required=False,
help="Policy specification yaml file."
)
return parser


def add_args_to_group(group):

group.add_argument(
'-s', '--stories',
type=str,
help="file or folder containing the training stories")
group.add_argument(
'--url',
type=str,
help="If supplied, downloads a story file from a URL and "
"trains on it. Fetches the data by sending a GET request "
"to the supplied URL.")
group.add_argument(
'--core',
default=None,
help="path to load a pre-trained model instead of training (for "
"interactive mode only)")
return group


def train_dialogue_model(domain_file, stories_file, output_path,
interpreter=None,
endpoints=AvailableEndpoints(),
max_history=None,
dump_flattened_stories=False,
policy_config=None,
kwargs=None):
if not kwargs:
kwargs = {}
Expand All @@ -161,19 +182,7 @@ def train_dialogue_model(domain_file, stories_file, output_path,
"core_threshold",
"fallback_action_name"})

policies = [
FallbackPolicy(
fallback_args.get("nlu_threshold",
DEFAULT_NLU_FALLBACK_THRESHOLD),
fallback_args.get("core_threshold",
DEFAULT_CORE_FALLBACK_THRESHOLD),
fallback_args.get("fallback_action_name",
DEFAULT_FALLBACK_ACTION)),
MemoizationPolicy(
max_history=max_history),
KerasPolicy(
MaxHistoryTrackerFeaturizer(BinarySingleStateFeaturizer(),
max_history=max_history))]
policies = config.load(policy_config, fallback_args, max_history)

agent = Agent(domain_file,
generator=endpoints.nlg,
Expand Down Expand Up @@ -248,6 +257,7 @@ def train_dialogue_model(domain_file, stories_file, output_path,
_endpoints,
cmdline_args.history,
cmdline_args.dump_stories,
cmdline_args.config,
additional_arguments)

if cmdline_args.interactive:
Expand Down
6 changes: 5 additions & 1 deletion rasa_core/utils.py
Expand Up @@ -81,13 +81,17 @@ def class_from_module_path(module_path):
import importlib

# load the module, will raise ImportError if module cannot be loaded
from rasa_core.policies.keras_policy import KerasPolicy
from rasa_core.policies.fallback import FallbackPolicy
from rasa_core.policies.memoization import MemoizationPolicy

if "." in module_path:
module_name, _, class_name = module_path.rpartition('.')
m = importlib.import_module(module_name)
# get the class, will raise AttributeError if class cannot be found
return getattr(m, class_name)
else:
return globals()[module_path]
return globals().get(module_path, locals().get(module_path))


def module_path_from_instance(inst):
Expand Down
8 changes: 7 additions & 1 deletion tests/conftest.py
Expand Up @@ -20,7 +20,7 @@
from rasa_core.nlg import TemplatedNaturalLanguageGenerator
from rasa_core.policies.ensemble import SimplePolicyEnsemble, PolicyEnsemble
from rasa_core.policies.memoization import (
MemoizationPolicy, AugmentedMemoizationPolicy)
Policy, MemoizationPolicy, AugmentedMemoizationPolicy)
from rasa_core.processor import MessageProcessor
from rasa_core.slots import Slot
from rasa_core.tracker_store import InMemoryTrackerStore
Expand All @@ -47,6 +47,12 @@ def as_feature(self):
return [0.5]


class ExamplePolicy(Policy):

def __init__(self, example_arg):
pass


@pytest.fixture(scope="session")
def default_domain():
return Domain.load(DEFAULT_DOMAIN_PATH)
Expand Down

0 comments on commit d83693c

Please sign in to comment.