Skip to content

Commit

Permalink
Merge pull request #5406 from RasaHQ/Ghostvv-patch-1
Browse files Browse the repository at this point in the history
add hidden_layers_sizes to mimic EmbeddingIntentClassifier
  • Loading branch information
Ghostvv committed Mar 11, 2020
2 parents 35b8ed5 + 9faf6c9 commit 3a49f2a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/migration-guide.rst
Expand Up @@ -90,6 +90,8 @@ General
pipeline:
# - ... other components
- name: DIETClassifier
hidden_layers_sizes:
text: [256, 128]
intent_classification: True
entity_recognition: False
use_masked_language_model: False
Expand Down
7 changes: 6 additions & 1 deletion rasa/nlu/config.py
Expand Up @@ -59,7 +59,12 @@ def override_defaults(
cfg = {}

if custom:
cfg.update(custom)
for key in custom.keys():
if isinstance(cfg.get(key), dict):
cfg[key].update(custom[key])
else:
cfg[key] = custom[key]

return cfg


Expand Down
10 changes: 9 additions & 1 deletion tests/nlu/test_config.py
Expand Up @@ -134,7 +134,11 @@ def test_override_defaults_supervised_embeddings_pipeline():
{"name": "SpacyNLP"},
{"name": "SpacyTokenizer"},
{"name": "SpacyFeaturizer", "pooling": "max"},
{"name": "DIETClassifier", "epochs": 10},
{
"name": "DIETClassifier",
"epochs": 10,
"hidden_layers_sizes": {"text": [256, 128]},
},
],
}
)
Expand All @@ -151,6 +155,10 @@ def test_override_defaults_supervised_embeddings_pipeline():
_config.for_component(idx_classifier), _config
)
assert component2.component_config["epochs"] == 10
assert (
component2.defaults["hidden_layers_sizes"].keys()
== component2.component_config["hidden_layers_sizes"].keys()
)


def config_files_in(config_directory: Text):
Expand Down

0 comments on commit 3a49f2a

Please sign in to comment.