Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add full retrieval intent name #5432

Merged
merged 11 commits into from Mar 19, 2020
4 changes: 4 additions & 0 deletions changelog/4826.improvement.rst
@@ -0,0 +1,4 @@
Add full retrieval intent name to message data
``ResponseSelector`` will now add the full retrieval intent name
e.g. ``faq/which_version`` to the prediction, making it accessible
from the tracker.
80 changes: 79 additions & 1 deletion rasa/nlu/selectors/response_selector.py
Expand Up @@ -2,13 +2,16 @@

import numpy as np
import tensorflow as tf
from pathlib import Path

from typing import Any, Dict, Optional, Text, Tuple, Union, List, Type

import rasa.utils.io as io_utils
from rasa.nlu.config import InvalidConfigError
from rasa.nlu.training_data import TrainingData, Message
from rasa.nlu.components import Component
from rasa.nlu.featurizers.featurizer import Featurizer
from rasa.nlu.model import Metadata
from rasa.nlu.classifiers.diet_classifier import (
DIETClassifier,
DIET,
Expand Down Expand Up @@ -66,9 +69,12 @@
from rasa.nlu.constants import (
RESPONSE,
RESPONSE_SELECTOR_PROPERTY_NAME,
RESPONSE_KEY_ATTRIBUTE,
INTENT,
DEFAULT_OPEN_UTTERANCE_TYPE,
TEXT,
)

from rasa.utils.tensorflow.model_data import RasaModelData
from rasa.utils.tensorflow.models import RasaModel

Expand Down Expand Up @@ -203,6 +209,7 @@ def __init__(
index_label_id_mapping: Optional[Dict[int, Text]] = None,
index_tag_id_mapping: Optional[Dict[int, Text]] = None,
model: Optional[RasaModel] = None,
retrieval_intent_mapping: Optional[Dict[Text, Text]] = None,
) -> None:

component_config = component_config or {}
Expand All @@ -211,6 +218,7 @@ def __init__(
component_config[INTENT_CLASSIFICATION] = True
component_config[ENTITY_RECOGNITION] = False
component_config[BILOU_FLAG] = None
self.retrieval_intent_mapping = retrieval_intent_mapping or {}

super().__init__(
component_config, index_label_id_mapping, index_tag_id_mapping, model
Expand All @@ -231,6 +239,20 @@ def _check_config_parameters(self) -> None:
super()._check_config_parameters()
self._load_selector_params(self.component_config)

@staticmethod
def _create_retrieval_intent_mapping(
training_data: TrainingData,
) -> Dict[Text, Text]:
"""Create response_key dictionary"""

retrieval_intent_mapping = {}
for example in training_data.intent_examples:
retrieval_intent_mapping[
example.get(RESPONSE)
] = f"{example.get(INTENT)}/{example.get(RESPONSE_KEY_ATTRIBUTE)}"

return retrieval_intent_mapping

@staticmethod
def _set_message_property(
message: Message, prediction_dict: Dict[Text, Any], selector_key: Text
Expand Down Expand Up @@ -262,6 +284,9 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
label_id_index_mapping = self._label_id_index_mapping(
training_data, attribute=RESPONSE
)
self.retrieval_intent_mapping = self._create_retrieval_intent_mapping(
training_data
)

if not label_id_index_mapping:
# no labels are present to train
Expand All @@ -288,6 +313,7 @@ def process(self, message: Message, **kwargs: Any) -> None:

out = self._predict(message)
label, label_ranking = self._predict_label(out)
retrieval_intent_name = self.retrieval_intent_mapping.get(label.get("name"))

selector_key = (
self.retrieval_intent
Expand All @@ -299,10 +325,62 @@ def process(self, message: Message, **kwargs: Any) -> None:
f"Adding following selector key to message property: {selector_key}"
)

prediction_dict = {"response": label, "ranking": label_ranking}
prediction_dict = {
"response": label,
"ranking": label_ranking,
"full_retrieval_intent": retrieval_intent_name,
}

self._set_message_property(message, prediction_dict, selector_key)

def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]:
"""Persist this model into the passed directory.

Return the metadata necessary to load the model again.
"""
if self.model is None:
indam23 marked this conversation as resolved.
Show resolved Hide resolved
return {"file": None}

super().persist(file_name, model_dir)

model_dir = Path(model_dir)

io_utils.json_pickle(
model_dir / f"{file_name}.retrieval_intent_mapping.pkl",
self.retrieval_intent_mapping,
)

return {"file": file_name}

@classmethod
def load(
cls,
meta: Dict[Text, Any],
model_dir: Text = None,
model_metadata: Metadata = None,
cached_component: Optional["ResponseSelector"] = None,
**kwargs: Any,
) -> "ResponseSelector":
"""Loads the trained model from the provided directory."""

model = super().load(
meta, model_dir, model_metadata, cached_component, **kwargs
)
if model == cls(component_config=meta):
model.retrieval_intent_mapping = {}
return model # pytype: disable=bad-return-type

file_name = meta.get("file")
model_dir = Path(model_dir)

retrieval_intent_mapping = io_utils.json_unpickle(
model_dir / f"{file_name}.retrieval_intent_mapping.pkl"
)

model.retrieval_intent_mapping = retrieval_intent_mapping

return model # pytype: disable=bad-return-type


class DIET2DIET(DIET):
def _check_data(self) -> None:
Expand Down
9 changes: 8 additions & 1 deletion tests/nlu/selectors/test_selectors.py
Expand Up @@ -4,6 +4,7 @@
from rasa.nlu.training_data import load_data
from rasa.nlu.train import Trainer, Interpreter
from rasa.utils.tensorflow.constants import EPOCHS
from rasa.nlu.constants import RESPONSE_SELECTOR_PROPERTY_NAME


@pytest.mark.parametrize(
Expand Down Expand Up @@ -33,6 +34,12 @@ def test_train_selector(pipeline, component_builder, tmpdir):
assert trainer.pipeline

loaded = Interpreter.load(persisted_path, component_builder)
parsed = loaded.parse("hello")

assert loaded.pipeline
assert loaded.parse("hello") is not None
assert parsed is not None
assert (
parsed.get(RESPONSE_SELECTOR_PROPERTY_NAME)
.get("default")
.get("full_retrieval_intent")
) is not None