Skip to content

Commit

Permalink
Merge c8ce141 into 4b9bb05
Browse files Browse the repository at this point in the history
  • Loading branch information
kedz committed Jan 11, 2022
2 parents 4b9bb05 + c8ce141 commit 1fb0c70
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-model-regression.yml
Expand Up @@ -643,7 +643,7 @@ jobs:
# Get ID of the schedule workflow
SCHEDULE_ID=$(curl -X GET -s -H 'Authorization: token ${{ secrets.GITHUB_TOKEN }}' -H "Accept: application/vnd.github.v3+json" \
"https://api.github.com/repos/${{ github.repository }}/actions/workflows" \
| jq -r '.workflows[] | select(.name == "${{ github.workflow }}") | select(.path | test("schedule")) | .id')
| jq -r '.workflows[] | select(.name == "CI - Model Regression on schedule") | select(.path | test("schedule")) | .id')
ARTIFACT_URL=$(curl -s -H 'Authorization: token ${{ secrets.GITHUB_TOKEN }}' -H "Accept: application/vnd.github.v3+json" \
"https://api.github.com/repos/${{ github.repository }}/actions/workflows/${SCHEDULE_ID}/runs?event=schedule&status=completed&branch=main&per_page=1" | jq -r .workflow_runs[0].artifacts_url)
Expand Down
2 changes: 2 additions & 0 deletions changelog/10606.bugfix.md
@@ -0,0 +1,2 @@
Fix `max_history` truncation in `AugmentedMemoizationPolicy` to preserve the most recent `UserUttered` event.
Previously, `AugmentedMemoizationPolicy` failed to predict next action after long sequences of actions (longer than `max_history`) because the policy did not have access to the most recent user message.
56 changes: 36 additions & 20 deletions rasa/core/policies/memoization.py
Expand Up @@ -22,6 +22,7 @@
from rasa.shared.core.generator import TrackerWithCachedStates
from rasa.shared.utils.io import is_logging_disabled
from rasa.core.constants import MEMOIZATION_POLICY_PRIORITY
from rasa.shared.core.constants import ACTION_LISTEN_NAME

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -286,20 +287,27 @@ class AugmentedMemoizationPolicy(MemoizationPolicy):
"""

@staticmethod
def _back_to_the_future(
def _strip_leading_events_until_action_executed(
tracker: DialogueStateTracker, again: bool = False
) -> Optional[DialogueStateTracker]:
"""Send Marty to the past to get
the new featurization for the future"""
"""Truncates the tracker to begin at the next `ActionExecuted` event.
Args:
tracker: The tracker to truncate.
again: When true, truncate tracker at the second action.
Otherwise truncate to the first action.
Returns:
The truncated tracker if there were actions present.
If none are found, returns `None`.
"""
idx_of_first_action = None
idx_of_second_action = None

applied_events = tracker.applied_events()

# we need to find second executed action
for e_i, event in enumerate(applied_events):
# find second ActionExecuted
if isinstance(event, ActionExecuted):
if idx_of_first_action is None:
idx_of_first_action = e_i
Expand All @@ -317,19 +325,19 @@ def _back_to_the_future(
if not events:
return None

mcfly_tracker = tracker.init_copy()
truncated_tracker = tracker.init_copy()
for e in events:
mcfly_tracker.update(e)
truncated_tracker.update(e)

return mcfly_tracker
return truncated_tracker

def _recall_using_delorean(
def _recall_using_truncation(
self, old_states: List[State], tracker: DialogueStateTracker, domain: Domain,
) -> Optional[Text]:
"""Applies to the future idea to change the past and get the new future.
"""Attempts to match memorized states to progressively shorter trackers.
Recursively go to the past to correctly forget slots,
and then back to the future to recall.
This matching will iteratively remove prior slot setting events and
other actions, looking for the first matching memorized state sequence.
Args:
old_states: List of states.
Expand All @@ -342,10 +350,12 @@ def _recall_using_delorean(
logger.debug("Launch DeLorean...")

# Truncate the tracker based on `max_history`
mcfly_tracker = _trim_tracker_by_max_history(tracker, self.max_history)
mcfly_tracker = self._back_to_the_future(mcfly_tracker)
while mcfly_tracker is not None:
states = self._prediction_states(mcfly_tracker, domain,)
truncated_tracker = _trim_tracker_by_max_history(tracker, self.max_history)
truncated_tracker = self._strip_leading_events_until_action_executed(
truncated_tracker
)
while truncated_tracker is not None:
states = self._prediction_states(truncated_tracker, domain)

if old_states != states:
# check if we like new futures
Expand All @@ -356,7 +366,9 @@ def _recall_using_delorean(
old_states = states

# go back again
mcfly_tracker = self._back_to_the_future(mcfly_tracker, again=True)
truncated_tracker = self._strip_leading_events_until_action_executed(
truncated_tracker, again=True
)

# No match found
logger.debug(f"Current tracker state {old_states}")
Expand All @@ -381,7 +393,7 @@ def recall(
predicted_action_name = self._recall_states(states)
if predicted_action_name is None:
# let's try a different method to recall that tracker
return self._recall_using_delorean(states, tracker, domain,)
return self._recall_using_truncation(states, tracker, domain)
else:
return predicted_action_name

Expand All @@ -391,12 +403,16 @@ def _get_max_applied_events_for_max_history(
) -> Optional[int]:
"""Computes the number of events in the tracker that correspond to max_history.
To ensure that the last user utterance is correctly included in the prediction
states, return the index of the most recent `action_listen` event occuring
before the tracker would be truncated according to the value of `max_history`.
Args:
tracker: Some tracker holding the events
max_history: The number of actions to count
Returns:
The number of actions, as counted from the end of the event list, that should
The number of events, as counted from the end of the event list, that should
be taken into accout according to the `max_history` setting. If all events
should be taken into account, the return value is `None`.
"""
Expand All @@ -408,8 +424,8 @@ def _get_max_applied_events_for_max_history(
num_events += 1
if isinstance(event, ActionExecuted):
num_actions += 1
if num_actions > max_history:
return num_events
if num_actions > max_history and event.action_name == ACTION_LISTEN_NAME:
return num_events
return None


Expand Down

0 comments on commit 1fb0c70

Please sign in to comment.