diff --git a/rasa/nlu/components.py b/rasa/nlu/components.py index eaae25c7deb7..e1f31b51bc33 100644 --- a/rasa/nlu/components.py +++ b/rasa/nlu/components.py @@ -1,13 +1,11 @@ +import itertools import logging import typing -from typing import Any, Dict, Hashable, List, Optional, Set, Text, Tuple, Iterable - -from rasa.nlu.config import RasaNLUModelConfig, override_defaults -from rasa.nlu.constants import ( - RESPONSE_ATTRIBUTE, - MESSAGE_ATTRIBUTES, - NOT_PRETRAINED_EXTRACTORS, -) +from typing import Any, Dict, Hashable, List, Optional, Set, Text, Tuple + +from rasa.nlu import config +from rasa.nlu.config import RasaNLUModelConfig +from rasa.nlu.constants import TRAINABLE_EXTRACTORS from rasa.nlu.training_data import Message, TrainingData from rasa.utils.common import raise_warning @@ -117,48 +115,46 @@ def validate_requires_any_of( ) -### def validate_required_components_from_data( pipeline: List["Component"], data: TrainingData ): - # Check for entity examples but no entity extractory trained on own data - def components_in_pipeline(components: Iterable[Text], pipeline: List["Component"]): - return any( - [any([component.name == c for component in pipeline]) for c in components] - ) + """Check training data for features. + + If those features require specific components to featurize or + process them, warn the user if the required component is missing. + """ - if data.entity_examples and not components_in_pipeline( - NOT_PRETRAINED_EXTRACTORS, pipeline + if data.entity_examples and not config.any_components_in_pipeline( + TRAINABLE_EXTRACTORS, pipeline ): raise_warning( "You have defined training data consisting of entity examples, but " "your NLU pipeline does not include an entity extractor trained on " "your training data. To extract entity examples, add one of " - f"{NOT_PRETRAINED_EXTRACTORS} to your pipeline." + f"{TRAINABLE_EXTRACTORS} to your pipeline." ) - # Check for Regex data but RegexFeaturizer not enabled - if data.regex_features and not components_in_pipeline( + if data.regex_features and not config.any_components_in_pipeline( ["RegexFeaturizer"], pipeline ): raise_warning( "You have defined training data with regexes, but " - "your NLU pipeline does not include an RegexFeaturizer. " + "your NLU pipeline does not include a RegexFeaturizer. " "To featurize regexes for entity extraction, you need " - "to have RegexFeaturizer in your pipeline." + "to have a RegexFeaturizer in your pipeline." ) - # Check for lookup tables but no RegexFeaturizer - if data.lookup_tables and not components_in_pipeline(["RegexFeaturizer"], pipeline): + if data.lookup_tables and not config.any_components_in_pipeline( + ["RegexFeaturizer"], pipeline + ): raise_warning( "You have defined training data consisting of lookup tables, but " "your NLU pipeline does not include a RegexFeaturizer. " "To featurize lookup tables, add a RegexFeaturizer to your pipeline." ) - # Lookup Tables config verification if data.lookup_tables: - if not components_in_pipeline(["CRFEntityExtractor"], pipeline): + if not config.any_components_in_pipeline(["CRFEntityExtractor"], pipeline): raise_warning( "You have defined training data consisting of lookup tables, but " "your NLU pipeline does not include a CRFEntityExtractor. " @@ -166,20 +162,23 @@ def components_in_pipeline(components: Iterable[Text], pipeline: List["Component ) else: crf_components = [c for c in pipeline if c.name == "CRFEntityExtractor"] - crf_component = crf_components[-1] - crf_features = [ - f for i in crf_component.component_config["features"] for f in i - ] - pattern_feature = "pattern" in crf_features - if not pattern_feature: + # check to see if any of the possible CRFEntityExtractors will featurize `pattern` + has_pattern_feature = False + for crf in crf_components: + crf_features = crf.component_config.get("features") + # iterate through [[before],[word],[after]] features + if "pattern" in itertools.chain(*crf_features): + has_pattern_feature = True + + if not has_pattern_feature: raise_warning( "You have defined training data consisting of lookup tables, but " - "your NLU pipeline CRFEntityExtractor does not include pattern feature. " - "To featurize lookup tables, add pattern feature to CRFEntityExtractor in pipeline." + "your NLU pipeline's CRFEntityExtractor does not include the `pattern` feature. " + "To featurize lookup tables, add the `pattern` feature to the CRFEntityExtractor in " + "your pipeline." ) - # Check for synonyms but no EntitySynonymMapper - if data.entity_synonyms and not components_in_pipeline( + if data.entity_synonyms and not config.any_components_in_pipeline( ["EntitySynonymMapper"], pipeline ): raise_warning( @@ -188,9 +187,8 @@ def components_in_pipeline(components: Iterable[Text], pipeline: List["Component "To map synonyms, add an EntitySynonymMapper to your pipeline." ) - # Check for response selector but no component for it - if data.response_examples and not any( - [MESSAGE_ATTRIBUTES in component.provides for component in pipeline] + if data.response_examples and not config.any_components_in_pipeline( + ["ResponseSelector"], pipeline ): raise_warning( "Your training data includes examples for training a response selector, but " @@ -314,7 +312,9 @@ def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None: # this is important for e.g. persistence component_config["name"] = self.name - self.component_config = override_defaults(self.defaults, component_config) + self.component_config = config.override_defaults( + self.defaults, component_config + ) self.partial_processing_pipeline = None self.partial_processing_context = None diff --git a/rasa/nlu/config.py b/rasa/nlu/config.py index a39d5632466a..c7378d36b411 100644 --- a/rasa/nlu/config.py +++ b/rasa/nlu/config.py @@ -2,13 +2,17 @@ import logging import os import ruamel.yaml as yaml -from typing import Any, Dict, List, Optional, Text, Union, Tuple +import typing +from typing import Any, Dict, Iterable, List, Optional, Text, Union import rasa.utils.io from rasa.constants import DEFAULT_CONFIG_PATH, DOCS_URL_PIPELINE from rasa.nlu.utils import json_to_string from rasa.utils.common import raise_warning +if typing.TYPE_CHECKING: + from rasa.nlu.components import Component + logger = logging.getLogger(__name__) @@ -77,6 +81,13 @@ def component_config_from_pipeline( return override_defaults(defaults, {}) +def any_components_in_pipeline(components: Iterable[Text], pipeline: List["Component"]): + """Check if any of the provided components are listed in the pipeline.""" + return any( + [any([component.name == c for component in pipeline]) for c in components] + ) + + class RasaNLUModelConfig: def __init__(self, configuration_values: Optional[Dict[Text, Any]] = None) -> None: """Create a model configuration, optionally overriding diff --git a/rasa/nlu/constants.py b/rasa/nlu/constants.py index 16522a2003a2..28b4c4bffe65 100644 --- a/rasa/nlu/constants.py +++ b/rasa/nlu/constants.py @@ -12,7 +12,7 @@ PRETRAINED_EXTRACTORS = {"DucklingHTTPExtractor", "SpacyEntityExtractor"} -NOT_PRETRAINED_EXTRACTORS = {"MitieEntityExtractor", "CRFEntityExtractor"} +TRAINABLE_EXTRACTORS = {"MitieEntityExtractor", "CRFEntityExtractor"} CLS_TOKEN = "__CLS__" diff --git a/tests/nlu/base/test_config.py b/tests/nlu/base/test_config.py index c0165586a6c7..d0f702cbb226 100644 --- a/tests/nlu/base/test_config.py +++ b/tests/nlu/base/test_config.py @@ -5,8 +5,9 @@ import pytest import rasa.utils.io -from rasa.nlu import config, load_data -from rasa.nlu.components import ComponentBuilder, validate_required_components_from_data +from rasa.nlu import components, config, load_data +from rasa.nlu.components import ComponentBuilder +from rasa.nlu.constants import TRAINABLE_EXTRACTORS from rasa.nlu.registry import registered_pipeline_templates from tests.nlu.conftest import CONFIG_DEFAULTS_PATH, DEFAULT_DATA_PATH from tests.nlu.utilities import write_file_config @@ -84,14 +85,17 @@ def test_override_defaults_supervised_embeddings_pipeline(): assert component2.epochs == 10 -def test_warn_no_pretrained_extractor(): +def test_warn_no_trainable_extractor(): cfg = config.load("sample_configs/config_spacy_entity_extractor.yml") trainer = Trainer(cfg) training_data = load_data(DEFAULT_DATA_PATH) with pytest.warns(UserWarning) as record: - validate_required_components_from_data(trainer.pipeline, training_data) + components.validate_required_components_from_data( + trainer.pipeline, training_data + ) assert len(record) == 1 + assert str(TRAINABLE_EXTRACTORS) in record[0].message.args[0] def test_warn_missing_regex_featurizer(): @@ -99,9 +103,12 @@ def test_warn_missing_regex_featurizer(): trainer = Trainer(cfg) training_data = load_data(DEFAULT_DATA_PATH) with pytest.warns(UserWarning) as record: - validate_required_components_from_data(trainer.pipeline, training_data) + components.validate_required_components_from_data( + trainer.pipeline, training_data + ) assert len(record) == 1 + assert "RegexFeaturizer" in record[0].message.args[0] def test_warn_missing_pattern_feature_lookup_tables(): @@ -109,9 +116,12 @@ def test_warn_missing_pattern_feature_lookup_tables(): trainer = Trainer(cfg) training_data = load_data("data/test/lookup_tables/lookup_table.md") with pytest.warns(UserWarning) as record: - validate_required_components_from_data(trainer.pipeline, training_data) + components.validate_required_components_from_data( + trainer.pipeline, training_data + ) assert len(record) == 1 + assert "`pattern` feature" in record[0].message.args[0] def test_warn_missing_synonym_mapper(): @@ -119,9 +129,12 @@ def test_warn_missing_synonym_mapper(): trainer = Trainer(cfg) training_data = load_data("data/test/markdown_single_sections/synonyms_only.md") with pytest.warns(UserWarning) as record: - validate_required_components_from_data(trainer.pipeline, training_data) + components.validate_required_components_from_data( + trainer.pipeline, training_data + ) assert len(record) == 1 + assert "EntitySynonymMapper" in record[0].message.args[0] def test_warn_missing_response_selector(): @@ -129,6 +142,9 @@ def test_warn_missing_response_selector(): trainer = Trainer(cfg) training_data = load_data("data/examples/rasa") with pytest.warns(UserWarning) as record: - validate_required_components_from_data(trainer.pipeline, training_data) + components.validate_required_components_from_data( + trainer.pipeline, training_data + ) assert len(record) == 1 + assert "ResponseSelector" in record[0].message.args[0]