diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9a36ac667a9d..1d5b8d388c05 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,9 @@ Removed Fixed ----- +- Allow training of pipelines containing ``EmbeddingIntentClassifier`` in + a separate thread on python 3. This makes http server calls to ``/train`` + non-blocking [0.13.5] - 2018-09-28 diff --git a/rasa_nlu/data_router.py b/rasa_nlu/data_router.py index 14b9db0ad97b..826959873eb3 100644 --- a/rasa_nlu/data_router.py +++ b/rasa_nlu/data_router.py @@ -6,10 +6,12 @@ import datetime import io import logging +import multiprocessing import os from concurrent.futures import ProcessPoolExecutor as ProcessPool from typing import Text, Dict, Any, Optional, List +import six from builtins import object from twisted.internet import reactor from twisted.internet.defer import Deferred @@ -107,6 +109,13 @@ def __init__(self, self.component_builder = ComponentBuilder(use_cache=True) self.project_store = self._create_project_store(project_dir) + + if six.PY3: + # tensorflow sessions are not fork-safe, + # and training processes have to be spawned instead of forked. + # See https://github.com/tensorflow/tensorflow/issues/5448#issuecomment-258934405 + multiprocessing.set_start_method('spawn', force=True) + self.pool = ProcessPool(self._training_processes) def __del__(self): @@ -358,9 +367,9 @@ def training_errback(failure): self._current_training_processes += 1 self.project_store[project].current_training_processes += 1 - # tensorflow training is not executed in a separate thread, as this may - # cause training to freeze - if self._tf_in_pipeline(train_config): + # tensorflow training is not executed in a separate thread on python 2, + # as this may cause training to freeze + if six.PY2 and self._tf_in_pipeline(train_config): try: logger.warning("Training a pipeline with a tensorflow " "component. This blocks the server during "