Skip to content

Commit

Permalink
Merge pull request #7749 from RasaHQ/fix_no_user_pred
Browse files Browse the repository at this point in the history
fix e2e overriding loop prediction
  • Loading branch information
rasabot committed Jan 27, 2021
2 parents 8e68e39 + b60b2ea commit 9379a22
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 12 deletions.
1 change: 1 addition & 0 deletions changelog/7749.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix the bug when `RulePolicy` handling loop predictions are overwritten by e2e `TEDPolicy`.
24 changes: 20 additions & 4 deletions rasa/core/policies/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,16 +578,30 @@ def _pick_best_policy(

form_confidence = None
form_policy_name = None
# End-to-end predictions overrule all other predictions.
use_only_end_to_end = any(
# different type of predictions have different priorities
# No user predictions overrule all other predictions.
is_no_user_prediction = any(
prediction.is_no_user_prediction for prediction in predictions.values()
)
# End-to-end predictions overrule all other predictions based on user input.
is_end_to_end_prediction = any(
prediction.is_end_to_end_prediction for prediction in predictions.values()
)
policy_events = []

policy_events = []
for policy_name, prediction in predictions.items():
policy_events += prediction.events

if prediction.is_end_to_end_prediction != use_only_end_to_end:
# No user predictions (e.g. happy path loop predictions)
# overrule all other predictions.
if prediction.is_no_user_prediction != is_no_user_prediction:
continue

# End-to-end predictions overrule all other predictions based on user input.
if (
not is_no_user_prediction
and prediction.is_end_to_end_prediction != is_end_to_end_prediction
):
continue

confidence = (prediction.max_confidence, prediction.policy_priority)
Expand Down Expand Up @@ -617,6 +631,8 @@ def _pick_best_policy(
best_prediction.policy_priority,
policy_events,
is_end_to_end_prediction=best_prediction.is_end_to_end_prediction,
is_no_user_prediction=best_prediction.is_no_user_prediction,
diagnostic_data=best_prediction.diagnostic_data,
)

def _best_policy_prediction(
Expand Down
11 changes: 10 additions & 1 deletion rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def _prediction(
events: Optional[List[Event]] = None,
optional_events: Optional[List[Event]] = None,
is_end_to_end_prediction: bool = False,
is_no_user_prediction: bool = False,
diagnostic_data: Optional[Dict[Text, Any]] = None,
) -> "PolicyPrediction":
return PolicyPrediction(
Expand All @@ -245,6 +246,7 @@ def _prediction(
events,
optional_events,
is_end_to_end_prediction,
is_no_user_prediction,
diagnostic_data,
)

Expand Down Expand Up @@ -343,7 +345,8 @@ def _default_predictions(domain: Domain) -> List[float]:
"""
return [0.0] * domain.num_actions

def format_tracker_states(self, states: List[Dict]) -> Text:
@staticmethod
def format_tracker_states(states: List[Dict]) -> Text:
"""Format tracker states to human readable format on debug log.
Args:
Expand Down Expand Up @@ -402,6 +405,7 @@ def __init__(
events: Optional[List[Event]] = None,
optional_events: Optional[List[Event]] = None,
is_end_to_end_prediction: bool = False,
is_no_user_prediction: bool = False,
diagnostic_data: Optional[Dict[Text, Any]] = None,
) -> None:
"""Creates a `PolicyPrediction`.
Expand All @@ -420,6 +424,9 @@ def __init__(
you return as they can potentially influence the conversation flow.
is_end_to_end_prediction: `True` if the prediction used the text of the
user message instead of the intent.
is_no_user_prediction: `True` if the prediction uses neither the text
of the user message nor the intent. This is for the example the case
for happy loop paths.
diagnostic_data: Intermediate results or other information that is not
necessary for Rasa to function, but intended for debugging and
fine-tuning purposes.
Expand All @@ -430,6 +437,7 @@ def __init__(
self.events = events or []
self.optional_events = optional_events or []
self.is_end_to_end_prediction = is_end_to_end_prediction
self.is_no_user_prediction = is_no_user_prediction
self.diagnostic_data = diagnostic_data or {}

@staticmethod
Expand Down Expand Up @@ -473,6 +481,7 @@ def __eq__(self, other: Any) -> bool:
and self.events == other.events
and self.optional_events == other.events
and self.is_end_to_end_prediction == other.is_end_to_end_prediction
and self.is_no_user_prediction == other.is_no_user_prediction
# We do not compare `diagnostic_data`, because it has no effect on the
# action prediction.
)
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,8 @@ def predict_action_probabilities(
# this prediction doesn't use user input
# and happy user input anyhow should be ignored during featurization
return self._prediction(
self._prediction_result(loop_happy_path_action_name, tracker, domain)
self._prediction_result(loop_happy_path_action_name, tracker, domain),
is_no_user_prediction=True,
)

# predict rules from text first
Expand Down
20 changes: 14 additions & 6 deletions tests/core/policies/test_rule_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Text
from typing import Text, Optional

import pytest

Expand Down Expand Up @@ -767,11 +767,15 @@ def assert_predicted_action(
domain: Domain,
expected_action_name: Text,
confidence: float = 1.0,
is_end_to_end_prediction: bool = False,
is_no_user_prediction: bool = False,
) -> None:
assert prediction.max_confidence == confidence
index_of_predicted_action = prediction.max_confidence_index
prediction_action_name = domain.action_names_or_texts[index_of_predicted_action]
assert prediction_action_name == expected_action_name
assert prediction.is_end_to_end_prediction == is_end_to_end_prediction
assert prediction.is_no_user_prediction == is_no_user_prediction


async def test_predict_form_action_if_in_form():
Expand Down Expand Up @@ -813,7 +817,7 @@ async def test_predict_form_action_if_in_form():
prediction = policy.predict_action_probabilities(
form_conversation, domain, RegexInterpreter()
)
assert_predicted_action(prediction, domain, form_name)
assert_predicted_action(prediction, domain, form_name, is_no_user_prediction=True)


async def test_predict_loop_action_if_in_loop_but_there_is_e2e_rule():
Expand Down Expand Up @@ -866,7 +870,7 @@ async def test_predict_loop_action_if_in_loop_but_there_is_e2e_rule():
prediction = policy.predict_action_probabilities(
loop_conversation, domain, RegexInterpreter()
)
assert_predicted_action(prediction, domain, loop_name)
assert_predicted_action(prediction, domain, loop_name, is_no_user_prediction=True)


async def test_predict_form_action_if_multiple_turns():
Expand Down Expand Up @@ -915,7 +919,7 @@ async def test_predict_form_action_if_multiple_turns():
prediction = policy.predict_action_probabilities(
form_conversation, domain, RegexInterpreter()
)
assert_predicted_action(prediction, domain, form_name)
assert_predicted_action(prediction, domain, form_name, is_no_user_prediction=True)


async def test_predict_action_listen_after_form():
Expand Down Expand Up @@ -959,7 +963,9 @@ async def test_predict_action_listen_after_form():
prediction = policy.predict_action_probabilities(
form_conversation, domain, RegexInterpreter()
)
assert_predicted_action(prediction, domain, ACTION_LISTEN_NAME)
assert_predicted_action(
prediction, domain, ACTION_LISTEN_NAME, is_no_user_prediction=True
)


async def test_dont_predict_form_if_already_finished():
Expand Down Expand Up @@ -1853,7 +1859,9 @@ def test_e2e_beats_default_actions(intent_name: Text):
prediction = policy.predict_action_probabilities(
new_conversation, domain, RegexInterpreter()
)
assert_predicted_action(prediction, domain, UTTER_GREET_ACTION)
assert_predicted_action(
prediction, domain, UTTER_GREET_ACTION, is_end_to_end_prediction=True
)


@pytest.mark.parametrize(
Expand Down
31 changes: 31 additions & 0 deletions tests/core/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
predict_index: Optional[int] = None,
confidence: float = 1,
is_end_to_end_prediction: bool = False,
is_no_user_prediction: bool = False,
events: Optional[List[Event]] = None,
optional_events: Optional[List[Event]] = None,
**kwargs: Any,
Expand All @@ -124,6 +125,7 @@ def __init__(
self.predict_index = predict_index
self.confidence = confidence
self.is_end_to_end_prediction = is_end_to_end_prediction
self.is_no_user_prediction = is_no_user_prediction
self.events = events or []
self.optional_events = optional_events or []

Expand Down Expand Up @@ -158,6 +160,7 @@ def predict_action_probabilities(
self.__class__.__name__,
policy_priority=self.priority,
is_end_to_end_prediction=self.is_end_to_end_prediction,
is_no_user_prediction=self.is_no_user_prediction,
events=self.events,
optional_events=self.optional_events,
)
Expand Down Expand Up @@ -527,6 +530,34 @@ def test_end_to_end_prediction_supersedes_others(default_domain: Domain):
assert prediction.policy_name == f"policy_1_{ConstantPolicy.__name__}"


def test_no_user_prediction_supersedes_others(default_domain: Domain):
expected_action_index = 2
expected_confidence = 0.5
ensemble = SimplePolicyEnsemble(
[
ConstantPolicy(priority=100, predict_index=0),
ConstantPolicy(priority=1, predict_index=1, is_end_to_end_prediction=True),
ConstantPolicy(
priority=1,
predict_index=expected_action_index,
confidence=expected_confidence,
is_no_user_prediction=True,
),
]
)
tracker = DialogueStateTracker.from_events("test", evts=[])

prediction = ensemble.probabilities_using_best_policy(
tracker, default_domain, RegexInterpreter()
)

assert prediction.max_confidence == expected_confidence
assert prediction.max_confidence_index == expected_action_index
assert prediction.policy_name == f"policy_2_{ConstantPolicy.__name__}"
assert prediction.is_no_user_prediction
assert not prediction.is_end_to_end_prediction


def test_prediction_applies_must_have_policy_events(default_domain: Domain):
must_have_events = [ActionExecuted("my action")]

Expand Down

0 comments on commit 9379a22

Please sign in to comment.