Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge pull request #982 from RasaHQ/jwt-auth
Browse files Browse the repository at this point in the history
Jwt auth
  • Loading branch information
tmbo committed Sep 11, 2018
2 parents ac18cf2 + 0db4f31 commit 15a4c3c
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 49 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ This project adheres to `Semantic Versioning`_ starting with version 0.2.0.

Added
-----
- JWT support: parameters to allow clients to authenticate requests to
the rasa_core.server using JWT's in addition to normal token based auth

Changed
-------
Expand Down
48 changes: 42 additions & 6 deletions docs/server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,26 @@
HTTP API
========

.. warning::

To protect your conversational data, make sure to secure the server.
Either by restricting access to the server (e.g. using firewalls) or
by enabling one of the authentication methods: :ref:`server_security`.

.. note::

Before you can use the server, you need to define a domain, create training
data, and train a model. You can then use the trained model!
See :ref:`quickstart` for an introduction.

If you are looking for documentation on how to run custom actions -
head over to :ref:`customactions`.


The HTTP api exists to make it easy for python and non-python
projects to interact with Rasa Core. The API allows you to modify
the trackers (e.g. push or remote events).

.. note::

If you are looking for documentation on how to run custom actions -
head over to :ref:`customactions`.

.. contents::

Expand Down Expand Up @@ -110,15 +115,21 @@ at :ref:`events`. You need to send these json formats to the endpoint to
log the event.


.. _server_security:

Security Considerations
-----------------------

We recommend to not expose the Rasa Core server to the outside world but
rather connect to it from your backend over a private connection (e.g.
between docker containers).

Nevertheless, there is built in token authentication. If you specify a token
when starting the server, that token needs to be passed with every request:
Nevertheless, there are two authentication methods built in:

**Token Based Auth:**

Pass in the token using ``--auth_token thisismysecret`` when starting
the server:

.. code-block:: bash
Expand All @@ -136,6 +147,31 @@ as a parameter:
$ curl -XGET localhost:5005/conversations/default/tracker?token=thisismysecret
**JWT Based Auth:**

Enable JWT based authentication using ``--jwt_secret thisismysecret``.
Requests to the server need to contain a valid JWT token in
the ``Authorization`` header that is signed using this secret
and the ``HS256`` algorithm.

.. code-block:: bash
$ python -m rasa_core.run \
--enable_api \
--jwt_secret thisismysecret \
-d models/dialogue \
-u models/nlu/current \
-o out.log
Your requests should have set a proper JWT header:

.. code-block:: json
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ"
"zdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIi"
"wiaWF0IjoxNTE2MjM5MDIyfQ.qdrr2_a7Sd80gmCWjnDomO"
"Gl8eZFVfKXA6jhncgRn-I"
Endpoints
---------
Expand Down
34 changes: 29 additions & 5 deletions rasa_core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ def create_argument_parser():
action="store_true",
help="Start the web server api in addition to the input channel")

jwt_auth = parser.add_argument_group('JWT Authentication')
jwt_auth.add_argument(
'--jwt_secret',
type=str,
help="Public key for asymmetric JWT methods or shared secret"
"for symmetric methods. Please also make sure to use "
"--jwt_method to select the method of the signature, "
"otherwise this argument will be ignored.")
jwt_auth.add_argument(
'--jwt_method',
type=str,
default="HS256",
help="Method used for the signature of the JWT authentication "
"payload.")

utils.add_logging_option_arguments(parser)
return parser

Expand Down Expand Up @@ -165,13 +180,17 @@ def start_server(input_channels,
auth_token,
port,
initial_agent,
enable_api=True):
enable_api=True,
jwt_secret=None,
jwt_method=None):
"""Run the agent."""

if enable_api:
app = server.create_app(initial_agent,
cors_origins=cors,
auth_token=auth_token)
auth_token=auth_token,
jwt_secret=jwt_secret,
jwt_method=jwt_method)
else:
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": cors or ""}})
Expand All @@ -198,12 +217,15 @@ def serve_application(initial_agent,
credentials_file=None,
cors=None,
auth_token=None,
enable_api=True
enable_api=True,
jwt_secret=None,
jwt_method=None,
):
input_channels = create_http_input_channels(channel, credentials_file)

http_server = start_server(input_channels, cors, auth_token,
port, initial_agent, enable_api)
port, initial_agent, enable_api,
jwt_secret, jwt_method)

if channel == "cmdline":
start_cmdline_io(constants.DEFAULT_SERVER_FORMAT.format(port),
Expand Down Expand Up @@ -262,4 +284,6 @@ def load_agent(core_model, interpreter, endpoints,
cmdline_args.credentials,
cmdline_args.cors,
cmdline_args.auth_token,
cmdline_args.enable_api)
cmdline_args.enable_api,
cmdline_args.jwt_secret,
cmdline_args.jwt_method)
67 changes: 43 additions & 24 deletions rasa_core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@
import zipfile
from flask import Flask, request, abort, Response, jsonify
from flask_cors import CORS, cross_origin
from flask_jwt_simple import JWTManager, view_decorators
from functools import wraps
from typing import List
from typing import Text, Optional
from typing import Union

from rasa_core import utils, constants
from rasa_core.channels import (
CollectingOutputChannel, UserMessage)
CollectingOutputChannel)
from rasa_core.channels import UserMessage
from rasa_core.events import Event
from rasa_core.interpreter import NaturalLanguageInterpreter
from rasa_core.policies import PolicyEnsemble
from rasa_core.trackers import DialogueStateTracker
from rasa_core.version import __version__
from rasa_core.channels import UserMessage


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,15 +59,22 @@ def request_parameters():
raise


def requires_auth(token=None):
# type: (Optional[Text]) -> function
def requires_auth(app, token=None):
# type: (Flask, Optional[Text]) -> function
"""Wraps a request handler with token authentication."""

def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
provided = request.args.get('token')
if token is None or provided == token:
# noinspection PyProtectedMember
if token is not None and provided == token:
return f(*args, **kwargs)
elif (app.config.get('JWT_ALGORITHM') is not None
and view_decorators._decode_jwt_from_headers()):
return f(*args, **kwargs)
elif token is None and app.config.get('JWT_ALGORITHM') is None:
# authentication is disabled
return f(*args, **kwargs)
abort(401)

Expand All @@ -79,13 +86,25 @@ def decorated(*args, **kwargs):
def create_app(agent,
cors_origins=None, # type: Optional[Union[Text, List[Text]]]
auth_token=None, # type: Optional[Text]
jwt_secret=None, # type: Optional[Text]
jwt_method="HS256", # type: Optional[Text]
):
"""Class representing a Rasa Core HTTP server."""

app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})
cors_origins = cors_origins or []

# Setup the Flask-JWT-Simple extension
if jwt_secret and jwt_method:
# since we only want to check signatures, we don't actually care
# about the JWT method and set the passed secret as either symmetric
# or asymmetric key. jwt lib will choose the right one based on method
app.config['JWT_SECRET_KEY'] = jwt_secret
app.config['JWT_PUBLIC_KEY'] = jwt_secret
app.config['JWT_ALGORITHM'] = jwt_method
JWTManager(app)

if not agent.is_ready():
logger.info("The loaded agent is not ready to be used yet "
"(e.g. only the NLU interpreter is configured, "
Expand Down Expand Up @@ -115,7 +134,7 @@ def version():
@app.route("/conversations/<sender_id>/execute",
methods=['POST', 'OPTIONS'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@ensure_loaded_agent(agent)
def execute_action(sender_id):
request_params = request.get_json(force=True)
Expand Down Expand Up @@ -147,7 +166,7 @@ def execute_action(sender_id):
@app.route("/conversations/<sender_id>/tracker/events",
methods=['POST', 'OPTIONS'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@ensure_loaded_agent(agent)
def append_event(sender_id):
"""Append a list of events to the state of a conversation"""
Expand All @@ -167,7 +186,7 @@ def append_event(sender_id):
@app.route("/conversations/<sender_id>/tracker/events",
methods=['PUT'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@ensure_loaded_agent(agent)
def replace_events(sender_id):
"""Use a list of events to set a conversations tracker to a state."""
Expand All @@ -183,7 +202,7 @@ def replace_events(sender_id):
@app.route("/conversations",
methods=['GET', 'OPTIONS'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
def list_trackers():
if agent.tracker_store:
return jsonify(list(agent.tracker_store.keys()))
Expand All @@ -193,7 +212,7 @@ def list_trackers():
@app.route("/conversations/<sender_id>/tracker",
methods=['GET', 'OPTIONS'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
def retrieve_tracker(sender_id):
"""Get a dump of a conversations tracker including its events."""

Expand Down Expand Up @@ -227,7 +246,7 @@ def retrieve_tracker(sender_id):
@app.route("/conversations/<sender_id>/respond",
methods=['GET', 'POST', 'OPTIONS'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@ensure_loaded_agent(agent)
def respond(sender_id):
request_params = request_parameters()
Expand Down Expand Up @@ -261,7 +280,7 @@ def respond(sender_id):
@app.route("/conversations/<sender_id>/predict",
methods=['POST', 'OPTIONS'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@ensure_loaded_agent(agent)
def predict(sender_id):
try:
Expand All @@ -278,7 +297,7 @@ def predict(sender_id):

@app.route("/conversations/<sender_id>/messages", methods=['POST'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@ensure_loaded_agent(agent)
def log_message(sender_id):
request_params = request.get_json(force=True)
Expand Down Expand Up @@ -308,7 +327,7 @@ def log_message(sender_id):
content_type="application/json")

@app.route("/model", methods=['POST', 'OPTIONS'])
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@cross_origin(origins=cors_origins)
def load_model():
"""Loads a zipped model, replacing the existing one."""
Expand Down Expand Up @@ -342,7 +361,7 @@ def load_model():
@app.route("/domain",
methods=['GET', 'OPTIONS'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@ensure_loaded_agent(agent)
def get_domain():
"""Get current domain in yaml or json format."""
Expand All @@ -367,7 +386,7 @@ def get_domain():
@app.route("/finetune",
methods=['POST', 'OPTIONS'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@ensure_loaded_agent(agent)
def continue_training():
request.headers.get("Accept")
Expand All @@ -394,15 +413,15 @@ def continue_training():

@app.route("/status", methods=['GET', 'OPTIONS'])
@cross_origin(origins=cors_origins)
@requires_auth(auth_token)
@requires_auth(app, auth_token)
def status():
return jsonify({
"model_fingerprint": agent.fingerprint,
"is_ready": agent.is_ready()
})

@app.route("/predict", methods=['POST'])
@requires_auth(auth_token)
@requires_auth(app, auth_token)
@cross_origin(origins=cors_origins)
@ensure_loaded_agent(agent)
def tracker_predict():
Expand All @@ -421,10 +440,8 @@ def tracker_predict():
probabilities, _ = policy_ensemble.probabilities_using_best_policy(
tracker, agent.domain)

probability_dict = {}
for idx, probability in enumerate(probabilities):
action_name = agent.domain.action_names[idx]
probability_dict[action_name] = probability
probability_dict = {agent.domain.action_names[idx]: probability
for idx, probability in enumerate(probabilities)}

return jsonify(probability_dict)

Expand Down Expand Up @@ -464,4 +481,6 @@ def tracker_predict():
cmdline_args.credentials,
cmdline_args.cors,
cmdline_args.auth_token,
cmdline_args.enable_api)
cmdline_args.enable_api,
cmdline_args.jwt_secret,
cmdline_args.jwt_method)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ rasa_core_sdk~=0.11.0
pymongo==3.5.1
python-dateutil==2.7.3
rocketchat_API==0.6.22
flask-jwt-simple==0.0.3
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"python-dateutil~=2.7",
"rasa_nlu~=0.13.0",
"rasa_core_sdk~=0.11.0",
"flask-jwt-simple~=0.0.3"
]

extras_requires = {
Expand Down
Loading

0 comments on commit 15a4c3c

Please sign in to comment.