Skip to content

Commit

Permalink
Merge pull request optuna#4696 from c-bata/cli-journal-storage-support
Browse files Browse the repository at this point in the history
Support `JournalFileStorage` and `JournalRedisStorage` on CLI
  • Loading branch information
not522 committed Jul 5, 2023
2 parents 1b0c6a8 + fcc1a1d commit 7e58315
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 26 deletions.
89 changes: 63 additions & 26 deletions optuna/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@
from typing import Union
import warnings

import sqlalchemy.exc
import yaml

import optuna
from optuna._imports import _LazyImport
from optuna.exceptions import CLIUsageError
from optuna.exceptions import ExperimentalWarning
from optuna.storages import BaseStorage
from optuna.storages import JournalFileStorage
from optuna.storages import JournalRedisStorage
from optuna.storages import JournalStorage
from optuna.storages import RDBStorage
from optuna.trial import TrialState

Expand All @@ -51,6 +56,27 @@ def _check_storage_url(storage_url: Optional[str]) -> str:
raise CLIUsageError("Storage URL is not specified.")


def _get_storage(storage_url: Optional[str], storage_class: Optional[str]) -> BaseStorage:
storage_url = _check_storage_url(storage_url)
if storage_class:
if storage_class == JournalRedisStorage.__name__:
return JournalStorage(JournalRedisStorage(storage_url))
if storage_class == JournalFileStorage.__name__:
return JournalStorage(JournalFileStorage(storage_url))
if storage_class == RDBStorage.__name__:
return RDBStorage(storage_url)
raise CLIUsageError("Unsupported storage class")

if storage_url.startswith("redis"):
return JournalStorage(JournalRedisStorage(storage_url))
if os.path.isfile(storage_url):
return JournalStorage(JournalFileStorage(storage_url))
try:
return RDBStorage(storage_url)
except sqlalchemy.exc.ArgumentError:
raise CLIUsageError("Failed to guess storage class from storage_url")


def _format_value(value: Any) -> Any:
# Format value that can be serialized to JSON or YAML.
if value is None or isinstance(value, (int, float)):
Expand Down Expand Up @@ -293,8 +319,7 @@ def add_arguments(self, parser: ArgumentParser) -> None:
)

def take_action(self, parsed_args: Namespace) -> int:
storage_url = _check_storage_url(parsed_args.storage)
storage = optuna.storages.get_storage(storage_url)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)
study_name = optuna.create_study(
storage=storage,
study_name=parsed_args.study_name,
Expand All @@ -313,8 +338,7 @@ def add_arguments(self, parser: ArgumentParser) -> None:
parser.add_argument("--study-name", default=None, help="The name of the study to delete.")

def take_action(self, parsed_args: Namespace) -> int:
storage_url = _check_storage_url(parsed_args.storage)
storage = optuna.storages.get_storage(storage_url)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)
study_id = storage.get_study_id_from_name(parsed_args.study_name)
storage.delete_study(study_id)
return 0
Expand All @@ -336,7 +360,7 @@ def add_arguments(self, parser: ArgumentParser) -> None:
parser.add_argument("--value", required=True, help="Value to be set.")

def take_action(self, parsed_args: Namespace) -> int:
storage_url = _check_storage_url(parsed_args.storage)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)

if parsed_args.study and parsed_args.study_name:
raise ValueError(
Expand All @@ -346,9 +370,9 @@ def take_action(self, parsed_args: Namespace) -> int:
elif parsed_args.study:
message = "The use of `--study` is deprecated. Please use `--study-name` instead."
warnings.warn(message, FutureWarning)
study = optuna.load_study(storage=storage_url, study_name=parsed_args.study)
study = optuna.load_study(storage=storage, study_name=parsed_args.study)
elif parsed_args.study_name:
study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name)
study = optuna.load_study(storage=storage, study_name=parsed_args.study_name)
else:
raise ValueError("Missing study name. Please use `--study-name`.")

Expand Down Expand Up @@ -385,8 +409,8 @@ def add_arguments(self, parser: ArgumentParser) -> None:
)

def take_action(self, parsed_args: Namespace) -> int:
storage_url = _check_storage_url(parsed_args.storage)
summaries = optuna.get_all_study_summaries(storage_url, include_best_trial=False)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)
summaries = optuna.get_all_study_summaries(storage, include_best_trial=False)

records = []
for s in summaries:
Expand Down Expand Up @@ -444,8 +468,8 @@ def take_action(self, parsed_args: Namespace) -> int:
ExperimentalWarning,
)

storage_url = _check_storage_url(parsed_args.storage)
study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)
study = optuna.load_study(storage=storage, study_name=parsed_args.study_name)
attrs = (
"number",
"value" if not study._is_multi_objective() else "values",
Expand Down Expand Up @@ -494,8 +518,8 @@ def take_action(self, parsed_args: Namespace) -> int:
ExperimentalWarning,
)

storage_url = _check_storage_url(parsed_args.storage)
study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)
study = optuna.load_study(storage=storage, study_name=parsed_args.study_name)
attrs = (
"number",
"value" if not study._is_multi_objective() else "values",
Expand Down Expand Up @@ -548,8 +572,8 @@ def take_action(self, parsed_args: Namespace) -> int:
ExperimentalWarning,
)

storage_url = _check_storage_url(parsed_args.storage)
study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)
study = optuna.load_study(storage=storage, study_name=parsed_args.study_name)
best_trials = [trial.number for trial in study.best_trials]
attrs = (
"number",
Expand Down Expand Up @@ -613,7 +637,7 @@ def take_action(self, parsed_args: Namespace) -> int:
)
warnings.warn(message, FutureWarning)

storage_url = _check_storage_url(parsed_args.storage)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)

if parsed_args.study and parsed_args.study_name:
raise ValueError(
Expand All @@ -623,9 +647,9 @@ def take_action(self, parsed_args: Namespace) -> int:
elif parsed_args.study:
message = "The use of `--study` is deprecated. Please use `--study-name` instead."
warnings.warn(message, FutureWarning)
study = optuna.load_study(storage=storage_url, study_name=parsed_args.study)
study = optuna.load_study(storage=storage, study_name=parsed_args.study)
elif parsed_args.study_name:
study = optuna.load_study(storage=storage_url, study_name=parsed_args.study_name)
study = optuna.load_study(storage=storage, study_name=parsed_args.study_name)
else:
raise ValueError("Missing study name. Please use `--study-name`.")

Expand Down Expand Up @@ -656,14 +680,17 @@ def take_action(self, parsed_args: Namespace) -> int:


class _StorageUpgrade(_BaseCommand):
"""Upgrade the schema of a storage."""
"""Upgrade the schema of an RDB storage."""

def take_action(self, parsed_args: Namespace) -> int:
storage_url = _check_storage_url(parsed_args.storage)
if storage_url.startswith("redis"):
self.logger.info("This storage does not support upgrade yet.")
try:
storage = RDBStorage(
storage_url, skip_compatibility_check=True, skip_table_creation=True
)
except sqlalchemy.exc.ArgumentError:
self.logger.error("Invalid RDBStorage URL.")
return 1
storage = RDBStorage(storage_url, skip_compatibility_check=True, skip_table_creation=True)
current_version = storage.get_current_version()
head_version = storage.get_head_version()
known_versions = storage.get_all_versions()
Expand Down Expand Up @@ -741,10 +768,10 @@ def take_action(self, parsed_args: Namespace) -> int:
ExperimentalWarning,
)

storage_url = _check_storage_url(parsed_args.storage)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)

create_study_kwargs = {
"storage": storage_url,
"storage": storage,
"study_name": parsed_args.study_name,
"direction": parsed_args.direction,
"directions": parsed_args.directions,
Expand Down Expand Up @@ -859,10 +886,10 @@ def take_action(self, parsed_args: Namespace) -> int:
ExperimentalWarning,
)

storage_url = _check_storage_url(parsed_args.storage)
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)

study = optuna.load_study(
storage=storage_url,
storage=storage,
study_name=parsed_args.study_name,
)

Expand Down Expand Up @@ -910,6 +937,16 @@ def _add_common_arguments(parser: ArgumentParser) -> ArgumentParser:
"Also can be specified via OPTUNA_STORAGE environment variable."
),
)
parser.add_argument(
"--storage-class",
help="Storage class hint (e.g. JournalFileStorage)",
default=None,
choices=[
RDBStorage.__name__,
JournalFileStorage.__name__,
JournalRedisStorage.__name__,
],
)
verbose_group = parser.add_mutually_exclusive_group()
verbose_group.add_argument(
"-v",
Expand Down
62 changes: 62 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from collections import OrderedDict
import json
import os
import platform
import re
import subprocess
from subprocess import CalledProcessError
import tempfile
from typing import Any
from typing import Callable
from typing import Optional
from typing import Tuple
from unittest.mock import MagicMock
from unittest.mock import patch

import fakeredis
import numpy as np
from pandas import Timedelta
from pandas import Timestamp
Expand All @@ -20,6 +24,9 @@
import optuna.cli
from optuna.exceptions import CLIUsageError
from optuna.exceptions import ExperimentalWarning
from optuna.storages import JournalFileStorage
from optuna.storages import JournalRedisStorage
from optuna.storages import JournalStorage
from optuna.storages import RDBStorage
from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
from optuna.study import StudyDirection
Expand Down Expand Up @@ -1049,6 +1056,51 @@ def test_check_storage_url() -> None:
optuna.cli._check_storage_url(None)


@pytest.mark.skipif(platform.system() == "Windows", reason="Skip on Windows")
@patch("optuna.storages._journal.redis.redis")
def test_get_storage_without_storage_class(mock_redis: MagicMock) -> None:
with tempfile.NamedTemporaryFile(suffix=".db") as fp:
storage = optuna.cli._get_storage(f"sqlite:///{fp.name}", storage_class=None)
assert isinstance(storage, RDBStorage)

with tempfile.NamedTemporaryFile(suffix=".log") as fp:
storage = optuna.cli._get_storage(fp.name, storage_class=None)
assert isinstance(storage, JournalStorage)
assert isinstance(storage._backend, JournalFileStorage)

mock_redis.Redis = fakeredis.FakeRedis
storage = optuna.cli._get_storage("redis://localhost:6379", storage_class=None)
assert isinstance(storage, JournalStorage)
assert isinstance(storage._backend, JournalRedisStorage)

with pytest.raises(CLIUsageError):
optuna.cli._get_storage("./file-not-found.log", storage_class=None)


@pytest.mark.skipif(platform.system() == "Windows", reason="Skip on Windows")
@patch("optuna.storages._journal.redis.redis")
def test_get_storage_with_storage_class(mock_redis: MagicMock) -> None:
with tempfile.NamedTemporaryFile(suffix=".db") as fp:
storage = optuna.cli._get_storage(f"sqlite:///{fp.name}", storage_class=None)
assert isinstance(storage, RDBStorage)

with tempfile.NamedTemporaryFile(suffix=".log") as fp:
storage = optuna.cli._get_storage(fp.name, storage_class="JournalFileStorage")
assert isinstance(storage, JournalStorage)
assert isinstance(storage._backend, JournalFileStorage)

mock_redis.Redis = fakeredis.FakeRedis
storage = optuna.cli._get_storage(
"redis:///localhost:6379", storage_class="JournalRedisStorage"
)
assert isinstance(storage, JournalStorage)
assert isinstance(storage._backend, JournalRedisStorage)

with pytest.raises(CLIUsageError):
with tempfile.NamedTemporaryFile(suffix=".db") as fp:
optuna.cli._get_storage(f"sqlite:///{fp.name}", storage_class="InMemoryStorage")


@pytest.mark.skip_coverage
def test_storage_upgrade_command() -> None:
with StorageSupplier("sqlite") as storage:
Expand All @@ -1066,6 +1118,16 @@ def test_storage_upgrade_command() -> None:
subprocess.check_call(command)


@pytest.mark.skip_coverage
def test_storage_upgrade_command_with_invalid_url() -> None:
with StorageSupplier("sqlite") as storage:
assert isinstance(storage, RDBStorage)

command = ["optuna", "storage", "upgrade", "--storage", "invalid-storage-url"]
with pytest.raises(CalledProcessError):
subprocess.check_call(command)


@pytest.mark.skip_coverage
@pytest.mark.parametrize(
"direction,directions,sampler,sampler_kwargs,output_format",
Expand Down

0 comments on commit 7e58315

Please sign in to comment.