-
Notifications
You must be signed in to change notification settings - Fork 68
/
actions.py
347 lines (265 loc) · 11.7 KB
/
actions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
# -*- coding: utf-8 -*-
from typing import Text, Dict, Any, List, Union
from rasa_sdk.events import SlotSet
from rasa_sdk import Action, Tracker
from schema import schema
from graph_database import GraphDatabase
def resolve_mention(tracker: Tracker) -> Text:
"""
Resolves a mention of an entity, such as first, to the actual entity.
If multiple entities are listed during the conversation, the entities
are stored in the slot 'listed_items' as a list. We resolve the mention,
such as first, to the list index and retrieve the actual entity.
:param tracker: tracker
:return: name of the actually entity
"""
graph_database = GraphDatabase()
mention = tracker.get_slot("mention")
listed_items = tracker.get_slot("listed_items")
if mention is not None and listed_items is not None:
idx = int(graph_database.map("mention-mapping", mention))
if type(idx) is int and idx < len(listed_items):
return listed_items[idx]
def get_entity_type(tracker: Tracker) -> Text:
"""
Get the entity type mentioned by the user. As the user may speak of an
entity type in plural, we need to map the mentioned entity type to the
type used in the knowledge base.
:param tracker: tracker
:return: entity type (same type as used in the knowledge base)
"""
graph_database = GraphDatabase()
entity_type = tracker.get_slot("entity_type")
return graph_database.map("entity-type-mapping", entity_type)
def get_attribute(tracker: Tracker) -> Text:
"""
Get the attribute mentioned by the user. As the user may use a synonym for
an attribute, we need to map the mentioned attribute to the
attribute name used in the knowledge base.
:param tracker: tracker
:return: attribute (same type as used in the knowledge base)
"""
graph_database = GraphDatabase()
attribute = tracker.get_slot("attribute")
return graph_database.map("attribute-mapping", attribute)
def get_entity_name(tracker: Tracker, entity_type: Text):
"""
Get the name of the entity the user referred to. Either the NER detected the
entity and stored its name in the corresponding slot or the user referred to
the entity by an ordinal number, such as first or last, or the user refers to
an entity by its attributes.
:param tracker: Tracker
:param entity_type: the entity type
:return: the name of the actual entity (value of key attribute in the knowledge base)
"""
# user referred to an entity by an ordinal number
mention = tracker.get_slot("mention")
if mention is not None:
return resolve_mention(tracker)
# user named the entity
entity_name = tracker.get_slot(entity_type)
if entity_name:
return entity_name
# user referred to an entity by its attributes
listed_items = tracker.get_slot("listed_items")
attributes = get_attributes_of_entity(entity_type, tracker)
if listed_items and attributes:
# filter the listed_items by the set attributes
graph_database = GraphDatabase()
for entity in listed_items:
key_attr = schema[entity_type]["key"]
result = graph_database.validate_entity(
entity_type, entity, key_attr, attributes
)
if result is not None:
return to_str(result, key_attr)
return None
def get_attributes_of_entity(entity_type, tracker):
# check what attributes the NER found for entity type
attributes = []
if entity_type in schema:
for attr in schema[entity_type]["attributes"]:
attr_val = tracker.get_slot(attr.replace("-", "_"))
if attr_val is not None:
attributes.append({"key": attr, "value": attr_val})
return attributes
def reset_attribute_slots(slots, entity_type, tracker):
# check what attributes the NER found for entity type
if entity_type in schema:
for attr in schema[entity_type]["attributes"]:
attr = attr.replace("-", "_")
attr_val = tracker.get_slot(attr)
if attr_val is not None:
slots.append(SlotSet(attr, None))
return slots
def to_str(entity: Dict[Text, Any], entity_keys: Union[Text, List[Text]]) -> Text:
"""
Converts an entity to a string by concatenating the values of the provided
entity keys.
:param entity: the entity with all its attributes
:param entity_keys: the name of the key attributes
:return: a string that represents the entity
"""
if isinstance(entity_keys, str):
entity_keys = [entity_keys]
v_list = []
for key in entity_keys:
_e = entity
for k in key.split("."):
_e = _e[k]
if "balance" in key or "amount" in key:
v_list.append(f"{str(_e)} €")
elif "date" in key:
v_list.append(_e.strftime("%d.%m.%Y (%H:%M:%S)"))
else:
v_list.append(str(_e))
return ", ".join(v_list)
class ActionQueryEntities(Action):
"""Action for listing entities.
The entities might be filtered by specific attributes."""
def name(self):
return "action_query_entities"
def run(self, dispatcher, tracker, domain):
graph_database = GraphDatabase()
# first need to know the entity type we are looking for
entity_type = get_entity_type(tracker)
if entity_type is None:
dispatcher.utter_template("utter_rephrase", tracker)
return []
# check what attributes the NER found for entity type
attributes = get_attributes_of_entity(entity_type, tracker)
# query knowledge base
entities = graph_database.get_entities(entity_type, attributes)
# filter out transactions that do not belong the set account (if any)
if entity_type == "transaction":
account_number = tracker.get_slot("account")
entities = self._filter_transaction_entities(entities, account_number)
if not entities:
dispatcher.utter_template(
"I could not find any entities for '{}'.".format(entity_type), tracker
)
return []
# utter a response that contains all found entities
# use the 'representation' attributes to print an entity
entity_representation = schema[entity_type]["representation"]
dispatcher.utter_message(
"Found the following '{}' entities:".format(entity_type)
)
sorted_entities = sorted([to_str(e, entity_representation) for e in entities])
for i, e in enumerate(sorted_entities):
dispatcher.utter_message(f"{i + 1}: {e}")
# set slots
# set the entities slot in order to resolve references to one of the found
# entites later on
entity_key = schema[entity_type]["key"]
slots = [
SlotSet("entity_type", entity_type),
SlotSet("listed_items", list(map(lambda x: to_str(x, entity_key), entities))),
]
# if only one entity was found, that the slot of that entity type to the
# found entity
if len(entities) == 1:
slots.append(SlotSet(entity_type, to_str(entities[0], entity_key)))
reset_attribute_slots(slots, entity_type, tracker)
return slots
def _filter_transaction_entities(
self, entities: List[Dict[Text, Any]], account_number: Text
) -> List[Dict[Text, Any]]:
"""
Filter out all transactions that do not belong to the provided account number.
:param entities: list of transactions
:param account_number: account number
:return: list of filtered transactions with max. 5 entries
"""
if account_number is not None:
filtered_entities = []
for entity in entities:
if entity["account-of-creator"]["account-number"] == account_number:
filtered_entities.append(entity)
return filtered_entities[:5]
return entities[:5]
class ActionQueryAttribute(Action):
"""Action for querying a specific attribute of an entity."""
def name(self):
return "action_query_attribute"
def run(self, dispatcher, tracker, domain):
graph_database = GraphDatabase()
# get entity type of entity
entity_type = get_entity_type(tracker)
if entity_type is None:
dispatcher.utter_template("utter_rephrase", tracker)
return []
# get name of entity and attribute of interest
name = get_entity_name(tracker, entity_type)
attribute = get_attribute(tracker)
if name is None or attribute is None:
dispatcher.utter_template("utter_rephrase", tracker)
slots = [SlotSet("mention", None)]
reset_attribute_slots(slots, entity_type, tracker)
return slots
# query knowledge base
key_attribute = schema[entity_type]["key"]
value = graph_database.get_attribute_of(
entity_type, key_attribute, name, attribute
)
# utter response
if value is not None and len(value) == 1:
dispatcher.utter_message(
f"{name} has the value '{value[0]}' for attribute '{attribute}'."
)
else:
dispatcher.utter_message(
f"Did not found a valid value for attribute {attribute} for entity '{name}'."
)
slots = [SlotSet("mention", None), SlotSet(entity_type, name)]
reset_attribute_slots(slots, entity_type, tracker)
return slots
class ActionCompareEntities(Action):
"""Action for comparing multiple entities."""
def name(self):
return "action_compare_entities"
def run(self, dispatcher, tracker, domain):
graph = GraphDatabase()
# get entities to compare and their entity type
listed_items = tracker.get_slot("listed_items")
entity_type = get_entity_type(tracker)
if listed_items is None or entity_type is None:
dispatcher.utter_template("utter_rephrase", tracker)
return []
# get attribute of interest
attribute = get_attribute(tracker)
if attribute is None:
dispatcher.utter_template("utter_rephrase", tracker)
return []
# utter response for every entity that shows the value of the attribute
for e in listed_items:
key_attribute = schema[entity_type]["key"]
value = graph.get_attribute_of(entity_type, key_attribute, e, attribute)
if value is not None and len(value) == 1:
dispatcher.utter_message(
f"{e} has the value '{value[0]}' for attribute '{attribute}'."
)
return []
class ActionResolveEntity(Action):
"""Action for resolving a mention."""
def name(self):
return "action_resolve_entity"
def run(self, dispatcher, tracker, domain):
entity_type = tracker.get_slot("entity_type")
listed_items = tracker.get_slot("listed_items")
if entity_type is None:
dispatcher.utter_template("utter_rephrase", tracker)
return []
# Check if entity was mentioned as 'first', 'second', etc.
mention = tracker.get_slot("mention")
if mention is not None:
value = resolve_mention(tracker)
if value is not None:
return [SlotSet(entity_type, value), SlotSet("mention", None)]
# Check if NER recognized entity directly
# (e.g. bank name was mentioned and recognized as 'bank')
value = tracker.get_slot(entity_type)
if value is not None and value in listed_items:
return [SlotSet(entity_type, value), SlotSet("mention", None)]
dispatcher.utter_template("utter_rephrase", tracker)
return [SlotSet(entity_type, None), SlotSet("mention", None)]