Skip to content

Commit

Permalink
Merge pull request #5614 from RasaHQ/filter-operation
Browse files Browse the repository at this point in the history
Improved filtering for NLU training data examples
  • Loading branch information
alwx committed Apr 15, 2020
2 parents a095714 + c7464a6 commit d55e868
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 25 deletions.
2 changes: 2 additions & 0 deletions changelog/5614.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Replace ``TrainingData.filter_by_intent`` function with a more general function which filters training
examples using a filtering function.
4 changes: 3 additions & 1 deletion rasa/nlu/selectors/response_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,9 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
"""

if self.retrieval_intent:
training_data = training_data.filter_by_intent(self.retrieval_intent)
training_data = training_data.filter_training_examples(
lambda ex: self.retrieval_intent == ex.get(INTENT)
)
else:
# retrieval intent was left to its default value
logger.info(
Expand Down
60 changes: 37 additions & 23 deletions rasa/nlu/training_data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from collections import Counter, OrderedDict
from copy import deepcopy
from os.path import relpath
from typing import Any, Dict, List, Optional, Set, Text, Tuple
from typing import Any, Dict, List, Optional, Set, Text, Tuple, Callable

import rasa.nlu.utils
from rasa.utils.common import raise_warning, lazy_property
from rasa.nlu.constants import RESPONSE, RESPONSE_KEY_ATTRIBUTE
from rasa.nlu.constants import ENTITIES, INTENT, RESPONSE, RESPONSE_KEY_ATTRIBUTE
from rasa.nlu.training_data.message import Message
from rasa.nlu.training_data.util import check_duplicate_synonym
from rasa.nlu.utils import list_to_str
Expand Down Expand Up @@ -75,21 +75,35 @@ def merge(self, *others: "TrainingData") -> "TrainingData":
nlg_stories,
)

def filter_by_intent(self, intent: Text):
"""Filter training examples """
def filter_training_examples(
self, condition: Callable[[Message], bool]
) -> "TrainingData":
"""Filter training examples.
training_examples = []
for ex in self.training_examples:
if ex.get("intent") == intent:
training_examples.append(ex)
Args:
condition: A function that will be applied to filter training examples.
Returns:
TrainingData: A TrainingData with filtered training examples.
"""

return TrainingData(
training_examples,
list(filter(condition, self.training_examples)),
self.entity_synonyms,
self.regex_features,
self.lookup_tables,
)

def filter_by_intent(self, intent: Text) -> "TrainingData":
"""Filter training examples."""
raise_warning(
"The `filter_by_intent` function is deprecated. "
"Please use `filter_training_examples` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.filter_training_examples(lambda ex: intent == ex.get(INTENT))

def __hash__(self) -> int:
from rasa.core import utils as core_utils

Expand All @@ -105,49 +119,49 @@ def sanitize_examples(examples: List[Message]) -> List[Message]:
Remove trailing whitespaces from intent and response annotations and drop duplicate examples."""

for ex in examples:
if ex.get("intent"):
ex.set("intent", ex.get("intent").strip())
if ex.get(INTENT):
ex.set(INTENT, ex.get(INTENT).strip())

if ex.get("response"):
ex.set("response", ex.get("response").strip())
if ex.get(RESPONSE):
ex.set(RESPONSE, ex.get(RESPONSE).strip())

return list(OrderedDict.fromkeys(examples))

@lazy_property
def intent_examples(self) -> List[Message]:
return [ex for ex in self.training_examples if ex.get("intent")]
return [ex for ex in self.training_examples if ex.get(INTENT)]

@lazy_property
def response_examples(self) -> List[Message]:
return [ex for ex in self.training_examples if ex.get("response")]
return [ex for ex in self.training_examples if ex.get(RESPONSE)]

@lazy_property
def entity_examples(self) -> List[Message]:
return [ex for ex in self.training_examples if ex.get("entities")]
return [ex for ex in self.training_examples if ex.get(ENTITIES)]

@lazy_property
def intents(self) -> Set[Text]:
"""Returns the set of intents in the training data."""
return {ex.get("intent") for ex in self.training_examples} - {None}
return {ex.get(INTENT) for ex in self.training_examples} - {None}

@lazy_property
def responses(self) -> Set[Text]:
"""Returns the set of responses in the training data."""
return {ex.get("response") for ex in self.training_examples} - {None}
return {ex.get(RESPONSE) for ex in self.training_examples} - {None}

@lazy_property
def retrieval_intents(self) -> Set[Text]:
"""Returns the total number of response types in the training data"""
return {
ex.get("intent")
ex.get(INTENT)
for ex in self.training_examples
if ex.get("response") is not None
if ex.get(RESPONSE) is not None
}

@lazy_property
def examples_per_intent(self) -> Dict[Text, int]:
"""Calculates the number of examples per intent."""
intents = [ex.get("intent") for ex in self.training_examples]
intents = [ex.get(INTENT) for ex in self.training_examples]
return dict(Counter(intents))

@lazy_property
Expand Down Expand Up @@ -299,7 +313,7 @@ def sorted_intent_examples(self) -> List[Message]:
"""Sorts the intent examples by the name of the intent and then response"""

return sorted(
self.intent_examples, key=lambda e: (e.get("intent"), e.get("response"))
self.intent_examples, key=lambda e: (e.get(INTENT), e.get(RESPONSE))
)

def validate(self) -> None:
Expand Down Expand Up @@ -393,7 +407,7 @@ def split_nlu_examples(
) -> Tuple[list, list]:
train, test = [], []
for intent, count in self.examples_per_intent.items():
ex = [e for e in self.intent_examples if e.data["intent"] == intent]
ex = [e for e in self.intent_examples if e.data[INTENT] == intent]
if random_seed is not None:
random.Random(random_seed).shuffle(ex)
else:
Expand Down
33 changes: 32 additions & 1 deletion tests/nlu/training_data/test_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
from jsonschema import ValidationError

from rasa.nlu.constants import TEXT
from rasa.nlu.constants import TEXT, RESPONSE_KEY_ATTRIBUTE
from rasa.nlu import training_data
from rasa.nlu.convert import convert_training_data
from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
Expand Down Expand Up @@ -174,6 +174,37 @@ def test_demo_data(files):
]


@pytest.mark.parametrize(
"files",
[
[
"data/examples/rasa/demo-rasa.json",
"data/examples/rasa/demo-rasa-responses.md",
],
[
"data/examples/rasa/demo-rasa.md",
"data/examples/rasa/demo-rasa-responses.md",
],
],
)
def test_demo_data_filter_out_retrieval_intents(files):
from rasa.importers.utils import training_data_from_paths

td = training_data_from_paths(files, language="en")
assert len(td.training_examples) == 46

td1 = td.filter_training_examples(lambda ex: ex.get(RESPONSE_KEY_ATTRIBUTE) is None)
assert len(td1.training_examples) == 42

td2 = td.filter_training_examples(
lambda ex: ex.get(RESPONSE_KEY_ATTRIBUTE) is not None
)
assert len(td2.training_examples) == 4

# make sure filtering operation doesn't mutate the source training data
assert len(td.training_examples) == 46


@pytest.mark.parametrize(
"filepaths",
[["data/examples/rasa/demo-rasa.md", "data/examples/rasa/demo-rasa-responses.md"]],
Expand Down

0 comments on commit d55e868

Please sign in to comment.