Skip to content

Commit

Permalink
review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
erohmensing committed Feb 26, 2020
1 parent b891e5a commit 0b4343d
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 49 deletions.
78 changes: 39 additions & 39 deletions rasa/nlu/components.py
@@ -1,13 +1,11 @@
import itertools
import logging
import typing
from typing import Any, Dict, Hashable, List, Optional, Set, Text, Tuple, Iterable

from rasa.nlu.config import RasaNLUModelConfig, override_defaults
from rasa.nlu.constants import (
RESPONSE_ATTRIBUTE,
MESSAGE_ATTRIBUTES,
NOT_PRETRAINED_EXTRACTORS,
)
from typing import Any, Dict, Hashable, List, Optional, Set, Text, Tuple

from rasa.nlu import config
from rasa.nlu.config import RasaNLUModelConfig
from rasa.nlu.constants import TRAINABLE_EXTRACTORS
from rasa.nlu.training_data import Message, TrainingData
from rasa.utils.common import raise_warning

Expand Down Expand Up @@ -117,69 +115,70 @@ def validate_requires_any_of(
)


###
def validate_required_components_from_data(
pipeline: List["Component"], data: TrainingData
):
# Check for entity examples but no entity extractory trained on own data
def components_in_pipeline(components: Iterable[Text], pipeline: List["Component"]):
return any(
[any([component.name == c for component in pipeline]) for c in components]
)
"""Check training data for features.
If those features require specific components to featurize or
process them, warn the user if the required component is missing.
"""

if data.entity_examples and not components_in_pipeline(
NOT_PRETRAINED_EXTRACTORS, pipeline
if data.entity_examples and not config.any_components_in_pipeline(
TRAINABLE_EXTRACTORS, pipeline
):
raise_warning(
"You have defined training data consisting of entity examples, but "
"your NLU pipeline does not include an entity extractor trained on "
"your training data. To extract entity examples, add one of "
f"{NOT_PRETRAINED_EXTRACTORS} to your pipeline."
f"{TRAINABLE_EXTRACTORS} to your pipeline."
)

# Check for Regex data but RegexFeaturizer not enabled
if data.regex_features and not components_in_pipeline(
if data.regex_features and not config.any_components_in_pipeline(
["RegexFeaturizer"], pipeline
):
raise_warning(
"You have defined training data with regexes, but "
"your NLU pipeline does not include an RegexFeaturizer. "
"your NLU pipeline does not include a RegexFeaturizer. "
"To featurize regexes for entity extraction, you need "
"to have RegexFeaturizer in your pipeline."
"to have a RegexFeaturizer in your pipeline."
)

# Check for lookup tables but no RegexFeaturizer
if data.lookup_tables and not components_in_pipeline(["RegexFeaturizer"], pipeline):
if data.lookup_tables and not config.any_components_in_pipeline(
["RegexFeaturizer"], pipeline
):
raise_warning(
"You have defined training data consisting of lookup tables, but "
"your NLU pipeline does not include a RegexFeaturizer. "
"To featurize lookup tables, add a RegexFeaturizer to your pipeline."
)

# Lookup Tables config verification
if data.lookup_tables:
if not components_in_pipeline(["CRFEntityExtractor"], pipeline):
if not config.any_components_in_pipeline(["CRFEntityExtractor"], pipeline):
raise_warning(
"You have defined training data consisting of lookup tables, but "
"your NLU pipeline does not include a CRFEntityExtractor. "
"To featurize lookup tables, add a CRFEntityExtractor to your pipeline."
)
else:
crf_components = [c for c in pipeline if c.name == "CRFEntityExtractor"]
crf_component = crf_components[-1]
crf_features = [
f for i in crf_component.component_config["features"] for f in i
]
pattern_feature = "pattern" in crf_features
if not pattern_feature:
# check to see if any of the possible CRFEntityExtractors will featurize `pattern`
has_pattern_feature = False
for crf in crf_components:
crf_features = crf.component_config.get("features")
# iterate through [[before],[word],[after]] features
if "pattern" in itertools.chain(*crf_features):
has_pattern_feature = True

if not has_pattern_feature:
raise_warning(
"You have defined training data consisting of lookup tables, but "
"your NLU pipeline CRFEntityExtractor does not include pattern feature. "
"To featurize lookup tables, add pattern feature to CRFEntityExtractor in pipeline."
"your NLU pipeline's CRFEntityExtractor does not include the `pattern` feature. "
"To featurize lookup tables, add the `pattern` feature to the CRFEntityExtractor in "
"your pipeline."
)

# Check for synonyms but no EntitySynonymMapper
if data.entity_synonyms and not components_in_pipeline(
if data.entity_synonyms and not config.any_components_in_pipeline(
["EntitySynonymMapper"], pipeline
):
raise_warning(
Expand All @@ -188,9 +187,8 @@ def components_in_pipeline(components: Iterable[Text], pipeline: List["Component
"To map synonyms, add an EntitySynonymMapper to your pipeline."
)

# Check for response selector but no component for it
if data.response_examples and not any(
[MESSAGE_ATTRIBUTES in component.provides for component in pipeline]
if data.response_examples and not config.any_components_in_pipeline(
["ResponseSelector"], pipeline
):
raise_warning(
"Your training data includes examples for training a response selector, but "
Expand Down Expand Up @@ -314,7 +312,9 @@ def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None:
# this is important for e.g. persistence
component_config["name"] = self.name

self.component_config = override_defaults(self.defaults, component_config)
self.component_config = config.override_defaults(
self.defaults, component_config
)

self.partial_processing_pipeline = None
self.partial_processing_context = None
Expand Down
13 changes: 12 additions & 1 deletion rasa/nlu/config.py
Expand Up @@ -2,13 +2,17 @@
import logging
import os
import ruamel.yaml as yaml
from typing import Any, Dict, List, Optional, Text, Union, Tuple
import typing
from typing import Any, Dict, Iterable, List, Optional, Text, Union

import rasa.utils.io
from rasa.constants import DEFAULT_CONFIG_PATH, DOCS_URL_PIPELINE
from rasa.nlu.utils import json_to_string
from rasa.utils.common import raise_warning

if typing.TYPE_CHECKING:
from rasa.nlu.components import Component

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -77,6 +81,13 @@ def component_config_from_pipeline(
return override_defaults(defaults, {})


def any_components_in_pipeline(components: Iterable[Text], pipeline: List["Component"]):
"""Check if any of the provided components are listed in the pipeline."""
return any(
[any([component.name == c for component in pipeline]) for c in components]
)


class RasaNLUModelConfig:
def __init__(self, configuration_values: Optional[Dict[Text, Any]] = None) -> None:
"""Create a model configuration, optionally overriding
Expand Down
2 changes: 1 addition & 1 deletion rasa/nlu/constants.py
Expand Up @@ -12,7 +12,7 @@

PRETRAINED_EXTRACTORS = {"DucklingHTTPExtractor", "SpacyEntityExtractor"}

NOT_PRETRAINED_EXTRACTORS = {"MitieEntityExtractor", "CRFEntityExtractor"}
TRAINABLE_EXTRACTORS = {"MitieEntityExtractor", "CRFEntityExtractor"}

CLS_TOKEN = "__CLS__"

Expand Down
32 changes: 24 additions & 8 deletions tests/nlu/base/test_config.py
Expand Up @@ -5,8 +5,9 @@
import pytest

import rasa.utils.io
from rasa.nlu import config, load_data
from rasa.nlu.components import ComponentBuilder, validate_required_components_from_data
from rasa.nlu import components, config, load_data
from rasa.nlu.components import ComponentBuilder
from rasa.nlu.constants import TRAINABLE_EXTRACTORS
from rasa.nlu.registry import registered_pipeline_templates
from tests.nlu.conftest import CONFIG_DEFAULTS_PATH, DEFAULT_DATA_PATH
from tests.nlu.utilities import write_file_config
Expand Down Expand Up @@ -84,51 +85,66 @@ def test_override_defaults_supervised_embeddings_pipeline():
assert component2.epochs == 10


def test_warn_no_pretrained_extractor():
def test_warn_no_trainable_extractor():
cfg = config.load("sample_configs/config_spacy_entity_extractor.yml")
trainer = Trainer(cfg)
training_data = load_data(DEFAULT_DATA_PATH)
with pytest.warns(UserWarning) as record:
validate_required_components_from_data(trainer.pipeline, training_data)
components.validate_required_components_from_data(
trainer.pipeline, training_data
)

assert len(record) == 1
assert str(TRAINABLE_EXTRACTORS) in record[0].message.args[0]


def test_warn_missing_regex_featurizer():
cfg = config.load("sample_configs/config_crf_no_regex.yml")
trainer = Trainer(cfg)
training_data = load_data(DEFAULT_DATA_PATH)
with pytest.warns(UserWarning) as record:
validate_required_components_from_data(trainer.pipeline, training_data)
components.validate_required_components_from_data(
trainer.pipeline, training_data
)

assert len(record) == 1
assert "RegexFeaturizer" in record[0].message.args[0]


def test_warn_missing_pattern_feature_lookup_tables():
cfg = config.load("sample_configs/config_crf_no_pattern_feature.yml")
trainer = Trainer(cfg)
training_data = load_data("data/test/lookup_tables/lookup_table.md")
with pytest.warns(UserWarning) as record:
validate_required_components_from_data(trainer.pipeline, training_data)
components.validate_required_components_from_data(
trainer.pipeline, training_data
)

assert len(record) == 1
assert "`pattern` feature" in record[0].message.args[0]


def test_warn_missing_synonym_mapper():
cfg = config.load("sample_configs/config_crf_no_synonyms.yml")
trainer = Trainer(cfg)
training_data = load_data("data/test/markdown_single_sections/synonyms_only.md")
with pytest.warns(UserWarning) as record:
validate_required_components_from_data(trainer.pipeline, training_data)
components.validate_required_components_from_data(
trainer.pipeline, training_data
)

assert len(record) == 1
assert "EntitySynonymMapper" in record[0].message.args[0]


def test_warn_missing_response_selector():
cfg = config.load("sample_configs/config_supervised_embeddings.yml")
trainer = Trainer(cfg)
training_data = load_data("data/examples/rasa")
with pytest.warns(UserWarning) as record:
validate_required_components_from_data(trainer.pipeline, training_data)
components.validate_required_components_from_data(
trainer.pipeline, training_data
)

assert len(record) == 1
assert "ResponseSelector" in record[0].message.args[0]

0 comments on commit 0b4343d

Please sign in to comment.