Skip to content

Commit

Permalink
fix the "loading model" message which was logged twice when using ras…
Browse files Browse the repository at this point in the history
…a run.

change formatting

fix model path return
  • Loading branch information
Anca Lita authored and ancalita committed Mar 23, 2021
1 parent f56c8b8 commit 2566827
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 7 deletions.
1 change: 1 addition & 0 deletions changelog/7260.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed the 'loading model' message which was logged twice when using `rasa run`.
Binary file added other
Binary file not shown.
4 changes: 2 additions & 2 deletions rasa/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def run(args: argparse.Namespace):
# make sure either a model server, a remote storage, or a local model is
# configured

from rasa.model import get_model
from rasa.model import verify_model_path, get_model
from rasa.core.utils import AvailableEndpoints

# start server if remote storage is configured
Expand All @@ -113,7 +113,7 @@ def run(args: argparse.Namespace):
args.model = _validate_model_path(args.model, "model", DEFAULT_MODELS_PATH)
local_model_set = True
try:
get_model(args.model)
verify_model_path(args.model)
except ModelNotFound:
local_model_set = False

Expand Down
29 changes: 24 additions & 5 deletions rasa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,15 @@ def should_retrain_nlu(self) -> bool:
return self.force_training or self.nlu


def get_model(model_path: Text = DEFAULT_MODELS_PATH) -> TempDirectoryPath:
"""Get a model and unpack it. Raises a `ModelNotFound` exception if
no model could be found at the provided path.
def verify_model_path(model_path: Text = DEFAULT_MODELS_PATH) -> Text:
"""Verifies that a model path exists.
Args:
model_path: Path to the zipped model. If it's a directory, the latest
trained model is returned.
Returns:
Path to the unpacked model.
Raises:
ModelNotFound Exception: When no model could be found at the provided path.
"""
if not model_path:
Expand All @@ -154,10 +153,30 @@ def get_model(model_path: Text = DEFAULT_MODELS_PATH) -> TempDirectoryPath:
elif not model_path.endswith(".tar.gz"):
raise ModelNotFound(f"Path '{model_path}' does not point to a Rasa model file.")

return model_path


def get_model(model_path: Text = DEFAULT_MODELS_PATH) -> TempDirectoryPath:
"""Get a model and unpack it.
Args:
model_path: Path to the zipped model. If it's a directory, the latest
trained model is returned.
Returns:
Path to the unpacked model.
Raises:
ModelNotFound Exception: When no model could be found at the provided path.
"""
model_path = verify_model_path(model_path)

try:
model_relative_path = os.path.relpath(model_path)
except ValueError:
model_relative_path = model_path

logger.info(f"Loading model {model_relative_path}...")

return unpack_model(model_path)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
can_finetune,
create_package_rasa,
get_latest_model,
verify_model_path,
get_model,
get_model_subdirectories,
model_fingerprint,
Expand Down Expand Up @@ -78,6 +79,12 @@ def test_get_model_context_manager(trained_rasa_model: str):
assert not os.path.exists(unpacked)


@pytest.mark.parametrize("model_path", ["foobar", "rasa", "README.md", None])
def test_verify_model_path_exception(model_path: Optional[Text]):
with pytest.raises(ModelNotFound):
verify_model_path(model_path)


@pytest.mark.parametrize("model_path", ["foobar", "rasa", "README.md", None])
def test_get_model_exception(model_path: Optional[Text]):
with pytest.raises(ModelNotFound):
Expand Down

0 comments on commit 2566827

Please sign in to comment.