diff --git a/optuna/cli.py b/optuna/cli.py index ecd2485d1a..ffba0e3724 100644 --- a/optuna/cli.py +++ b/optuna/cli.py @@ -21,12 +21,17 @@ from typing import Union import warnings +import sqlalchemy.exc import yaml import optuna from optuna._imports import _LazyImport from optuna.exceptions import CLIUsageError from optuna.exceptions import ExperimentalWarning +from optuna.storages import BaseStorage +from optuna.storages import JournalFileStorage +from optuna.storages import JournalRedisStorage +from optuna.storages import JournalStorage from optuna.storages import RDBStorage from optuna.trial import TrialState @@ -51,6 +56,27 @@ def _check_storage_url(storage_url: Optional[str]) -> str: raise CLIUsageError("Storage URL is not specified.") +def _get_storage(storage_url: Optional[str], storage_class: Optional[str]) -> BaseStorage: + storage_url = _check_storage_url(storage_url) + if storage_class: + if storage_class == JournalRedisStorage.__name__: + return JournalStorage(JournalRedisStorage(storage_url)) + if storage_class == JournalFileStorage.__name__: + return JournalStorage(JournalFileStorage(storage_url)) + if storage_class == RDBStorage.__name__: + return RDBStorage(storage_url) + raise CLIUsageError("Unsupported storage class") + + if storage_url.startswith("redis"): + return JournalStorage(JournalRedisStorage(storage_url)) + if os.path.isfile(storage_url): + return JournalStorage(JournalFileStorage(storage_url)) + try: + return RDBStorage(storage_url) + except sqlalchemy.exc.ArgumentError: + raise CLIUsageError("Failed to guess storage class from storage_url") + + def _format_value(value: Any) -> Any: # Format value that can be serialized to JSON or YAML. if value is None or isinstance(value, (int, float)): @@ -293,8 +319,7 @@ def add_arguments(self, parser: ArgumentParser) -> None: ) def take_action(self, parsed_args: Namespace) -> int: - storage_url = _check_storage_url(parsed_args.storage) - storage = optuna.storages.get_storage(storage_url) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) study_name = optuna.create_study( storage=storage, study_name=parsed_args.study_name, @@ -313,8 +338,7 @@ def add_arguments(self, parser: ArgumentParser) -> None: parser.add_argument("--study-name", default=None, help="The name of the study to delete.") def take_action(self, parsed_args: Namespace) -> int: - storage_url = _check_storage_url(parsed_args.storage) - storage = optuna.storages.get_storage(storage_url) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) study_id = storage.get_study_id_from_name(parsed_args.study_name) storage.delete_study(study_id) return 0 @@ -336,7 +360,7 @@ def add_arguments(self, parser: ArgumentParser) -> None: parser.add_argument("--value", required=True, help="Value to be set.") def take_action(self, parsed_args: Namespace) -> int: - storage_url = _check_storage_url(parsed_args.storage) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) if parsed_args.study and parsed_args.study_name: raise ValueError( @@ -346,9 +370,9 @@ def take_action(self, parsed_args: Namespace) -> int: elif parsed_args.study: message = "The use of `--study` is deprecated. Please use `--study-name` instead." warnings.warn(message, FutureWarning) - study = optuna.load_study(storage=storage_url, study_name=parsed_args.study) + study = optuna.load_study(storage=storage, study_name=parsed_args.study) elif parsed_args.study_name: - study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name) + study = optuna.load_study(storage=storage, study_name=parsed_args.study_name) else: raise ValueError("Missing study name. Please use `--study-name`.") @@ -385,8 +409,8 @@ def add_arguments(self, parser: ArgumentParser) -> None: ) def take_action(self, parsed_args: Namespace) -> int: - storage_url = _check_storage_url(parsed_args.storage) - summaries = optuna.get_all_study_summaries(storage_url, include_best_trial=False) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) + summaries = optuna.get_all_study_summaries(storage, include_best_trial=False) records = [] for s in summaries: @@ -444,8 +468,8 @@ def take_action(self, parsed_args: Namespace) -> int: ExperimentalWarning, ) - storage_url = _check_storage_url(parsed_args.storage) - study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) + study = optuna.load_study(storage=storage, study_name=parsed_args.study_name) attrs = ( "number", "value" if not study._is_multi_objective() else "values", @@ -494,8 +518,8 @@ def take_action(self, parsed_args: Namespace) -> int: ExperimentalWarning, ) - storage_url = _check_storage_url(parsed_args.storage) - study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) + study = optuna.load_study(storage=storage, study_name=parsed_args.study_name) attrs = ( "number", "value" if not study._is_multi_objective() else "values", @@ -548,8 +572,8 @@ def take_action(self, parsed_args: Namespace) -> int: ExperimentalWarning, ) - storage_url = _check_storage_url(parsed_args.storage) - study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) + study = optuna.load_study(storage=storage, study_name=parsed_args.study_name) best_trials = [trial.number for trial in study.best_trials] attrs = ( "number", @@ -613,7 +637,7 @@ def take_action(self, parsed_args: Namespace) -> int: ) warnings.warn(message, FutureWarning) - storage_url = _check_storage_url(parsed_args.storage) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) if parsed_args.study and parsed_args.study_name: raise ValueError( @@ -623,9 +647,9 @@ def take_action(self, parsed_args: Namespace) -> int: elif parsed_args.study: message = "The use of `--study` is deprecated. Please use `--study-name` instead." warnings.warn(message, FutureWarning) - study = optuna.load_study(storage=storage_url, study_name=parsed_args.study) + study = optuna.load_study(storage=storage, study_name=parsed_args.study) elif parsed_args.study_name: - study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name) + study = optuna.load_study(storage=storage, study_name=parsed_args.study_name) else: raise ValueError("Missing study name. Please use `--study-name`.") @@ -656,14 +680,17 @@ def take_action(self, parsed_args: Namespace) -> int: class _StorageUpgrade(_BaseCommand): - """Upgrade the schema of a storage.""" + """Upgrade the schema of an RDB storage.""" def take_action(self, parsed_args: Namespace) -> int: storage_url = _check_storage_url(parsed_args.storage) - if storage_url.startswith("redis"): - self.logger.info("This storage does not support upgrade yet.") + try: + storage = RDBStorage( + storage_url, skip_compatibility_check=True, skip_table_creation=True + ) + except sqlalchemy.exc.ArgumentError: + self.logger.error("Invalid RDBStorage URL.") return 1 - storage = RDBStorage(storage_url, skip_compatibility_check=True, skip_table_creation=True) current_version = storage.get_current_version() head_version = storage.get_head_version() known_versions = storage.get_all_versions() @@ -741,10 +768,10 @@ def take_action(self, parsed_args: Namespace) -> int: ExperimentalWarning, ) - storage_url = _check_storage_url(parsed_args.storage) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) create_study_kwargs = { - "storage": storage_url, + "storage": storage, "study_name": parsed_args.study_name, "direction": parsed_args.direction, "directions": parsed_args.directions, @@ -859,10 +886,10 @@ def take_action(self, parsed_args: Namespace) -> int: ExperimentalWarning, ) - storage_url = _check_storage_url(parsed_args.storage) + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) study = optuna.load_study( - storage=storage_url, + storage=storage, study_name=parsed_args.study_name, ) @@ -910,6 +937,16 @@ def _add_common_arguments(parser: ArgumentParser) -> ArgumentParser: "Also can be specified via OPTUNA_STORAGE environment variable." ), ) + parser.add_argument( + "--storage-class", + help="Storage class hint (e.g. JournalFileStorage)", + default=None, + choices=[ + RDBStorage.__name__, + JournalFileStorage.__name__, + JournalRedisStorage.__name__, + ], + ) verbose_group = parser.add_mutually_exclusive_group() verbose_group.add_argument( "-v", diff --git a/tests/test_cli.py b/tests/test_cli.py index cfd9026014..63a5ca29e8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,15 +1,19 @@ from collections import OrderedDict import json import os +import platform import re import subprocess from subprocess import CalledProcessError +import tempfile from typing import Any from typing import Callable from typing import Optional from typing import Tuple +from unittest.mock import MagicMock from unittest.mock import patch +import fakeredis import numpy as np from pandas import Timedelta from pandas import Timestamp @@ -20,6 +24,9 @@ import optuna.cli from optuna.exceptions import CLIUsageError from optuna.exceptions import ExperimentalWarning +from optuna.storages import JournalFileStorage +from optuna.storages import JournalRedisStorage +from optuna.storages import JournalStorage from optuna.storages import RDBStorage from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX from optuna.study import StudyDirection @@ -1049,6 +1056,51 @@ def test_check_storage_url() -> None: optuna.cli._check_storage_url(None) +@pytest.mark.skipif(platform.system() == "Windows", reason="Skip on Windows") +@patch("optuna.storages._journal.redis.redis") +def test_get_storage_without_storage_class(mock_redis: MagicMock) -> None: + with tempfile.NamedTemporaryFile(suffix=".db") as fp: + storage = optuna.cli._get_storage(f"sqlite:///{fp.name}", storage_class=None) + assert isinstance(storage, RDBStorage) + + with tempfile.NamedTemporaryFile(suffix=".log") as fp: + storage = optuna.cli._get_storage(fp.name, storage_class=None) + assert isinstance(storage, JournalStorage) + assert isinstance(storage._backend, JournalFileStorage) + + mock_redis.Redis = fakeredis.FakeRedis + storage = optuna.cli._get_storage("redis://localhost:6379", storage_class=None) + assert isinstance(storage, JournalStorage) + assert isinstance(storage._backend, JournalRedisStorage) + + with pytest.raises(CLIUsageError): + optuna.cli._get_storage("./file-not-found.log", storage_class=None) + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Skip on Windows") +@patch("optuna.storages._journal.redis.redis") +def test_get_storage_with_storage_class(mock_redis: MagicMock) -> None: + with tempfile.NamedTemporaryFile(suffix=".db") as fp: + storage = optuna.cli._get_storage(f"sqlite:///{fp.name}", storage_class=None) + assert isinstance(storage, RDBStorage) + + with tempfile.NamedTemporaryFile(suffix=".log") as fp: + storage = optuna.cli._get_storage(fp.name, storage_class="JournalFileStorage") + assert isinstance(storage, JournalStorage) + assert isinstance(storage._backend, JournalFileStorage) + + mock_redis.Redis = fakeredis.FakeRedis + storage = optuna.cli._get_storage( + "redis:///localhost:6379", storage_class="JournalRedisStorage" + ) + assert isinstance(storage, JournalStorage) + assert isinstance(storage._backend, JournalRedisStorage) + + with pytest.raises(CLIUsageError): + with tempfile.NamedTemporaryFile(suffix=".db") as fp: + optuna.cli._get_storage(f"sqlite:///{fp.name}", storage_class="InMemoryStorage") + + @pytest.mark.skip_coverage def test_storage_upgrade_command() -> None: with StorageSupplier("sqlite") as storage: @@ -1066,6 +1118,16 @@ def test_storage_upgrade_command() -> None: subprocess.check_call(command) +@pytest.mark.skip_coverage +def test_storage_upgrade_command_with_invalid_url() -> None: + with StorageSupplier("sqlite") as storage: + assert isinstance(storage, RDBStorage) + + command = ["optuna", "storage", "upgrade", "--storage", "invalid-storage-url"] + with pytest.raises(CalledProcessError): + subprocess.check_call(command) + + @pytest.mark.skip_coverage @pytest.mark.parametrize( "direction,directions,sampler,sampler_kwargs,output_format",