From 870abc05e4cab61654e25bd6720b50da40d23f73 Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 1 Oct 2018 18:22:45 +0200 Subject: [PATCH 1/4] #1437 run tf training in separate thread on py3 --- CHANGELOG.rst | 3 +++ rasa_nlu/data_router.py | 7 ++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 96af397dbd9e..92cb084f5420 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,6 +20,9 @@ Removed Fixed ----- +- Allow training of pipelines containing ``EmbeddingIntentClassifier`` in + a separate thread on python 3. This makes http server calls to ``/train`` + non-blocking [0.13.5] - 2018-09-28 diff --git a/rasa_nlu/data_router.py b/rasa_nlu/data_router.py index 14b9db0ad97b..1c71fb8542ca 100644 --- a/rasa_nlu/data_router.py +++ b/rasa_nlu/data_router.py @@ -10,6 +10,7 @@ from concurrent.futures import ProcessPoolExecutor as ProcessPool from typing import Text, Dict, Any, Optional, List +import six from builtins import object from twisted.internet import reactor from twisted.internet.defer import Deferred @@ -358,9 +359,9 @@ def training_errback(failure): self._current_training_processes += 1 self.project_store[project].current_training_processes += 1 - # tensorflow training is not executed in a separate thread, as this may - # cause training to freeze - if self._tf_in_pipeline(train_config): + # tensorflow training is not executed in a separate thread on python 2, + # as this may cause training to freeze + if self._tf_in_pipeline(train_config) and six.PY2: try: logger.warning("Training a pipeline with a tensorflow " "component. This blocks the server during " From 47143c0933c066cfb6016f0ed531136c12e4e16d Mon Sep 17 00:00:00 2001 From: ricwo Date: Mon, 1 Oct 2018 18:23:36 +0200 Subject: [PATCH 2/4] #1437 check py2 first --- rasa_nlu/data_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa_nlu/data_router.py b/rasa_nlu/data_router.py index 1c71fb8542ca..32d41c7fbe45 100644 --- a/rasa_nlu/data_router.py +++ b/rasa_nlu/data_router.py @@ -361,7 +361,7 @@ def training_errback(failure): # tensorflow training is not executed in a separate thread on python 2, # as this may cause training to freeze - if self._tf_in_pipeline(train_config) and six.PY2: + if six.PY2 and self._tf_in_pipeline(train_config): try: logger.warning("Training a pipeline with a tensorflow " "component. This blocks the server during " From 4557f42ad644be2948c5ed0cd92c3389f1819edc Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 3 Oct 2018 18:41:44 +0200 Subject: [PATCH 3/4] #1437 use multiprocessing star method spawn --- rasa_nlu/data_router.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/rasa_nlu/data_router.py b/rasa_nlu/data_router.py index 32d41c7fbe45..e31d70b116d4 100644 --- a/rasa_nlu/data_router.py +++ b/rasa_nlu/data_router.py @@ -6,6 +6,7 @@ import datetime import io import logging +import multiprocessing import os from concurrent.futures import ProcessPoolExecutor as ProcessPool from typing import Text, Dict, Any, Optional, List @@ -108,8 +109,17 @@ def __init__(self, self.component_builder = ComponentBuilder(use_cache=True) self.project_store = self._create_project_store(project_dir) + + if six.PY3: + # tensorflow sessions are not fork-safe, + # and training processes have to be spawned instead of forked. + # See https://github.com/tensorflow/tensorflow/issues/5448#issuecomment-258934405 + multiprocessing.set_start_method('spawn', force=True) + self.pool = ProcessPool(self._training_processes) + + def __del__(self): """Terminates workers pool processes""" self.pool.shutdown() From e19cb912bab2a7957df2c69904da3426b0ca2031 Mon Sep 17 00:00:00 2001 From: ricwo Date: Wed, 3 Oct 2018 18:45:20 +0200 Subject: [PATCH 4/4] #1437 remove rogue newlines --- rasa_nlu/data_router.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/rasa_nlu/data_router.py b/rasa_nlu/data_router.py index e31d70b116d4..826959873eb3 100644 --- a/rasa_nlu/data_router.py +++ b/rasa_nlu/data_router.py @@ -118,8 +118,6 @@ def __init__(self, self.pool = ProcessPool(self._training_processes) - - def __del__(self): """Terminates workers pool processes""" self.pool.shutdown()