Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix component lookup #4244

Merged
merged 14 commits into from Aug 15, 2019
Merged
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Expand Up @@ -11,7 +11,8 @@ This project adheres to `Semantic Versioning`_ starting with version 1.0.

Added
-----
- `FallbackPolicy` can now be configured to trigger when the difference between confidences of two predicted intents is too narrow
- `FallbackPolicy` can now be configured to trigger when the difference between
confidences of two predicted intents is too narrow
- throw error during training when triggers are defined in the domain without
``MappingPolicy`` being present in the policy ensemble
- experimental training data importer which supports training with data of multiple
Expand Down
15 changes: 6 additions & 9 deletions rasa/nlu/registry.py
Expand Up @@ -117,23 +117,20 @@
{"name": "CountVectorsFeaturizer"},
{
"name": "CountVectorsFeaturizer",
"config": {"analyzer": "char_wb", "min_ngram": 1, "max_ngram": 4},
"analyzer": "char_wb",
"min_ngram": 1,
"max_ngram": 4,
},
{"name": "EmbeddingIntentClassifier"},
],
}


def pipeline_template(s: Text) -> Optional[List[Dict[Text, Any]]]:
components = registered_pipeline_templates.get(s)
import copy

if components:
# converts the list of components in the configuration
# format expected (one json object per component)
return [{"name": c.get("name"), **c.get("config", {})} for c in components]

else:
return None
# do a deepcopy to avoid changing the template configurations
return copy.deepcopy(registered_pipeline_templates.get(s))


def get_component_class(component_name: Text) -> Type["Component"]:
Expand Down
15 changes: 11 additions & 4 deletions tests/nlu/base/test_config.py
@@ -1,4 +1,6 @@
import json
import tempfile
from typing import Text

import pytest

Expand Down Expand Up @@ -40,21 +42,26 @@ def test_invalid_pipeline_template():
assert "unknown pipeline template" in str(execinfo.value)


def test_pipeline_looksup_registry():
pipeline_template = list(registered_pipeline_templates)[0]
@pytest.mark.parametrize(
"pipeline_template", list(registered_pipeline_templates.keys())
)
def test_pipeline_registry_lookup(pipeline_template: Text):
args = {"pipeline": pipeline_template}
f = write_file_config(args)
final_config = config.load(f.name)
components = [c for c in final_config.pipeline]
assert components == registered_pipeline_templates[pipeline_template]

assert json.dumps(components, sort_keys=True) == json.dumps(
registered_pipeline_templates[pipeline_template], sort_keys=True
)


def test_default_config_file():
final_config = config.RasaNLUModelConfig()
assert len(final_config) > 1


def test_set_attr_on_component(default_config):
def test_set_attr_on_component():
cfg = config.load("sample_configs/config_pretrained_embeddings_spacy.yml")
cfg.set_component_attr(6, C=324)

Expand Down