Skip to content

Commit

Permalink
Merge 246ee2b into 6e14203
Browse files Browse the repository at this point in the history
  • Loading branch information
EPedrotti committed Feb 25, 2019
2 parents 6e14203 + 246ee2b commit 4f479a4
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ docs/key.pub
secrets.tar
.pytest_cache
src
test_download.zip

2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Changed
- validate training data only if used for training
- applied spacy guidelines on how to disable pipeline components
- starter packs now also tested when attempting to merge a branch to master
- `/train` endpoint now returns a zipfile of the trained model. This is done to avoid sharing a folder between `api`
service and `nlu` service when running NLU in the platform.

=======
- replace pep8 with pycodestyle
Expand Down
3 changes: 2 additions & 1 deletion docs/http.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ You can also query against a specific model for a project :

You can post your training data to this endpoint to train a new model for a project.
This request will wait for the server answer: either the model
was trained successfully or the training exited with an error.
was trained successfully or the training exited with an error. If the model
is trained successfully a zip file is returned with the trained model.
Using the HTTP server, you must specify the project you want to train a
new model for to be able to use it during parse requests later on :
``/train?project=my_project``. The configuration of the model should be
Expand Down
8 changes: 4 additions & 4 deletions rasa_nlu/data_router.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import multiprocessing

import datetime
import io
import logging
import multiprocessing
import os
from concurrent.futures import ProcessPoolExecutor as ProcessPool
from typing import Any, Dict, List, Optional, Text

from twisted.internet import reactor
from twisted.internet.defer import Deferred
from twisted.logger import Logger, jsonFileLogObserver
from typing import Any, Dict, List, Optional, Text

from rasa_nlu import config, utils
from rasa_nlu.components import ComponentBuilder
Expand Down Expand Up @@ -326,7 +326,7 @@ def training_callback(model_path):
self.project_store[project].current_training_processes ==
0):
self.project_store[project].status = STATUS_READY
return model_dir
return model_path

def training_errback(failure):
logger.warning(failure)
Expand Down
23 changes: 13 additions & 10 deletions rasa_nlu/server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import argparse
import io
import logging
import simplejson
from functools import wraps

import simplejson
from klein import Klein
from twisted.internet import reactor, threads
from twisted.internet.defer import inlineCallbacks, returnValue

from rasa_nlu import config, utils
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.data_router import (
DataRouter, InvalidProjectError,
MaxTrainingError)
DataRouter, InvalidProjectError, MaxTrainingError)
from rasa_nlu.model import MINIMUM_COMPATIBLE_VERSION
from rasa_nlu.train import TrainingException
from rasa_nlu.utils import json_to_string, read_endpoints
Expand Down Expand Up @@ -320,12 +321,11 @@ def get_request_content_type(self, request):
else:
return content_type[0]

@app.route("/train", methods=['POST', 'OPTIONS'])
@app.route("/train", methods=['POST', 'OPTIONS'], branch=True)
@requires_auth
@check_cors
@inlineCallbacks
def train(self, request):

# if not set will use the default project name, e.g. "default"
project = parameter_or_default(request, "project", default=None)
# if set will not generate a model name but use the passed one
Expand All @@ -340,16 +340,19 @@ def train(self, request):

data_file = dump_to_data_file(data)

request.setHeader('Content-Type', 'application/json')
request.setHeader('Content-Type', 'application/zip')

try:
request.setResponseCode(200)

response = yield self.data_router.start_train_process(
request.setHeader("Content-Disposition", "attachment")
path_to_model = yield self.data_router.start_train_process(
data_file, project,
RasaNLUModelConfig(model_config), model_name)
returnValue(json_to_string({'info': 'new model trained',
'model': response}))
zipped_path = utils.zip_folder(path_to_model)

zip_content = io.open(zipped_path, 'r+b').read()
return returnValue(zip_content)

except MaxTrainingError as e:
request.setResponseCode(403)
returnValue(json_to_string({"error": "{}".format(e)}))
Expand Down
11 changes: 5 additions & 6 deletions rasa_nlu/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import errno
from collections import namedtuple

import glob
import io
import json
import logging
import os
import re
import tempfile
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Text, Type

import requests
import ruamel.yaml as yaml
import simplejson
import tempfile
from requests import Response
from requests.auth import HTTPBasicAuth
from typing import Any, Callable, Dict, List, Optional, Text, Type


def add_logging_option_arguments(parser, default=logging.WARNING):
Expand Down Expand Up @@ -375,10 +375,9 @@ def zip_folder(folder: Text) -> Text:
import tempfile
import shutil

# WARN: not thread save!
zipped_path = tempfile.NamedTemporaryFile(delete=False)
zipped_path.close()

# WARN: not thread save!
return shutil.make_archive(zipped_path.name, str("zip"), folder)


Expand Down
22 changes: 19 additions & 3 deletions tests/base/test_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
import time

import io
import json
import tempfile
import time

import pytest
import ruamel.yaml as yaml
import tempfile
from treq.testing import StubTreq

from rasa_nlu.data_router import DataRouter
Expand Down Expand Up @@ -144,6 +144,22 @@ def test_post_train(app, rasa_default_train_data):
assert "error" in rjs


@utilities.slowtest
@pytest.inlineCallbacks
def test_post_train_success(app, rasa_default_train_data):
import zipfile
model_config = {"pipeline": "keyword", "data": rasa_default_train_data}

response = app.post("http://dummy-uri/train?project=test&model=test",
json=model_config)
time.sleep(3)
app.flush()
response = yield response
content = yield response.content()
assert response.code == 200
assert zipfile.ZipFile(io.BytesIO(content)).testzip() is None


@utilities.slowtest
@pytest.inlineCallbacks
def test_post_train_internal_error(app, rasa_default_train_data):
Expand Down

0 comments on commit 4f479a4

Please sign in to comment.