-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
test_config.py
83 lines (59 loc) 路 2.54 KB
/
test_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import json
import tempfile
from typing import Text
import pytest
import rasa.utils.io
from rasa.nlu import config
from rasa.nlu.components import ComponentBuilder
from rasa.nlu.registry import registered_pipeline_templates
from tests.nlu.conftest import CONFIG_DEFAULTS_PATH
from tests.nlu.utilities import write_file_config
defaults = rasa.utils.io.read_config_file(CONFIG_DEFAULTS_PATH)
def test_default_config(default_config):
assert default_config.as_dict() == defaults
def test_blank_config():
file_config = {}
f = write_file_config(file_config)
final_config = config.load(f.name)
assert final_config.as_dict() == defaults
def test_invalid_config_json():
file_config = """pipeline: [pretrained_embeddings_spacy""" # invalid yaml
with tempfile.NamedTemporaryFile("w+", suffix="_tmp_config_file.json") as f:
f.write(file_config)
f.flush()
with pytest.raises(config.InvalidConfigError):
config.load(f.name)
def test_invalid_pipeline_template():
args = {"pipeline": "my_made_up_name"}
f = write_file_config(args)
with pytest.raises(config.InvalidConfigError) as execinfo:
config.load(f.name)
assert "unknown pipeline template" in str(execinfo.value)
@pytest.mark.parametrize(
"pipeline_template", list(registered_pipeline_templates.keys())
)
def test_pipeline_registry_lookup(pipeline_template: Text):
args = {"pipeline": pipeline_template}
f = write_file_config(args)
final_config = config.load(f.name)
components = [c for c in final_config.pipeline]
assert json.dumps(components, sort_keys=True) == json.dumps(
registered_pipeline_templates[pipeline_template], sort_keys=True
)
def test_default_config_file():
final_config = config.RasaNLUModelConfig()
assert len(final_config) > 1
def test_set_attr_on_component():
cfg = config.load("sample_configs/config_pretrained_embeddings_spacy.yml")
cfg.set_component_attr(6, C=324)
assert cfg.for_component(1) == {"name": "SpacyTokenizer", "use_cls_token": False}
assert cfg.for_component(6) == {"name": "SklearnIntentClassifier", "C": 324}
def test_override_defaults_supervised_embeddings_pipeline():
cfg = config.load("data/test/config_embedding_test.yml")
builder = ComponentBuilder()
component1_cfg = cfg.for_component(0)
component1 = builder.create_component(component1_cfg, cfg)
assert component1.max_ngram == 3
component2_cfg = cfg.for_component(1)
component2 = builder.create_component(component2_cfg, cfg)
assert component2.epochs == 10