Skip to content

Commit

Permalink
Merge 4e4a299 into f5796ec
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed Aug 15, 2019
2 parents f5796ec + 4e4a299 commit 0ccb619
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
15 changes: 6 additions & 9 deletions rasa/nlu/registry.py
Expand Up @@ -117,23 +117,20 @@
{"name": "CountVectorsFeaturizer"},
{
"name": "CountVectorsFeaturizer",
"config": {"analyzer": "char_wb", "min_ngram": 1, "max_ngram": 4},
"analyzer": "char_wb",
"min_ngram": 1,
"max_ngram": 4,
},
{"name": "EmbeddingIntentClassifier"},
],
}


def pipeline_template(s: Text) -> Optional[List[Dict[Text, Any]]]:
components = registered_pipeline_templates.get(s)
import copy

if components:
# converts the list of components in the configuration
# format expected (one json object per component)
return [{"name": c.get("name"), **c.get("config", {})} for c in components]

else:
return None
# do a deepcopy to avoid changing the template configurations
return copy.deepcopy(registered_pipeline_templates.get(s))


def get_component_class(component_name: Text) -> Type["Component"]:
Expand Down
15 changes: 11 additions & 4 deletions tests/nlu/base/test_config.py
@@ -1,4 +1,6 @@
import json
import tempfile
from typing import Text

import pytest

Expand Down Expand Up @@ -40,21 +42,26 @@ def test_invalid_pipeline_template():
assert "unknown pipeline template" in str(execinfo.value)


def test_pipeline_looksup_registry():
pipeline_template = list(registered_pipeline_templates)[0]
@pytest.mark.parametrize(
"pipeline_template", list(registered_pipeline_templates.keys())
)
def test_pipeline_registry_lookup(pipeline_template: Text):
args = {"pipeline": pipeline_template}
f = write_file_config(args)
final_config = config.load(f.name)
components = [c for c in final_config.pipeline]
assert components == registered_pipeline_templates[pipeline_template]

assert json.dumps(components, sort_keys=True) == json.dumps(
registered_pipeline_templates[pipeline_template], sort_keys=True
)


def test_default_config_file():
final_config = config.RasaNLUModelConfig()
assert len(final_config) > 1


def test_set_attr_on_component(default_config):
def test_set_attr_on_component():
cfg = config.load("sample_configs/config_pretrained_embeddings_spacy.yml")
cfg.set_component_attr(6, C=324)

Expand Down

0 comments on commit 0ccb619

Please sign in to comment.