Skip to content

Commit

Permalink
Merge branch 'master' into merge-1.8.x
Browse files Browse the repository at this point in the history
  • Loading branch information
akelad committed Mar 20, 2020
2 parents 69b11bc + ac8a0e2 commit 13dae99
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 2 deletions.
4 changes: 4 additions & 0 deletions changelog/4826.improvement.rst
@@ -0,0 +1,4 @@
Add full retrieval intent name to message data
``ResponseSelector`` will now add the full retrieval intent name
e.g. ``faq/which_version`` to the prediction, making it accessible
from the tracker.
80 changes: 79 additions & 1 deletion rasa/nlu/selectors/response_selector.py
Expand Up @@ -2,13 +2,16 @@

import numpy as np
import tensorflow as tf
from pathlib import Path

from typing import Any, Dict, Optional, Text, Tuple, Union, List, Type

import rasa.utils.io as io_utils
from rasa.nlu.config import InvalidConfigError
from rasa.nlu.training_data import TrainingData, Message
from rasa.nlu.components import Component
from rasa.nlu.featurizers.featurizer import Featurizer
from rasa.nlu.model import Metadata
from rasa.nlu.classifiers.diet_classifier import (
DIETClassifier,
DIET,
Expand Down Expand Up @@ -66,9 +69,12 @@
from rasa.nlu.constants import (
RESPONSE,
RESPONSE_SELECTOR_PROPERTY_NAME,
RESPONSE_KEY_ATTRIBUTE,
INTENT,
DEFAULT_OPEN_UTTERANCE_TYPE,
TEXT,
)

from rasa.utils.tensorflow.model_data import RasaModelData
from rasa.utils.tensorflow.models import RasaModel

Expand Down Expand Up @@ -203,6 +209,7 @@ def __init__(
index_label_id_mapping: Optional[Dict[int, Text]] = None,
index_tag_id_mapping: Optional[Dict[int, Text]] = None,
model: Optional[RasaModel] = None,
retrieval_intent_mapping: Optional[Dict[Text, Text]] = None,
) -> None:

component_config = component_config or {}
Expand All @@ -211,6 +218,7 @@ def __init__(
component_config[INTENT_CLASSIFICATION] = True
component_config[ENTITY_RECOGNITION] = False
component_config[BILOU_FLAG] = None
self.retrieval_intent_mapping = retrieval_intent_mapping or {}

super().__init__(
component_config, index_label_id_mapping, index_tag_id_mapping, model
Expand All @@ -231,6 +239,20 @@ def _check_config_parameters(self) -> None:
super()._check_config_parameters()
self._load_selector_params(self.component_config)

@staticmethod
def _create_retrieval_intent_mapping(
training_data: TrainingData,
) -> Dict[Text, Text]:
"""Create response_key dictionary"""

retrieval_intent_mapping = {}
for example in training_data.intent_examples:
retrieval_intent_mapping[
example.get(RESPONSE)
] = f"{example.get(INTENT)}/{example.get(RESPONSE_KEY_ATTRIBUTE)}"

return retrieval_intent_mapping

@staticmethod
def _set_message_property(
message: Message, prediction_dict: Dict[Text, Any], selector_key: Text
Expand Down Expand Up @@ -262,6 +284,9 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
label_id_index_mapping = self._label_id_index_mapping(
training_data, attribute=RESPONSE
)
self.retrieval_intent_mapping = self._create_retrieval_intent_mapping(
training_data
)

if not label_id_index_mapping:
# no labels are present to train
Expand All @@ -288,6 +313,7 @@ def process(self, message: Message, **kwargs: Any) -> None:

out = self._predict(message)
label, label_ranking = self._predict_label(out)
retrieval_intent_name = self.retrieval_intent_mapping.get(label.get("name"))

selector_key = (
self.retrieval_intent
Expand All @@ -299,10 +325,62 @@ def process(self, message: Message, **kwargs: Any) -> None:
f"Adding following selector key to message property: {selector_key}"
)

prediction_dict = {"response": label, "ranking": label_ranking}
prediction_dict = {
"response": label,
"ranking": label_ranking,
"full_retrieval_intent": retrieval_intent_name,
}

self._set_message_property(message, prediction_dict, selector_key)

def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]:
"""Persist this model into the passed directory.
Return the metadata necessary to load the model again.
"""
if self.model is None:
return {"file": None}

super().persist(file_name, model_dir)

model_dir = Path(model_dir)

io_utils.json_pickle(
model_dir / f"{file_name}.retrieval_intent_mapping.pkl",
self.retrieval_intent_mapping,
)

return {"file": file_name}

@classmethod
def load(
cls,
meta: Dict[Text, Any],
model_dir: Text = None,
model_metadata: Metadata = None,
cached_component: Optional["ResponseSelector"] = None,
**kwargs: Any,
) -> "ResponseSelector":
"""Loads the trained model from the provided directory."""

model = super().load(
meta, model_dir, model_metadata, cached_component, **kwargs
)
if model == cls(component_config=meta):
model.retrieval_intent_mapping = {}
return model # pytype: disable=bad-return-type

file_name = meta.get("file")
model_dir = Path(model_dir)

retrieval_intent_mapping = io_utils.json_unpickle(
model_dir / f"{file_name}.retrieval_intent_mapping.pkl"
)

model.retrieval_intent_mapping = retrieval_intent_mapping

return model # pytype: disable=bad-return-type


class DIET2DIET(DIET):
def _check_data(self) -> None:
Expand Down
9 changes: 8 additions & 1 deletion tests/nlu/selectors/test_selectors.py
Expand Up @@ -4,6 +4,7 @@
from rasa.nlu.training_data import load_data
from rasa.nlu.train import Trainer, Interpreter
from rasa.utils.tensorflow.constants import EPOCHS
from rasa.nlu.constants import RESPONSE_SELECTOR_PROPERTY_NAME


@pytest.mark.parametrize(
Expand Down Expand Up @@ -33,6 +34,12 @@ def test_train_selector(pipeline, component_builder, tmpdir):
assert trainer.pipeline

loaded = Interpreter.load(persisted_path, component_builder)
parsed = loaded.parse("hello")

assert loaded.pipeline
assert loaded.parse("hello") is not None
assert parsed is not None
assert (
parsed.get(RESPONSE_SELECTOR_PROPERTY_NAME)
.get("default")
.get("full_retrieval_intent")
) is not None

0 comments on commit 13dae99

Please sign in to comment.