Skip to content

Commit

Permalink
Merge pull request #1439 from RasaHQ/tf_py3
Browse files Browse the repository at this point in the history
Threaded tensorflow training on python 3
  • Loading branch information
tmbo committed Oct 4, 2018
2 parents aeca318 + fb0c959 commit c681bde
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -21,6 +21,9 @@ Removed

Fixed
-----
- Allow training of pipelines containing ``EmbeddingIntentClassifier`` in
a separate thread on python 3. This makes http server calls to ``/train``
non-blocking


[0.13.5] - 2018-09-28
Expand Down
15 changes: 12 additions & 3 deletions rasa_nlu/data_router.py
Expand Up @@ -6,10 +6,12 @@
import datetime
import io
import logging
import multiprocessing
import os
from concurrent.futures import ProcessPoolExecutor as ProcessPool
from typing import Text, Dict, Any, Optional, List

import six
from builtins import object
from twisted.internet import reactor
from twisted.internet.defer import Deferred
Expand Down Expand Up @@ -107,6 +109,13 @@ def __init__(self,
self.component_builder = ComponentBuilder(use_cache=True)

self.project_store = self._create_project_store(project_dir)

if six.PY3:
# tensorflow sessions are not fork-safe,
# and training processes have to be spawned instead of forked.
# See https://github.com/tensorflow/tensorflow/issues/5448#issuecomment-258934405
multiprocessing.set_start_method('spawn', force=True)

self.pool = ProcessPool(self._training_processes)

def __del__(self):
Expand Down Expand Up @@ -358,9 +367,9 @@ def training_errback(failure):
self._current_training_processes += 1
self.project_store[project].current_training_processes += 1

# tensorflow training is not executed in a separate thread, as this may
# cause training to freeze
if self._tf_in_pipeline(train_config):
# tensorflow training is not executed in a separate thread on python 2,
# as this may cause training to freeze
if six.PY2 and self._tf_in_pipeline(train_config):
try:
logger.warning("Training a pipeline with a tensorflow "
"component. This blocks the server during "
Expand Down

0 comments on commit c681bde

Please sign in to comment.