Skip to content

Commit

Permalink
Merge dda82c3 into 3e84a0d
Browse files Browse the repository at this point in the history
  • Loading branch information
MetcalfeTom committed Jan 30, 2019
2 parents 3e84a0d + dda82c3 commit 6bff40d
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 59 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@ Change Log
All notable changes to this project will be documented in this file.
This project adheres to `Semantic Versioning`_ starting with version 0.7.0.

.. _v0-14-0:
[Unreleased 0.15.0.aX] - `master`_
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Added
-----
- ``rasa_nlu.evaluate`` now exports reports into a folder and also
includes the entity extractor reports

[0.14.1] - 2018-01-23
^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -13,6 +19,8 @@ Fixed
-----
- scikit-learn is a global requirement

.. _v0-14-0:

[0.14.0] - 2018-01-23
^^^^^^^^^^^^^^^^^^^^^

Expand Down
6 changes: 3 additions & 3 deletions docs/evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ Intent Classification
The evaluation script will produce a report, confusion matrix
and confidence histogram for your model.

The report logs precision, recall, and f1 measure for
each intent, as well as provide an overall average. You can save this
report as a JSON file using the `--report` flag.
The report logs precision, recall and f1 measure for
each intent and entity, as well as provide an overall average.
You can save these reports as JSON files using the `--report` flag.

The confusion matrix shows you which
intents are mistaken for others; any samples which have been
Expand Down
106 changes: 56 additions & 50 deletions rasa_nlu/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
from __future__ import unicode_literals

import itertools

import json
import os
import logging
import numpy as np
import shutil
from collections import defaultdict
from collections import namedtuple
from typing import List, Optional, Text

import numpy as np

from rasa_nlu import training_data, utils, config
from rasa_nlu import config, training_data, utils
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.extractors.crf_entity_extractor import CRFEntityExtractor
from rasa_nlu.model import Interpreter
from rasa_nlu.model import Trainer, TrainingData
from rasa_nlu.model import Interpreter, Trainer, TrainingData

logger = logging.getLogger(__name__)

Expand All @@ -42,8 +40,8 @@
def create_argument_parser():
import argparse
parser = argparse.ArgumentParser(
description='evaluate a Rasa NLU pipeline with cross '
'validation or on external data')
description='evaluate a Rasa NLU pipeline with cross '
'validation or on external data')

parser.add_argument('-d', '--data', required=True,
help="file containing training/evaluation data")
Expand All @@ -64,8 +62,9 @@ def create_argument_parser():
help="number of CV folds (crossvalidation only)")

parser.add_argument('--report', required=False, nargs='?',
const="report.json", default=False,
help="output path to save the metrics report")
const="reports", default=False,
help="output path to save the intent/entity"
"metrics report")

parser.add_argument('--successes', required=False, nargs='?',
const="successes.json", default=False,
Expand All @@ -85,14 +84,14 @@ def create_argument_parser():
return parser


def plot_confusion_matrix(cm, classes,
def plot_confusion_matrix(cm,
classes,
normalize=False,
title='Confusion matrix',
cmap=None,
zmin=1,
out=None): # pragma: no cover
"""Print and plot the confusion matrix for the intent classification.
Normalization can be applied by setting `normalize=True`."""
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
Expand Down Expand Up @@ -171,7 +170,7 @@ def log_evaluation_table(report, # type: Text
logger.info("Classification report: \n{}".format(report))


def get_evaluation_metrics(targets, predictions, output_dict=False): # pragma: no cover
def get_evaluation_metrics(targets, predictions, output_dict=False):
"""Compute the f1, precision, accuracy and summary report from sklearn."""
from sklearn import metrics

Expand Down Expand Up @@ -214,7 +213,7 @@ def drop_intents_below_freq(td, cutoff=5):
"""Remove intent groups with less than cutoff instances."""

logger.debug(
"Raw data intent examples: {}".format(len(td.intent_examples)))
"Raw data intent examples: {}".format(len(td.intent_examples)))
keep_examples = [ex
for ex in td.intent_examples
if td.examples_per_intent[ex.get("intent")] >= cutoff]
Expand Down Expand Up @@ -284,7 +283,7 @@ def plot_intent_confidences(intent_results, intent_hist_filename):


def evaluate_intents(intent_results,
report_filename,
report_folder,
successes_filename,
errors_filename,
confmat_filename,
Expand All @@ -308,10 +307,11 @@ def evaluate_intents(intent_results,

targets, predictions = _targets_predictions_from(intent_results)

if report_filename:
report, precision, f1, accuracy = get_evaluation_metrics(targets,
predictions,
output_dict=True)
if report_folder:
report, precision, f1, accuracy = get_evaluation_metrics(
targets, predictions, output_dict=True)

report_filename = os.path.join(report_folder, 'intent_report.json')

save_json(report, report_filename)
logger.info("Classification report saved to {}."
Expand Down Expand Up @@ -367,7 +367,6 @@ def evaluate_intents(intent_results,

def merge_labels(aligned_predictions, extractor=None):
"""Concatenates all labels of the aligned predictions.
Takes the aligned prediction labels which are grouped for each message
and concatenates them."""

Expand All @@ -390,9 +389,9 @@ def substitute_labels(labels, old, new):
def evaluate_entities(targets,
predictions,
tokens,
extractors): # pragma: no cover
extractors,
report_folder): # pragma: no cover
"""Creates summary statistics for each entity extractor.
Logs precision, recall, and F1 per entity type for each extractor."""

aligned_predictions = align_all_entity_predictions(targets, predictions,
Expand All @@ -405,11 +404,24 @@ def evaluate_entities(targets,
for extractor in extractors:
merged_predictions = merge_labels(aligned_predictions, extractor)
merged_predictions = substitute_labels(
merged_predictions, "O", "no_entity")
merged_predictions, "O", "no_entity")
logger.info("Evaluation for entity extractor: {} ".format(extractor))
report, precision, f1, accuracy = get_evaluation_metrics(
merged_targets, merged_predictions)
log_evaluation_table(report, precision, f1, accuracy)
if report_folder:
report, precision, f1, accuracy = get_evaluation_metrics(
merged_targets, merged_predictions, output_dict=True)

report_filename = extractor + "_report.json"
extractor_report = os.path.join(report_folder, report_filename)

save_json(report, extractor_report)
logger.info("Classification report for {} saved to {}."
.format(extractor, extractor_report))

else:
report, precision, f1, accuracy = get_evaluation_metrics(
merged_targets, merged_predictions)
log_evaluation_table(report, precision, f1, accuracy)

result[extractor] = {
"report": report,
"precision": precision,
Expand Down Expand Up @@ -442,9 +454,7 @@ def determine_intersection(token, entity):

def do_entities_overlap(entities):
"""Checks if entities overlap.
I.e. cross each others start and end boundaries.
:param entities: list of entities
:return: boolean
"""
Expand All @@ -453,16 +463,15 @@ def do_entities_overlap(entities):
for i in range(len(sorted_entities) - 1):
curr_ent = sorted_entities[i]
next_ent = sorted_entities[i + 1]
if (next_ent["start"] < curr_ent["end"]
and next_ent["entity"] != curr_ent["entity"]):
if (next_ent["start"] < curr_ent["end"] and
next_ent["entity"] != curr_ent["entity"]):
return True

return False


def find_intersecting_entites(token, entities):
"""Finds the entities that intersect with a token.
:param token: a single token
:param entities: entities found by a single extractor
:return: list of entities
Expand All @@ -482,7 +491,6 @@ def find_intersecting_entites(token, entities):

def pick_best_entity_fit(token, candidates):
"""Determines the token label given intersecting entities.
:param token: a single token
:param candidates: entities found by a single extractor
:return: entity type
Expand Down Expand Up @@ -510,8 +518,8 @@ def determine_token_labels(token, entities, extractors):

if len(entities) == 0:
return "O"
if not do_extractors_support_overlap(extractors) and \
do_entities_overlap(entities):
if (not do_extractors_support_overlap(extractors) and
do_entities_overlap(entities)):
raise ValueError("The possible entities should not overlap")

candidates = find_intersecting_entites(token, entities)
Expand All @@ -526,11 +534,9 @@ def do_extractors_support_overlap(extractors):

def align_entity_predictions(targets, predictions, tokens, extractors):
"""Aligns entity predictions to the message tokens.
Determines for every token the true label based on the
prediction targets and the label assigned by each
single extractor.
:param targets: list of target entities
:param predictions: list of predicted entities
:param tokens: original message tokens
Expand Down Expand Up @@ -558,7 +564,6 @@ def align_entity_predictions(targets, predictions, tokens, extractors):
def align_all_entity_predictions(targets, predictions, tokens, extractors):
""" Aligns entity predictions to the message tokens for the whole dataset
using align_entity_predictions
:param targets: list of lists of target entities
:param predictions: list of lists of predicted entities
:param tokens: list of original message tokens
Expand Down Expand Up @@ -614,10 +619,10 @@ def get_intent_predictions(targets, interpreter,
for e, target in zip(test_data.training_examples, targets):
res = interpreter.parse(e.text, only_output_properties=False)
intent_results.append(IntentEvaluationResult(
target,
extract_intent(res),
extract_message(res),
extract_confidence(res)))
target,
extract_intent(res),
extract_message(res),
extract_confidence(res)))

return intent_results

Expand All @@ -639,7 +644,6 @@ def get_entity_predictions(interpreter, test_data): # pragma: no cover

def get_entity_extractors(interpreter):
"""Finds the names of entity extractors used by the interpreter.
Processors are removed since they do not
detect the boundaries themselves."""

Expand All @@ -663,7 +667,6 @@ def combine_extractor_and_dimension_name(extractor, dim):

def get_duckling_dimensions(interpreter, duckling_extractor_name):
"""Gets the activated dimensions of a duckling extractor.
If there are no activated dimensions, it uses all known
dimensions as a fallback."""

Expand Down Expand Up @@ -708,7 +711,7 @@ def remove_duckling_entities(entity_predictions):


def run_evaluation(data_path, model,
report_filename=None,
report_folder=None,
successes_filename=None,
errors_filename='errors.json',
confmat_filename=None,
Expand Down Expand Up @@ -736,14 +739,17 @@ def run_evaluation(data_path, model,
"entity_evaluation": None
}

if report_folder:
utils.create_dir(report_folder)

if is_intent_classifier_present(interpreter):
intent_targets = get_intent_targets(test_data)
intent_results = get_intent_predictions(
intent_targets, interpreter, test_data)
intent_targets, interpreter, test_data)

logger.info("Intent evaluation results:")
result['intent_evaluation'] = evaluate_intents(intent_results,
report_filename,
report_folder,
successes_filename,
errors_filename,
confmat_filename,
Expand All @@ -756,7 +762,8 @@ def run_evaluation(data_path, model,
result['entity_evaluation'] = evaluate_entities(entity_targets,
entity_predictions,
tokens,
extractors)
extractors,
report_folder)

return result

Expand Down Expand Up @@ -802,7 +809,6 @@ def combine_entity_result(results, interpreter, data):
def run_cv_evaluation(data, n_folds, nlu_config):
# type: (TrainingData, int, RasaNLUModelConfig) -> CVEvaluationResult
"""Stratified cross validation on data
:param data: Training Data
:param n_folds: integer, number of cv folds
:param nlu_config: nlu config file
Expand Down Expand Up @@ -943,7 +949,7 @@ def main():
data = training_data.load_data(cmdline_args.data)
data = drop_intents_below_freq(data, cutoff=5)
results, entity_results = run_cv_evaluation(
data, int(cmdline_args.folds), nlu_config)
data, int(cmdline_args.folds), nlu_config)
logger.info("CV evaluation (n={})".format(cmdline_args.folds))

if any(results):
Expand Down

0 comments on commit 6bff40d

Please sign in to comment.