Skip to content

Commit

Permalink
Update test.
Browse files Browse the repository at this point in the history
  • Loading branch information
kedz committed Jan 11, 2022
1 parent 4bfca13 commit c8ce141
Showing 1 changed file with 9 additions and 48 deletions.
57 changes: 9 additions & 48 deletions tests/core/policies/test_memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,9 @@ def test_aug_pred_sensitive_to_intent_across_max_history_actions(self, max_histo
)

@pytest.mark.parametrize("max_history", [1, 2, 3, 4, None])
def test_aug_pred_connects_different_memoizations(self, max_history):
"""Tests memoization works for a memoized state sequence that starts
with a user utterance that leads to memoized state that does not
have user utterance information.
def test_aug_pred_without_intent(self, max_history):
"""Tests memoization works for a memoized state sequence that does
not have a user utterance.
"""
policy = self.create_policy(
featurizer=MaxHistoryTrackerFeaturizer(max_history=max_history), priority=1
Expand All @@ -367,7 +366,6 @@ def test_aug_pred_connects_different_memoizations(self, max_history):
UTTER_ACTION_2 = "utter_2"
UTTER_ACTION_3 = "utter_3"
UTTER_ACTION_4 = "utter_4"
UTTER_BYE_ACTION = "utter_goodbye"
domain = Domain.from_yaml(
f"""
intents:
Expand All @@ -381,20 +379,7 @@ def test_aug_pred_connects_different_memoizations(self, max_history):
- {UTTER_ACTION_4}
"""
)
training_story1 = TrackerWithCachedStates.from_events(
"training story",
[
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(intent={"name": GREET_INTENT_NAME}),
ActionExecuted(UTTER_ACTION_1),
ActionExecuted(UTTER_ACTION_2),
ActionExecuted(UTTER_ACTION_3),
ActionExecuted(ACTION_LISTEN_NAME),
],
domain=domain,
slots=domain.slots,
)
training_story2 = TrackerWithCachedStates.from_events(
training_story = TrackerWithCachedStates.from_events(
"training story",
[
ActionExecuted(UTTER_ACTION_3),
Expand All @@ -406,51 +391,27 @@ def test_aug_pred_connects_different_memoizations(self, max_history):
)

interpreter = RegexInterpreter()
policy.train([training_story1, training_story2], domain, interpreter)
policy.train([training_story], domain, interpreter)

test_story1 = TrackerWithCachedStates.from_events(
test_story = TrackerWithCachedStates.from_events(
"test story",
[
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(intent={"name": GREET_INTENT_NAME}),
ActionExecuted(UTTER_ACTION_1),
ActionExecuted(UTTER_ACTION_2),
# ActionExecuted(UTTER_ACTION_3),
],
domain=domain,
slots=domain.slots,
)
prediction1 = policy.predict_action_probabilities(
test_story1, domain, interpreter
)
assert (
domain.action_names_or_texts[
prediction1.probabilities.index(max(prediction1.probabilities))
]
== UTTER_ACTION_3
)

test_story2 = TrackerWithCachedStates.from_events(
"test story",
[
UserUttered(intent={"name": GREET_INTENT_NAME}),
ActionExecuted(UTTER_BYE_ACTION),
UserUttered(intent={"name": GOODBYE_INTENT_NAME}),
ActionExecuted(UTTER_ACTION_1),
ActionExecuted(UTTER_ACTION_2),
ActionExecuted(UTTER_ACTION_3),
# ActionExecuted(UTTER_ACTION_4),
],
domain=domain,
slots=domain.slots,
)

prediction2 = policy.predict_action_probabilities(
test_story2, domain, interpreter
prediction = policy.predict_action_probabilities(
test_story, domain, interpreter
)
assert (
domain.action_names_or_texts[
prediction2.probabilities.index(max(prediction2.probabilities))
prediction.probabilities.index(max(prediction.probabilities))
]
== UTTER_ACTION_4
)

0 comments on commit c8ce141

Please sign in to comment.