Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge pull request #989 from RasaHQ/fallback_fix
Browse files Browse the repository at this point in the history
fixes #988: fix fallback for two action_listen in a row
  • Loading branch information
tmbo committed Sep 14, 2018
2 parents 394fcc9 + 88fadb5 commit 2569021
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 26 deletions.
2 changes: 1 addition & 1 deletion rasa_core/policies/embedding_policy.py
Expand Up @@ -21,7 +21,7 @@
from rasa_core.featurizers import (TrackerFeaturizer,
FullDialogueTrackerFeaturizer,
LabelTokenizerSingleStateFeaturizer)
from rasa_core.policies import Policy
from rasa_core.policies.policy import Policy

import tensorflow as tf
from rasa_core.policies.tf_utils import (TimeAttentionWrapper,
Expand Down
55 changes: 32 additions & 23 deletions rasa_core/policies/ensemble.py
Expand Up @@ -13,15 +13,18 @@
import numpy as np
from builtins import str
import typing
from typing import Text, Optional, Any, List, Dict
from typing import Text, Optional, Any, List, Dict, Tuple

import rasa_core
from rasa_core import utils, training, constants
from rasa_core.events import SlotSet, ActionExecuted, UserUttered
from rasa_core.events import SlotSet, ActionExecuted
from rasa_core.exceptions import UnsupportedDialogueModelError
from rasa_core.featurizers import MaxHistoryTrackerFeaturizer
from rasa_core.policies.fallback import FallbackPolicy
from rasa_core.policies.memoization import MemoizationPolicy, AugmentedMemoizationPolicy
from rasa_core.policies.memoization import (MemoizationPolicy,
AugmentedMemoizationPolicy)

from rasa_core.actions.action import ACTION_LISTEN_NAME

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -198,10 +201,11 @@ def continue_training(self, trackers, domain, **kwargs):

class SimplePolicyEnsemble(PolicyEnsemble):

def is_not_memo_policy(self, best_policy_name):
@staticmethod
def is_not_memo_policy(best_policy_name):
return not (best_policy_name.endswith("_" + MemoizationPolicy.__name__)
or best_policy_name.endswith(
"_" + AugmentedMemoizationPolicy.__name__))
"_" + AugmentedMemoizationPolicy.__name__))

def probabilities_using_best_policy(self, tracker, domain):
# type: (DialogueStateTracker, Domain) -> Tuple[List[float], Text]
Expand All @@ -216,27 +220,32 @@ def probabilities_using_best_policy(self, tracker, domain):
result = probabilities
best_policy_name = 'policy_{}_{}'.format(i, type(p).__name__)

policy_names = [type(p).__name__ for p in self.policies]

# Trigger the fallback policy when ActionListen is predicted after
# a user utterance. This is done on the condition that: a fallback
# policy is present, there was just a user message and the predicted
# action is action_listen by a policy other than the MemoizationPolicy
if (result.index(max_confidence) ==
domain.index_for_action(ACTION_LISTEN_NAME) and
tracker.latest_action_name == ACTION_LISTEN_NAME and
self.is_not_memo_policy(best_policy_name)):
# Trigger the fallback policy when ActionListen is predicted after
# a user utterance. This is done on the condition that:
# - a fallback policy is present,
# - there was just a user message and the predicted
# action is action_listen by a policy
# other than the MemoizationPolicy

fallback_idx_policy = [(i, p) for i, p in enumerate(self.policies)
if isinstance(p, FallbackPolicy)]

if fallback_idx_policy:
fallback_idx, fallback_policy = fallback_idx_policy[0]

logger.debug("Action 'action_listen' was predicted after "
"a user message using {}. "
"Predicting fallback action: {}"
"".format(best_policy_name,
fallback_policy.fallback_action_name))

if FallbackPolicy.__name__ in policy_names:
idx = policy_names.index(FallbackPolicy.__name__)
fallback_policy = self.policies[idx]

if (result.index(max_confidence) == 0 and
self.is_not_memo_policy(best_policy_name)
and isinstance(tracker.events[-1], UserUttered)):
logger.debug("Action listen was predicted after a user message."
" Predicting fallback action: {}"
"".format(fallback_policy.fallback_action_name))
result = fallback_policy.fallback_scores(domain)

best_policy_name = 'policy_{}_{}'.format(
idx,
fallback_idx,
type(fallback_policy).__name__)

# normalize probablilities
Expand Down
2 changes: 1 addition & 1 deletion rasa_core/policies/keras_policy.py
Expand Up @@ -14,7 +14,7 @@
from typing import Any, List, Dict, Text, Optional, Tuple

from rasa_core import utils
from rasa_core.policies import Policy
from rasa_core.policies.policy import Policy
from rasa_core.featurizers import TrackerFeaturizer

if typing.TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion rasa_core/policies/sklearn_policy.py
Expand Up @@ -19,7 +19,7 @@
# noinspection PyProtectedMember
from sklearn.utils import shuffle as sklearn_shuffle

from rasa_core.policies import Policy
from rasa_core.policies.policy import Policy
from rasa_core.featurizers import \
TrackerFeaturizer, MaxHistoryTrackerFeaturizer

Expand Down

0 comments on commit 2569021

Please sign in to comment.