Skip to content

Commit

Permalink
code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ricwo committed Sep 11, 2018
1 parent 8a10f78 commit 54f210a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 67 deletions.
59 changes: 32 additions & 27 deletions docs/migrations.rst
Expand Up @@ -6,51 +6,56 @@ how you can migrate from one version to another.

0.13.x to 0.13.3
----------------

- ``rasa_nlu.server`` needs to be supplied with an ``yml`` file defining the
model endpoint to retrieve training data. The file location has to be passed
with the ``--endpoints`` argument, e.g.
``python rasa_nlu.server --path projects --endpoints endpoints.yml``
``endpoints.yml`` needs to contain the ``model`` key
with a ``url`` and an optional ``token``. Here's an example:
- ``rasa_nlu.server`` has to be supplied with a ``yml`` file defining the
model endpoint from which to retrieve training data. The file location has
be passed with the ``--endpoints`` argument, e.g.
``python rasa_nlu.server --path projects --endpoints endpoints.yml``
``endpoints.yml`` needs to contain the ``model`` key
with a ``url`` and an optional ``token``. Here's an example:

.. code-block:: yaml
model:
url: http://my_model_server.com/models/default/nlu/tags/latest
token: my_model_server_token
model:
url: http://my_model_server.com/models/default/nlu/tags/latest
token: my_model_server_token
.. note::

If you configure ``rasa_nlu.server`` to pull models from a remote server,
the default project name will be used. It is defined
``RasaNLUModelConfig.DEFAULT_PROJECT_NAME``.


- ``rasa_nlu.train`` also has to be run with the ``--endpoints`` argument
if you want to pull training data from a URL. This replaces the previous
``--url`` syntax.
- ``rasa_nlu.train`` can also be run with the ``--endpoints`` argument
if you want to pull training data from a URL. Alternatively, the
current ``--url`` syntax is still supported.

.. code-block:: yaml
data:
url: http://my_data_server.com/projects/default/data
token: my_data_server_token
data:
url: http://my_data_server.com/projects/default/data
token: my_data_server_token
.. note::

Your endpoint file may contain entries for both ``model`` and ``data``.
``rasa_nlu.server`` and ``rasa_nlu.train`` will pick the relevant entry.

- If you directly access the ``DataRouter`` class or ``rasa_nlu.train``'s
``do_train()`` method, you can directly create instances of
``EndpointConfig`` without creating a ``yml`` file. Example:
``do_train()`` method, you can directly create instances of
``EndpointConfig`` without creating a ``yml`` file. Example:

.. code-block:: python
.. code-block:: python
from rasa_nlu.utils import EndpointConfig
from rasa_nlu.data_router import DataRouter
from rasa_nlu.utils import EndpointConfig
from rasa_nlu.data_router import DataRouter
model_endpoint = EndpointConfig(
url="http://my_model_server.com/models/default/nlu/tags/latest",
token="my_model_server_token"
)
model_endpoint = EndpointConfig(
url="http://my_model_server.com/models/default/nlu/tags/latest",
token="my_model_server_token"
)
interpreter = DataRouter("projects",
model_server=model_endpoint)
interpreter = DataRouter("projects", model_server=model_endpoint)
0.12.x to 0.13.0
Expand Down
61 changes: 26 additions & 35 deletions rasa_nlu/data_router.py
Expand Up @@ -7,10 +7,14 @@
import io
import logging
import os
from builtins import object
from concurrent.futures import ProcessPoolExecutor as ProcessPool
from typing import Text, Dict, Any, Optional, List

from builtins import object
from twisted.internet import reactor
from twisted.internet.defer import Deferred
from twisted.logger import jsonFileLogObserver, Logger

from rasa_nlu import utils, config
from rasa_nlu.components import ComponentBuilder
from rasa_nlu.config import RasaNLUModelConfig
Expand All @@ -20,9 +24,6 @@
from rasa_nlu.train import do_train_in_worker, TrainingException
from rasa_nlu.training_data import Message
from rasa_nlu.training_data.loading import load_data
from twisted.internet import reactor
from twisted.internet.defer import Deferred
from twisted.logger import jsonFileLogObserver, Logger

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -153,45 +154,35 @@ def _collect_projects(self, project_dir):

def _create_project_store(self,
project_dir):
default_project = RasaNLUModelConfig.DEFAULT_PROJECT_NAME

projects = self._collect_projects(project_dir)

project_store = {}

for project in projects:
if self.model_server is not None:
project_store[project] = load_from_server(
self.component_builder,
project,
self.project_dir,
self.remote_storage,
self.model_server,
self.wait_time_between_pulls
)
else:
project_store[project] = Project(
self.component_builder,
project,
self.project_dir,
self.remote_storage
)

if not project_store:
default_project = RasaNLUModelConfig.DEFAULT_PROJECT_NAME
if self.model_server is not None:
project_store[default_project] = load_from_server(
self.component_builder,
default_project,
self.project_dir,
self.remote_storage,
self.model_server,
self.wait_time_between_pulls
)
else:
if self.model_server is not None:
project_store[default_project] = load_from_server(
self.component_builder,
default_project,
self.project_dir,
self.remote_storage,
self.model_server,
self.wait_time_between_pulls
)
else:
for project in projects:
project_store[project] = Project(self.component_builder,
project,
self.project_dir,
self.remote_storage)

if not project_store:
project_store[default_project] = Project(
project=default_project,
project_dir=self.project_dir,
remote_storage=self.remote_storage)
remote_storage=self.remote_storage
)

return project_store

def _pre_load(self, projects):
Expand Down
17 changes: 12 additions & 5 deletions rasa_nlu/train.py
Expand Up @@ -6,7 +6,6 @@
import argparse
import logging
import typing
from collections import namedtuple
from typing import Optional, Any
from typing import Text
from typing import Tuple
Expand All @@ -18,7 +17,7 @@
from rasa_nlu.model import Trainer
from rasa_nlu.training_data import load_data
from rasa_nlu.training_data.loading import load_data_from_endpoint
from rasa_nlu.utils import read_endpoints
from rasa_nlu.utils import read_endpoints, EndpointConfig

logger = logging.getLogger(__name__)

Expand All @@ -43,9 +42,14 @@ def create_argument_parser():
"or a directory containing multiple training "
"data files.")

group.add_argument('-u', '--url',
default=None,
help="URL from which to retrieve training data.")

group.add_argument('--endpoints',
default=None,
help="EndpointConfig defining the server from which pull training data.")
help="EndpointConfig defining the server from which "
"pull training data.")

parser.add_argument('-c', '--config',
required=True,
Expand Down Expand Up @@ -165,14 +169,17 @@ def do_train(cfg, # type: RasaNLUModelConfig

utils.configure_colored_logging(cmdline_args.loglevel)

_endpoints = read_endpoints(cmdline_args.endpoints)
if cmdline_args.url:
data_endpoint = EndpointConfig(cmdline_args.url)
else:
data_endpoint = read_endpoints(cmdline_args.endpoints).data

do_train(config.load(cmdline_args.config),
cmdline_args.data,
cmdline_args.path,
cmdline_args.project,
cmdline_args.fixed_model_name,
cmdline_args.storage,
data_endpoint=_endpoints.data,
data_endpoint=data_endpoint,
num_threads=cmdline_args.num_threads)
logger.info("Finished training")

0 comments on commit 54f210a

Please sign in to comment.