From e54880d404055cb74a40d0912794d2f4dbaa66bb Mon Sep 17 00:00:00 2001 From: kedz Date: Wed, 29 Dec 2021 14:29:46 -0500 Subject: [PATCH 01/12] AugmentedMemoizationPolicy truncates tracker on action_listen instead of arbitrary action. Added corresponding tests. --- rasa/core/policies/memoization.py | 11 +++- tests/core/policies/test_memoization.py | 86 +++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/rasa/core/policies/memoization.py b/rasa/core/policies/memoization.py index 977bc254ddac..1175ed328844 100644 --- a/rasa/core/policies/memoization.py +++ b/rasa/core/policies/memoization.py @@ -22,6 +22,7 @@ from rasa.shared.core.generator import TrackerWithCachedStates from rasa.shared.utils.io import is_logging_disabled from rasa.core.constants import MEMOIZATION_POLICY_PRIORITY +from rasa.shared.core.constants import ACTION_LISTEN_NAME logger = logging.getLogger(__name__) @@ -391,12 +392,16 @@ def _get_max_applied_events_for_max_history( ) -> Optional[int]: """Computes the number of events in the tracker that correspond to max_history. + To ensure that the last user utterance is correctly included in the prediction + states, return the index of the most recent `action_listen` event occuring + before the tracker would be truncated according to the value of `max_history`. + Args: tracker: Some tracker holding the events max_history: The number of actions to count Returns: - The number of actions, as counted from the end of the event list, that should + The number of events, as counted from the end of the event list, that should be taken into accout according to the `max_history` setting. If all events should be taken into account, the return value is `None`. """ @@ -408,8 +413,8 @@ def _get_max_applied_events_for_max_history( num_events += 1 if isinstance(event, ActionExecuted): num_actions += 1 - if num_actions > max_history: - return num_events + if num_actions > max_history and event.action_name == ACTION_LISTEN_NAME: + return num_events return None diff --git a/tests/core/policies/test_memoization.py b/tests/core/policies/test_memoization.py index b2578d1acc55..7e6dd242a420 100644 --- a/tests/core/policies/test_memoization.py +++ b/tests/core/policies/test_memoization.py @@ -157,3 +157,89 @@ def test_augmented_prediction(self, max_history): ] == UTTER_BYE_ACTION ) + + @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None]) + def test_augmented_prediction_long_story(self, max_history): + policy = self.create_policy( + featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1 + ) + + GREET_INTENT_NAME = "greet" + UTTER_GREET_ACTION = "utter_greet" + UTTER_ACTION_1 = "utter_1" + UTTER_ACTION_2 = "utter_2" + UTTER_ACTION_3 = "utter_3" + UTTER_ACTION_4 = "utter_4" + UTTER_ACTION_5 = "utter_5" + UTTER_BYE_ACTION = "utter_goodbye" + domain = Domain.from_yaml( + f""" + intents: + - {GREET_INTENT_NAME} + actions: + - {UTTER_GREET_ACTION} + - {UTTER_ACTION_1} + - {UTTER_ACTION_2} + - {UTTER_ACTION_3} + - {UTTER_ACTION_4} + - {UTTER_ACTION_5} + - {UTTER_BYE_ACTION} + slots: + slot_1: + type: bool + initial_value: true + slot_2: + type: bool + slot_3: + type: bool + """ + ) + training_story = TrackerWithCachedStates.from_events( + "training story", + [ + ActionExecuted(UTTER_GREET_ACTION), + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_ACTION_1), + ActionExecuted(UTTER_ACTION_2), + ActionExecuted(UTTER_ACTION_3), + ActionExecuted(UTTER_ACTION_4), + ActionExecuted(UTTER_ACTION_5), + ActionExecuted(UTTER_BYE_ACTION), + ], + domain=domain, + slots=domain.slots, + ) + test_story = TrackerWithCachedStates.from_events( + "test story", + [ + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_GREET_ACTION), + SlotSet("slot_1", False), + ActionExecuted(UTTER_GREET_ACTION), + ActionExecuted(UTTER_GREET_ACTION), + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_GREET_ACTION), + SlotSet("slot_2", True), + ActionExecuted(UTTER_GREET_ACTION), + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_ACTION_1), + ActionExecuted(UTTER_ACTION_2), + ActionExecuted(UTTER_ACTION_3), + ActionExecuted(UTTER_ACTION_4), + ActionExecuted(UTTER_ACTION_5), + # ActionExecuted(UTTER_BYE_ACTION), + ], + domain=domain, + slots=domain.slots, + ) + interpreter = RegexInterpreter() + policy.train([training_story], domain, interpreter) + prediction = policy.predict_action_probabilities( + test_story, domain, interpreter + ) + assert ( + domain.action_names_or_texts[ + prediction.probabilities.index(max(prediction.probabilities)) + ] + == UTTER_BYE_ACTION + ) From b6197108b21df5c611e058f9d8fdc77bfb4c8a1a Mon Sep 17 00:00:00 2001 From: kedz Date: Wed, 29 Dec 2021 14:47:34 -0500 Subject: [PATCH 02/12] Add changelog entry. --- changelog/10606.bugfix.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/10606.bugfix.md diff --git a/changelog/10606.bugfix.md b/changelog/10606.bugfix.md new file mode 100644 index 000000000000..ebca3da91a52 --- /dev/null +++ b/changelog/10606.bugfix.md @@ -0,0 +1 @@ +Fix `max_history` truncation in `AugmentedMemoizationPolicy` to preserve the most recent `UserUttered` event. From b2cbf2960e52d46ad74904666ab045b4ac4d1419 Mon Sep 17 00:00:00 2001 From: kedz Date: Sun, 2 Jan 2022 22:58:24 -0500 Subject: [PATCH 03/12] Fix broken CI id. --- .github/workflows/ci-model-regression.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-model-regression.yml b/.github/workflows/ci-model-regression.yml index ca874cf62625..3ef275f175c1 100644 --- a/.github/workflows/ci-model-regression.yml +++ b/.github/workflows/ci-model-regression.yml @@ -643,7 +643,7 @@ jobs: # Get ID of the schedule workflow SCHEDULE_ID=$(curl -X GET -s -H 'Authorization: token ${{ secrets.GITHUB_TOKEN }}' -H "Accept: application/vnd.github.v3+json" \ "https://api.github.com/repos/${{ github.repository }}/actions/workflows" \ - | jq -r '.workflows[] | select(.name == "${{ github.workflow }}") | select(.path | test("schedule")) | .id') + | jq -r '.workflows[] | select(.name == "CI - Model Regression on schedule") | select(.path | test("schedule")) | .id') ARTIFACT_URL=$(curl -s -H 'Authorization: token ${{ secrets.GITHUB_TOKEN }}' -H "Accept: application/vnd.github.v3+json" \ "https://api.github.com/repos/${{ github.repository }}/actions/workflows/${SCHEDULE_ID}/runs?event=schedule&status=completed&branch=main&per_page=1" | jq -r .workflow_runs[0].artifacts_url) From 4bcf344a44ef494239e404a776c6775f2af9ba2b Mon Sep 17 00:00:00 2001 From: Chris Kedzie Date: Mon, 3 Jan 2022 11:57:42 -0500 Subject: [PATCH 04/12] Update changelog/10606.bugfix.md Co-authored-by: Sam Sucik --- changelog/10606.bugfix.md | 1 + 1 file changed, 1 insertion(+) diff --git a/changelog/10606.bugfix.md b/changelog/10606.bugfix.md index ebca3da91a52..4f0b55ee6df9 100644 --- a/changelog/10606.bugfix.md +++ b/changelog/10606.bugfix.md @@ -1 +1,2 @@ Fix `max_history` truncation in `AugmentedMemoizationPolicy` to preserve the most recent `UserUttered` event. +Previously, `AugmentedMemoizationPolicy` failed to predict next action after long sequences of actions (longer than `max_history`) because the policy did not have access to the most recent user message. From a9d885f20add8f2bdcc5b65b6863555dd68c9fc5 Mon Sep 17 00:00:00 2001 From: kedz Date: Mon, 3 Jan 2022 14:04:52 -0500 Subject: [PATCH 05/12] Added tests and fixed edge case. --- rasa/core/policies/memoization.py | 37 ++- tests/core/policies/test_memoization.py | 311 +++++++++++++++++++++++- 2 files changed, 323 insertions(+), 25 deletions(-) diff --git a/rasa/core/policies/memoization.py b/rasa/core/policies/memoization.py index 1175ed328844..41372698f8e2 100644 --- a/rasa/core/policies/memoization.py +++ b/rasa/core/policies/memoization.py @@ -11,7 +11,7 @@ import rasa.shared.utils.io from rasa.shared.constants import DOCS_URL_POLICIES from rasa.shared.core.domain import State, Domain -from rasa.shared.core.events import ActionExecuted +from rasa.shared.core.events import ActionExecuted, UserUttered from rasa.core.featurizers.tracker_featurizers import ( TrackerFeaturizer, MaxHistoryTrackerFeaturizer, @@ -290,26 +290,37 @@ class AugmentedMemoizationPolicy(MemoizationPolicy): def _back_to_the_future( tracker: DialogueStateTracker, again: bool = False ) -> Optional[DialogueStateTracker]: - """Send Marty to the past to get - the new featurization for the future""" + """Truncates the tracker to the next `ActionExecuted` or `UserUttered` event. - idx_of_first_action = None - idx_of_second_action = None + Args: + tracker: The tracker to truncate. + again: When true, truncate tracker at the second action or + user utterance. Otherwise truncate to the firt action or + user utterance. + + Returns: + The truncated tracker if there were actions or user utterances + present. If none are found, returns `None`. + """ + idx_of_first_action_or_user = None + idx_of_second_action_or_user = None applied_events = tracker.applied_events() - # we need to find second executed action + # We need to find the second `ActionExecuted` or `UserUttered` event. for e_i, event in enumerate(applied_events): - # find second ActionExecuted - if isinstance(event, ActionExecuted): - if idx_of_first_action is None: - idx_of_first_action = e_i + if isinstance(event, ActionExecuted) or isinstance(event, UserUttered): + if idx_of_first_action_or_user is None: + idx_of_first_action_or_user = e_i else: - idx_of_second_action = e_i + idx_of_second_action_or_user = e_i break - # use first action, if we went first time and second action, if we went again - idx_to_use = idx_of_second_action if again else idx_of_first_action + # use first action/user utterance, if we went first time and + # second action/user utterance, if we went again + idx_to_use = ( + idx_of_second_action_or_user if again else idx_of_first_action_or_user + ) if idx_to_use is None: return None diff --git a/tests/core/policies/test_memoization.py b/tests/core/policies/test_memoization.py index 7e6dd242a420..4c06e8b6f11f 100644 --- a/tests/core/policies/test_memoization.py +++ b/tests/core/policies/test_memoization.py @@ -13,6 +13,7 @@ UserUttered, SlotSet, ) +from rasa.shared.core.constants import ACTION_LISTEN_NAME from rasa.shared.nlu.interpreter import RegexInterpreter @@ -89,7 +90,76 @@ def create_policy( return AugmentedMemoizationPolicy(featurizer=featurizer, priority=priority) @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None]) - def test_augmented_prediction(self, max_history): + def test_augmented_prediction_when_starts_with_intent(self, max_history): + policy = self.create_policy( + featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1 + ) + + GREET_INTENT_NAME = "greet" + UTTER_GREET_ACTION = "utter_greet" + UTTER_BYE_ACTION = "utter_goodbye" + domain = Domain.from_yaml( + f""" + intents: + - {GREET_INTENT_NAME} + actions: + - {UTTER_GREET_ACTION} + - {UTTER_BYE_ACTION} + slots: + slot_1: + type: bool + initial_value: true + slot_2: + type: bool + slot_3: + type: bool + """ + ) + training_story = TrackerWithCachedStates.from_events( + "training story", + [ + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_GREET_ACTION), + SlotSet("slot_3", True), + ActionExecuted(UTTER_BYE_ACTION), + ], + domain=domain, + slots=domain.slots, + ) + test_story = TrackerWithCachedStates.from_events( + "test story", + [ + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_GREET_ACTION), + SlotSet("slot_1", False), + ActionExecuted(UTTER_GREET_ACTION), + ActionExecuted(UTTER_GREET_ACTION), + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_GREET_ACTION), + SlotSet("slot_2", True), + ActionExecuted(UTTER_GREET_ACTION), + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_GREET_ACTION), + SlotSet("slot_3", True), + # ActionExecuted(UTTER_BYE_ACTION), + ], + domain=domain, + slots=domain.slots, + ) + interpreter = RegexInterpreter() + policy.train([training_story], domain, interpreter) + prediction = policy.predict_action_probabilities( + test_story, domain, interpreter + ) + assert ( + domain.action_names_or_texts[ + prediction.probabilities.index(max(prediction.probabilities)) + ] + == UTTER_BYE_ACTION + ) + + @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None]) + def test_augmented_prediction_when_starts_with_action(self, max_history): policy = self.create_policy( featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1 ) @@ -159,7 +229,11 @@ def test_augmented_prediction(self, max_history): ) @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None]) - def test_augmented_prediction_long_story(self, max_history): + def test_augmented_prediction_across_max_history_actions(self, max_history): + """Tests that the last user utterance is preserved in action states + even when the utterance occurs prior to `max_history` actions in the + past. + """ policy = self.create_policy( featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1 ) @@ -197,7 +271,6 @@ def test_augmented_prediction_long_story(self, max_history): training_story = TrackerWithCachedStates.from_events( "training story", [ - ActionExecuted(UTTER_GREET_ACTION), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), @@ -212,15 +285,6 @@ def test_augmented_prediction_long_story(self, max_history): test_story = TrackerWithCachedStates.from_events( "test story", [ - UserUttered(intent={"name": GREET_INTENT_NAME}), - ActionExecuted(UTTER_GREET_ACTION), - SlotSet("slot_1", False), - ActionExecuted(UTTER_GREET_ACTION), - ActionExecuted(UTTER_GREET_ACTION), - UserUttered(intent={"name": GREET_INTENT_NAME}), - ActionExecuted(UTTER_GREET_ACTION), - SlotSet("slot_2", True), - ActionExecuted(UTTER_GREET_ACTION), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), @@ -243,3 +307,226 @@ def test_augmented_prediction_long_story(self, max_history): ] == UTTER_BYE_ACTION ) + + @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None]) + def test_aug_pred_sensitive_to_intent_across_max_history_actions(self, max_history): + """Tests that only the most recent user utterance propagates to state + creation of following actions. + """ + policy = self.create_policy( + featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1 + ) + + GREET_INTENT_NAME = "greet" + GOODBYE_INTENT_NAME = "goodbye" + UTTER_GREET_ACTION = "utter_greet" + UTTER_ACTION_1 = "utter_1" + UTTER_ACTION_2 = "utter_2" + UTTER_ACTION_3 = "utter_3" + UTTER_ACTION_4 = "utter_4" + UTTER_ACTION_5 = "utter_5" + UTTER_BYE_ACTION = "utter_goodbye" + domain = Domain.from_yaml( + f""" + intents: + - {GREET_INTENT_NAME} + - {GOODBYE_INTENT_NAME} + actions: + - {UTTER_GREET_ACTION} + - {UTTER_ACTION_1} + - {UTTER_ACTION_2} + - {UTTER_ACTION_3} + - {UTTER_ACTION_4} + - {UTTER_ACTION_5} + - {UTTER_BYE_ACTION} + slots: + slot_1: + type: bool + initial_value: true + slot_2: + type: bool + slot_3: + type: bool + """ + ) + training_story = TrackerWithCachedStates.from_events( + "training story", + [ + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_ACTION_1), + ActionExecuted(UTTER_ACTION_2), + ActionExecuted(UTTER_ACTION_3), + ActionExecuted(UTTER_ACTION_4), + ActionExecuted(UTTER_ACTION_5), + ActionExecuted(UTTER_BYE_ACTION), + ], + domain=domain, + slots=domain.slots, + ) + interpreter = RegexInterpreter() + policy.train([training_story], domain, interpreter) + + test_story1 = TrackerWithCachedStates.from_events( + "test story", + [ + UserUttered(intent={"name": GOODBYE_INTENT_NAME}), + SlotSet("slot_1", False), + ActionExecuted(UTTER_BYE_ACTION), + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_ACTION_1), + ActionExecuted(UTTER_ACTION_2), + ActionExecuted(UTTER_ACTION_3), + ActionExecuted(UTTER_ACTION_4), + ActionExecuted(UTTER_ACTION_5), + # ActionExecuted(UTTER_BYE_ACTION), + ], + domain=domain, + slots=domain.slots, + ) + prediction1 = policy.predict_action_probabilities( + test_story1, domain, interpreter + ) + assert ( + domain.action_names_or_texts[ + prediction1.probabilities.index(max(prediction1.probabilities)) + ] + == UTTER_BYE_ACTION + ) + + test_story2_no_match_expected = TrackerWithCachedStates.from_events( + "test story", + [ + UserUttered(intent={"name": GREET_INTENT_NAME}), + SlotSet("slot_1", False), + ActionExecuted(UTTER_BYE_ACTION), + UserUttered(intent={"name": GOODBYE_INTENT_NAME}), + ActionExecuted(UTTER_ACTION_1), + ActionExecuted(UTTER_ACTION_2), + ActionExecuted(UTTER_ACTION_3), + ActionExecuted(UTTER_ACTION_4), + ActionExecuted(UTTER_ACTION_5), + # ActionExecuted(ACTION_LISTEN_NAME), + ], + domain=domain, + slots=domain.slots, + ) + + prediction2 = policy.predict_action_probabilities( + test_story2_no_match_expected, domain, interpreter + ) + assert ( + domain.action_names_or_texts[ + prediction2.probabilities.index(max(prediction2.probabilities)) + ] + == ACTION_LISTEN_NAME + ) + + @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None]) + def test_aug_pred_connects_different_memoizations(self, max_history): + """Tests memoization works for a memoized state sequence that starts + with a user utterance that leads to memoized state that does not + have user utterance information. + """ + policy = self.create_policy( + featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1 + ) + + GREET_INTENT_NAME = "greet" + GOODBYE_INTENT_NAME = "goodbye" + UTTER_GREET_ACTION = "utter_greet" + UTTER_ACTION_1 = "utter_1" + UTTER_ACTION_2 = "utter_2" + UTTER_ACTION_3 = "utter_3" + UTTER_ACTION_4 = "utter_4" + UTTER_BYE_ACTION = "utter_goodbye" + domain = Domain.from_yaml( + f""" + intents: + - {GREET_INTENT_NAME} + - {GOODBYE_INTENT_NAME} + actions: + - {UTTER_GREET_ACTION} + - {UTTER_ACTION_1} + - {UTTER_ACTION_2} + - {UTTER_ACTION_3} + - {UTTER_ACTION_4} + slots: + slot_1: + type: bool + initial_value: true + slot_2: + type: bool + slot_3: + type: bool + """ + ) + training_story1 = TrackerWithCachedStates.from_events( + "training story", + [ + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_ACTION_1), + ActionExecuted(UTTER_ACTION_2), + ActionExecuted(UTTER_ACTION_3), + ], + domain=domain, + slots=domain.slots, + ) + training_story2 = TrackerWithCachedStates.from_events( + "training story", + [ActionExecuted(UTTER_ACTION_3), ActionExecuted(UTTER_ACTION_4),], + domain=domain, + slots=domain.slots, + ) + + interpreter = RegexInterpreter() + policy.train([training_story1, training_story2], domain, interpreter) + + test_story1 = TrackerWithCachedStates.from_events( + "test story", + [ + UserUttered(intent={"name": GOODBYE_INTENT_NAME}), + SlotSet("slot_1", False), + ActionExecuted(UTTER_BYE_ACTION), + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_ACTION_1), + ActionExecuted(UTTER_ACTION_2), + # ActionExecuted(UTTER_ACTION_3), + ], + domain=domain, + slots=domain.slots, + ) + prediction1 = policy.predict_action_probabilities( + test_story1, domain, interpreter + ) + assert ( + domain.action_names_or_texts[ + prediction1.probabilities.index(max(prediction1.probabilities)) + ] + == UTTER_ACTION_3 + ) + + test_story2 = TrackerWithCachedStates.from_events( + "test story", + [ + UserUttered(intent={"name": GREET_INTENT_NAME}), + SlotSet("slot_1", False), + ActionExecuted(UTTER_BYE_ACTION), + UserUttered(intent={"name": GOODBYE_INTENT_NAME}), + ActionExecuted(UTTER_ACTION_1), + ActionExecuted(UTTER_ACTION_2), + ActionExecuted(UTTER_ACTION_3), + # ActionExecuted(UTTER_ACTION_4), + ], + domain=domain, + slots=domain.slots, + ) + + prediction2 = policy.predict_action_probabilities( + test_story2, domain, interpreter + ) + assert ( + domain.action_names_or_texts[ + prediction2.probabilities.index(max(prediction2.probabilities)) + ] + == UTTER_ACTION_4 + ) From 163bb9e4d3eb2e6c083e1dc999bf6c7eada775f1 Mon Sep 17 00:00:00 2001 From: kedz Date: Mon, 10 Jan 2022 13:13:27 -0500 Subject: [PATCH 06/12] Made test cases more realistic. Renamed some functions. Rewrote some doc strings. --- rasa/core/policies/memoization.py | 63 ++++++------ tests/core/policies/test_memoization.py | 130 +++++------------------- 2 files changed, 59 insertions(+), 134 deletions(-) diff --git a/rasa/core/policies/memoization.py b/rasa/core/policies/memoization.py index 41372698f8e2..2820671de808 100644 --- a/rasa/core/policies/memoization.py +++ b/rasa/core/policies/memoization.py @@ -11,7 +11,7 @@ import rasa.shared.utils.io from rasa.shared.constants import DOCS_URL_POLICIES from rasa.shared.core.domain import State, Domain -from rasa.shared.core.events import ActionExecuted, UserUttered +from rasa.shared.core.events import ActionExecuted from rasa.core.featurizers.tracker_featurizers import ( TrackerFeaturizer, MaxHistoryTrackerFeaturizer, @@ -160,7 +160,7 @@ def _create_feature_key(self, states: List[State]) -> Text: # represented as dictionaries have the same json strings # quotes are removed for aesthetic reasons feature_str = json.dumps(states, sort_keys=True).replace('"', "") - if self.ENABLE_FEATURE_STRING_COMPRESSION: + if False: # self.ENABLE_FEATURE_STRING_COMPRESSION: compressed = zlib.compress( bytes(feature_str, rasa.shared.utils.io.DEFAULT_ENCODING) ) @@ -287,40 +287,37 @@ class AugmentedMemoizationPolicy(MemoizationPolicy): """ @staticmethod - def _back_to_the_future( + def _truncate_leading_events_until_action_executed( tracker: DialogueStateTracker, again: bool = False ) -> Optional[DialogueStateTracker]: - """Truncates the tracker to the next `ActionExecuted` or `UserUttered` event. + """Truncates the tracker to begin at the next `ActionExecuted` event. Args: tracker: The tracker to truncate. - again: When true, truncate tracker at the second action or - user utterance. Otherwise truncate to the firt action or - user utterance. + again: When true, truncate tracker at the second action. + Otherwise truncate to the first action. Returns: - The truncated tracker if there were actions or user utterances - present. If none are found, returns `None`. + The truncated tracker if there were actions present. + If none are found, returns `None`. """ - idx_of_first_action_or_user = None - idx_of_second_action_or_user = None + idx_of_first_action = None + idx_of_second_action = None applied_events = tracker.applied_events() # We need to find the second `ActionExecuted` or `UserUttered` event. for e_i, event in enumerate(applied_events): - if isinstance(event, ActionExecuted) or isinstance(event, UserUttered): - if idx_of_first_action_or_user is None: - idx_of_first_action_or_user = e_i + if isinstance(event, ActionExecuted): + if idx_of_first_action is None: + idx_of_first_action = e_i else: - idx_of_second_action_or_user = e_i + idx_of_second_action = e_i break # use first action/user utterance, if we went first time and # second action/user utterance, if we went again - idx_to_use = ( - idx_of_second_action_or_user if again else idx_of_first_action_or_user - ) + idx_to_use = idx_of_second_action if again else idx_of_first_action if idx_to_use is None: return None @@ -329,19 +326,19 @@ def _back_to_the_future( if not events: return None - mcfly_tracker = tracker.init_copy() + truncated_tracker = tracker.init_copy() for e in events: - mcfly_tracker.update(e) + truncated_tracker.update(e) - return mcfly_tracker + return truncated_tracker - def _recall_using_delorean( + def _recall_using_truncation( self, old_states: List[State], tracker: DialogueStateTracker, domain: Domain, ) -> Optional[Text]: - """Applies to the future idea to change the past and get the new future. + """Attempts to match memorized states to progressively shorter trackers. - Recursively go to the past to correctly forget slots, - and then back to the future to recall. + This matching will iteratively remove prior slot setting events and + other actions, looking for the first matching memorized state sequence. Args: old_states: List of states. @@ -354,10 +351,12 @@ def _recall_using_delorean( logger.debug("Launch DeLorean...") # Truncate the tracker based on `max_history` - mcfly_tracker = _trim_tracker_by_max_history(tracker, self.max_history) - mcfly_tracker = self._back_to_the_future(mcfly_tracker) - while mcfly_tracker is not None: - states = self._prediction_states(mcfly_tracker, domain,) + truncated_tracker = _trim_tracker_by_max_history(tracker, self.max_history) + truncated_tracker = self._truncate_leading_events_until_action_executed( + truncated_tracker + ) + while truncated_tracker is not None: + states = self._prediction_states(truncated_tracker, domain,) if old_states != states: # check if we like new futures @@ -368,7 +367,9 @@ def _recall_using_delorean( old_states = states # go back again - mcfly_tracker = self._back_to_the_future(mcfly_tracker, again=True) + truncated_tracker = self._truncate_leading_events_until_action_executed( + truncated_tracker, again=True + ) # No match found logger.debug(f"Current tracker state {old_states}") @@ -393,7 +394,7 @@ def recall( predicted_action_name = self._recall_states(states) if predicted_action_name is None: # let's try a different method to recall that tracker - return self._recall_using_delorean(states, tracker, domain,) + return self._recall_using_truncation(states, tracker, domain,) else: return predicted_action_name diff --git a/tests/core/policies/test_memoization.py b/tests/core/policies/test_memoization.py index 4c06e8b6f11f..150a1ed25e46 100644 --- a/tests/core/policies/test_memoization.py +++ b/tests/core/policies/test_memoization.py @@ -51,6 +51,7 @@ def test_prediction(self, max_history): """ ) events = [ + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_GREET_ACTION), SlotSet("slot_1", True), @@ -59,16 +60,18 @@ def test_prediction(self, max_history): SlotSet("slot_3", True), ActionExecuted(UTTER_GREET_ACTION), ActionExecuted(UTTER_GREET_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_GREET_ACTION), SlotSet("slot_4", True), ActionExecuted(UTTER_BYE_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), ] training_story = TrackerWithCachedStates.from_events( "training story", evts=events, domain=domain, slots=domain.slots, ) test_story = TrackerWithCachedStates.from_events( - "training story", events[:-1], domain=domain, slots=domain.slots, + "training story", events[:-2], domain=domain, slots=domain.slots, ) interpreter = RegexInterpreter() policy.train([training_story], domain, interpreter) @@ -90,7 +93,7 @@ def create_policy( return AugmentedMemoizationPolicy(featurizer=featurizer, priority=priority) @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None]) - def test_augmented_prediction_when_starts_with_intent(self, max_history): + def test_augmented_prediction(self, max_history): policy = self.create_policy( featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1 ) @@ -118,10 +121,12 @@ def test_augmented_prediction_when_starts_with_intent(self, max_history): training_story = TrackerWithCachedStates.from_events( "training story", [ + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_GREET_ACTION), SlotSet("slot_3", True), ActionExecuted(UTTER_BYE_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), ], domain=domain, slots=domain.slots, @@ -129,85 +134,18 @@ def test_augmented_prediction_when_starts_with_intent(self, max_history): test_story = TrackerWithCachedStates.from_events( "test story", [ + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_GREET_ACTION), SlotSet("slot_1", False), ActionExecuted(UTTER_GREET_ACTION), ActionExecuted(UTTER_GREET_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_GREET_ACTION), SlotSet("slot_2", True), ActionExecuted(UTTER_GREET_ACTION), - UserUttered(intent={"name": GREET_INTENT_NAME}), - ActionExecuted(UTTER_GREET_ACTION), - SlotSet("slot_3", True), - # ActionExecuted(UTTER_BYE_ACTION), - ], - domain=domain, - slots=domain.slots, - ) - interpreter = RegexInterpreter() - policy.train([training_story], domain, interpreter) - prediction = policy.predict_action_probabilities( - test_story, domain, interpreter - ) - assert ( - domain.action_names_or_texts[ - prediction.probabilities.index(max(prediction.probabilities)) - ] - == UTTER_BYE_ACTION - ) - - @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None]) - def test_augmented_prediction_when_starts_with_action(self, max_history): - policy = self.create_policy( - featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1 - ) - - GREET_INTENT_NAME = "greet" - UTTER_GREET_ACTION = "utter_greet" - UTTER_BYE_ACTION = "utter_goodbye" - domain = Domain.from_yaml( - f""" - intents: - - {GREET_INTENT_NAME} - actions: - - {UTTER_GREET_ACTION} - - {UTTER_BYE_ACTION} - slots: - slot_1: - type: bool - initial_value: true - slot_2: - type: bool - slot_3: - type: bool - """ - ) - training_story = TrackerWithCachedStates.from_events( - "training story", - [ - ActionExecuted(UTTER_GREET_ACTION), - UserUttered(intent={"name": GREET_INTENT_NAME}), - ActionExecuted(UTTER_GREET_ACTION), - SlotSet("slot_3", True), - ActionExecuted(UTTER_BYE_ACTION), - ], - domain=domain, - slots=domain.slots, - ) - test_story = TrackerWithCachedStates.from_events( - "test story", - [ - UserUttered(intent={"name": GREET_INTENT_NAME}), - ActionExecuted(UTTER_GREET_ACTION), - SlotSet("slot_1", False), - ActionExecuted(UTTER_GREET_ACTION), - ActionExecuted(UTTER_GREET_ACTION), - UserUttered(intent={"name": GREET_INTENT_NAME}), - ActionExecuted(UTTER_GREET_ACTION), - SlotSet("slot_2", True), - ActionExecuted(UTTER_GREET_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_GREET_ACTION), SlotSet("slot_3", True), @@ -258,19 +196,12 @@ def test_augmented_prediction_across_max_history_actions(self, max_history): - {UTTER_ACTION_4} - {UTTER_ACTION_5} - {UTTER_BYE_ACTION} - slots: - slot_1: - type: bool - initial_value: true - slot_2: - type: bool - slot_3: - type: bool """ ) training_story = TrackerWithCachedStates.from_events( "training story", [ + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), @@ -278,6 +209,7 @@ def test_augmented_prediction_across_max_history_actions(self, max_history): ActionExecuted(UTTER_ACTION_4), ActionExecuted(UTTER_ACTION_5), ActionExecuted(UTTER_BYE_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), ], domain=domain, slots=domain.slots, @@ -285,6 +217,7 @@ def test_augmented_prediction_across_max_history_actions(self, max_history): test_story = TrackerWithCachedStates.from_events( "test story", [ + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), @@ -339,19 +272,12 @@ def test_aug_pred_sensitive_to_intent_across_max_history_actions(self, max_histo - {UTTER_ACTION_4} - {UTTER_ACTION_5} - {UTTER_BYE_ACTION} - slots: - slot_1: - type: bool - initial_value: true - slot_2: - type: bool - slot_3: - type: bool """ ) training_story = TrackerWithCachedStates.from_events( "training story", [ + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), @@ -359,6 +285,7 @@ def test_aug_pred_sensitive_to_intent_across_max_history_actions(self, max_histo ActionExecuted(UTTER_ACTION_4), ActionExecuted(UTTER_ACTION_5), ActionExecuted(UTTER_BYE_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), ], domain=domain, slots=domain.slots, @@ -369,9 +296,10 @@ def test_aug_pred_sensitive_to_intent_across_max_history_actions(self, max_histo test_story1 = TrackerWithCachedStates.from_events( "test story", [ + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GOODBYE_INTENT_NAME}), - SlotSet("slot_1", False), ActionExecuted(UTTER_BYE_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), @@ -396,9 +324,10 @@ def test_aug_pred_sensitive_to_intent_across_max_history_actions(self, max_histo test_story2_no_match_expected = TrackerWithCachedStates.from_events( "test story", [ + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), - SlotSet("slot_1", False), ActionExecuted(UTTER_BYE_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GOODBYE_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), @@ -450,30 +379,28 @@ def test_aug_pred_connects_different_memoizations(self, max_history): - {UTTER_ACTION_2} - {UTTER_ACTION_3} - {UTTER_ACTION_4} - slots: - slot_1: - type: bool - initial_value: true - slot_2: - type: bool - slot_3: - type: bool """ ) training_story1 = TrackerWithCachedStates.from_events( "training story", [ + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), ActionExecuted(UTTER_ACTION_3), + ActionExecuted(ACTION_LISTEN_NAME), ], domain=domain, slots=domain.slots, ) training_story2 = TrackerWithCachedStates.from_events( "training story", - [ActionExecuted(UTTER_ACTION_3), ActionExecuted(UTTER_ACTION_4),], + [ + ActionExecuted(UTTER_ACTION_3), + ActionExecuted(UTTER_ACTION_4), + ActionExecuted(ACTION_LISTEN_NAME), + ], domain=domain, slots=domain.slots, ) @@ -484,9 +411,7 @@ def test_aug_pred_connects_different_memoizations(self, max_history): test_story1 = TrackerWithCachedStates.from_events( "test story", [ - UserUttered(intent={"name": GOODBYE_INTENT_NAME}), - SlotSet("slot_1", False), - ActionExecuted(UTTER_BYE_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), @@ -509,7 +434,6 @@ def test_aug_pred_connects_different_memoizations(self, max_history): "test story", [ UserUttered(intent={"name": GREET_INTENT_NAME}), - SlotSet("slot_1", False), ActionExecuted(UTTER_BYE_ACTION), UserUttered(intent={"name": GOODBYE_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), From 488a9cdb60ef67de1b8f99bbf298122a842bd80a Mon Sep 17 00:00:00 2001 From: kedz Date: Mon, 10 Jan 2022 13:30:25 -0500 Subject: [PATCH 07/12] Renamed function. --- rasa/core/policies/memoization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rasa/core/policies/memoization.py b/rasa/core/policies/memoization.py index 2820671de808..7b419d6304b9 100644 --- a/rasa/core/policies/memoization.py +++ b/rasa/core/policies/memoization.py @@ -287,7 +287,7 @@ class AugmentedMemoizationPolicy(MemoizationPolicy): """ @staticmethod - def _truncate_leading_events_until_action_executed( + def _strip_leading_events_until_action_executed( tracker: DialogueStateTracker, again: bool = False ) -> Optional[DialogueStateTracker]: """Truncates the tracker to begin at the next `ActionExecuted` event. @@ -352,7 +352,7 @@ def _recall_using_truncation( # Truncate the tracker based on `max_history` truncated_tracker = _trim_tracker_by_max_history(tracker, self.max_history) - truncated_tracker = self._truncate_leading_events_until_action_executed( + truncated_tracker = self._strip_leading_events_until_action_executed( truncated_tracker ) while truncated_tracker is not None: @@ -367,7 +367,7 @@ def _recall_using_truncation( old_states = states # go back again - truncated_tracker = self._truncate_leading_events_until_action_executed( + truncated_tracker = self._strip_leading_events_until_action_executed( truncated_tracker, again=True ) From 6885710d02ae91696af7b795073c9d0da9a395a1 Mon Sep 17 00:00:00 2001 From: Chris Kedzie Date: Tue, 11 Jan 2022 12:20:55 -0500 Subject: [PATCH 08/12] Update rasa/core/policies/memoization.py Co-authored-by: Johannes E. M. Mosig --- rasa/core/policies/memoization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/policies/memoization.py b/rasa/core/policies/memoization.py index 7b419d6304b9..43f49e4cd57c 100644 --- a/rasa/core/policies/memoization.py +++ b/rasa/core/policies/memoization.py @@ -356,7 +356,7 @@ def _recall_using_truncation( truncated_tracker ) while truncated_tracker is not None: - states = self._prediction_states(truncated_tracker, domain,) + states = self._prediction_states(truncated_tracker, domain) if old_states != states: # check if we like new futures From 77402aa4742b2d0df4fe08ace381135536a354aa Mon Sep 17 00:00:00 2001 From: Chris Kedzie Date: Tue, 11 Jan 2022 12:21:08 -0500 Subject: [PATCH 09/12] Update rasa/core/policies/memoization.py Co-authored-by: Johannes E. M. Mosig --- rasa/core/policies/memoization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/policies/memoization.py b/rasa/core/policies/memoization.py index 43f49e4cd57c..cfe4d0eb2187 100644 --- a/rasa/core/policies/memoization.py +++ b/rasa/core/policies/memoization.py @@ -394,7 +394,7 @@ def recall( predicted_action_name = self._recall_states(states) if predicted_action_name is None: # let's try a different method to recall that tracker - return self._recall_using_truncation(states, tracker, domain,) + return self._recall_using_truncation(states, tracker, domain) else: return predicted_action_name From f3e77c9a36e10ade865016bf5671fac7d24ffda5 Mon Sep 17 00:00:00 2001 From: kedz Date: Tue, 11 Jan 2022 12:25:14 -0500 Subject: [PATCH 10/12] Revert comment and compression flag. --- rasa/core/policies/memoization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/policies/memoization.py b/rasa/core/policies/memoization.py index cfe4d0eb2187..65ba5e08a5fa 100644 --- a/rasa/core/policies/memoization.py +++ b/rasa/core/policies/memoization.py @@ -160,7 +160,7 @@ def _create_feature_key(self, states: List[State]) -> Text: # represented as dictionaries have the same json strings # quotes are removed for aesthetic reasons feature_str = json.dumps(states, sort_keys=True).replace('"', "") - if False: # self.ENABLE_FEATURE_STRING_COMPRESSION: + if self.ENABLE_FEATURE_STRING_COMPRESSION: compressed = zlib.compress( bytes(feature_str, rasa.shared.utils.io.DEFAULT_ENCODING) ) @@ -306,7 +306,7 @@ def _strip_leading_events_until_action_executed( applied_events = tracker.applied_events() - # We need to find the second `ActionExecuted` or `UserUttered` event. + # we need to find second executed action for e_i, event in enumerate(applied_events): if isinstance(event, ActionExecuted): if idx_of_first_action is None: From 4bfca135ccbabb35a63d228cd66ac73ae63a5d9c Mon Sep 17 00:00:00 2001 From: kedz Date: Tue, 11 Jan 2022 12:33:25 -0500 Subject: [PATCH 11/12] Revert comment. --- rasa/core/policies/memoization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rasa/core/policies/memoization.py b/rasa/core/policies/memoization.py index 65ba5e08a5fa..1836f6bcff6c 100644 --- a/rasa/core/policies/memoization.py +++ b/rasa/core/policies/memoization.py @@ -315,8 +315,7 @@ def _strip_leading_events_until_action_executed( idx_of_second_action = e_i break - # use first action/user utterance, if we went first time and - # second action/user utterance, if we went again + # use first action, if we went first time and second action, if we went again idx_to_use = idx_of_second_action if again else idx_of_first_action if idx_to_use is None: return None From c8ce1410db110dc665037139a883468035c653ad Mon Sep 17 00:00:00 2001 From: kedz Date: Tue, 11 Jan 2022 14:02:22 -0500 Subject: [PATCH 12/12] Update test. --- tests/core/policies/test_memoization.py | 57 ++++--------------------- 1 file changed, 9 insertions(+), 48 deletions(-) diff --git a/tests/core/policies/test_memoization.py b/tests/core/policies/test_memoization.py index 150a1ed25e46..5a34a3aa2b44 100644 --- a/tests/core/policies/test_memoization.py +++ b/tests/core/policies/test_memoization.py @@ -351,10 +351,9 @@ def test_aug_pred_sensitive_to_intent_across_max_history_actions(self, max_histo ) @pytest.mark.parametrize("max_history", [1, 2, 3, 4, None]) - def test_aug_pred_connects_different_memoizations(self, max_history): - """Tests memoization works for a memoized state sequence that starts - with a user utterance that leads to memoized state that does not - have user utterance information. + def test_aug_pred_without_intent(self, max_history): + """Tests memoization works for a memoized state sequence that does + not have a user utterance. """ policy = self.create_policy( featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1 @@ -367,7 +366,6 @@ def test_aug_pred_connects_different_memoizations(self, max_history): UTTER_ACTION_2 = "utter_2" UTTER_ACTION_3 = "utter_3" UTTER_ACTION_4 = "utter_4" - UTTER_BYE_ACTION = "utter_goodbye" domain = Domain.from_yaml( f""" intents: @@ -381,20 +379,7 @@ def test_aug_pred_connects_different_memoizations(self, max_history): - {UTTER_ACTION_4} """ ) - training_story1 = TrackerWithCachedStates.from_events( - "training story", - [ - ActionExecuted(ACTION_LISTEN_NAME), - UserUttered(intent={"name": GREET_INTENT_NAME}), - ActionExecuted(UTTER_ACTION_1), - ActionExecuted(UTTER_ACTION_2), - ActionExecuted(UTTER_ACTION_3), - ActionExecuted(ACTION_LISTEN_NAME), - ], - domain=domain, - slots=domain.slots, - ) - training_story2 = TrackerWithCachedStates.from_events( + training_story = TrackerWithCachedStates.from_events( "training story", [ ActionExecuted(UTTER_ACTION_3), @@ -406,51 +391,27 @@ def test_aug_pred_connects_different_memoizations(self, max_history): ) interpreter = RegexInterpreter() - policy.train([training_story1, training_story2], domain, interpreter) + policy.train([training_story], domain, interpreter) - test_story1 = TrackerWithCachedStates.from_events( + test_story = TrackerWithCachedStates.from_events( "test story", [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered(intent={"name": GREET_INTENT_NAME}), ActionExecuted(UTTER_ACTION_1), ActionExecuted(UTTER_ACTION_2), - # ActionExecuted(UTTER_ACTION_3), - ], - domain=domain, - slots=domain.slots, - ) - prediction1 = policy.predict_action_probabilities( - test_story1, domain, interpreter - ) - assert ( - domain.action_names_or_texts[ - prediction1.probabilities.index(max(prediction1.probabilities)) - ] - == UTTER_ACTION_3 - ) - - test_story2 = TrackerWithCachedStates.from_events( - "test story", - [ - UserUttered(intent={"name": GREET_INTENT_NAME}), - ActionExecuted(UTTER_BYE_ACTION), - UserUttered(intent={"name": GOODBYE_INTENT_NAME}), - ActionExecuted(UTTER_ACTION_1), - ActionExecuted(UTTER_ACTION_2), ActionExecuted(UTTER_ACTION_3), # ActionExecuted(UTTER_ACTION_4), ], domain=domain, slots=domain.slots, ) - - prediction2 = policy.predict_action_probabilities( - test_story2, domain, interpreter + prediction = policy.predict_action_probabilities( + test_story, domain, interpreter ) assert ( domain.action_names_or_texts[ - prediction2.probabilities.index(max(prediction2.probabilities)) + prediction.probabilities.index(max(prediction.probabilities)) ] == UTTER_ACTION_4 )