Skip to content

Commit

Permalink
Merge b744fbd into 243fa48
Browse files Browse the repository at this point in the history
  • Loading branch information
raoulvm committed Aug 16, 2022
2 parents 243fa48 + b744fbd commit af7f1af
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions rasa/shared/core/generator.py
Expand Up @@ -62,6 +62,7 @@ def __init__(
sender_id, slots, max_event_history, is_rule_tracker=is_rule_tracker
)
self._states_for_hashing = None
self._omit_unset_slots = None
self.domain = domain
# T/F property to filter augmented stories
self.is_augmented = is_augmented
Expand Down Expand Up @@ -103,8 +104,13 @@ def past_states_for_hashing(

# if don't have it cached, we use the domain to calculate the states
# from the events
if self._states_for_hashing is None:
if (
self._states_for_hashing is None
or not hasattr(self, "_omit_unset_slots")
or omit_unset_slots != self._omit_unset_slots
):
states = super().past_states(domain, omit_unset_slots=omit_unset_slots)
self._omit_unset_slots = omit_unset_slots
self._states_for_hashing = deque(
self.freeze_current_state(s) for s in states
)
Expand Down Expand Up @@ -146,6 +152,7 @@ def past_states(
def clear_states(self) -> None:
"""Reset the states."""
self._states_for_hashing = None
self._omit_unset_slots = None

def init_copy(self) -> "TrackerWithCachedStates":
"""Create a new state tracker with the same initial values."""
Expand Down Expand Up @@ -177,14 +184,28 @@ def copy(
tracker.update(event, skip_states=True)

tracker._states_for_hashing = copy.copy(self._states_for_hashing)
tracker._omit_unset_slots = (
self._omit_unset_slots if hasattr(self, "_omit_unset_slots") else None
)

return tracker

def _append_current_state(self) -> None:
if self._states_for_hashing is None:
self._states_for_hashing = self.past_states_for_hashing(self.domain)
self._states_for_hashing = self.past_states_for_hashing(
self.domain,
omit_unset_slots=self._omit_unset_slots
if hasattr(self, "_omit_unset_slots")
else False,
)
self._omit_unset_slots = False
else:
state = self.domain.get_active_state(self)
state = self.domain.get_active_state(
self,
omit_unset_slots=self._omit_unset_slots
if hasattr(self, "_omit_unset_slots")
else False,
)
frozen_state = self.freeze_current_state(state)
self._states_for_hashing.append(frozen_state)

Expand All @@ -197,7 +218,12 @@ def update(self, event: Event, skip_states: bool = False) -> None:
if self._states_for_hashing is None and not skip_states:
# rest of this function assumes we have the previous state
# cached. let's make sure it is there.
self._states_for_hashing = self.past_states_for_hashing(self.domain)
self._states_for_hashing = self.past_states_for_hashing(
self.domain,
omit_unset_slots=self._omit_unset_slots
if hasattr(self, "_omit_unset_slots")
else False,
)

super().update(event)

Expand Down

0 comments on commit af7f1af

Please sign in to comment.