Skip to content

Commit

Permalink
merged master
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Mar 21, 2018
2 parents 291c870 + 278baf5 commit 4ae9d4d
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 56 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -27,6 +27,9 @@ Changed
of ``"None"`` as intent name in the json result if there's no match
- in teh evaluation results, replaced ``O`` with the string
``no_entity`` for better understanding
- The ``CRFEntityExtractor`` now only trains entity examples that have
``"extractor": "ner_crf"`` or no extractor at all
- Ignore hidden files when listing projects or models

Fixed
-----
Expand Down
5 changes: 3 additions & 2 deletions docs/_templates/layout.html
Expand Up @@ -6,12 +6,13 @@
{% block footer %}
{{ super() }}
<!-- Global Site Tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=UA-87333416-3"></script>
<script async src="https://www.googletagmanager.com/gtag/js?id=UA-87333416-1"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments)};
gtag('js', new Date());

gtag('config', 'UA-87333416-3');
gtag('config', 'UA-87333416-1');
</script>
<script type="text/javascript" src="//script.crazyegg.com/pages/scripts/0074/3851.js" async="async"></script>
{% endblock %}
3 changes: 1 addition & 2 deletions rasa_nlu/data_router.py
Expand Up @@ -210,8 +210,7 @@ def parse(self, data):
def _list_projects(path):
"""List the projects in the path, ignoring hidden directories."""
return [os.path.basename(fn)
for fn in glob.glob(os.path.join(path, '*'))
if os.path.isdir(fn)]
for fn in utils.list_subdirectories(path)]

@staticmethod
def create_temporary_file(data, suffix=""):
Expand Down
27 changes: 27 additions & 0 deletions rasa_nlu/extractors/__init__.py
Expand Up @@ -9,6 +9,7 @@
from typing import Text

from rasa_nlu.components import Component
from rasa_nlu.training_data import Message


class EntityExtractor(Component):
Expand All @@ -24,6 +25,7 @@ def add_processor_name(self, entity):
entity["processors"].append(self.name)
else:
entity["processors"] = [self.name]

return entity

@staticmethod
Expand All @@ -46,3 +48,28 @@ def find_entity(ent, text, tokens):
start = offsets.index(ent["start"])
end = ends.index(ent["end"]) + 1
return start, end

def filter_trainable_entities(self, entity_examples):
# type: (List[Message]) -> List[Message]
"""Filters out untrainable entity annotations.
Creates a copy of entity_examples in which entities that have
`extractor` set to something other than self.name (e.g. 'ner_crf')
are removed."""

filtered = []
for message in entity_examples:
entities = []
for ent in message.get("entities", []):
extractor = ent.get("extractor")
if not extractor or extractor == self.name:
entities.append(ent)
data = message.data.copy()
data['entities'] = entities
filtered.append(
Message(text=message.text,
data=data,
output_properties=message.output_properties,
time=message.time))

return filtered
7 changes: 4 additions & 3 deletions rasa_nlu/extractors/crf_entity_extractor.py
Expand Up @@ -17,7 +17,6 @@
from rasa_nlu.config import RasaNLUConfig
from rasa_nlu.extractors import EntityExtractor
from rasa_nlu.model import Metadata
from rasa_nlu.tokenizers import Token
from rasa_nlu.training_data import Message
from rasa_nlu.training_data import TrainingData
from builtins import str
Expand All @@ -27,7 +26,6 @@
if typing.TYPE_CHECKING:
from spacy.language import Language
import sklearn_crfsuite
from spacy.tokens import Doc


class CRFEntityExtractor(EntityExtractor):
Expand Down Expand Up @@ -97,10 +95,13 @@ def train(self, training_data, config, **kwargs):

# checks whether there is at least one example with an entity annotation
if training_data.entity_examples:
# filter out pre-trained entity examples
filtered_entity_examples = self.filter_trainable_entities(
training_data.entity_examples)
# convert the dataset into features
# this will train on ALL examples, even the ones
# without annotations
dataset = self._create_dataset(training_data.training_examples)
dataset = self._create_dataset(filtered_entity_examples)
# train the model
self._train_model(dataset)

Expand Down
4 changes: 2 additions & 2 deletions rasa_nlu/project.py
Expand Up @@ -12,6 +12,7 @@
from builtins import object
from threading import Lock

from rasa_nlu import utils
from typing import Text, List

from rasa_nlu.config import RasaNLUConfig
Expand Down Expand Up @@ -223,5 +224,4 @@ def _list_models_in_dir(path):
return []
else:
return [os.path.relpath(model, path)
for model in glob.glob(os.path.join(path, '*'))
if os.path.isdir(model)]
for model in utils.list_subdirectories(path)]
2 changes: 1 addition & 1 deletion rasa_nlu/training_data/loading.py
Expand Up @@ -42,7 +42,7 @@ def load_data(resource_name, language='en'):
# type: (Text, Optional[Text]) -> TrainingData
"""Loads training data from disk and merges them if multiple files are found."""

files = utils.recursively_find_files(resource_name)
files = utils.list_files(resource_name)
data_sets = [_load(f, language) for f in files]
data_sets = [ds for ds in data_sets if ds]
if len(data_sets) == 0:
Expand Down
49 changes: 35 additions & 14 deletions rasa_nlu/utils/__init__.py
Expand Up @@ -4,6 +4,7 @@
from __future__ import unicode_literals

import errno
import glob
import io
import json
import os
Expand Down Expand Up @@ -52,28 +53,48 @@ def create_dir_for_file(file_path):
raise


def recursively_find_files(resource_name):
def list_directory(path):
# type: (Text) -> List[Text]
"""Traverse directory hierarchy to find files.
"""Returns all files and folders excluding hidden files.
`resource_name` can be a folder or a file. In both cases
we will return a list of files."""
If the path points to a file, returns the file. This is a recursive
implementation returning files in any depth of the path."""

if not isinstance(resource_name, six.string_types):
if not isinstance(path, six.string_types):
raise ValueError("Resourcename must be a string type")

found = []
if os.path.isfile(resource_name):
found.append(resource_name)
elif os.path.isdir(resource_name):
for root, directories, files in os.walk(resource_name):
for f in files:
found.append(os.path.join(root, f))
if os.path.isfile(path):
return [path]
elif os.path.isdir(path):
results = []
for base, dirs, files in os.walk(path):
# remove hidden files
goodfiles = filter(lambda x: not x.startswith('.'), files)
results.extend(os.path.join(base, f) for f in goodfiles)
return results
else:
raise ValueError("Could not locate the resource '{}'."
"".format(os.path.abspath(resource_name)))
"".format(os.path.abspath(path)))

return found

def list_files(path):
# type: (Text) -> List[Text]
"""Returns all files excluding hidden files.
If the path points to a file, returns the file."""

return [fn for fn in list_directory(path) if os.path.isfile(fn)]


def list_subdirectories(path):
# type: (Text) -> List[Text]
"""Returns all folders excluding hidden files.
If the path points to a file, returns an empty list."""

return [fn
for fn in glob.glob(os.path.join(path, '*'))
if os.path.isdir(fn)]


def lazyproperty(fn):
Expand Down
15 changes: 14 additions & 1 deletion tests/base/test_extractors.py
Expand Up @@ -20,7 +20,10 @@ def test_crf_extractor(spacy_nlp):
}),
Message("central indian restaurant", {
"intent": "restaurant_search",
"entities": [{"start": 0, "end": 7, "value": "central", "entity": "location"}],
"entities": [
{"start": 0, "end": 7, "value": "central", "entity": "location", "extractor": "random_extractor"},
{"start": 8, "end": 14, "value": "indian", "entity": "cuisine", "extractor": "ner_crf"}
],
"spacy_doc": spacy_nlp("central indian restaurant")
})]
config = {"ner_crf": {"BILOU_flag": True, "features": ext.crf_features}}
Expand All @@ -34,6 +37,16 @@ def test_crf_extractor(spacy_nlp):
assert feats[1]['0:low'] == "in"
sentence = 'anywhere in the west'
ext.extract_entities(Message(sentence, {"spacy_doc": spacy_nlp(sentence)}))
filtered = ext.filter_trainable_entities(examples)
assert filtered[0].get('entities') == [
{"start": 16, "end": 20, "value": "west", "entity": "location"}
], 'Entity without extractor remains'
assert filtered[1].get('entities') == [
{"start": 8, "end": 14, "value": "indian", "entity": "cuisine", "extractor": "ner_crf"}
], 'Only ner_crf entity annotation remains'
assert examples[1].get('entities')[0] == {
"start": 0, "end": 7, "value": "central", "entity": "location", "extractor": "random_extractor"
}, 'Original examples are not mutated'


def test_crf_json_from_BILOU(spacy_nlp):
Expand Down
16 changes: 8 additions & 8 deletions tests/base/test_multitenancy.py
Expand Up @@ -48,7 +48,7 @@ def app(component_builder):

@pytest.mark.parametrize("response_test", [
ResponseTest(
"http://dummy_uri/parse?q=food&project=test_project_spacy_sklearn",
"http://dummy-uri/parse?q=food&project=test_project_spacy_sklearn",
{"entities": [], "intent": "restaurant_search", "text": "food"}
),
])
Expand All @@ -63,11 +63,11 @@ def test_get_parse(app, response_test):

@pytest.mark.parametrize("response_test", [
ResponseTest(
"http://dummy_uri/parse?q=food",
"http://dummy-uri/parse?q=food",
{"error": "No project found with name 'default'."}
),
ResponseTest(
"http://dummy_uri/parse?q=food&project=umpalumpa",
"http://dummy-uri/parse?q=food&project=umpalumpa",
{"error": "No project found with name 'umpalumpa'."}
)
])
Expand All @@ -81,7 +81,7 @@ def test_get_parse_invalid_model(app, response_test):

@pytest.mark.parametrize("response_test", [
ResponseTest(
"http://dummy_uri/parse",
"http://dummy-uri/parse",
{"entities": [], "intent": "restaurant_search", "text": "food"},
payload={"q": "food", "project": "test_project_spacy_sklearn"}
),
Expand All @@ -96,11 +96,11 @@ def test_post_parse(app, response_test):

@pytest.inlineCallbacks
def test_post_parse_specific_model(app):
status = yield app.get("http://dummy_uri/status")
status = yield app.get("http://dummy-uri/status")
sjs = yield status.json()
project = sjs["available_projects"]["test_project_spacy_sklearn"]
model = project["available_models"][0]
query = ResponseTest("http://dummy_uri/parse",
query = ResponseTest("http://dummy-uri/parse",
{"entities": [], "intent": "affirm", "text": "food"},
payload={"q": "food",
"project": "test_project_spacy_sklearn",
Expand All @@ -111,12 +111,12 @@ def test_post_parse_specific_model(app):

@pytest.mark.parametrize("response_test", [
ResponseTest(
"http://dummy_uri/parse",
"http://dummy-uri/parse",
{"error": "No project found with name 'default'."},
payload={"q": "food"}
),
ResponseTest(
"http://dummy_uri/parse",
"http://dummy-uri/parse",
{"error": "No project found with name 'umpalumpa'."},
payload={"q": "food", "project": "umpalumpa"}
),
Expand Down

0 comments on commit 4ae9d4d

Please sign in to comment.