Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge resposnes from NLU and Domain when there're no retrieval intents #7390

Merged
merged 6 commits into from Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/7390.bugfix.md
@@ -0,0 +1 @@
Make sure the `responses` are synced between NLU training data and the Domain even if there're no retrieval intents in the NLU training data.
3 changes: 3 additions & 0 deletions data/test_nlg/test_responses.yml
@@ -0,0 +1,3 @@
responses:
utter_rasa:
- text: this is utter_rasa!
42 changes: 21 additions & 21 deletions rasa/shared/importers/importer.py
Expand Up @@ -153,7 +153,7 @@ def load_from_dict(
)
]

return E2EImporter(RetrievalModelsDataImporter(CombinedDataImporter(importers)))
return E2EImporter(ResponsesSyncImporter(CombinedDataImporter(importers)))

@staticmethod
def _importer_from_dict(
Expand Down Expand Up @@ -293,8 +293,8 @@ async def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
)


class RetrievalModelsDataImporter(TrainingDataImporter):
"""A `TrainingDataImporter` that sets up the data for training retrieval models.
class ResponsesSyncImporter(TrainingDataImporter):
"""Importer that syncs `responses` between Domain and NLU training data.

Synchronizes response templates between Domain and NLU
and adds retrieval intent properties from the NLU training data
Expand All @@ -314,19 +314,18 @@ async def get_domain(self) -> Domain:
existing_domain = await self._importer.get_domain()
existing_nlu_data = await self._importer.get_nlu_data()

# Check if NLU data has any retrieval intents, if yes
# add corresponding retrieval actions with `utter_` prefix automatically
# to an empty domain, update the properties of existing retrieval intents
# and merge response templates
if existing_nlu_data.retrieval_intents:

domain_with_retrieval_intents = self._get_domain_with_retrieval_intents(
existing_nlu_data.retrieval_intents,
existing_nlu_data.responses,
existing_domain,
)
# Merge responses from NLU data with responses in the domain.
# If NLU data has any retrieval intents, then add corresponding
# retrieval actions with `utter_` prefix automatically to the
# final domain, update the properties of existing retrieval intents.
domain_with_retrieval_intents = self._get_domain_with_retrieval_intents(
existing_nlu_data.retrieval_intents,
existing_nlu_data.responses,
existing_domain,
)

existing_domain = existing_domain.merge(domain_with_retrieval_intents)
existing_domain = existing_domain.merge(domain_with_retrieval_intents)
existing_domain.check_missing_templates()

return existing_domain

Expand All @@ -351,16 +350,19 @@ def _get_domain_with_retrieval_intents(
response_templates: Dict[Text, List[Dict[Text, Any]]],
existing_domain: Domain,
) -> Domain:
"""Construct a domain consisting of retrieval intents listed in the NLU training data.
"""Construct a domain consisting of retrieval intents.

The result domain will have retrieval intents that are listed
in the NLU training data.

Args:
retrieval_intents: Set of retrieval intents defined in NLU training data.
response_templates: Response templates defined in NLU training data.
existing_domain: Domain which is already loaded from the domain file.

Returns: Domain with retrieval actions added to action names and properties
for retrieval intents updated.
for retrieval intents updated.
"""

# Get all the properties already defined
# for each retrieval intent in other domains
# and add the retrieval intent property to them
Expand All @@ -379,9 +381,7 @@ def _get_domain_with_retrieval_intents(
[],
[],
response_templates,
RetrievalModelsDataImporter._construct_retrieval_action_names(
retrieval_intents
),
ResponsesSyncImporter._construct_retrieval_action_names(retrieval_intents),
{},
)

Expand Down
1 change: 0 additions & 1 deletion rasa/shared/importers/rasa.py
Expand Up @@ -65,7 +65,6 @@ async def get_domain(self) -> Domain:
return domain
try:
domain = Domain.load(self._domain_path)
domain.check_missing_templates()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this check be done somewhere if not here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added it here

except InvalidDomain as e:
rasa.shared.utils.io.raise_warning(
f"Loading domain from '{self._domain_path}' failed. Using "
Expand Down
32 changes: 27 additions & 5 deletions tests/shared/importers/test_importer.py
Expand Up @@ -19,7 +19,7 @@
NluDataImporter,
CoreDataImporter,
E2EImporter,
RetrievalModelsDataImporter,
ResponsesSyncImporter,
)
from rasa.shared.importers.multi_project import MultiProjectImporter
from rasa.shared.importers.rasa import RasaFileImporter
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_load_from_dict(
)

assert isinstance(actual, E2EImporter)
assert isinstance(actual.importer, RetrievalModelsDataImporter)
assert isinstance(actual.importer, ResponsesSyncImporter)

actual_importers = [i.__class__ for i in actual.importer._importer._importers]
assert actual_importers == expected
Expand All @@ -128,7 +128,7 @@ def test_load_from_config(tmpdir: Path):

importer = TrainingDataImporter.load_from_config(config_path)
assert isinstance(importer, E2EImporter)
assert isinstance(importer.importer, RetrievalModelsDataImporter)
assert isinstance(importer.importer, ResponsesSyncImporter)
assert isinstance(importer.importer._importer._importers[0], MultiProjectImporter)


Expand All @@ -140,7 +140,7 @@ async def test_nlu_only(project: Text):
)

assert isinstance(actual, NluDataImporter)
assert isinstance(actual._importer, RetrievalModelsDataImporter)
assert isinstance(actual._importer, ResponsesSyncImporter)

stories = await actual.get_stories()
assert stories.is_empty()
Expand Down Expand Up @@ -350,7 +350,7 @@ async def test_nlu_data_domain_sync_with_retrieval_intents(project: Text):
nlu_importer = NluDataImporter(base_data_importer)
core_importer = CoreDataImporter(base_data_importer)

importer = RetrievalModelsDataImporter(
importer = ResponsesSyncImporter(
CombinedDataImporter([nlu_importer, core_importer])
)
domain = await importer.get_domain()
Expand All @@ -361,3 +361,25 @@ async def test_nlu_data_domain_sync_with_retrieval_intents(project: Text):
assert domain.retrieval_intent_templates == nlu_data.responses
assert domain.templates != nlu_data.responses
assert "utter_chitchat" in domain.action_names


async def test_nlu_data_domain_sync_responses(project: Text):
config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
domain_path = "data/test_domains/default.yml"
data_paths = ["data/test_nlg/test_responses.yml"]

base_data_importer = TrainingDataImporter.load_from_dict(
{}, config_path, domain_path, data_paths
)

nlu_importer = NluDataImporter(base_data_importer)
core_importer = CoreDataImporter(base_data_importer)

importer = ResponsesSyncImporter(
CombinedDataImporter([nlu_importer, core_importer])
)
with pytest.warns(None):
domain = await importer.get_domain()

# Responses were sync between "test_responses.yml" and the "domain.yml"
assert "utter_rasa" in domain.templates.keys()