Skip to content

Commit

Permalink
Edit serialisation, change WronglyClassifiedUserUtterance class initi…
Browse files Browse the repository at this point in the history
…alisation
  • Loading branch information
MetcalfeTom committed Jul 17, 2019
1 parent 3abf8a4 commit b733e0f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 70 deletions.
105 changes: 40 additions & 65 deletions rasa/core/test.py
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, List, Optional, Text, Tuple

from rasa.constants import RESULTS_FILE
from rasa.core.events import ActionExecuted, UserUttered
from rasa.core.events import ActionExecuted, UserUttered, Event

if typing.TYPE_CHECKING:
from rasa.core.agent import Agent
Expand Down Expand Up @@ -76,32 +76,25 @@ def has_prediction_target_mismatch(self):
or self.action_predictions != self.action_targets
)

def serialise_targets(
self, include_actions=True, include_intents=True, include_entities=False
):
targets = []
if include_actions:
targets += self.action_targets
if include_intents:
targets += self.intent_targets
if include_entities:
targets += self.entity_targets

return [json.dumps(t) if isinstance(t, dict) else t for t in targets]

def serialise_predictions(
self, include_actions=True, include_intents=True, include_entities=False
):
predictions = []
def serialise(self):
"""Turn targets and predictions to lists of equal size for sklearn"""

targets = (
self.action_targets
+ self.intent_targets
+ [json.dumps(t) for t in self.entity_targets]
)
predictions = (
self.action_predictions
+ self.intent_predictions
+ [json.dumps(p) for p in self.entity_predictions]
)

if include_actions:
predictions += self.action_predictions
if include_intents:
predictions += self.intent_predictions
if include_entities:
predictions += self.entity_predictions
# sklearn does not cope with lists of unequal size, nor None values
padding = len(targets) - len(predictions)
predictions += ["None"] * padding

return [json.dumps(t) if isinstance(t, dict) else t for t in predictions]
return targets, predictions


class WronglyPredictedAction(ActionExecuted):
Expand Down Expand Up @@ -144,24 +137,23 @@ class WronglyClassifiedUserUtterance(UserUttered):

type_name = "wrong_utterance"

def __init__(
self,
text,
correct_intent,
correct_entities,
parse_data=None,
timestamp=None,
input_channel=None,
predicted_intent=None,
predicted_entities=None,
):
self.predicted_intent = predicted_intent
self.predicted_entities = predicted_entities
def __init__(self, event: UserUttered, eval_store: EvaluationStore):

if eval_store.intent_predictions == list():
self.predicted_intent = None
else:
self.predicted_intent = eval_store.intent_predictions[0]
self.predicted_entities = eval_store.entity_predictions

intent = {"name": correct_intent}
intent = {"name": eval_store.intent_targets[0]}

super(WronglyClassifiedUserUtterance, self).__init__(
text, intent, correct_entities, parse_data, timestamp, input_channel
event.text,
intent,
eval_store.entity_targets,
event.parse_data,
event.timestamp,
event.input_channel,
)

def as_story_string(self, e2e=True):
Expand Down Expand Up @@ -207,14 +199,14 @@ def _clean_entity_results(entity_results):
def _collect_user_uttered_predictions(
event, partial_tracker, fail_on_prediction_errors
):
from rasa.core.utils import pad_list_to_size

user_uttered_eval_store = EvaluationStore()

intent_gold = event.parse_data.get("true_intent")
predicted_intent = event.parse_data.get("intent").get("name")
if predicted_intent is None:
predicted_intent = "None"

if predicted_intent == list():
predicted_intent = [None]

user_uttered_eval_store.add_to_store(
intent_predictions=predicted_intent, intent_targets=intent_gold
)
Expand All @@ -223,30 +215,14 @@ def _collect_user_uttered_predictions(
predicted_entities = event.parse_data.get("entities")

if entity_gold or predicted_entities:
if len(entity_gold) > len(predicted_entities):
predicted_entities = pad_list_to_size(
predicted_entities, len(entity_gold), "None"
)
elif len(predicted_entities) > len(entity_gold):
entity_gold = pad_list_to_size(entity_gold, len(predicted_entities), "None")

user_uttered_eval_store.add_to_store(
entity_targets=_clean_entity_results(entity_gold),
entity_predictions=_clean_entity_results(predicted_entities),
)

if user_uttered_eval_store.has_prediction_target_mismatch():
partial_tracker.update(
WronglyClassifiedUserUtterance(
event.text,
intent_gold,
user_uttered_eval_store.entity_predictions,
event.parse_data,
event.timestamp,
event.input_channel,
predicted_intent,
user_uttered_eval_store.entity_targets,
)
WronglyClassifiedUserUtterance(event, user_uttered_eval_store)
)
if fail_on_prediction_errors:
raise ValueError(
Expand Down Expand Up @@ -493,10 +469,9 @@ async def test(
from sklearn.exceptions import UndefinedMetricWarning

warnings.simplefilter("ignore", UndefinedMetricWarning)
report, precision, f1, accuracy = get_evaluation_metrics(
evaluation_store.serialise_targets(),
evaluation_store.serialise_predictions(),
)

targets, predictions = evaluation_store.serialise()
report, precision, f1, accuracy = get_evaluation_metrics(targets, predictions)

if out_directory:
plot_story_evaluation(
Expand Down
5 changes: 0 additions & 5 deletions rasa/core/utils.py
Expand Up @@ -376,11 +376,6 @@ def remove_none_values(obj: Dict[Text, Any]) -> Dict[Text, Any]:
return {k: v for k, v in obj.items() if v is not None}


def pad_list_to_size(_list, size, padding_value=None):
"""Pads _list with padding_value up to size"""
return _list + [padding_value] * (size - len(_list))


class AvailableEndpoints(object):
"""Collection of configured endpoints."""

Expand Down

0 comments on commit b733e0f

Please sign in to comment.