Tracker_store added in endpoints.yml #1048
Changes from 7 commits
7593507
962dc43
83b3e9b
523cfb2
859836b
e00f7ac
39e8302
57d26d3
bf1b532
d12d1de
d42fd37
771ff80
be75b1c
b6d3dd9
34c46d9
0cd8e78
abee451
477108a
b437f9b
11a0b0e
3d1a378
48b74b3
0d8db77
0b35da7
3a3241d
9bd7a09
d46d480
123fefe
dfeb3c3
de4696c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,7 @@ | |
from rasa_core.policies.ensemble import SimplePolicyEnsemble, PolicyEnsemble | ||
from rasa_core.policies.memoization import MemoizationPolicy | ||
from rasa_core.processor import MessageProcessor | ||
from rasa_core.tracker_store import InMemoryTrackerStore, TrackerStore | ||
from rasa_core.tracker_store import InMemoryTrackerStore | ||
from rasa_core.trackers import DialogueStateTracker, EventVerbosity | ||
from rasa_core.utils import EndpointConfig | ||
from rasa_nlu.utils import is_url | ||
|
@@ -648,8 +648,7 @@ def create_tracker_store(store, domain): | |
store.domain = domain | ||
return store | ||
else: | ||
return InMemoryTrackerStore(domain) | ||
|
||
return InMemoryTrackerStore(domain) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a blank line after the method |
||
@staticmethod | ||
def _create_ensemble( | ||
policies # type: Union[List[Policy], PolicyEnsemble, None] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
|
||
from builtins import str | ||
from collections import namedtuple | ||
|
||
import os | ||
import argparse | ||
import logging | ||
from flask import Flask | ||
|
@@ -23,8 +23,9 @@ | |
BUILTIN_CHANNELS) | ||
from rasa_core.interpreter import ( | ||
NaturalLanguageInterpreter) | ||
from rasa_core.tracker_store import TrackerStore | ||
from rasa_core.utils import read_yaml_file, AvailableEndpoints | ||
|
||
from rasa_core.domain import TemplateDomain | ||
logger = logging.getLogger() # get the root logger | ||
|
||
|
||
|
@@ -217,7 +218,7 @@ def load_agent(core_model, interpreter, endpoints, | |
generator=endpoints.nlg, | ||
action_endpoint=endpoints.action, | ||
model_server=endpoints.model, | ||
tracker_store=tracker_store, | ||
tracker_store=endpoints.tracker_store, | ||
wait_time_between_pulls=wait_time_between_pulls | ||
) | ||
else: | ||
|
@@ -245,10 +246,12 @@ def load_agent(core_model, interpreter, endpoints, | |
_endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints) | ||
_interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu, | ||
_endpoints.nlu) | ||
domain = TemplateDomain.load(os.path.join(cmdline_args.core, "domain.yml")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this file might not exist: the agent might be loaded from a remote server and this directory might be empty It is fine to set the |
||
_tracker_store = TrackerStore(domain).find_tracker_store(_endpoints.tracker_store) | ||
_agent = load_agent(cmdline_args.core, | ||
interpreter=_interpreter, | ||
tracker_store=_tracker_store, | ||
endpoints=_endpoints) | ||
|
||
serve_application(_agent, | ||
cmdline_args.connector, | ||
cmdline_args.port, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,14 @@ def __init__(self, domain, event_broker=None): | |
# type: (Optional[Domain], Optional[EventChannel]) -> None | ||
self.domain = domain | ||
self.event_broker = event_broker | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please re-add an empty line here |
||
def find_tracker_store(self, store=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realised that this should be class method. We are just using the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. true, you don't need to instantiate the TrackerStore class in order to retrieve a tracker_store class. I can change that to return a cls object |
||
if store is None: | ||
souvikg10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return InMemoryTrackerStore(self.domain) | ||
elif store.store_type == 'redis': | ||
return RedisTrackerStore(domain=self.domain,host=store.url,db=store.db,password=store.password,record_exp=store.timeout) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be great if the user could also define a certain port through the configuration file. return RedisTrackerStore(domain=self.domain,
host=store.url,
port=store.port,
db=store.db,
password=store.password,
record_exp=store.timeout) |
||
elif store.store_type == 'mongod': | ||
return MongoTrackerStore(domain=self.domain,host=store.url,db=store.db,username=store.user,password=store.password) | ||
|
||
def get_or_create_tracker(self, sender_id): | ||
tracker = self.retrieve(sender_id) | ||
if tracker is None: | ||
|
@@ -124,6 +131,7 @@ def __init__(self, domain, host='localhost', | |
port=6379, db=0, password=None, event_broker=None, | ||
record_exp=None): | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove the empty line |
||
import redis | ||
self.red = redis.StrictRedis(host=host, port=port, db=db, | ||
password=password) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,8 @@ | |
from rasa_core.policies.memoization import MemoizationPolicy | ||
from rasa_core.run import AvailableEndpoints | ||
from rasa_core.training import interactive | ||
|
||
from rasa_core.tracker_store import TrackerStore | ||
from rasa_core.domain import TemplateDomain | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
@@ -216,7 +217,8 @@ def train_dialogue_model(domain_file, stories_file, output_path, | |
_endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints) | ||
_interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu, | ||
_endpoints.nlu) | ||
|
||
domain = TemplateDomain.load(cmdline_args.domain) | ||
_tracker_store = TrackerStore(domain).find_tracker_store(_endpoints.tracker_store) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please insert an empty line here for better readability |
||
if cmdline_args.core: | ||
if not cmdline_args.interactive: | ||
raise ValueError("--core can only be used together with the" | ||
|
@@ -230,6 +232,7 @@ def train_dialogue_model(domain_file, stories_file, output_path, | |
_agent = Agent.load(cmdline_args.core, | ||
interpreter=_interpreter, | ||
generator=_endpoints.nlg, | ||
tracker_store=_tracker_store, | ||
action_endpoint=_endpoints.action) | ||
else: | ||
if not cmdline_args.out: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -595,27 +595,34 @@ def read_endpoints(cls, endpoint_file): | |
endpoint_file, endpoint_type="action_endpoint") | ||
model = read_endpoint_config( | ||
endpoint_file, endpoint_type="models") | ||
tracker_store = read_endpoint_config(endpoint_file, endpoint_type="tracker_store") | ||
|
||
return cls(nlg, nlu, action, model) | ||
return cls(nlg, nlu, action, model, tracker_store) | ||
|
||
def __init__(self, nlg=None, nlu=None, action=None, model=None): | ||
def __init__(self, nlg=None, nlu=None, action=None, model=None,tracker_store=None): | ||
self.model = model | ||
self.action = action | ||
self.nlu = nlu | ||
self.nlg = nlg | ||
self.tracker_store = tracker_store | ||
|
||
|
||
class EndpointConfig(object): | ||
"""Configuration for an external HTTP endpoint.""" | ||
|
||
def __init__(self, url, params=None, headers=None, basic_auth=None, | ||
token=None, token_name="token"): | ||
token=None, token_name="token",store_type=None,db=None,user=None,password=None,timeout=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure we should set all of them as separate arguments. Now that we start to use this for more purposes, I think it makes sense to just take all the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah makes sense, i had some thoughts on that as well. I will try to do that next week 👍 |
||
self.url = url | ||
self.params = params if params else {} | ||
self.headers = headers if headers else {} | ||
self.basic_auth = basic_auth | ||
self.token = token | ||
self.token_name = token_name | ||
self.store_type = store_type | ||
self.db = db | ||
self.user = user | ||
self.password= password | ||
self.timeout = timeout | ||
|
||
def request(self, | ||
method="post", # type: Text | ||
|
@@ -672,7 +679,12 @@ def from_dict(cls, data): | |
data.get("headers"), | ||
data.get("basic_auth"), | ||
data.get("token"), | ||
data.get("token_name")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you have either to pass [...],
token_name=data.get('token_name'),
store_type=data.get('store_type'),
[...] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I noticed the issue comes when I added **kwargs so each parameter became a key value pair, I think it is also best here to change the key value pair as well. Sorry overlooked it during the tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I noticed the issue comes when I added **kwargs so each parameter became a key value pair, I think it is also best here to change the key value pair as well. Sorry overlooked it during the tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alternatively:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that makes sense. I will do the fixes this weekend 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your efforts! Your feature will help a lot of users in the future to use different tracker stores 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the issues, however i haven't found any formatting error in Travis, not sure how to check that though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I will have another look on them later. Regarding travis: I think we have to let it built first. This takes usually about 20 minutes after you have pushed your changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parameters are overlapping here. See https://travis-ci.com/RasaHQ/rasa_core/jobs/152986926 line 1495. A solution to this could be: return EndpointConfig(**data) and a constructor like this: class EndpointConfig(object):
"""Configuration for an external HTTP endpoint."""
def __init__(self, url, params=None, headers=None, basic_auth=None,
token=None, token_name="token",**kwargs):
self.url = url
self.params = params if params else {}
self.headers = headers if headers else {}
self.basic_auth = basic_auth
self.token = token
self.token_name = token_name
[setattr(self, k, v) for k, v in kwargs.items() if not hasattr(self, k)] I used the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding the formatting issues: You can also check that locally, e.g. by using a linter which knows pep8 or IDEs such as PyCharm or Visual Studio Code (with the Python plugin). |
||
data.get("token_name"), | ||
data.get("store_type"), | ||
data.get("db"), | ||
data.get("user"), | ||
data.get("password"), | ||
data.get("timeout")) | ||
|
||
def __eq__(self, other): | ||
if isinstance(self, type(other)): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be
url: localhost
(omitting thehttp://
)It would also be great if you include
port: 6379
in your example 😊.