From a12885fbbd5003c323bab3d086a865e22d4003e7 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 27 Nov 2019 07:31:15 -0500 Subject: [PATCH 01/74] add SessionStarted event --- rasa/core/events/__init__.py | 35 ++++++++++++++++++++++++------ rasa/core/training/structures.py | 37 ++++++++++++++++++++------------ tests/core/test_dsl.py | 6 ++++++ 3 files changed, 58 insertions(+), 20 deletions(-) diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index ec9c5a274004..6fee56dc4e58 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -1,13 +1,12 @@ -import time -import typing - import json -import warnings -import jsonpickle import logging + +import jsonpickle +import time +import typing import uuid from dateutil import parser -from typing import List, Dict, Text, Any, Type, Optional +from typing import List, Dict, Text, Any, Type, Optional, NoReturn from rasa.core import utils @@ -1172,3 +1171,27 @@ def as_dict(self): def apply_to(self, tracker: "DialogueStateTracker") -> None: tracker.reject_action(self.action_name) + + +class SessionStarted(Event): + """Mark the beginning of a new conversation session.""" + + type_name = "session_started" + + def __hash__(self) -> int: + return hash(32143124320) + + def __eq__(self, other: Any) -> bool: + return isinstance(other, SessionStarted) + + def __str__(self) -> Text: + return "SessionStarted()" + + def as_story_string(self) -> NoReturn: + raise NotImplementedError( + f"'{self.type_name}' events cannot be serialised as story strings." + ) + + # todo + def apply_to(self, tracker: "DialogueStateTracker") -> None: + tracker._paused = True diff --git a/rasa/core/training/structures.py b/rasa/core/training/structures.py index de3a0ea724de..c5cad08620b1 100644 --- a/rasa/core/training/structures.py +++ b/rasa/core/training/structures.py @@ -1,8 +1,9 @@ import json import logging -import sys -import uuid from collections import deque, defaultdict + +import uuid +import typing from typing import List, Text, Dict, Optional, Tuple, Any, Set, ValuesView from rasa.core import utils @@ -17,8 +18,12 @@ SlotSet, Event, ActionExecutionRejected, + SessionStarted, ) +if typing.TYPE_CHECKING: + import networkx as nx + logger = logging.getLogger(__name__) # Checkpoint id used to identify story starting blocks @@ -118,7 +123,7 @@ def __init__( self.story_string_helper = StoryStringHelper() - def create_copy(self, use_new_id): + def create_copy(self, use_new_id: bool) -> "StoryStep": copied = StoryStep( self.block_name, self.start_checkpoints, @@ -129,10 +134,10 @@ def create_copy(self, use_new_id): copied.id = self.id return copied - def add_user_message(self, user_message): + def add_user_message(self, user_message: UserUttered) -> None: self.add_event(user_message) - def add_event(self, event): + def add_event(self, event: Event) -> None: self.events.append(event) @staticmethod @@ -156,10 +161,10 @@ def _store_user_strings( ) @staticmethod - def _bot_string(story_step_element, prefix=""): + def _bot_string(story_step_element: Event, prefix: Text = "") -> Text: return " - {}{}\n".format(prefix, story_step_element.as_story_string()) - def _store_bot_strings(self, story_step_element, prefix=""): + def _store_bot_strings(self, story_step_element: Event, prefix: Text = "") -> None: self.story_string_helper.no_form_prefix_string += self._bot_string( story_step_element ) @@ -167,7 +172,7 @@ def _store_bot_strings(self, story_step_element, prefix=""): story_step_element, prefix ) - def _reset_stored_strings(self): + def _reset_stored_strings(self) -> None: self.story_string_helper.form_prefix_string = "" self.story_string_helper.no_form_prefix_string = "" @@ -212,6 +217,10 @@ def as_story_string(self, flat: bool = False, e2e: bool = False) -> Text: result += self._bot_string(s) + elif isinstance(s, SessionStarted): + # `SessionStarted` events are not dumped in stories + continue + elif isinstance(s, FormValidation): self.story_string_helper.form_validation = s.validate @@ -303,12 +312,12 @@ def as_story_string(self, flat: bool = False, e2e: bool = False) -> Text: return result @staticmethod - def _is_action_listen(event): + def _is_action_listen(event: Event) -> bool: # this is not an `isinstance` because # we don't want to allow subclasses here return type(event) == ActionExecuted and event.action_name == ACTION_LISTEN_NAME - def _add_action_listen(self, events): + def _add_action_listen(self, events: List[Event]) -> None: if not events or not self._is_action_listen(events[-1]): # do not add second action_listen events.append(ActionExecuted(ACTION_LISTEN_NAME)) @@ -362,7 +371,7 @@ def __init__( self.story_name = story_name @staticmethod - def from_events(events, story_name=None): + def from_events(events: List[Event], story_name: Optional[Text] = None) -> "Story": """Create a story from a list of events.""" story_step = StoryStep() @@ -370,7 +379,7 @@ def from_events(events, story_name=None): story_step.add_event(event) return Story([story_step], story_name) - def as_dialogue(self, sender_id, domain): + def as_dialogue(self, sender_id: Text, domain: Domain) -> Dialogue: events = [] for step in self.story_steps: events.extend( @@ -732,7 +741,7 @@ def dfs(node): return ordered, sorted(removed_edges) - def visualize(self, output_file=None): + def visualize(self, output_file=None) -> "nx.MultiDiGraph": import networkx as nx from rasa.core.training import visualization # pytype: disable=pyi-error from colorhash import ColorHash @@ -741,7 +750,7 @@ def visualize(self, output_file=None): next_node_idx = [0] nodes = {"STORY_START": 0, "STORY_END": -1} - def ensure_checkpoint_is_drawn(cp): + def ensure_checkpoint_is_drawn(cp: Checkpoint) -> None: if cp.name not in nodes: next_node_idx[0] += 1 nodes[cp.name] = next_node_idx[0] diff --git a/tests/core/test_dsl.py b/tests/core/test_dsl.py index 8da97152d923..220da918f0a8 100644 --- a/tests/core/test_dsl.py +++ b/tests/core/test_dsl.py @@ -18,6 +18,7 @@ ActionExecutionRejected, Form, FormValidation, + SessionStarted, ) from rasa.core.training.structures import Story from rasa.core.featurizers import ( @@ -480,6 +481,11 @@ def test_user_uttered_to_e2e(parse_data: Dict, expected_story_string: Text): assert event.as_story_string(e2e=True) == expected_story_string +def test_session_started_event_cannot_be_serialised(): + with pytest.raises(NotImplementedError): + SessionStarted().as_story_string() + + @pytest.mark.parametrize("line", [" greet{: hi"]) def test_invalid_end_to_end_format(line: Text): reader = EndToEndReader() From c3c2d4201f211451e78a31e76b30187f1a06695b Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 27 Nov 2019 16:01:06 -0500 Subject: [PATCH 02/74] add structures tests --- rasa/core/events/__init__.py | 11 +++++---- tests/core/test_events.py | 44 ++++++++++++++++++----------------- tests/core/test_structures.py | 29 +++++++++++++++++++++++ tests/core/test_trackers.py | 34 ++++++++++++++++++++------- 4 files changed, 84 insertions(+), 34 deletions(-) create mode 100644 tests/core/test_structures.py diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index 6fee56dc4e58..fa776e0b653c 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -490,7 +490,7 @@ def _from_parameters(cls, parameters): except KeyError as e: raise ValueError(f"Failed to parse set slot event. {e}") - def apply_to(self, tracker): + def apply_to(self, tracker: "DialogueStateTracker") -> None: tracker._set_slot(self.key, self.value) @@ -516,7 +516,7 @@ def __str__(self): def as_story_string(self): return self.type_name - def apply_to(self, tracker): + def apply_to(self, tracker: "DialogueStateTracker") -> None: from rasa.core.actions.action import ( # pytype: disable=pyi-error ACTION_LISTEN_NAME, ) @@ -1192,6 +1192,9 @@ def as_story_string(self) -> NoReturn: f"'{self.type_name}' events cannot be serialised as story strings." ) - # todo def apply_to(self, tracker: "DialogueStateTracker") -> None: - tracker._paused = True + from rasa.core.actions.action import ( # pytype: disable=pyi-error + ACTION_LISTEN_NAME, + ) + + tracker.trigger_followup_action(ACTION_LISTEN_NAME) diff --git a/tests/core/test_events.py b/tests/core/test_events.py index f1025f498168..15d88d87aadc 100644 --- a/tests/core/test_events.py +++ b/tests/core/test_events.py @@ -1,11 +1,11 @@ -import time - -import pytz -from datetime import datetime import copy import pytest +import pytz +import time +from datetime import datetime from dateutil import parser + from rasa.core import utils from rasa.core.events import ( Event, @@ -23,6 +23,7 @@ FollowupAction, UserUtteranceReverted, AgentUttered, + SessionStarted, ) @@ -41,6 +42,7 @@ (StoryExported(), None), (ActionReverted(), None), (UserUtteranceReverted(), None), + (SessionStarted(), None), (ActionExecuted("my_action"), ActionExecuted("my_other_action")), (FollowupAction("my_action"), FollowupAction("my_other_action")), ( @@ -92,6 +94,7 @@ def test_event_has_proper_implementation(one_event, another_event): StoryExported(), ActionReverted(), UserUtteranceReverted(), + SessionStarted(), ActionExecuted("my_action"), ActionExecuted("my_action", "policy_1_KerasPolicy", 0.8), FollowupAction("my_action"), @@ -131,18 +134,18 @@ def test_json_parse_reset(): def test_json_parse_user(): # fmt: off # DOCS MARKER UserUttered - evt={ - "event": "user", - "text": "Hey", - "parse_data": { + evt = { + "event": "user", + "text": "Hey", + "parse_data": { "intent": { - "name": "greet", - "confidence": 0.9 + "name": "greet", + "confidence": 0.9 }, "entities": [] - }, - "metadata": {}, - } + }, + "metadata": {}, + } # DOCS END # fmt: on assert Event.from_parameters(evt) == UserUttered( @@ -171,13 +174,13 @@ def test_json_parse_rewind(): def test_json_parse_reminder(): # fmt: off # DOCS MARKER ReminderScheduled - evt={ - "event": "reminder", - "action": "my_action", - "date_time": "2018-09-03T11:41:10.128172", - "name": "my_reminder", - "kill_on_user_msg": True, - } + evt = { + "event": "reminder", + "action": "my_action", + "date_time": "2018-09-03T11:41:10.128172", + "name": "my_reminder", + "kill_on_user_msg": True, + } # DOCS END # fmt: on assert Event.from_parameters(evt) == ReminderScheduled( @@ -288,7 +291,6 @@ def test_event_metadata_dict(event_class): @pytest.mark.parametrize("event_class", utils.all_subclasses(Event)) def test_event_default_metadata(event_class): - # Create an event without metadata. # When converting the Event to a dict, it should not include a `metadata` # property - unless it's a UserUttered or a BotUttered event (or subclasses diff --git a/tests/core/test_structures.py b/tests/core/test_structures.py new file mode 100644 index 000000000000..58f5d8d4d788 --- /dev/null +++ b/tests/core/test_structures.py @@ -0,0 +1,29 @@ +from rasa.core.domain import Domain +from rasa.core.events import SessionStarted, SlotSet, UserUttered +from rasa.core.trackers import DialogueStateTracker +from rasa.core.training.structures import Story + +domain = Domain.load("examples/moodbot/domain.yml") + + +def test_session_start_is_not_serialised(default_domain: Domain): + tracker = DialogueStateTracker("default", default_domain.slots) + # the retrieved tracker should be empty + assert len(tracker.events) == 0 + + # add SlotSet event + tracker.update(SlotSet("slot", "value")) + + # add a SessionStarted event and a user event + tracker.update(SessionStarted()) + tracker.update(UserUttered("say something")) + + # make sure session start is not serialised + story = Story.from_events(tracker.events, "some-story01") + + expected = """## some-story01 + - slot{"slot": "value"} +* say something +""" + + assert story.as_story_string(flat=True) == expected diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index 7f42708fdfa5..ba1994cb06af 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -17,6 +17,7 @@ Restarted, ActionReverted, UserUtteranceReverted, + SessionStarted, ) from rasa.core.tracker_store import ( InMemoryTrackerStore, @@ -37,14 +38,14 @@ class MockRedisTrackerStore(RedisTrackerStore): - def __init__(self, domain): + def __init__(self, _domain: Domain) -> None: self.red = fakeredis.FakeStrictRedis() self.record_exp = None # added in redis==3.3.0, but not yet in fakeredis self.red.connection_pool.connection_class.health_check_interval = 0 - TrackerStore.__init__(self, domain) + TrackerStore.__init__(self, _domain) def stores_to_be_tested(): @@ -113,7 +114,7 @@ def test_tracker_store(store, pair): assert restored == tracker -async def test_tracker_write_to_story(tmpdir, moodbot_domain): +async def test_tracker_write_to_story(tmpdir, moodbot_domain: Domain): tracker = tracker_from_dialogue_file( "data/test_dialogues/moodbot.json", moodbot_domain ) @@ -181,7 +182,7 @@ async def test_bot_utterance_comes_after_action_event(default_agent): assert [e.type_name for e in tracker.events] == expected -def test_tracker_entity_retrieval(default_domain): +def test_tracker_entity_retrieval(default_domain: Domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 @@ -207,7 +208,7 @@ def test_tracker_entity_retrieval(default_domain): assert list(tracker.get_latest_entity_values("unknown")) == [] -def test_tracker_update_slots_with_entity(default_domain): +def test_tracker_update_slots_with_entity(default_domain: Domain): tracker = DialogueStateTracker("default", default_domain.slots) test_entity = default_domain.entities[0] @@ -234,7 +235,7 @@ def test_tracker_update_slots_with_entity(default_domain): assert tracker.get_slot(test_entity) == expected_slot_value -def test_restart_event(default_domain): +def test_restart_event(default_domain: Domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 @@ -268,7 +269,22 @@ def test_restart_event(default_domain): assert len(list(recovered.generate_all_prior_trackers())) == 1 -def test_revert_action_event(default_domain): +def test_session_start(default_domain: Domain): + tracker = DialogueStateTracker("default", default_domain.slots) + # the retrieved tracker should be empty + assert len(tracker.events) == 0 + + # add a SessionStarted event + tracker.update(SessionStarted()) + + # tracker has one event + assert len(tracker.events) == 1 + + # follow-up action should be 'action_listen' + assert tracker.followup_action == ACTION_LISTEN_NAME + + +def test_revert_action_event(default_domain: Domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 @@ -304,7 +320,7 @@ def test_revert_action_event(default_domain): assert len(list(tracker.generate_all_prior_trackers())) == 3 -def test_revert_user_utterance_event(default_domain): +def test_revert_user_utterance_event(default_domain: Domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 @@ -346,7 +362,7 @@ def test_revert_user_utterance_event(default_domain): assert len(list(tracker.generate_all_prior_trackers())) == 3 -def test_traveling_back_in_time(default_domain): +def test_traveling_back_in_time(default_domain: Domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 From c3d07c4e08a6d1e4448f9027d80084928dd4df0d Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 27 Nov 2019 21:19:47 -0500 Subject: [PATCH 03/74] fix typing --- rasa/core/training/structures.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rasa/core/training/structures.py b/rasa/core/training/structures.py index c5cad08620b1..087dbd995656 100644 --- a/rasa/core/training/structures.py +++ b/rasa/core/training/structures.py @@ -312,12 +312,12 @@ def as_story_string(self, flat: bool = False, e2e: bool = False) -> Text: return result @staticmethod - def _is_action_listen(event: Event) -> bool: + def _is_action_listen(event: ActionExecuted) -> bool: # this is not an `isinstance` because # we don't want to allow subclasses here return type(event) == ActionExecuted and event.action_name == ACTION_LISTEN_NAME - def _add_action_listen(self, events: List[Event]) -> None: + def _add_action_listen(self, events: List[ActionExecuted]) -> None: if not events or not self._is_action_listen(events[-1]): # do not add second action_listen events.append(ActionExecuted(ACTION_LISTEN_NAME)) @@ -741,7 +741,7 @@ def dfs(node): return ordered, sorted(removed_edges) - def visualize(self, output_file=None) -> "nx.MultiDiGraph": + def visualize(self, output_file: Optional[Text] = None) -> "nx.MultiDiGraph": import networkx as nx from rasa.core.training import visualization # pytype: disable=pyi-error from colorhash import ColorHash From 3b5542b6bd3e5fc5b4bc23ecec422c7f773a0049 Mon Sep 17 00:00:00 2001 From: ricwo Date: Thu, 28 Nov 2019 13:44:20 -0500 Subject: [PATCH 04/74] add changelog file --- changelog/4830.feature.rst | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 changelog/4830.feature.rst diff --git a/changelog/4830.feature.rst b/changelog/4830.feature.rst new file mode 100644 index 000000000000..ad8135c39fb1 --- /dev/null +++ b/changelog/4830.feature.rst @@ -0,0 +1,2 @@ +Added a new event ``SessionStarted`` that marks the beginning of a new conversation +session. From 295fb8a94467ca1d1be93ca071b07a4a063ab370 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 2 Dec 2019 14:22:40 +0100 Subject: [PATCH 05/74] add session logic to processor --- rasa/core/actions/action.py | 26 ++++++++++++ rasa/core/constants.py | 2 + rasa/core/events/__init__.py | 1 - rasa/core/policies/mapping_policy.py | 9 ++++- rasa/core/processor.py | 59 ++++++++++++++++++++++++++-- rasa/core/tracker_store.py | 17 +++++--- rasa/core/trackers.py | 4 +- 7 files changed, 104 insertions(+), 14 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 523de3752b59..9bdbb0710dd1 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -43,6 +43,8 @@ ACTION_RESTART_NAME = "action_restart" +ACTION_SESSION_START_NAME = "action_session_start" + ACTION_DEFAULT_FALLBACK_NAME = "action_default_fallback" ACTION_DEACTIVATE_FORM_NAME = "action_deactivate_form" @@ -61,6 +63,7 @@ def default_actions() -> List["Action"]: return [ ActionListen(), ActionRestart(), + ActionSessionStart(), ActionDefaultFallback(), ActionDeactivateForm(), ActionRevertFallbackEvents(), @@ -306,6 +309,29 @@ async def run(self, output_channel, nlg, tracker, domain): return evts + [Restarted()] +class ActionSessionStart(Action): + """Resets the tracker to its initial state. + + Utters the restart template if available.""" + + def name(self) -> Text: + return ACTION_SESSION_START_NAME + + async def run(self, output_channel, nlg, tracker, domain): + from rasa.core.events import SessionStarted, SlotSet + + # TODO: config check whether slots should be carried over + # fetch SlotSet events from tracker + # carry over key, value and metadata + slot_set_events: List[Event] = [ + SlotSet(key=event.key, value=event.value, metadata=event.metadata) + for event in tracker.events + if isinstance(event, SlotSet) + ] + + return slot_set_events + [SessionStarted()] + + class ActionDefaultFallback(ActionUtterTemplate): """Executes the fallback action and goes back to the previous state of the dialogue""" diff --git a/rasa/core/constants.py b/rasa/core/constants.py index 104b79fa0aa0..d8a6da541573 100644 --- a/rasa/core/constants.py +++ b/rasa/core/constants.py @@ -27,6 +27,8 @@ USER_INTENT_RESTART = "restart" +USER_INTENT_SESSION_START = "session_start" + USER_INTENT_BACK = "back" USER_INTENT_OUT_OF_SCOPE = "out_of_scope" diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index 259438d5d02b..14b9bd45a519 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -525,7 +525,6 @@ def apply_to(self, tracker: "DialogueStateTracker") -> None: ) tracker._reset() - tracker.trigger_followup_action(ACTION_LISTEN_NAME) # noinspection PyProtectedMember diff --git a/rasa/core/policies/mapping_policy.py b/rasa/core/policies/mapping_policy.py index 246559b62079..d81c3f91fb88 100644 --- a/rasa/core/policies/mapping_policy.py +++ b/rasa/core/policies/mapping_policy.py @@ -11,8 +11,13 @@ ACTION_BACK_NAME, ACTION_LISTEN_NAME, ACTION_RESTART_NAME, + ACTION_SESSION_START_NAME, +) +from rasa.core.constants import ( + USER_INTENT_BACK, + USER_INTENT_RESTART, + USER_INTENT_SESSION_START, ) -from rasa.core.constants import USER_INTENT_BACK, USER_INTENT_RESTART from rasa.core.domain import Domain, InvalidDomain from rasa.core.events import ActionExecuted from rasa.core.policies.policy import Policy @@ -91,6 +96,8 @@ def predict_action_probabilities( action = ACTION_RESTART_NAME elif intent == USER_INTENT_BACK: action = ACTION_BACK_NAME + elif intent == USER_INTENT_SESSION_START: + action = ACTION_SESSION_START_NAME else: action = domain.intent_properties.get(intent, {}).get("triggers") diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 03492d0b0c13..3c3c2323c09b 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -33,6 +33,7 @@ SlotSet, UserUttered, BotUttered, + SessionStarted, ) from rasa.core.interpreter import ( INTENT_MESSAGE_PREFIX, @@ -580,9 +581,60 @@ def _log_action_on_tracker(self, tracker, action_name, events, policy, confidenc e.timestamp = time.time() tracker.update(e, self.domain) - def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]: + @staticmethod + def _session_start_timestamp_from(tracker: DialogueStateTracker) -> Optional[float]: + """Retrieve timestamp of the beginning of the last session start for + `tracker`. + + Args: + tracker: pass. + + Returns: + + """ + + if not tracker.events: + return None + + # try to fetch the timestamp of the latest `SessionStarted` event + for event in reversed(tracker.events): + if isinstance(event, SessionStarted): + return event.timestamp + + # otherwise fetch the timestamp of the first event + return tracker.events[0].timestamp + + def _tracker_has_valid_session( + self, tracker: DialogueStateTracker, session_length_in_minutes: int + ) -> bool: + """Determine whether `tracker` requires a new session. + + Args: + tracker: Tracker to inspect. + session_length_in_minutes: Session length in minutes. + Returns: + `True` if a new session is required, else `False`. + + """ + session_start_timestamp = self._session_start_timestamp_from(tracker) + + if not session_start_timestamp: + return True + + time_delta_in_seconds = time.time() - session_start_timestamp + + return time_delta_in_seconds / 60 > session_length_in_minutes + + def _get_tracker( + self, sender_id: Text, session_length_in_minutes: int = 60 + ) -> Optional[DialogueStateTracker]: sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID - return self.tracker_store.get_or_create_tracker(sender_id) + tracker = self.tracker_store.get_or_create_tracker(sender_id) + + if not self._tracker_has_valid_session(tracker, session_length_in_minutes): + tracker.update(SessionStarted()) + + return tracker def _save_tracker(self, tracker: DialogueStateTracker) -> None: self.tracker_store.save(tracker) @@ -601,8 +653,7 @@ def _prob_array_for_action( def _get_next_action_probabilities( self, tracker: DialogueStateTracker ) -> Tuple[Optional[List[float]], Optional[Text]]: - """Collect predictions from ensemble and return action and predictions. - """ + """Collect predictions from ensemble and return action and predictions.""" followup_action = tracker.followup_action if followup_action: diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 8bb5febf0713..4da553abef59 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -6,7 +6,8 @@ import pickle import typing from datetime import datetime, timezone -from typing import Iterator, Optional, Text, Iterable, Union, Dict, Callable + +from typing import Iterator, Optional, Text, Iterable, Union, Dict, Callable, List import itertools from boto3.dynamodb.conditions import Key @@ -14,7 +15,7 @@ # noinspection PyPep8Naming from time import sleep -from rasa.core.actions.action import ACTION_LISTEN_NAME +from rasa.core.actions.action import ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME from rasa.core.brokers.event_channel import EventChannel from rasa.core.conversation import Dialogue from rasa.core.domain import Domain @@ -28,6 +29,7 @@ from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session import boto3 + from rasa.core.events import Event logger = logging.getLogger(__name__) @@ -133,13 +135,13 @@ def load_tracker_from_module_string( return InMemoryTrackerStore(domain) def get_or_create_tracker( - self, sender_id: Text, max_event_history: Optional[int] = None + self, sender_id: Text, max_event_history: Optional[int] = None, ) -> "DialogueStateTracker": """Returns tracker or creates one if the retrieval returns None""" tracker = self.retrieve(sender_id) self.max_event_history = max_event_history if tracker is None: - tracker = self.create_tracker(sender_id) + tracker = self.create_tracker(sender_id, append_action_listen=True,) return tracker def init_tracker(self, sender_id: Text) -> "DialogueStateTracker": @@ -151,14 +153,17 @@ def init_tracker(self, sender_id: Text) -> "DialogueStateTracker": ) def create_tracker( - self, sender_id: Text, append_action_listen: bool = True + self, sender_id: Text, append_action_listen: bool = True, ) -> DialogueStateTracker: - """Creates a new tracker for the sender_id. The tracker is initially listening.""" + """Creates a new tracker for the sender_id. The tracker is initially listening. + """ tracker = self.init_tracker(sender_id) if tracker: if append_action_listen: tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) + self.save(tracker) + return tracker def save(self, tracker): diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index f1c35c4abf24..02974e680f02 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -2,7 +2,7 @@ import logging from collections import deque from enum import Enum -from typing import Dict, Text, Any, Optional, Iterator, Generator, Type, List +from typing import Dict, Text, Any, Optional, Iterator, Generator, Type, List, Deque from rasa.core import events # pytype: disable=pyi-error from rasa.core.actions.action import ACTION_LISTEN_NAME # pytype: disable=pyi-error @@ -525,7 +525,7 @@ def _set_slot(self, key: Text, value: Any) -> None: "".format(key) ) - def _create_events(self, evts: List[Event]) -> deque: + def _create_events(self, evts: List[Event]) -> Deque[Event]: if evts and not isinstance(evts[0], Event): # pragma: no cover raise ValueError("events, if given, must be a list of events") From 076562f47e1f4c7e45159518c77b8145e79a2bcb Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Mon, 2 Dec 2019 15:06:45 +0100 Subject: [PATCH 06/74] retrieve only latest session in SQLTrackerStore --- rasa/core/tracker_store.py | 33 +++++++++++++++-- tests/core/test_tracker_stores.py | 61 ++++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 5 deletions(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 8bb5febf0713..238633bd40ce 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -694,15 +694,40 @@ def keys(self) -> Iterable[Text]: def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]: """Create a tracker from all previously stored events.""" + import sqlalchemy as sa + from rasa.core.events import SessionStarted + with self.session_scope() as session: - query = session.query(self.SQLEvent) - result = ( - query.filter_by(sender_id=sender_id) + # Subquery to find the timestamp of the first `SessionStartedEvent`. + session_start_sub_query = ( + session.query( + sa.func.max(self.SQLEvent.timestamp).label("session_start") + ) + .filter( + self.SQLEvent.sender_id == sender_id, + self.SQLEvent.type_name == SessionStarted.type_name, + ) + .subquery() + ) + + results = ( + session.query(self.SQLEvent) + .filter( + self.SQLEvent.sender_id == sender_id, + # Find events after the latest `SessionStarted` event or return all + # events + sa.or_( + self.SQLEvent.timestamp + >= session_start_sub_query.c.session_start, + # Compare `None` with `==` since this happens in SQL + session_start_sub_query.c.session_start == None, + ), + ) .order_by(self.SQLEvent.timestamp) .all() ) - events = [json.loads(event.data) for event in result] + events = [json.loads(event.data) for event in results] if self.domain and len(events) > 0: logger.debug(f"Recreating tracker from sender id '{sender_id}'") diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index 7aad4328d2d4..3327ac8b1ff9 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -10,7 +10,14 @@ from rasa.core.channels.channel import UserMessage from rasa.core.domain import Domain -from rasa.core.events import SlotSet, ActionExecuted, Restarted +from rasa.core.events import ( + SlotSet, + ActionExecuted, + Restarted, + UserUttered, + SessionStarted, + BotUttered, +) from rasa.core.tracker_store import ( TrackerStore, InMemoryTrackerStore, @@ -379,3 +386,55 @@ def test_set_fail_safe_tracker_store_domain(default_domain: Domain): assert failsafe_store.domain is default_domain assert tracker_store.domain is failsafe_store.domain assert fallback_tracker_store.domain is failsafe_store.domain + + +def test_sql_tracker_store_retrieve_with_session_started_events(default_domain: Domain): + tracker_store = SQLTrackerStore(default_domain, host="sqlite:///") + + # Create tracker with a SessionStarted event + events = [ + UserUttered("Hola", {"name": "greet"}), + BotUttered("Hi"), + SessionStarted(), + UserUttered("Ciao", {"name": "greet"}), + ] + sender_id = "test_sql_tracker_store_with_session_events" + tracker = DialogueStateTracker.from_events(sender_id, events) + tracker_store.save(tracker) + + # Save other tracker to ensure that we don't run into problems with other senders + other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()]) + tracker_store.save(other_tracker) + + # Retrieve tracker with events since latest restart + tracker = tracker_store.retrieve(sender_id) + + assert len(tracker.events) == 2 + assert all((event == tracker.events[i] for i, event in enumerate(events[2:]))) + + +def test_sql_tracker_store_retrieve_without_session_started_events( + default_domain: Domain, +): + tracker_store = SQLTrackerStore(default_domain, host="sqlite:///") + + # Create tracker with a SessionStarted event + events = [ + UserUttered("Hola", {"name": "greet"}), + BotUttered("Hi"), + UserUttered("Ciao", {"name": "greet"}), + BotUttered("Hi2"), + ] + + sender_id = "test_sql_tracker_store_retrieve_without_session_started_events" + tracker = DialogueStateTracker.from_events(sender_id, events) + tracker_store.save(tracker) + + # Save other tracker to ensure that we don't run into problems with other senders + other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()]) + tracker_store.save(other_tracker) + + tracker = tracker_store.retrieve(sender_id) + + assert len(tracker.events) == 4 + assert all(event == tracker.events[i] for i, event in enumerate(events)) From 3d716fe99eb1eab7f47861ca0516f01a463339fa Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Mon, 2 Dec 2019 15:10:08 +0100 Subject: [PATCH 07/74] update changelog --- changelog/4830.improvement.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/4830.improvement.rst diff --git a/changelog/4830.improvement.rst b/changelog/4830.improvement.rst new file mode 100644 index 000000000000..6c81fff327a8 --- /dev/null +++ b/changelog/4830.improvement.rst @@ -0,0 +1 @@ +``SQLTrackerStore`` only retrieves events from the last session from the database. \ No newline at end of file From 1af10f60e0be34787ec3a506b3cf6029518f81ad Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Mon, 2 Dec 2019 16:22:49 +0100 Subject: [PATCH 08/74] improve the way we test for 'null' --- rasa/core/tracker_store.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 238633bd40ce..452bad54fc4e 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -719,8 +719,7 @@ def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]: sa.or_( self.SQLEvent.timestamp >= session_start_sub_query.c.session_start, - # Compare `None` with `==` since this happens in SQL - session_start_sub_query.c.session_start == None, + session_start_sub_query.c.session_start.is_(None), ), ) .order_by(self.SQLEvent.timestamp) From 22c4f0789fa6a0f535887cca98eb4a949a787da0 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 2 Dec 2019 16:32:53 +0100 Subject: [PATCH 09/74] add tests --- rasa/core/actions/action.py | 24 ++++++++++++++++------ tests/core/test_actions.py | 41 +++++++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 9bdbb0710dd1..66e0f7b97860 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -317,19 +317,31 @@ class ActionSessionStart(Action): def name(self) -> Text: return ACTION_SESSION_START_NAME - async def run(self, output_channel, nlg, tracker, domain): + async def run( + self, + output_channel: "OutputChannel", + nlg: "NaturalLanguageGenerator", + tracker: "DialogueStateTracker", + domain: "DialogueStateTracker", + ) -> List[Event]: from rasa.core.events import SessionStarted, SlotSet # TODO: config check whether slots should be carried over - # fetch SlotSet events from tracker - # carry over key, value and metadata - slot_set_events: List[Event] = [ + # fetch SlotSet events from tracker and carry over key, value and metadata + # use generator so the timestamps are greater than that of the returned + # `SessionStarted` event + slot_set_events = ( SlotSet(key=event.key, value=event.value, metadata=event.metadata) for event in tracker.events if isinstance(event, SlotSet) - ] + ) - return slot_set_events + [SessionStarted()] + # noinspection PyTypeChecker + return ( + [SessionStarted()] + + list(slot_set_events) + + [ActionExecuted(action_name=ACTION_LISTEN_NAME)] + ) class ActionDefaultFallback(ActionUtterTemplate): diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index 5d7f2ea19782..9c4cc26b5abe 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -12,6 +12,7 @@ ACTION_LISTEN_NAME, ACTION_RESTART_NAME, ACTION_REVERT_FALLBACK_EVENTS_NAME, + ACTION_SESSION_START_NAME, ActionBack, ActionDefaultAskAffirmation, ActionDefaultAskRephrase, @@ -22,9 +23,17 @@ ActionUtterTemplate, ActionRetrieveResponse, RemoteAction, + ActionSessionStart, ) from rasa.core.domain import Domain, InvalidDomain -from rasa.core.events import Restarted, SlotSet, UserUtteranceReverted, BotUttered, Form +from rasa.core.events import ( + Restarted, + SlotSet, + UserUtteranceReverted, + BotUttered, + Form, + SessionStarted, +) from rasa.core.nlg.template import TemplatedNaturalLanguageGenerator from rasa.core.trackers import DialogueStateTracker from rasa.utils.endpoints import ClientResponseError, EndpointConfig @@ -98,18 +107,19 @@ def test_domain_action_instantiation(): instantiated_actions = domain.actions(None) - assert len(instantiated_actions) == 11 + assert len(instantiated_actions) == 12 assert instantiated_actions[0].name() == ACTION_LISTEN_NAME assert instantiated_actions[1].name() == ACTION_RESTART_NAME - assert instantiated_actions[2].name() == ACTION_DEFAULT_FALLBACK_NAME - assert instantiated_actions[3].name() == ACTION_DEACTIVATE_FORM_NAME - assert instantiated_actions[4].name() == ACTION_REVERT_FALLBACK_EVENTS_NAME - assert instantiated_actions[5].name() == (ACTION_DEFAULT_ASK_AFFIRMATION_NAME) - assert instantiated_actions[6].name() == (ACTION_DEFAULT_ASK_REPHRASE_NAME) - assert instantiated_actions[7].name() == ACTION_BACK_NAME - assert instantiated_actions[8].name() == "my_module.ActionTest" - assert instantiated_actions[9].name() == "utter_test" - assert instantiated_actions[10].name() == "respond_test" + assert instantiated_actions[2].name() == ACTION_SESSION_START_NAME + assert instantiated_actions[3].name() == ACTION_DEFAULT_FALLBACK_NAME + assert instantiated_actions[4].name() == ACTION_DEACTIVATE_FORM_NAME + assert instantiated_actions[5].name() == ACTION_REVERT_FALLBACK_EVENTS_NAME + assert instantiated_actions[6].name() == ACTION_DEFAULT_ASK_AFFIRMATION_NAME + assert instantiated_actions[7].name() == ACTION_DEFAULT_ASK_REPHRASE_NAME + assert instantiated_actions[8].name() == ACTION_BACK_NAME + assert instantiated_actions[9].name() == "my_module.ActionTest" + assert instantiated_actions[10].name() == "utter_test" + assert instantiated_actions[11].name() == "respond_test" async def test_remote_action_runs( @@ -482,6 +492,15 @@ async def test_action_restart( assert events == [BotUttered("congrats, you've restarted me!"), Restarted()] +async def test_action_session_start( + default_channel, template_nlg, template_sender_tracker, default_domain +): + events = await ActionSessionStart().run( + default_channel, template_nlg, template_sender_tracker, default_domain + ) + assert events == [SessionStarted(), ActionExecuted()] + + async def test_action_default_fallback( default_channel, default_nlg, default_tracker, default_domain ): From 8b7fbf4b489c8b78c9fbfeb74ed8530663397c63 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 2 Dec 2019 16:52:07 +0100 Subject: [PATCH 10/74] test action_session_start --- rasa/core/actions/action.py | 2 +- rasa/core/events/__init__.py | 5 +++-- tests/core/test_actions.py | 40 +++++++++++++++++++++++++++++++++--- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 66e0f7b97860..22d253d9f96f 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -322,7 +322,7 @@ async def run( output_channel: "OutputChannel", nlg: "NaturalLanguageGenerator", tracker: "DialogueStateTracker", - domain: "DialogueStateTracker", + domain: "Domain", ) -> List[Event]: from rasa.core.events import SessionStarted, SlotSet diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index 14b9bd45a519..6192c66c8348 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -525,6 +525,7 @@ def apply_to(self, tracker: "DialogueStateTracker") -> None: ) tracker._reset() + tracker.trigger_followup_action(ACTION_LISTEN_NAME) # noinspection PyProtectedMember @@ -1206,7 +1207,7 @@ def as_story_string(self) -> NoReturn: def apply_to(self, tracker: "DialogueStateTracker") -> None: from rasa.core.actions.action import ( # pytype: disable=pyi-error - ACTION_LISTEN_NAME, + ACTION_SESSION_START_NAME, ) - tracker.trigger_followup_action(ACTION_LISTEN_NAME) + tracker.trigger_followup_action(ACTION_SESSION_START_NAME) diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index 9c4cc26b5abe..a5f530dca20f 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -25,6 +25,7 @@ RemoteAction, ActionSessionStart, ) +from rasa.core.channels import CollectingOutputChannel from rasa.core.domain import Domain, InvalidDomain from rasa.core.events import ( Restarted, @@ -33,6 +34,7 @@ BotUttered, Form, SessionStarted, + ActionExecuted, ) from rasa.core.nlg.template import TemplatedNaturalLanguageGenerator from rasa.core.trackers import DialogueStateTracker @@ -492,13 +494,45 @@ async def test_action_restart( assert events == [BotUttered("congrats, you've restarted me!"), Restarted()] -async def test_action_session_start( - default_channel, template_nlg, template_sender_tracker, default_domain +async def test_action_session_start_without_slots( + default_channel: CollectingOutputChannel, + template_nlg: TemplatedNaturalLanguageGenerator, + template_sender_tracker: DialogueStateTracker, + default_domain: Domain, ): events = await ActionSessionStart().run( default_channel, template_nlg, template_sender_tracker, default_domain ) - assert events == [SessionStarted(), ActionExecuted()] + assert events == [SessionStarted(), ActionExecuted(action_name=ACTION_LISTEN_NAME)] + + +async def test_action_session_start_with_slots( + default_channel: CollectingOutputChannel, + template_nlg: TemplatedNaturalLanguageGenerator, + template_sender_tracker: DialogueStateTracker, + default_domain: Domain, +): + # set a few slots on tracker + slot_set_event_1 = SlotSet("my_slot", "value") + slot_set_event_2 = SlotSet("another-slot", "value2") + for event in [slot_set_event_1, slot_set_event_2]: + template_sender_tracker.update(event) + + events = await ActionSessionStart().run( + default_channel, template_nlg, template_sender_tracker, default_domain + ) + + assert events == [ + SessionStarted(), + slot_set_event_1, + slot_set_event_2, + ActionExecuted(action_name=ACTION_LISTEN_NAME), + ] + + # make sure that the list of events has ascending timestamps + assert all( + events[i].timestamp <= events[i + 1].timestamp for i in range(len(events) - 1) + ) async def test_action_default_fallback( From aee83a75172661d7c407bee9cea7ca9d4b63d54a Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 2 Dec 2019 18:20:04 +0100 Subject: [PATCH 11/74] add tests for 'has_session_expired()' --- rasa/core/events/__init__.py | 24 +++++++------- rasa/core/processor.py | 58 ++++++++++++++++++++++++++-------- rasa/core/tracker_store.py | 2 +- rasa/core/trackers.py | 2 +- tests/core/test_actions.py | 4 +-- tests/core/test_processor.py | 61 ++++++++++++++++++++++++++++++++++++ 6 files changed, 121 insertions(+), 30 deletions(-) diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index 6192c66c8348..eb6b06474a69 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -209,7 +209,7 @@ def __init__( intent=None, entities=None, parse_data: Optional[Dict[Text, Any]] = None, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, input_channel: Optional[Text] = None, message_id: Optional[Text] = None, metadata: Optional[Dict] = None, @@ -237,7 +237,7 @@ def __init__( def _from_parse_data( text: Text, parse_data: Dict[Text, Any], - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, input_channel: Optional[Text] = None, message_id: Optional[Text] = None, metadata: Optional[Dict] = None, @@ -441,7 +441,7 @@ def __init__( self, key: Text, value: Optional[Any] = None, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict[Text, Any]] = None, ) -> None: self.key = key @@ -596,7 +596,7 @@ def __init__( trigger_date_time: datetime, name: Optional[Text] = None, kill_on_user_message: bool = True, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict[Text, Any]] = None, ): """Creates the reminder @@ -684,7 +684,7 @@ class ReminderCancelled(Event): def __init__( self, action_name: Text, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict[Text, Any]] = None, ): """ @@ -758,7 +758,7 @@ class StoryExported(Event): def __init__( self, path: Optional[Text] = None, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict[Text, Any]] = None, ): self.path = path @@ -800,7 +800,7 @@ class FollowupAction(Event): def __init__( self, name: Text, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict[Text, Any]] = None, ) -> None: self.action_name = name @@ -906,7 +906,7 @@ def __init__( action_name: Text, policy: Optional[Text] = None, confidence: Optional[float] = None, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict] = None, ): self.action_name = action_name @@ -975,7 +975,7 @@ def __init__( self, text: Optional[Text] = None, data: Optional[Any] = None, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict[Text, Any]] = None, ) -> None: self.text = text @@ -1038,7 +1038,7 @@ class Form(Event): def __init__( self, name: Optional[Text], - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict[Text, Any]] = None, ) -> None: self.name = name @@ -1089,7 +1089,7 @@ class FormValidation(Event): def __init__( self, validate: bool, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict[Text, Any]] = None, ) -> None: self.validate = validate @@ -1134,7 +1134,7 @@ def __init__( action_name: Text, policy: Optional[Text] = None, confidence: Optional[float] = None, - timestamp: Optional[int] = None, + timestamp: Optional[float] = None, metadata: Optional[Dict[Text, Any]] = None, ) -> None: self.action_name = action_name diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 3c3c2323c09b..3c8c6320e8c0 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -292,7 +292,8 @@ def _log_slots(tracker): logger.debug(f"Current slot values: \n{slot_values}") def _log_unseen_features(self, parse_data: Dict[Text, Any]) -> None: - """Check if the NLU interpreter picks up intents or entities that aren't recognized.""" + """Check if the NLU interpreter picks up intents or entities that aren't + recognized.""" domain_is_not_empty = self.domain and not self.domain.is_empty() @@ -582,56 +583,87 @@ def _log_action_on_tracker(self, tracker, action_name, events, policy, confidenc tracker.update(e, self.domain) @staticmethod - def _session_start_timestamp_from(tracker: DialogueStateTracker) -> Optional[float]: + def _session_start_timestamp_for(tracker: DialogueStateTracker) -> Optional[float]: """Retrieve timestamp of the beginning of the last session start for `tracker`. Args: - tracker: pass. + tracker: Tracker to inspect. Returns: + Timestamp of last `SessionStarted` event if available, else timestamp of + oldest event. `None` if no events are available. """ - if not tracker.events: return None # try to fetch the timestamp of the latest `SessionStarted` event - for event in reversed(tracker.events): - if isinstance(event, SessionStarted): - return event.timestamp + last_session_started_event = tracker.get_last_event_for(SessionStarted) + if last_session_started_event: + return last_session_started_event.timestamp # otherwise fetch the timestamp of the first event return tracker.events[0].timestamp - def _tracker_has_valid_session( + def _has_session_expired( self, tracker: DialogueStateTracker, session_length_in_minutes: int ) -> bool: - """Determine whether `tracker` requires a new session. + """Determine whether the latest session in `tracker` has expired. Args: tracker: Tracker to inspect. session_length_in_minutes: Session length in minutes. + Returns: - `True` if a new session is required, else `False`. + `True` if the session has expired, else `False`. """ - session_start_timestamp = self._session_start_timestamp_from(tracker) + session_start_timestamp = self._session_start_timestamp_for(tracker) + # this is a legacy tracker (pre-sessions) if not session_start_timestamp: - return True + return False time_delta_in_seconds = time.time() - session_start_timestamp return time_delta_in_seconds / 60 > session_length_in_minutes + @staticmethod + def _is_new_tracker(tracker: DialogueStateTracker) -> bool: + """Determine whether `tracker` is new. + + A new tracker is a tracker that has either no events, or one event that is an + executed 'action_listen'. + + Args: + tracker: Tracker to inspect. + + Returns: + `True` if the tracker contains no events, `True` if the tracker contains + one event that is an executed "ActionListen", `False` otherwise. + + """ + + if len(tracker.events) > 1: + return False + + last_action_executed_event = tracker.get_last_event_for(ActionExecuted) + + return not tracker.events or ( + last_action_executed_event + and last_action_executed_event.action_name == ACTION_LISTEN_NAME + ) + def _get_tracker( self, sender_id: Text, session_length_in_minutes: int = 60 ) -> Optional[DialogueStateTracker]: sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID tracker = self.tracker_store.get_or_create_tracker(sender_id) - if not self._tracker_has_valid_session(tracker, session_length_in_minutes): + if self._is_new_tracker(tracker) or self._has_session_expired( + tracker, session_length_in_minutes + ): tracker.update(SessionStarted()) return tracker diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 4da553abef59..ceb66a3acb95 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -141,7 +141,7 @@ def get_or_create_tracker( tracker = self.retrieve(sender_id) self.max_event_history = max_event_history if tracker is None: - tracker = self.create_tracker(sender_id, append_action_listen=True,) + tracker = self.create_tracker(sender_id) return tracker def init_tracker(self, sender_id: Text) -> "DialogueStateTracker": diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index 02974e680f02..acb9b22ba073 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -466,10 +466,10 @@ def get_last_event_for( def filter_function(e: Event): has_instance = isinstance(e, event_type) excluded = isinstance(e, ActionExecuted) and e.action_name in to_exclude - return has_instance and not excluded filtered = filter(filter_function, reversed(self.applied_events())) + for i in range(skip): next(filtered, None) diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index a5f530dca20f..77dcd01dc257 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -530,9 +530,7 @@ async def test_action_session_start_with_slots( ] # make sure that the list of events has ascending timestamps - assert all( - events[i].timestamp <= events[i + 1].timestamp for i in range(len(events) - 1) - ) + assert sorted(events, key=lambda x: x.timestamp) == events async def test_action_default_fallback( diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index e58d13124389..f4a48c940b00 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -1,3 +1,7 @@ +from typing import Optional + +import time + import aiohttp import asyncio import datetime @@ -19,6 +23,8 @@ ReminderScheduled, Restarted, UserUttered, + SessionStarted, + Event, ) from rasa.core.trackers import DialogueStateTracker from rasa.core.slots import Slot @@ -267,3 +273,58 @@ async def test_reminder_restart( # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) assert len(t.events) == 4 # nothing should have been executed + + +# noinspection PyProtectedMember +async def test_is_new_tracker(default_processor: MessageProcessor): + sender_id = uuid.uuid4().hex + + # tracker with just an action_listen is a new tacker + tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) + assert default_processor._is_new_tracker(tracker) + + # adding another event means it's no longer considered a new tracker + tracker.update(UserUttered("hello")) + assert not default_processor._is_new_tracker(tracker) + + # a tracker without any events is also a new tracker + tracker.events.clear() + assert default_processor._is_new_tracker(tracker) + + +# noinspection PyProtectedMember +@pytest.mark.parametrize( + "event_to_apply,session_length_in_minutes,has_expired", + [ + # session start is way in the past + (SessionStarted(timestamp=1), 60, True), + # session start is very recent + (SessionStarted(timestamp=time.time()), 1, False), + # there is no session start event (legacy tracker) + (UserUttered("hello", timestamp=time.time()), 1, False), + # there is no event + (None, 1, False), + ], +) +async def test_has_session_expired( + event_to_apply: Optional[Event], + session_length_in_minutes: int, + has_expired: bool, + default_processor: MessageProcessor, +): + sender_id = uuid.uuid4().hex + + # create new tracker without events + tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) + tracker.events.clear() + + # apply desired event + if event_to_apply: + tracker.update(event_to_apply) + + assert ( + default_processor._has_session_expired( + tracker, session_length_in_minutes=session_length_in_minutes + ) + == has_expired + ) From ec94163f61caf72872b5c522d839d000d4a86523 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 2 Dec 2019 18:25:01 +0100 Subject: [PATCH 12/74] pep484 Union[int, float]->float --- rasa/cli/x.py | 5 ++--- rasa/core/brokers/pika.py | 6 +++--- rasa/core/lock.py | 4 ++-- rasa/core/lock_store.py | 8 ++++---- rasa/core/processor.py | 4 ++-- tests/core/test_lock_store.py | 11 ++++------- 6 files changed, 17 insertions(+), 21 deletions(-) diff --git a/rasa/cli/x.py b/rasa/cli/x.py index 087ee3bce751..c845c8b5aa73 100644 --- a/rasa/cli/x.py +++ b/rasa/cli/x.py @@ -2,12 +2,11 @@ import asyncio import importlib.util import logging -import warnings import os import signal import traceback from multiprocessing import get_context -from typing import List, Text, Optional, Tuple, Union, Iterable +from typing import List, Text, Optional, Tuple, Iterable import aiohttp import ruamel.yaml as yaml @@ -328,7 +327,7 @@ def rasa_x(args: argparse.Namespace): async def _pull_runtime_config_from_server( config_endpoint: Optional[Text], attempts: int = 60, - wait_time_between_pulls: Union[int, float] = 5, + wait_time_between_pulls: float = 5, keys: Iterable[Text] = ("endpoints", "credentials"), ) -> Optional[List[Text]]: """Pull runtime config from `config_endpoint`. diff --git a/rasa/core/brokers/pika.py b/rasa/core/brokers/pika.py index 1d44fb13426d..32bf68a36e64 100644 --- a/rasa/core/brokers/pika.py +++ b/rasa/core/brokers/pika.py @@ -29,7 +29,7 @@ def initialise_pika_connection( password: Text, port: Union[Text, int] = 5672, connection_attempts: int = 20, - retry_delay_in_seconds: Union[int, float] = 5, + retry_delay_in_seconds: float = 5, ) -> "BlockingConnection": """Create a Pika `BlockingConnection`. @@ -60,7 +60,7 @@ def _get_pika_parameters( password: Text, port: Union[Text, int] = 5672, connection_attempts: int = 20, - retry_delay_in_seconds: Union[int, float] = 5, + retry_delay_in_seconds: float = 5, ) -> "Parameters": """Create Pika `Parameters`. @@ -135,7 +135,7 @@ def initialise_pika_channel( password: Text, port: Union[Text, int] = 5672, connection_attempts: int = 20, - retry_delay_in_seconds: Union[int, float] = 5, + retry_delay_in_seconds: float = 5, ) -> "BlockingChannel": """Initialise a Pika channel with a durable queue. diff --git a/rasa/core/lock.py b/rasa/core/lock.py index 8c01a85848e8..0e225aab755d 100644 --- a/rasa/core/lock.py +++ b/rasa/core/lock.py @@ -1,9 +1,9 @@ import json import logging from collections import deque -from typing import Text, Optional, Union, Deque, Dict, Any import time +from typing import Text, Optional, Union, Deque, Dict, Any logger = logging.getLogger(__name__) @@ -72,7 +72,7 @@ def is_locked(self, ticket_number: int) -> bool: return self.now_serving != ticket_number - def issue_ticket(self, lifetime: Union[float, int]) -> int: + def issue_ticket(self, lifetime: float) -> int: """Issue a new ticket and return its number.""" self.remove_expired_tickets() diff --git a/rasa/core/lock_store.py b/rasa/core/lock_store.py index a3bf1ae709d8..4adedf28a8fc 100644 --- a/rasa/core/lock_store.py +++ b/rasa/core/lock_store.py @@ -2,7 +2,7 @@ import json import logging import os -from typing import Text, Optional, Union, AsyncGenerator +from typing import Text, Optional, AsyncGenerator from async_generator import asynccontextmanager @@ -93,7 +93,7 @@ def save_lock(self, lock: TicketLock) -> None: raise NotImplementedError def issue_ticket( - self, conversation_id: Text, lock_lifetime: Union[float, int] = LOCK_LIFETIME + self, conversation_id: Text, lock_lifetime: float = LOCK_LIFETIME ) -> int: """Issue new ticket with `lock_lifetime` for lock associated with `conversation_id`. @@ -124,7 +124,7 @@ async def lock( self, conversation_id: Text, lock_lifetime: int = LOCK_LIFETIME, - wait_time_in_seconds: Union[int, float] = 1, + wait_time_in_seconds: float = 1, ) -> AsyncGenerator[TicketLock, None]: """Acquire lock with lifetime `lock_lifetime`for `conversation_id`. @@ -146,7 +146,7 @@ async def _acquire_lock( self, conversation_id: Text, ticket: int, - wait_time_in_seconds: Union[int, float], + wait_time_in_seconds: float, ) -> TicketLock: while True: diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 3c8c6320e8c0..c0a22ca2806e 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -3,7 +3,7 @@ import logging import os from types import LambdaType -from typing import Any, Dict, List, Optional, Text, Tuple +from typing import Any, Dict, List, Optional, Text, Tuple, Union import numpy as np import time @@ -607,7 +607,7 @@ def _session_start_timestamp_for(tracker: DialogueStateTracker) -> Optional[floa return tracker.events[0].timestamp def _has_session_expired( - self, tracker: DialogueStateTracker, session_length_in_minutes: int + self, tracker: DialogueStateTracker, session_length_in_minutes: float ) -> bool: """Determine whether the latest session in `tracker` has expired. diff --git a/tests/core/test_lock_store.py b/tests/core/test_lock_store.py index 47989a23bd96..33ca2f8503a8 100644 --- a/tests/core/test_lock_store.py +++ b/tests/core/test_lock_store.py @@ -1,15 +1,14 @@ import asyncio import copy import os -from typing import Union, Text -from unittest.mock import patch import numpy as np import pytest import time from _pytest.tmpdir import TempdirFactory +from typing import Text +from unittest.mock import patch -import rasa.utils.io from rasa.core.agent import Agent from rasa.core.channels import UserMessage from rasa.core.constants import INTENT_MESSAGE_PREFIX, DEFAULT_LOCK_LIFETIME @@ -122,9 +121,7 @@ def test_lock_expiration(): def test_ticket_exists_error(): def mocked_issue_ticket( - self, - conversation_id: Text, - lock_lifetime: Union[float, int] = DEFAULT_LOCK_LIFETIME, + self, conversation_id: Text, lock_lifetime: float = DEFAULT_LOCK_LIFETIME, ) -> None: # mock LockStore.issue_ticket() so it issues two tickets for the same # conversation ID simultaneously @@ -177,7 +174,7 @@ async def test_message_order(tmpdir_factory: TempdirFactory, default_agent: Agen # can check the order later on. We don't need the return value of this method so # we'll just return None. async def mocked_handle_message( - self, message: UserMessage, wait: Union[int, float] + self, message: UserMessage, wait: float ) -> None: # write incoming message to file with open(str(incoming_order_file), "a+") as f_0: From b24d69f5bddfc17df9b56d4dab7aa1cdce294e12 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 2 Dec 2019 18:25:09 +0100 Subject: [PATCH 13/74] black --- rasa/core/lock_store.py | 5 +---- tests/core/test_lock_store.py | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/rasa/core/lock_store.py b/rasa/core/lock_store.py index 4adedf28a8fc..ace43ef58122 100644 --- a/rasa/core/lock_store.py +++ b/rasa/core/lock_store.py @@ -143,10 +143,7 @@ async def lock( self.cleanup(conversation_id, ticket) async def _acquire_lock( - self, - conversation_id: Text, - ticket: int, - wait_time_in_seconds: float, + self, conversation_id: Text, ticket: int, wait_time_in_seconds: float, ) -> TicketLock: while True: diff --git a/tests/core/test_lock_store.py b/tests/core/test_lock_store.py index 33ca2f8503a8..50ac789419d3 100644 --- a/tests/core/test_lock_store.py +++ b/tests/core/test_lock_store.py @@ -173,9 +173,7 @@ async def test_message_order(tmpdir_factory: TempdirFactory, default_agent: Agen # record messages as they come and and as they're processed in files so we # can check the order later on. We don't need the return value of this method so # we'll just return None. - async def mocked_handle_message( - self, message: UserMessage, wait: float - ) -> None: + async def mocked_handle_message(self, message: UserMessage, wait: float) -> None: # write incoming message to file with open(str(incoming_order_file), "a+") as f_0: f_0.write(message.text + "\n") From 02fc02ddffffbaae951072c47eb6787633f9c438 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 2 Dec 2019 20:48:32 +0100 Subject: [PATCH 14/74] do not use SessionStarted events in prediction states --- rasa/core/actions/action.py | 6 ++--- rasa/core/processor.py | 34 +++++++++++++-------------- rasa/core/run.py | 2 +- rasa/core/tracker_store.py | 8 ++++++- rasa/core/trackers.py | 7 +++--- tests/core/test_processor.py | 45 ++++++++++++++++++++++++++---------- 6 files changed, 64 insertions(+), 38 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 22d253d9f96f..9580e74059b4 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -310,9 +310,9 @@ async def run(self, output_channel, nlg, tracker, domain): class ActionSessionStart(Action): - """Resets the tracker to its initial state. + """Applies. - Utters the restart template if available.""" + Utters the 'session start' template if available.""" def name(self) -> Text: return ACTION_SESSION_START_NAME @@ -338,7 +338,7 @@ async def run( # noinspection PyTypeChecker return ( - [SessionStarted()] + [ActionExecuted(action_name=ACTION_SESSION_START_NAME)] + list(slot_set_events) + [ActionExecuted(action_name=ACTION_LISTEN_NAME)] ) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index c0a22ca2806e..290a2c9193bf 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -9,7 +9,7 @@ import time from rasa.core import jobs -from rasa.core.actions.action import Action +from rasa.core.actions.action import Action, ACTION_SESSION_START_NAME from rasa.core.actions.action import ACTION_LISTEN_NAME, ActionExecutionRejection from rasa.core.channels.channel import ( CollectingOutputChannel, @@ -97,8 +97,11 @@ async def handle_message( await self._predict_and_execute_next_action(message, tracker) # save tracker state to continue conversation from this state self._save_tracker(tracker) - + print("have tracker events") + for e in tracker.events: + print(e) if isinstance(message.output_channel, CollectingOutputChannel): + print("done in processor") return message.output_channel.messages else: return None @@ -596,7 +599,8 @@ def _session_start_timestamp_for(tracker: DialogueStateTracker) -> Optional[floa """ if not tracker.events: - return None + # this is a legacy tracker (pre-sessions) + return time.time() # try to fetch the timestamp of the latest `SessionStarted` event last_session_started_event = tracker.get_last_event_for(SessionStarted) @@ -604,6 +608,7 @@ def _session_start_timestamp_for(tracker: DialogueStateTracker) -> Optional[floa return last_session_started_event.timestamp # otherwise fetch the timestamp of the first event + # this also is a legacy tracker (pre-sessions) return tracker.events[0].timestamp def _has_session_expired( @@ -621,16 +626,12 @@ def _has_session_expired( """ session_start_timestamp = self._session_start_timestamp_for(tracker) - # this is a legacy tracker (pre-sessions) - if not session_start_timestamp: - return False - time_delta_in_seconds = time.time() - session_start_timestamp return time_delta_in_seconds / 60 > session_length_in_minutes @staticmethod - def _is_new_tracker(tracker: DialogueStateTracker) -> bool: + def _is_legacy_tracker(tracker: DialogueStateTracker) -> bool: """Determine whether `tracker` is new. A new tracker is a tracker that has either no events, or one event that is an @@ -645,14 +646,10 @@ def _is_new_tracker(tracker: DialogueStateTracker) -> bool: """ - if len(tracker.events) > 1: - return False - - last_action_executed_event = tracker.get_last_event_for(ActionExecuted) - - return not tracker.events or ( - last_action_executed_event - and last_action_executed_event.action_name == ACTION_LISTEN_NAME + return not any( + isinstance(event, ActionExecuted) + and event.action_name == ACTION_SESSION_START_NAME + for event in tracker.events ) def _get_tracker( @@ -661,9 +658,10 @@ def _get_tracker( sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID tracker = self.tracker_store.get_or_create_tracker(sender_id) - if self._is_new_tracker(tracker) or self._has_session_expired( - tracker, session_length_in_minutes + if self._is_legacy_tracker(tracker) or self._has_session_expired( + tracker, 0.083 ): + print("has expired") tracker.update(SessionStarted()) return tracker diff --git a/rasa/core/run.py b/rasa/core/run.py index 6c37a06995ba..47292ce45b94 100644 --- a/rasa/core/run.py +++ b/rasa/core/run.py @@ -129,7 +129,7 @@ async def run_cmdline_io(running_app: Sanic): ) logger.info("Killing Sanic server now.") - running_app.stop() # kill the sanic serverx + running_app.stop() # kill the sanic server app.add_task(run_cmdline_io) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index ceb66a3acb95..e71aca370bae 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -153,12 +153,18 @@ def init_tracker(self, sender_id: Text) -> "DialogueStateTracker": ) def create_tracker( - self, sender_id: Text, append_action_listen: bool = True, + self, + sender_id: Text, + append_action_listen: bool = True, + should_append_session_started=True, ) -> DialogueStateTracker: """Creates a new tracker for the sender_id. The tracker is initially listening. """ tracker = self.init_tracker(sender_id) if tracker: + if should_append_session_started: + tracker.update(ActionExecuted(ACTION_SESSION_START_NAME)) + if append_action_listen: tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index acb9b22ba073..8e1db1d4e20c 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -17,6 +17,7 @@ UserUtteranceReverted, BotUttered, Form, + SessionStarted, ) from rasa.core.domain import Domain # pytype: disable=pyi-error from rasa.core.slots import Slot @@ -261,13 +262,13 @@ def init_copy(self) -> "DialogueStateTracker": UserMessage.DEFAULT_SENDER_ID, self.slots.values(), self._max_event_history ) + # TODO: exclude SessionStart from prior states def generate_all_prior_trackers( self, ) -> Generator["DialogueStateTracker", None, None]: """Returns a generator of the previous trackers of this tracker. - The resulting array is representing - the trackers before each action.""" + The resulting array is representing the trackers before each action.""" tracker = self.init_copy() @@ -344,7 +345,7 @@ def undo_till_previous(event_type, done_events): applied_events = [] for event in self.events: - if isinstance(event, Restarted): + if isinstance(event, (Restarted, SessionStarted)): applied_events = [] elif isinstance(event, ActionReverted): undo_till_previous(ActionExecuted, applied_events) diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index f4a48c940b00..ece0c5eda0f7 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List import time @@ -14,6 +14,11 @@ import rasa.utils.io from rasa.core import jobs +from rasa.core.actions.action import ( + ACTION_LISTEN_NAME, + ActionSessionStart, + ACTION_SESSION_START_NAME, +) from rasa.core.agent import Agent from rasa.core.channels.channel import CollectingOutputChannel, UserMessage from rasa.core.events import ( @@ -203,7 +208,7 @@ async def test_reminder_aborted( # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) - assert len(t.events) == 3 # nothing should have been executed + assert len(t.events) == 4 # nothing should have been executed async def test_reminder_cancelled( @@ -272,24 +277,40 @@ async def test_reminder_restart( # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) - assert len(t.events) == 4 # nothing should have been executed + assert len(t.events) == 5 # nothing should have been executed # noinspection PyProtectedMember -async def test_is_new_tracker(default_processor: MessageProcessor): +@pytest.mark.parametrize( + "events_to_apply,is_legacy", + [ + # just an action listen means it's legacy + ([ActionExecuted(ACTION_LISTEN_NAME)], True), + # action listen and session start means it isn't legacy + ( + [ + ActionExecuted(ACTION_SESSION_START_NAME), + ActionExecuted(ACTION_LISTEN_NAME), + ], + False, + ), + # just a single event means it's legacy + ([UserUttered("hello")], True), + ], +) +async def test_is_new_tracker( + events_to_apply: List[Event], is_legacy: bool, default_processor: MessageProcessor, +): sender_id = uuid.uuid4().hex - # tracker with just an action_listen is a new tacker + # new tracker without events tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) - assert default_processor._is_new_tracker(tracker) + tracker.events.clear() - # adding another event means it's no longer considered a new tracker - tracker.update(UserUttered("hello")) - assert not default_processor._is_new_tracker(tracker) + for event in events_to_apply: + tracker.update(event) - # a tracker without any events is also a new tracker - tracker.events.clear() - assert default_processor._is_new_tracker(tracker) + assert default_processor._is_legacy_tracker(tracker) == is_legacy # noinspection PyProtectedMember From 039992e1f9fbcfd4aaabd3738e1a57ec2416aa7e Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 3 Dec 2019 08:58:56 +0100 Subject: [PATCH 15/74] new tracker method to get last executed action for action name --- rasa/core/actions/action.py | 1 + rasa/core/processor.py | 39 ++++++++++++++++++------------------ rasa/core/trackers.py | 19 ++++++++++++++++-- tests/core/test_actions.py | 7 +++++-- tests/core/test_processor.py | 16 +++++++-------- tests/core/test_trackers.py | 29 +++++++++++++++++++++++++++ 6 files changed, 80 insertions(+), 31 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 9580e74059b4..6b43b14fad9c 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -338,6 +338,7 @@ async def run( # noinspection PyTypeChecker return ( + # TODO: should this rather return a `SessionStarted()` event rather than a SessionExecuted()? [ActionExecuted(action_name=ACTION_SESSION_START_NAME)] + list(slot_set_events) + [ActionExecuted(action_name=ACTION_LISTEN_NAME)] diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 290a2c9193bf..4bf3cb07e69a 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -585,8 +585,9 @@ def _log_action_on_tracker(self, tracker, action_name, events, policy, confidenc e.timestamp = time.time() tracker.update(e, self.domain) - @staticmethod - def _session_start_timestamp_for(tracker: DialogueStateTracker) -> Optional[float]: + def _session_start_timestamp_for( + self, tracker: DialogueStateTracker + ) -> Optional[float]: """Retrieve timestamp of the beginning of the last session start for `tracker`. @@ -595,17 +596,19 @@ def _session_start_timestamp_for(tracker: DialogueStateTracker) -> Optional[floa Returns: Timestamp of last `SessionStarted` event if available, else timestamp of - oldest event. `None` if no events are available. + oldest event. Current time if no events are available. """ if not tracker.events: - # this is a legacy tracker (pre-sessions) + # this is a legacy tracker (pre-sessions), return current time return time.time() # try to fetch the timestamp of the latest `SessionStarted` event - last_session_started_event = tracker.get_last_event_for(SessionStarted) - if last_session_started_event: - return last_session_started_event.timestamp + last_executed_session_started_action = tracker.get_last_executed( + ACTION_SESSION_START_NAME + ) + if last_executed_session_started_action: + return last_executed_session_started_action.timestamp # otherwise fetch the timestamp of the first event # this also is a legacy tracker (pre-sessions) @@ -621,7 +624,7 @@ def _has_session_expired( session_length_in_minutes: Session length in minutes. Returns: - `True` if the session has expired, else `False`. + `True` if the session in `tracker` has expired, `False` otherwise. """ session_start_timestamp = self._session_start_timestamp_for(tracker) @@ -632,26 +635,25 @@ def _has_session_expired( @staticmethod def _is_legacy_tracker(tracker: DialogueStateTracker) -> bool: - """Determine whether `tracker` is new. + """Determine whether `tracker` is a legacy tracker. - A new tracker is a tracker that has either no events, or one event that is an - executed 'action_listen'. + A legacy tracker is a tracker that has been created before the introduction + of sessions in release 1.7.0. Args: tracker: Tracker to inspect. Returns: - `True` if the tracker contains no events, `True` if the tracker contains - one event that is an executed "ActionListen", `False` otherwise. + `True` if the tracker contains `SessionStarted` event, `False` otherwise. """ - return not any( - isinstance(event, ActionExecuted) - and event.action_name == ACTION_SESSION_START_NAME - for event in tracker.events + last_executed_session_started_action = tracker.get_last_executed( + ACTION_SESSION_START_NAME ) + return last_executed_session_started_action is None + def _get_tracker( self, sender_id: Text, session_length_in_minutes: int = 60 ) -> Optional[DialogueStateTracker]: @@ -659,9 +661,8 @@ def _get_tracker( tracker = self.tracker_store.get_or_create_tracker(sender_id) if self._is_legacy_tracker(tracker) or self._has_session_expired( - tracker, 0.083 + tracker, session_length_in_minutes ): - print("has expired") tracker.update(SessionStarted()) return tracker diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index 8e1db1d4e20c..7c1bc719fb73 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -5,7 +5,10 @@ from typing import Dict, Text, Any, Optional, Iterator, Generator, Type, List, Deque from rasa.core import events # pytype: disable=pyi-error -from rasa.core.actions.action import ACTION_LISTEN_NAME # pytype: disable=pyi-error +from rasa.core.actions.action import ( + ACTION_LISTEN_NAME, + ACTION_SESSION_START_NAME, +) # pytype: disable=pyi-error from rasa.core.conversation import Dialogue # pytype: disable=pyi-error from rasa.core.events import ( # pytype: disable=pyi-error UserUttered, @@ -487,11 +490,23 @@ def last_executed_action_has(self, name: Text, skip=0) -> bool: `True` if last executed action had name `name`, otherwise `False`. """ - last = self.get_last_event_for( + last: Optional[ActionExecuted] = self.get_last_event_for( ActionExecuted, action_names_to_exclude=[ACTION_LISTEN_NAME], skip=skip ) return last is not None and last.action_name == name + def get_last_executed(self, action_name: Text) -> Optional[ActionExecuted]: + """Get the last executed `action_name`. + + Returns: + The last `ActionExecuted` marking a session start if available, + otherwise `None`. + + """ + for event in reversed(self.applied_events()): + if isinstance(event, ActionExecuted) and event.action_name == action_name: + return event + ### # Internal methods for the modification of the trackers state. Should # only be called by events, not directly. Rather update the tracker diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index 77dcd01dc257..fd708d33ba7e 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -503,7 +503,10 @@ async def test_action_session_start_without_slots( events = await ActionSessionStart().run( default_channel, template_nlg, template_sender_tracker, default_domain ) - assert events == [SessionStarted(), ActionExecuted(action_name=ACTION_LISTEN_NAME)] + assert events == [ + ActionExecuted(action_name=ACTION_SESSION_START_NAME), + ActionExecuted(action_name=ACTION_LISTEN_NAME), + ] async def test_action_session_start_with_slots( @@ -523,7 +526,7 @@ async def test_action_session_start_with_slots( ) assert events == [ - SessionStarted(), + ActionExecuted(action_name=ACTION_SESSION_START_NAME), slot_set_event_1, slot_set_event_2, ActionExecuted(action_name=ACTION_LISTEN_NAME), diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index ece0c5eda0f7..2119d21b8a33 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -280,17 +280,16 @@ async def test_reminder_restart( assert len(t.events) == 5 # nothing should have been executed -# noinspection PyProtectedMember @pytest.mark.parametrize( "events_to_apply,is_legacy", [ # just an action listen means it's legacy - ([ActionExecuted(ACTION_LISTEN_NAME)], True), - # action listen and session start means it isn't legacy + ([ActionExecuted(action_name=ACTION_LISTEN_NAME)], True), + # action listen and session at the beginning start means it isn't legacy ( [ - ActionExecuted(ACTION_SESSION_START_NAME), - ActionExecuted(ACTION_LISTEN_NAME), + ActionExecuted(action_name=ACTION_SESSION_START_NAME), + ActionExecuted(action_name=ACTION_LISTEN_NAME), ], False, ), @@ -298,22 +297,22 @@ async def test_reminder_restart( ([UserUttered("hello")], True), ], ) -async def test_is_new_tracker( +async def test_is_legacy_tracker( events_to_apply: List[Event], is_legacy: bool, default_processor: MessageProcessor, ): sender_id = uuid.uuid4().hex - # new tracker without events + # create a new tracker without events tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) tracker.events.clear() for event in events_to_apply: tracker.update(event) + # noinspection PyProtectedMember assert default_processor._is_legacy_tracker(tracker) == is_legacy -# noinspection PyProtectedMember @pytest.mark.parametrize( "event_to_apply,session_length_in_minutes,has_expired", [ @@ -343,6 +342,7 @@ async def test_has_session_expired( if event_to_apply: tracker.update(event_to_apply) + # noinspection PyProtectedMember assert ( default_processor._has_session_expired( tracker, session_length_in_minutes=session_length_in_minutes diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index ba1994cb06af..3e82bd72a87a 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -2,6 +2,7 @@ import logging import os import tempfile +from typing import List, Optional import fakeredis import pytest @@ -18,6 +19,7 @@ ActionReverted, UserUtteranceReverted, SessionStarted, + Event, ) from rasa.core.tracker_store import ( InMemoryTrackerStore, @@ -605,3 +607,30 @@ def test_tracker_without_slots(key, value, caplog): v = tracker.get_slot(key) assert v == value assert len(caplog.records) == 0 + + +@pytest.mark.parametrize( + "events,index_of_last_executed_event", + [ + ([ActionExecuted("one")], 0), + ([ActionExecuted("a"), ActionExecuted("b")], 1), + ([ActionExecuted("first"), UserUttered("b"), ActionExecuted("second")], 2), + ([ActionExecuted("this"), UserUttered("b")], 0), + ([UserUttered("b")], None), # no `ActionExecuted` event + ], +) +def test_get_last_executed(events: List[Event], index_of_last_executed_event: int): + tracker = get_tracker(events) + + # noinspection PyTypeChecker + expected_event: Optional[ActionExecuted] = events[ + index_of_last_executed_event + ] if index_of_last_executed_event is not None else None + + fetched_event = ( + tracker.get_last_executed(expected_event.action_name) + if expected_event + else None + ) + + assert expected_event == fetched_event From bebe7734796e58641977837bd6492e4f4fe1508b Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 3 Dec 2019 09:00:06 +0100 Subject: [PATCH 16/74] update todo --- rasa/core/actions/action.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 6b43b14fad9c..afbc62494712 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -338,7 +338,8 @@ async def run( # noinspection PyTypeChecker return ( - # TODO: should this rather return a `SessionStarted()` event rather than a SessionExecuted()? + # TODO: should return a `SessionStarted()` event rather than an + # ActionExecuted(action_name=ACTION_SESSION_START_NAME)? [ActionExecuted(action_name=ACTION_SESSION_START_NAME)] + list(slot_set_events) + [ActionExecuted(action_name=ACTION_LISTEN_NAME)] From 7ad62cd8e59032b491e120d7536c98cd19d86129 Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 3 Dec 2019 19:01:44 +0100 Subject: [PATCH 17/74] wip --- rasa/core/actions/action.py | 13 ++--- rasa/core/events/__init__.py | 2 + rasa/core/processor.py | 93 +++++++++++++++++++++++++----------- rasa/core/tracker_store.py | 11 +++-- rasa/core/trackers.py | 10 ++-- tests/core/test_processor.py | 27 ++++------- tests/core/test_trackers.py | 19 ++++---- 7 files changed, 102 insertions(+), 73 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index afbc62494712..af339256c611 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -324,9 +324,9 @@ async def run( tracker: "DialogueStateTracker", domain: "Domain", ) -> List[Event]: - from rasa.core.events import SessionStarted, SlotSet + from rasa.core.events import SessionStarted, SlotSet, FollowupAction - # TODO: config check whether slots should be carried over + # TODO: check in domain whether slots should be carried over # fetch SlotSet events from tracker and carry over key, value and metadata # use generator so the timestamps are greater than that of the returned # `SessionStarted` event @@ -335,14 +335,15 @@ async def run( for event in tracker.events if isinstance(event, SlotSet) ) + slot_set_events = [ + SlotSet(f"dummy slot {i}", f"dummy slot {i}") for i in range(4) + ] # noinspection PyTypeChecker return ( - # TODO: should return a `SessionStarted()` event rather than an - # ActionExecuted(action_name=ACTION_SESSION_START_NAME)? - [ActionExecuted(action_name=ACTION_SESSION_START_NAME)] + [SessionStarted()] + list(slot_set_events) - + [ActionExecuted(action_name=ACTION_LISTEN_NAME)] + + [FollowupAction(ACTION_LISTEN_NAME)] ) diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index eb6b06474a69..0b390d51e136 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -1210,4 +1210,6 @@ def apply_to(self, tracker: "DialogueStateTracker") -> None: ACTION_SESSION_START_NAME, ) + # noinspection PyProtectedMember + tracker._reset() tracker.trigger_followup_action(ACTION_SESSION_START_NAME) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 4bf3cb07e69a..834517356013 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -1,3 +1,4 @@ +import asyncio import json import warnings import logging @@ -9,7 +10,11 @@ import time from rasa.core import jobs -from rasa.core.actions.action import Action, ACTION_SESSION_START_NAME +from rasa.core.actions.action import ( + Action, + ACTION_SESSION_START_NAME, + ActionSessionStart, +) from rasa.core.actions.action import ACTION_LISTEN_NAME, ActionExecutionRejection from rasa.core.channels.channel import ( CollectingOutputChannel, @@ -97,11 +102,11 @@ async def handle_message( await self._predict_and_execute_next_action(message, tracker) # save tracker state to continue conversation from this state self._save_tracker(tracker) - print("have tracker events") + print("\n\nhave tracker events") for e in tracker.events: print(e) + print("done in processor\n\n") if isinstance(message.output_channel, CollectingOutputChannel): - print("done in processor") return message.output_channel.messages else: return None @@ -138,6 +143,26 @@ def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]: "tracker": tracker.current_state(EventVerbosity.AFTER_RESTART), } + async def _update_tracker_session( + self, + tracker: DialogueStateTracker, + output_channel: OutputChannel, + session_length_in_minutes: float, + ) -> None: + + if self._is_legacy_tracker(tracker) or self._has_session_expired( + tracker, session_length_in_minutes + ): + logger.debug( + f"Starting a new session for conversation ID '{tracker.sender_id}'." + ) + await self._run_action( + action=self._get_action(ACTION_SESSION_START_NAME), + tracker=tracker, + output_channel=output_channel, + nlg=self.nlg, + ) + async def log_message( self, message: UserMessage, should_save_tracker: bool = True ) -> Optional[DialogueStateTracker]: @@ -154,7 +179,11 @@ async def log_message( # we have a Tracker instance for each user # which maintains conversation state tracker = self._get_tracker(message.sender_id) + if tracker: + # TODO: get session length from domain + await self._update_tracker_session(tracker, message.output_channel, 1) + await self._handle_message_with_tracker(message, tracker) if should_save_tracker: @@ -253,7 +282,8 @@ async def handle_reminder( """Handle a reminder that is triggered asynchronously.""" tracker = self._get_tracker(sender_id) - + print("fetched", sender_id, tracker.as_dialogue()) + raise if not tracker: logger.warning( f"Failed to retrieve or create tracker for sender '{sender_id}'." @@ -585,9 +615,8 @@ def _log_action_on_tracker(self, tracker, action_name, events, policy, confidenc e.timestamp = time.time() tracker.update(e, self.domain) - def _session_start_timestamp_for( - self, tracker: DialogueStateTracker - ) -> Optional[float]: + @staticmethod + def _session_start_timestamp_for(tracker: DialogueStateTracker) -> Optional[float]: """Retrieve timestamp of the beginning of the last session start for `tracker`. @@ -603,12 +632,14 @@ def _session_start_timestamp_for( # this is a legacy tracker (pre-sessions), return current time return time.time() - # try to fetch the timestamp of the latest `SessionStarted` event - last_executed_session_started_action = tracker.get_last_executed( - ACTION_SESSION_START_NAME - ) - if last_executed_session_started_action: - return last_executed_session_started_action.timestamp + # return last_executed_session_started_action.timestamp + + last_session_started_event = tracker.get_last_session_started_event() + + for e in tracker.applied_events(): + print(e) + if last_session_started_event: + return last_session_started_event.timestamp # otherwise fetch the timestamp of the first event # this also is a legacy tracker (pre-sessions) @@ -631,14 +662,22 @@ def _has_session_expired( time_delta_in_seconds = time.time() - session_start_timestamp - return time_delta_in_seconds / 60 > session_length_in_minutes + has_expired = time_delta_in_seconds / 60 > session_length_in_minutes + + if has_expired: + logger.debug( + f"The latest session for conversation ID {tracker.sender_id} has " + f"expired." + ) + + return has_expired @staticmethod def _is_legacy_tracker(tracker: DialogueStateTracker) -> bool: """Determine whether `tracker` is a legacy tracker. A legacy tracker is a tracker that has been created before the introduction - of sessions in release 1.7.0. + of sessions in release 1.6.0. Args: tracker: Tracker to inspect. @@ -647,24 +686,24 @@ def _is_legacy_tracker(tracker: DialogueStateTracker) -> bool: `True` if the tracker contains `SessionStarted` event, `False` otherwise. """ + last_session_started_event = tracker.get_last_session_started_event() - last_executed_session_started_action = tracker.get_last_executed( - ACTION_SESSION_START_NAME - ) + is_legacy_tracker = last_session_started_event is None - return last_executed_session_started_action is None + if is_legacy_tracker: + logger.debug( + f"Tracker for conversation ID '{tracker.sender_id}' is a legacy " + f"tracker. A legacy tracker is a tracker that contains no " + f"'SessionStarted' and was last saved before the introduction of " + f"tracker sessions in release 1.6.0." + ) - def _get_tracker( - self, sender_id: Text, session_length_in_minutes: int = 60 - ) -> Optional[DialogueStateTracker]: + return is_legacy_tracker + + def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]: sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID tracker = self.tracker_store.get_or_create_tracker(sender_id) - if self._is_legacy_tracker(tracker) or self._has_session_expired( - tracker, session_length_in_minutes - ): - tracker.update(SessionStarted()) - return tracker def _save_tracker(self, tracker: DialogueStateTracker) -> None: diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index e71aca370bae..bfc3fcfb76a4 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -7,7 +7,7 @@ import typing from datetime import datetime, timezone -from typing import Iterator, Optional, Text, Iterable, Union, Dict, Callable, List +from typing import Iterator, Optional, Text, Iterable, Union, Dict, Callable import itertools from boto3.dynamodb.conditions import Key @@ -15,8 +15,10 @@ # noinspection PyPep8Naming from time import sleep -from rasa.core.actions.action import ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME +from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.brokers.event_channel import EventChannel +from rasa.core.events import SessionStarted + from rasa.core.conversation import Dialogue from rasa.core.domain import Domain from rasa.core.trackers import ActionExecuted, DialogueStateTracker, EventVerbosity @@ -29,7 +31,6 @@ from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session import boto3 - from rasa.core.events import Event logger = logging.getLogger(__name__) @@ -156,14 +157,14 @@ def create_tracker( self, sender_id: Text, append_action_listen: bool = True, - should_append_session_started=True, + should_append_session_started: bool = True, ) -> DialogueStateTracker: """Creates a new tracker for the sender_id. The tracker is initially listening. """ tracker = self.init_tracker(sender_id) if tracker: if should_append_session_started: - tracker.update(ActionExecuted(ACTION_SESSION_START_NAME)) + tracker.update(SessionStarted()) if append_action_listen: tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index 7c1bc719fb73..f4d18db17043 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -495,16 +495,16 @@ def last_executed_action_has(self, name: Text, skip=0) -> bool: ) return last is not None and last.action_name == name - def get_last_executed(self, action_name: Text) -> Optional[ActionExecuted]: - """Get the last executed `action_name`. + def get_last_session_started_event(self) -> Optional[SessionStarted]: + """Get the last `SessionStarted` event. Returns: - The last `ActionExecuted` marking a session start if available, + The last `SessionStarted` marking a session start if available, otherwise `None`. """ - for event in reversed(self.applied_events()): - if isinstance(event, ActionExecuted) and event.action_name == action_name: + for event in reversed(self.events): + if isinstance(event, SessionStarted): return event ### diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 2119d21b8a33..aa04c9bc3121 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -2,7 +2,6 @@ import time -import aiohttp import asyncio import datetime import uuid @@ -12,13 +11,8 @@ from unittest.mock import patch -import rasa.utils.io from rasa.core import jobs -from rasa.core.actions.action import ( - ACTION_LISTEN_NAME, - ActionSessionStart, - ACTION_SESSION_START_NAME, -) +from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.agent import Agent from rasa.core.channels.channel import CollectingOutputChannel, UserMessage from rasa.core.events import ( @@ -33,14 +27,11 @@ ) from rasa.core.trackers import DialogueStateTracker from rasa.core.slots import Slot -from rasa.core.processor import MessageProcessor from rasa.core.interpreter import RasaNLUHttpInterpreter from rasa.core.processor import MessageProcessor from rasa.utils.endpoints import EndpointConfig -from tests.utilities import json_of_latest_request, latest_request +from tests.utilities import latest_request -from tests.core.conftest import DEFAULT_DOMAIN_PATH_WITH_SLOTS -from rasa.core.domain import Domain import logging @@ -166,12 +157,16 @@ async def test_reminder_scheduled( tracker.update(reminder) default_processor.tracker_store.save(tracker) + + print("before", default_processor._get_tracker(sender_id).as_dialogue()) + print("before", default_processor.tracker_store.retrieve(sender_id).as_dialogue()) await default_processor.handle_reminder( reminder, sender_id, default_channel, default_processor.nlg ) # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) + assert t.events[-4] == UserUttered(None) assert t.events[-3] == ActionExecuted("utter_greet") assert t.events[-2] == BotUttered( @@ -208,7 +203,7 @@ async def test_reminder_aborted( # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) - assert len(t.events) == 4 # nothing should have been executed + assert len(t.events) == 2 # nothing should have been executed async def test_reminder_cancelled( @@ -286,13 +281,7 @@ async def test_reminder_restart( # just an action listen means it's legacy ([ActionExecuted(action_name=ACTION_LISTEN_NAME)], True), # action listen and session at the beginning start means it isn't legacy - ( - [ - ActionExecuted(action_name=ACTION_SESSION_START_NAME), - ActionExecuted(action_name=ACTION_LISTEN_NAME), - ], - False, - ), + ([SessionStarted(), ActionExecuted(action_name=ACTION_LISTEN_NAME)], False), # just a single event means it's legacy ([UserUttered("hello")], True), ], diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index 3e82bd72a87a..a0d648703e9c 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -612,14 +612,15 @@ def test_tracker_without_slots(key, value, caplog): @pytest.mark.parametrize( "events,index_of_last_executed_event", [ - ([ActionExecuted("one")], 0), - ([ActionExecuted("a"), ActionExecuted("b")], 1), - ([ActionExecuted("first"), UserUttered("b"), ActionExecuted("second")], 2), - ([ActionExecuted("this"), UserUttered("b")], 0), - ([UserUttered("b")], None), # no `ActionExecuted` event + ([ActionExecuted("one")], None), # no SessionStarted event + ([ActionExecuted("a"), SessionStarted()], 1), + ([ActionExecuted("first"), UserUttered("b"), SessionStarted()], 2), + ([SessionStarted(), UserUttered("b")], 0), ], ) -def test_get_last_executed(events: List[Event], index_of_last_executed_event: int): +def test_last_session_started_event( + events: List[Event], index_of_last_executed_event: int +): tracker = get_tracker(events) # noinspection PyTypeChecker @@ -627,10 +628,6 @@ def test_get_last_executed(events: List[Event], index_of_last_executed_event: in index_of_last_executed_event ] if index_of_last_executed_event is not None else None - fetched_event = ( - tracker.get_last_executed(expected_event.action_name) - if expected_event - else None - ) + fetched_event = tracker.get_last_session_started_event() if expected_event else None assert expected_event == fetched_event From d08cbfed180a2d36513fc20e924ec2b22367e112 Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 3 Dec 2019 19:03:24 +0100 Subject: [PATCH 18/74] remove debug messages --- rasa/core/processor.py | 16 +++------------- tests/core/test_processor.py | 2 -- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 834517356013..5b38e8184847 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -102,10 +102,7 @@ async def handle_message( await self._predict_and_execute_next_action(message, tracker) # save tracker state to continue conversation from this state self._save_tracker(tracker) - print("\n\nhave tracker events") - for e in tracker.events: - print(e) - print("done in processor\n\n") + if isinstance(message.output_channel, CollectingOutputChannel): return message.output_channel.messages else: @@ -282,8 +279,7 @@ async def handle_reminder( """Handle a reminder that is triggered asynchronously.""" tracker = self._get_tracker(sender_id) - print("fetched", sender_id, tracker.as_dialogue()) - raise + if not tracker: logger.warning( f"Failed to retrieve or create tracker for sender '{sender_id}'." @@ -632,12 +628,8 @@ def _session_start_timestamp_for(tracker: DialogueStateTracker) -> Optional[floa # this is a legacy tracker (pre-sessions), return current time return time.time() - # return last_executed_session_started_action.timestamp - last_session_started_event = tracker.get_last_session_started_event() - for e in tracker.applied_events(): - print(e) if last_session_started_event: return last_session_started_event.timestamp @@ -702,9 +694,7 @@ def _is_legacy_tracker(tracker: DialogueStateTracker) -> bool: def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]: sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID - tracker = self.tracker_store.get_or_create_tracker(sender_id) - - return tracker + return self.tracker_store.get_or_create_tracker(sender_id) def _save_tracker(self, tracker: DialogueStateTracker) -> None: self.tracker_store.save(tracker) diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index aa04c9bc3121..928e9ccc1b5f 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -158,8 +158,6 @@ async def test_reminder_scheduled( default_processor.tracker_store.save(tracker) - print("before", default_processor._get_tracker(sender_id).as_dialogue()) - print("before", default_processor.tracker_store.retrieve(sender_id).as_dialogue()) await default_processor.handle_reminder( reminder, sender_id, default_channel, default_processor.nlg ) From b8c2ca3cb33cd0f4c40642b6bcc32580fd535b3a Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 3 Dec 2019 19:48:20 +0100 Subject: [PATCH 19/74] restructure session creation --- rasa/core/actions/action.py | 6 +++--- tests/core/test_actions.py | 9 +++++---- tests/core/test_processor.py | 31 +++++++++++++++++++++++++++++-- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index af339256c611..976f36627449 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -335,9 +335,9 @@ async def run( for event in tracker.events if isinstance(event, SlotSet) ) - slot_set_events = [ - SlotSet(f"dummy slot {i}", f"dummy slot {i}") for i in range(4) - ] + # slot_set_events = [ + # SlotSet(f"dummy slot {i}", f"dummy slot {i}") for i in range(4) + # ] # noinspection PyTypeChecker return ( diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index fd708d33ba7e..68e269018ff2 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -35,6 +35,7 @@ Form, SessionStarted, ActionExecuted, + FollowupAction, ) from rasa.core.nlg.template import TemplatedNaturalLanguageGenerator from rasa.core.trackers import DialogueStateTracker @@ -504,8 +505,8 @@ async def test_action_session_start_without_slots( default_channel, template_nlg, template_sender_tracker, default_domain ) assert events == [ - ActionExecuted(action_name=ACTION_SESSION_START_NAME), - ActionExecuted(action_name=ACTION_LISTEN_NAME), + SessionStarted(), + FollowupAction(ACTION_LISTEN_NAME), ] @@ -526,10 +527,10 @@ async def test_action_session_start_with_slots( ) assert events == [ - ActionExecuted(action_name=ACTION_SESSION_START_NAME), + SessionStarted(), slot_set_event_1, slot_set_event_2, - ActionExecuted(action_name=ACTION_LISTEN_NAME), + FollowupAction(ACTION_LISTEN_NAME), ] # make sure that the list of events has ascending timestamps diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 928e9ccc1b5f..93b8d315e948 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -12,7 +12,7 @@ from unittest.mock import patch from rasa.core import jobs -from rasa.core.actions.action import ACTION_LISTEN_NAME +from rasa.core.actions.action import ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME from rasa.core.agent import Agent from rasa.core.channels.channel import CollectingOutputChannel, UserMessage from rasa.core.events import ( @@ -24,6 +24,7 @@ UserUttered, SessionStarted, Event, + FollowupAction, ) from rasa.core.trackers import DialogueStateTracker from rasa.core.slots import Slot @@ -201,7 +202,7 @@ async def test_reminder_aborted( # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) - assert len(t.events) == 2 # nothing should have been executed + assert len(t.events) == 4 # nothing should have been executed async def test_reminder_cancelled( @@ -336,3 +337,29 @@ async def test_has_session_expired( ) == has_expired ) + + +# noinspection PyProtectedMember +async def test_update_tracker_session( + default_channel: CollectingOutputChannel, default_processor: MessageProcessor, +): + sender_id = uuid.uuid4().hex + tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) + + # make sure session expires + await asyncio.sleep(1e-2) # in seconds + + await default_processor._update_tracker_session(tracker, default_channel, 1e-5) + + # the save is not called in _update_tracker_session() + default_processor._save_tracker(tracker) + + # inspect tracker and make sure all events are present + tracker = default_processor.tracker_store.retrieve(sender_id) + assert list(tracker.events) == [ + SessionStarted(), + ActionExecuted(ACTION_LISTEN_NAME), + ActionExecuted(ACTION_SESSION_START_NAME), + SessionStarted(), + FollowupAction(ACTION_LISTEN_NAME), + ] From 7b8a84d9245df513ee58edb26feebeded4f2ff9b Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 3 Dec 2019 19:52:34 +0100 Subject: [PATCH 20/74] update docstring --- rasa/core/processor.py | 11 +++++++++++ tests/core/test_processor.py | 3 +-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 5b38e8184847..45b80709b172 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -146,7 +146,18 @@ async def _update_tracker_session( output_channel: OutputChannel, session_length_in_minutes: float, ) -> None: + """Check the current session in `tracker` and update it if expired. + A 'session_start' is run if the tracker is a legacy tracker, or if the latest + tracker session has expired. + + Args: + tracker: Tracker to inspect. + output_channel: Output channel for potential utterances in a custom + `ActionSessionStart`. + session_length_in_minutes: Session length in minutes. + + """ if self._is_legacy_tracker(tracker) or self._has_session_expired( tracker, session_length_in_minutes ): diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 93b8d315e948..57e3691ea8e8 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -346,9 +346,8 @@ async def test_update_tracker_session( sender_id = uuid.uuid4().hex tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) - # make sure session expires + # make sure session expires and run tracker session update await asyncio.sleep(1e-2) # in seconds - await default_processor._update_tracker_session(tracker, default_channel, 1e-5) # the save is not called in _update_tracker_session() From c7717ef7a01a30008d80a14ac3c3da9c3f45b9ca Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 10:05:29 +0100 Subject: [PATCH 21/74] fix tests --- rasa/core/actions/action.py | 2 +- rasa/core/events/__init__.py | 8 +++-- rasa/core/processor.py | 38 +++++++++++--------- rasa/core/trackers.py | 5 +-- tests/core/test_actions.py | 8 ++--- tests/core/test_agent.py | 2 +- tests/core/test_domain.py | 2 +- tests/core/test_dsl.py | 14 ++++---- tests/core/test_processor.py | 68 +++++++++++++++++++++++++++++++----- tests/core/test_trackers.py | 18 ++++++---- tests/test_server.py | 13 +++++-- 11 files changed, 121 insertions(+), 57 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 976f36627449..04d2e08f8851 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -343,7 +343,7 @@ async def run( return ( [SessionStarted()] + list(slot_set_events) - + [FollowupAction(ACTION_LISTEN_NAME)] + + [ActionExecuted(ACTION_LISTEN_NAME)] ) diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index 89497ac42583..38ac264d2548 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -1,5 +1,6 @@ import json import logging +import warnings import jsonpickle import time @@ -8,7 +9,7 @@ from dateutil import parser from datetime import datetime -from typing import List, Dict, Text, Any, Type, Optional, NoReturn +from typing import List, Dict, Text, Any, Type, Optional from rasa.core import utils @@ -1197,10 +1198,11 @@ def __eq__(self, other: Any) -> bool: def __str__(self) -> Text: return "SessionStarted()" - def as_story_string(self) -> NoReturn: - raise NotImplementedError( + def as_story_string(self) -> None: + warnings.warn( f"'{self.type_name}' events cannot be serialised as story strings." ) + return None def apply_to(self, tracker: "DialogueStateTracker") -> None: from rasa.core.actions.action import ( # pytype: disable=pyi-error diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 45b80709b172..0d7381026ccc 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -1,10 +1,8 @@ -import asyncio -import json import warnings import logging import os from types import LambdaType -from typing import Any, Dict, List, Optional, Text, Tuple, Union +from typing import Any, Dict, List, Optional, Text, Tuple import numpy as np import time @@ -13,7 +11,6 @@ from rasa.core.actions.action import ( Action, ACTION_SESSION_START_NAME, - ActionSessionStart, ) from rasa.core.actions.action import ACTION_LISTEN_NAME, ActionExecutionRejection from rasa.core.channels.channel import ( @@ -27,6 +24,7 @@ UTTER_PREFIX, USER_INTENT_BACK, USER_INTENT_OUT_OF_SCOPE, + USER_INTENT_SESSION_START, ) from rasa.core.domain import Domain from rasa.core.events import ( @@ -38,7 +36,6 @@ SlotSet, UserUttered, BotUttered, - SessionStarted, ) from rasa.core.interpreter import ( INTENT_MESSAGE_PREFIX, @@ -56,6 +53,13 @@ MAX_NUMBER_OF_PREDICTIONS = int(os.environ.get("MAX_NUMBER_OF_PREDICTIONS", "10")) +DEFAULT_INTENTS = [ + USER_INTENT_RESTART, + USER_INTENT_BACK, + USER_INTENT_OUT_OF_SCOPE, + USER_INTENT_SESSION_START, +] + class MessageProcessor: def __init__( @@ -102,7 +106,13 @@ async def handle_message( await self._predict_and_execute_next_action(message, tracker) # save tracker state to continue conversation from this state self._save_tracker(tracker) + print("have tracker events") + for e in tracker.events: + print(e) + print("\napplied events") + for e in tracker.applied_events(): + print(e) if isinstance(message.output_channel, CollectingOutputChannel): return message.output_channel.messages else: @@ -156,7 +166,7 @@ async def _update_tracker_session( output_channel: Output channel for potential utterances in a custom `ActionSessionStart`. session_length_in_minutes: Session length in minutes. - + """ if self._is_legacy_tracker(tracker) or self._has_session_expired( tracker, session_length_in_minutes @@ -179,6 +189,7 @@ async def log_message( Optionally save the tracker if `should_save_tracker` is `True`. Tracker saving can be skipped if the tracker returned by this method is used for further processing and saved at a later stage. + """ # preprocess message if necessary @@ -199,7 +210,7 @@ async def log_message( self._save_tracker(tracker) else: logger.warning( - "Failed to retrieve or create tracker for sender " + "Failed to retrieve or create tracker for conversation ID " f"'{message.sender_id}'." ) return tracker @@ -227,7 +238,8 @@ async def execute_action( self._save_tracker(tracker) else: logger.warning( - f"Failed to retrieve or create tracker for sender '{sender_id}'." + f"Failed to retrieve or create tracker for conversation ID " + f"'{sender_id}'." ) return tracker @@ -293,7 +305,7 @@ async def handle_reminder( if not tracker: logger.warning( - f"Failed to retrieve or create tracker for sender '{sender_id}'." + f"Failed to retrieve tracker for conversation ID '{sender_id}'." ) return None @@ -337,17 +349,11 @@ def _log_unseen_features(self, parse_data: Dict[Text, Any]) -> None: domain_is_not_empty = self.domain and not self.domain.is_empty() - default_intents = [ - USER_INTENT_RESTART, - USER_INTENT_BACK, - USER_INTENT_OUT_OF_SCOPE, - ] - intent = parse_data["intent"]["name"] if intent: intent_is_recognized = ( domain_is_not_empty and intent in self.domain.intents - ) or intent in default_intents + ) or intent in DEFAULT_INTENTS if not intent_is_recognized: warnings.warn( f"Interpreter parsed an intent '{intent}' " diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index f4d18db17043..35d3156ba3b6 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -5,10 +5,7 @@ from typing import Dict, Text, Any, Optional, Iterator, Generator, Type, List, Deque from rasa.core import events # pytype: disable=pyi-error -from rasa.core.actions.action import ( - ACTION_LISTEN_NAME, - ACTION_SESSION_START_NAME, -) # pytype: disable=pyi-error +from rasa.core.actions.action import ACTION_LISTEN_NAME # pytype: disable=pyi-error from rasa.core.conversation import Dialogue # pytype: disable=pyi-error from rasa.core.events import ( # pytype: disable=pyi-error UserUttered, diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index 68e269018ff2..b2d38099bf05 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -26,7 +26,7 @@ ActionSessionStart, ) from rasa.core.channels import CollectingOutputChannel -from rasa.core.domain import Domain, InvalidDomain +from rasa.core.domain import Domain from rasa.core.events import ( Restarted, SlotSet, @@ -35,13 +35,11 @@ Form, SessionStarted, ActionExecuted, - FollowupAction, ) from rasa.core.nlg.template import TemplatedNaturalLanguageGenerator from rasa.core.trackers import DialogueStateTracker from rasa.utils.endpoints import ClientResponseError, EndpointConfig from tests.utilities import json_of_latest_request, latest_request -from rasa.core.constants import UTTER_PREFIX, RESPOND_PREFIX @pytest.fixture(scope="module") @@ -506,7 +504,7 @@ async def test_action_session_start_without_slots( ) assert events == [ SessionStarted(), - FollowupAction(ACTION_LISTEN_NAME), + ActionExecuted(ACTION_LISTEN_NAME), ] @@ -530,7 +528,7 @@ async def test_action_session_start_with_slots( SessionStarted(), slot_set_event_1, slot_set_event_2, - FollowupAction(ACTION_LISTEN_NAME), + ActionExecuted(ACTION_LISTEN_NAME), ] # make sure that the list of events has ascending timestamps diff --git a/tests/core/test_agent.py b/tests/core/test_agent.py index 530f015017a4..6464ba7f718e 100644 --- a/tests/core/test_agent.py +++ b/tests/core/test_agent.py @@ -259,7 +259,7 @@ async def test_agent_update_model_none_domain(trained_model): tracker = agent.tracker_store.get_or_create_tracker(sender_id) # UserUttered event was added to tracker, with correct intent data - assert tracker.events[1].intent["name"] == "greet" + assert tracker.events[2].intent["name"] == "greet" async def test_load_agent_on_not_existing_path(): diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index c3bd74113530..819abcf5b5a7 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -157,7 +157,7 @@ def test_domain_from_template(): assert not domain.is_empty() assert len(domain.intents) == 10 - assert len(domain.action_names) == 11 + assert len(domain.action_names) == 12 def test_utter_templates(): diff --git a/tests/core/test_dsl.py b/tests/core/test_dsl.py index 220da918f0a8..70e8e5df59d3 100644 --- a/tests/core/test_dsl.py +++ b/tests/core/test_dsl.py @@ -218,7 +218,7 @@ async def test_read_story_file_with_cycles(tmpdir, default_domain): assert len(graph_without_cycles.story_end_checkpoints) == 2 -async def test_generate_training_data_with_cycles(tmpdir, default_domain): +async def test_generate_training_data_with_cycles(default_domain): featurizer = MaxHistoryTrackerFeaturizer( BinarySingleStateFeaturizer(), max_history=4 ) @@ -232,11 +232,11 @@ async def test_generate_training_data_with_cycles(tmpdir, default_domain): # deterministic way but should always be 3 or 4 assert len(training_trackers) == 3 or len(training_trackers) == 4 - # if we have 4 trackers, there is going to be one example more for label 4 - num_threes = len(training_trackers) - 1 + # if we have 4 trackers, there is going to be one example more for label 9 + num_nines = len(training_trackers) - 1 # if new default actions are added the keys of the actions will be changed - assert Counter(y) == {0: 6, 9: 3, 8: num_threes, 1: 2, 10: 1} + assert Counter(y) == {0: 6, 10: 3, 9: num_nines, 1: 2, 11: 1} async def test_generate_training_data_with_unused_checkpoints(tmpdir, default_domain): @@ -481,9 +481,8 @@ def test_user_uttered_to_e2e(parse_data: Dict, expected_story_string: Text): assert event.as_story_string(e2e=True) == expected_story_string -def test_session_started_event_cannot_be_serialised(): - with pytest.raises(NotImplementedError): - SessionStarted().as_story_string() +def test_session_started_event_is_not_serialised(): + assert SessionStarted().as_story_string() is None @pytest.mark.parametrize("line", [" greet{: hi"]) @@ -491,4 +490,5 @@ def test_invalid_end_to_end_format(line: Text): reader = EndToEndReader() with pytest.raises(ValueError): + # noinspection PyProtectedMember _ = reader._parse_item(line) diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 57e3691ea8e8..c2c48068e7c3 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Text import time @@ -24,12 +24,12 @@ UserUttered, SessionStarted, Event, - FollowupAction, + SlotSet, ) from rasa.core.trackers import DialogueStateTracker from rasa.core.slots import Slot from rasa.core.interpreter import RasaNLUHttpInterpreter -from rasa.core.processor import MessageProcessor +from rasa.core.processor import MessageProcessor, DEFAULT_INTENTS from rasa.utils.endpoints import EndpointConfig from tests.utilities import latest_request @@ -86,8 +86,11 @@ async def test_log_unseen_feature(default_processor: MessageProcessor): ) -async def test_default_intent_recognized(default_processor: MessageProcessor): - message = UserMessage("/restart") +@pytest.mark.parametrize("default_intent", DEFAULT_INTENTS) +async def test_default_intent_recognized( + default_processor: MessageProcessor, default_intent: Text +): + message = UserMessage(default_intent) parsed = await default_processor._parse_message(message) with pytest.warns(None) as record: default_processor._log_unseen_features(parsed) @@ -307,16 +310,16 @@ async def test_is_legacy_tracker( # session start is way in the past (SessionStarted(timestamp=1), 60, True), # session start is very recent - (SessionStarted(timestamp=time.time()), 1, False), + (SessionStarted(timestamp=time.time()), 10, False), # there is no session start event (legacy tracker) - (UserUttered("hello", timestamp=time.time()), 1, False), + (UserUttered("hello", timestamp=time.time()), 10, False), # there is no event (None, 1, False), ], ) async def test_has_session_expired( event_to_apply: Optional[Event], - session_length_in_minutes: int, + session_length_in_minutes: float, has_expired: bool, default_processor: MessageProcessor, ): @@ -360,5 +363,52 @@ async def test_update_tracker_session( ActionExecuted(ACTION_LISTEN_NAME), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), - FollowupAction(ACTION_LISTEN_NAME), + ActionExecuted(ACTION_LISTEN_NAME), ] + + +# noinspection PyProtectedMember +async def test_update_tracker_session_with_slots( + default_channel: CollectingOutputChannel, default_processor: MessageProcessor, +): + sender_id = uuid.uuid4().hex + tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) + + # apply a user uttered and five slots + user_event = UserUttered("some utterance") + tracker.update(user_event) + + slot_set_events = [SlotSet(f"slot key {i}", f"test value {i}") for i in range(5)] + + for event in slot_set_events: + tracker.update(event) + + # make sure session expires and run tracker session update + await asyncio.sleep(1e-2) # in seconds + await default_processor._update_tracker_session(tracker, default_channel, 1e-5) + + # the save is not called in _update_tracker_session() + default_processor._save_tracker(tracker) + + # inspect tracker and make sure all events are present + tracker = default_processor.tracker_store.retrieve(sender_id) + events = list(tracker.events) + + # the first three events should be up to the user utterance + assert events[:3] == [ + SessionStarted(), + ActionExecuted(ACTION_LISTEN_NAME), + user_event, + ] + + # next come the five slots + assert events[3:8] == slot_set_events + + # the next two events are the session start sequence + assert events[8:10] == [ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted()] + + # the five slots should be reapplied + assert events[10:15] == slot_set_events + + # finally an action listen, this should also be the last event + assert events[15] == events[-1] == ActionExecuted(ACTION_LISTEN_NAME) diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index a0d648703e9c..e60f2fcdaaed 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -9,7 +9,7 @@ import rasa.utils.io from rasa.core import training, restore -from rasa.core.actions.action import ACTION_LISTEN_NAME +from rasa.core.actions.action import ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME from rasa.core.domain import Domain from rasa.core.events import ( SlotSet, @@ -85,7 +85,10 @@ def test_tracker_store_storage_and_retrieval(store): assert tracker.sender_id == "some-id" # Action listen should be in there - assert list(tracker.events) == [ActionExecuted(ACTION_LISTEN_NAME)] + assert list(tracker.events) == [ + SessionStarted(), + ActionExecuted(ACTION_LISTEN_NAME), + ] # lets log a test message intent = {"name": "greet", "confidence": 1.0} @@ -96,13 +99,13 @@ def test_tracker_store_storage_and_retrieval(store): # retrieving the same tracker should result in the same tracker retrieved_tracker = store.get_or_create_tracker("some-id") assert retrieved_tracker.sender_id == "some-id" - assert len(retrieved_tracker.events) == 2 + assert len(retrieved_tracker.events) == 3 assert retrieved_tracker.latest_message.intent.get("name") == "greet" # getting another tracker should result in an empty tracker again other_tracker = store.get_or_create_tracker("some-other-id") assert other_tracker.sender_id == "some-other-id" - assert len(other_tracker.events) == 1 + assert len(other_tracker.events) == 2 @pytest.mark.parametrize("store", stores_to_be_tested(), ids=stores_to_be_tested_ids()) @@ -158,6 +161,7 @@ async def test_tracker_state_regression_with_bot_utterance(default_agent): tracker = default_agent.tracker_store.get_or_create_tracker(sender_id) expected = [ + None, "action_listen", "greet", "utter_greet", @@ -179,7 +183,7 @@ async def test_bot_utterance_comes_after_action_event(default_agent): # important is, that the 'bot' comes after the second 'action' and not # before - expected = ["action", "user", "action", "bot", "action"] + expected = ["session_started", "action", "user", "action", "bot", "action"] assert [e.type_name for e in tracker.events] == expected @@ -282,8 +286,8 @@ def test_session_start(default_domain: Domain): # tracker has one event assert len(tracker.events) == 1 - # follow-up action should be 'action_listen' - assert tracker.followup_action == ACTION_LISTEN_NAME + # follow-up action should be 'session_start' + assert tracker.followup_action == ACTION_SESSION_START_NAME def test_revert_action_event(default_domain: Domain): diff --git a/tests/test_server.py b/tests/test_server.py index 28c11c121eb8..aa92ba8d7f24 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -536,13 +536,20 @@ def test_requesting_non_existent_tracker(rasa_app: SanicTestClient): assert content["slots"] == {"location": None, "cuisine": None} assert content["sender_id"] == "madeupid" assert content["events"] == [ + { + "event": "action", + "name": "session_start", + "policy": None, + "confidence": None, + "timestamp": 1514764800, + }, { "event": "action", "name": "action_listen", "policy": None, "confidence": None, "timestamp": 1514764800, - } + }, ] assert content["latest_message"] == { "text": None, @@ -569,7 +576,7 @@ def test_pushing_event(rasa_app, event): _, tracker_response = rasa_app.get(f"/conversations/{cid}/tracker") tracker = tracker_response.json assert tracker is not None - assert len(tracker.get("events")) == 2 + assert len(tracker.get("events")) == 3 evt = tracker.get("events")[1] assert Event.from_parameters(evt) == event @@ -593,7 +600,7 @@ def test_push_multiple_events(rasa_app: SanicTestClient): assert tracker is not None # there is also an `ACTION_LISTEN` event at the start - assert len(tracker.get("events")) == len(test_events) + 1 + assert len(tracker.get("events")) == len(test_events) + 2 assert tracker.get("events")[1:] == events From 59be4ea6eee9e6fda786515b719b0d3bb8100284 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 10:52:50 +0100 Subject: [PATCH 22/74] update changelog --- changelog/4830.feature.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/changelog/4830.feature.rst b/changelog/4830.feature.rst index ad8135c39fb1..3ddad54229fc 100644 --- a/changelog/4830.feature.rst +++ b/changelog/4830.feature.rst @@ -1,2 +1,8 @@ Added a new event ``SessionStarted`` that marks the beginning of a new conversation session. + +Added a new default action ``ActionSessionStart``. This action takes all ``SlotSet`` +events from the previous session and applies it to the next session. + +Added new default intent ``session_start`` which triggers the start of a new +conversation session. From ba3772b20069c128d23e77d43f93f65ead343a9b Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 11:00:34 +0100 Subject: [PATCH 23/74] update docs --- docs/api/events.rst | 22 ++++++++++++++++++++++ rasa/core/actions/action.py | 32 +++++++++++++++++++------------- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/docs/api/events.rst b/docs/api/events.rst index f3cb0258b18e..98ae4b2806dd 100644 --- a/docs/api/events.rst +++ b/docs/api/events.rst @@ -271,3 +271,25 @@ Log an executed action .. literalinclude:: ../../rasa/core/events/__init__.py :dedent: 4 :pyobject: ActionExecuted.apply_to + +Start a new conversation session +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:Short: Marks the beginning of a new conversation session. Resets the tracker and + triggers an ``ActionSessionStart`` which by default applies the existing + ``SlotSet`` events to the new session. + +:JSON: + .. literalinclude:: ../../tests/core/test_events.py + :start-after: # DOCS MARKER ActionExecuted + :dedent: 4 + :end-before: # DOCS END +:Class: + .. autoclass:: rasa.core.events.SessionStarted + +:Effect: + When added to a tracker, this is the code used to update the tracker: + + .. literalinclude:: ../../rasa/core/events/__init__.py + :dedent: 4 + :pyobject: SessionStarted.apply_to diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 04d2e08f8851..e2e6b2989d3b 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -2,7 +2,7 @@ import json import logging import typing -from typing import List, Text, Optional, Dict, Any +from typing import List, Text, Optional, Dict, Any, Generator import aiohttp @@ -36,6 +36,7 @@ from rasa.core.domain import Domain from rasa.core.nlg import NaturalLanguageGenerator from rasa.core.channels.channel import OutputChannel + from rasa.core.events import SlotSet logger = logging.getLogger(__name__) @@ -317,6 +318,21 @@ class ActionSessionStart(Action): def name(self) -> Text: return ACTION_SESSION_START_NAME + @staticmethod + def _slot_set_events_from_tracker( + tracker: "DialogueStateTracker", + ) -> Generator["SlotSet", None, None]: + """Fetch SlotSet events from tracker and carry over key, value and metadata.""" + + from rasa.core.events import SlotSet + + # use generator so the timestamps are greater than that of the returned + return ( + SlotSet(key=event.key, value=event.value, metadata=event.metadata) + for event in tracker.events + if isinstance(event, SlotSet) + ) + async def run( self, output_channel: "OutputChannel", @@ -324,20 +340,10 @@ async def run( tracker: "DialogueStateTracker", domain: "Domain", ) -> List[Event]: - from rasa.core.events import SessionStarted, SlotSet, FollowupAction + from rasa.core.events import SessionStarted # TODO: check in domain whether slots should be carried over - # fetch SlotSet events from tracker and carry over key, value and metadata - # use generator so the timestamps are greater than that of the returned - # `SessionStarted` event - slot_set_events = ( - SlotSet(key=event.key, value=event.value, metadata=event.metadata) - for event in tracker.events - if isinstance(event, SlotSet) - ) - # slot_set_events = [ - # SlotSet(f"dummy slot {i}", f"dummy slot {i}") for i in range(4) - # ] + slot_set_events = self._slot_set_events_from_tracker(tracker) # noinspection PyTypeChecker return ( From 35af03c0e55e1cfdf0f057d464c645d786b58580 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 11:09:40 +0100 Subject: [PATCH 24/74] update test --- tests/core/test_processor.py | 2 +- tests/importers/test_rasa.py | 2 +- tests/test_server.py | 12 +++--------- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index c2c48068e7c3..3a0ffef9c98e 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -310,7 +310,7 @@ async def test_is_legacy_tracker( # session start is way in the past (SessionStarted(timestamp=1), 60, True), # session start is very recent - (SessionStarted(timestamp=time.time()), 10, False), + (SessionStarted(timestamp=time.time()), 60, False), # there is no session start event (legacy tracker) (UserUttered("hello", timestamp=time.time()), 10, False), # there is no event diff --git a/tests/importers/test_rasa.py b/tests/importers/test_rasa.py index 5328df3b8188..dcfd6a75c638 100644 --- a/tests/importers/test_rasa.py +++ b/tests/importers/test_rasa.py @@ -22,7 +22,7 @@ async def test_rasa_file_importer(project: Text): assert len(domain.intents) == 7 assert domain.slots == [] assert domain.entities == [] - assert len(domain.action_names) == 14 + assert len(domain.action_names) == 15 assert len(domain.templates) == 6 stories = await importer.get_stories() diff --git a/tests/test_server.py b/tests/test_server.py index aa92ba8d7f24..1092b5faabda 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -536,13 +536,7 @@ def test_requesting_non_existent_tracker(rasa_app: SanicTestClient): assert content["slots"] == {"location": None, "cuisine": None} assert content["sender_id"] == "madeupid" assert content["events"] == [ - { - "event": "action", - "name": "session_start", - "policy": None, - "confidence": None, - "timestamp": 1514764800, - }, + {"event": "session_started", "timestamp": 1514764800,}, { "event": "action", "name": "action_listen", @@ -578,7 +572,7 @@ def test_pushing_event(rasa_app, event): assert tracker is not None assert len(tracker.get("events")) == 3 - evt = tracker.get("events")[1] + evt = tracker.get("events")[2] assert Event.from_parameters(evt) == event @@ -601,7 +595,7 @@ def test_push_multiple_events(rasa_app: SanicTestClient): # there is also an `ACTION_LISTEN` event at the start assert len(tracker.get("events")) == len(test_events) + 2 - assert tracker.get("events")[1:] == events + assert tracker.get("events")[2:] == events def test_put_tracker(rasa_app: SanicTestClient): From 7849dbf0c0f37b058d905b80e08b48958816fbda Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 4 Dec 2019 09:35:07 +0100 Subject: [PATCH 25/74] configure session handling through domain --- rasa/core/actions/action.py | 12 ++++---- rasa/core/domain.py | 35 +++++++++++++++++++++-- rasa/core/processor.py | 16 ++++------- rasa/core/schemas/domain.yml | 10 +++++++ tests/core/test_actions.py | 37 +++++++++++++++++++----- tests/core/test_domain.py | 54 +++++++++++++++++++++++++++++++++++- tests/core/test_processor.py | 14 ++++------ 7 files changed, 144 insertions(+), 34 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index afbc62494712..144b9a261dc5 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -330,11 +330,13 @@ async def run( # fetch SlotSet events from tracker and carry over key, value and metadata # use generator so the timestamps are greater than that of the returned # `SessionStarted` event - slot_set_events = ( - SlotSet(key=event.key, value=event.value, metadata=event.metadata) - for event in tracker.events - if isinstance(event, SlotSet) - ) + slot_set_events = [] + if domain.session_config.carry_over_slots: + slot_set_events = ( + SlotSet(key=event.key, value=event.value, metadata=event.metadata) + for event in tracker.events + if isinstance(event, SlotSet) + ) # noinspection PyTypeChecker return ( diff --git a/rasa/core/domain.py b/rasa/core/domain.py index f4ca2fb2b7f7..7957cdfd6e5c 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -5,7 +5,7 @@ import os import typing from pathlib import Path -from typing import Any, Dict, List, Optional, Text, Tuple, Union, Set +from typing import Any, Dict, List, Optional, Text, Tuple, Union, Set, NamedTuple import rasa.core.constants import rasa.utils.common as common_utils @@ -31,6 +31,11 @@ PREV_PREFIX = "prev_" ACTIVE_FORM_PREFIX = "active_form_" +DEFAULT_SESSION_LENGTH = 60 +DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION = True + +CARRY_OVER_SLOTS_KEY = "carry_over_slots_to_new_session" +SESSION_LENGTH_KEY = "session_length" if typing.TYPE_CHECKING: from rasa.core.trackers import DialogueStateTracker @@ -47,6 +52,11 @@ def __str__(self): return bcolors.FAIL + self.message + bcolors.ENDC +class SessionConfig(NamedTuple): + session_length: int + carry_over_slots: bool + + class Domain: """The domain specifies the universe in which the bot's policy acts. @@ -109,6 +119,7 @@ def from_dict(cls, data: Dict) -> "Domain": utter_templates = cls.collect_templates(data.get("templates", {})) slots = cls.collect_slots(data.get("slots", {})) additional_arguments = data.get("config", {}) + session_config = cls._get_session_config(additional_arguments) intents = data.get("intents", {}) return cls( @@ -118,9 +129,21 @@ def from_dict(cls, data: Dict) -> "Domain": utter_templates, data.get("actions", []), data.get("forms", []), + session_config=session_config, **additional_arguments, ) + @staticmethod + def _get_session_config(additional_arguments: Dict) -> SessionConfig: + session_length = additional_arguments.pop( + SESSION_LENGTH_KEY, DEFAULT_SESSION_LENGTH + ) + carry_over_slots = additional_arguments.pop( + CARRY_OVER_SLOTS_KEY, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION + ) + + return SessionConfig(session_length, carry_over_slots) + @classmethod def from_directory(cls, path: Text) -> "Domain": """Loads and merges multiple domain files recursively from a directory tree.""" @@ -278,6 +301,9 @@ def __init__( action_names: List[Text], form_names: List[Text], store_entities_as_slots: bool = True, + session_config: SessionConfig = SessionConfig( + DEFAULT_SESSION_LENGTH, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION + ), ) -> None: self.intent_properties = self.collect_intent_properties(intents) @@ -285,6 +311,7 @@ def __init__( self.form_names = form_names self.slots = slots self.templates = templates + self.session_config = session_config # only includes custom actions and utterance actions self.user_actions = action_names @@ -651,7 +678,11 @@ def _slot_definitions(self): return {slot.name: slot.persistence_info() for slot in self.slots} def as_dict(self) -> Dict[Text, Any]: - additional_config = {"store_entities_as_slots": self.store_entities_as_slots} + additional_config = { + "store_entities_as_slots": self.store_entities_as_slots, + SESSION_LENGTH_KEY: self.session_config.session_length, + CARRY_OVER_SLOTS_KEY: self.session_config.carry_over_slots, + } return { "config": additional_config, diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 4bf3cb07e69a..daac30091e19 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -614,14 +614,11 @@ def _session_start_timestamp_for( # this also is a legacy tracker (pre-sessions) return tracker.events[0].timestamp - def _has_session_expired( - self, tracker: DialogueStateTracker, session_length_in_minutes: float - ) -> bool: + def _has_session_expired(self, tracker: DialogueStateTracker) -> bool: """Determine whether the latest session in `tracker` has expired. Args: tracker: Tracker to inspect. - session_length_in_minutes: Session length in minutes. Returns: `True` if the session in `tracker` has expired, `False` otherwise. @@ -631,7 +628,7 @@ def _has_session_expired( time_delta_in_seconds = time.time() - session_start_timestamp - return time_delta_in_seconds / 60 > session_length_in_minutes + return time_delta_in_seconds / 60 > self.domain.session_config.session_length @staticmethod def _is_legacy_tracker(tracker: DialogueStateTracker) -> bool: @@ -654,15 +651,12 @@ def _is_legacy_tracker(tracker: DialogueStateTracker) -> bool: return last_executed_session_started_action is None - def _get_tracker( - self, sender_id: Text, session_length_in_minutes: int = 60 - ) -> Optional[DialogueStateTracker]: + def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]: sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID tracker = self.tracker_store.get_or_create_tracker(sender_id) - if self._is_legacy_tracker(tracker) or self._has_session_expired( - tracker, session_length_in_minutes - ): + if self._is_legacy_tracker(tracker) or self._has_session_expired(tracker): + tracker.update(SessionStarted()) return tracker diff --git a/rasa/core/schemas/domain.yml b/rasa/core/schemas/domain.yml index b6c7d6a23574..b86fc29639cd 100644 --- a/rasa/core/schemas/domain.yml +++ b/rasa/core/schemas/domain.yml @@ -31,3 +31,13 @@ mapping: sequence: - type: "str" allowempty: True + config: + type: "map" + allowempty: True + mapping: + store_entities_as_slots: + type: "bool" + session_length: + type: "int" + carry_over_slots_to_new_session: + type: "bool" diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index fd708d33ba7e..7036f4e86644 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -1,3 +1,5 @@ +from typing import List + import pytest from aioresponses import aioresponses @@ -26,7 +28,7 @@ ActionSessionStart, ) from rasa.core.channels import CollectingOutputChannel -from rasa.core.domain import Domain, InvalidDomain +from rasa.core.domain import Domain, InvalidDomain, SessionConfig from rasa.core.events import ( Restarted, SlotSet, @@ -35,6 +37,7 @@ Form, SessionStarted, ActionExecuted, + Event, ) from rasa.core.nlg.template import TemplatedNaturalLanguageGenerator from rasa.core.trackers import DialogueStateTracker @@ -509,11 +512,34 @@ async def test_action_session_start_without_slots( ] +@pytest.mark.parametrize( + "session_config, expected_events", + [ + ( + SessionConfig(123, True), + [ + ActionExecuted(action_name=ACTION_SESSION_START_NAME), + SlotSet("my_slot", "value"), + SlotSet("another-slot", "value2"), + ActionExecuted(action_name=ACTION_LISTEN_NAME), + ], + ), + ( + SessionConfig(123, False), + [ + ActionExecuted(action_name=ACTION_SESSION_START_NAME), + ActionExecuted(action_name=ACTION_LISTEN_NAME), + ], + ), + ], +) async def test_action_session_start_with_slots( default_channel: CollectingOutputChannel, template_nlg: TemplatedNaturalLanguageGenerator, template_sender_tracker: DialogueStateTracker, default_domain: Domain, + session_config: SessionConfig, + expected_events: List[Event], ): # set a few slots on tracker slot_set_event_1 = SlotSet("my_slot", "value") @@ -521,16 +547,13 @@ async def test_action_session_start_with_slots( for event in [slot_set_event_1, slot_set_event_2]: template_sender_tracker.update(event) + default_domain.session_config = session_config + events = await ActionSessionStart().run( default_channel, template_nlg, template_sender_tracker, default_domain ) - assert events == [ - ActionExecuted(action_name=ACTION_SESSION_START_NAME), - slot_set_event_1, - slot_set_event_2, - ActionExecuted(action_name=ACTION_LISTEN_NAME), - ] + assert events == expected_events # make sure that the list of events has ascending timestamps assert sorted(events, key=lambda x: x.timestamp) == events diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index c3bd74113530..f4df115a5d4a 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -11,7 +11,7 @@ SLOT_LAST_OBJECT_TYPE, ) from rasa.core import training, utils -from rasa.core.domain import Domain, InvalidDomain +from rasa.core.domain import Domain, InvalidDomain, SessionConfig from rasa.core.featurizers import MaxHistoryTrackerFeaturizer from rasa.core.slots import TextSlot, UnfeaturizedSlot from tests.core.conftest import DEFAULT_DOMAIN_PATH_WITH_SLOTS, DEFAULT_STORIES_FILE @@ -564,3 +564,55 @@ def test_add_knowledge_base_slots(default_domain): assert SLOT_LISTED_ITEMS in slot_names assert SLOT_LAST_OBJECT in slot_names assert SLOT_LAST_OBJECT_TYPE in slot_names + + +@pytest.mark.parametrize( + "input_domain, expected_session_length, expected_carry_over_slots, ", + [ + ( + """config: + session_length: 20 + carry_over_slots_to_new_session: true""", + 20, + True, + ), + ("", 60, True), + ( + """config: + carry_over_slots_to_new_session: false""", + 60, + False, + ), + ( + """config: + carry_over_slots_to_new_session: false""", + 60, + False, + ), + ( + """ +config: + session_length: 20 + carry_over_slots_to_new_session: False""", + 20, + False, + ), + ], +) +def test_session_config( + input_domain, expected_session_length: int, expected_carry_over_slots: bool +): + domain = Domain.from_yaml(input_domain) + assert domain.session_config.session_length == expected_session_length + assert domain.session_config.carry_over_slots == expected_carry_over_slots + + +def test_domain_as_dict_with_session_config(): + session_config = SessionConfig(123, False) + domain = Domain.empty() + domain.session_config = session_config + + serialized = domain.as_dict() + deserialized = Domain.from_dict(serialized) + + assert deserialized.session_config == session_config diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 2119d21b8a33..66a70b6eb575 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -40,7 +40,7 @@ from tests.utilities import json_of_latest_request, latest_request from tests.core.conftest import DEFAULT_DOMAIN_PATH_WITH_SLOTS -from rasa.core.domain import Domain +from rasa.core.domain import Domain, SessionConfig import logging @@ -298,7 +298,7 @@ async def test_reminder_restart( ], ) async def test_is_legacy_tracker( - events_to_apply: List[Event], is_legacy: bool, default_processor: MessageProcessor, + events_to_apply: List[Event], is_legacy: bool, default_processor: MessageProcessor ): sender_id = uuid.uuid4().hex @@ -334,6 +334,9 @@ async def test_has_session_expired( ): sender_id = uuid.uuid4().hex + default_processor.domain.session_config = SessionConfig( + session_length_in_minutes, True + ) # create new tracker without events tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) tracker.events.clear() @@ -343,9 +346,4 @@ async def test_has_session_expired( tracker.update(event_to_apply) # noinspection PyProtectedMember - assert ( - default_processor._has_session_expired( - tracker, session_length_in_minutes=session_length_in_minutes - ) - == has_expired - ) + assert default_processor._has_session_expired(tracker) == has_expired From 97d0cf757b2a73e975cee852e9ccfa3a9fea2724 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 11:16:20 +0100 Subject: [PATCH 26/74] remove prints --- rasa/core/processor.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 0d7381026ccc..50aae076632d 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -104,15 +104,10 @@ async def handle_message( return None await self._predict_and_execute_next_action(message, tracker) + # save tracker state to continue conversation from this state self._save_tracker(tracker) - print("have tracker events") - for e in tracker.events: - print(e) - print("\napplied events") - for e in tracker.applied_events(): - print(e) if isinstance(message.output_channel, CollectingOutputChannel): return message.output_channel.messages else: From 4b7b101a3b28cfe38a2441ca6c4d05137c9f596a Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 4 Dec 2019 11:24:40 +0100 Subject: [PATCH 27/74] support floats for session lengths --- rasa/core/domain.py | 2 +- rasa/core/schemas/domain.yml | 2 +- tests/core/test_domain.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rasa/core/domain.py b/rasa/core/domain.py index 7957cdfd6e5c..259d4c4e7f37 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -53,7 +53,7 @@ def __str__(self): class SessionConfig(NamedTuple): - session_length: int + session_length: Union[int, float] carry_over_slots: bool diff --git a/rasa/core/schemas/domain.yml b/rasa/core/schemas/domain.yml index b86fc29639cd..714a7ef803e1 100644 --- a/rasa/core/schemas/domain.yml +++ b/rasa/core/schemas/domain.yml @@ -38,6 +38,6 @@ mapping: store_entities_as_slots: type: "bool" session_length: - type: "int" + type: "number" carry_over_slots_to_new_session: type: "bool" diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index 0e951e8d629c..cc3ce1fb82e0 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -592,9 +592,9 @@ def test_add_knowledge_base_slots(default_domain): ( """ config: - session_length: 20 + session_length: 20.2 carry_over_slots_to_new_session: False""", - 20, + 20.2, False, ), ], From 25160d909c4c0961bafb7c6597e2fc161d8d99f6 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 14:40:55 +0100 Subject: [PATCH 28/74] wip --- rasa/core/tracker_store.py | 43 +++++++++++++++++++++++------ requirements-dev.txt | 1 + setup.py | 1 + tests/core/conftest.py | 3 +-- tests/core/test_tracker_stores.py | 45 +++++++++++++++++++++++-------- 5 files changed, 72 insertions(+), 21 deletions(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 8430134db5f9..c2329048caa5 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -7,7 +7,7 @@ import typing from datetime import datetime, timezone -from typing import Iterator, Optional, Text, Iterable, Union, Dict, Callable +from typing import Iterator, Optional, Text, Iterable, Union, Dict, Callable, List import itertools from boto3.dynamodb.conditions import Key @@ -471,12 +471,37 @@ def save(self, tracker, timeout=None): if self.event_broker: self.stream_events(tracker) - state = tracker.current_state(EventVerbosity.ALL) - + additional_events = list(self._additional_events(tracker)) + print("have additional events", len(additional_events)) + for e in additional_events: + print(e) self.conversations.update_one( - {"sender_id": tracker.sender_id}, {"$set": state}, upsert=True + {"sender_id": tracker.sender_id}, + {"$push": {"events": {"$each": additional_events}}}, + upsert=True, ) + def _additional_events(self, tracker: DialogueStateTracker) -> List[Dict]: + """Return events from the tracker which aren't currently stored.""" + + stored = self.conversations.find_one({"sender_id": tracker.sender_id}) + n_events = len(stored.get("events", [])) if stored else 0 + + return [ + event.as_dict() + for event in itertools.islice(tracker.events, n_events, len(tracker.events)) + ] + + @staticmethod + def _events_since_last_session_start(serialised_tracker: Dict) -> List[Dict]: + events = [] + for event in serialised_tracker.get("events", []): + events.append(event) + if event["event"] == SessionStarted.type_name: + break + + return list(reversed(events)) + def retrieve(self, sender_id): """ Args: @@ -499,9 +524,8 @@ def retrieve(self, sender_id): ) if stored is not None: - return DialogueStateTracker.from_dict( - sender_id, stored.get("events"), self.domain.slots - ) + events = self._events_since_last_session_start(stored) + return DialogueStateTracker.from_dict(sender_id, events, self.domain.slots) else: return None @@ -762,7 +786,10 @@ def save(self, tracker: DialogueStateTracker) -> None: with self.session_scope() as session: # only store recent events events = self._additional_events(session, tracker) - + events = list(events) + print("have additional events", len(events)) + for e in events: + print(e) for event in events: data = event.as_dict() diff --git a/requirements-dev.txt b/requirements-dev.txt index 3b400e1f52ce..e724c6ad87c4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,6 +16,7 @@ aioresponses==0.6.0 moto==1.3.8 fakeredis==1.0.3 six>=1.12.0 # upstream - should be removed if fakeredis depends on at least 1.12.0 +mongomock==3.18.0 # lint/format/types black==19.10b0 diff --git a/setup.py b/setup.py index 9d19a743bb78..12a8f0e53d78 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ "aioresponses~=0.6.0", "moto~=1.3.8", "fakeredis~=1.0", + "mongomock~=3.18.0", ] install_requires = [ diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 9de378579669..66c453e9e7da 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -1,8 +1,8 @@ import asyncio import os + from typing import Text -import matplotlib import pytest import rasa.utils.io @@ -21,7 +21,6 @@ from rasa.core.slots import Slot from rasa.core.tracker_store import InMemoryTrackerStore from rasa.core.trackers import DialogueStateTracker -from rasa.train import train_async DEFAULT_DOMAIN_PATH_WITH_SLOTS = "data/test_domains/default_with_slots.yml" diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index 3327ac8b1ff9..21d310290be9 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -1,6 +1,7 @@ import logging import tempfile -from typing import Tuple, Text + +from typing import Tuple, Text, Callable, Type from unittest.mock import Mock import pytest @@ -8,6 +9,7 @@ from _pytest.monkeypatch import MonkeyPatch from moto import mock_dynamodb2 +from rasa.core.brokers.event_channel import EventChannel from rasa.core.channels.channel import UserMessage from rasa.core.domain import Domain from rasa.core.events import ( @@ -25,17 +27,30 @@ SQLTrackerStore, DynamoTrackerStore, FailSafeTrackerStore, + MongoTrackerStore, ) import rasa.core.tracker_store from rasa.core.trackers import DialogueStateTracker from rasa.utils.endpoints import EndpointConfig, read_endpoint_config -from tests.conftest import assert_log_emitted from tests.core.conftest import DEFAULT_ENDPOINTS_FILE domain = Domain.load("data/test_domains/default.yml") -def get_or_create_tracker_store(store: "TrackerStore"): +class MockedMongoTrackerStore(MongoTrackerStore): + """In-memory mocked version of `MongoTrackerStore`.""" + + def __init__( + self, _domain: Domain, + ): + from mongomock import MongoClient + + self.db = MongoClient().rasa + self.collection = "conversations" + super(MongoTrackerStore, self).__init__(domain, None) + + +def get_or_create_tracker_store(store: TrackerStore): slot_key = "location" slot_val = "Easter Island" @@ -388,10 +403,13 @@ def test_set_fail_safe_tracker_store_domain(default_domain: Domain): assert fallback_tracker_store.domain is failsafe_store.domain -def test_sql_tracker_store_retrieve_with_session_started_events(default_domain: Domain): - tracker_store = SQLTrackerStore(default_domain, host="sqlite:///") - - # Create tracker with a SessionStarted event +@pytest.mark.parametrize( + "tracker_store_type", [MockedMongoTrackerStore, SQLTrackerStore], +) +def test_sql_tracker_store_retrieve_with_session_started_events( + tracker_store_type: Type[TrackerStore], default_domain: Domain +): + tracker_store = tracker_store_type(default_domain) events = [ UserUttered("Hola", {"name": "greet"}), BotUttered("Hi"), @@ -406,17 +424,22 @@ def test_sql_tracker_store_retrieve_with_session_started_events(default_domain: other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()]) tracker_store.save(other_tracker) - # Retrieve tracker with events since latest restart + # Retrieve tracker with events since latest SessionStarted tracker = tracker_store.retrieve(sender_id) - + print("fetched") + for e in tracker.events: + print(e) assert len(tracker.events) == 2 assert all((event == tracker.events[i] for i, event in enumerate(events[2:]))) +@pytest.mark.parametrize( + "tracker_store_factory", [SQLTrackerStore], +) def test_sql_tracker_store_retrieve_without_session_started_events( - default_domain: Domain, + tracker_store_factory: Callable, default_domain: Domain ): - tracker_store = SQLTrackerStore(default_domain, host="sqlite:///") + tracker_store = tracker_store_factory(default_domain) # Create tracker with a SessionStarted event events = [ From 3c63765aa32fc5296aac8d5568eb06f85d2a543e Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 16:06:48 +0100 Subject: [PATCH 29/74] update tests --- changelog/4830.improvement.rst | 3 +- rasa/core/tracker_store.py | 42 ++++++++++++++---------- tests/core/test_tracker_stores.py | 53 ++++++++++++++++++++++++------- 3 files changed, 68 insertions(+), 30 deletions(-) diff --git a/changelog/4830.improvement.rst b/changelog/4830.improvement.rst index 6c81fff327a8..7ee96d0d9cec 100644 --- a/changelog/4830.improvement.rst +++ b/changelog/4830.improvement.rst @@ -1 +1,2 @@ -``SQLTrackerStore`` only retrieves events from the last session from the database. \ No newline at end of file +``SQLTrackerStore`` and ``MongoTrackerStore`` only retrieve events from the last +session from the database. diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index c2329048caa5..d8979b6b427b 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -471,31 +471,43 @@ def save(self, tracker, timeout=None): if self.event_broker: self.stream_events(tracker) - additional_events = list(self._additional_events(tracker)) - print("have additional events", len(additional_events)) - for e in additional_events: - print(e) + additional_events = self._additional_events(tracker) + self.conversations.update_one( {"sender_id": tracker.sender_id}, - {"$push": {"events": {"$each": additional_events}}}, + {"$push": {"events": {"$each": [e.as_dict() for e in additional_events]}}}, upsert=True, ) - def _additional_events(self, tracker: DialogueStateTracker) -> List[Dict]: - """Return events from the tracker which aren't currently stored.""" + def _additional_events(self, tracker: DialogueStateTracker) -> Iterator: + """Return events from the tracker which aren't currently stored. + Args: + tracker: Tracker to inspect. + + Returns: + List of serialised events that aren't current stored. + + """ stored = self.conversations.find_one({"sender_id": tracker.sender_id}) n_events = len(stored.get("events", [])) if stored else 0 - return [ - event.as_dict() - for event in itertools.islice(tracker.events, n_events, len(tracker.events)) - ] + return itertools.islice(tracker.events, n_events, len(tracker.events)) @staticmethod def _events_since_last_session_start(serialised_tracker: Dict) -> List[Dict]: + """Retrieve events since and including the latest `SessionStart` event. + + Args: + serialised_tracker: Serialised tracker to inspect. + + Returns: + List of serialised events since and including the latest `SessionStarted` + event. Returns all events if no such event is found. + + """ events = [] - for event in serialised_tracker.get("events", []): + for event in reversed(serialised_tracker.get("events", [])): events.append(event) if event["event"] == SessionStarted.type_name: break @@ -786,13 +798,9 @@ def save(self, tracker: DialogueStateTracker) -> None: with self.session_scope() as session: # only store recent events events = self._additional_events(session, tracker) - events = list(events) - print("have additional events", len(events)) - for e in events: - print(e) + for event in events: data = event.as_dict() - intent = data.get("parse_data", {}).get("intent", {}).get("name") action = data.get("name") timestamp = data.get("timestamp") diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index 21d310290be9..1e08cfd282eb 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -1,7 +1,8 @@ import logging import tempfile +import uuid -from typing import Tuple, Text, Callable, Type +from typing import Tuple, Text, Callable, Type, Dict from unittest.mock import Mock import pytest @@ -9,6 +10,7 @@ from _pytest.monkeypatch import MonkeyPatch from moto import mock_dynamodb2 +from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.brokers.event_channel import EventChannel from rasa.core.channels.channel import UserMessage from rasa.core.domain import Domain @@ -403,13 +405,39 @@ def test_set_fail_safe_tracker_store_domain(default_domain: Domain): assert fallback_tracker_store.domain is failsafe_store.domain +def test_mongo_additional_events(default_domain: Domain): + tracker_store = MockedMongoTrackerStore(default_domain) + sender_id = uuid.uuid4().hex + + # create tracker with two events and save it + events_1 = [UserUttered("hello"), BotUttered("what")] + tracker = DialogueStateTracker.from_events(sender_id, events_1) + tracker_store.save(tracker) + + # add more events to the tracker, do not yet save it + events_2 = [ + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered("123"), + BotUttered("yes"), + ] + for event in events_2: + tracker.update(event) + + # make sure only new events are returned + # noinspection PyProtectedMember + assert list(tracker_store._additional_events(tracker)) == events_2 + + @pytest.mark.parametrize( - "tracker_store_type", [MockedMongoTrackerStore, SQLTrackerStore], + "tracker_store_type,tracker_store_kwargs", + [(MockedMongoTrackerStore, {}), (SQLTrackerStore, {"host": "sqlite:///"})], ) -def test_sql_tracker_store_retrieve_with_session_started_events( - tracker_store_type: Type[TrackerStore], default_domain: Domain +def test_tracker_store_retrieve_with_session_started_events( + tracker_store_type: Type[TrackerStore], + tracker_store_kwargs: Dict, + default_domain: Domain, ): - tracker_store = tracker_store_type(default_domain) + tracker_store = tracker_store_type(default_domain, **tracker_store_kwargs) events = [ UserUttered("Hola", {"name": "greet"}), BotUttered("Hi"), @@ -426,20 +454,21 @@ def test_sql_tracker_store_retrieve_with_session_started_events( # Retrieve tracker with events since latest SessionStarted tracker = tracker_store.retrieve(sender_id) - print("fetched") - for e in tracker.events: - print(e) + assert len(tracker.events) == 2 assert all((event == tracker.events[i] for i, event in enumerate(events[2:]))) @pytest.mark.parametrize( - "tracker_store_factory", [SQLTrackerStore], + "tracker_store_type,tracker_store_kwargs", + [(MockedMongoTrackerStore, {}), (SQLTrackerStore, {"host": "sqlite:///"})], ) -def test_sql_tracker_store_retrieve_without_session_started_events( - tracker_store_factory: Callable, default_domain: Domain +def test_tracker_store_retrieve_without_session_started_events( + tracker_store_type: Type[TrackerStore], + tracker_store_kwargs: Dict, + default_domain: Domain, ): - tracker_store = tracker_store_factory(default_domain) + tracker_store = tracker_store_type(default_domain, **tracker_store_kwargs) # Create tracker with a SessionStarted event events = [ From 6740ff7f428c2fe08dc487fffbbadacadf6333df Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 16:08:08 +0100 Subject: [PATCH 30/74] use single changelog --- changelog/4830.feature.rst | 3 +++ changelog/4830.improvement.rst | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) delete mode 100644 changelog/4830.improvement.rst diff --git a/changelog/4830.feature.rst b/changelog/4830.feature.rst index 3ddad54229fc..6457a66f8838 100644 --- a/changelog/4830.feature.rst +++ b/changelog/4830.feature.rst @@ -6,3 +6,6 @@ events from the previous session and applies it to the next session. Added new default intent ``session_start`` which triggers the start of a new conversation session. + +``SQLTrackerStore`` and ``MongoTrackerStore`` only retrieve events from the last +session from the database. diff --git a/changelog/4830.improvement.rst b/changelog/4830.improvement.rst deleted file mode 100644 index 7ee96d0d9cec..000000000000 --- a/changelog/4830.improvement.rst +++ /dev/null @@ -1,2 +0,0 @@ -``SQLTrackerStore`` and ``MongoTrackerStore`` only retrieve events from the last -session from the database. From f8b84030ac16807052b4694558ae75c74d0f7d0f Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 17:12:51 +0100 Subject: [PATCH 31/74] add explicit test for mongo events since last session start --- tests/core/test_tracker_stores.py | 65 ++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 9 deletions(-) diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index 1e08cfd282eb..fecea0edfa79 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -1,8 +1,9 @@ import logging import tempfile import uuid +from sqlalchemy.orm import Session -from typing import Tuple, Text, Callable, Type, Dict +from typing import Tuple, Text, Callable, Type, Dict, List from unittest.mock import Mock import pytest @@ -21,6 +22,7 @@ UserUttered, SessionStarted, BotUttered, + Event, ) from rasa.core.tracker_store import ( TrackerStore, @@ -32,7 +34,7 @@ MongoTrackerStore, ) import rasa.core.tracker_store -from rasa.core.trackers import DialogueStateTracker +from rasa.core.trackers import DialogueStateTracker, EventVerbosity from rasa.utils.endpoints import EndpointConfig, read_endpoint_config from tests.core.conftest import DEFAULT_ENDPOINTS_FILE @@ -405,27 +407,72 @@ def test_set_fail_safe_tracker_store_domain(default_domain: Domain): assert fallback_tracker_store.domain is failsafe_store.domain -def test_mongo_additional_events(default_domain: Domain): - tracker_store = MockedMongoTrackerStore(default_domain) +def create_tracker_with_partially_saved_events( + tracker_store: TrackerStore, +) -> Tuple[List[Event], DialogueStateTracker]: + # creates a tracker with two events and saved it to the tracker store + # following that, it adds three more events that are not saved to the tracker store sender_id = uuid.uuid4().hex # create tracker with two events and save it - events_1 = [UserUttered("hello"), BotUttered("what")] - tracker = DialogueStateTracker.from_events(sender_id, events_1) + events = [UserUttered("hello"), BotUttered("what")] + tracker = DialogueStateTracker.from_events(sender_id, events) tracker_store.save(tracker) # add more events to the tracker, do not yet save it - events_2 = [ + events = [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("123"), BotUttered("yes"), ] - for event in events_2: + for event in events: tracker.update(event) + return events, tracker + + +def test_mongo_additional_events(default_domain: Domain): + tracker_store = MockedMongoTrackerStore(default_domain) + events, tracker = create_tracker_with_partially_saved_events(tracker_store) + # make sure only new events are returned # noinspection PyProtectedMember - assert list(tracker_store._additional_events(tracker)) == events_2 + assert list(tracker_store._additional_events(tracker)) == events + + +# we cannot parametrise over this and the previous test due to the different ways of +# calling _additional_events() +def test_sql_additional_events(default_domain: Domain): + tracker_store = SQLTrackerStore(default_domain) + events, tracker = create_tracker_with_partially_saved_events(tracker_store) + + # make sure only new events are returned + with tracker_store.session_scope() as session: + # noinspection PyProtectedMember + assert list(tracker_store._additional_events(session, tracker)) == events + + +def test_mongo_events_since_last_session_start(default_domain: Domain): + tracker_store = MockedMongoTrackerStore(default_domain) + + # create tracker with events + events = [ + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered("123"), + BotUttered("yes"), + SessionStarted(), + UserUttered("hello"), + BotUttered("welcome to your new session"), + ] + + tracker = DialogueStateTracker.from_events("conversation with session", events) + serialised = tracker.current_state(EventVerbosity.ALL) + + # make sure only events post session start are returned + # noinspection PyProtectedMember + assert tracker_store._events_since_last_session_start(serialised) == [ + event.as_dict() for event in events[3:] + ] @pytest.mark.parametrize( From a24fed6b32d7a4d5d36c2ba4d24a7297409ab97c Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 4 Dec 2019 18:19:18 +0100 Subject: [PATCH 32/74] fix test --- tests/core/test_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 3a0ffef9c98e..ecdd0f6a396b 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -312,7 +312,7 @@ async def test_is_legacy_tracker( # session start is very recent (SessionStarted(timestamp=time.time()), 60, False), # there is no session start event (legacy tracker) - (UserUttered("hello", timestamp=time.time()), 10, False), + (UserUttered("hello", timestamp=time.time()), 60, False), # there is no event (None, 1, False), ], From 60a0aef0225ed3b36405d38bbc95e59fa9b62c2a Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 4 Dec 2019 22:37:31 +0100 Subject: [PATCH 33/74] Update rasa/core/tracker_store.py Co-Authored-By: ricwo --- rasa/core/tracker_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 8430134db5f9..93ef9aa70dfb 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -710,7 +710,7 @@ def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]: from rasa.core.events import SessionStarted with self.session_scope() as session: - # Subquery to find the timestamp of the first `SessionStartedEvent`. + # Subquery to find the timestamp of the first `SessionStarted` event session_start_sub_query = ( session.query( sa.func.max(self.SQLEvent.timestamp).label("session_start") From 9c56a733c160eef0a583febd4a62fddc1958bd8a Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 4 Dec 2019 22:45:03 +0100 Subject: [PATCH 34/74] move session config constants to constants module --- rasa/constants.py | 3 +++ rasa/core/domain.py | 12 +++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/rasa/constants.py b/rasa/constants.py index a270ab695e18..122db69e6f28 100644 --- a/rasa/constants.py +++ b/rasa/constants.py @@ -46,3 +46,6 @@ DEFAULT_SANIC_WORKERS = 1 ENV_SANIC_WORKERS = "SANIC_WORKERS" ENV_SANIC_BACKLOG = "SANIC_BACKLOG" + +DEFAULT_SESSION_LENGTH_IN_MINUTES = 60 +DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION = True diff --git a/rasa/core/domain.py b/rasa/core/domain.py index 259d4c4e7f37..113046c5bfe6 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -11,7 +11,11 @@ import rasa.utils.common as common_utils import rasa.utils.io from rasa.cli.utils import bcolors -from rasa.constants import DOMAIN_SCHEMA_FILE +from rasa.constants import ( + DOMAIN_SCHEMA_FILE, + DEFAULT_SESSION_LENGTH_IN_MINUTES, + DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION, +) from rasa.core import utils from rasa.core.actions import action # pytype: disable=pyi-error from rasa.core.actions.action import Action # pytype: disable=pyi-error @@ -31,8 +35,6 @@ PREV_PREFIX = "prev_" ACTIVE_FORM_PREFIX = "active_form_" -DEFAULT_SESSION_LENGTH = 60 -DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION = True CARRY_OVER_SLOTS_KEY = "carry_over_slots_to_new_session" SESSION_LENGTH_KEY = "session_length" @@ -136,7 +138,7 @@ def from_dict(cls, data: Dict) -> "Domain": @staticmethod def _get_session_config(additional_arguments: Dict) -> SessionConfig: session_length = additional_arguments.pop( - SESSION_LENGTH_KEY, DEFAULT_SESSION_LENGTH + SESSION_LENGTH_KEY, DEFAULT_SESSION_LENGTH_IN_MINUTES ) carry_over_slots = additional_arguments.pop( CARRY_OVER_SLOTS_KEY, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION @@ -302,7 +304,7 @@ def __init__( form_names: List[Text], store_entities_as_slots: bool = True, session_config: SessionConfig = SessionConfig( - DEFAULT_SESSION_LENGTH, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION + DEFAULT_SESSION_LENGTH_IN_MINUTES, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION ), ) -> None: From 96860db47bb630047b9adc7dfd53ca4b028c1e5e Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 4 Dec 2019 22:46:24 +0100 Subject: [PATCH 35/74] use float instead of Union[float, int] due to the numeric tower rule --- rasa/core/domain.py | 2 +- tests/core/test_domain.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/domain.py b/rasa/core/domain.py index 113046c5bfe6..9ba02b53fd6f 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -55,7 +55,7 @@ def __str__(self): class SessionConfig(NamedTuple): - session_length: Union[int, float] + session_length: float carry_over_slots: bool diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index cc3ce1fb82e0..a720c168c8d9 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -600,7 +600,7 @@ def test_add_knowledge_base_slots(default_domain): ], ) def test_session_config( - input_domain, expected_session_length: int, expected_carry_over_slots: bool + input_domain, expected_session_length: float, expected_carry_over_slots: bool ): domain = Domain.from_yaml(input_domain) assert domain.session_config.session_length == expected_session_length From f316964783a851b4de8eb4db696576b98d8022e9 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 4 Dec 2019 22:48:30 +0100 Subject: [PATCH 36/74] update yaml test case due to new config entries --- tests/core/test_domain.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index a720c168c8d9..1003c3b44811 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -230,6 +230,8 @@ def test_domain_to_yaml(): test_yaml = """actions: - utter_greet config: + carry_over_slots_to_new_session: true + session_length: 60 store_entities_as_slots: true entities: [] forms: [] @@ -567,7 +569,7 @@ def test_add_knowledge_base_slots(default_domain): @pytest.mark.parametrize( - "input_domain, expected_session_length, expected_carry_over_slots, ", + "input_domain, expected_session_length, expected_carry_over_slots", [ ( """config: From 22fc7e1e54525a685191891ea71f8e92c4eaac41 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 4 Dec 2019 22:52:46 +0100 Subject: [PATCH 37/74] validate session length range with pykwalify --- rasa/core/schemas/domain.yml | 2 ++ tests/core/test_domain.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/rasa/core/schemas/domain.yml b/rasa/core/schemas/domain.yml index 714a7ef803e1..8f4651c92aa7 100644 --- a/rasa/core/schemas/domain.yml +++ b/rasa/core/schemas/domain.yml @@ -39,5 +39,7 @@ mapping: type: "bool" session_length: type: "number" + range: + min: 0 carry_over_slots_to_new_session: type: "bool" diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index 1003c3b44811..9543bacca922 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -573,9 +573,9 @@ def test_add_knowledge_base_slots(default_domain): [ ( """config: - session_length: 20 + session_length: 0 carry_over_slots_to_new_session: true""", - 20, + 0, True, ), ("", 60, True), From 1958128b82da9db39b170a0c4eb4e43f2e1a8a8d Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 4 Dec 2019 22:58:23 +0100 Subject: [PATCH 38/74] remove extra session length parameter --- rasa/core/processor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index cde9927b6499..5789b5636456 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -191,9 +191,7 @@ async def log_message( tracker = self._get_tracker(message.sender_id) if tracker: - # TODO: get session length from domain - await self._update_tracker_session(tracker, message.output_channel, 1) - + await self._update_tracker_session(tracker, message.output_channel) await self._handle_message_with_tracker(message, tracker) if should_save_tracker: From 2685c629351e01bd0ab103c9e64816386c0548a3 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Thu, 5 Dec 2019 10:15:40 +0100 Subject: [PATCH 39/74] reset tracker store after each test in defaut_agent --- tests/core/conftest.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 9de378579669..2ca79bbca8e1 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -4,10 +4,11 @@ import matplotlib import pytest +from _pytest.tmpdir import TempdirFactory import rasa.utils.io from rasa.core.agent import Agent -from rasa.core.channels.channel import CollectingOutputChannel +from rasa.core.channels.channel import CollectingOutputChannel, OutputChannel from rasa.core.domain import Domain from rasa.core.interpreter import RegexInterpreter from rasa.core.nlg import TemplatedNaturalLanguageGenerator @@ -106,7 +107,7 @@ def default_domain(): @pytest.fixture(scope="session") -async def default_agent(default_domain) -> Agent: +async def _default_agent(default_domain: Domain) -> Agent: agent = Agent( default_domain, policies=[MemoizationPolicy()], @@ -118,15 +119,22 @@ async def default_agent(default_domain) -> Agent: return agent +@pytest.fixture() +async def default_agent(_default_agent: Agent) -> Agent: + # Clean tracker store after each test so tests don't affect each other + _default_agent.tracker_store = InMemoryTrackerStore(_default_agent.domain) + return _default_agent + + @pytest.fixture(scope="session") -def default_agent_path(default_agent, tmpdir_factory): +def default_agent_path(default_agent: Agent, tmpdir_factory: TempdirFactory): path = tmpdir_factory.mktemp("agent").strpath default_agent.persist(path) return path @pytest.fixture -def default_channel(): +def default_channel() -> OutputChannel: return CollectingOutputChannel() From 2db0f83e5537ea7e5f447580df319d6edb02d732 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Thu, 5 Dec 2019 11:50:40 +0100 Subject: [PATCH 40/74] fix tests which affect each others base state --- rasa/core/domain.py | 10 +++++++--- tests/core/conftest.py | 3 ++- tests/core/test_agent.py | 8 ++++++-- tests/core/test_trackers.py | 5 +++-- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/rasa/core/domain.py b/rasa/core/domain.py index 9ba02b53fd6f..82ca13d0f0e6 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -58,6 +58,12 @@ class SessionConfig(NamedTuple): session_length: float carry_over_slots: bool + @staticmethod + def default() -> "SessionConfig": + return SessionConfig( + DEFAULT_SESSION_LENGTH_IN_MINUTES, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION + ) + class Domain: """The domain specifies the universe in which the bot's policy acts. @@ -303,9 +309,7 @@ def __init__( action_names: List[Text], form_names: List[Text], store_entities_as_slots: bool = True, - session_config: SessionConfig = SessionConfig( - DEFAULT_SESSION_LENGTH_IN_MINUTES, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION - ), + session_config: SessionConfig = SessionConfig.default(), ) -> None: self.intent_properties = self.collect_intent_properties(intents) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 2ca79bbca8e1..1a0f04956e9f 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -9,7 +9,7 @@ import rasa.utils.io from rasa.core.agent import Agent from rasa.core.channels.channel import CollectingOutputChannel, OutputChannel -from rasa.core.domain import Domain +from rasa.core.domain import Domain, SessionConfig from rasa.core.interpreter import RegexInterpreter from rasa.core.nlg import TemplatedNaturalLanguageGenerator from rasa.core.policies.ensemble import PolicyEnsemble, SimplePolicyEnsemble @@ -123,6 +123,7 @@ async def _default_agent(default_domain: Domain) -> Agent: async def default_agent(_default_agent: Agent) -> Agent: # Clean tracker store after each test so tests don't affect each other _default_agent.tracker_store = InMemoryTrackerStore(_default_agent.domain) + _default_agent.domain.session_config = SessionConfig.default() return _default_agent diff --git a/tests/core/test_agent.py b/tests/core/test_agent.py index 6464ba7f718e..6def31338e3c 100644 --- a/tests/core/test_agent.py +++ b/tests/core/test_agent.py @@ -247,10 +247,14 @@ def test_two_stage_fallback_without_deny_suggestion(domain, policy_config): assert "The intent 'out_of_scope' must be present" in str(execinfo.value) -async def test_agent_update_model_none_domain(trained_model): +async def test_agent_update_model_none_domain(trained_model: Text): agent = await load_agent(model_path=trained_model) agent.update_model( - None, None, agent.fingerprint, agent.interpreter, agent.model_directory + Domain.empty(), + None, + agent.fingerprint, + agent.interpreter, + agent.model_directory, ) sender_id = "test_sender_id" diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index e60f2fcdaaed..912b8b6863ff 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -10,6 +10,7 @@ import rasa.utils.io from rasa.core import training, restore from rasa.core.actions.action import ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME +from rasa.core.agent import Agent from rasa.core.domain import Domain from rasa.core.events import ( SlotSet, @@ -139,7 +140,7 @@ async def test_tracker_write_to_story(tmpdir, moodbot_domain: Domain): assert recovered.events[4].intent == {"confidence": 1.0, "name": "mood_unhappy"} -async def test_tracker_state_regression_without_bot_utterance(default_agent): +async def test_tracker_state_regression_without_bot_utterance(default_agent: Agent): sender_id = "test_tracker_state_regression_without_bot_utterance" for i in range(0, 2): await default_agent.handle_message("/greet", sender_id=sender_id) @@ -154,7 +155,7 @@ async def test_tracker_state_regression_without_bot_utterance(default_agent): ) -async def test_tracker_state_regression_with_bot_utterance(default_agent): +async def test_tracker_state_regression_with_bot_utterance(default_agent: Agent): sender_id = "test_tracker_state_regression_with_bot_utterance" for i in range(0, 2): await default_agent.handle_message("/greet", sender_id=sender_id) From 74ce68233441bab519ba1664df3503b2a8b1db19 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Thu, 5 Dec 2019 13:32:04 +0100 Subject: [PATCH 41/74] fix deprecated use of logger.warn --- rasa/nlu/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/nlu/test.py b/rasa/nlu/test.py index 5a7b97ac7b0e..57610af3e412 100644 --- a/rasa/nlu/test.py +++ b/rasa/nlu/test.py @@ -725,7 +725,7 @@ def do_entities_overlap(entities: List[Dict]) -> bool: next_ent["start"] < curr_ent["end"] and next_ent["entity"] != curr_ent["entity"] ): - logger.warn(f"Overlapping entity {curr_ent} with {next_ent}") + logger.warning(f"Overlapping entity {curr_ent} with {next_ent}") return True return False From a3477429e9b251d1446ae5e307bb2c2e2fc55b18 Mon Sep 17 00:00:00 2001 From: ricwo Date: Thu, 5 Dec 2019 14:44:36 +0100 Subject: [PATCH 42/74] Apply suggestions from code review Co-Authored-By: Tobias Wochinger --- tests/core/test_tracker_stores.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index fecea0edfa79..52082b5cb236 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -1,9 +1,8 @@ import logging import tempfile import uuid -from sqlalchemy.orm import Session -from typing import Tuple, Text, Callable, Type, Dict, List +from typing import Tuple, Text, Type, Dict, List from unittest.mock import Mock import pytest @@ -12,7 +11,6 @@ from moto import mock_dynamodb2 from rasa.core.actions.action import ACTION_LISTEN_NAME -from rasa.core.brokers.event_channel import EventChannel from rasa.core.channels.channel import UserMessage from rasa.core.domain import Domain from rasa.core.events import ( @@ -444,12 +442,12 @@ def test_mongo_additional_events(default_domain: Domain): # calling _additional_events() def test_sql_additional_events(default_domain: Domain): tracker_store = SQLTrackerStore(default_domain) - events, tracker = create_tracker_with_partially_saved_events(tracker_store) + additional_events, tracker = create_tracker_with_partially_saved_events(tracker_store) # make sure only new events are returned with tracker_store.session_scope() as session: # noinspection PyProtectedMember - assert list(tracker_store._additional_events(session, tracker)) == events + assert list(tracker_store._additional_events(session, tracker)) == additional_events def test_mongo_events_since_last_session_start(default_domain: Domain): From 01082a0a80fcfecaa368b0d7ecb7e136caf9f536 Mon Sep 17 00:00:00 2001 From: ricwo Date: Thu, 5 Dec 2019 16:05:49 +0100 Subject: [PATCH 43/74] black --- tests/core/test_tracker_stores.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index 52082b5cb236..d7e7bb7283dd 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -442,12 +442,17 @@ def test_mongo_additional_events(default_domain: Domain): # calling _additional_events() def test_sql_additional_events(default_domain: Domain): tracker_store = SQLTrackerStore(default_domain) - additional_events, tracker = create_tracker_with_partially_saved_events(tracker_store) + additional_events, tracker = create_tracker_with_partially_saved_events( + tracker_store + ) # make sure only new events are returned with tracker_store.session_scope() as session: # noinspection PyProtectedMember - assert list(tracker_store._additional_events(session, tracker)) == additional_events + assert ( + list(tracker_store._additional_events(session, tracker)) + == additional_events + ) def test_mongo_events_since_last_session_start(default_domain: Domain): From a21c94800a0aadf5964d1bb14f1c18544611272a Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Thu, 5 Dec 2019 16:26:38 +0100 Subject: [PATCH 44/74] use session scoped default_agent with session scoped tempdir --- tests/core/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 1a0f04956e9f..986667145a5c 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -128,9 +128,9 @@ async def default_agent(_default_agent: Agent) -> Agent: @pytest.fixture(scope="session") -def default_agent_path(default_agent: Agent, tmpdir_factory: TempdirFactory): +def default_agent_path(_default_agent: Agent, tmpdir_factory: TempdirFactory): path = tmpdir_factory.mktemp("agent").strpath - default_agent.persist(path) + _default_agent.persist(path) return path From 443cc166f997678483c3582ccb193d2574cd643f Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Thu, 5 Dec 2019 16:34:42 +0100 Subject: [PATCH 45/74] disable sessions by using a session length <= 0 --- rasa/core/domain.py | 3 +++ rasa/core/processor.py | 5 +++++ tests/core/test_domain.py | 12 ++++++++++++ tests/core/test_processor.py | 2 ++ 4 files changed, 22 insertions(+) diff --git a/rasa/core/domain.py b/rasa/core/domain.py index 82ca13d0f0e6..a966aaa524ed 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -64,6 +64,9 @@ def default() -> "SessionConfig": DEFAULT_SESSION_LENGTH_IN_MINUTES, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION ) + def are_session_enabled(self) -> bool: + return self.session_length > 0 + class Domain: """The domain specifies the universe in which the bot's policy acts. diff --git a/rasa/core/processor.py b/rasa/core/processor.py index df4405a03c5e..536daad2015a 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -648,6 +648,11 @@ def _has_session_expired(self, tracker: DialogueStateTracker) -> bool: `True` if the session in `tracker` has expired, `False` otherwise. """ + + if not self.domain.session_config.are_session_enabled(): + # Tracker is never expired when sessions are disabled + return False + session_start_timestamp = self._session_start_timestamp_for(tracker) time_delta_in_seconds = time.time() - session_start_timestamp diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index 9543bacca922..be5dc34275e1 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -618,3 +618,15 @@ def test_domain_as_dict_with_session_config(): deserialized = Domain.from_dict(serialized) assert deserialized.session_config == session_config + + +@pytest.mark.parametrize( + "session_config, enabled", + [ + (SessionConfig(0, True), False), + (SessionConfig(1, True), True), + (SessionConfig(-1, False), False), + ], +) +def test_are_sessions_enabled(session_config: SessionConfig, enabled: bool): + assert session_config.are_session_enabled() == enabled diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 4aae620563ea..3de0fa268a68 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -314,6 +314,8 @@ async def test_is_legacy_tracker( (SessionStarted(timestamp=time.time()), 60, False), # there is no session start event (legacy tracker) (UserUttered("hello", timestamp=time.time()), 60, False), + # Old event, but sessions are disabled + (UserUttered("hello", timestamp=1), 0, False), # there is no event (None, 1, False), ], From 98af47cad7249d5205fe20b6c6e3cd409eff69d2 Mon Sep 17 00:00:00 2001 From: ricwo Date: Fri, 6 Dec 2019 13:46:35 +0100 Subject: [PATCH 46/74] do not dump action_session_start --- rasa/core/training/structures.py | 13 ++++++++++++- tests/core/test_structures.py | 11 +++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/rasa/core/training/structures.py b/rasa/core/training/structures.py index 087dbd995656..e2ba656dce03 100644 --- a/rasa/core/training/structures.py +++ b/rasa/core/training/structures.py @@ -7,7 +7,7 @@ from typing import List, Text, Dict, Optional, Tuple, Any, Set, ValuesView from rasa.core import utils -from rasa.core.actions.action import ACTION_LISTEN_NAME +from rasa.core.actions.action import ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME from rasa.core.conversation import Dialogue from rasa.core.domain import Domain from rasa.core.events import ( @@ -232,6 +232,8 @@ def as_story_string(self, flat: bool = False, e2e: bool = False) -> Text: elif isinstance(s, ActionExecuted): if self._is_action_listen(s): pass + elif self._is_action_session_start(s): + pass elif self.story_string_helper.active_form is None: result += self._bot_string(s) else: @@ -317,6 +319,15 @@ def _is_action_listen(event: ActionExecuted) -> bool: # we don't want to allow subclasses here return type(event) == ActionExecuted and event.action_name == ACTION_LISTEN_NAME + @staticmethod + def _is_action_session_start(event: ActionExecuted) -> bool: + # this is not an `isinstance` because + # we don't want to allow subclasses here + return ( + type(event) == ActionExecuted + and event.action_name == ACTION_SESSION_START_NAME + ) + def _add_action_listen(self, events: List[ActionExecuted]) -> None: if not events or not self._is_action_listen(events[-1]): # do not add second action_listen diff --git a/tests/core/test_structures.py b/tests/core/test_structures.py index 58f5d8d4d788..a6407b576bd9 100644 --- a/tests/core/test_structures.py +++ b/tests/core/test_structures.py @@ -1,5 +1,11 @@ +from rasa.core.actions.action import ACTION_SESSION_START_NAME from rasa.core.domain import Domain -from rasa.core.events import SessionStarted, SlotSet, UserUttered +from rasa.core.events import ( + SessionStarted, + SlotSet, + UserUttered, + ActionExecuted, +) from rasa.core.trackers import DialogueStateTracker from rasa.core.training.structures import Story @@ -14,7 +20,8 @@ def test_session_start_is_not_serialised(default_domain: Domain): # add SlotSet event tracker.update(SlotSet("slot", "value")) - # add a SessionStarted event and a user event + # add the two SessionStarted events and a user event + tracker.update(ActionExecuted(ACTION_SESSION_START_NAME)) tracker.update(SessionStarted()) tracker.update(UserUttered("say something")) From 1f73780de7ae52a8ec30da3969e85871f8677e48 Mon Sep 17 00:00:00 2001 From: ricwo Date: Fri, 6 Dec 2019 14:05:16 +0100 Subject: [PATCH 47/74] remove explicit mongo additional events test --- tests/core/test_tracker_stores.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index d7e7bb7283dd..d7db5d6904d4 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -455,29 +455,6 @@ def test_sql_additional_events(default_domain: Domain): ) -def test_mongo_events_since_last_session_start(default_domain: Domain): - tracker_store = MockedMongoTrackerStore(default_domain) - - # create tracker with events - events = [ - ActionExecuted(ACTION_LISTEN_NAME), - UserUttered("123"), - BotUttered("yes"), - SessionStarted(), - UserUttered("hello"), - BotUttered("welcome to your new session"), - ] - - tracker = DialogueStateTracker.from_events("conversation with session", events) - serialised = tracker.current_state(EventVerbosity.ALL) - - # make sure only events post session start are returned - # noinspection PyProtectedMember - assert tracker_store._events_since_last_session_start(serialised) == [ - event.as_dict() for event in events[3:] - ] - - @pytest.mark.parametrize( "tracker_store_type,tracker_store_kwargs", [(MockedMongoTrackerStore, {}), (SQLTrackerStore, {"host": "sqlite:///"})], From 9e237af5cd865b984a272c410fd52bae54af30bf Mon Sep 17 00:00:00 2001 From: ricwo Date: Fri, 6 Dec 2019 15:57:43 +0100 Subject: [PATCH 48/74] import Iterable --- rasa/core/trackers.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index 1e5d1511e4e3..84fc1e9b317c 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -2,7 +2,18 @@ import logging from collections import deque from enum import Enum -from typing import Dict, Text, Any, Optional, Iterator, Generator, Type, List, Deque +from typing import ( + Dict, + Text, + Any, + Optional, + Iterator, + Generator, + Type, + List, + Deque, + Iterable, +) from rasa.core import events # pytype: disable=pyi-error from rasa.core.actions.action import ACTION_LISTEN_NAME # pytype: disable=pyi-error From 0c8293a3319c53cb2d28016be80e23bd41beb238 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 09:34:10 +0100 Subject: [PATCH 49/74] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 12a8f0e53d78..cfa664f1349c 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "aioresponses~=0.6.0", "moto~=1.3.8", "fakeredis~=1.0", - "mongomock~=3.18.0", + "mongomock~=3.18", ] install_requires = [ From 3568a91816796af1403b7d8cea74d8f28a682a17 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 14:43:17 +0100 Subject: [PATCH 50/74] save current tracker state in mongo --- rasa/core/tracker_store.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 99b455b1c1d7..e1a902b82871 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -482,9 +482,19 @@ def save(self, tracker, timeout=None): additional_events = self._additional_events(tracker) + # get current tracker state remove `events` key from state as they're pushed + # separately + state = tracker.current_state(EventVerbosity.ALL) + state.pop("events", None) + self.conversations.update_one( {"sender_id": tracker.sender_id}, - {"$push": {"events": {"$each": [e.as_dict() for e in additional_events]}}}, + { + "$set": state, + "$push": { + "events": {"$each": [e.as_dict() for e in additional_events]} + }, + }, upsert=True, ) From ac6961278669c844f137783a5599928777d204df Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 15:42:58 +0100 Subject: [PATCH 51/74] partial review comments and improvements --- changelog/4830.feature.rst | 25 +++++---- rasa/cli/initial_project/domain.yml | 4 ++ rasa/core/actions/action.py | 4 +- rasa/core/domain.py | 23 +++++--- rasa/core/processor.py | 45 ++++------------ rasa/core/trackers.py | 11 ++-- tests/core/test_domain.py | 2 +- tests/core/test_processor.py | 84 +++++++++++------------------ 8 files changed, 87 insertions(+), 111 deletions(-) diff --git a/changelog/4830.feature.rst b/changelog/4830.feature.rst index 6457a66f8838..3affabb77f03 100644 --- a/changelog/4830.feature.rst +++ b/changelog/4830.feature.rst @@ -1,11 +1,16 @@ -Added a new event ``SessionStarted`` that marks the beginning of a new conversation -session. +Added conversation sessions to trackers. A conversation session represents the +dialog between the agent and a user. Conversation sessions can begin in three ways: 1. +the user begins the conversation with the agent, 2. the user sends their first +message after a configurable period of inactivity, or 3. a manual session start is +triggered with the ``/session_start`` intent message. The period of inactivity after +which a new conversation session is triggered is defined in the domain using the +``session_length`` key in the ``config`` section. The introduction of +conversation sessions comprises the following changes: -Added a new default action ``ActionSessionStart``. This action takes all ``SlotSet`` -events from the previous session and applies it to the next session. - -Added new default intent ``session_start`` which triggers the start of a new -conversation session. - -``SQLTrackerStore`` and ``MongoTrackerStore`` only retrieve events from the last -session from the database. +- Added a new event ``SessionStarted`` that marks the beginning of a new conversation + session. +- Added a new default action ``ActionSessionStart``. This action takes all + ``SlotSet`` events from the previous session and applies it to the next session. +- Added a new default intent ``session_start`` which triggers the start of a new + conversation session. ``SQLTrackerStore`` and ``MongoTrackerStore`` only retrieve + events from the last session from the database. diff --git a/rasa/cli/initial_project/domain.yml b/rasa/cli/initial_project/domain.yml index c4772bb14737..95fa39485e2f 100644 --- a/rasa/cli/initial_project/domain.yml +++ b/rasa/cli/initial_project/domain.yml @@ -34,3 +34,7 @@ templates: utter_iamabot: - text: "I am a bot, powered by Rasa." + +config: + session_length: 60 + carry_over_slots_to_new_session: true diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index f714ce85c262..578d984aea2d 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -336,9 +336,9 @@ async def run( class ActionSessionStart(Action): - """Applies. + """Applies a conversation session start. - Utters the 'session start' template if available.""" + """ def name(self) -> Text: return ACTION_SESSION_START_NAME diff --git a/rasa/core/domain.py b/rasa/core/domain.py index ba51084b155b..9888d8266afe 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -60,11 +60,10 @@ class SessionConfig(NamedTuple): @staticmethod def default() -> "SessionConfig": - return SessionConfig( - DEFAULT_SESSION_LENGTH_IN_MINUTES, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION - ) + # TODO: 2.0, reconsider how to apply sessions to old projects + return SessionConfig(0, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION) - def are_session_enabled(self) -> bool: + def are_sessions_enabled(self) -> bool: return self.session_length > 0 @@ -146,9 +145,19 @@ def from_dict(cls, data: Dict) -> "Domain": @staticmethod def _get_session_config(additional_arguments: Dict) -> SessionConfig: - session_length = additional_arguments.pop( - SESSION_LENGTH_KEY, DEFAULT_SESSION_LENGTH_IN_MINUTES - ) + + session_length = additional_arguments.pop(SESSION_LENGTH_KEY, None) + + # TODO: 2.0 reconsider how to apply sessions to old projects and legacy trackers + if session_length is None: + warnings.warn( + "No tracker session configuration was found in the loaded domain. " + "Domains without a session config will be deprecated in Rasa " + "version 2.0.", + DeprecationWarning, + ) + session_length = 0 + carry_over_slots = additional_arguments.pop( CARRY_OVER_SLOTS_KEY, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION ) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 3342618182a4..e65286fd1b25 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -157,7 +157,7 @@ async def _update_tracker_session( `ActionSessionStart`. """ - if self._is_legacy_tracker(tracker) or self._has_session_expired(tracker): + if self._has_session_expired(tracker): logger.debug( f"Starting a new session for conversation ID '{tracker.sender_id}'." ) @@ -652,18 +652,23 @@ def _has_session_expired(self, tracker: DialogueStateTracker) -> bool: """ - if not self.domain.session_config.are_session_enabled(): - # Tracker is never expired when sessions are disabled + if not self.domain.session_config.are_sessions_enabled(): + # tracker has never expired if sessions are disabled return False - session_start_timestamp = self._session_start_timestamp_for(tracker) + user_uttered_event: Optional[UserUttered] = tracker.get_last_event_for( + UserUttered + ) - time_delta_in_seconds = time.time() - session_start_timestamp + if not user_uttered_event: + # there is no user event so far so the session should not be considered + # expired + return False + time_delta_in_seconds = time.time() - user_uttered_event.timestamp has_expired = ( time_delta_in_seconds / 60 > self.domain.session_config.session_length ) - if has_expired: logger.debug( f"The latest session for conversation ID {tracker.sender_id} has " @@ -672,34 +677,6 @@ def _has_session_expired(self, tracker: DialogueStateTracker) -> bool: return has_expired - @staticmethod - def _is_legacy_tracker(tracker: DialogueStateTracker) -> bool: - """Determine whether `tracker` is a legacy tracker. - - A legacy tracker is a tracker that has been created before the introduction - of sessions in release 1.6.0. - - Args: - tracker: Tracker to inspect. - - Returns: - `True` if the tracker contains `SessionStarted` event, `False` otherwise. - - """ - last_session_started_event = tracker.get_last_session_started_event() - - is_legacy_tracker = last_session_started_event is None - - if is_legacy_tracker: - logger.debug( - f"Tracker for conversation ID '{tracker.sender_id}' is a legacy " - f"tracker. A legacy tracker is a tracker that contains no " - f"'SessionStarted' and was last saved before the introduction of " - f"tracker sessions in release 1.6.0." - ) - - return is_legacy_tracker - def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]: sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID return self.tracker_store.get_or_create_tracker(sender_id) diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index 84fc1e9b317c..fe9eba3784fc 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -516,9 +516,14 @@ def get_last_session_started_event(self) -> Optional[SessionStarted]: otherwise `None`. """ - for event in reversed(self.events): - if isinstance(event, SessionStarted): - return event + return next( + ( + event + for event in reversed(self.events) + if isinstance(event, SessionStarted) + ), + None, + ) ### # Internal methods for the modification of the trackers state. Should diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index be5dc34275e1..0fe235996966 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -629,4 +629,4 @@ def test_domain_as_dict_with_session_config(): ], ) def test_are_sessions_enabled(session_config: SessionConfig, enabled: bool): - assert session_config.are_session_enabled() == enabled + assert session_config.are_sessions_enabled() == enabled diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 3de0fa268a68..3d5206a96417 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -1,20 +1,20 @@ -from typing import Optional, List, Text - -import time - import asyncio -import datetime -import uuid +import logging +import datetime import pytest +import time +import uuid +from _pytest.monkeypatch import MonkeyPatch from aioresponses import aioresponses - +from typing import Optional, Text from unittest.mock import patch from rasa.core import jobs from rasa.core.actions.action import ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME from rasa.core.agent import Agent from rasa.core.channels.channel import CollectingOutputChannel, UserMessage +from rasa.core.domain import SessionConfig from rasa.core.events import ( ActionExecuted, BotUttered, @@ -26,17 +26,13 @@ Event, SlotSet, ) -from rasa.core.trackers import DialogueStateTracker -from rasa.core.slots import Slot from rasa.core.interpreter import RasaNLUHttpInterpreter from rasa.core.processor import MessageProcessor, DEFAULT_INTENTS +from rasa.core.slots import Slot +from rasa.core.trackers import DialogueStateTracker from rasa.utils.endpoints import EndpointConfig from tests.utilities import latest_request -from rasa.core.domain import Domain, SessionConfig - -import logging - logger = logging.getLogger(__name__) @@ -278,42 +274,15 @@ async def test_reminder_restart( assert len(t.events) == 5 # nothing should have been executed -@pytest.mark.parametrize( - "events_to_apply,is_legacy", - [ - # just an action listen means it's legacy - ([ActionExecuted(action_name=ACTION_LISTEN_NAME)], True), - # action listen and session at the beginning start means it isn't legacy - ([SessionStarted(), ActionExecuted(action_name=ACTION_LISTEN_NAME)], False), - # just a single event means it's legacy - ([UserUttered("hello")], True), - ], -) -async def test_is_legacy_tracker( - events_to_apply: List[Event], is_legacy: bool, default_processor: MessageProcessor -): - sender_id = uuid.uuid4().hex - - # create a new tracker without events - tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) - tracker.events.clear() - - for event in events_to_apply: - tracker.update(event) - - # noinspection PyProtectedMember - assert default_processor._is_legacy_tracker(tracker) == is_legacy - - @pytest.mark.parametrize( "event_to_apply,session_length_in_minutes,has_expired", [ - # session start is way in the past - (SessionStarted(timestamp=1), 60, True), - # session start is very recent - (SessionStarted(timestamp=time.time()), 60, False), - # there is no session start event (legacy tracker) - (UserUttered("hello", timestamp=time.time()), 60, False), + # last user event is way in the past + (UserUttered(timestamp=1), 60, True), + # user event are very recent + (UserUttered("hello", timestamp=time.time()), 60, False,), + # there is user event + (ActionExecuted(ACTION_LISTEN_NAME, timestamp=time.time()), 60, False), # Old event, but sessions are disabled (UserUttered("hello", timestamp=1), 0, False), # there is no event @@ -345,14 +314,17 @@ async def test_has_session_expired( # noinspection PyProtectedMember async def test_update_tracker_session( - default_channel: CollectingOutputChannel, default_processor: MessageProcessor + default_channel: CollectingOutputChannel, + default_processor: MessageProcessor, + monkeypatch: MonkeyPatch, ): sender_id = uuid.uuid4().hex tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) - # make sure session expires and run tracker session update - await asyncio.sleep(1e-2) # in seconds - default_processor.domain.session_config = SessionConfig(1e-5, True) + # patch `_has_session_expired()` so the `_update_tracker_session()` call actually + # does something + monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True) + await default_processor._update_tracker_session(tracker, default_channel) # the save is not called in _update_tracker_session() @@ -360,6 +332,7 @@ async def test_update_tracker_session( # inspect tracker and make sure all events are present tracker = default_processor.tracker_store.retrieve(sender_id) + assert list(tracker.events) == [ SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), @@ -371,7 +344,9 @@ async def test_update_tracker_session( # noinspection PyProtectedMember async def test_update_tracker_session_with_slots( - default_channel: CollectingOutputChannel, default_processor: MessageProcessor + default_channel: CollectingOutputChannel, + default_processor: MessageProcessor, + monkeypatch: MonkeyPatch, ): sender_id = uuid.uuid4().hex tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) @@ -385,9 +360,10 @@ async def test_update_tracker_session_with_slots( for event in slot_set_events: tracker.update(event) - # make sure session expires and run tracker session update - await asyncio.sleep(1e-2) # in seconds - default_processor.domain.session_config = SessionConfig(1e-5, True) + # patch `_has_session_expired()` so the `_update_tracker_session()` call actually + # does something + monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True) + await default_processor._update_tracker_session(tracker, default_channel) # the save is not called in _update_tracker_session() From 70e7a5aa24176066ef06456cc0e2c0252c16c66a Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 16:29:37 +0100 Subject: [PATCH 52/74] test tracker state without events --- rasa/core/tracker_store.py | 19 +++++++++++++------ tests/core/test_tracker_stores.py | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index e1a902b82871..bac55affa108 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -475,6 +475,18 @@ def _ensure_indices(self): """Create an index on the sender_id""" self.conversations.create_index("sender_id") + @staticmethod + def _current_tracker_state_without_events(tracker: DialogueStateTracker) -> Dict: + # get current tracker state and remove `events` key from state + # since events are pushed separately in the `update_one()` operation + state = tracker.current_state(EventVerbosity.ALL) + try: + del state["events"] + except KeyError: + pass + + return state + def save(self, tracker, timeout=None): """Saves the current conversation state""" if self.event_broker: @@ -482,15 +494,10 @@ def save(self, tracker, timeout=None): additional_events = self._additional_events(tracker) - # get current tracker state remove `events` key from state as they're pushed - # separately - state = tracker.current_state(EventVerbosity.ALL) - state.pop("events", None) - self.conversations.update_one( {"sender_id": tracker.sender_id}, { - "$set": state, + "$set": self._current_tracker_state_without_events(tracker), "$push": { "events": {"$each": [e.as_dict() for e in additional_events]} }, diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index d43f97b04707..95682d051e23 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -538,3 +538,25 @@ def test_tracker_store_retrieve_without_session_started_events( assert len(tracker.events) == 4 assert all(event == tracker.events[i] for i, event in enumerate(events)) + + +def test_current_state_without_events(default_domain: Domain,): + tracker_store = MockedMongoTrackerStore(default_domain) + + # insert some events + events = [ + UserUttered("Hola", {"name": "greet"}), + BotUttered("Hi"), + UserUttered("Ciao", {"name": "greet"}), + BotUttered("Hi2"), + ] + + sender_id = "test_sql_tracker_store_retrieve_without_session_started_events" + tracker = DialogueStateTracker.from_events(sender_id, events) + + # get current state without events + # noinspection PyProtectedMember + state = tracker_store._current_tracker_state_without_events(tracker) + + # `events` key should not be in there + assert state and "events" not in state From 4e69fcd3062c51a6cdcdab1b5a716b8ab8e35fcd Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 16:31:18 +0100 Subject: [PATCH 53/74] update sender ID in test --- tests/core/test_tracker_stores.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index 973e9bfa5c1a..ef68f3464869 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -551,7 +551,7 @@ def test_current_state_without_events(default_domain: Domain): BotUttered("Hi2"), ] - sender_id = "test_sql_tracker_store_retrieve_without_session_started_events" + sender_id = "test_mongo_tracker_store_current_state_without_events" tracker = DialogueStateTracker.from_events(sender_id, events) # get current state without events From 6917bf00bf0d1a2d789c600d43b179330b744c89 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 17:33:36 +0100 Subject: [PATCH 54/74] remove get_last_session_started_event --- rasa/core/domain.py | 2 +- rasa/core/processor.py | 26 -------------------------- rasa/core/trackers.py | 17 ----------------- tests/core/test_trackers.py | 24 ------------------------ 4 files changed, 1 insertion(+), 68 deletions(-) diff --git a/rasa/core/domain.py b/rasa/core/domain.py index 9888d8266afe..9b2e43461276 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -154,7 +154,7 @@ def _get_session_config(additional_arguments: Dict) -> SessionConfig: "No tracker session configuration was found in the loaded domain. " "Domains without a session config will be deprecated in Rasa " "version 2.0.", - DeprecationWarning, + FutureWarning, ) session_length = 0 diff --git a/rasa/core/processor.py b/rasa/core/processor.py index e65286fd1b25..341c6cbd259b 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -615,32 +615,6 @@ def _log_action_on_tracker( e.timestamp = time.time() tracker.update(e, self.domain) - @staticmethod - def _session_start_timestamp_for(tracker: DialogueStateTracker) -> Optional[float]: - """Retrieve timestamp of the beginning of the last session start for - `tracker`. - - Args: - tracker: Tracker to inspect. - - Returns: - Timestamp of last `SessionStarted` event if available, else timestamp of - oldest event. Current time if no events are available. - - """ - if not tracker.events: - # this is a legacy tracker (pre-sessions), return current time - return time.time() - - last_session_started_event = tracker.get_last_session_started_event() - - if last_session_started_event: - return last_session_started_event.timestamp - - # otherwise fetch the timestamp of the first event - # this also is a legacy tracker (pre-sessions) - return tracker.events[0].timestamp - def _has_session_expired(self, tracker: DialogueStateTracker) -> bool: """Determine whether the latest session in `tracker` has expired. diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index fe9eba3784fc..3ea4295e2ef3 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -508,23 +508,6 @@ def last_executed_action_has(self, name: Text, skip=0) -> bool: ) return last is not None and last.action_name == name - def get_last_session_started_event(self) -> Optional[SessionStarted]: - """Get the last `SessionStarted` event. - - Returns: - The last `SessionStarted` marking a session start if available, - otherwise `None`. - - """ - return next( - ( - event - for event in reversed(self.events) - if isinstance(event, SessionStarted) - ), - None, - ) - ### # Internal methods for the modification of the trackers state. Should # only be called by events, not directly. Rather update the tracker diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index 912b8b6863ff..1c66e59ac9e6 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -612,27 +612,3 @@ def test_tracker_without_slots(key, value, caplog): v = tracker.get_slot(key) assert v == value assert len(caplog.records) == 0 - - -@pytest.mark.parametrize( - "events,index_of_last_executed_event", - [ - ([ActionExecuted("one")], None), # no SessionStarted event - ([ActionExecuted("a"), SessionStarted()], 1), - ([ActionExecuted("first"), UserUttered("b"), SessionStarted()], 2), - ([SessionStarted(), UserUttered("b")], 0), - ], -) -def test_last_session_started_event( - events: List[Event], index_of_last_executed_event: int -): - tracker = get_tracker(events) - - # noinspection PyTypeChecker - expected_event: Optional[ActionExecuted] = events[ - index_of_last_executed_event - ] if index_of_last_executed_event is not None else None - - fetched_event = tracker.get_last_session_started_event() if expected_event else None - - assert expected_event == fetched_event From 79568dedff42d969f1433cf035a1d0f1fc5485ae Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 17:51:03 +0100 Subject: [PATCH 55/74] Apply suggestions from code review Co-Authored-By: Tobias Wochinger --- changelog/4830.feature.rst | 7 ++++--- rasa/core/domain.py | 1 - rasa/core/events/__init__.py | 2 -- rasa/core/processor.py | 2 +- rasa/core/tracker_store.py | 6 ++++-- rasa/core/training/structures.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/changelog/4830.feature.rst b/changelog/4830.feature.rst index 3affabb77f03..9454c074fba4 100644 --- a/changelog/4830.feature.rst +++ b/changelog/4830.feature.rst @@ -1,6 +1,6 @@ Added conversation sessions to trackers. A conversation session represents the -dialog between the agent and a user. Conversation sessions can begin in three ways: 1. -the user begins the conversation with the agent, 2. the user sends their first +dialog between the assistant and a user. Conversation sessions can begin in three ways: 1. +the user begins the conversation with the assistant, 2. the user sends their first message after a configurable period of inactivity, or 3. a manual session start is triggered with the ``/session_start`` intent message. The period of inactivity after which a new conversation session is triggered is defined in the domain using the @@ -12,5 +12,6 @@ conversation sessions comprises the following changes: - Added a new default action ``ActionSessionStart``. This action takes all ``SlotSet`` events from the previous session and applies it to the next session. - Added a new default intent ``session_start`` which triggers the start of a new - conversation session. ``SQLTrackerStore`` and ``MongoTrackerStore`` only retrieve + conversation session. + - ``SQLTrackerStore`` and ``MongoTrackerStore`` only retrieve events from the last session from the database. diff --git a/rasa/core/domain.py b/rasa/core/domain.py index 9888d8266afe..9fc86820c9d0 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -145,7 +145,6 @@ def from_dict(cls, data: Dict) -> "Domain": @staticmethod def _get_session_config(additional_arguments: Dict) -> SessionConfig: - session_length = additional_arguments.pop(SESSION_LENGTH_KEY, None) # TODO: 2.0 reconsider how to apply sessions to old projects and legacy trackers diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index 4dce140330ba..575799846d9f 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -7,11 +7,9 @@ import typing import uuid from dateutil import parser - from datetime import datetime from typing import List, Dict, Text, Any, Type, Optional - from rasa.core import utils from typing import Union diff --git a/rasa/core/processor.py b/rasa/core/processor.py index e65286fd1b25..6ce0595b4efb 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -671,7 +671,7 @@ def _has_session_expired(self, tracker: DialogueStateTracker) -> bool: ) if has_expired: logger.debug( - f"The latest session for conversation ID {tracker.sender_id} has " + f"The latest session for conversation ID '{tracker.sender_id}' has " f"expired." ) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index e1a902b82871..58ca97fd8a94 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -505,9 +505,10 @@ def _additional_events(self, tracker: DialogueStateTracker) -> Iterator: tracker: Tracker to inspect. Returns: - List of serialised events that aren't current stored. + List of serialised events that aren't currently stored. """ + stored = self.conversations.find_one({"sender_id": tracker.sender_id}) n_events = len(stored.get("events", [])) if stored else 0 @@ -525,6 +526,7 @@ def _events_since_last_session_start(serialised_tracker: Dict) -> List[Dict]: event. Returns all events if no such event is found. """ + events = [] for event in reversed(serialised_tracker.get("events", [])): events.append(event) @@ -763,7 +765,7 @@ def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]: from rasa.core.events import SessionStarted with self.session_scope() as session: - # Subquery to find the timestamp of the first `SessionStarted` event + # Subquery to find the timestamp of the latest `SessionStarted` event session_start_sub_query = ( session.query( sa.func.max(self.SQLEvent.timestamp).label("session_start") diff --git a/rasa/core/training/structures.py b/rasa/core/training/structures.py index abc829d259ed..9ed0156f98c0 100644 --- a/rasa/core/training/structures.py +++ b/rasa/core/training/structures.py @@ -203,7 +203,7 @@ def as_story_string(self, flat: bool = False, e2e: bool = False) -> Text: else: # form is active # it is not known whether the form will be - # successfully executed, so store this] + # successfully executed, so store this # story string for later self._store_user_strings(s, e2e, FORM_PREFIX) From d93d6760a3616cda54326d677f5e2b50b7598db4 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 17:51:13 +0100 Subject: [PATCH 56/74] update docstring --- rasa/core/processor.py | 3 +-- rasa/core/trackers.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 341c6cbd259b..d135a0937f65 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -148,8 +148,7 @@ async def _update_tracker_session( ) -> None: """Check the current session in `tracker` and update it if expired. - A 'session_start' is run if the tracker is a legacy tracker, or if the latest - tracker session has expired. + A 'session_start' the latest tracker session has expired. Args: tracker: Tracker to inspect. diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index 3ea4295e2ef3..54fcec1dbe93 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -278,7 +278,6 @@ def init_copy(self) -> "DialogueStateTracker": UserMessage.DEFAULT_SENDER_ID, self.slots.values(), self._max_event_history ) - # TODO: exclude SessionStart from prior states def generate_all_prior_trackers( self, ) -> Generator["DialogueStateTracker", None, None]: From 205bf214700cd5f90c64eab258202489f6a4ecc3 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 18:16:27 +0100 Subject: [PATCH 57/74] add applied_events() test --- data/test_domains/duplicate_intents.yml | 4 ++++ rasa/core/actions/action.py | 3 +++ tests/core/test_domain.py | 11 ++--------- tests/core/test_trackers.py | 21 +++++++++++++++++++++ 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/data/test_domains/duplicate_intents.yml b/data/test_domains/duplicate_intents.yml index a60f4c935dba..0806a9547e3e 100644 --- a/data/test_domains/duplicate_intents.yml +++ b/data/test_domains/duplicate_intents.yml @@ -26,3 +26,7 @@ actions: - utter_default - utter_greet - utter_goodbye + +config: + session_length: 60 + carry_over_slots_to_new_session: true diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 578d984aea2d..3ce964038b93 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -338,6 +338,8 @@ async def run( class ActionSessionStart(Action): """Applies a conversation session start. + Takes all `SlotSet` events from the previous session and applies them to the new + session. """ def name(self) -> Text: @@ -352,6 +354,7 @@ def _slot_set_events_from_tracker( from rasa.core.events import SlotSet # use generator so the timestamps are greater than that of the returned + # SessionStarted event in the run() call return ( SlotSet(key=event.key, value=event.value, metadata=event.metadata) for event in tracker.events diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index 0fe235996966..48e8d9f2e93b 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -578,22 +578,15 @@ def test_add_knowledge_base_slots(default_domain): 0, True, ), - ("", 60, True), + ("", 0, True), ( """config: carry_over_slots_to_new_session: false""", - 60, + 0, False, ), ( """config: - carry_over_slots_to_new_session: false""", - 60, - False, - ), - ( - """ -config: session_length: 20.2 carry_over_slots_to_new_session: False""", 20.2, diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index 1c66e59ac9e6..2fa8c30cbf6c 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -510,6 +510,27 @@ def test_current_state_applied_events(default_agent): assert state.get("events") == applied_events +def test_session_started_not_part_of_applied_events(default_agent: Agent): + # take tracker dump and insert a SessionStarted event sequence + tracker_dump = "data/test_trackers/tracker_moodbot.json" + tracker_json = json.loads(rasa.utils.io.read_file(tracker_dump)) + tracker_json["events"].insert( + 4, {"event": ActionExecuted.type_name, "name": ACTION_SESSION_START_NAME} + ) + tracker_json["events"].insert(5, {"event": SessionStarted.type_name}) + + # initialise a tracker from this list of events + tracker = DialogueStateTracker.from_dict( + tracker_json.get("sender_id"), + tracker_json.get("events", []), + default_agent.domain.slots, + ) + + # the SessionStart event was at index 5, the tracker's `applied_events()` should + # be the same as the list of events from index 6 onwards + assert tracker.applied_events() == list(tracker.events)[6:] + + async def test_tracker_dump_e2e_story(default_agent): sender_id = "test_tracker_dump_e2e_story" From 5c359f90dc495c6fd0251728e8b56ab70e120f2b Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 9 Dec 2019 18:22:28 +0100 Subject: [PATCH 58/74] general improvements --- tests/core/test_domain.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index 48e8d9f2e93b..ab400d4170df 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -592,6 +592,7 @@ def test_add_knowledge_base_slots(default_domain): 20.2, False, ), + ("""config: {}""", 0, True,), ], ) def test_session_config( From 88e33b4a3646f75cb29785d4d6e798abb2f067e4 Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 10 Dec 2019 10:13:59 +0100 Subject: [PATCH 59/74] add a handle_message() test and update `create_tracker` docstring --- rasa/core/tracker_store.py | 22 +++++++++++- tests/core/test_processor.py | 65 ++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 86517d0478f1..484fe968c4e8 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -169,11 +169,31 @@ def create_tracker( append_action_listen: bool = True, should_append_session_started: bool = True, ) -> DialogueStateTracker: - """Creates a new tracker for the sender_id. The tracker is initially listening. + """Creates a new tracker for the sender_id. + + The tracker begins with a `SessionStarted` event and is initially listening. + + Args: + sender_id: Conversation ID associated with the tracker. + append_action_listen: Whether or not to append an initial `action_listen`. + should_append_session_started: Whether or not to append an initial + `session_started` event. If `True` this will be the first event of the + tracker. Note: every tracker should begin with a `session_started` + event. This kwarg is provided only for completeness and in analogy + to the existing `append_action_listen` kwarg. Internal Rasa calls + of this method should never set `should_append_session_started` to + `False`. + + Returns: + The newly created tracker for `sender_id`. + """ tracker = self.init_tracker(sender_id) + if tracker: if should_append_session_started: + # do not set this to `False`, unless required by an external tool + # that creates trackers tracker.update(SessionStarted()) if append_action_listen: diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 3d5206a96417..643bae447444 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -5,6 +5,7 @@ import pytest import time import uuid +import json from _pytest.monkeypatch import MonkeyPatch from aioresponses import aioresponses from typing import Optional, Text @@ -391,3 +392,67 @@ async def test_update_tracker_session_with_slots( # finally an action listen, this should also be the last event assert events[15] == events[-1] == ActionExecuted(ACTION_LISTEN_NAME) + + +async def test_handle_message_with_session_start( + default_channel: CollectingOutputChannel, + default_processor: MessageProcessor, + monkeypatch: MonkeyPatch, +): + sender_id = uuid.uuid4().hex + + entity = "name" + slot_1 = {entity: "Core"} + await default_processor.handle_message( + UserMessage(f"/greet{json.dumps(slot_1)}", default_channel, sender_id) + ) + + assert { + "recipient_id": sender_id, + "text": "hey there Core!", + } == default_channel.latest_output() + + # patch processor so a session start is triggered + monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True) + + slot_2 = {entity: "post-session start hello"} + # handle a new message + await default_processor.handle_message( + UserMessage(f"/greet{json.dumps(slot_2)}", default_channel, sender_id) + ) + + tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) + + # make sure the sequence of events is as expected + assert list(tracker.events) == [ + SessionStarted(), + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered( + f"/greet{json.dumps(slot_1)}", + {"name": "greet", "confidence": 1.0}, + [{"entity": entity, "start": 6, "end": 22, "value": "Core"}], + ), + SlotSet(entity, slot_1[entity]), + ActionExecuted("utter_greet"), + BotUttered("hey there Core!"), + ActionExecuted(ACTION_LISTEN_NAME), + ActionExecuted(ACTION_SESSION_START_NAME), + SessionStarted(), + # the initial SlotSet is reapplied after the SessionStarted sequence + SlotSet(entity, slot_1[entity]), + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered( + f"/greet{json.dumps(slot_2)}", + {"name": "greet", "confidence": 1.0}, + [ + { + "entity": entity, + "start": 6, + "end": 42, + "value": "post-session start hello", + } + ], + ), + SlotSet(entity, slot_2[entity]), + ActionExecuted(ACTION_LISTEN_NAME), + ] From c32039fc3b941a4d386e15918b7b4d5ebf8d2387 Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 10 Dec 2019 10:15:28 +0100 Subject: [PATCH 60/74] fix typo in 'overridden' --- rasa/core/tracker_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 484fe968c4e8..97333d970cfc 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -204,11 +204,11 @@ def create_tracker( return tracker def save(self, tracker): - """Save method that will be overriden by specific tracker""" + """Save method that will be overridden by specific tracker""" raise NotImplementedError() def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]: - """Retrieve method that will be overriden by specific tracker""" + """Retrieve method that will be overridden by specific tracker""" raise NotImplementedError() def stream_events(self, tracker: DialogueStateTracker) -> None: From 27e78f19ac18c8b8ef0f658339a58c056303a5b3 Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 10 Dec 2019 10:44:15 +0100 Subject: [PATCH 61/74] update docstring --- rasa/core/tracker_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 97333d970cfc..9cfa3daf1139 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -169,7 +169,7 @@ def create_tracker( append_action_listen: bool = True, should_append_session_started: bool = True, ) -> DialogueStateTracker: - """Creates a new tracker for the sender_id. + """Creates a new tracker for `sender_id`. The tracker begins with a `SessionStarted` event and is initially listening. From f3022aa5b124fff43820b6f868f80b067e02d4f0 Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 10 Dec 2019 11:52:41 +0100 Subject: [PATCH 62/74] pop events --- rasa/core/tracker_store.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 9cfa3daf1139..308c6be72757 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -500,10 +500,7 @@ def _current_tracker_state_without_events(tracker: DialogueStateTracker) -> Dict # get current tracker state and remove `events` key from state # since events are pushed separately in the `update_one()` operation state = tracker.current_state(EventVerbosity.ALL) - try: - del state["events"] - except KeyError: - pass + state.pop("events", None) return state From 6a22247042935fa703665c3e8bfc1209322d806a Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 10 Dec 2019 18:36:24 +0100 Subject: [PATCH 63/74] review comments --- changelog/4830.feature.rst | 28 ++++++++++++-------- data/test_domains/duplicate_intents.yml | 2 +- docs/core/actions.rst | 10 ++++++++ rasa/cli/initial_project/domain.yml | 2 +- rasa/constants.py | 2 +- rasa/core/actions/action.py | 24 ++++++++--------- rasa/core/domain.py | 25 ++++++++++-------- rasa/core/processor.py | 23 ++++++++++++----- rasa/core/schemas/domain.yml | 2 +- rasa/core/tracker_store.py | 34 +++++++++++-------------- rasa/core/training/structures.py | 6 ++--- tests/core/conftest.py | 15 ++++++++++- tests/core/test_domain.py | 17 ++++++++----- tests/core/test_processor.py | 23 ++++++++--------- tests/core/test_tracker_stores.py | 27 +++++--------------- tests/core/test_trackers.py | 33 +++++++++++++++++------- 16 files changed, 156 insertions(+), 117 deletions(-) diff --git a/changelog/4830.feature.rst b/changelog/4830.feature.rst index 9454c074fba4..75730ad15880 100644 --- a/changelog/4830.feature.rst +++ b/changelog/4830.feature.rst @@ -1,17 +1,25 @@ -Added conversation sessions to trackers. A conversation session represents the -dialog between the assistant and a user. Conversation sessions can begin in three ways: 1. -the user begins the conversation with the assistant, 2. the user sends their first -message after a configurable period of inactivity, or 3. a manual session start is -triggered with the ``/session_start`` intent message. The period of inactivity after -which a new conversation session is triggered is defined in the domain using the -``session_length`` key in the ``config`` section. The introduction of -conversation sessions comprises the following changes: +Added conversation sessions to trackers. + +A conversation session represents the dialog between the assistant and a user. +Conversation sessions can begin in three ways: 1. the user begins the conversation +with the assistant, 2. the user sends their first message after a configurable period +of inactivity, or 3. a manual session start is triggered with the ``/session_start`` +intent message. The period of inactivity after which a new conversation session is +triggered is defined in the domain using the ``session_expiration_time`` key in the +``config`` section. The introduction of conversation sessions comprises the following +changes: - Added a new event ``SessionStarted`` that marks the beginning of a new conversation session. - Added a new default action ``ActionSessionStart``. This action takes all ``SlotSet`` events from the previous session and applies it to the next session. - Added a new default intent ``session_start`` which triggers the start of a new - conversation session. - - ``SQLTrackerStore`` and ``MongoTrackerStore`` only retrieve + conversation session. +- ``SQLTrackerStore`` and ``MongoTrackerStore`` only retrieve events from the last session from the database. + + +.. note:: + + The session behaviour is disabled for existing projects, i.e. existing domains + without session config section. diff --git a/data/test_domains/duplicate_intents.yml b/data/test_domains/duplicate_intents.yml index 0806a9547e3e..e81a2cea9a61 100644 --- a/data/test_domains/duplicate_intents.yml +++ b/data/test_domains/duplicate_intents.yml @@ -28,5 +28,5 @@ actions: - utter_goodbye config: - session_length: 60 + session_expiration_time: 60 carry_over_slots_to_new_session: true diff --git a/docs/core/actions.rst b/docs/core/actions.rst index fb590f77e557..58819ef037f2 100644 --- a/docs/core/actions.rst +++ b/docs/core/actions.rst @@ -222,6 +222,16 @@ There are eight default actions: | | if the :ref:`mapping-policy` is included in | | | the policy configuration. | +-----------------------------------+------------------------------------------------+ +| ``action_session_start`` | Start a new conversation session. Take all set | +| | slots, mark the beginning of a new conversation| +| | session and re-apply the existing ``SlotSet`` | +| | events. This action is triggered automatically | +| | after an inactivity period defined by the | +| | ``session_expiration_time`` parameter in the | +| | domain's session config. Can be triggered | +| | manually during a conversation by entering | +| | ``/session_start``. | ++-----------------------------------+------------------------------------------------+ | ``action_default_fallback`` | Undo the last user message (as if the user did | | | not send it and the bot did not react) and | | | utter a message that the bot did not | diff --git a/rasa/cli/initial_project/domain.yml b/rasa/cli/initial_project/domain.yml index 95fa39485e2f..e9b8470150cf 100644 --- a/rasa/cli/initial_project/domain.yml +++ b/rasa/cli/initial_project/domain.yml @@ -36,5 +36,5 @@ templates: - text: "I am a bot, powered by Rasa." config: - session_length: 60 + session_expiration_time: 60 carry_over_slots_to_new_session: true diff --git a/rasa/constants.py b/rasa/constants.py index 122db69e6f28..2e598681b6b4 100644 --- a/rasa/constants.py +++ b/rasa/constants.py @@ -47,5 +47,5 @@ ENV_SANIC_WORKERS = "SANIC_WORKERS" ENV_SANIC_BACKLOG = "SANIC_BACKLOG" -DEFAULT_SESSION_LENGTH_IN_MINUTES = 60 +DEFAULT_SESSION_EXPIRATION_TIME_IN_MINUTES = 60 DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION = True diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 3ce964038b93..5fef4c703587 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -348,18 +348,16 @@ def name(self) -> Text: @staticmethod def _slot_set_events_from_tracker( tracker: "DialogueStateTracker", - ) -> Generator["SlotSet", None, None]: + ) -> List["SlotSet"]: """Fetch SlotSet events from tracker and carry over key, value and metadata.""" from rasa.core.events import SlotSet - # use generator so the timestamps are greater than that of the returned - # SessionStarted event in the run() call - return ( + return [ SlotSet(key=event.key, value=event.value, metadata=event.metadata) - for event in tracker.events + for event in tracker.applied_events() if isinstance(event, SlotSet) - ) + ] async def run( self, @@ -370,16 +368,14 @@ async def run( ) -> List[Event]: from rasa.core.events import SessionStarted - slot_set_events = [] + _events = [SessionStarted()] + if domain.session_config.carry_over_slots: - slot_set_events = self._slot_set_events_from_tracker(tracker) + _events.extend(self._slot_set_events_from_tracker(tracker)) - # noinspection PyTypeChecker - return ( - [SessionStarted()] - + list(slot_set_events) - + [ActionExecuted(ACTION_LISTEN_NAME)] - ) + _events.append(ActionExecuted(ACTION_LISTEN_NAME)) + + return _events class ActionDefaultFallback(ActionUtterTemplate): diff --git a/rasa/core/domain.py b/rasa/core/domain.py index ac5103cadde8..416964996ddc 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -13,7 +13,7 @@ from rasa.cli.utils import bcolors from rasa.constants import ( DOMAIN_SCHEMA_FILE, - DEFAULT_SESSION_LENGTH_IN_MINUTES, + DEFAULT_SESSION_EXPIRATION_TIME_IN_MINUTES, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION, ) from rasa.core import utils @@ -37,7 +37,7 @@ ACTIVE_FORM_PREFIX = "active_form_" CARRY_OVER_SLOTS_KEY = "carry_over_slots_to_new_session" -SESSION_LENGTH_KEY = "session_length" +SESSION_EXPIRATION_TIME_KEY = "session_expiration_time" if typing.TYPE_CHECKING: from rasa.core.trackers import DialogueStateTracker @@ -55,7 +55,7 @@ def __str__(self): class SessionConfig(NamedTuple): - session_length: float + session_expiration_time: float # in minutes carry_over_slots: bool @staticmethod @@ -64,7 +64,7 @@ def default() -> "SessionConfig": return SessionConfig(0, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION) def are_sessions_enabled(self) -> bool: - return self.session_length > 0 + return self.session_expiration_time > 0 class Domain: @@ -145,23 +145,26 @@ def from_dict(cls, data: Dict) -> "Domain": @staticmethod def _get_session_config(additional_arguments: Dict) -> SessionConfig: - session_length = additional_arguments.pop(SESSION_LENGTH_KEY, None) + session_expiration_time = additional_arguments.pop( + SESSION_EXPIRATION_TIME_KEY, None + ) # TODO: 2.0 reconsider how to apply sessions to old projects and legacy trackers - if session_length is None: + if session_expiration_time is None: warnings.warn( "No tracker session configuration was found in the loaded domain. " - "Domains without a session config will be deprecated in Rasa " - "version 2.0.", + "Domains without a session config will automatically receive a " + "session expiration time of 60 minutes in Rasa version 2.0 if not " + "configured otherwise.", FutureWarning, ) - session_length = 0 + session_expiration_time = 0 carry_over_slots = additional_arguments.pop( CARRY_OVER_SLOTS_KEY, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION ) - return SessionConfig(session_length, carry_over_slots) + return SessionConfig(session_expiration_time, carry_over_slots) @classmethod def from_directory(cls, path: Text) -> "Domain": @@ -697,7 +700,7 @@ def _slot_definitions(self) -> Dict[Any, Dict[str, Any]]: def as_dict(self) -> Dict[Text, Any]: additional_config = { "store_entities_as_slots": self.store_entities_as_slots, - SESSION_LENGTH_KEY: self.session_config.session_length, + SESSION_EXPIRATION_TIME_KEY: self.session_config.session_expiration_time, CARRY_OVER_SLOTS_KEY: self.session_config.carry_over_slots, } diff --git a/rasa/core/processor.py b/rasa/core/processor.py index b99c2f4bccc9..acdbb4ce1206 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -143,20 +143,28 @@ def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]: "tracker": tracker.current_state(EventVerbosity.AFTER_RESTART), } + @staticmethod + def _contains_no_user_message(tracker: DialogueStateTracker) -> bool: + """Determine `tracker` does not yet contain any user messages.""" + + return tracker.get_last_event_for(UserUttered) is None + async def _update_tracker_session( self, tracker: DialogueStateTracker, output_channel: OutputChannel ) -> None: """Check the current session in `tracker` and update it if expired. - A 'session_start' the latest tracker session has expired. + An 'action_session_start' is run if the latest tracker session has expired, or + if the tracker has not yet received any user messages. Args: tracker: Tracker to inspect. output_channel: Output channel for potential utterances in a custom `ActionSessionStart`. - """ - if self._has_session_expired(tracker): + if self._contains_no_user_message(tracker) or self._has_session_expired( + tracker + ): logger.debug( f"Starting a new session for conversation ID '{tracker.sender_id}'." ) @@ -175,7 +183,6 @@ async def log_message( Optionally save the tracker if `should_save_tracker` is `True`. Tracker saving can be skipped if the tracker returned by this method is used for further processing and saved at a later stage. - """ # preprocess message if necessary @@ -622,7 +629,6 @@ def _has_session_expired(self, tracker: DialogueStateTracker) -> bool: Returns: `True` if the session in `tracker` has expired, `False` otherwise. - """ if not self.domain.session_config.are_sessions_enabled(): @@ -640,7 +646,8 @@ def _has_session_expired(self, tracker: DialogueStateTracker) -> bool: time_delta_in_seconds = time.time() - user_uttered_event.timestamp has_expired = ( - time_delta_in_seconds / 60 > self.domain.session_config.session_length + time_delta_in_seconds / 60 + > self.domain.session_config.session_expiration_time ) if has_expired: logger.debug( @@ -652,7 +659,9 @@ def _has_session_expired(self, tracker: DialogueStateTracker) -> bool: def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]: sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID - return self.tracker_store.get_or_create_tracker(sender_id) + return self.tracker_store.get_or_create_tracker( + sender_id, append_action_listen=False + ) def _save_tracker(self, tracker: DialogueStateTracker) -> None: self.tracker_store.save(tracker) diff --git a/rasa/core/schemas/domain.yml b/rasa/core/schemas/domain.yml index 8f4651c92aa7..36405ec88ff5 100644 --- a/rasa/core/schemas/domain.yml +++ b/rasa/core/schemas/domain.yml @@ -37,7 +37,7 @@ mapping: mapping: store_entities_as_slots: type: "bool" - session_length: + session_expiration_time: type: "number" range: min: 0 diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 308c6be72757..d5fb2c5630db 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -146,13 +146,24 @@ def load_tracker_from_module_string( return InMemoryTrackerStore(domain) def get_or_create_tracker( - self, sender_id: Text, max_event_history: Optional[int] = None, + self, + sender_id: Text, + max_event_history: Optional[int] = None, + append_action_listen: bool = True, ) -> "DialogueStateTracker": - """Returns tracker or creates one if the retrieval returns None""" + """Returns tracker or creates one if the retrieval returns None. + + Args: + sender_id: pass + max_event_history: pass + append_action_listen: Whether or not to append an initial `action_listen`. + """ tracker = self.retrieve(sender_id) self.max_event_history = max_event_history if tracker is None: - tracker = self.create_tracker(sender_id) + tracker = self.create_tracker( + sender_id, append_action_listen=append_action_listen, + ) return tracker def init_tracker(self, sender_id: Text) -> "DialogueStateTracker": @@ -164,10 +175,7 @@ def init_tracker(self, sender_id: Text) -> "DialogueStateTracker": ) def create_tracker( - self, - sender_id: Text, - append_action_listen: bool = True, - should_append_session_started: bool = True, + self, sender_id: Text, append_action_listen: bool = True, ) -> DialogueStateTracker: """Creates a new tracker for `sender_id`. @@ -176,13 +184,6 @@ def create_tracker( Args: sender_id: Conversation ID associated with the tracker. append_action_listen: Whether or not to append an initial `action_listen`. - should_append_session_started: Whether or not to append an initial - `session_started` event. If `True` this will be the first event of the - tracker. Note: every tracker should begin with a `session_started` - event. This kwarg is provided only for completeness and in analogy - to the existing `append_action_listen` kwarg. Internal Rasa calls - of this method should never set `should_append_session_started` to - `False`. Returns: The newly created tracker for `sender_id`. @@ -191,11 +192,6 @@ def create_tracker( tracker = self.init_tracker(sender_id) if tracker: - if should_append_session_started: - # do not set this to `False`, unless required by an external tool - # that creates trackers - tracker.update(SessionStarted()) - if append_action_listen: tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) diff --git a/rasa/core/training/structures.py b/rasa/core/training/structures.py index 9ed0156f98c0..c9cee4f3f82f 100644 --- a/rasa/core/training/structures.py +++ b/rasa/core/training/structures.py @@ -339,13 +339,13 @@ def _add_action_listen(self, events: List[ActionExecuted]) -> None: def explicit_events( self, domain: Domain, should_append_final_listen: bool = True ) -> List[Event]: - """Returns events contained in the story step - including implicit events. + """Returns events contained in the story step including implicit events. Not all events are always listed in the story dsl. This includes listen actions as well as implicitly set slots. This functions makes these events explicit and - returns them with the rest of the steps events.""" + returns them with the rest of the steps events. + """ events = [] diff --git a/tests/core/conftest.py b/tests/core/conftest.py index a6d65969248c..3e60254abfd8 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -20,7 +20,7 @@ ) from rasa.core.processor import MessageProcessor from rasa.core.slots import Slot -from rasa.core.tracker_store import InMemoryTrackerStore +from rasa.core.tracker_store import InMemoryTrackerStore, MongoTrackerStore from rasa.core.trackers import DialogueStateTracker @@ -71,6 +71,19 @@ def __init__(self, example_arg): pass +class MockedMongoTrackerStore(MongoTrackerStore): + """In-memory mocked version of `MongoTrackerStore`.""" + + def __init__( + self, _domain: Domain, + ): + from mongomock import MongoClient + + self.db = MongoClient().rasa + self.collection = "conversations" + super(MongoTrackerStore, self).__init__(_domain, None) + + @pytest.fixture(scope="session") def loop(): loop = asyncio.new_event_loop() diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index ab400d4170df..ef871aec11d6 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -231,7 +231,7 @@ def test_domain_to_yaml(): - utter_greet config: carry_over_slots_to_new_session: true - session_length: 60 + session_expiration_time: 60 store_entities_as_slots: true entities: [] forms: [] @@ -569,11 +569,11 @@ def test_add_knowledge_base_slots(default_domain): @pytest.mark.parametrize( - "input_domain, expected_session_length, expected_carry_over_slots", + "input_domain, expected_session_expiration_time, expected_carry_over_slots", [ ( """config: - session_length: 0 + session_expiration_time: 0 carry_over_slots_to_new_session: true""", 0, True, @@ -587,7 +587,7 @@ def test_add_knowledge_base_slots(default_domain): ), ( """config: - session_length: 20.2 + session_expiration_time: 20.2 carry_over_slots_to_new_session: False""", 20.2, False, @@ -596,10 +596,15 @@ def test_add_knowledge_base_slots(default_domain): ], ) def test_session_config( - input_domain, expected_session_length: float, expected_carry_over_slots: bool + input_domain, + expected_session_expiration_time: float, + expected_carry_over_slots: bool, ): domain = Domain.from_yaml(input_domain) - assert domain.session_config.session_length == expected_session_length + assert ( + domain.session_config.session_expiration_time + == expected_session_expiration_time + ) assert domain.session_config.carry_over_slots == expected_carry_over_slots diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 643bae447444..89f2fd6d6e98 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -203,7 +203,7 @@ async def test_reminder_aborted( # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) - assert len(t.events) == 4 # nothing should have been executed + assert len(t.events) == 3 # nothing should have been executed async def test_reminder_cancelled( @@ -272,11 +272,11 @@ async def test_reminder_restart( # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) - assert len(t.events) == 5 # nothing should have been executed + assert len(t.events) == 4 # nothing should have been executed @pytest.mark.parametrize( - "event_to_apply,session_length_in_minutes,has_expired", + "event_to_apply,session_expiration_time_in_minutes,has_expired", [ # last user event is way in the past (UserUttered(timestamp=1), 60, True), @@ -292,14 +292,14 @@ async def test_reminder_restart( ) async def test_has_session_expired( event_to_apply: Optional[Event], - session_length_in_minutes: float, + session_expiration_time_in_minutes: float, has_expired: bool, default_processor: MessageProcessor, ): sender_id = uuid.uuid4().hex default_processor.domain.session_config = SessionConfig( - session_length_in_minutes, True + session_expiration_time_in_minutes, True ) # create new tracker without events tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) @@ -335,7 +335,6 @@ async def test_update_tracker_session( tracker = default_processor.tracker_store.retrieve(sender_id) assert list(tracker.events) == [ - SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), @@ -375,23 +374,22 @@ async def test_update_tracker_session_with_slots( events = list(tracker.events) # the first three events should be up to the user utterance - assert events[:3] == [ - SessionStarted(), + assert events[:2] == [ ActionExecuted(ACTION_LISTEN_NAME), user_event, ] # next come the five slots - assert events[3:8] == slot_set_events + assert events[2:7] == slot_set_events # the next two events are the session start sequence - assert events[8:10] == [ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted()] + assert events[7:9] == [ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted()] # the five slots should be reapplied - assert events[10:15] == slot_set_events + assert events[9:14] == slot_set_events # finally an action listen, this should also be the last event - assert events[15] == events[-1] == ActionExecuted(ACTION_LISTEN_NAME) + assert events[14] == events[-1] == ActionExecuted(ACTION_LISTEN_NAME) async def test_handle_message_with_session_start( @@ -425,6 +423,7 @@ async def test_handle_message_with_session_start( # make sure the sequence of events is as expected assert list(tracker.events) == [ + ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), UserUttered( diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index ef68f3464869..5f966cfe66c8 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -1,15 +1,15 @@ import logging import tempfile -import uuid - -from typing import Tuple, Text, Type, Dict, List -from unittest.mock import Mock import pytest +import uuid from _pytest.logging import LogCaptureFixture from _pytest.monkeypatch import MonkeyPatch from moto import mock_dynamodb2 +from typing import Tuple, Text, Type, Dict, List +from unittest.mock import Mock +import rasa.core.tracker_store from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.channels.channel import UserMessage from rasa.core.domain import Domain @@ -29,29 +29,14 @@ SQLTrackerStore, DynamoTrackerStore, FailSafeTrackerStore, - MongoTrackerStore, ) -import rasa.core.tracker_store -from rasa.core.trackers import DialogueStateTracker, EventVerbosity +from rasa.core.trackers import DialogueStateTracker from rasa.utils.endpoints import EndpointConfig, read_endpoint_config -from tests.core.conftest import DEFAULT_ENDPOINTS_FILE +from tests.core.conftest import DEFAULT_ENDPOINTS_FILE, MockedMongoTrackerStore domain = Domain.load("data/test_domains/default.yml") -class MockedMongoTrackerStore(MongoTrackerStore): - """In-memory mocked version of `MongoTrackerStore`.""" - - def __init__( - self, _domain: Domain, - ): - from mongomock import MongoClient - - self.db = MongoClient().rasa - self.collection = "conversations" - super(MongoTrackerStore, self).__init__(domain, None) - - def get_or_create_tracker_store(store: TrackerStore): slot_key = "location" slot_val = "Easter Island" diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index 2fa8c30cbf6c..bedf07e50dfb 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -2,7 +2,6 @@ import logging import os import tempfile -from typing import List, Optional import fakeredis import pytest @@ -20,7 +19,6 @@ ActionReverted, UserUtteranceReverted, SessionStarted, - Event, ) from rasa.core.tracker_store import ( InMemoryTrackerStore, @@ -29,7 +27,12 @@ ) from rasa.core.tracker_store import TrackerStore from rasa.core.trackers import DialogueStateTracker, EventVerbosity -from tests.core.conftest import DEFAULT_STORIES_FILE, EXAMPLE_DOMAINS, TEST_DIALOGUES +from tests.core.conftest import ( + DEFAULT_STORIES_FILE, + EXAMPLE_DOMAINS, + TEST_DIALOGUES, + MockedMongoTrackerStore, +) from tests.core.utilities import ( tracker_from_dialogue_file, read_dialogue_file, @@ -57,11 +60,12 @@ def stores_to_be_tested(): MockRedisTrackerStore(domain), InMemoryTrackerStore(domain), SQLTrackerStore(domain, db=os.path.join(temp, "rasa.db")), + MockedMongoTrackerStore(domain), ] def stores_to_be_tested_ids(): - return ["redis-tracker", "in-memory-tracker", "SQL-tracker"] + return ["redis-tracker", "in-memory-tracker", "SQL-tracker", "mongo-tracker"] def test_tracker_duplicate(): @@ -87,7 +91,6 @@ def test_tracker_store_storage_and_retrieval(store): # Action listen should be in there assert list(tracker.events) == [ - SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), ] @@ -100,13 +103,13 @@ def test_tracker_store_storage_and_retrieval(store): # retrieving the same tracker should result in the same tracker retrieved_tracker = store.get_or_create_tracker("some-id") assert retrieved_tracker.sender_id == "some-id" - assert len(retrieved_tracker.events) == 3 + assert len(retrieved_tracker.events) == 2 assert retrieved_tracker.latest_message.intent.get("name") == "greet" # getting another tracker should result in an empty tracker again other_tracker = store.get_or_create_tracker("some-other-id") assert other_tracker.sender_id == "some-other-id" - assert len(other_tracker.events) == 2 + assert len(other_tracker.events) == 1 @pytest.mark.parametrize("store", stores_to_be_tested(), ids=stores_to_be_tested_ids()) @@ -148,7 +151,10 @@ async def test_tracker_state_regression_without_bot_utterance(default_agent: Age # Ensures that the tracker has changed between the utterances # (and wasn't reset in between them) - expected = "action_listen;greet;utter_greet;action_listen;greet;action_listen" + expected = ( + "action_session_start;action_listen;greet;utter_greet;action_listen;" + "greet;action_listen" + ) assert ( ";".join([e.as_story_string() for e in tracker.events if e.as_story_string()]) == expected @@ -162,6 +168,7 @@ async def test_tracker_state_regression_with_bot_utterance(default_agent: Agent) tracker = default_agent.tracker_store.get_or_create_tracker(sender_id) expected = [ + "action_session_start", None, "action_listen", "greet", @@ -184,7 +191,15 @@ async def test_bot_utterance_comes_after_action_event(default_agent): # important is, that the 'bot' comes after the second 'action' and not # before - expected = ["session_started", "action", "user", "action", "bot", "action"] + expected = [ + "action", + "session_started", + "action", + "user", + "action", + "bot", + "action", + ] assert [e.type_name for e in tracker.events] == expected From ebc9edf74fce5d73f5026a71ee53cdfb27441455 Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 10 Dec 2019 18:55:40 +0100 Subject: [PATCH 64/74] run session start action at the beginning of every tracker --- rasa/core/processor.py | 37 +++++++++++++++++++++++++----------- rasa/core/trackers.py | 5 +++++ rasa/server.py | 21 +++++++++++++------- tests/core/test_processor.py | 17 +++++++++++++++++ tests/core/test_trackers.py | 27 ++++++++++++++++++++++++++ 5 files changed, 89 insertions(+), 18 deletions(-) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index acdbb4ce1206..189c3af05d95 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -143,12 +143,6 @@ def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]: "tracker": tracker.current_state(EventVerbosity.AFTER_RESTART), } - @staticmethod - def _contains_no_user_message(tracker: DialogueStateTracker) -> bool: - """Determine `tracker` does not yet contain any user messages.""" - - return tracker.get_last_event_for(UserUttered) is None - async def _update_tracker_session( self, tracker: DialogueStateTracker, output_channel: OutputChannel ) -> None: @@ -162,9 +156,7 @@ async def _update_tracker_session( output_channel: Output channel for potential utterances in a custom `ActionSessionStart`. """ - if self._contains_no_user_message(tracker) or self._has_session_expired( - tracker - ): + if tracker.contains_no_user_message() or self._has_session_expired(tracker): logger.debug( f"Starting a new session for conversation ID '{tracker.sender_id}'." ) @@ -175,6 +167,28 @@ async def _update_tracker_session( nlg=self.nlg, ) + async def get_tracker_with_session_start( + self, sender_id: Text, output_channel: Optional[OutputChannel] = None, + ) -> Optional[DialogueStateTracker]: + """Get tracker for `sender_id` or create a new tracker for `sender_id`. + + If a new tracker is created, `action_session_start` is run. + + Args: + output_channel: Output channel associated with the incoming user message. + sender_id: Conversation ID for which to fetch the tracker. + Returns: + Tracker for `sender_id` if available, `None` otherwise. + """ + + tracker = self._get_tracker(sender_id) + if not tracker: + return None + + await self._update_tracker_session(tracker, output_channel) + + return tracker + async def log_message( self, message: UserMessage, should_save_tracker: bool = True ) -> Optional[DialogueStateTracker]: @@ -190,10 +204,11 @@ async def log_message( message.text = self.message_preprocessor(message.text) # we have a Tracker instance for each user # which maintains conversation state - tracker = self._get_tracker(message.sender_id) + tracker = await self.get_tracker_with_session_start( + message.sender_id, message.output_channel + ) if tracker: - await self._update_tracker_session(tracker, message.output_channel) await self._handle_message_with_tracker(message, tracker) if should_save_tracker: diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index 54fcec1dbe93..d83cb9a5e2a3 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -491,6 +491,11 @@ def filter_function(e: Event): return next(filtered, None) + def contains_no_user_message(self) -> bool: + """Determine whether tracker does not yet contain any user messages.""" + + return self.get_last_event_for(UserUttered) is None + def last_executed_action_has(self, name: Text, skip=0) -> bool: """Returns whether last `ActionExecuted` event had a specific name. diff --git a/rasa/server.py b/rasa/server.py index 072b8e0c6b05..c8cc6499793d 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -46,6 +46,7 @@ if typing.TYPE_CHECKING: from ssl import SSLContext + from rasa.core.processor import MessageProcessor logger = logging.getLogger(__name__) @@ -208,8 +209,10 @@ def event_verbosity_parameter( ) -def get_tracker(agent: "Agent", conversation_id: Text) -> DialogueStateTracker: - tracker = agent.tracker_store.get_or_create_tracker(conversation_id) +async def get_tracker( + processor: "MessageProcessor", conversation_id: Text +) -> Optional[DialogueStateTracker]: + tracker = await processor.get_tracker_with_session_start(conversation_id) if not tracker: raise ErrorResponse( 409, @@ -440,7 +443,7 @@ async def retrieve_tracker(request: Request, conversation_id: Text): verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) until_time = rasa.utils.endpoints.float_arg(request, "until") - tracker = get_tracker(app.agent, conversation_id) + tracker = await get_tracker(app.agent.create_processor(), conversation_id) try: if until_time is not None: @@ -488,7 +491,9 @@ async def append_events(request: Request, conversation_id: Text): try: async with app.agent.lock_store.lock(conversation_id): - tracker = get_tracker(app.agent, conversation_id) + tracker = await get_tracker( + app.agent.create_processor(), conversation_id + ) for event in events: tracker.update(event, app.agent.domain) app.agent.tracker_store.save(tracker) @@ -536,7 +541,7 @@ async def retrieve_story(request: Request, conversation_id: Text): """Get an end-to-end story corresponding to this conversation.""" # retrieve tracker and set to requested state - tracker = get_tracker(app.agent, conversation_id) + tracker = await get_tracker(app.agent.create_processor(), conversation_id) until_time = rasa.utils.endpoints.float_arg(request, "until") @@ -575,7 +580,9 @@ async def execute_action(request: Request, conversation_id: Text): try: async with app.agent.lock_store.lock(conversation_id): - tracker = get_tracker(app.agent, conversation_id) + tracker = await get_tracker( + app.agent.create_processor(), conversation_id + ) output_channel = _get_output_channel(request, tracker) await app.agent.execute_action( conversation_id, @@ -591,7 +598,7 @@ async def execute_action(request: Request, conversation_id: Text): 500, "ConversationError", f"An unexpected error occurred. Error: {e}" ) - tracker = get_tracker(app.agent, conversation_id) + tracker = await get_tracker(app.agent.create_processor(), conversation_id) state = tracker.current_state(verbosity) response_body = {"tracker": state} diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 89f2fd6d6e98..fa73c0478c4d 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -392,6 +392,23 @@ async def test_update_tracker_session_with_slots( assert events[14] == events[-1] == ActionExecuted(ACTION_LISTEN_NAME) +# noinspection PyProtectedMember +async def test_get_tracker_with_session_start( + default_channel: CollectingOutputChannel, default_processor: MessageProcessor, +): + sender_id = uuid.uuid4().hex + tracker = await default_processor.get_tracker_with_session_start( + sender_id, default_channel + ) + + # ensure session start sequence is present + assert list(tracker.events) == [ + ActionExecuted(ACTION_SESSION_START_NAME), + SessionStarted(), + ActionExecuted(ACTION_LISTEN_NAME), + ] + + async def test_handle_message_with_session_start( default_channel: CollectingOutputChannel, default_processor: MessageProcessor, diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index bedf07e50dfb..5012d5bbfc5d 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -2,6 +2,7 @@ import logging import os import tempfile +from typing import List import fakeredis import pytest @@ -19,6 +20,7 @@ ActionReverted, UserUtteranceReverted, SessionStarted, + Event, ) from rasa.core.tracker_store import ( InMemoryTrackerStore, @@ -648,3 +650,28 @@ def test_tracker_without_slots(key, value, caplog): v = tracker.get_slot(key) assert v == value assert len(caplog.records) == 0 + + +@pytest.mark.parametrize( + "events, contains_no_user_message", + [ + ( + [ + ActionExecuted("one"), + UserUttered("two", 1), + ActionExecuted(ACTION_LISTEN_NAME), + ], + False, + ), + ([], True), + ([ActionExecuted("one")], True), + ([UserUttered("two", 1)], False), + ], +) +def test_tracker_contains_no_user_message( + events: List[Event], contains_no_user_message: bool +): + tracker = DialogueStateTracker.from_dict( + "any", [event.as_dict() for event in events] + ) + assert tracker.contains_no_user_message() == contains_no_user_message From 450e652eed1b0b44d96795fc4408675e0267acea Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 10 Dec 2019 22:33:27 +0100 Subject: [PATCH 65/74] fix tests --- rasa/core/processor.py | 12 ++++++++---- rasa/core/trackers.py | 5 ----- tests/core/test_agent.py | 2 +- tests/core/test_trackers.py | 25 ------------------------- tests/test_server.py | 18 +++++++++++++----- 5 files changed, 22 insertions(+), 40 deletions(-) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 189c3af05d95..0d18674b5f80 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -148,15 +148,16 @@ async def _update_tracker_session( ) -> None: """Check the current session in `tracker` and update it if expired. - An 'action_session_start' is run if the latest tracker session has expired, or - if the tracker has not yet received any user messages. + An 'action_session_start' is run if the latest tracker session has expired, + or if the tracker does not yet contain any events (only those after the last + restart are considered). Args: tracker: Tracker to inspect. output_channel: Output channel for potential utterances in a custom `ActionSessionStart`. """ - if tracker.contains_no_user_message() or self._has_session_expired(tracker): + if len(tracker.applied_events()) == 0 or self._has_session_expired(tracker): logger.debug( f"Starting a new session for conversation ID '{tracker.sender_id}'." ) @@ -583,7 +584,10 @@ async def _run_action( def _warn_about_new_slots(self, tracker, action_name, events) -> None: # these are the events from that action we have seen during training - if action_name not in self.policy_ensemble.action_fingerprints: + if ( + not self.policy_ensemble + or action_name not in self.policy_ensemble.action_fingerprints + ): return fp = self.policy_ensemble.action_fingerprints[action_name] diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index d83cb9a5e2a3..54fcec1dbe93 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -491,11 +491,6 @@ def filter_function(e: Event): return next(filtered, None) - def contains_no_user_message(self) -> bool: - """Determine whether tracker does not yet contain any user messages.""" - - return self.get_last_event_for(UserUttered) is None - def last_executed_action_has(self, name: Text, skip=0) -> bool: """Returns whether last `ActionExecuted` event had a specific name. diff --git a/tests/core/test_agent.py b/tests/core/test_agent.py index 6def31338e3c..d5e836edeb0b 100644 --- a/tests/core/test_agent.py +++ b/tests/core/test_agent.py @@ -263,7 +263,7 @@ async def test_agent_update_model_none_domain(trained_model: Text): tracker = agent.tracker_store.get_or_create_tracker(sender_id) # UserUttered event was added to tracker, with correct intent data - assert tracker.events[2].intent["name"] == "greet" + assert tracker.events[3].intent["name"] == "greet" async def test_load_agent_on_not_existing_path(): diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index 5012d5bbfc5d..c29c3bf27802 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -650,28 +650,3 @@ def test_tracker_without_slots(key, value, caplog): v = tracker.get_slot(key) assert v == value assert len(caplog.records) == 0 - - -@pytest.mark.parametrize( - "events, contains_no_user_message", - [ - ( - [ - ActionExecuted("one"), - UserUttered("two", 1), - ActionExecuted(ACTION_LISTEN_NAME), - ], - False, - ), - ([], True), - ([ActionExecuted("one")], True), - ([UserUttered("two", 1)], False), - ], -) -def test_tracker_contains_no_user_message( - events: List[Event], contains_no_user_message: bool -): - tracker = DialogueStateTracker.from_dict( - "any", [event.as_dict() for event in events] - ) - assert tracker.contains_no_user_message() == contains_no_user_message diff --git a/tests/test_server.py b/tests/test_server.py index 1092b5faabda..b319231530fa 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -536,7 +536,14 @@ def test_requesting_non_existent_tracker(rasa_app: SanicTestClient): assert content["slots"] == {"location": None, "cuisine": None} assert content["sender_id"] == "madeupid" assert content["events"] == [ - {"event": "session_started", "timestamp": 1514764800,}, + { + "event": "action", + "name": "action_session_start", + "policy": None, + "confidence": None, + "timestamp": 1514764800, + }, + {"event": "session_started", "timestamp": 1514764800}, { "event": "action", "name": "action_listen", @@ -570,9 +577,10 @@ def test_pushing_event(rasa_app, event): _, tracker_response = rasa_app.get(f"/conversations/{cid}/tracker") tracker = tracker_response.json assert tracker is not None - assert len(tracker.get("events")) == 3 - evt = tracker.get("events")[2] + assert len(tracker.get("events")) == 4 + + evt = tracker.get("events")[3] assert Event.from_parameters(evt) == event @@ -594,8 +602,8 @@ def test_push_multiple_events(rasa_app: SanicTestClient): assert tracker is not None # there is also an `ACTION_LISTEN` event at the start - assert len(tracker.get("events")) == len(test_events) + 2 - assert tracker.get("events")[2:] == events + assert len(tracker.get("events")) == len(test_events) + 3 + assert tracker.get("events")[3:] == events def test_put_tracker(rasa_app: SanicTestClient): From 770e0a5757f157fb93c22fcadc995042c5d85e2d Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 10 Dec 2019 22:40:20 +0100 Subject: [PATCH 66/74] update f-strings --- rasa/core/processor.py | 29 ++++++++++++----------------- rasa/server.py | 4 ++-- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 0d18674b5f80..e2b9e83e010a 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -178,6 +178,7 @@ async def get_tracker_with_session_start( Args: output_channel: Output channel associated with the incoming user message. sender_id: Conversation ID for which to fetch the tracker. + Returns: Tracker for `sender_id` if available, `None` otherwise. """ @@ -217,7 +218,7 @@ async def log_message( self._save_tracker(tracker) else: logger.warning( - "Failed to retrieve or create tracker for conversation ID " + f"Failed to retrieve or create tracker for conversation ID " f"'{message.sender_id}'." ) return tracker @@ -265,9 +266,8 @@ def predict_next_action( max_confidence_index, self.action_endpoint ) logger.debug( - "Predicted next action '{}' with confidence {:.2f}.".format( - action.name(), action_confidences[max_confidence_index] - ) + f"Predicted next action '{action.name()}' with confidence " + f"{action_confidences[max_confidence_index]:.2f}." ) return action, policy, action_confidences[max_confidence_index] @@ -322,10 +322,8 @@ async def handle_reminder( or not self._is_reminder_still_valid(tracker, reminder_event) ): logger.debug( - "Canceled reminder because it is outdated. " - "(event: {} id: {})".format( - reminder_event.action_name, reminder_event.name - ) + f"Canceled reminder because it is outdated. " + f"(event: {reminder_event.action_name} id: {reminder_event.name})" ) else: # necessary for proper featurization, otherwise the previous @@ -431,8 +429,7 @@ async def _handle_message_with_tracker( self._log_slots(tracker) logger.debug( - "Logged UserUtterance - " - "tracker now has {} events".format(len(tracker.events)) + f"Logged UserUtterance - " f"tracker now has {len(tracker.events)} events." ) @staticmethod @@ -561,10 +558,10 @@ async def _run_action( return self.should_predict_another_action(action.name(), events) except Exception as e: logger.error( - "Encountered an exception while running action '{}'. " + f"Encountered an exception while running action '{action.name()}'. " "Bot will continue, but the actions events are lost. " "Please check the logs of your action server for " - "more information.".format(action.name()) + "more information." ) logger.debug(e, exc_info=True) events = [] @@ -621,9 +618,7 @@ def _log_action_on_tracker( events = [] logger.debug( - "Action '{}' ended with events '{}'".format( - action_name, [f"{e}" for e in events] - ) + f"Action '{action_name}' ended with events '{[e for e in events]}'." ) self._warn_about_new_slots(tracker, action_name, events) @@ -709,9 +704,9 @@ def _get_next_action_probabilities( return result else: logger.error( - "Trying to run unknown follow up action '{}'!" + f"Trying to run unknown follow-up action '{followup_action}'!" "Instead of running that, we will ignore the action " - "and predict the next action.".format(followup_action) + "and predict the next action." ) return self.policy_ensemble.probabilities_using_best_policy( diff --git a/rasa/server.py b/rasa/server.py index c8cc6499793d..03e2c0d68ae2 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -217,8 +217,8 @@ async def get_tracker( raise ErrorResponse( 409, "Conflict", - "Could not retrieve tracker with id '{}'. Most likely " - "because there is no domain set on the agent.".format(conversation_id), + f"Could not retrieve tracker with id '{conversation_id}'. Most likely " + f"because there is no domain set on the agent.", ) return tracker From 8a65d811ecc0f85d3ff061040747f6d83e8a83d0 Mon Sep 17 00:00:00 2001 From: ricwo Date: Tue, 10 Dec 2019 22:55:36 +0100 Subject: [PATCH 67/74] update docstring --- rasa/core/tracker_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index d5fb2c5630db..071ec1a781a4 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -154,8 +154,8 @@ def get_or_create_tracker( """Returns tracker or creates one if the retrieval returns None. Args: - sender_id: pass - max_event_history: pass + sender_id: Conversation ID associated with the requested tracker. + max_event_history: Value to update the tracker store's max event history to. append_action_listen: Whether or not to append an initial `action_listen`. """ tracker = self.retrieve(sender_id) From 56259aa3c5a6d96cac4731300eed7f284ba1943a Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 11 Dec 2019 10:30:45 +0100 Subject: [PATCH 68/74] handle session_start intent message properly --- rasa/core/policies/mapping_policy.py | 2 +- rasa/core/processor.py | 23 +++++++--- rasa/core/trackers.py | 15 +++++++ tests/core/test_processor.py | 65 ++++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 6 deletions(-) diff --git a/rasa/core/policies/mapping_policy.py b/rasa/core/policies/mapping_policy.py index 8b7ec8d1174c..efcdf401aad8 100644 --- a/rasa/core/policies/mapping_policy.py +++ b/rasa/core/policies/mapping_policy.py @@ -97,7 +97,7 @@ def predict_action_probabilities( elif intent == USER_INTENT_BACK: action = ACTION_BACK_NAME elif intent == USER_INTENT_SESSION_START: - action = ACTION_SESSION_START_NAME + action = ACTION_LISTEN_NAME else: action = domain.intent_properties.get(intent, {}).get("triggers") diff --git a/rasa/core/processor.py b/rasa/core/processor.py index e2b9e83e010a..82ed09fd41f0 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -144,7 +144,10 @@ def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]: } async def _update_tracker_session( - self, tracker: DialogueStateTracker, output_channel: OutputChannel + self, + tracker: DialogueStateTracker, + output_channel: OutputChannel, + message_text: Optional[Text] = None, ) -> None: """Check the current session in `tracker` and update it if expired. @@ -156,8 +159,14 @@ async def _update_tracker_session( tracker: Tracker to inspect. output_channel: Output channel for potential utterances in a custom `ActionSessionStart`. + message_text: Text of the incoming user message. """ - if len(tracker.applied_events()) == 0 or self._has_session_expired(tracker): + if ( + not tracker.applied_events() # new tracker + or self._has_session_expired(tracker) # session has expired + # a manual session start was requested + or (message_text == f"{INTENT_MESSAGE_PREFIX}{USER_INTENT_SESSION_START}") + ): logger.debug( f"Starting a new session for conversation ID '{tracker.sender_id}'." ) @@ -169,7 +178,10 @@ async def _update_tracker_session( ) async def get_tracker_with_session_start( - self, sender_id: Text, output_channel: Optional[OutputChannel] = None, + self, + sender_id: Text, + output_channel: Optional[OutputChannel] = None, + message_text: Text = None, ) -> Optional[DialogueStateTracker]: """Get tracker for `sender_id` or create a new tracker for `sender_id`. @@ -178,6 +190,7 @@ async def get_tracker_with_session_start( Args: output_channel: Output channel associated with the incoming user message. sender_id: Conversation ID for which to fetch the tracker. + message_text: Text of the incoming user message. Returns: Tracker for `sender_id` if available, `None` otherwise. @@ -187,7 +200,7 @@ async def get_tracker_with_session_start( if not tracker: return None - await self._update_tracker_session(tracker, output_channel) + await self._update_tracker_session(tracker, output_channel, message_text) return tracker @@ -207,7 +220,7 @@ async def log_message( # we have a Tracker instance for each user # which maintains conversation state tracker = await self.get_tracker_with_session_start( - message.sender_id, message.output_channel + message.sender_id, message.output_channel, message.text ) if tracker: diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index 54fcec1dbe93..1d65cddd5c45 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -32,6 +32,7 @@ ) from rasa.core.domain import Domain # pytype: disable=pyi-error from rasa.core.slots import Slot +from rasa.core import constants logger = logging.getLogger(__name__) @@ -358,10 +359,24 @@ def undo_till_previous(event_type, done_events): if isinstance(e, event_type): break + def is_session_start_intent_message(_event: Event) -> bool: + """Determine whether `event` is a `UserUttered` event with the explicit + `/session_start` intent.""" + + return isinstance(_event, UserUttered) and _event.text == ( + f"{constants.INTENT_MESSAGE_PREFIX}" + f"{constants.USER_INTENT_SESSION_START}" + ) + applied_events = [] for event in self.events: if isinstance(event, (Restarted, SessionStarted)): applied_events = [] + elif is_session_start_intent_message(event): + # UserUttered('/session_start') messages need to be excluded from + # featurisation - similar to `restarted` events but the prior SlotSet + # events cannot be ignored + undo_till_previous(ActionExecuted, applied_events) elif isinstance(event, ActionReverted): undo_till_previous(ActionExecuted, applied_events) elif isinstance(event, UserUtteranceReverted): diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index fa73c0478c4d..2386ce00a27b 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -16,6 +16,7 @@ from rasa.core.agent import Agent from rasa.core.channels.channel import CollectingOutputChannel, UserMessage from rasa.core.domain import SessionConfig +from rasa.core import constants from rasa.core.events import ( ActionExecuted, BotUttered, @@ -472,3 +473,67 @@ async def test_handle_message_with_session_start( SlotSet(entity, slot_2[entity]), ActionExecuted(ACTION_LISTEN_NAME), ] + + +async def test_handle_message_with_session_start_intent( + default_channel: CollectingOutputChannel, default_processor: MessageProcessor, +): + sender_id = uuid.uuid4().hex + + # first send a normal message + entity = "name" + slot_1 = {entity: "Core"} + await default_processor.handle_message( + UserMessage(f"/greet{json.dumps(slot_1)}", default_channel, sender_id) + ) + + # send session start intent message + await default_processor.handle_message( + UserMessage( + f"/{constants.USER_INTENT_SESSION_START}", default_channel, sender_id + ) + ) + + # and another normal message + await default_processor.handle_message( + UserMessage(f"/greet", default_channel, sender_id) + ) + + tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) + + # make sure the sequence of events is as expected + assert list(tracker.events) == [ + ActionExecuted(ACTION_SESSION_START_NAME), + SessionStarted(), + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered( + f"/greet{json.dumps(slot_1)}", + {"name": "greet", "confidence": 1.0}, + [{"entity": entity, "start": 6, "end": 22, "value": "Core"}], + ), + SlotSet(entity, slot_1[entity]), + ActionExecuted("utter_greet"), + BotUttered("hey there Core!"), + ActionExecuted(ACTION_LISTEN_NAME), + ActionExecuted(ACTION_SESSION_START_NAME), + SessionStarted(), + # the initial SlotSet is reapplied after the SessionStarted sequence + SlotSet(entity, slot_1[entity]), + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered( + f"/{constants.USER_INTENT_SESSION_START}", + {"name": constants.USER_INTENT_SESSION_START, "confidence": 1.0}, + ), + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered(f"/greet", {"name": "greet", "confidence": 1.0},), + ActionExecuted(ACTION_LISTEN_NAME), + ] + + # the applied events should not include the UserUttered event, but they must + # include the SlotSet event + assert tracker.applied_events() == [ + SlotSet(entity, slot_1[entity]), + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered(f"/greet", {"name": "greet", "confidence": 1.0},), + ActionExecuted(ACTION_LISTEN_NAME), + ] From 76d5703e26993b882790d3eb40749327de15ace6 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 11 Dec 2019 11:28:05 +0100 Subject: [PATCH 69/74] update config key --- changelog/4830.feature.rst | 4 +- data/test_domains/duplicate_intents.yml | 2 +- rasa/cli/initial_project/domain.yml | 2 +- rasa/core/agent.py | 6 ++- rasa/core/domain.py | 22 ++++----- rasa/core/events/__init__.py | 5 -- rasa/core/policies/mapping_policy.py | 2 +- rasa/core/processor.py | 32 ++++--------- rasa/core/schemas/domain.yml | 4 ++ rasa/core/trackers.py | 16 +------ rasa/server.py | 2 +- tests/core/test_domain.py | 13 ++--- tests/core/test_processor.py | 64 ------------------------- tests/core/test_training.py | 4 +- 14 files changed, 43 insertions(+), 135 deletions(-) diff --git a/changelog/4830.feature.rst b/changelog/4830.feature.rst index 75730ad15880..f47f5fe75925 100644 --- a/changelog/4830.feature.rst +++ b/changelog/4830.feature.rst @@ -6,8 +6,8 @@ with the assistant, 2. the user sends their first message after a configurable p of inactivity, or 3. a manual session start is triggered with the ``/session_start`` intent message. The period of inactivity after which a new conversation session is triggered is defined in the domain using the ``session_expiration_time`` key in the -``config`` section. The introduction of conversation sessions comprises the following -changes: +``session_config`` section. The introduction of conversation sessions comprises the +following changes: - Added a new event ``SessionStarted`` that marks the beginning of a new conversation session. diff --git a/data/test_domains/duplicate_intents.yml b/data/test_domains/duplicate_intents.yml index e81a2cea9a61..b08755195edc 100644 --- a/data/test_domains/duplicate_intents.yml +++ b/data/test_domains/duplicate_intents.yml @@ -27,6 +27,6 @@ actions: - utter_greet - utter_goodbye -config: +session_config: session_expiration_time: 60 carry_over_slots_to_new_session: true diff --git a/rasa/cli/initial_project/domain.yml b/rasa/cli/initial_project/domain.yml index e9b8470150cf..c4d7865fc3c8 100644 --- a/rasa/cli/initial_project/domain.yml +++ b/rasa/cli/initial_project/domain.yml @@ -35,6 +35,6 @@ templates: utter_iamabot: - text: "I am a bot, powered by Rasa." -config: +session_config: session_expiration_time: 60 carry_over_slots_to_new_session: true diff --git a/rasa/core/agent.py b/rasa/core/agent.py index de36d209418b..e75cb03fdc64 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -486,11 +486,13 @@ def noop(_): return await processor.handle_message(message) # noinspection PyUnusedLocal - def predict_next(self, sender_id: Text, **kwargs: Any) -> Optional[Dict[Text, Any]]: + async def predict_next( + self, sender_id: Text, **kwargs: Any + ) -> Optional[Dict[Text, Any]]: """Handle a single message.""" processor = self.create_processor() - return processor.predict_next(sender_id) + return await processor.predict_next(sender_id) # noinspection PyUnusedLocal async def log_message( diff --git a/rasa/core/domain.py b/rasa/core/domain.py index 416964996ddc..62ea52e600a6 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -38,6 +38,7 @@ CARRY_OVER_SLOTS_KEY = "carry_over_slots_to_new_session" SESSION_EXPIRATION_TIME_KEY = "session_expiration_time" +SESSION_CONFIG_KEY = "session_config" if typing.TYPE_CHECKING: from rasa.core.trackers import DialogueStateTracker @@ -129,7 +130,7 @@ def from_dict(cls, data: Dict) -> "Domain": utter_templates = cls.collect_templates(data.get("templates", {})) slots = cls.collect_slots(data.get("slots", {})) additional_arguments = data.get("config", {}) - session_config = cls._get_session_config(additional_arguments) + session_config = cls._get_session_config(data.get(SESSION_CONFIG_KEY, {})) intents = data.get("intents", {}) return cls( @@ -144,10 +145,8 @@ def from_dict(cls, data: Dict) -> "Domain": ) @staticmethod - def _get_session_config(additional_arguments: Dict) -> SessionConfig: - session_expiration_time = additional_arguments.pop( - SESSION_EXPIRATION_TIME_KEY, None - ) + def _get_session_config(session_config: Dict) -> SessionConfig: + session_expiration_time = session_config.get(SESSION_EXPIRATION_TIME_KEY) # TODO: 2.0 reconsider how to apply sessions to old projects and legacy trackers if session_expiration_time is None: @@ -160,7 +159,7 @@ def _get_session_config(additional_arguments: Dict) -> SessionConfig: ) session_expiration_time = 0 - carry_over_slots = additional_arguments.pop( + carry_over_slots = session_config.get( CARRY_OVER_SLOTS_KEY, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION ) @@ -698,14 +697,13 @@ def _slot_definitions(self) -> Dict[Any, Dict[str, Any]]: return {slot.name: slot.persistence_info() for slot in self.slots} def as_dict(self) -> Dict[Text, Any]: - additional_config = { - "store_entities_as_slots": self.store_entities_as_slots, - SESSION_EXPIRATION_TIME_KEY: self.session_config.session_expiration_time, - CARRY_OVER_SLOTS_KEY: self.session_config.carry_over_slots, - } return { - "config": additional_config, + "config": {"store_entities_as_slots": self.store_entities_as_slots}, + SESSION_CONFIG_KEY: { + SESSION_EXPIRATION_TIME_KEY: self.session_config.session_expiration_time, + CARRY_OVER_SLOTS_KEY: self.session_config.carry_over_slots, + }, "intents": [{k: v} for k, v in self.intent_properties.items()], "entities": self.entities, "slots": self._slot_definitions(), diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index 575799846d9f..cafe01824622 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -1204,10 +1204,5 @@ def as_story_string(self) -> None: return None def apply_to(self, tracker: "DialogueStateTracker") -> None: - from rasa.core.actions.action import ( # pytype: disable=pyi-error - ACTION_SESSION_START_NAME, - ) - # noinspection PyProtectedMember tracker._reset() - tracker.trigger_followup_action(ACTION_SESSION_START_NAME) diff --git a/rasa/core/policies/mapping_policy.py b/rasa/core/policies/mapping_policy.py index efcdf401aad8..8b7ec8d1174c 100644 --- a/rasa/core/policies/mapping_policy.py +++ b/rasa/core/policies/mapping_policy.py @@ -97,7 +97,7 @@ def predict_action_probabilities( elif intent == USER_INTENT_BACK: action = ACTION_BACK_NAME elif intent == USER_INTENT_SESSION_START: - action = ACTION_LISTEN_NAME + action = ACTION_SESSION_START_NAME else: action = domain.intent_properties.get(intent, {}).get("triggers") diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 82ed09fd41f0..6c2985fe9e2b 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -44,7 +44,6 @@ from rasa.core.tracker_store import TrackerStore from rasa.core.trackers import DialogueStateTracker, EventVerbosity from rasa.utils.endpoints import EndpointConfig -from typing import Coroutine logger = logging.getLogger(__name__) @@ -111,11 +110,11 @@ async def handle_message( else: return None - def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]: + async def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]: # we have a Tracker instance for each user # which maintains conversation state - tracker = self._get_tracker(sender_id) + tracker = await self.get_tracker_with_session_start(sender_id) if not tracker: logger.warning( f"Failed to retrieve or create tracker for sender '{sender_id}'." @@ -144,10 +143,7 @@ def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]: } async def _update_tracker_session( - self, - tracker: DialogueStateTracker, - output_channel: OutputChannel, - message_text: Optional[Text] = None, + self, tracker: DialogueStateTracker, output_channel: OutputChannel, ) -> None: """Check the current session in `tracker` and update it if expired. @@ -159,14 +155,8 @@ async def _update_tracker_session( tracker: Tracker to inspect. output_channel: Output channel for potential utterances in a custom `ActionSessionStart`. - message_text: Text of the incoming user message. """ - if ( - not tracker.applied_events() # new tracker - or self._has_session_expired(tracker) # session has expired - # a manual session start was requested - or (message_text == f"{INTENT_MESSAGE_PREFIX}{USER_INTENT_SESSION_START}") - ): + if not tracker.applied_events() or self._has_session_expired(tracker): logger.debug( f"Starting a new session for conversation ID '{tracker.sender_id}'." ) @@ -178,10 +168,7 @@ async def _update_tracker_session( ) async def get_tracker_with_session_start( - self, - sender_id: Text, - output_channel: Optional[OutputChannel] = None, - message_text: Text = None, + self, sender_id: Text, output_channel: Optional[OutputChannel] = None, ) -> Optional[DialogueStateTracker]: """Get tracker for `sender_id` or create a new tracker for `sender_id`. @@ -190,7 +177,6 @@ async def get_tracker_with_session_start( Args: output_channel: Output channel associated with the incoming user message. sender_id: Conversation ID for which to fetch the tracker. - message_text: Text of the incoming user message. Returns: Tracker for `sender_id` if available, `None` otherwise. @@ -200,7 +186,7 @@ async def get_tracker_with_session_start( if not tracker: return None - await self._update_tracker_session(tracker, output_channel, message_text) + await self._update_tracker_session(tracker, output_channel) return tracker @@ -220,7 +206,7 @@ async def log_message( # we have a Tracker instance for each user # which maintains conversation state tracker = await self.get_tracker_with_session_start( - message.sender_id, message.output_channel, message.text + message.sender_id, message.output_channel ) if tracker: @@ -248,7 +234,7 @@ async def execute_action( # we have a Tracker instance for each user # which maintains conversation state - tracker = self._get_tracker(sender_id) + tracker = await self.get_tracker_with_session_start(sender_id, output_channel) if tracker: action = self._get_action(action_name) await self._run_action( @@ -321,7 +307,7 @@ async def handle_reminder( ) -> None: """Handle a reminder that is triggered asynchronously.""" - tracker = self._get_tracker(sender_id) + tracker = await self.get_tracker_with_session_start(sender_id, output_channel) if not tracker: logger.warning( diff --git a/rasa/core/schemas/domain.yml b/rasa/core/schemas/domain.yml index 36405ec88ff5..c2f608d2c31f 100644 --- a/rasa/core/schemas/domain.yml +++ b/rasa/core/schemas/domain.yml @@ -37,6 +37,10 @@ mapping: mapping: store_entities_as_slots: type: "bool" + session_config: + type: "map" + allowempty: True + mapping: session_expiration_time: type: "number" range: diff --git a/rasa/core/trackers.py b/rasa/core/trackers.py index 1d65cddd5c45..30fd6aacc6ab 100644 --- a/rasa/core/trackers.py +++ b/rasa/core/trackers.py @@ -32,7 +32,6 @@ ) from rasa.core.domain import Domain # pytype: disable=pyi-error from rasa.core.slots import Slot -from rasa.core import constants logger = logging.getLogger(__name__) @@ -359,24 +358,10 @@ def undo_till_previous(event_type, done_events): if isinstance(e, event_type): break - def is_session_start_intent_message(_event: Event) -> bool: - """Determine whether `event` is a `UserUttered` event with the explicit - `/session_start` intent.""" - - return isinstance(_event, UserUttered) and _event.text == ( - f"{constants.INTENT_MESSAGE_PREFIX}" - f"{constants.USER_INTENT_SESSION_START}" - ) - applied_events = [] for event in self.events: if isinstance(event, (Restarted, SessionStarted)): applied_events = [] - elif is_session_start_intent_message(event): - # UserUttered('/session_start') messages need to be excluded from - # featurisation - similar to `restarted` events but the prior SlotSet - # events cannot be ignored - undo_till_previous(ActionExecuted, applied_events) elif isinstance(event, ActionReverted): undo_till_previous(ActionExecuted, applied_events) elif isinstance(event, UserUtteranceReverted): @@ -388,6 +373,7 @@ def is_session_start_intent_message(_event: Event) -> bool: undo_till_previous(ActionExecuted, applied_events) else: applied_events.append(event) + return applied_events def replay_events(self) -> None: diff --git a/rasa/server.py b/rasa/server.py index 03e2c0d68ae2..5ad81bc720e3 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -614,7 +614,7 @@ async def execute_action(request: Request, conversation_id: Text): async def predict(request: Request, conversation_id: Text): try: # Fetches the appropriate bot response in a json format - responses = app.agent.predict_next(conversation_id) + responses = await app.agent.predict_next(conversation_id) responses["scores"] = sorted( responses["scores"], key=lambda k: (-k["score"], k["action"]) ) diff --git a/tests/core/test_domain.py b/tests/core/test_domain.py index ef871aec11d6..56edcd3480c7 100644 --- a/tests/core/test_domain.py +++ b/tests/core/test_domain.py @@ -230,12 +230,13 @@ def test_domain_to_yaml(): test_yaml = """actions: - utter_greet config: - carry_over_slots_to_new_session: true - session_expiration_time: 60 store_entities_as_slots: true entities: [] forms: [] intents: [] +session_config: + carry_over_slots_to_new_session: true + session_expiration_time: 60 slots: {} templates: utter_greet: @@ -572,7 +573,7 @@ def test_add_knowledge_base_slots(default_domain): "input_domain, expected_session_expiration_time, expected_carry_over_slots", [ ( - """config: + """session_config: session_expiration_time: 0 carry_over_slots_to_new_session: true""", 0, @@ -580,19 +581,19 @@ def test_add_knowledge_base_slots(default_domain): ), ("", 0, True), ( - """config: + """session_config: carry_over_slots_to_new_session: false""", 0, False, ), ( - """config: + """session_config: session_expiration_time: 20.2 carry_over_slots_to_new_session: False""", 20.2, False, ), - ("""config: {}""", 0, True,), + ("""session_config: {}""", 0, True,), ], ) def test_session_config( diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 2386ce00a27b..3f36d2cc092f 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -473,67 +473,3 @@ async def test_handle_message_with_session_start( SlotSet(entity, slot_2[entity]), ActionExecuted(ACTION_LISTEN_NAME), ] - - -async def test_handle_message_with_session_start_intent( - default_channel: CollectingOutputChannel, default_processor: MessageProcessor, -): - sender_id = uuid.uuid4().hex - - # first send a normal message - entity = "name" - slot_1 = {entity: "Core"} - await default_processor.handle_message( - UserMessage(f"/greet{json.dumps(slot_1)}", default_channel, sender_id) - ) - - # send session start intent message - await default_processor.handle_message( - UserMessage( - f"/{constants.USER_INTENT_SESSION_START}", default_channel, sender_id - ) - ) - - # and another normal message - await default_processor.handle_message( - UserMessage(f"/greet", default_channel, sender_id) - ) - - tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) - - # make sure the sequence of events is as expected - assert list(tracker.events) == [ - ActionExecuted(ACTION_SESSION_START_NAME), - SessionStarted(), - ActionExecuted(ACTION_LISTEN_NAME), - UserUttered( - f"/greet{json.dumps(slot_1)}", - {"name": "greet", "confidence": 1.0}, - [{"entity": entity, "start": 6, "end": 22, "value": "Core"}], - ), - SlotSet(entity, slot_1[entity]), - ActionExecuted("utter_greet"), - BotUttered("hey there Core!"), - ActionExecuted(ACTION_LISTEN_NAME), - ActionExecuted(ACTION_SESSION_START_NAME), - SessionStarted(), - # the initial SlotSet is reapplied after the SessionStarted sequence - SlotSet(entity, slot_1[entity]), - ActionExecuted(ACTION_LISTEN_NAME), - UserUttered( - f"/{constants.USER_INTENT_SESSION_START}", - {"name": constants.USER_INTENT_SESSION_START, "confidence": 1.0}, - ), - ActionExecuted(ACTION_LISTEN_NAME), - UserUttered(f"/greet", {"name": "greet", "confidence": 1.0},), - ActionExecuted(ACTION_LISTEN_NAME), - ] - - # the applied events should not include the UserUttered event, but they must - # include the SlotSet event - assert tracker.applied_events() == [ - SlotSet(entity, slot_1[entity]), - ActionExecuted(ACTION_LISTEN_NAME), - UserUttered(f"/greet", {"name": "greet", "confidence": 1.0},), - ActionExecuted(ACTION_LISTEN_NAME), - ] diff --git a/tests/core/test_training.py b/tests/core/test_training.py index 2e7e5b8b86ea..0714e9791d86 100644 --- a/tests/core/test_training.py +++ b/tests/core/test_training.py @@ -141,6 +141,6 @@ async def test_random_seed(tmpdir, config_file): processor_1 = agent_1.create_processor() processor_2 = agent_2.create_processor() - probs_1 = processor_1.predict_next("1") - probs_2 = processor_2.predict_next("2") + probs_1 = await processor_1.predict_next("1") + probs_2 = await processor_2.predict_next("2") assert probs_1["confidence"] == probs_2["confidence"] From 19146518420a2040e8f95b0c458f75f1fe7cc402 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 11 Dec 2019 11:57:05 +0100 Subject: [PATCH 70/74] remove constants import --- tests/core/test_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 3f36d2cc092f..fa73c0478c4d 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -16,7 +16,6 @@ from rasa.core.agent import Agent from rasa.core.channels.channel import CollectingOutputChannel, UserMessage from rasa.core.domain import SessionConfig -from rasa.core import constants from rasa.core.events import ( ActionExecuted, BotUttered, From c5a3bf6be08613ad40a732272d93501462d4bc84 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 11 Dec 2019 14:56:00 +0100 Subject: [PATCH 71/74] fix test --- tests/core/test_trackers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_trackers.py b/tests/core/test_trackers.py index c29c3bf27802..ade53c5287bb 100644 --- a/tests/core/test_trackers.py +++ b/tests/core/test_trackers.py @@ -304,9 +304,6 @@ def test_session_start(default_domain: Domain): # tracker has one event assert len(tracker.events) == 1 - # follow-up action should be 'session_start' - assert tracker.followup_action == ACTION_SESSION_START_NAME - def test_revert_action_event(default_domain: Domain): tracker = DialogueStateTracker("default", default_domain.slots) From 28688a8ffcf514e4a500428d458cf3ccafa79e11 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Wed, 11 Dec 2019 15:10:22 +0100 Subject: [PATCH 72/74] add test for the case when a user manually triggers a session restart --- tests/core/test_actions.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index ad8d6b2eaca4..af42403ad261 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -38,8 +38,10 @@ SessionStarted, ActionExecuted, Event, + UserUttered, ) from rasa.core.nlg.template import TemplatedNaturalLanguageGenerator +from rasa.core.constants import USER_INTENT_SESSION_START from rasa.core.trackers import DialogueStateTracker from rasa.utils.endpoints import ClientResponseError, EndpointConfig from tests.utilities import json_of_latest_request, latest_request @@ -552,6 +554,32 @@ async def test_action_session_start_with_slots( assert sorted(events, key=lambda x: x.timestamp) == events +async def test_applied_events_after_action_session_start( + default_channel: CollectingOutputChannel, + template_nlg: TemplatedNaturalLanguageGenerator, +): + slot_set = SlotSet("my_slot", "value") + events = [ + slot_set, + ActionExecuted(ACTION_LISTEN_NAME), + # User triggers a restart manually by triggering the intent + UserUttered( + text=f"/{USER_INTENT_SESSION_START}", + intent={"name": USER_INTENT_SESSION_START}, + ), + ] + tracker = DialogueStateTracker.from_events("🕵️‍♀️", events) + + # Mapping Policy kicks in and runs the session restart action + events = await ActionSessionStart().run( + default_channel, template_nlg, tracker, Domain.empty() + ) + for event in events: + tracker.update(event) + + assert tracker.applied_events() == [slot_set, ActionExecuted(ACTION_LISTEN_NAME)] + + async def test_action_default_fallback( default_channel, default_nlg, default_tracker, default_domain ): From 0be88abab42dae309de8080e2b359d7f71613b27 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 11 Dec 2019 18:59:49 +0100 Subject: [PATCH 73/74] Apply suggestions from code review Co-Authored-By: Tobias Wochinger Co-Authored-By: Tom Bocklisch --- rasa/core/lock_store.py | 1 - rasa/core/training/structures.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/rasa/core/lock_store.py b/rasa/core/lock_store.py index 3ebbbc4acbcc..1a7be3a00f32 100644 --- a/rasa/core/lock_store.py +++ b/rasa/core/lock_store.py @@ -3,7 +3,6 @@ import logging import os - from async_generator import asynccontextmanager from typing import Text, Optional, AsyncGenerator diff --git a/rasa/core/training/structures.py b/rasa/core/training/structures.py index c9cee4f3f82f..2866be528a07 100644 --- a/rasa/core/training/structures.py +++ b/rasa/core/training/structures.py @@ -317,13 +317,13 @@ def as_story_string(self, flat: bool = False, e2e: bool = False) -> Text: return result @staticmethod - def _is_action_listen(event: ActionExecuted) -> bool: + def _is_action_listen(event: Event) -> bool: # this is not an `isinstance` because # we don't want to allow subclasses here return type(event) == ActionExecuted and event.action_name == ACTION_LISTEN_NAME @staticmethod - def _is_action_session_start(event: ActionExecuted) -> bool: + def _is_action_session_start(event: Event) -> bool: # this is not an `isinstance` because # we don't want to allow subclasses here return ( @@ -331,7 +331,7 @@ def _is_action_session_start(event: ActionExecuted) -> bool: and event.action_name == ACTION_SESSION_START_NAME ) - def _add_action_listen(self, events: List[ActionExecuted]) -> None: + def _add_action_listen(self, events: List[Event]) -> None: if not events or not self._is_action_listen(events[-1]): # do not add second action_listen events.append(ActionExecuted(ACTION_LISTEN_NAME)) From 3b167f4700232cf7a733537c0d566b3933af595a Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 11 Dec 2019 19:00:15 +0100 Subject: [PATCH 74/74] warnings.warn->logger.warning --- rasa/core/events/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/events/__init__.py b/rasa/core/events/__init__.py index cafe01824622..35af7c297205 100644 --- a/rasa/core/events/__init__.py +++ b/rasa/core/events/__init__.py @@ -1198,7 +1198,7 @@ def __str__(self) -> Text: return "SessionStarted()" def as_story_string(self) -> None: - warnings.warn( + logger.warning( f"'{self.type_name}' events cannot be serialised as story strings." ) return None