Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge branch 'master' into tracker_store_documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge authored Oct 31, 2018
2 parents cf65ef2 + 770aa86 commit 6a0c0fb
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Added
the end-to-end story describing a conversation
- docker-compose file to start a rasa core server together with nlu, an action server, and duckling
- http server (``rasa_core.run --enable-api``) evaluation endpoint
- ability to add tracker_store using endpoints.yml

Changed
-------
Expand Down
17 changes: 16 additions & 1 deletion data/example_endpoints.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,19 @@ nlg:
# basic_auth:
# username: user
# password: pass

# example of redis external tracker store config
tracker_store:
store_type: redis
url: localhost
port: 6379
db: 0
password: password
record_exp: 30000
# example of mongoDB external tracker store config
#tracker_store:
#store_type: mongod
#url: mongodb://localhost:27017
#db: rasa
#user: username
#password: password

2 changes: 1 addition & 1 deletion rasa_core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from rasa_core.policies.ensemble import SimplePolicyEnsemble, PolicyEnsemble
from rasa_core.policies.memoization import MemoizationPolicy
from rasa_core.processor import MessageProcessor
from rasa_core.tracker_store import InMemoryTrackerStore, TrackerStore
from rasa_core.tracker_store import InMemoryTrackerStore
from rasa_core.trackers import DialogueStateTracker, EventVerbosity
from rasa_core.utils import EndpointConfig
from rasa_nlu.utils import is_url
Expand Down
16 changes: 9 additions & 7 deletions rasa_core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from builtins import str
from collections import namedtuple

import os
import argparse
import logging
from flask import Flask
Expand All @@ -23,8 +23,8 @@
BUILTIN_CHANNELS)
from rasa_core.interpreter import (
NaturalLanguageInterpreter)
from rasa_core.tracker_store import TrackerStore
from rasa_core.utils import read_yaml_file, AvailableEndpoints

logger = logging.getLogger() # get the root logger


Expand Down Expand Up @@ -127,9 +127,9 @@ def _create_single_channel(channel, credentials):
raise Exception(
"Failed to find input channel class for '{}'. Unknown "
"input channel. Check your credentials configuration to "
"make sure the mentioned channel is not misspelled. If you "
"are creating your own channel, make sure it is a proper "
"name of a class in a module.".format(channel))
"make sure the mentioned channel is not misspelled. "
"If you are creating your own channel, make sure it "
"is a proper name of a class in a module.".format(channel))


def start_cmdline_io(server_url, on_finish, **kwargs):
Expand Down Expand Up @@ -217,7 +217,7 @@ def load_agent(core_model, interpreter, endpoints,
generator=endpoints.nlg,
action_endpoint=endpoints.action,
model_server=endpoints.model,
tracker_store=tracker_store,
tracker_store=endpoints.tracker_store,
wait_time_between_pulls=wait_time_between_pulls
)
else:
Expand Down Expand Up @@ -245,10 +245,12 @@ def load_agent(core_model, interpreter, endpoints,
_endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
_interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
_endpoints.nlu)
_tracker_store = TrackerStore.find_tracker_store(
None, _endpoints.tracker_store)
_agent = load_agent(cmdline_args.core,
interpreter=_interpreter,
tracker_store=_tracker_store,
endpoints=_endpoints)

serve_application(_agent,
cmdline_args.connector,
cmdline_args.port,
Expand Down
13 changes: 13 additions & 0 deletions rasa_core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ def __init__(self, domain, event_broker=None):
self.domain = domain
self.event_broker = event_broker

@staticmethod
def find_tracker_store(domain, store=None):
if store is None or store.store_type is None:
return InMemoryTrackerStore(domain)
elif store.store_type == 'redis':
return RedisTrackerStore(domain=domain,
host=store.url,
**store.kwargs)
elif store.store_type == 'mongod':
return MongoTrackerStore(domain=domain,
host=store.url,
**store.kwargs)

def get_or_create_tracker(self, sender_id):
tracker = self.retrieve(sender_id)
if tracker is None:
Expand Down
11 changes: 7 additions & 4 deletions rasa_core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rasa_core.interpreter import NaturalLanguageInterpreter
from rasa_core.run import AvailableEndpoints
from rasa_core.training import interactive

from rasa_core.tracker_store import TrackerStore
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -124,9 +124,9 @@ def add_args_to_parser(parser):
type=str,
default=None,
required=False,
help="When a fallback is triggered (e.g. because the ML prediction "
"is of low confidence) this is the name of tje action that "
"will get triggered instead.")
help="When a fallback is triggered (e.g. because the "
"ML prediction is of low confidence) this is the name "
"of the action that will get triggered instead.")
parser.add_argument(
'-c', '--config',
type=str,
Expand Down Expand Up @@ -230,6 +230,8 @@ def _additional_arguments(args):
_interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
_endpoints.nlu)

_tracker_store = TrackerStore.find_tracker_store(None,
_endpoints.tracker_store)
if cmdline_args.core:
if not cmdline_args.interactive:
raise ValueError("--core can only be used together with the"
Expand All @@ -243,6 +245,7 @@ def _additional_arguments(args):
_agent = Agent.load(cmdline_args.core,
interpreter=_interpreter,
generator=_endpoints.nlg,
tracker_store=_tracker_store,
action_endpoint=_endpoints.action)
else:
if not cmdline_args.out:
Expand Down
29 changes: 17 additions & 12 deletions rasa_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,8 @@ def extract_args(kwargs, # type: Dict[Text, Any]


def arguments_of(func):
"""Return the parameters of the function `func` as a list of their names."""
"""Return the parameters of the function `func` """
"""as a list of their names."""

try:
# python 3.x is used
Expand Down Expand Up @@ -547,7 +548,7 @@ def all_subclasses(cls):

def read_endpoint_config(filename, endpoint_type):
# type: (Text, Text) -> Optional[EndpointConfig]
"""Read an endpoint configuration file from disk and extract one config. """
"""Read an endpoint configuration file from disk and extract one config."""

if not filename:
return None
Expand Down Expand Up @@ -621,27 +622,37 @@ def read_endpoints(cls, endpoint_file):
endpoint_file, endpoint_type="action_endpoint")
model = read_endpoint_config(
endpoint_file, endpoint_type="models")
tracker_store = read_endpoint_config(
endpoint_file, endpoint_type="tracker_store")

return cls(nlg, nlu, action, model)
return cls(nlg, nlu, action, model, tracker_store)

def __init__(self, nlg=None, nlu=None, action=None, model=None):
def __init__(self,
nlg=None,
nlu=None,
action=None,
model=None,
tracker_store=None):
self.model = model
self.action = action
self.nlu = nlu
self.nlg = nlg
self.tracker_store = tracker_store


class EndpointConfig(object):
"""Configuration for an external HTTP endpoint."""

def __init__(self, url, params=None, headers=None, basic_auth=None,
token=None, token_name="token"):
token=None, token_name="token", **kwargs):
self.url = url
self.params = params if params else {}
self.headers = headers if headers else {}
self.basic_auth = basic_auth
self.token = token
self.token_name = token_name
self.store_type = kwargs.pop('store_type', None)
self.kwargs = kwargs

def request(self,
method="post", # type: Text
Expand Down Expand Up @@ -692,13 +703,7 @@ def request(self,

@classmethod
def from_dict(cls, data):
return EndpointConfig(
data.get("url"),
data.get("params"),
data.get("headers"),
data.get("basic_auth"),
data.get("token"),
data.get("token_name"))
return EndpointConfig(**data)

def __eq__(self, other):
if isinstance(self, type(other)):
Expand Down
36 changes: 35 additions & 1 deletion tests/test_tracker_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
from __future__ import print_function
from __future__ import unicode_literals

from rasa_core import utils
from rasa_core.channels import UserMessage
from rasa_core.domain import Domain
from rasa_core.events import SlotSet, ActionExecuted, Restarted
from rasa_core.tracker_store import InMemoryTrackerStore
from rasa_core.tracker_store import (
TrackerStore,
InMemoryTrackerStore,
RedisTrackerStore)
from rasa_core.utils import EndpointConfig
from tests.conftest import DEFAULT_ENDPOINTS_FILE

domain = Domain.load("data/test_domains/default.yml")

Expand Down Expand Up @@ -42,3 +48,31 @@ def test_restart_after_retrieval_from_tracker_store(default_domain):
tr2 = store.retrieve("myuser")
latest_restart_after_loading = tr2.idx_after_latest_restart()
assert latest_restart == latest_restart_after_loading


def test_tracker_store_endpoint_config_loading():
cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")

assert cfg == EndpointConfig.from_dict({
"store_type": "redis",
"url": "localhost",
"port": 6379,
"db": 0,
"password": "password",
"timeout": 30000
})


def test_find_tracker_store(default_domain):
store = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")
tracker_store = RedisTrackerStore(domain=default_domain,
host="localhost",
port=6379,
db=0,
password="password",
record_exp=3000)

assert isinstance(tracker_store,
type(
TrackerStore.find_tracker_store(default_domain, store)
))
7 changes: 6 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def test_endpoint_config():
basic_auth={"username": "user",
"password": "pass"},
token="mytoken",
token_name="letoken"
token_name="letoken",
store_type="redis",
port=6379,
db=0,
password="password",
timeout=30000
)

httpretty.register_uri(
Expand Down

0 comments on commit 6a0c0fb

Please sign in to comment.