Skip to content

Commit

Permalink
modify do_extractors_support_overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
EPedrotti committed Mar 5, 2019
1 parent 894f333 commit 60a3f35
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
4 changes: 3 additions & 1 deletion rasa_nlu/test.py
Expand Up @@ -529,7 +529,9 @@ def determine_token_labels(token, entities, extractors):
def do_extractors_support_overlap(extractors):
"""Checks if extractors support overlapping entities
"""
return extractors is None or CRFEntityExtractor.name not in extractors
if extractors is None:
return False
return CRFEntityExtractor.name not in extractors


def align_entity_predictions(targets, predictions, tokens, extractors):
Expand Down
28 changes: 19 additions & 9 deletions tests/base/test_evaluation.py
Expand Up @@ -4,6 +4,8 @@

import pytest

from rasa_nlu.extractors.mitie_entity_extractor import MitieEntityExtractor
from rasa_nlu.extractors.spacy_entity_extractor import SpacyEntityExtractor
from rasa_nlu.test import (
is_token_within_entity, do_entities_overlap,
merge_labels, remove_duckling_entities,
Expand Down Expand Up @@ -175,19 +177,27 @@ def test_entity_overlap():

def test_determine_token_labels_throws_error():
with pytest.raises(ValueError):
determine_token_labels(CH_correct_segmentation,
determine_token_labels(CH_correct_segmentation[0],
[CH_correct_entity,
CH_wrong_entity], ["CRFEntityExtractor"])


def test_determine_token_labels_no_extractors():
with pytest.raises(ValueError):
determine_token_labels(CH_correct_segmentation[0],
[CH_correct_entity, CH_wrong_entity], None)


def test_determine_token_labels_no_extractors_no_overlap():
determine_token_labels(CH_correct_segmentation[0],
[CH_correct_entity, CH_wrong_entity], None)
EN_targets, None)


def test_determine_token_labels_with_extractors():
determine_token_labels(CH_correct_segmentation[0],
[CH_correct_entity, CH_wrong_entity], ["A", "B"])
[CH_correct_entity, CH_wrong_entity],
[SpacyEntityExtractor.name,
MitieEntityExtractor.name])


def test_label_merging():
Expand Down Expand Up @@ -259,17 +269,17 @@ def test_run_cv_evaluation():
assert len(results.test["Precision"]) == n_folds
assert len(results.test["F1-score"]) == n_folds
assert len(entity_results.train[
'CRFEntityExtractor']["Accuracy"]) == n_folds
'CRFEntityExtractor']["Accuracy"]) == n_folds
assert len(entity_results.train[
'CRFEntityExtractor']["Precision"]) == n_folds
'CRFEntityExtractor']["Precision"]) == n_folds
assert len(entity_results.train[
'CRFEntityExtractor']["F1-score"]) == n_folds
'CRFEntityExtractor']["F1-score"]) == n_folds
assert len(entity_results.test[
'CRFEntityExtractor']["Accuracy"]) == n_folds
'CRFEntityExtractor']["Accuracy"]) == n_folds
assert len(entity_results.test[
'CRFEntityExtractor']["Precision"]) == n_folds
'CRFEntityExtractor']["Precision"]) == n_folds
assert len(entity_results.test[
'CRFEntityExtractor']["F1-score"]) == n_folds
'CRFEntityExtractor']["F1-score"]) == n_folds


def test_intent_evaluation_report(tmpdir_factory):
Expand Down

0 comments on commit 60a3f35

Please sign in to comment.