Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge pull request #1733 from therealdeadly/sort-score-array-http-api
Browse files Browse the repository at this point in the history
WIP: sort the scores array in server.py
  • Loading branch information
tmbo committed Feb 25, 2019
2 parents 186c1a4 + e40a41a commit fa847d3
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Changed
-------
- starter packs are now tested in parallel with the unittests,
and only on master and branches ending in ``.x`` (i.e. new version releases)
- ``scores`` array returned by the ``/conversations/{sender_id}/predict``
endpoint is now sorted according to the actions' scores.

Removed
-------
Expand Down
9 changes: 5 additions & 4 deletions docs/_static/spec/server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,11 @@ paths:
- Tracker
summary: Predict the next action
description: >-
Runs the conversations tracker through the models
policies to predict the next action. The action is
not executed, just returned. The state of the tracker
is not modified.
Runs the conversations tracker through the model's
policies to predict the scores of all actions present
in the model's domain. Actions are returned in the
'scores' array, sorted on their 'score' values.
The state of the tracker is not modified.
operationId: predictAction
parameters:
- $ref: '#/components/parameters/senderId'
Expand Down
3 changes: 3 additions & 0 deletions rasa_core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,9 @@ def predict(sender_id):
try:
# Fetches the appropriate bot response in a json format
responses = agent.predict_next(sender_id)
responses['scores'] = sorted(responses['scores'],
key = lambda k: (-k['score'],
k['action']))
return jsonify(responses)

except Exception as e:
Expand Down
5 changes: 1 addition & 4 deletions rasa_core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,13 +602,10 @@ def _request_action_from_user(

_print_history(sender_id, endpoint)

sorted_actions = sorted(predictions,
key=lambda k: (-k['score'], k['action']))

choices = [{"name": "{:03.2f} {:40}".format(a.get("score"),
a.get("action")),
"value": a.get("action")}
for a in sorted_actions]
for a in predictions]

choices = ([{"name": "<create new action>", "value": OTHER_ACTION}] +
choices)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,20 @@ def test_predict(http_app, app):
assert response.status_code == 200


def test_sorted_predict(http_app, app):
client = RasaCoreClient(EndpointConfig(http_app))
cid = str(uuid.uuid1())
for event in test_events[:3]:
client.append_event_to_tracker(cid, event)
response = app.post("http://dummy/conversations/{}/predict".format(cid))
content = response.get_json()
scores = content["scores"]
sorted_scores = sorted(scores,
key = lambda k: (-k['score'],
k['action']))
assert scores == sorted_scores


def test_evaluate(app):
with io.open(DEFAULT_STORIES_FILE, 'r') as f:
stories = f.read()
Expand Down

0 comments on commit fa847d3

Please sign in to comment.