Skip to content

Commit

Permalink
update slots in 'tracker.update()'
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed May 22, 2019
1 parent a925f77 commit ea36bc5
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 7 deletions.
6 changes: 2 additions & 4 deletions rasa/core/processor.py
Expand Up @@ -295,11 +295,9 @@ async def _handle_message_with_tracker(
parse_data,
input_channel=message.input_channel,
message_id=message.message_id,
)
),
self.domain,
)
# store all entities as slots
for e in self.domain.slots_for_entities(parse_data["entities"]):
tracker.update(e)

if parse_data["entities"]:
self._log_slots(tracker)
Expand Down
8 changes: 6 additions & 2 deletions rasa/core/trackers.py
@@ -1,6 +1,5 @@
import copy
import logging
import typing
from collections import deque
from enum import Enum
import typing
Expand Down Expand Up @@ -400,14 +399,19 @@ def as_dialogue(self) -> Dialogue:

return Dialogue(self.sender_id, list(self.events))

def update(self, event: Event) -> None:
def update(self, event: Event, domain: Optional["Domain"] = None) -> None:
"""Modify the state of the tracker according to an ``Event``. """
if not isinstance(event, Event): # pragma: no cover
raise ValueError("event to log must be an instance of a subclass of Event.")

self.events.append(event)
event.apply_to(self)

if domain and isinstance(event, UserUttered):
# store all entities as slots
for e in domain.slots_for_entities(event.parse_data["entities"]):
self.update(e)

def export_stories(self, e2e=False) -> Text:
"""Dump the tracker as a story in the Rasa Core story format.
Expand Down
2 changes: 1 addition & 1 deletion rasa/server.py
Expand Up @@ -376,7 +376,7 @@ async def append_event(request: Request, conversation_id: Text):

if evt:
try:
tracker.update(evt)
tracker.update(evt, app.agent.domain)
app.agent.tracker_store.save(tracker)
return response.json(tracker.current_state(verbosity))
except Exception as e:
Expand Down
27 changes: 27 additions & 0 deletions tests/core/test_trackers.py
Expand Up @@ -209,6 +209,33 @@ def test_tracker_entity_retrieval(default_domain):
assert list(tracker.get_latest_entity_values("unknown")) == []


def test_tracker_update_slots_with_entity(default_domain):
tracker = DialogueStateTracker("default", default_domain.slots)

test_entity = default_domain.entities[0]
expected_slot_value = "test user"

intent = {"name": "greet", "confidence": 1.0}
tracker.update(
UserUttered(
"/greet",
intent,
[
{
"start": 1,
"end": 5,
"value": expected_slot_value,
"entity": test_entity,
"extractor": "manual",
}
],
),
default_domain,
)

assert tracker.get_slot(test_entity) == expected_slot_value


def test_restart_event(default_domain):
tracker = DialogueStateTracker("default", default_domain.slots)
# the retrieved tracker should be empty
Expand Down

0 comments on commit ea36bc5

Please sign in to comment.