Skip to content

Commit

Permalink
Merge pull request #3887 from RasaHQ/temp-files-train
Browse files Browse the repository at this point in the history
Cleanup all temporary directories after training
  • Loading branch information
federicotdn authored Jul 5, 2019
2 parents 8bc5f65 + bdccf3c commit 644f247
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 110 deletions.
48 changes: 15 additions & 33 deletions rasa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import os
import shutil
import tempfile
from typing import Text, Tuple, Union, Optional, List, Dict, Type
from types import TracebackType
from typing import Text, Tuple, Union, Optional, List, Dict

import yaml.parser

Expand All @@ -19,6 +18,7 @@
from rasa.core.domain import Domain
from rasa.core.utils import get_dict_hash
from rasa.exceptions import ModelNotFound
from rasa.utils.common import TempDirectoryPath

# Type alias for the fingerprint
Fingerprint = Dict[Text, Union[Text, List[Text], int, float]]
Expand All @@ -37,25 +37,7 @@
FINGERPRINT_TRAINED_AT_KEY = "trained_at"


class UnpackedModelPath(str):
"""Represents a path to an unpacked model on disk. When used as a context
manager, it erases the unpacked model files after the context is exited.
"""

def __enter__(self) -> "UnpackedModelPath":
return self

def __exit__(
self,
_exc: Optional[Type[BaseException]],
_value: Optional[Exception],
_tb: Optional[TracebackType],
) -> bool:
shutil.rmtree(self)


def get_model(model_path: Text = DEFAULT_MODELS_PATH) -> UnpackedModelPath:
def get_model(model_path: Text = DEFAULT_MODELS_PATH) -> TempDirectoryPath:
"""Gets a model and unpacks it. Raises a `ModelNotFound` exception if
no model could be found at the provided path.
Expand Down Expand Up @@ -109,7 +91,7 @@ def get_latest_model(model_path: Text = DEFAULT_MODELS_PATH) -> Optional[Text]:

def unpack_model(
model_file: Text, working_directory: Optional[Text] = None
) -> UnpackedModelPath:
) -> TempDirectoryPath:
"""Unpacks a zipped Rasa model.
Args:
Expand All @@ -136,7 +118,7 @@ def unpack_model(
tar.close()
logger.debug("Extracted model to '{}'.".format(working_directory))

return UnpackedModelPath(working_directory)
return TempDirectoryPath(working_directory)


def get_model_subdirectories(unpacked_model_path: Text) -> Tuple[Text, Text]:
Expand Down Expand Up @@ -398,17 +380,17 @@ def should_retrain(new_fingerprint: Fingerprint, old_model: Text, train_path: Te
if old_model is None or not os.path.exists(old_model):
return retrain_core, retrain_nlu

unpacked = unpack_model(old_model)
last_fingerprint = fingerprint_from_path(unpacked)
with unpack_model(old_model) as unpacked:
last_fingerprint = fingerprint_from_path(unpacked)

old_core, old_nlu = get_model_subdirectories(unpacked)
old_core, old_nlu = get_model_subdirectories(unpacked)

if not core_fingerprint_changed(last_fingerprint, new_fingerprint):
target_path = os.path.join(train_path, "core")
retrain_core = not merge_model(old_core, target_path)
if not core_fingerprint_changed(last_fingerprint, new_fingerprint):
target_path = os.path.join(train_path, "core")
retrain_core = not merge_model(old_core, target_path)

if not nlu_fingerprint_changed(last_fingerprint, new_fingerprint):
target_path = os.path.join(train_path, "nlu")
retrain_nlu = not merge_model(old_nlu, target_path)
if not nlu_fingerprint_changed(last_fingerprint, new_fingerprint):
target_path = os.path.join(train_path, "nlu")
retrain_nlu = not merge_model(old_nlu, target_path)

return retrain_core, retrain_nlu
return retrain_core, retrain_nlu
212 changes: 137 additions & 75 deletions rasa/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import os
import tempfile
from contextlib import ExitStack
from typing import Text, Optional, List, Union, Dict

from rasa import model, data
from rasa.core.domain import Domain, InvalidDomain
from rasa.model import Fingerprint, should_retrain
from rasa.skill import SkillSelector
from rasa.utils.common import TempDirectoryPath

from rasa.cli.utils import (
create_output_path,
Expand Down Expand Up @@ -65,8 +67,6 @@ async def train_async(
Returns:
Path of the trained model archive.
"""
train_path = tempfile.mkdtemp()

skill_imports = SkillSelector.load(config, training_files)
try:
domain = Domain.load(domain, skill_imports)
Expand All @@ -81,6 +81,52 @@ async def train_async(
story_directory, nlu_data_directory = data.get_core_nlu_directories(
training_files, skill_imports
)

with ExitStack() as stack:
train_path = stack.enter_context(TempDirectoryPath(tempfile.mkdtemp()))
nlu_data = stack.enter_context(TempDirectoryPath(nlu_data_directory))
story = stack.enter_context(TempDirectoryPath(story_directory))

return await _train_async_internal(
domain,
config,
train_path,
nlu_data,
story,
output_path,
force_training,
fixed_model_name,
kwargs,
)


async def _train_async_internal(
domain: Union[Domain, Text],
config: Text,
train_path: Text,
nlu_data_directory: Text,
story_directory: Text,
output_path: Text,
force_training: bool,
fixed_model_name: Optional[Text],
kwargs: Optional[Dict],
) -> Optional[Text]:
"""Trains a Rasa model (Core and NLU). Use only from `train_async`.
Args:
domain: Path to the domain file.
config: Path to the config for Core and NLU.
train_path: Directory in which to train the model.
nlu_data_directory: Path to NLU training files.
story_directory: Path to Core training files.
output_path: Output path.
force_training: If `True` retrain model even if data has not changed.
fixed_model_name: Name of model to be stored.
kwargs: Additional training parameters.
Returns:
Path of the trained model archive.
"""
new_fingerprint = model.model_fingerprint(
config, domain, nlu_data_directory, story_directory
)
Expand Down Expand Up @@ -258,24 +304,25 @@ async def train_core_async(
)
return None

story_directory = data.get_core_directory(stories, skill_imports)
train_context = TempDirectoryPath(data.get_core_directory(stories, skill_imports))

if not os.listdir(story_directory):
print_error(
"No stories given. Please provide stories in order to "
"train a Rasa Core model using the '--stories' argument."
)
return
with train_context as story_directory:
if not os.listdir(story_directory):
print_error(
"No stories given. Please provide stories in order to "
"train a Rasa Core model using the '--stories' argument."
)
return

return await _train_core_with_validated_data(
domain=domain,
config=config,
story_directory=story_directory,
output=output,
train_path=train_path,
fixed_model_name=fixed_model_name,
kwargs=kwargs,
)
return await _train_core_with_validated_data(
domain=domain,
config=config,
story_directory=story_directory,
output=output,
train_path=train_path,
fixed_model_name=fixed_model_name,
kwargs=kwargs,
)


async def _train_core_with_validated_data(
Expand All @@ -291,33 +338,39 @@ async def _train_core_with_validated_data(

import rasa.core.train

_train_path = train_path or tempfile.mkdtemp()

# normal (not compare) training
print_color("Training Core model...", color=bcolors.OKBLUE)
await rasa.core.train(
domain_file=domain,
stories_file=story_directory,
output_path=os.path.join(_train_path, "core"),
policy_config=config,
kwargs=kwargs,
)
print_color("Core model training completed.", color=bcolors.OKBLUE)

if train_path is None:
# Only Core was trained.
new_fingerprint = model.model_fingerprint(
config, domain, stories=story_directory
)
return _package_model(
new_fingerprint=new_fingerprint,
output_path=output,
train_path=_train_path,
fixed_model_name=fixed_model_name,
model_prefix="core-",
with ExitStack() as stack:
if train_path:
# If the train path was provided, do nothing on exit.
_train_path = train_path
else:
# Otherwise, create a temp train path and clean it up on exit.
_train_path = stack.enter_context(TempDirectoryPath(tempfile.mkdtemp()))

# normal (not compare) training
print_color("Training Core model...", color=bcolors.OKBLUE)
await rasa.core.train(
domain_file=domain,
stories_file=story_directory,
output_path=os.path.join(_train_path, "core"),
policy_config=config,
kwargs=kwargs,
)
print_color("Core model training completed.", color=bcolors.OKBLUE)

return _train_path
if train_path is None:
# Only Core was trained.
new_fingerprint = model.model_fingerprint(
config, domain, stories=story_directory
)
return _package_model(
new_fingerprint=new_fingerprint,
output_path=output,
train_path=_train_path,
fixed_model_name=fixed_model_name,
model_prefix="core-",
)

return _train_path


def train_nlu(
Expand Down Expand Up @@ -346,22 +399,23 @@ def train_nlu(

# training NLU only hence the training files still have to be selected
skill_imports = SkillSelector.load(config, nlu_data)
nlu_data_directory = data.get_nlu_directory(nlu_data, skill_imports)
train_context = TempDirectoryPath(data.get_nlu_directory(nlu_data, skill_imports))

if not os.listdir(nlu_data_directory):
print_error(
"No NLU data given. Please provide NLU data in order to train "
"a Rasa NLU model using the '--nlu' argument."
)
return
with train_context as nlu_data_directory:
if not os.listdir(nlu_data_directory):
print_error(
"No NLU data given. Please provide NLU data in order to train "
"a Rasa NLU model using the '--nlu' argument."
)
return

return _train_nlu_with_validated_data(
config=config,
nlu_data_directory=nlu_data_directory,
output=output,
train_path=train_path,
fixed_model_name=fixed_model_name,
)
return _train_nlu_with_validated_data(
config=config,
nlu_data_directory=nlu_data_directory,
output=output,
train_path=train_path,
fixed_model_name=fixed_model_name,
)


def _train_nlu_with_validated_data(
Expand All @@ -375,27 +429,35 @@ def _train_nlu_with_validated_data(

import rasa.nlu.train

_train_path = train_path or tempfile.mkdtemp()

print_color("Training NLU model...", color=bcolors.OKBLUE)
_, nlu_model, _ = rasa.nlu.train(
config, nlu_data_directory, _train_path, fixed_model_name="nlu"
)
print_color("NLU model training completed.", color=bcolors.OKBLUE)
with ExitStack() as stack:
if train_path:
# If the train path was provided, do nothing on exit.
_train_path = train_path
else:
# Otherwise, create a temp train path and clean it up on exit.
_train_path = stack.enter_context(TempDirectoryPath(tempfile.mkdtemp()))

print_color("Training NLU model...", color=bcolors.OKBLUE)
_, nlu_model, _ = rasa.nlu.train(
config, nlu_data_directory, _train_path, fixed_model_name="nlu"
)
print_color("NLU model training completed.", color=bcolors.OKBLUE)

if train_path is None:
# Only NLU was trained
new_fingerprint = model.model_fingerprint(config, nlu_data=nlu_data_directory)
if train_path is None:
# Only NLU was trained
new_fingerprint = model.model_fingerprint(
config, nlu_data=nlu_data_directory
)

return _package_model(
new_fingerprint=new_fingerprint,
output_path=output,
train_path=_train_path,
fixed_model_name=fixed_model_name,
model_prefix="nlu-",
)
return _package_model(
new_fingerprint=new_fingerprint,
output_path=output,
train_path=_train_path,
fixed_model_name=fixed_model_name,
model_prefix="nlu-",
)

return _train_path
return _train_path


def _package_model(
Expand Down
Loading

0 comments on commit 644f247

Please sign in to comment.