From b733e0fe2b0db2838ee7903613a2968c85fc16a9 Mon Sep 17 00:00:00 2001 From: Tom Metcalfe Date: Wed, 17 Jul 2019 17:11:52 +0200 Subject: [PATCH] Edit serialisation, change WronglyClassifiedUserUtterance class initialisation --- rasa/core/test.py | 105 +++++++++++++++++---------------------------- rasa/core/utils.py | 5 --- 2 files changed, 40 insertions(+), 70 deletions(-) diff --git a/rasa/core/test.py b/rasa/core/test.py index e1b523bf55a2..816b13d1973f 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Text, Tuple from rasa.constants import RESULTS_FILE -from rasa.core.events import ActionExecuted, UserUttered +from rasa.core.events import ActionExecuted, UserUttered, Event if typing.TYPE_CHECKING: from rasa.core.agent import Agent @@ -76,32 +76,25 @@ def has_prediction_target_mismatch(self): or self.action_predictions != self.action_targets ) - def serialise_targets( - self, include_actions=True, include_intents=True, include_entities=False - ): - targets = [] - if include_actions: - targets += self.action_targets - if include_intents: - targets += self.intent_targets - if include_entities: - targets += self.entity_targets - - return [json.dumps(t) if isinstance(t, dict) else t for t in targets] - - def serialise_predictions( - self, include_actions=True, include_intents=True, include_entities=False - ): - predictions = [] + def serialise(self): + """Turn targets and predictions to lists of equal size for sklearn""" + + targets = ( + self.action_targets + + self.intent_targets + + [json.dumps(t) for t in self.entity_targets] + ) + predictions = ( + self.action_predictions + + self.intent_predictions + + [json.dumps(p) for p in self.entity_predictions] + ) - if include_actions: - predictions += self.action_predictions - if include_intents: - predictions += self.intent_predictions - if include_entities: - predictions += self.entity_predictions + # sklearn does not cope with lists of unequal size, nor None values + padding = len(targets) - len(predictions) + predictions += ["None"] * padding - return [json.dumps(t) if isinstance(t, dict) else t for t in predictions] + return targets, predictions class WronglyPredictedAction(ActionExecuted): @@ -144,24 +137,23 @@ class WronglyClassifiedUserUtterance(UserUttered): type_name = "wrong_utterance" - def __init__( - self, - text, - correct_intent, - correct_entities, - parse_data=None, - timestamp=None, - input_channel=None, - predicted_intent=None, - predicted_entities=None, - ): - self.predicted_intent = predicted_intent - self.predicted_entities = predicted_entities + def __init__(self, event: UserUttered, eval_store: EvaluationStore): + + if eval_store.intent_predictions == list(): + self.predicted_intent = None + else: + self.predicted_intent = eval_store.intent_predictions[0] + self.predicted_entities = eval_store.entity_predictions - intent = {"name": correct_intent} + intent = {"name": eval_store.intent_targets[0]} super(WronglyClassifiedUserUtterance, self).__init__( - text, intent, correct_entities, parse_data, timestamp, input_channel + event.text, + intent, + eval_store.entity_targets, + event.parse_data, + event.timestamp, + event.input_channel, ) def as_story_string(self, e2e=True): @@ -207,14 +199,14 @@ def _clean_entity_results(entity_results): def _collect_user_uttered_predictions( event, partial_tracker, fail_on_prediction_errors ): - from rasa.core.utils import pad_list_to_size - user_uttered_eval_store = EvaluationStore() intent_gold = event.parse_data.get("true_intent") predicted_intent = event.parse_data.get("intent").get("name") - if predicted_intent is None: - predicted_intent = "None" + + if predicted_intent == list(): + predicted_intent = [None] + user_uttered_eval_store.add_to_store( intent_predictions=predicted_intent, intent_targets=intent_gold ) @@ -223,13 +215,6 @@ def _collect_user_uttered_predictions( predicted_entities = event.parse_data.get("entities") if entity_gold or predicted_entities: - if len(entity_gold) > len(predicted_entities): - predicted_entities = pad_list_to_size( - predicted_entities, len(entity_gold), "None" - ) - elif len(predicted_entities) > len(entity_gold): - entity_gold = pad_list_to_size(entity_gold, len(predicted_entities), "None") - user_uttered_eval_store.add_to_store( entity_targets=_clean_entity_results(entity_gold), entity_predictions=_clean_entity_results(predicted_entities), @@ -237,16 +222,7 @@ def _collect_user_uttered_predictions( if user_uttered_eval_store.has_prediction_target_mismatch(): partial_tracker.update( - WronglyClassifiedUserUtterance( - event.text, - intent_gold, - user_uttered_eval_store.entity_predictions, - event.parse_data, - event.timestamp, - event.input_channel, - predicted_intent, - user_uttered_eval_store.entity_targets, - ) + WronglyClassifiedUserUtterance(event, user_uttered_eval_store) ) if fail_on_prediction_errors: raise ValueError( @@ -493,10 +469,9 @@ async def test( from sklearn.exceptions import UndefinedMetricWarning warnings.simplefilter("ignore", UndefinedMetricWarning) - report, precision, f1, accuracy = get_evaluation_metrics( - evaluation_store.serialise_targets(), - evaluation_store.serialise_predictions(), - ) + + targets, predictions = evaluation_store.serialise() + report, precision, f1, accuracy = get_evaluation_metrics(targets, predictions) if out_directory: plot_story_evaluation( diff --git a/rasa/core/utils.py b/rasa/core/utils.py index ac8691b83610..23dcf18080f2 100644 --- a/rasa/core/utils.py +++ b/rasa/core/utils.py @@ -376,11 +376,6 @@ def remove_none_values(obj: Dict[Text, Any]) -> Dict[Text, Any]: return {k: v for k, v in obj.items() if v is not None} -def pad_list_to_size(_list, size, padding_value=None): - """Pads _list with padding_value up to size""" - return _list + [padding_value] * (size - len(_list)) - - class AvailableEndpoints(object): """Collection of configured endpoints."""