Skip to content

Commit

Permalink
Merge pull request #5173 from RasaHQ/issue_3923
Browse files Browse the repository at this point in the history
refactor `use_entities` and `ignore_entities`
  • Loading branch information
chkoss committed Feb 18, 2020
2 parents 14b8798 + 9039aca commit aaf4e6a
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 102 deletions.
3 changes: 3 additions & 0 deletions changelog/3923.misc.rst
@@ -0,0 +1,3 @@
Internally, intents now have only one property ``used_entities`` to indicate which
entities should be used. For displaying purposes and in ``domain.yml`` files, the
properties ``use_entities`` and/or ``ignore_entities`` will be used as before.
196 changes: 137 additions & 59 deletions rasa/core/domain.py
@@ -1,4 +1,5 @@
import collections
import copy
import json
import logging
import os
Expand Down Expand Up @@ -42,6 +43,9 @@
CARRY_OVER_SLOTS_KEY = "carry_over_slots_to_new_session"
SESSION_EXPIRATION_TIME_KEY = "session_expiration_time"
SESSION_CONFIG_KEY = "session_config"
USED_ENTITIES_KEY = "used_entities"
USE_ENTITIES_KEY = "use_entities"
IGNORE_ENTITIES_KEY = "ignore_entities"

if typing.TYPE_CHECKING:
from rasa.core.trackers import DialogueStateTracker
Expand Down Expand Up @@ -75,7 +79,7 @@ class Domain:
"""The domain specifies the universe in which the bot's policy acts.
A Domain subclass provides the actions the bot can take, the intents
and entities it can recognise"""
and entities it can recognise."""

@classmethod
def empty(cls) -> "Domain":
Expand Down Expand Up @@ -264,40 +268,97 @@ def collect_slots(slot_dict: Dict[Text, Any]) -> List[Slot]:
return slots

@staticmethod
def _transform_intent_properties_for_internal_use(
intent: Dict[Text, Any], entities: List
) -> Dict[Text, Any]:
"""Transform intent properties coming from a domain file for internal use.
In domain files, `use_entities` or `ignore_entities` is used. Internally, there
is a property `used_entities` instead that lists all entities to be used.
Args:
intent: The intents as provided by a domain file.
entities: All entities as provided by a domain file.
Returns:
The intents as they should be used internally.
"""
name, properties = list(intent.items())[0]

properties.setdefault(USE_ENTITIES_KEY, True)
properties.setdefault(IGNORE_ENTITIES_KEY, [])
if not properties[USE_ENTITIES_KEY]: # this covers False, None and []
properties[USE_ENTITIES_KEY] = []

# `use_entities` is either a list of explicitly included entities
# or `True` if all should be included
if properties[USE_ENTITIES_KEY] is True:
included_entities = set(entities)
else:
included_entities = set(properties[USE_ENTITIES_KEY])
excluded_entities = set(properties[IGNORE_ENTITIES_KEY])
used_entities = list(included_entities - excluded_entities)
used_entities.sort()

# Only print warning for ambiguous configurations if entities were included
# explicitly.
explicitly_included = isinstance(properties[USE_ENTITIES_KEY], list)
ambiguous_entities = included_entities.intersection(excluded_entities)
if explicitly_included and ambiguous_entities:
raise_warning(
f"Entities: '{ambiguous_entities}' are explicitly included and"
f" excluded for intent '{name}'."
f"Excluding takes precedence in this case. "
f"Please resolve that ambiguity.",
docs=f"{DOCS_URL_DOMAINS}#ignoring-entities-for-certain-intents",
)

properties[USED_ENTITIES_KEY] = used_entities
del properties[USE_ENTITIES_KEY]
del properties[IGNORE_ENTITIES_KEY]

return intent

@classmethod
def collect_intent_properties(
intents: List[Union[Text, Dict[Text, Any]]]
cls, intents: List[Union[Text, Dict[Text, Any]]], entities: List[Text]
) -> Dict[Text, Dict[Text, Union[bool, List]]]:
"""Get intent properties for a domain from what is provided by a domain file.
Args:
intents: The intents as provided by a domain file.
entities: All entities as provided by a domain file.
Returns:
The intent properties to be stored in the domain.
"""
intent_properties = {}
duplicates = set()
for intent in intents:
if isinstance(intent, dict):
name = list(intent.keys())[0]
for properties in intent.values():
properties.setdefault("use_entities", True)
properties.setdefault("ignore_entities", [])
if (
properties["use_entities"] is None
or properties["use_entities"] is False
):
properties["use_entities"] = []
else:
name = intent
intent = {intent: {"use_entities": True, "ignore_entities": []}}
if not isinstance(intent, dict):
intent = {intent: {USE_ENTITIES_KEY: True, IGNORE_ENTITIES_KEY: []}}

name = list(intent.keys())[0]
if name in intent_properties.keys():
raise InvalidDomain(
"Intents are not unique! Found two intents with name '{}'. "
"Either rename or remove one of them.".format(name)
)
duplicates.add(name)

intent = cls._transform_intent_properties_for_internal_use(intent, entities)

intent_properties.update(intent)

if duplicates:
raise InvalidDomain(
f"Intents are not unique! Found multiple intents with name(s) {sorted(duplicates)}. "
f"Either rename or remove the duplicate ones."
)

return intent_properties

@staticmethod
def collect_templates(
yml_templates: Dict[Text, List[Any]]
) -> Dict[Text, List[Dict[Text, Any]]]:
"""Go through the templates and make sure they are all in dict format
"""
"""Go through the templates and make sure they are all in dict format."""
templates = {}
for template_key, template_variations in yml_templates.items():
validated_variations = []
Expand Down Expand Up @@ -345,7 +406,7 @@ def __init__(
session_config: SessionConfig = SessionConfig.default(),
) -> None:

self.intent_properties = self.collect_intent_properties(intents)
self.intent_properties = self.collect_intent_properties(intents, entities)
self.entities = entities
self.form_names = form_names
self.slots = slots
Expand Down Expand Up @@ -376,7 +437,7 @@ def __hash__(self) -> int:

@lazy_property
def user_actions_and_forms(self):
"""Returns combination of user actions and forms"""
"""Returns combination of user actions and forms."""

return self.user_actions + self.form_names

Expand All @@ -394,7 +455,7 @@ def num_states(self):
return len(self.input_states)

def add_categorical_slot_default_value(self) -> None:
"""Add a default value to all categorical slots
"""Add a default value to all categorical slots.
All unseen values found for the slot will be mapped to this default value
for featurization.
Expand Down Expand Up @@ -439,7 +500,7 @@ def add_knowledge_base_slots(self) -> None:
def action_for_name(
self, action_name: Text, action_endpoint: Optional[EndpointConfig]
) -> Optional[Action]:
"""Looks up which action corresponds to this action name."""
"""Look up which action corresponds to this action name."""

if action_name not in self.action_names:
self._raise_action_not_found_exception(action_name)
Expand Down Expand Up @@ -470,7 +531,7 @@ def actions(self, action_endpoint) -> List[Optional[Action]]:
]

def index_for_action(self, action_name: Text) -> Optional[int]:
"""Looks up which action index corresponds to this action name"""
"""Look up which action index corresponds to this action name."""

try:
return self.action_names.index(action_name)
Expand Down Expand Up @@ -532,13 +593,13 @@ def form_states(self) -> List[Text]:
return [f"active_form_{f}" for f in self.form_names]

def index_of_state(self, state_name: Text) -> Optional[int]:
"""Provides the index of a state."""
"""Provide the index of a state."""

return self.input_state_map.get(state_name)

@lazy_property
def input_state_map(self) -> Dict[Text, int]:
"""Provides a mapping from state names to indices."""
"""Provide a mapping from state names to indices."""
return {f: i for i, f in enumerate(self.input_states)}

@lazy_property
Expand Down Expand Up @@ -598,32 +659,14 @@ def _get_featurized_entities(self, latest_message: UserUttered) -> Set[Text]:
entity["entity"] for entity in entities if "entity" in entity.keys()
}

# `use_entities` is either a list of explicitly included entities
# or `True` if all should be included
include = intent_config.get("use_entities", True)
included_entities = set(entity_names if include is True else include)
excluded_entities = set(intent_config.get("ignore_entities", []))
wanted_entities = included_entities - excluded_entities

# Only print warning for ambiguous configurations if entities were included
# explicitly.
explicitly_included = isinstance(include, list)
ambiguous_entities = included_entities.intersection(excluded_entities)
if explicitly_included and ambiguous_entities:
raise_warning(
f"Entities: '{ambiguous_entities}' are explicitly included and"
f" excluded for intent '{intent_name}'."
f"Excluding takes precedence in this case. "
f"Please resolve that ambiguity.",
docs=DOCS_URL_DOMAINS + "#ignoring-entities-for-certain-intents",
)
wanted_entities = set(intent_config.get(USED_ENTITIES_KEY, entity_names))

return entity_names.intersection(wanted_entities)

def get_prev_action_states(
self, tracker: "DialogueStateTracker"
) -> Dict[Text, float]:
"""Turns the previous taken action into a state name."""
"""Turn the previous taken action into a state name."""

latest_action = tracker.latest_action_name
if latest_action:
Expand All @@ -637,15 +680,15 @@ def get_prev_action_states(

@staticmethod
def get_active_form(tracker: "DialogueStateTracker") -> Dict[Text, float]:
"""Turns tracker's active form into a state name."""
"""Turn tracker's active form into a state name."""
form = tracker.active_form.get("name")
if form is not None:
return {ACTIVE_FORM_PREFIX + form: 1.0}
else:
return {}

def get_active_states(self, tracker: "DialogueStateTracker") -> Dict[Text, float]:
"""Return a bag of active states from the tracker state"""
"""Return a bag of active states from the tracker state."""
state_dict = self.get_parsing_states(tracker)
state_dict.update(self.get_prev_action_states(tracker))
state_dict.update(self.get_active_form(tracker))
Expand Down Expand Up @@ -677,7 +720,7 @@ def slots_for_entities(self, entities: List[Dict[Text, Any]]) -> List[SlotSet]:
return []

def persist_specification(self, model_path: Text) -> None:
"""Persists the domain specification to storage."""
"""Persist the domain specification to storage."""

domain_spec_path = os.path.join(model_path, "domain.json")
rasa.utils.io.create_directory_for_file(domain_spec_path)
Expand All @@ -694,7 +737,7 @@ def load_specification(cls, path: Text) -> Dict[Text, Any]:
return specification

def compare_with_specification(self, path: Text) -> bool:
"""Compares the domain spec of the current and the loaded domain.
"""Compare the domain spec of the current and the loaded domain.
Throws exception if the loaded domain specification is different
to the current domain are different."""
Expand Down Expand Up @@ -727,7 +770,7 @@ def as_dict(self) -> Dict[Text, Any]:
SESSION_EXPIRATION_TIME_KEY: self.session_config.session_expiration_time,
CARRY_OVER_SLOTS_KEY: self.session_config.carry_over_slots,
},
"intents": [{k: v} for k, v in self.intent_properties.items()],
"intents": self._transform_intents_for_file(),
"entities": self.entities,
"slots": self._slot_definitions(),
"responses": self.templates,
Expand All @@ -741,16 +784,51 @@ def persist(self, filename: Union[Text, Path]) -> None:
domain_data = self.as_dict()
utils.dump_obj_as_yaml_to_file(filename, domain_data)

def _transform_intents_for_file(self) -> List[Union[Text, Dict[Text, Any]]]:
"""Transform intent properties for displaying or writing into a domain file.
Internally, there is a property `used_entities` that lists all entities to be
used. In domain files, `use_entities` or `ignore_entities` is used instead to
list individual entities to ex- or include, because this is easier to read.
Returns:
The intent properties as they are used in domain files.
"""
intent_properties = copy.deepcopy(self.intent_properties)
intents_for_file = []

for intent_name, intent_props in intent_properties.items():
use_entities = set(intent_props[USED_ENTITIES_KEY])
ignore_entities = set(self.entities) - use_entities
if len(use_entities) == len(self.entities):
intent_props[USE_ENTITIES_KEY] = True
elif len(use_entities) <= len(self.entities) / 2:
intent_props[USE_ENTITIES_KEY] = list(use_entities)
else:
intent_props[IGNORE_ENTITIES_KEY] = list(ignore_entities)
intent_props.pop(USED_ENTITIES_KEY)
intents_for_file.append({intent_name: intent_props})

return intents_for_file

def cleaned_domain(self) -> Dict[Text, Any]:
"""Fetch cleaned domain, replacing redundant keys with default values."""
"""Fetch cleaned domain to display or write into a file.
The internal `used_entities` property is replaced by `use_entities` or
`ignore_entities` and redundant keys are replaced with default values
to make the domain easier readable.
Returns:
A cleaned dictionary version of the domain.
"""
domain_data = self.as_dict()

for idx, intent_info in enumerate(domain_data["intents"]):
for name, intent in intent_info.items():
if intent.get("use_entities") is True:
intent.pop("use_entities")
if not intent.get("ignore_entities"):
intent.pop("ignore_entities", None)
if intent.get(USE_ENTITIES_KEY) is True:
del intent[USE_ENTITIES_KEY]
if not intent.get(IGNORE_ENTITIES_KEY):
intent.pop(IGNORE_ENTITIES_KEY, None)
if len(intent) == 0:
domain_data["intents"][idx] = name

Expand Down Expand Up @@ -988,7 +1066,7 @@ def check_missing_templates(self) -> None:
)

def is_empty(self) -> bool:
"""Checks whether the domain is empty."""
"""Check whether the domain is empty."""

return self.as_dict() == Domain.empty().as_dict()

Expand Down

0 comments on commit aaf4e6a

Please sign in to comment.