Skip to content

Commit

Permalink
Merge fd234bd into 660b2db
Browse files Browse the repository at this point in the history
  • Loading branch information
MetcalfeTom committed Aug 7, 2019
2 parents 660b2db + fd234bd commit 6ea0d2c
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 127 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.rst
Expand Up @@ -12,10 +12,10 @@ This project adheres to `Semantic Versioning`_ starting with version 1.0.
Added
-----


Changed
-------

- messages with multiple entities are now handled properly with e2e evaluation
- ``data/test_evaluations/end_to_end_story.md`` was re-written in the restaurantbot domain

Removed
-------
Expand Down Expand Up @@ -56,6 +56,9 @@ Added

Changed
-------
- ``Agent.update_model()`` and ``Agent.handle_message()`` now work without needing to set a domain
or a policy ensemble
- Update pytype to ``2019.7.11``
- new event broker class: ``SQLProducer``. This event broker is now used when running locally with
Rasa X
- API requests are not longer logged to ``rasa_core.log`` by default in order to avoid
Expand Down
38 changes: 24 additions & 14 deletions data/test_evaluations/end_to_end_story.md
@@ -1,17 +1,27 @@
## simple_story_with_only_start
> check_greet <!-- checkpoints at the start define entry points -->
* default:/default
- utter_default

## simple_story_with_only_end
* greet:/greet
- utter_greet
> check_greet <!-- checkpoint defining the end of this turn -->
* greet: Hello
- utter_ask_howcanhelp

## simple_story_with_multiple_turns
* greet:/greet
- utter_greet
* default:/default
- utter_default
* goodbye:/goodbye
- utter_goodbye
* greet: good morning
- utter_ask_howcanhelp
* inform: im looking for a [moderately](price:moderate) priced restaurant in the [east](location) part of town
- utter_on_it
- utter_ask_cuisine
* inform: [french](cuisine) food
- utter_ask_numpeople

## story_with_multiple_entities_correction_and_search
* greet: hello
- utter_ask_howcanhelp
* inform: im looking for a [cheap](price:lo) restaurant which has [french](cuisine) food and is located in [bombay](location)
- utter_on_it
- utter_ask_numpeople
* inform: for [six](people:6) please
- utter_ask_moreupdates
* inform: actually i need a [moderately](price:moderate) priced restaurant
- utter_ask_moreupdates
* deny: no
- utter_ack_dosearch
- action_search_restaurants
- action_suggest
2 changes: 1 addition & 1 deletion examples/restaurantbot/config.yml
Expand Up @@ -11,7 +11,7 @@ pipeline:
policies:
- name: "examples.restaurantbot.policy.RestaurantPolicy"
batch_size: 100
epochs: 400
epochs: 100
validation_split: 0.2
- name: MemoizationPolicy
- name: MappingPolicy
4 changes: 2 additions & 2 deletions examples/restaurantbot/run.py
Expand Up @@ -37,11 +37,11 @@ async def train_core(
policies=[
MemoizationPolicy(max_history=3),
MappingPolicy(),
RestaurantPolicy(batch_size=100, epochs=400, validation_split=0.2),
RestaurantPolicy(batch_size=100, epochs=100, validation_split=0.2),
],
)

training_data = await agent.load_data(training_data_file)
training_data = await agent.load_data(training_data_file, augmentation_factor=10)
agent.train(training_data)

# Attention: agent.persist stores the model and all meta data into a folder.
Expand Down
144 changes: 67 additions & 77 deletions rasa/core/test.py
@@ -1,17 +1,18 @@
import json
import logging
import os
import typing
import warnings
import typing
from collections import defaultdict, namedtuple
from typing import Any, Dict, List, Optional, Text, Tuple

from rasa.constants import RESULTS_FILE
from rasa.core.utils import pad_lists_to_size
from rasa.core.events import ActionExecuted, UserUttered
from rasa.nlu.training_data.formats.markdown import MarkdownWriter
from rasa.core.trackers import DialogueStateTracker

if typing.TYPE_CHECKING:
from rasa.core.agent import Agent
from rasa.core.trackers import DialogueStateTracker

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,32 +77,28 @@ def has_prediction_target_mismatch(self):
or self.action_predictions != self.action_targets
)

def serialise_targets(
self, include_actions=True, include_intents=True, include_entities=False
):
targets = []
if include_actions:
targets += self.action_targets
if include_intents:
targets += self.intent_targets
if include_entities:
targets += self.entity_targets

return [json.dumps(t) if isinstance(t, dict) else t for t in targets]

def serialise_predictions(
self, include_actions=True, include_intents=True, include_entities=False
):
predictions = []
def serialise(self) -> Tuple[List[Text], List[Text]]:
"""Turn targets and predictions to lists of equal size for sklearn."""

if include_actions:
predictions += self.action_predictions
if include_intents:
predictions += self.intent_predictions
if include_entities:
predictions += self.entity_predictions
targets = (
self.action_targets
+ self.intent_targets
+ [
MarkdownWriter._generate_entity_md(gold.get("text"), gold)
for gold in self.entity_targets
]
)
predictions = (
self.action_predictions
+ self.intent_predictions
+ [
MarkdownWriter._generate_entity_md(predicted.get("text"), predicted)
for predicted in self.entity_predictions
]
)

return [json.dumps(t) if isinstance(t, dict) else t for t in predictions]
# sklearn does not cope with lists of unequal size, nor None values
return pad_lists_to_size(targets, predictions, padding_value="None")


class WronglyPredictedAction(ActionExecuted):
Expand Down Expand Up @@ -144,24 +141,23 @@ class WronglyClassifiedUserUtterance(UserUttered):

type_name = "wrong_utterance"

def __init__(
self,
text,
correct_intent,
correct_entities,
parse_data=None,
timestamp=None,
input_channel=None,
predicted_intent=None,
predicted_entities=None,
):
self.predicted_intent = predicted_intent
self.predicted_entities = predicted_entities
def __init__(self, event: UserUttered, eval_store: EvaluationStore):

intent = {"name": correct_intent}
if not eval_store.intent_predictions:
self.predicted_intent = None
else:
self.predicted_intent = eval_store.intent_predictions[0]
self.predicted_entities = eval_store.entity_predictions

intent = {"name": eval_store.intent_targets[0]}

super(WronglyClassifiedUserUtterance, self).__init__(
text, intent, correct_entities, parse_data, timestamp, input_channel
event.text,
intent,
eval_store.entity_targets,
event.parse_data,
event.timestamp,
event.input_channel,
)

def as_story_string(self, e2e=True):
Expand Down Expand Up @@ -197,24 +193,35 @@ async def _generate_trackers(resource_name, agent, max_stories=None, use_e2e=Fal
return g.generate()


def _clean_entity_results(entity_results):
return [
{k: r[k] for k in ("start", "end", "entity", "value") if k in r}
for r in entity_results
]
def _clean_entity_results(
text: Text, entity_results: List[Dict[Text, Any]]
) -> List[Dict[Text, Any]]:
"""Extract only the token variables from an entity dict."""
cleaned_entities = []

for r in tuple(entity_results):
cleaned_entity = {"text": text}
for k in ("start", "end", "entity", "value"):
if k in set(r):
cleaned_entity[k] = r[k]
cleaned_entities.append(cleaned_entity)

return cleaned_entities

def _collect_user_uttered_predictions(
event, partial_tracker, fail_on_prediction_errors
):
from rasa.core.utils import pad_list_to_size

def _collect_user_uttered_predictions(
event: UserUttered,
partial_tracker: DialogueStateTracker,
fail_on_prediction_errors: bool,
) -> EvaluationStore:
user_uttered_eval_store = EvaluationStore()

intent_gold = event.parse_data.get("true_intent")
predicted_intent = event.parse_data.get("intent").get("name")
if predicted_intent is None:
predicted_intent = "None"
predicted_intent = event.parse_data.get("intent", {}).get("name")

if not predicted_intent:
predicted_intent = [None]

user_uttered_eval_store.add_to_store(
intent_predictions=predicted_intent, intent_targets=intent_gold
)
Expand All @@ -223,30 +230,14 @@ def _collect_user_uttered_predictions(
predicted_entities = event.parse_data.get("entities")

if entity_gold or predicted_entities:
if len(entity_gold) > len(predicted_entities):
predicted_entities = pad_list_to_size(
predicted_entities, len(entity_gold), "None"
)
elif len(predicted_entities) > len(entity_gold):
entity_gold = pad_list_to_size(entity_gold, len(predicted_entities), "None")

user_uttered_eval_store.add_to_store(
entity_targets=_clean_entity_results(entity_gold),
entity_predictions=_clean_entity_results(predicted_entities),
entity_targets=_clean_entity_results(event.text, entity_gold),
entity_predictions=_clean_entity_results(event.text, predicted_entities),
)

if user_uttered_eval_store.has_prediction_target_mismatch():
partial_tracker.update(
WronglyClassifiedUserUtterance(
event.text,
intent_gold,
user_uttered_eval_store.entity_predictions,
event.parse_data,
event.timestamp,
event.input_channel,
predicted_intent,
user_uttered_eval_store.entity_targets,
)
WronglyClassifiedUserUtterance(event, user_uttered_eval_store)
)
if fail_on_prediction_errors:
raise ValueError(
Expand Down Expand Up @@ -493,10 +484,9 @@ async def test(
from sklearn.exceptions import UndefinedMetricWarning

warnings.simplefilter("ignore", UndefinedMetricWarning)
report, precision, f1, accuracy = get_evaluation_metrics(
evaluation_store.serialise_targets(),
evaluation_store.serialise_predictions(),
)

targets, predictions = evaluation_store.serialise()
report, precision, f1, accuracy = get_evaluation_metrics(targets, predictions)

if out_directory:
plot_story_evaluation(
Expand Down
16 changes: 13 additions & 3 deletions rasa/core/utils.py
Expand Up @@ -379,9 +379,19 @@ def remove_none_values(obj: Dict[Text, Any]) -> Dict[Text, Any]:
return {k: v for k, v in obj.items() if v is not None}


def pad_list_to_size(_list, size, padding_value=None):
"""Pads _list with padding_value up to size"""
return _list + [padding_value] * (size - len(_list))
def pad_lists_to_size(
list_x: List, list_y: List, padding_value: Optional[Any] = None
) -> Tuple[List, List]:
"""Compares list sizes and pads them to equal length."""

difference = len(list_x) - len(list_y)

if difference > 0:
return list_x, list_y + [padding_value] * difference
elif difference < 0:
return list_x + [padding_value] * (-difference), list_y
else:
return list_x, list_y


class AvailableEndpoints(object):
Expand Down
18 changes: 16 additions & 2 deletions tests/core/conftest.py
Expand Up @@ -9,8 +9,7 @@
import rasa.utils.io
from rasa.core import train
from rasa.core.agent import Agent
from rasa.core.channels import channel
from rasa.core.channels.channel import CollectingOutputChannel, RestInput
from rasa.core.channels.channel import CollectingOutputChannel
from rasa.core.domain import Domain
from rasa.core.interpreter import RegexInterpreter
from rasa.core.nlg import TemplatedNaturalLanguageGenerator
Expand Down Expand Up @@ -44,6 +43,8 @@

MOODBOT_MODEL_PATH = "examples/moodbot/models/"

RESTAURANTBOT_PATH = "examples/restaurantbot/"

DEFAULT_ENDPOINTS_FILE = "data/test_endpoints/example_endpoints.yml"

TEST_DIALOGUES = [
Expand Down Expand Up @@ -237,3 +238,16 @@ def train_model(project: Text, filename: Text = "test.tar.gz"):
@pytest.fixture(scope="session")
def trained_model(project) -> Text:
return train_model(project)


@pytest.fixture
async def restaurantbot(tmpdir_factory) -> Text:
model_path = tmpdir_factory.mktemp("model").strpath
restaurant_domain = os.path.join(RESTAURANTBOT_PATH, "domain.yml")
restaurant_config = os.path.join(RESTAURANTBOT_PATH, "config.yml")
restaurant_data = os.path.join(RESTAURANTBOT_PATH, "data/")

agent = await train_async(
restaurant_domain, restaurant_config, restaurant_data, model_path
)
return agent

0 comments on commit 6ea0d2c

Please sign in to comment.