Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a --pre_load_model command to load specific models #1410

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.rst
Expand Up @@ -12,7 +12,7 @@ Added
- Added a detailed warning showing which entities are overlapping
- Authentication token can be also set with env variable `RASA_NLU_TOKEN`.
- `SpacyEntityExtractor` supports same entity filtering as `DucklingHTTPExtractor`

- ability to preload specific models on startup with the ``--pre_load_model`` option
Changed
-------
- validate training data only if used for training
Expand Down Expand Up @@ -108,6 +108,12 @@ Changed
- updated cloudpickle version to 0.6.1
- updated requirements to match Core and SDK
- pinned keras dependecies
- use cloudpickle version 0.6.1
- replaced ``yaml`` with ``ruamel.yaml``
- the ``load_model`` method in ``Project`` class now takes an optional ``model_name`` parameter
- ``pre_load`` command line argument is now the path of the model or project, relative to the path argument
- Training data is now validated after loading from files in ``loading.py`` instead of on initialisation of
``TrainingData`` object

Removed
-------
Expand Down
10 changes: 4 additions & 6 deletions rasa_nlu/cli/server.py
Expand Up @@ -10,12 +10,10 @@ def add_server_arguments(parser):
parser.add_argument('--pre_load',
nargs='+',
default=[],
help='Preload models into memory before starting the '
'server. \nIf given `all` as input all the models '
'will be loaded.\nElse you can specify a list of '
'specific project names.\nEg: python -m '
'rasa_nlu.server --pre_load project1 '
'--path projects '
help='Preload specific projects or models into memory '
'before starting the server. \nEg: python -m '
'rasa_nlu.server --pre_load projectX/'
'my_model_1 --pre_load projectY/my_model_2 '
'-c config.yaml')
parser.add_argument('-t', '--token',
help="auth token. If set, reject requests which don't "
Expand Down
31 changes: 26 additions & 5 deletions rasa_nlu/data_router.py
Expand Up @@ -189,11 +189,32 @@ def _create_project_store(self,

return project_store

def _pre_load(self, projects: List[Text]) -> None:
logger.debug("loading %s", projects)
for project in self.project_store:
if project in projects:
self.project_store[project].load_model()
def _pre_load(self, project, model=None, absolute_path=None):
if model is None:
self._pre_load_project(project)
else:
self._pre_load_model(project, model)

def _pre_load_project(self, project) -> None:
logger.debug("loading %s", project)
if project in self.project_store:
logger.info('loading project %s', project)
self.project_store[project].load_model()
else:
logger.debug(
'project %s does not exist in the project store',
project
)

def _pre_load_model(self, project, model=None) -> None:
if project in self.project_store:
logger.info("loading model %s from project %s", model, project)
self.project_store[project].load_model(model)
else:
logger.debug(
'project %s does not exist in the project store',
project
)

def _list_projects_in_cloud(self) -> List[Text]:
# noinspection PyBroadException
Expand Down
5 changes: 3 additions & 2 deletions rasa_nlu/project.py
Expand Up @@ -265,10 +265,11 @@ def parse(self, text, parsing_time=None, requested_model_name=None):

return response

def load_model(self):
def load_model(self, model_name=None):
self._begin_read()
status = False
model_name = self._dynamic_load_model()
if not model_name:
model_name = self._dynamic_load_model()
logger.debug('Loading model %s', model_name)

self._loader_lock.acquire()
Expand Down
115 changes: 111 additions & 4 deletions rasa_nlu/server.py
Expand Up @@ -8,6 +8,8 @@
from klein import Klein
from twisted.internet import reactor, threads
from twisted.internet.defer import inlineCallbacks, returnValue
from os.path import normpath, basename
import os

from rasa_nlu import config, utils
import rasa_nlu.cli.server as cli
Expand Down Expand Up @@ -359,10 +361,114 @@ def get_token(_clitoken: str) -> str:
return token


def get_absolute_path(model_path, current_path):
model_path = model_path.rstrip("/")
current_path = current_path.rstrip("/")
if not current_path.startswith('/'):
current_path = '/' + current_path
if(model_path.startswith('/')):
return model_path
else:
return current_path + "/" + model_path


def parse_model(path):
project, model = filter(None, path.split('/')[-2:])
return (project, model)


def parse_project(path):
project = path.split('/')[-1]
return project


def parse_pre_load_path(pre_load_path):
pre_load_path_split = pre_load_path.lstrip("/").split('/')
is_potential_model_or_project_path = len(pre_load_path_split) > 1
is_potential_project_path = len(pre_load_path_split) > 0
return (
pre_load_path_split,
is_potential_model_or_project_path,
is_potential_project_path
)


def should_fetch_from_cloud(
is_local_model,
is_local_project,
is_potential_model_or_project_path):
return (
(not is_local_model or
not is_local_project) and
is_potential_model_or_project_path
)


def parse_project_and_model(
is_local_model,
is_local_project,
pre_load_path,
absolute_path):

(
pre_load_path_split,
is_potential_model_or_project_path,
is_potential_project_path
) = parse_pre_load_path(pre_load_path)

result = []
if is_local_model:
# the argument is a model
project, model = parse_model(pre_load_path)
result = result + [(project, model, absolute_path)]

elif is_local_project or is_potential_project_path:
# the argument is a project
project = parse_project(pre_load_path)
result = result + [(project, None, absolute_path)]

if should_fetch_from_cloud(
is_local_model,
is_local_project,
is_potential_model_or_project_path):
# if we did not find anything locally,
# we can assume the project or model is stored on the cloud
# we do not know yet if it is a project or a model,
# so we try to load both
project, model = parse_model(pre_load_path)
result = [
(project, model, absolute_path),
(model, None, absolute_path)
]
return result


def parse_pre_load(pre_load_args, path):
parsed_results = []
for pre_load_path in pre_load_args:

pre_load_path = pre_load_path.rstrip("/")
absolute_path = get_absolute_path(pre_load_path, path)

is_local_model = os.path.isfile(absolute_path + '/metadata.json')
is_local_project = os.path.isdir(pre_load_path)

parsed_projects_and_models = parse_project_and_model(
is_local_model,
is_local_project,
pre_load_path,
absolute_path
)
if len(parsed_projects_and_models) == 0:
logger.error('invalid input for parse_pre_load: at least one element expected')
parsed_results = parsed_results + parsed_projects_and_models

return parsed_results


def main(args):
utils.configure_colored_logging(args.loglevel)
pre_load = args.pre_load

_endpoints = read_endpoints(args.endpoints)

router = DataRouter(
Expand All @@ -374,11 +480,12 @@ def main(args):
model_server=_endpoints.model,
wait_time_between_pulls=args.wait_time_between_pulls
)

if pre_load:
logger.debug('Preloading....')
if 'all' in pre_load:
pre_load = router.project_store.keys()
router._pre_load(pre_load)
parsed_input = parse_pre_load(pre_load, cmdline_args.path)
for (project, model, absolute_path) in parsed_input:
router._pre_load(project, model, absolute_path)

rasa = RasaNLU(
router,
Expand Down
34 changes: 33 additions & 1 deletion tests/base/test_data_router.py
@@ -1,5 +1,5 @@
import mock

from rasa_nlu.project import Project
from rasa_nlu import data_router
from rasa_nlu import persistor

Expand All @@ -25,3 +25,35 @@ def mocked_data_router_init(self, *args, **kwargs):
mocked_get_persistor):
return_value = data_router.DataRouter()._list_projects_in_cloud()
assert isinstance(return_value[0], UniqueValue)


def test_pre_load_model():
with mock.patch.object(Project, 'load_model', return_value=None) as mock_load_model:
with mock.patch.object(data_router.DataRouter, '_pre_load_model', return_value=None) as mock_pre_load_model:
dr = data_router.DataRouter()
dr.project_store['project_test'] = Project()
dr._pre_load('project_test', 'model_test', None)
mock_pre_load_model.assert_called_once_with('project_test', 'model_test')

with mock.patch.object(Project, 'load_model', return_value=None) as mock_load_model:
with mock.patch.object(data_router.DataRouter, '_pre_load_model', return_value=None) as mock_pre_load_model:
dr = data_router.DataRouter()
dr.project_store['project_test'] = Project()
dr._pre_load('project_test', 'model_test', '/whatever/absolute/path')
mock_pre_load_model.assert_called_once_with('project_test', 'model_test')


def test_pre_load_project():
with mock.patch.object(Project, 'load_model', return_value=None) as mock_load_model:
with mock.patch.object(data_router.DataRouter, '_pre_load_project', return_value=None) as mock_pre_load_project:
dr = data_router.DataRouter()
dr.project_store['project_test'] = Project()
dr._pre_load('project_test', None, None)
mock_pre_load_project.assert_called_once_with('project_test')

with mock.patch.object(Project, 'load_model', return_value=None) as mock_load_model:
with mock.patch.object(data_router.DataRouter, '_pre_load_project', return_value=None) as mock_pre_load_project:
dr = data_router.DataRouter()
dr.project_store['project_test'] = Project()
dr._pre_load('project_test', None, '/some/absolute/path')
mock_pre_load_project.assert_called_once_with('project_test')
32 changes: 31 additions & 1 deletion tests/base/test_project.py
Expand Up @@ -4,7 +4,7 @@
import responses
from rasa_nlu.project import Project, load_from_server
from rasa_nlu.utils import EndpointConfig

from rasa_nlu.model import Metadata

def test_dynamic_load_model_with_exists_model():
MODEL_NAME = 'model_name'
Expand Down Expand Up @@ -119,3 +119,33 @@ def test_project_with_model_server(zipped_nlu_model):
stream=True)
project = load_from_server(model_server=model_endpoint)
assert project.fingerprint == fingerprint


def test_load_model_without_args():

@staticmethod
def mock_load(model_dir):
data = Project._default_model_metadata()
return Metadata(data, model_dir)

with mock.patch.object(Metadata, "load", mock_load):
with mock.patch.object(Project, "_load_model_from_cloud", return_value=None):
with mock.patch.object(Project, "_dynamic_load_model", return_value='') as mock_dynamic_load_model:
project = Project()
project.load_model()
mock_dynamic_load_model.assert_called_once_with()


def test_load_model_with_args():

@staticmethod
def mock_load(model_dir):
data = Project._default_model_metadata()
return Metadata(data, model_dir)

with mock.patch.object(Metadata, "load", mock_load):
with mock.patch.object(Project, "_load_model_from_cloud", return_value=None):
with mock.patch.object(Project, "_dynamic_load_model", return_value='') as mock_dynamic_load_model:
project = Project()
project.load_model('my_model')
mock_dynamic_load_model.assert_not_called()