Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split do_extarctor_overlap to correctly handle overllappine entities #1737

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -25,6 +25,7 @@ Removed

Fixed
-----
- Overlapping entities cause an exception only if the extractor does not support them.

[0.14.3] - 2019-02-01
^^^^^^^^^^^^^^^^^^^^^
Expand Down
27 changes: 21 additions & 6 deletions rasa_nlu/evaluate.py
@@ -1,13 +1,13 @@
import itertools
from collections import defaultdict, namedtuple

import json
import os
import logging
import numpy as np
import os
import shutil
from collections import defaultdict, namedtuple
from typing import List, Optional, Text

import numpy as np

from rasa_nlu import config, training_data, utils
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.extractors.crf_entity_extractor import CRFEntityExtractor
Expand Down Expand Up @@ -521,10 +521,25 @@ def determine_token_labels(token, entities, extractors):
return pick_best_entity_fit(token, candidates)


def determine_true_token_labels(token, entities):
"""Determines the token label given entities that do not overlap.
Args:
token: a single token
entities: entities found by a single extractor
Returns:
entity type
"""
if len(entities) == 0:
return "O"

candidates = find_intersecting_entites(token, entities)
return pick_best_entity_fit(token, candidates)


def do_extractors_support_overlap(extractors):
"""Checks if extractors support overlapping entities
"""
return extractors is None or CRFEntityExtractor.name not in extractors
return CRFEntityExtractor.name not in extractors


def align_entity_predictions(targets, predictions, tokens, extractors):
Expand All @@ -547,7 +562,7 @@ def align_entity_predictions(targets, predictions, tokens, extractors):
extractor_labels = {extractor: [] for extractor in extractors}
for t in tokens:
true_token_labels.append(
determine_token_labels(t, targets, None))
determine_true_token_labels(t, targets))
for extractor, entities in entities_by_extractors.items():
extracted = determine_token_labels(t, entities, extractor)
extractor_labels[extractor].append(extracted)
Expand Down
3 changes: 2 additions & 1 deletion rasa_nlu/server.py
@@ -1,7 +1,8 @@
import argparse
import logging
import simplejson
from functools import wraps

import simplejson
from klein import Klein
from twisted.internet import reactor, threads
from twisted.internet.defer import inlineCallbacks, returnValue
Expand Down
9 changes: 4 additions & 5 deletions tests/base/test_evaluation.py
Expand Up @@ -16,6 +16,7 @@
from rasa_nlu.evaluate import align_entity_predictions
from rasa_nlu.evaluate import determine_intersection
from rasa_nlu.evaluate import determine_token_labels
from rasa_nlu.evaluate import determine_true_token_labels
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.tokenizers import Token
from rasa_nlu import utils
Expand Down Expand Up @@ -170,12 +171,12 @@ def test_determine_token_labels_throws_error():
with pytest.raises(ValueError):
determine_token_labels(CH_correct_segmentation,
[CH_correct_entity,
CH_wrong_entity], ["ner_crf"])
CH_wrong_entity], ["A", "B", "ner_crf"])


def test_determine_token_labels_no_extractors():
determine_token_labels(CH_correct_segmentation[0],
[CH_correct_entity, CH_wrong_entity], None)
determine_true_token_labels(CH_correct_segmentation[0],
[CH_correct_entity, CH_wrong_entity])


def test_determine_token_labels_with_extractors():
Expand Down Expand Up @@ -259,7 +260,6 @@ def test_run_cv_evaluation():


def test_intent_evaluation_report(tmpdir_factory):

path = tmpdir_factory.mktemp("evaluation").strpath
report_folder = os.path.join(path, "reports")
report_filename = os.path.join(report_folder, "intent_report.json")
Expand Down Expand Up @@ -297,7 +297,6 @@ def test_intent_evaluation_report(tmpdir_factory):


def test_entity_evaluation_report(tmpdir_factory):

path = tmpdir_factory.mktemp("evaluation").strpath
report_folder = os.path.join(path, "reports")

Expand Down