Skip to content

Commit

Permalink
Quick fix for slow augmented memo
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes E. M. Mosig committed Jul 6, 2021
1 parent f5c0354 commit 9760f57
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 5 deletions.
22 changes: 17 additions & 5 deletions rasa/core/policies/memoization.py
Expand Up @@ -287,16 +287,21 @@ class AugmentedMemoizationPolicy(MemoizationPolicy):

@staticmethod
def _back_to_the_future(
tracker: DialogueStateTracker, again: bool = False
tracker: DialogueStateTracker,
again: bool = False,
max_history: Optional[int] = None,
) -> Optional[DialogueStateTracker]:
"""Send Marty to the past to get
the new featurization for the future"""

idx_of_first_action = None
idx_of_second_action = None

applied_events = tracker.applied_events()
start_index = max(0, len(applied_events) - max_history) if max_history else 0

# we need to find second executed action
for e_i, event in enumerate(tracker.applied_events()):
for e_i, event in enumerate(applied_events[start_index:]):
# find second ActionExecuted
if isinstance(event, ActionExecuted):
if idx_of_first_action is None:
Expand All @@ -311,7 +316,7 @@ def _back_to_the_future(
return

# make second ActionExecuted the first one
events = tracker.applied_events()[idx_to_use:]
events = applied_events[idx_to_use:]
if not events:
return

Expand Down Expand Up @@ -339,7 +344,10 @@ def _recall_using_delorean(
"""
logger.debug("Launch DeLorean...")

mcfly_tracker = self._back_to_the_future(tracker)
mcfly_tracker = self._back_to_the_future(
tracker,
max_history=self.max_history
)
while mcfly_tracker is not None:
states = self._prediction_states(mcfly_tracker, domain,)

Expand All @@ -352,7 +360,11 @@ def _recall_using_delorean(
old_states = states

# go back again
mcfly_tracker = self._back_to_the_future(mcfly_tracker, again=True)
mcfly_tracker = self._back_to_the_future(
mcfly_tracker,
max_history=None,
again=True,
)

# No match found
logger.debug(f"Current tracker state {old_states}")
Expand Down
150 changes: 150 additions & 0 deletions tests/core/policies/test_memoization.py
@@ -0,0 +1,150 @@

import pytest
from tests.core.test_policies import PolicyTestCollection
from typing import Optional
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer, MaxHistoryTrackerFeaturizer
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.generator import TrackerWithCachedStates
from rasa.core.policies.memoization import AugmentedMemoizationPolicy, MemoizationPolicy
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import (
ActionExecuted,
UserUttered,
SlotSet,
)
from rasa.shared.nlu.interpreter import RegexInterpreter
from pathlib import Path



class TestMemoizationPolicy(PolicyTestCollection):

def create_policy(
self, featurizer: Optional[TrackerFeaturizer], priority: int
) -> MemoizationPolicy:
return AugmentedMemoizationPolicy(featurizer=featurizer, priority=priority)
# return MemoizationPolicy(featurizer=featurizer, priority=priority)

def test_prediction(self):
policy = self.create_policy(
featurizer=MaxHistoryTrackerFeaturizer(max_history=2),
priority=1
)

GREET_INTENT_NAME = "greet"
UTTER_GREET_ACTION = "utter_greet"
domain = Domain.from_yaml(
f"""
intents:
- {GREET_INTENT_NAME}
actions:
- {UTTER_GREET_ACTION}
slots:
slot_1:
type: bool
slot_2:
type: bool
slot_3:
type: bool
slot_4:
type: bool
"""
)
events = [
UserUttered(intent={"name": GREET_INTENT_NAME}),
ActionExecuted(UTTER_GREET_ACTION),
SlotSet("slot_1", True),
ActionExecuted(UTTER_GREET_ACTION),
SlotSet("slot_2", True),
SlotSet("slot_3", True),
ActionExecuted(UTTER_GREET_ACTION),
ActionExecuted(UTTER_GREET_ACTION),
ActionExecuted(UTTER_GREET_ACTION),
SlotSet("slot_4", True),
ActionExecuted(UTTER_GREET_ACTION),
]
training_story = TrackerWithCachedStates.from_events(
"training story",
events,
domain=domain,
slots=domain.slots,
)
test_story = TrackerWithCachedStates.from_events(
"training story",
events[:-2],
domain=domain,
slots=domain.slots,
)
policy.train([training_story], domain, RegexInterpreter())
prediction = policy.predict_action_probabilities(
test_story, domain, RegexInterpreter()
)
assert domain.action_names_or_texts[prediction.max_confidence_index] == UTTER_GREET_ACTION


class TestAugmentedMemoizationPolicy(TestMemoizationPolicy):

def test_augmented_prediction(self):
policy = self.create_policy(
featurizer=MaxHistoryTrackerFeaturizer(max_history=2),
priority=1
)

GREET_INTENT_NAME = "greet"
UTTER_GREET_ACTION = "utter_greet"
UTTER_BYE_ACTION = "utter_goodbye"
domain = Domain.from_yaml(
f"""
intents:
- {GREET_INTENT_NAME}
actions:
- {UTTER_GREET_ACTION}
- {UTTER_BYE_ACTION}
slots:
slot_1:
type: bool
influence_conversation: true
initial_value: true
slot_2:
type: bool
influence_conversation: true
slot_3:
type: bool
influence_conversation: true
slot_4:
type: bool
influence_conversation: true
"""
)
training_story = TrackerWithCachedStates.from_events(
"training story",
[
ActionExecuted(UTTER_GREET_ACTION),
SlotSet("slot_4", True),
ActionExecuted(UTTER_BYE_ACTION),
],
domain=domain,
slots=domain.slots,
)
test_story = TrackerWithCachedStates.from_events(
"test story",
[
UserUttered(intent={"name": GREET_INTENT_NAME}),
ActionExecuted(UTTER_GREET_ACTION),
SlotSet("slot_1", False),
ActionExecuted(UTTER_GREET_ACTION),
ActionExecuted(UTTER_GREET_ACTION),
ActionExecuted(UTTER_GREET_ACTION),
SlotSet("slot_3", True),
ActionExecuted(UTTER_GREET_ACTION),
SlotSet("slot_4", True),
# ActionExecuted(UTTER_BYE_ACTION),
],
domain=domain,
slots=domain.slots,
)
policy.train([training_story], domain, RegexInterpreter())
prediction = policy.predict_action_probabilities(
test_story, domain, RegexInterpreter()
)
assert domain.action_names_or_texts[prediction.max_confidence_index] == UTTER_BYE_ACTION

0 comments on commit 9760f57

Please sign in to comment.