diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d88c8be7d..91d04d00f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -27,6 +27,7 @@ jobs: strategy: matrix: toxenv: [pylint, doc8, docs] + steps: - uses: actions/checkout@v2 - name: Set up Python 3.9 diff --git a/dashboard/src/src/__tests__/flattenObject.test.js b/dashboard/src/src/__tests__/flattenObject.test.js index f7092ca69..0c0c95207 100644 --- a/dashboard/src/src/__tests__/flattenObject.test.js +++ b/dashboard/src/src/__tests__/flattenObject.test.js @@ -4,7 +4,7 @@ test('test flatten object', () => { const input = { a: 1, b: { - ba: 'world', + be: 'world', bb: 1.5, bc: { bd: { @@ -30,7 +30,7 @@ test('test flatten object', () => { const keys = Object.keys(output); expect(keys.length).toBe(12); expect(output.hasOwnProperty('a')).toBeTruthy(); - expect(output.hasOwnProperty('b.ba')).toBeTruthy(); + expect(output.hasOwnProperty('b.be')).toBeTruthy(); expect(output.hasOwnProperty('b.bb')).toBeTruthy(); expect(output.hasOwnProperty('b.bc.bd.a key')).toBeTruthy(); expect(output.hasOwnProperty('b.bc.bd.another key.x')).toBeTruthy(); diff --git a/docs/scripts/build_database_and_plots.py b/docs/scripts/build_database_and_plots.py index 7d75774ff..31f73cdc1 100644 --- a/docs/scripts/build_database_and_plots.py +++ b/docs/scripts/build_database_and_plots.py @@ -5,8 +5,7 @@ import subprocess from orion.client import get_experiment -from orion.core.utils.singleton import update_singletons -from orion.storage.base import get_storage, setup_storage +from orion.storage.base import setup_storage ROOT_DIR = os.path.abspath(os.path.dirname(__file__)) DOC_SRC_DIR = os.path.join(ROOT_DIR, "..", "src") @@ -69,9 +68,8 @@ def prepare_dbs(): def setup_tmp_storage(host): # Clear singletons - update_singletons() - setup_storage( + return setup_storage( storage={ "type": "legacy", "database": { @@ -81,8 +79,6 @@ def setup_tmp_storage(host): } ) - return get_storage() - def load_data(host): print("Loading data from", host) @@ -102,7 +98,7 @@ def copy_data(data, host=TMP_DB_HOST): storage = setup_tmp_storage(host) for exp_id, experiment in data["experiments"].items(): del experiment["_id"] - storage.create_experiment(experiment) + storage.create_experiment(experiment, storage=storage) assert exp_id != experiment["_id"] trials = [] for trial in data["trials"][exp_id]: diff --git a/docs/scripts/filter_database.py b/docs/scripts/filter_database.py index bcf847c3e..8b4b7e1f3 100644 --- a/docs/scripts/filter_database.py +++ b/docs/scripts/filter_database.py @@ -5,11 +5,11 @@ import shutil from orion.core.io.orion_cmdline_parser import OrionCmdlineParser -from orion.storage.base import get_storage, setup_storage +from orion.storage.base import setup_storage shutil.copy("./examples/plotting/database.pkl", "./examples/base_db.pkl") -setup_storage( +storage = setup_storage( dict( type="legacy", database=dict(type="pickleddb", host="./examples/base_db.pkl"), @@ -23,8 +23,6 @@ ("lateral-view-multitask3", 1): "3-dim-cat-shape-exp", } -storage = get_storage() - def update_dropout(experiment_config): metadata = experiment_config["metadata"] diff --git a/examples/benchmark/profet_benchmark.py b/examples/benchmark/profet_benchmark.py index c8795879d..5bfb6c3b1 100644 --- a/examples/benchmark/profet_benchmark.py +++ b/examples/benchmark/profet_benchmark.py @@ -22,6 +22,7 @@ ProfetXgBoostTask, ) from orion.benchmark.task.profet.profet_task import MetaModelConfig, ProfetTask +from orion.storage.base import setup_storage try: from simple_parsing.helpers import choice @@ -103,7 +104,15 @@ def main(config: ProfetExperimentConfig): print(f"Storage file used: {config.storage_pickle_path}") + storage = setup_storage( + { + "type": "legacy", + "database": {"type": "pickleddb", "host": str(config.storage_pickle_path)}, + } + ) + benchmark = get_or_create_benchmark( + storage, name=config.name, algorithms=config.algorithms, targets=[ @@ -114,10 +123,6 @@ def main(config: ProfetExperimentConfig): ], } ], - storage={ - "type": "legacy", - "database": {"type": "pickleddb", "host": str(config.storage_pickle_path)}, - }, debug=config.debug, ) benchmark.setup_studies() diff --git a/src/orion/benchmark/__init__.py b/src/orion/benchmark/__init__.py index a3cb6481b..ce28428a9 100644 --- a/src/orion/benchmark/__init__.py +++ b/src/orion/benchmark/__init__.py @@ -12,6 +12,7 @@ import orion.core from orion.client import create_experiment from orion.executor.base import executor_factory +from orion.storage.base import BaseStorageProtocol class Benchmark: @@ -20,8 +21,12 @@ class Benchmark: Parameters ---------- + storage: Storage + Instance of the storage to use + name: str Name of the benchmark + algorithms: list, optional Algorithms used for benchmark, and for each algorithm, it can be formats as below: @@ -49,19 +54,26 @@ class Benchmark: task: list Task objects - storage: dict, optional - Configuration of the storage backend. executor: `orion.executor.base.BaseExecutor`, optional Executor to run the benchmark experiments """ - def __init__(self, name, algorithms, targets, storage=None, executor=None): + def __init__( + self, + storage, + name, + algorithms, + targets, + executor=None, + ): + assert isinstance(storage, BaseStorageProtocol) + self._id = None self.name = name self.algorithms = algorithms self.targets = targets self.metadata = {} - self.storage_config = storage + self.storage = storage self._executor = executor self._executor_owner = False @@ -353,7 +365,7 @@ def setup_experiments(self): space=space, algorithms=algorithm.experiment_algorithm, max_trials=max_trials, - storage=self.benchmark.storage_config, + storage=self.benchmark.storage, executor=executor, ) self.experiments_info.append((task_index, experiment)) diff --git a/src/orion/benchmark/benchmark_client.py b/src/orion/benchmark/benchmark_client.py index f5062ec6b..507bb76b0 100644 --- a/src/orion/benchmark/benchmark_client.py +++ b/src/orion/benchmark/benchmark_client.py @@ -11,19 +11,24 @@ from orion.benchmark.task.base import bench_task_factory from orion.core.io.database import DuplicateKeyError from orion.core.utils.exceptions import NoConfigurationError -from orion.storage.base import get_storage, setup_storage logger = logging.getLogger(__name__) def get_or_create_benchmark( - name, algorithms=None, targets=None, storage=None, executor=None, debug=False + storage, + name, + algorithms=None, + targets=None, + executor=None, ): """ Create or get a benchmark object. Parameters ---------- + storage: BaseStorageProtocol + Instance of the storage to use name: str Name of the benchmark algorithms: list, optional @@ -35,30 +40,25 @@ def get_or_create_benchmark( Assessment objects task: list Task objects - storage: dict, optional - Configuration of the storage backend. executor: `orion.executor.base.BaseExecutor`, optional Executor to run the benchmark experiments - debug: bool, optional - If using in debug mode, the storage config is overridden with legacy:EphemeralDB. - Defaults to False. Returns ------- An instance of `orion.benchmark.Benchmark` """ - setup_storage(storage=storage, debug=debug) # fetch benchmark from db - db_config = _fetch_benchmark(name) + db_config = _fetch_benchmark(storage, name) benchmark_id = None input_configure = None if db_config: if algorithms or targets: - input_benchmark = Benchmark(name, algorithms, targets) + input_benchmark = Benchmark(storage, name, algorithms, targets) input_configure = input_benchmark.configuration + benchmark_id, algorithms, targets = _resolve_db_config(db_config) if not algorithms or not targets: @@ -68,7 +68,11 @@ def get_or_create_benchmark( ) benchmark = _create_benchmark( - name, algorithms, targets, storage=storage, executor=executor + storage, + name, + algorithms, + targets, + executor=executor, ) if input_configure and input_benchmark.configuration != benchmark.configuration: @@ -80,7 +84,7 @@ def get_or_create_benchmark( if benchmark_id is None: logger.debug("Benchmark not found in DB. Now attempting registration in DB.") try: - _register_benchmark(benchmark) + _register_benchmark(storage, benchmark) logger.debug("Benchmark successfully registered in DB.") except DuplicateKeyError: logger.info( @@ -89,7 +93,11 @@ def get_or_create_benchmark( ) benchmark.close() benchmark = get_or_create_benchmark( - name, algorithms, targets, storage, executor, debug + storage, + name, + algorithms, + targets, + executor, ) return benchmark @@ -132,9 +140,9 @@ def _resolve_db_config(db_config): return benchmark_id, algorithms, targets -def _create_benchmark(name, algorithms, targets, storage, executor): +def _create_benchmark(storage, name, algorithms, targets, executor): - benchmark = Benchmark(name, algorithms, targets, storage, executor) + benchmark = Benchmark(storage, name, algorithms, targets, executor) benchmark.setup_studies() return benchmark @@ -147,12 +155,8 @@ def _create_study(benchmark, algorithms, assess, task): return study -def _fetch_benchmark(name): - - if name: - configs = get_storage().fetch_benchmark({"name": name}) - else: - configs = get_storage().fetch_benchmark({}) +def _fetch_benchmark(storage, name): + configs = storage.fetch_benchmark({"name": name}) if not configs: return {} @@ -160,9 +164,9 @@ def _fetch_benchmark(name): return configs[0] -def _register_benchmark(benchmark): +def _register_benchmark(storage, benchmark): benchmark.metadata["datetime"] = datetime.datetime.utcnow() config = benchmark.configuration # This will raise DuplicateKeyError if a concurrent experiment with # identical (name, metadata.user) is written first in the database. - get_storage().create_benchmark(config) + storage.create_benchmark(config) diff --git a/src/orion/client/__init__.py b/src/orion/client/__init__.py index 222e1f8fa..e48569510 100644 --- a/src/orion/client/__init__.py +++ b/src/orion/client/__init__.py @@ -17,7 +17,6 @@ ) from orion.client.experiment import ExperimentClient from orion.core.utils.exceptions import RaceCondition -from orion.core.utils.singleton import update_singletons from orion.core.worker.producer import Producer from orion.storage.base import setup_storage @@ -184,9 +183,6 @@ def build_experiment( Raises ------ - :class:`orion.core.utils.singleton.SingletonAlreadyInstantiatedError` - If the storage is already instantiated and given configuration is different. - Storage is a singleton, you may only use one instance per process. :class:`orion.core.utils.exceptions.NoConfigurationError` The experiment is not in database and no space is provided by the user. :class:`orion.core.utils.exceptions.RaceCondition` @@ -209,10 +205,10 @@ def build_experiment( "max_idle_time is deprecated. Use experiment.workon(reservation_timeout) instead." ) - setup_storage(storage=storage, debug=debug) + builder = experiment_builder.ExperimentBuilder(storage, debug) try: - experiment = experiment_builder.build( + experiment = builder.build( name, version=version, space=space, @@ -227,7 +223,7 @@ def build_experiment( # Try again, but if it fails again, raise. Race conditions due to version increment should # only occur once in a short window of time unless code version is changing at a crazy pace. try: - experiment = experiment_builder.build( + experiment = builder.build( name, version=version, space=space, @@ -279,9 +275,9 @@ def get_experiment(name, version=None, mode="r", storage=None): `orion.core.utils.exceptions.NoConfigurationError` The experiment is not in the database provided by the user. """ - setup_storage(storage) assert mode in set("rw") - experiment = experiment_builder.load(name, version, mode) + + experiment = experiment_builder.load(name, version, mode, storage=storage) return ExperimentClient(experiment) @@ -323,27 +319,21 @@ def workon( If the algorithm specified is not properly installed. """ - # Clear singletons and keep pointers to restore them. - singletons = update_singletons() - - try: - setup_storage(storage={"type": "legacy", "database": {"type": "EphemeralDB"}}) - - experiment = experiment_builder.build( - name, - version=1, - space=space, - algorithms=algorithms, - max_trials=max_trials, - max_broken=max_broken, - ) - - experiment_client = ExperimentClient(experiment) - with experiment_client.tmp_executor("singleexecutor", n_workers=1): - experiment_client.workon(function, n_workers=1, max_trials=max_trials) - - finally: - # Restore singletons - update_singletons(singletons) + experiment = experiment_builder.build( + name, + version=1, + space=space, + algorithms=algorithms, + max_trials=max_trials, + max_broken=max_broken, + storage={"type": "legacy", "database": {"type": "EphemeralDB"}}, + ) + + producer = Producer(experiment) + + experiment_client = ExperimentClient(experiment, producer) + + with experiment_client.tmp_executor("singleexecutor", n_workers=1): + experiment_client.workon(function, n_workers=1, max_trials=max_trials) return experiment_client diff --git a/src/orion/client/cli.py b/src/orion/client/cli.py index a25c87de2..7ca3ed00c 100644 --- a/src/orion/client/cli.py +++ b/src/orion/client/cli.py @@ -10,6 +10,7 @@ IS_ORION_ON = False _HAS_REPORTED_RESULTS = False RESULTS_FILENAME = os.getenv("ORION_RESULTS_PATH", None) + if RESULTS_FILENAME and os.path.isfile(RESULTS_FILENAME): import json @@ -17,8 +18,8 @@ if RESULTS_FILENAME and not IS_ORION_ON: raise RuntimeWarning( - "Results file path provided in environmental variable " - "does not correspond to an existing file." + f"Results file path ({RESULTS_FILENAME}) provided in environmental variable " + "does not correspond to an existing file. " ) diff --git a/src/orion/client/experiment.py b/src/orion/client/experiment.py index 1052b9506..40356e805 100644 --- a/src/orion/client/experiment.py +++ b/src/orion/client/experiment.py @@ -848,7 +848,7 @@ def _verify_reservation(self, trial): raise RuntimeError(f"Reservation for trial {trial.id} has been lost.") def _maintain_reservation(self, trial): - self._pacemakers[trial.id] = TrialPacemaker(trial) + self._pacemakers[trial.id] = TrialPacemaker(trial, self.storage) self._pacemakers[trial.id].start() def _release_reservation(self, trial, raise_if_unreserved=True): @@ -861,3 +861,8 @@ def _release_reservation(self, trial, raise_if_unreserved=True): return self._pacemakers.pop(trial.id).stop() + + @property + def storage(self): + """Return the storage currently in use by this client""" + return self._experiment.storage diff --git a/src/orion/core/cli/base.py b/src/orion/core/cli/base.py index 9014948fd..2577a8426 100644 --- a/src/orion/core/cli/base.py +++ b/src/orion/core/cli/base.py @@ -4,6 +4,7 @@ """ import argparse import logging +import os import sys import textwrap @@ -50,6 +51,13 @@ def __init__(self, description=CLI_DOC_HEADER): help="logging levels of information about the process (-v: INFO. -vv: DEBUG)", ) + self.parser.add_argument( + "--logdir", + type=str, + default=None, + help="Path to a directory to store logs", + ) + self.parser.add_argument( "-d", "--debug", @@ -69,9 +77,19 @@ def parse(self, argv): verbose = args.pop("verbose", 0) levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} + + logdir = args.pop("logdir") + logfile = None + if logdir is not None: + os.makedirs(logdir, exist_ok=True) + + pid = os.getpid() + logfile = os.path.join(logdir, f"orion_{pid}.log") + logging.basicConfig( format="%(asctime)-15s::%(levelname)s::%(name)s::%(message)s", level=levels.get(verbose, logging.DEBUG), + filename=logfile, ) logger.debug("Orion version : %s", orion.core.__version__) diff --git a/src/orion/core/cli/db/release.py b/src/orion/core/cli/db/release.py index 04f963ef6..22871129d 100644 --- a/src/orion/core/cli/db/release.py +++ b/src/orion/core/cli/db/release.py @@ -13,7 +13,6 @@ from orion.core.io import experiment_builder from orion.core.utils.pptree import print_tree from orion.core.utils.terminal import confirm_name -from orion.storage.base import get_storage logger = logging.getLogger(__name__) @@ -100,16 +99,14 @@ def release_locks(storage, root, name, force): def main(args): """Remove the experiment(s) or trial(s).""" config = experiment_builder.get_cmd_config(args) - experiment_builder.setup_storage(config.get("storage")) + builder = experiment_builder.ExperimentBuilder(config.get("storage")) # Find root experiment - root = experiment_builder.load( - name=args["name"], version=args.get("version", None) - ).node + root = builder.load(name=args["name"], version=args.get("version", None)).node # List all experiments with children print_tree(root, nameattr="tree_name") - storage = get_storage() + storage = builder.storage release_locks(storage, root, args["name"], args["force"]) diff --git a/src/orion/core/cli/db/rm.py b/src/orion/core/cli/db/rm.py index ab7634895..a31478b21 100644 --- a/src/orion/core/cli/db/rm.py +++ b/src/orion/core/cli/db/rm.py @@ -13,7 +13,6 @@ from orion.core.io import experiment_builder from orion.core.utils.pptree import print_tree from orion.core.utils.terminal import confirm_name -from orion.storage.base import get_storage logger = logging.getLogger(__name__) @@ -185,17 +184,15 @@ def delete_trials(storage, root, name, status, force): def main(args): """Remove the experiment(s) or trial(s).""" config = experiment_builder.get_cmd_config(args) - experiment_builder.setup_storage(config.get("storage")) + builder = experiment_builder.ExperimentBuilder(config.get("storage")) # Find root experiment - root = experiment_builder.load( - name=args["name"], version=args.get("version", None) - ).node + root = builder.load(name=args["name"], version=args.get("version", None)).node # List all experiments with children print_tree(root, nameattr="tree_name") - storage = get_storage() + storage = builder.storage if args["status"]: delete_trials(storage, root, args["name"], args["status"], args["force"]) diff --git a/src/orion/core/cli/db/set.py b/src/orion/core/cli/db/set.py index 874392392..371d6186b 100644 --- a/src/orion/core/cli/db/set.py +++ b/src/orion/core/cli/db/set.py @@ -14,7 +14,6 @@ import orion.core.io.experiment_builder as experiment_builder from orion.core.utils.pptree import print_tree from orion.core.utils.terminal import confirm_name -from orion.storage.base import get_storage logger = logging.getLogger(__name__) @@ -168,12 +167,10 @@ def build_update(update): def main(args): """Remove the experiment(s) or trial(s).""" config = experiment_builder.get_cmd_config(args) - experiment_builder.setup_storage(config.get("storage")) + builder = experiment_builder.ExperimentBuilder(config.get("storage")) # Find root experiment - root = experiment_builder.load( - name=args["name"], version=args.get("version", None) - ).node + root = builder.load(name=args["name"], version=args.get("version", None)).node try: query = build_query(root.item, args["query"]) @@ -195,8 +192,6 @@ def main(args): print("Confirmation failed, aborting operation.") return 1 - storage = get_storage() - - process_updates(storage, root, query, update) + process_updates(builder.storage, root, query, update) return 0 diff --git a/src/orion/core/cli/db/upgrade.py b/src/orion/core/cli/db/upgrade.py index 59ff4e5ca..a829e8c43 100644 --- a/src/orion/core/cli/db/upgrade.py +++ b/src/orion/core/cli/db/upgrade.py @@ -14,7 +14,6 @@ from orion.core.io.database.ephemeraldb import EphemeralCollection from orion.core.io.database.mongodb import MongoDB from orion.core.io.database.pickleddb import PickledDB -from orion.storage.base import get_storage from orion.storage.legacy import Legacy log = logging.getLogger(__name__) @@ -95,14 +94,12 @@ def main(args): storage_config["setup"] = False - experiment_builder.setup_storage(storage_config) + builder = experiment_builder.ExperimentBuilder(storage_config) - storage = get_storage() - - upgrade_db_specifics(storage) + upgrade_db_specifics(builder.storage) print("Updating documents...") - upgrade_documents(storage) + upgrade_documents(builder.storage) print("Database upgrade completed successfully") diff --git a/src/orion/core/cli/list.py b/src/orion/core/cli/list.py index 79c16d361..6d7e82dcd 100644 --- a/src/orion/core/cli/list.py +++ b/src/orion/core/cli/list.py @@ -11,7 +11,6 @@ from orion.core.cli import base as cli from orion.core.io import experiment_builder from orion.core.utils.pptree import print_tree -from orion.storage.base import get_storage log = logging.getLogger(__name__) SHORT_DESCRIPTION = "Gives a list of experiments" @@ -36,7 +35,7 @@ def add_subparser(parser): def main(args): """List all experiments inside database.""" config = experiment_builder.get_cmd_config(args) - experiment_builder.setup_storage(config.get("storage")) + builder = experiment_builder.ExperimentBuilder(config.get("storage")) query = {} @@ -44,7 +43,7 @@ def main(args): query["name"] = args["name"] query["version"] = args.get("version", None) or 1 - experiments = get_storage().fetch_experiments(query) + experiments = builder.storage.fetch_experiments(query) if args["name"]: root_experiments = experiments @@ -60,7 +59,7 @@ def main(args): return for root_experiment in root_experiments: - root = experiment_builder.load( + root = builder.load( name=root_experiment["name"], version=root_experiment.get("version") ).node print_tree(root, nameattr="tree_name") diff --git a/src/orion/core/cli/serve.py b/src/orion/core/cli/serve.py index ad9e39339..eb16c56f4 100644 --- a/src/orion/core/cli/serve.py +++ b/src/orion/core/cli/serve.py @@ -13,6 +13,7 @@ from orion.core.io import experiment_builder from orion.serving.webapi import WebApi +from orion.storage.base import setup_storage log = logging.getLogger(__name__) DESCRIPTION = "Starts OrĂ­on's REST API server" @@ -39,7 +40,8 @@ def main(args): """Starts an application server to serve http requests""" config = experiment_builder.get_cmd_config(args) - web_api = WebApi(config) + storage = setup_storage(config.get("storage")) + web_api = WebApi(storage, config) gunicorn_app = GunicornApp(web_api) gunicorn_app.run() diff --git a/src/orion/core/cli/status.py b/src/orion/core/cli/status.py index 0a54d7068..0cbcae4b2 100644 --- a/src/orion/core/cli/status.py +++ b/src/orion/core/cli/status.py @@ -13,7 +13,7 @@ from orion.core.cli import base as cli from orion.core.io import experiment_builder -from orion.storage.base import get_storage +from orion.storage.base import setup_storage log = logging.getLogger(__name__) SHORT_DESCRIPTION = "Gives an overview of experiments' trials" @@ -67,11 +67,11 @@ def add_subparser(parser): def main(args): """Fetch config and status experiments""" config = experiment_builder.get_cmd_config(args) - experiment_builder.setup_storage(config.get("storage")) + storage = setup_storage(config.get("storage")) args["all_trials"] = args.pop("all", False) - experiments = get_experiments(args) + experiments = get_experiments(storage, args) if not experiments: print("No experiment found") @@ -121,7 +121,7 @@ def print_evc( print_status(experiment, all_trials=all_trials, collapse=True) -def get_experiments(args): +def get_experiments(storage, args): """Return the different experiments. Parameters @@ -133,7 +133,7 @@ def get_experiments(args): projection = {"name": 1, "version": 1, "refers": 1} query = {"name": args["name"]} if args.get("name") else {} - experiments = get_storage().fetch_experiments(query, projection) + experiments = storage.fetch_experiments(query, projection) if args["name"]: root_experiments = experiments diff --git a/src/orion/core/evc/conflicts.py b/src/orion/core/evc/conflicts.py index 0e5d16e7f..92128254d 100644 --- a/src/orion/core/evc/conflicts.py +++ b/src/orion/core/evc/conflicts.py @@ -58,7 +58,6 @@ from orion.core.io.space_builder import SpaceBuilder from orion.core.utils.diff import colored_diff from orion.core.utils.format_trials import standard_param_name -from orion.storage.base import get_storage log = logging.getLogger(__name__) @@ -288,8 +287,15 @@ def try_resolve(self, conflict, *args, **kwargs): except Exception: # pylint:disable=broad-except conflict.resolution = None conflict._is_resolved = None # pylint:disable=protected-access + + msg = traceback.format_exc() + + # no silence error in debug mode + log.debug("%s", msg) + if not silence_errors: - print(traceback.format_exc()) + print(msg) + return None if resolution: @@ -372,7 +378,7 @@ def get_marked_arguments(self, conflicts, **branching_kwargs): return {} @abstractmethod - def try_resolve(self): + def try_resolve(self, *args, **kwargs): """Try to create a resolution Conflict is then marked as resolved and its attribute `resolution` now points to the @@ -564,7 +570,7 @@ def __init__(self, old_config, new_config, dimension, prior): self.dimension = dimension self.prior = prior - def try_resolve(self, default_value=Dimension.NO_DEFAULT_VALUE): + def try_resolve(self, default_value=Dimension.NO_DEFAULT_VALUE, *args, **kwargs): """Try to create a resolution AddDimensionResolution Parameters @@ -703,7 +709,7 @@ def __init__(self, old_config, new_config, dimension, old_prior, new_prior): self.old_prior = old_prior self.new_prior = new_prior - def try_resolve(self): + def try_resolve(self, *args, **kwargs): """Try to create a resolution ChangeDimensionResolution""" if self.is_resolved: return None @@ -868,7 +874,11 @@ def get_marked_rename_arguments(self, conflicts): return {} def try_resolve( - self, new_dimension_conflict=None, default_value=Dimension.NO_DEFAULT_VALUE + self, + new_dimension_conflict=None, + default_value=Dimension.NO_DEFAULT_VALUE, + *args, + **kwargs, ): """Try to create a resolution RenameDimensionResolution of RemoveDimensionResolution @@ -1076,7 +1086,7 @@ def detect(cls, old_config, new_config, branching_config=None): if old_config["algorithms"] != new_config["algorithms"]: yield cls(old_config, new_config) - def try_resolve(self): + def try_resolve(self, *args, **kwargs): """Try to create a resolution AlgorithmResolution""" if self.is_resolved: return None @@ -1171,7 +1181,7 @@ def get_marked_arguments( return dict(change_type=code_change_type) - def try_resolve(self, change_type=None): + def try_resolve(self, change_type=None, *args, **kwargs): """Try to create a resolution CodeResolution Parameters @@ -1339,7 +1349,7 @@ def get_marked_arguments(self, conflicts, cli_change_type=None, **branching_kwar return dict(change_type=cli_change_type) - def try_resolve(self, change_type=None): + def try_resolve(self, change_type=None, *args, **kwargs): """Try to create a resolution CommandLineResolution Parameters @@ -1492,7 +1502,7 @@ def get_marked_arguments( return dict(change_type=config_change_type) - def try_resolve(self, change_type=None): + def try_resolve(self, change_type=None, *args, **kwargs): """Try to create a resolution ScriptConfigResolution Parameters @@ -1611,7 +1621,7 @@ def version(self): """Retrieve version of configuration""" return self.old_config["version"] - def try_resolve(self, new_name=None): + def try_resolve(self, new_name=None, storage=None, *args, **kwargs): """Try to create a resolution ExperimentNameResolution Parameters @@ -1629,7 +1639,7 @@ def try_resolve(self, new_name=None): if self.is_resolved: return None - return self.ExperimentNameResolution(self, new_name) + return self.ExperimentNameResolution(self, new_name, storage=storage) @property def diff(self): @@ -1660,7 +1670,7 @@ class ExperimentNameResolution(Resolution): ARGUMENT = "--branch-to" - def __init__(self, conflict, new_name): + def __init__(self, conflict, new_name, storage=None): """Initialize resolution and mark conflict as resolved Parameters @@ -1685,18 +1695,18 @@ def __init__(self, conflict, new_name): self.old_name = self.conflict.old_config["name"] self.old_version = self.conflict.old_config.get("version", 1) self.new_version = self.old_version - self.validate() + self.validate(storage=storage) self.conflict.new_config["name"] = self.new_name self.conflict.new_config["version"] = self.new_version - def _validate(self): + def _validate(self, storage=None): """Validate new_name is not in database with a direct child for current version""" # TODO: WARNING!!! _name_is_unique could lead to race conditions, # The resolution may become invalid before the branching experiment is # registered. What should we do in such case? if self.new_name is not None and self.new_name != self.old_name: # If we are trying to actually branch from experiment - if not self._name_is_unique(): + if not self._name_is_unique(storage): raise ValueError( f"Cannot branch from {self.old_name} with name {self.new_name} " "since it already exists." @@ -1706,7 +1716,7 @@ def _validate(self): # If the new name is the same as the old name, we are trying to increment # the version of the experiment. - elif self._check_for_greater_versions(): + elif self._check_for_greater_versions(storage): raise ValueError( f"Experiment name '{self.new_name}' already exist for version " f"'{self.conflict.version}' and has children. Version cannot be " @@ -1716,20 +1726,20 @@ def _validate(self): self.new_name = self.old_name self.new_version = self.conflict.old_config.get("version", 1) + 1 - def _name_is_unique(self): + def _name_is_unique(self, storage): """Return True if given name is not in database for current version""" query = {"name": self.new_name, "version": self.conflict.version} - named_experiments = len(get_storage().fetch_experiments(query)) + named_experiments = len(storage.fetch_experiments(query)) return named_experiments == 0 - def _check_for_greater_versions(self): + def _check_for_greater_versions(self, storage): """Check if experiment has children""" # If we made it this far, new_name is actually the name of the parent. parent = self.conflict.old_config query = {"name": parent["name"], "refers.parent_id": parent["_id"]} - children = len(get_storage().fetch_experiments(query)) + children = len(storage.fetch_experiments(query)) return bool(children) @@ -1772,7 +1782,7 @@ def detect(cls, old_config, new_config, branching_config=None): ): yield cls(old_config, new_config) - def try_resolve(self): + def try_resolve(self, *args, **kwargs): """Try to create a resolution OrionVersionResolution""" if self.is_resolved: return None diff --git a/src/orion/core/evc/experiment.py b/src/orion/core/evc/experiment.py index f75c447be..a4a6393eb 100644 --- a/src/orion/core/evc/experiment.py +++ b/src/orion/core/evc/experiment.py @@ -17,7 +17,6 @@ import logging from orion.core.utils.tree import TreeNode -from orion.storage.base import get_storage log = logging.getLogger(__name__) @@ -47,9 +46,18 @@ class ExperimentNode(TreeNode): "version", "_no_parent_lookup", "_no_children_lookup", + "storage", ) + TreeNode.__slots__ - def __init__(self, name, version, experiment=None, parent=None, children=tuple()): + def __init__( + self, + name, + version, + experiment=None, + parent=None, + children=tuple(), + storage=None, + ): """Initialize experiment node with item, experiment, parent and children .. seealso:: @@ -61,6 +69,7 @@ def __init__(self, name, version, experiment=None, parent=None, children=tuple() self._no_parent_lookup = True self._no_children_lookup = True + self.storage = storage or experiment._storage @property def item(self): @@ -73,7 +82,9 @@ def item(self): # TODO: Find another way around the circular import from orion.core.io import experiment_builder - self._item = experiment_builder.load(name=self.name, version=self.version) + self._item = experiment_builder.load( + name=self.name, version=self.version, storage=self.storage + ) self._item._node = self return self._item @@ -92,12 +103,14 @@ def parent(self): self._no_parent_lookup = False query = {"_id": self.item.refers.get("parent_id")} selection = {"name": 1, "version": 1} - experiments = get_storage().fetch_experiments(query, selection) + experiments = self.storage.fetch_experiments(query, selection) if experiments: parent = experiments[0] exp_node = ExperimentNode( - name=parent["name"], version=parent.get("version", 1) + name=parent["name"], + version=parent.get("version", 1), + storage=self.storage, ) self.set_parent(exp_node) return self._parent @@ -117,10 +130,14 @@ def children(self): self._no_children_lookup = False query = {"refers.parent_id": self.item.id} selection = {"name": 1, "version": 1} - experiments = get_storage().fetch_experiments(query, selection) + experiments = self.storage.fetch_experiments(query, selection) for child in experiments: self.add_children( - ExperimentNode(name=child["name"], version=child.get("version", 1)) + ExperimentNode( + name=child["name"], + version=child.get("version", 1), + storage=self.storage, + ) ) return self._children diff --git a/src/orion/core/io/config.py b/src/orion/core/io/config.py index 808c2b2ce..b67ed4230 100644 --- a/src/orion/core/io/config.py +++ b/src/orion/core/io/config.py @@ -11,6 +11,7 @@ import contextlib import logging import os +import pprint import yaml @@ -451,3 +452,27 @@ def to_dict(self): config[key] = self[key].to_dict() return config + + def from_dict(self, config): + """Set the configuration from a dictionary""" + + logger.debug("Setting config to %s", config) + logger.debug("Config was %s", repr(self)) + + with _disable_logger(): + for key in self._config: # pylint: disable=consider-using-dict-items + value = config.get(key, NOT_SET) + + if value is not NOT_SET: + self[key] = value + else: + self._config[key].pop("value", None) + + for key in self._subconfigs: + value = config.get(key) + self[key].from_dict(value) + + logger.debug("Config is %s", repr(self)) + + def __repr__(self) -> str: + return pprint.pformat(self.to_dict()) diff --git a/src/orion/core/io/database/__init__.py b/src/orion/core/io/database/__init__.py index a62bc75de..2e4c60399 100644 --- a/src/orion/core/io/database/__init__.py +++ b/src/orion/core/io/database/__init__.py @@ -12,7 +12,7 @@ import logging from abc import abstractmethod, abstractproperty -from orion.core.utils.singleton import GenericSingletonFactory +from orion.core.utils import Factory, GenericFactory # pylint: disable=too-many-public-methods @@ -320,7 +320,7 @@ class OutdatedDatabaseError(DatabaseError): """Exception type used when the database is outdated.""" -database_factory = GenericSingletonFactory(Database) +database_factory = GenericFactory(Database) # set per-module log level diff --git a/src/orion/core/io/experiment_branch_builder.py b/src/orion/core/io/experiment_branch_builder.py index ecf9c1719..ae509d782 100644 --- a/src/orion/core/io/experiment_branch_builder.py +++ b/src/orion/core/io/experiment_branch_builder.py @@ -54,8 +54,14 @@ class ExperimentBranchBuilder: """ def __init__( - self, conflicts, enabled=True, manual_resolution=None, **branching_arguments + self, + conflicts, + enabled=True, + manual_resolution=None, + storage=None, + **branching_arguments, ): + self.storage = storage # TODO: handle all other arguments if manual_resolution is None: manual_resolution = orion.core.config.evc.manual_resolution @@ -95,6 +101,7 @@ def resolve_conflicts(self, silence_errors=True): resolution = self.conflicts.try_resolve( conflict, silence_errors=silence_errors, + storage=self.storage, **conflict.get_marked_arguments( self.conflicts, **self.branching_arguments ), @@ -133,7 +140,7 @@ def change_experiment_name(self, name): if not exp_name_conflicts: raise RuntimeError("No experiment name conflict to solve") - self.conflicts.try_resolve(exp_name_conflicts[0], name) + self.conflicts.try_resolve(exp_name_conflicts[0], name, storage=self.storage) def set_code_change_type(self, change_type): """Set code change type diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index 833ebaa41..59912cb58 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -102,7 +102,7 @@ ) from orion.core.worker.experiment import Experiment, Mode from orion.core.worker.primary_algo import create_algo -from orion.storage.base import get_storage, setup_storage +from orion.storage.base import setup_storage log = logging.getLogger(__name__) @@ -112,116 +112,6 @@ ## -def build( - name: str, version: int | None = None, branching: dict | None = None, **config -) -> Experiment: - """Build an experiment object - - If new, ``space`` argument must be provided, else all arguments are fetched from the database - based on (name, version). If any argument given does not match the corresponding ones in the - database for given (name, version), than the version is incremented and the experiment will be a - child of the previous version. - - Parameters - ---------- - name: str - Name of the experiment to build - version: int, optional - Version to select. If None, last version will be selected. If version given is larger than - largest version available, the largest version will be selected. - space: dict, optional - Optimization space of the algorithm. Should have the form ``dict(name='(args)')``. - algorithms: str or dict, optional - Algorithm used for optimization. - strategy: str or dict, optional - Deprecated and will be remove in v0.4. It should now be set in algorithm configuration - directly if it supports it. - max_trials: int, optional - Maximum number of trials before the experiment is considered done. - max_broken: int, optional - Number of broken trials for the experiment to be considered broken. - storage: dict, optional - Configuration of the storage backend. - - branching: dict, optional - Arguments to control the branching. - - branch_from: str, optional - Name of the experiment to branch from. - manual_resolution: bool, optional - Starts the prompt to resolve manually the conflicts. Defaults to False. - non_monitored_arguments: list of str, optional - Will ignore these arguments while looking for differences. Defaults to []. - ignore_code_changes: bool, optional - Will ignore code changes while looking for differences. Defaults to False. - algorithm_change: bool, optional - Whether to automatically solve the algorithm conflict (change of algo config). - Defaults to True. - orion_version_change: bool, optional - Whether to automatically solve the orion version conflict. - Defaults to True. - code_change_type: str, optional - How to resolve code change automatically. Must be one of 'noeffect', 'unsure' or - 'break'. Defaults to 'break'. - cli_change_type: str, optional - How to resolve cli change automatically. Must be one of 'noeffect', 'unsure' or 'break'. - Defaults to 'break'. - config_change_type: str, optional - How to resolve config change automatically. Must be one of 'noeffect', 'unsure' or - 'break'. Defaults to 'break'. - - """ - log.debug(f"Building experiment {name} with {version}") - log.debug(" Passed experiment config:\n%s", pprint.pformat(config)) - log.debug(" Branching config:\n%s", pprint.pformat(branching)) - - name, config, branching = clean_config(name, config, branching) - - config = consolidate_config(name, version, config) - - if "space" not in config: - raise NoConfigurationError( - f"Experiment {name} does not exist in DB and space was not defined." - ) - - if len(config["space"]) == 0: - raise NoConfigurationError("No prior found. Please include at least one.") - - experiment = create_experiment(mode="x", **copy.deepcopy(config)) - if experiment.id is None: - log.debug("Experiment not found in DB. Now attempting registration in DB.") - try: - _register_experiment(experiment) - log.debug("Experiment successfully registered in DB.") - except DuplicateKeyError: - log.debug( - "Experiment registration failed. This is likely due to a race condition. " - "Now rolling back and re-attempting building it." - ) - experiment = build(branching=branching, **config) - - return experiment - - log.debug(f"Experiment {config['name']}-v{config['version']} already existed.") - - conflicts = _get_conflicts(experiment, branching) - assert branching is not None - must_branch = len(conflicts.get()) > 1 or branching.get("branch_to") - - if must_branch and branching.get("enable", orion.core.config.evc.enable): - return _attempt_branching(conflicts, experiment, version, branching) - elif must_branch: - log.warning( - "Running experiment in a different state:\n%s", - _get_branching_status_string(conflicts, branching), - ) - - log.debug("No branching required.") - - _update_experiment(experiment) - return experiment - - def clean_config(name: str, config: dict, branching: dict | None): """Clean configuration from hidden fields (ex: ``_id``) and update branching if necessary""" log.debug("Cleaning config") @@ -250,38 +140,6 @@ def clean_config(name: str, config: dict, branching: dict | None): return name, config, branching -def consolidate_config(name: str, version: int | None, config: dict): - """Merge together given configuration with db configuration matching - for experiment (``name``, ``version``) - """ - db_config = fetch_config_from_db(name, version) - - # Do not merge spaces, the new definition overrides it. - if "space" in config: - db_config.pop("space", None) - - log.debug("Merging user and db configs:") - log.debug(" config from user:\n%s", pprint.pformat(config)) - log.debug(" config from DB:\n%s", pprint.pformat(db_config)) - - new_config = config - config = resolve_config.merge_configs(db_config, config) - - config.setdefault("metadata", {}) - resolve_config.update_metadata(config["metadata"]) - - merge_algorithm_config(config, new_config) - # TODO: Remove for v0.4 - merge_producer_config(config, new_config) - - config.setdefault("name", name) - config.setdefault("version", version) - - log.debug(" Merged config:\n%s", pprint.pformat(config)) - - return config - - def merge_algorithm_config(config: dict, new_config: dict) -> None: """Merge given algorithm configuration with db config""" # TODO: Find a better solution @@ -308,183 +166,6 @@ def merge_producer_config(config: dict, new_config: dict) -> None: config["producer"]["strategy"] = new_config["producer"]["strategy"] -def build_view(name, version=None): - """Load experiment from database - - This function is deprecated and will be remove in v0.3.0. Use `load()` instead. - """ - return load(name, version=version, mode="r") - - -def load(name: str, version: int | None = None, mode: Mode = "r") -> Experiment: - """Load experiment from database - - An experiment view provides all reading operations of standard experiment but prevents the - modification of the experiment and its trials. - - Parameters - ---------- - name: str - Name of the experiment to build - version: int, optional - Version to select. If None, last version will be selected. If version given is larger than - largest version available, the largest version will be selected. - mode: str, optional - The access rights of the experiment on the database. - 'r': read access only - 'w': can read and write to database - Default is 'r' - - """ - assert mode in set("rw") - - log.debug( - f"Loading experiment {name} (version={version}) from database in mode `{mode}`" - ) - db_config = fetch_config_from_db(name, version) - - if not db_config: - raise NoConfigurationError( - f"No experiment with given name '{name}' " - f"and version '{version if version else '*'}' inside database, " - "no view can be created." - ) - - db_config.setdefault("version", 1) - - return create_experiment(mode=mode, **db_config) - - -# pylint: disable=too-many-arguments -def create_experiment( - name: str, - version: int, - mode: Mode, - space: Space | dict[str, str], - algorithms: str | dict | None = None, - max_trials: int | None = None, - max_broken: int | None = None, - working_dir: str | None = None, - metadata: dict | None = None, - refers: dict | None = None, - producer: dict | None = None, - user: str | None = None, - _id: int | str | None = None, - **kwargs, -) -> Experiment: - """Instantiate the experiment and its attribute objects - - All unspecified arguments will be replaced by system's defaults (orion.core.config.*). - - Parameters - ---------- - name: str - Name of the experiment. - version: int - Version of the experiment. - mode: str - The access rights of the experiment on the database. - 'r': read access only - 'w': can read and write to database - 'x': can read and write to database, algo is instantiated and can execute optimization - space: dict or Space object - Optimization space of the algorithm. If dict, should have the form - `dict(name='(args)')`. - algorithms: str or dict, optional - Algorithm used for optimization. - strategy: str or dict, optional - Parallel strategy to use to parallelize the algorithm. - max_trials: int, optional - Maximum number or trials before the experiment is considered done. - max_broken: int, optional - Number of broken trials for the experiment to be considered broken. - storage: dict, optional - Configuration of the storage backend. - - """ - - T = TypeVar("T") - V = TypeVar("V") - - def _default(v: T | None, default: V) -> T | V: - return v if v is not None else default - - space = _instantiate_space(space) - max_trials = _default(max_trials, orion.core.config.experiment.max_trials) - instantiated_algorithm = _instantiate_algo( - space=space, - max_trials=max_trials, - config=algorithms, - ignore_unavailable=mode != "x", - ) - - max_broken = _default(max_broken, orion.core.config.experiment.max_broken) - working_dir = _default(working_dir, orion.core.config.experiment.working_dir) - metadata = _default(metadata, {"user": _default(user, getpass.getuser())}) - refers = _default(refers, dict(parent_id=None, root_id=None, adapter=[])) - refers["adapter"] = _instantiate_adapters(refers.get("adapter", [])) # type: ignore - - # TODO: Remove for v0.4 - _instantiate_strategy((producer or {}).get("strategy")) - - experiment = Experiment( - name=name, - version=version, - mode=mode, - space=space, - _id=_id, - max_trials=max_trials, - algorithms=instantiated_algorithm, - max_broken=max_broken, - working_dir=working_dir, - metadata=metadata, - refers=refers, - ) - log.debug( - "Created experiment with config:\n%s", pprint.pformat(experiment.configuration) - ) - if kwargs: - # TODO: https://github.com/Epistimio/orion/issues/972 - log.debug("create_experiment received some extra unused arguments: %s", kwargs) - - return experiment - - -def fetch_config_from_db(name, version=None): - """Fetch configuration from database - - Parameters - ---------- - name: str - Name of the experiment to fetch - version: int, optional - Version to select. If None, last version will be selected. If version given is larger than - largest version available, the largest version will be selected. - - """ - configs = get_storage().fetch_experiments({"name": name}) - - if not configs: - return {} - - config = _fetch_config_version(configs, version) - - if len(configs) > 1 and version is None: - log.info( - "Many versions for experiment %s have been found. Using latest " - "version %s.", - name, - config["version"], - ) - - log.debug("Config found in DB:\n%s", pprint.pformat(config)) - - backward.populate_space(config, force_update=False) - backward.update_max_broken(config) - - return config - - ## # Private helper functions to build experiments ## @@ -577,141 +258,6 @@ def _instantiate_strategy(config=None): return None -def _register_experiment(experiment): - """Register a new experiment in the database""" - experiment.metadata["datetime"] = datetime.datetime.utcnow() - config = experiment.configuration - # This will raise DuplicateKeyError if a concurrent experiment with - # identical (name, metadata.user) is written first in the database. - - get_storage().create_experiment(config) - - # XXX: Reminder for future DB implementations: - # MongoDB, updates an inserted dict with _id, so should you :P - experiment._id = config["_id"] # pylint:disable=protected-access - - # Update refers in db if experiment is root - if experiment.refers.get("parent_id") is None: - log.debug("update refers (name: %s)", experiment.name) - experiment.refers["root_id"] = experiment.id - get_storage().update_experiment( - experiment, refers=experiment.configuration["refers"] - ) - - -def _update_experiment(experiment: Experiment) -> None: - """Update experiment configuration in database""" - log.debug("Updating experiment (name: %s)", experiment.name) - config = experiment.configuration - - # TODO: Remove since this should not occur anymore without metadata.user in the indices? - # Writing the final config to an already existing experiment raises - # a DuplicatKeyError because of the embedding id `metadata.user`. - # To avoid this `final_config["name"]` is popped out before - # `db.write()`, thus seamingly breaking the compound index - # `(name, metadata.user)` - config.pop("name") - - get_storage().update_experiment(experiment, **config) - - log.debug("Experiment configuration successfully updated in DB.") - - -def _attempt_branching(conflicts, experiment, version, branching): - if len(conflicts.get()) > 1: - log.debug("Experiment must branch because of conflicts") - else: - assert branching.get("branch_to") - log.debug("Experiment branching forced with ``branch_to``") - branched_experiment = _branch_experiment(experiment, conflicts, version, branching) - log.debug("Now attempting registration of branched experiment in DB.") - try: - _register_experiment(branched_experiment) - log.debug("Branched experiment successfully registered in DB.") - except DuplicateKeyError as e: - log.debug( - "Experiment registration failed. This is likely due to a race condition " - "during branching. Now rolling back and re-attempting building " - "the branched experiment." - ) - raise RaceCondition( - "There was a race condition during branching. This error can " - "also occur if you try branching from a specific version that already " - "has a child experiment with the same name. Change the name of the new " - "experiment and use `branch-from` to specify the parent experiment." - ) from e - - return branched_experiment - - -def _get_branching_status_string(conflicts, branching_arguments): - experiment_brancher = ExperimentBranchBuilder( - conflicts, enabled=False, **branching_arguments - ) - branching_prompt = BranchingPrompt(experiment_brancher) - return branching_prompt.get_status() - - -def _branch_experiment(experiment, conflicts, version, branching_arguments): - """Create a new branch experiment with adapters for the given conflicts""" - experiment_brancher = ExperimentBranchBuilder(conflicts, **branching_arguments) - - needs_manual_resolution = ( - not experiment_brancher.is_resolved or experiment_brancher.manual_resolution - ) - - if not experiment_brancher.is_resolved: - name_conflict = conflicts.get([ExperimentNameConflict])[0] - if not name_conflict.is_resolved and not version: - log.debug( - "A race condition likely occurred during conflicts resolutions. " - "Now rolling back and attempting re-building the branched experiment." - ) - raise RaceCondition( - "There was likely a race condition during version increment." - ) - - if needs_manual_resolution: - log.debug("Some conflicts cannot be solved automatically.") - - # TODO: This should only be possible when using cmdline API - branching_prompt = BranchingPrompt(experiment_brancher) - - if not sys.__stdin__.isatty(): - log.debug("No interactive prompt available to manually resolve conflicts.") - raise BranchingEvent(branching_prompt.get_status()) - - branching_prompt.cmdloop() - - if branching_prompt.abort or not experiment_brancher.is_resolved: - sys.exit() - - log.debug("Creating new branched configuration") - config = experiment_brancher.conflicting_config - config["refers"]["adapter"] = experiment_brancher.create_adapters().configuration - config["refers"]["parent_id"] = experiment.id - - config.pop("_id") - - return create_experiment(mode="x", **config) - - -def _get_conflicts(experiment, branching): - """Get conflicts between current experiment and corresponding configuration in database""" - log.debug("Looking for conflicts in new configuration.") - db_experiment = load(experiment.name, experiment.version, mode="r") - conflicts = detect_conflicts( - db_experiment.configuration, experiment.configuration, branching - ) - - log.debug(f"{len(conflicts.get())} conflicts detected:\n {conflicts.get()}") - - # elif must_branch and not enable_branching: - # raise ValueError("Configuration is different and generate a branching event") - - return conflicts - - def _fetch_config_version(configs, version=None): """Fetch the experiment configuration corresponding to the given version @@ -751,6 +297,51 @@ def _fetch_config_version(configs, version=None): ### +def get_cmd_config(cmdargs): + """Fetch configuration defined by commandline and local configuration file. + + Arguments of commandline have priority over options in configuration file. + """ + cmdargs = resolve_config.fetch_config_from_cmdargs(cmdargs) + + cmd_config = resolve_config.fetch_config(cmdargs) + cmd_config = resolve_config.merge_configs(cmd_config, cmdargs) + + cmd_config.update(cmd_config.pop("experiment", {})) + cmd_config["user_script_config"] = cmd_config.get("worker", {}).get( + "user_script_config", None + ) + + cmd_config["branching"] = cmd_config.pop("evc", {}) + + # TODO: We should move branching specific stuff below in a centralized place for EVC stuff. + if ( + cmd_config["branching"].get("auto_resolution", False) + and cmdargs.get("manual_resolution", None) is None + ): + cmd_config["branching"]["manual_resolution"] = False + + non_monitored_arguments = cmdargs.get("non_monitored_arguments") + if non_monitored_arguments: + cmd_config["branching"][ + "non_monitored_arguments" + ] = non_monitored_arguments.split(":") + + # TODO: user_args won't be defined if reading from DB only (`orion hunt -n ` alone) + metadata = resolve_config.fetch_metadata( + cmd_config.get("user"), + cmd_config.get("user_args"), + cmd_config.get("user_script_config"), + ) + cmd_config["metadata"] = metadata + cmd_config.pop("config", None) + + cmd_config["space"] = cmd_config["metadata"].get("priors", None) + + backward.update_db_config(cmd_config) + return cmd_config + + def build_from_args(cmdargs): """Build an experiment based on commandline arguments. @@ -766,12 +357,13 @@ def build_from_args(cmdargs): cmd_config = get_cmd_config(cmdargs) + # breakpoint() if "name" not in cmd_config: raise NoNameError() - setup_storage(cmd_config["storage"], debug=cmd_config.get("debug")) + builder = ExperimentBuilder(cmd_config["storage"], debug=cmd_config.get("debug")) - return build(**cmd_config) + return builder.build(**cmd_config) def get_from_args(cmdargs, mode="r"): @@ -788,54 +380,515 @@ def get_from_args(cmdargs, mode="r"): if "name" not in cmd_config: raise NoNameError() - setup_storage(cmd_config["storage"], debug=cmd_config.get("debug")) + builder = ExperimentBuilder(cmd_config["storage"], debug=cmd_config.get("debug")) name = cmd_config.get("name") version = cmd_config.get("version") - return load(name, version, mode=mode) + return builder.load(name, version, mode=mode) -def get_cmd_config(cmdargs): - """Fetch configuration defined by commandline and local configuration file. +def build(name, version=None, branching=None, storage=None, **config): + """Build an experiment. + + .. seealso:: + + :func:`orion.core.io.experiment_builder.Experiment.build` for more information - Arguments of commandline have priority over options in configuration file. """ - cmdargs = resolve_config.fetch_config_from_cmdargs(cmdargs) - cmd_config = resolve_config.fetch_config(cmdargs) - cmd_config = resolve_config.merge_configs(cmd_config, cmdargs) + if storage is None: + storage = setup_storage() - cmd_config.update(cmd_config.pop("experiment", {})) - cmd_config["user_script_config"] = cmd_config.get("worker", {}).get( - "user_script_config", None - ) + return ExperimentBuilder(storage).build(name, version, branching, **config) - cmd_config["branching"] = cmd_config.pop("evc", {}) - # TODO: We should move branching specific stuff below in a centralized place for EVC stuff. - if ( - cmd_config["branching"].get("auto_resolution", False) - and cmdargs.get("manual_resolution", None) is None - ): - cmd_config["branching"]["manual_resolution"] = False +def load(name, version=None, mode="r", storage=None): + """Load an experiment. - non_monitored_arguments = cmdargs.get("non_monitored_arguments") - if non_monitored_arguments: - cmd_config["branching"][ - "non_monitored_arguments" - ] = non_monitored_arguments.split(":") + .. seealso:: - # TODO: user_args won't be defined if reading from DB only (`orion hunt -n ` alone) - metadata = resolve_config.fetch_metadata( - cmd_config.get("user"), - cmd_config.get("user_args"), - cmd_config.get("user_script_config"), - ) - cmd_config["metadata"] = metadata - cmd_config.pop("config", None) + :func:`orion.core.io.experiment_builder.Experiment.load` for more information - cmd_config["space"] = cmd_config["metadata"].get("priors", None) + """ + if storage is None: + storage = setup_storage() + return ExperimentBuilder(storage).load(name, version, mode) - backward.update_db_config(cmd_config) - return cmd_config +class ExperimentBuilder: + """Utility to make new experiments using the same storage object. + + Parameters + ---------- + storage: dict or BaseStorageProtocol, optional + Storage object or storage configuration. + debug: bool, optional. + If True, force using EphemeralDB for the storage. Default: False + """ + + def __init__(self, storage=None, debug=False) -> None: + singleton = None + log.debug("Using for storage %s", storage) + + if not isinstance(storage, dict): + singleton = storage + storage = None + + if singleton is None: + if storage is None: + log.debug("Setting up storage from default config") + + self.storage_config = storage + self.storage = setup_storage(storage, debug=debug) + else: + self.storage = singleton + + def build( + self, + name: str, + version: int | None = None, + branching: dict | None = None, + **config, + ) -> Experiment: + """Build an experiment object + + If new, ``space`` argument must be provided, else all arguments are fetched from the + database based on (name, version). If any argument given does not match the corresponding + ones in the database for given (name, version), than the version is incremented and the + experiment will be a child of the previous version. + + Parameters + ---------- + name: str + Name of the experiment to build + version: int, optional + Version to select. If None, last version will be selected. + If version given is larger than largest version available, the largest version + will be selected. + space: dict, optional + Optimization space of the algorithm. + Should have the form ``dict(name='(args)')``. + algorithms: str or dict, optional + Algorithm used for optimization. + strategy: str or dict, optional + Deprecated and will be remove in v0.4. It should now be set in algorithm configuration + directly if it supports it. + max_trials: int, optional + Maximum number of trials before the experiment is considered done. + max_broken: int, optional + Number of broken trials for the experiment to be considered broken. + branching: dict, optional + Arguments to control the branching. + + branch_from: str, optional + Name of the experiment to branch from. + manual_resolution: bool, optional + Starts the prompt to resolve manually the conflicts. Defaults to False. + non_monitored_arguments: list of str, optional + Will ignore these arguments while looking for differences. Defaults to []. + ignore_code_changes: bool, optional + Will ignore code changes while looking for differences. Defaults to False. + algorithm_change: bool, optional + Whether to automatically solve the algorithm conflict (change of algo config). + Defaults to True. + orion_version_change: bool, optional + Whether to automatically solve the orion version conflict. + Defaults to True. + code_change_type: str, optional + How to resolve code change automatically. Must be one of 'noeffect', 'unsure' or + 'break'. Defaults to 'break'. + cli_change_type: str, optional + How to resolve cli change automatically. + Must be one of 'noeffect', 'unsure' or 'break'. + Defaults to 'break'. + config_change_type: str, optional + How to resolve config change automatically. Must be one of 'noeffect', 'unsure' or + 'break'. Defaults to 'break'. + + """ + log.debug(f"Building experiment {name} with {version}") + log.debug(" Passed experiment config:\n%s", pprint.pformat(config)) + log.debug(" Branching config:\n%s", pprint.pformat(branching)) + + name, config, branching = clean_config(name, config, branching) + + config = self.consolidate_config(name, version, config) + + if "space" not in config: + raise NoConfigurationError( + f"Experiment {name} does not exist in DB and space was not defined." + ) + + if len(config["space"]) == 0: + raise NoConfigurationError("No prior found. Please include at least one.") + + experiment = self.create_experiment(mode="x", **copy.deepcopy(config)) + if experiment.id is None: + log.debug("Experiment not found in DB. Now attempting registration in DB.") + try: + self._register_experiment(experiment) + log.debug("Experiment successfully registered in DB.") + except DuplicateKeyError: + log.debug( + "Experiment registration failed. This is likely due to a race condition. " + "Now rolling back and re-attempting building it." + ) + experiment = self.build(branching=branching, **config) + + return experiment + + log.debug(f"Experiment {config['name']}-v{config['version']} already existed.") + + conflicts = self._get_conflicts(experiment, branching) + must_branch = len(conflicts.get()) > 1 or branching.get("branch_to") + + if must_branch and branching.get("enable", orion.core.config.evc.enable): + return self._attempt_branching(conflicts, experiment, version, branching) + elif must_branch: + log.warning( + "Running experiment in a different state:\n%s", + self._get_branching_status_string(conflicts, branching), + ) + + log.debug("No branching required.") + + self._update_experiment(experiment) + return experiment + + def _get_conflicts(self, experiment, branching): + """Get conflicts between current experiment and corresponding configuration in database""" + log.debug("Looking for conflicts in new configuration.") + db_experiment = self.load(experiment.name, experiment.version, mode="r") + conflicts = detect_conflicts( + db_experiment.configuration, experiment.configuration, branching + ) + + log.debug(f"{len(conflicts.get())} conflicts detected:\n {conflicts.get()}") + + # elif must_branch and not enable_branching: + # raise ValueError("Configuration is different and generate a branching event") + + return conflicts + + def load(self, name, version=None, mode="r"): + """Load experiment from database + + An experiment view provides all reading operations of standard experiment but prevents the + modification of the experiment and its trials. + + Parameters + ---------- + name: str + Name of the experiment to build + version: int, optional + Version to select. If None, last version will be selected. + If version given is larger than largest version available, + the largest version will be selected. + mode: str, optional + The access rights of the experiment on the database. + 'r': read access only + 'w': can read and write to database + Default is 'r' + + """ + assert mode in set("rw") + + log.debug( + f"Loading experiment {name} (version={version}) from database in mode `{mode}`" + ) + db_config = self.fetch_config_from_db(name, version) + + if not db_config: + version = version if version else "*" + message = ( + f"No experiment with given name '{name}' and version '{version}' inside database, " + "no view can be created." + ) + raise NoConfigurationError(message) + + db_config.setdefault("version", 1) + + return self.create_experiment(mode=mode, **db_config) + + def fetch_config_from_db(self, name, version=None): + """Fetch configuration from database + + Parameters + ---------- + name: str + Name of the experiment to fetch + version: int, optional + Version to select. If None, last version will be selected. + If version given is larger than largest version available, + the largest version will be selected. + + """ + configs = self.storage.fetch_experiments({"name": name}) + + if not configs: + return {} + + config = _fetch_config_version(configs, version) + + if len(configs) > 1 and version is None: + log.info( + "Many versions for experiment %s have been found. Using latest " + "version %s.", + name, + config["version"], + ) + + log.debug("Config found in DB:\n%s", pprint.pformat(config)) + + backward.populate_space(config, force_update=False) + backward.update_max_broken(config) + + return config + + def _register_experiment(self, experiment): + """Register a new experiment in the database""" + experiment.metadata["datetime"] = datetime.datetime.utcnow() + config = experiment.configuration + # This will raise DuplicateKeyError if a concurrent experiment with + # identical (name, metadata.user) is written first in the database. + + self.storage.create_experiment(config) + + # XXX: Reminder for future DB implementations: + # MongoDB, updates an inserted dict with _id, so should you :P + experiment._id = config["_id"] # pylint:disable=protected-access + + # Update refers in db if experiment is root + if experiment.refers.get("parent_id") is None: + log.debug("update refers (name: %s)", experiment.name) + experiment.refers["root_id"] = experiment.id + self.storage.update_experiment( + experiment, refers=experiment.configuration["refers"] + ) + + def _update_experiment(self, experiment: Experiment) -> None: + """Update experiment configuration in database""" + log.debug("Updating experiment (name: %s)", experiment.name) + config = experiment.configuration + + # TODO: Remove since this should not occur anymore without metadata.user in the indices? + # Writing the final config to an already existing experiment raises + # a DuplicatKeyError because of the embedding id `metadata.user`. + # To avoid this `final_config["name"]` is popped out before + # `db.write()`, thus seamingly breaking the compound index + # `(name, metadata.user)` + config.pop("name") + + self.storage.update_experiment(experiment, **config) + + log.debug("Experiment configuration successfully updated in DB.") + + def _attempt_branching(self, conflicts, experiment, version, branching): + if len(conflicts.get()) > 1: + log.debug("Experiment must branch because of conflicts") + else: + assert branching.get("branch_to") + log.debug("Experiment branching forced with ``branch_to``") + + branched_experiment = self._branch_experiment( + experiment, conflicts, version, branching + ) + log.debug("Now attempting registration of branched experiment in DB.") + try: + self._register_experiment(branched_experiment) + log.debug("Branched experiment successfully registered in DB.") + except DuplicateKeyError as e: + log.debug( + "Experiment registration failed. This is likely due to a race condition " + "during branching. Now rolling back and re-attempting building " + "the branched experiment." + ) + raise RaceCondition( + "There was a race condition during branching. This error can " + "also occur if you try branching from a specific version that already " + "has a child experiment with the same name. Change the name of the new " + "experiment and use `branch-from` to specify the parent experiment." + ) from e + + return branched_experiment + + def consolidate_config(self, name: str, version: int | None, config: dict): + """Merge together given configuration with db configuration matching + for experiment (``name``, ``version``) + """ + db_config = self.fetch_config_from_db(name, version) + + # Do not merge spaces, the new definition overrides it. + if "space" in config: + db_config.pop("space", None) + + log.debug("Merging user and db configs:") + log.debug(" config from user:\n%s", pprint.pformat(config)) + log.debug(" config from DB:\n%s", pprint.pformat(db_config)) + + new_config = config + config = resolve_config.merge_configs(db_config, config) + + config.setdefault("metadata", {}) + resolve_config.update_metadata(config["metadata"]) + + merge_algorithm_config(config, new_config) + # TODO: Remove for v0.4 + merge_producer_config(config, new_config) + + config.setdefault("name", name) + config.setdefault("version", version) + + log.debug(" Merged config:\n%s", pprint.pformat(config)) + + return config + + def _get_branching_status_string(self, conflicts, branching_arguments): + experiment_brancher = ExperimentBranchBuilder( + conflicts, enabled=False, storage=self.storage, **branching_arguments + ) + branching_prompt = BranchingPrompt(experiment_brancher) + return branching_prompt.get_status() + + def _branch_experiment(self, experiment, conflicts, version, branching_arguments): + """Create a new branch experiment with adapters for the given conflicts""" + experiment_brancher = ExperimentBranchBuilder( + conflicts, storage=self.storage, **branching_arguments + ) + + needs_manual_resolution = ( + not experiment_brancher.is_resolved or experiment_brancher.manual_resolution + ) + + if not experiment_brancher.is_resolved: + name_conflict = conflicts.get([ExperimentNameConflict])[0] + if not name_conflict.is_resolved and not version: + log.debug( + "A race condition likely occurred during conflicts resolutions. " + "Now rolling back and attempting re-building the branched experiment." + ) + raise RaceCondition( + "There was likely a race condition during version increment." + ) + + if needs_manual_resolution: + log.debug("Some conflicts cannot be solved automatically.") + + # TODO: This should only be possible when using cmdline API + branching_prompt = BranchingPrompt(experiment_brancher) + + if not sys.__stdin__.isatty(): + log.debug( + "No interactive prompt available to manually resolve conflicts." + ) + raise BranchingEvent(branching_prompt.get_status()) + + branching_prompt.cmdloop() + + if branching_prompt.abort or not experiment_brancher.is_resolved: + sys.exit() + + log.debug("Creating new branched configuration") + config = experiment_brancher.conflicting_config + config["refers"][ + "adapter" + ] = experiment_brancher.create_adapters().configuration + config["refers"]["parent_id"] = experiment.id + + config.pop("_id") + + return self.create_experiment(mode="x", **config) + + # pylint: disable=too-many-arguments + def create_experiment( + self, + name: str, + version: int, + mode: Mode, + space: Space | dict[str, str], + algorithms: str | dict | None = None, + max_trials: int | None = None, + max_broken: int | None = None, + working_dir: str | None = None, + metadata: dict | None = None, + refers: dict | None = None, + producer: dict | None = None, + user: str | None = None, + _id: int | str | None = None, + **kwargs, + ) -> Experiment: + """Instantiate the experiment and its attribute objects + + All unspecified arguments will be replaced by system's defaults (orion.core.config.*). + + Parameters + ---------- + name: str + Name of the experiment. + version: int + Version of the experiment. + mode: str + The access rights of the experiment on the database. + 'r': read access only + 'w': can read and write to database + 'x': can read and write to database, algo is instantiated and can execute optimization + space: dict or Space object + Optimization space of the algorithm. If dict, should have the form + `dict(name='(args)')`. + algorithms: str or dict, optional + Algorithm used for optimization. + strategy: str or dict, optional + Parallel strategy to use to parallelize the algorithm. + max_trials: int, optional + Maximum number or trials before the experiment is considered done. + max_broken: int, optional + Number of broken trials for the experiment to be considered broken. + storage: dict, optional + Configuration of the storage backend. + + """ + T = TypeVar("T") + V = TypeVar("V") + + def _default(v: T | None, default: V) -> T | V: + return v if v is not None else default + + space = _instantiate_space(space) + max_trials = _default(max_trials, orion.core.config.experiment.max_trials) + instantiated_algorithm = _instantiate_algo( + space=space, + max_trials=max_trials, + config=algorithms, + ignore_unavailable=mode != "x", + ) + + max_broken = _default(max_broken, orion.core.config.experiment.max_broken) + working_dir = _default(working_dir, orion.core.config.experiment.working_dir) + metadata = _default(metadata, {"user": _default(user, getpass.getuser())}) + refers = _default(refers, dict(parent_id=None, root_id=None, adapter=[])) + refers["adapter"] = _instantiate_adapters(refers.get("adapter", [])) # type: ignore + + _instantiate_strategy((producer or {}).get("strategy")) + + experiment = Experiment( + storage=self.storage, + name=name, + version=version, + mode=mode, + space=space, + _id=_id, + max_trials=max_trials, + algorithms=instantiated_algorithm, + max_broken=max_broken, + working_dir=working_dir, + metadata=metadata, + refers=refers, + ) + + if kwargs: + # TODO: https://github.com/Epistimio/orion/issues/972 + log.debug( + "create_experiment received some extra unused arguments: %s", kwargs + ) + + return experiment diff --git a/src/orion/core/io/interactive_commands/branching_prompt.py b/src/orion/core/io/interactive_commands/branching_prompt.py index b43994817..5c1ab8b58 100644 --- a/src/orion/core/io/interactive_commands/branching_prompt.py +++ b/src/orion/core/io/interactive_commands/branching_prompt.py @@ -19,7 +19,6 @@ from orion.algo.space import Dimension from orion.core.evc import adapters, conflicts from orion.core.utils.diff import green, red -from orion.storage.base import get_storage readline.set_completer_delims(" ") @@ -306,7 +305,7 @@ def complete_name(self, text, line, begidx, endidx): } names = [ experiment["name"] - for experiment in get_storage().fetch_experiments(query) + for experiment in self.branch_builder.storage.fetch_experiments(query) ] return self._get_completions(names, text) diff --git a/src/orion/core/utils/singleton.py b/src/orion/core/utils/singleton.py index f572421be..8e9fe4022 100644 --- a/src/orion/core/utils/singleton.py +++ b/src/orion/core/utils/singleton.py @@ -61,26 +61,6 @@ class SingletonFactory(AbstractSingletonType, Factory): `AbstractSingletonType`.""" -def update_singletons(values=None): - """Replace singletons by given values and return previous singleton objects""" - if values is None: - values = {} - - # Avoiding circular import problems when importing this module. - from orion.core.io.database import database_factory - from orion.storage.base import storage_factory - - singletons = (storage_factory, database_factory) - - updated_singletons = {} - for singleton in singletons: - name = singleton.base.__name__.lower() - updated_singletons[name] = singleton.instance - singleton.instance = values.get(name, None) - - return updated_singletons - - class GenericSingletonFactory(GenericFactory): """Factory to create singleton instances of classes inheriting a given ``base`` class. @@ -127,6 +107,7 @@ def create(self, of_type=None, *args, **kwargs): if self.instance is None and of_type is None: raise SingletonNotInstantiatedError(self.base.__name__) + elif self.instance is None: try: self.instance = super().create(of_type, *args, **kwargs) diff --git a/src/orion/core/worker/consumer.py b/src/orion/core/worker/consumer.py index 25f82b9b0..ae9cf5073 100644 --- a/src/orion/core/worker/consumer.py +++ b/src/orion/core/worker/consumer.py @@ -212,6 +212,7 @@ def _consume(self, trial, workdirname): log.debug("New temp results file: %s", results_file.name) log.debug("Building command line argument and configuration for trial.") + env = self.get_execution_environment(trial, results_file.name) cmd_args = self.template_builder.format( config_file.name, trial, self.experiment @@ -254,13 +255,21 @@ def execute_process(self, cmd_args, environ): try: # pylint: disable = consider-using-with - process = subprocess.Popen(command, env=environ) + process = subprocess.Popen( + command, + env=environ, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) except PermissionError as exc: log.debug("Script is not executable") raise InexecutableUserScript(" ".join(cmd_args)) from exc + stdout, _ = process.communicate() + return_code = process.wait() log.debug(f"Script finished with return code {return_code}") if return_code != 0: + log.debug("%s", stdout.decode("utf-8")) raise ExecutionError(return_code) diff --git a/src/orion/core/worker/experiment.py b/src/orion/core/worker/experiment.py index 6b752a267..f46d42907 100644 --- a/src/orion/core/worker/experiment.py +++ b/src/orion/core/worker/experiment.py @@ -25,8 +25,7 @@ from orion.core.io.database import DuplicateKeyError from orion.core.utils.exceptions import UnsupportedOperation from orion.core.utils.flatten import flatten -from orion.core.utils.singleton import update_singletons -from orion.storage.base import FailedUpdate, get_storage +from orion.storage.base import BaseStorageProtocol, FailedUpdate log = logging.getLogger(__name__) Mode = Literal["r", "w", "x"] @@ -153,6 +152,7 @@ def __init__( working_dir: str | None = None, metadata: dict | None = None, refers: dict | None = None, + storage: BaseStorageProtocol | None = None, ): self._id = _id self.name = name @@ -163,12 +163,19 @@ def __init__( self.metadata = metadata or {} self.max_trials = max_trials self.max_broken = max_broken - self.algorithms = algorithms self.working_dir = working_dir - self._storage = get_storage() - self._node = ExperimentNode(self.name, self.version, experiment=self) + self._storage = storage + + self._node = ExperimentNode( + self.name, self.version, experiment=self, storage=self._storage + ) + + @property + def storage(self): + """Return the storage currently in use by this experiment""" + return self._storage def _check_if_writable(self): if self.mode == "r": @@ -190,20 +197,12 @@ def __getstate__(self): for entry in self.__slots__: state[entry] = getattr(self, entry) - # TODO: This should be removed when singletons and `get_storage()` are removed. - # See https://github.com/Epistimio/orion/issues/606 - singletons = update_singletons() - state["singletons"] = singletons - update_singletons(singletons) - return state def __setstate__(self, state): for entry in self.__slots__: setattr(self, entry, state[entry]) - update_singletons(state.pop("singletons")) - def to_pandas(self, with_evc_tree=False): """Builds a dataframe with the trials of the experiment @@ -443,11 +442,16 @@ def acquire_algorithm_lock( with self._storage.acquire_algorithm_lock( experiment=self, timeout=timeout, retry_interval=retry_interval ) as locked_algorithm_state: + if locked_algorithm_state.configuration != self.algorithms.configuration: log.warning( "Saved configuration: %s", locked_algorithm_state.configuration ) - log.warning("Current configuration: %s", self.algorithms.configuration) + log.warning( + "Current configuration: %s %s", + self.algorithms.configuration, + self._storage._db, + ) raise RuntimeError( "Algorithm configuration changed since last experiment execution. " "Algorithm cannot be resumed with a different configuration. " diff --git a/src/orion/core/worker/trial_pacemaker.py b/src/orion/core/worker/trial_pacemaker.py index 69c5521c7..e06107408 100644 --- a/src/orion/core/worker/trial_pacemaker.py +++ b/src/orion/core/worker/trial_pacemaker.py @@ -7,8 +7,6 @@ """ import threading -from orion.storage.base import get_storage - STOPPED_STATUS = {"completed", "interrupted", "suspended"} @@ -23,12 +21,12 @@ class TrialPacemaker(threading.Thread): """ - def __init__(self, trial, wait_time=60): + def __init__(self, trial, storage, wait_time=60): threading.Thread.__init__(self) self.stopped = threading.Event() self.trial = trial self.wait_time = wait_time - self.storage = get_storage() + self.storage = storage def stop(self): """Stop monitoring.""" diff --git a/src/orion/serving/experiments_resource.py b/src/orion/serving/experiments_resource.py index b757bb6eb..46ffad405 100644 --- a/src/orion/serving/experiments_resource.py +++ b/src/orion/serving/experiments_resource.py @@ -17,14 +17,13 @@ build_experiment_response, build_experiments_response, ) -from orion.storage.base import get_storage class ExperimentsResource: """Handle requests for the experiments/ REST endpoint""" - def __init__(self): - self.storage = get_storage() + def __init__(self, storage): + self.storage = storage def on_get(self, req: Request, resp: Response): """Handle the GET requests for experiments/""" @@ -41,7 +40,7 @@ def on_get_experiment(self, req: Request, resp: Response, name: str): """ verify_query_parameters(req.params, ["version"]) version = req.get_param_as_int("version") - experiment = retrieve_experiment(name, version) + experiment = retrieve_experiment(self.storage, name, version) status = _retrieve_status(experiment) algorithm = _retrieve_algorithm(experiment) diff --git a/src/orion/serving/parameters.py b/src/orion/serving/parameters.py index 3642da26c..579b58397 100644 --- a/src/orion/serving/parameters.py +++ b/src/orion/serving/parameters.py @@ -64,7 +64,7 @@ def _compose_error_message(key: str, supported_parameters: list): def retrieve_experiment( - experiment_name: str, version: int = None + storage, experiment_name: str, version: int = None ) -> Optional[Experiment]: """ Retrieve an experiment from the database with the given name and version. @@ -75,7 +75,7 @@ def retrieve_experiment( When the experiment doesn't exist """ try: - experiment = experiment_builder.load(experiment_name, version) + experiment = experiment_builder.load(experiment_name, version, storage=storage) if version and experiment.version != version: raise falcon.HTTPNotFound( title=ERROR_EXPERIMENT_NOT_FOUND, diff --git a/src/orion/serving/plots_resources.py b/src/orion/serving/plots_resources.py index 0f9cf942d..3d7d46753 100644 --- a/src/orion/serving/plots_resources.py +++ b/src/orion/serving/plots_resources.py @@ -10,21 +10,22 @@ from orion.client import ExperimentClient from orion.serving.parameters import retrieve_experiment -from orion.storage.base import get_storage class PlotsResource: """Serves all the requests made to plots/ REST endpoint""" - def __init__(self): - self.storage = get_storage() + def __init__(self, storage): + self.storage = storage def on_get_lpi(self, req: Request, resp: Response, experiment_name: str): """ Handle GET requests for plotting lpi plots on plots/lpi/:experiment where ``experiment`` is the user-defined name of the experiment. """ - experiment = ExperimentClient(retrieve_experiment(experiment_name), None) + experiment = ExperimentClient( + retrieve_experiment(self.storage, experiment_name), None + ) resp.body = experiment.plot.lpi().to_json() def on_get_parallel_coordinates( @@ -35,7 +36,9 @@ def on_get_parallel_coordinates( plots/parallel_coordinates/:experiment where ``experiment`` is the user-defined name of the experiment. """ - experiment = ExperimentClient(retrieve_experiment(experiment_name), None) + experiment = ExperimentClient( + retrieve_experiment(self.storage, experiment_name), None + ) resp.body = experiment.plot.parallel_coordinates().to_json() def on_get_partial_dependencies( @@ -46,7 +49,9 @@ def on_get_partial_dependencies( plots/partial_dependencies/:experiment where ``experiment`` is the user-defined name of the experiment. """ - experiment = ExperimentClient(retrieve_experiment(experiment_name), None) + experiment = ExperimentClient( + retrieve_experiment(self.storage, experiment_name), None + ) resp.body = experiment.plot.partial_dependencies().to_json() def on_get_regret(self, req: Request, resp: Response, experiment_name: str): @@ -54,5 +59,7 @@ def on_get_regret(self, req: Request, resp: Response, experiment_name: str): Handle GET requests for plotting regret plots on plots/regret/:experiment where ``experiment`` is the user-defined name of the experiment. """ - experiment = ExperimentClient(retrieve_experiment(experiment_name), None) + experiment = ExperimentClient( + retrieve_experiment(self.storage, experiment_name), None + ) resp.body = experiment.plot.regret().to_json() diff --git a/src/orion/serving/runtime.py b/src/orion/serving/runtime.py index 8ecad6257..6dfc9d011 100644 --- a/src/orion/serving/runtime.py +++ b/src/orion/serving/runtime.py @@ -5,14 +5,13 @@ import json import orion.core -from orion.storage.base import get_storage class RuntimeResource: """Handle requests for the '/' REST endpoint""" - def __init__(self): - pass + def __init__(self, storage): + self.storage = storage def on_get(self, req, resp): """Handle the HTTP GET requests for the '/' endpoint @@ -24,7 +23,7 @@ def on_get(self, req, resp): resp The response to send back """ - database = get_storage()._db.__class__.__name__ + database = self.storage._db.__class__.__name__ response = { "orion": orion.core.__version__, "server": "gunicorn", diff --git a/src/orion/serving/trials_resource.py b/src/orion/serving/trials_resource.py index ba52b60dc..e093b0a82 100644 --- a/src/orion/serving/trials_resource.py +++ b/src/orion/serving/trials_resource.py @@ -13,7 +13,6 @@ verify_status, ) from orion.serving.responses import build_trial_response, build_trials_response -from orion.storage.base import get_storage SUPPORTED_PARAMETERS = ["ancestors", "status", "version"] @@ -21,8 +20,8 @@ class TrialsResource: """Serves all the requests made to trials/ REST endpoint""" - def __init__(self): - self.storage = get_storage() + def __init__(self, storage): + self.storage = storage def on_get_trials_in_experiment( self, req: Request, resp: Response, experiment_name: str @@ -38,7 +37,7 @@ def on_get_trials_in_experiment( version = req.get_param_as_int("version") with_ancestors = req.get_param_as_bool("ancestors", default=False) - experiment = retrieve_experiment(experiment_name, version) + experiment = retrieve_experiment(self.storage, experiment_name, version) if status: trials = experiment.fetch_trials_by_status(status, with_ancestors) else: @@ -54,7 +53,7 @@ def on_get_trial_in_experiment( Handle GET requests for trials/:experiment/:trial_id where ``experiment`` is the user-defined name of the experiment and ``trial_id`` the id of the trial. """ - experiment = retrieve_experiment(experiment_name) + experiment = retrieve_experiment(self.storage, experiment_name) trial = retrieve_trial(experiment, trial_id) response = build_trial_response(trial) diff --git a/src/orion/serving/webapi.py b/src/orion/serving/webapi.py index 7c94e71ed..427bfd4c1 100644 --- a/src/orion/serving/webapi.py +++ b/src/orion/serving/webapi.py @@ -16,7 +16,6 @@ from orion.serving.plots_resources import PlotsResource from orion.serving.runtime import RuntimeResource from orion.serving.trials_resource import TrialsResource -from orion.storage.base import setup_storage logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -82,7 +81,7 @@ class WebApi(falcon.API): routing engine. """ - def __init__(self, config=None): + def __init__(self, storage, config=None): # By default, server will reject requests coming from a server # with different origin. E.g., if server is hosted at # http://myorionserver.com, it won't accept an API call @@ -105,14 +104,13 @@ def __init__(self, config=None): cors = MyCORS(allow_origins_list=frontends_uri) super().__init__(middleware=[cors.middleware]) self.config = config - - setup_storage(config.get("storage")) + self.storage = storage # Create our resources - root_resource = RuntimeResource() - experiments_resource = ExperimentsResource() - trials_resource = TrialsResource() - plots_resource = PlotsResource() + root_resource = RuntimeResource(self.storage) + experiments_resource = ExperimentsResource(self.storage) + trials_resource = TrialsResource(self.storage) + plots_resource = PlotsResource(self.storage) # Build routes self.add_route("/", root_resource) diff --git a/src/orion/storage/base.py b/src/orion/storage/base.py index 645f00c1a..b0c2a042b 100644 --- a/src/orion/storage/base.py +++ b/src/orion/storage/base.py @@ -25,7 +25,7 @@ import orion.core from orion.core.io import resolve_config -from orion.core.utils.singleton import GenericSingletonFactory +from orion.core.utils import GenericFactory log = logging.getLogger(__name__) @@ -558,27 +558,7 @@ def acquire_algorithm_lock(self, experiment, timeout=600, retry_interval=1): raise NotImplementedError() -storage_factory = GenericSingletonFactory(BaseStorageProtocol) - - -def get_storage(): - """Return current storage - - This is a wrapper around the Storage Singleton object to provide - better error message when it is used without being initialized. - - Raises - ------ - RuntimeError - If the underlying storage was not initialized prior to calling this function - - Notes - ----- - To initialize the underlying storage you must first call `Storage(...)` - with the appropriate arguments for the chosen backend - - """ - return storage_factory.create() +storage_factory = GenericFactory(BaseStorageProtocol) def setup_storage(storage=None, debug=False): @@ -617,7 +597,7 @@ def setup_storage(storage=None, debug=False): log.debug("Creating %s storage client with args: %s", storage_type, storage) try: - storage_factory.create(of_type=storage_type, **storage) + return storage_factory.create(of_type=storage_type, **storage) except ValueError: if storage_factory.create().__class__.__name__.lower() != storage_type.lower(): raise diff --git a/src/orion/storage/legacy.py b/src/orion/storage/legacy.py index 166c51e45..16484eb83 100644 --- a/src/orion/storage/legacy.py +++ b/src/orion/storage/legacy.py @@ -27,26 +27,6 @@ log = logging.getLogger(__name__) -def get_database(): - """Return current database - - This is a wrapper around the Database Singleton object to provide - better error message when it is used without being initialized. - - Raises - ------ - RuntimeError - If the underlying database was not initialized prior to calling this function - - Notes - ----- - To initialize the underlying database you must first call `Database(...)` - with the appropriate arguments for the chosen backend - - """ - return database_factory.create() - - def setup_database(config=None): """Create the Database instance from a configuration. @@ -65,6 +45,7 @@ def setup_database(config=None): dbtype = db_opts.pop("type") log.debug("Creating %s database client with args: %s", dbtype, db_opts) + return database_factory.create(dbtype, **db_opts) @@ -83,10 +64,7 @@ class Legacy(BaseStorageProtocol): """ def __init__(self, database=None, setup=True): - if database is not None: - setup_database(database) - - self._db = database_factory.create() + self._db = setup_database(database) if setup: self._setup_db() diff --git a/src/orion/testing/__init__.py b/src/orion/testing/__init__.py index b2a6152c4..8396a89eb 100644 --- a/src/orion/testing/__init__.py +++ b/src/orion/testing/__init__.py @@ -13,10 +13,13 @@ import os from contextlib import contextmanager +from falcon import testing + import orion.algo.space import orion.core.io.experiment_builder as experiment_builder from orion.core.io.space_builder import SpaceBuilder from orion.core.worker.producer import Producer +from orion.serving.webapi import WebApi from orion.testing.state import OrionState base_experiment = { @@ -145,9 +148,8 @@ def generate_benchmark_experiments_trials( return gen_exps, gen_trials -@contextmanager def create_study_experiments( - exp_config, trial_config, algorithms, task_number, max_trial, n_workers=(1,) + state, exp_config, trial_config, algorithms, task_number, max_trial, n_workers=(1,) ): gen_exps, gen_trials = generate_benchmark_experiments_trials( algorithms, exp_config, trial_config, task_number * len(n_workers), max_trial @@ -161,20 +163,25 @@ def create_study_experiments( for worker in n_workers: for _ in range(len(algorithms)): workers.append(worker) - with OrionState(experiments=gen_exps, trials=gen_trials): - experiments = [] - experiments_info = [] - for i in range(task_number * len(n_workers) * len(algorithms)): - experiment = experiment_builder.build(f"experiment-name-{i}") - executor = Joblib(n_workers=workers[i], backend="threading") - client = ExperimentClient(experiment, executor=executor) - experiments.append(client) + state.add_trials(*gen_trials) + state.add_experiments(*gen_exps) + + experiments = [] + experiments_info = [] + for i in range(task_number * len(n_workers) * len(algorithms)): + experiment = experiment_builder.build( + f"experiment-name-{i}", storage=state.storage_config + ) + + executor = Joblib(n_workers=workers[i], backend="threading") + client = ExperimentClient(experiment, executor=executor) + experiments.append(client) - for index, exp in enumerate(experiments): - experiments_info.append((int(index / task_number), exp)) + for index, exp in enumerate(experiments): + experiments_info.append((int(index / task_number), exp)) - yield experiments_info + return experiments_info def mock_space_iterate(monkeypatch): @@ -212,7 +219,9 @@ def create_experiment(exp_config=None, trial_config=None, statuses=None): experiments=[exp_config], trials=generate_trials(trial_config, statuses, exp_config), ) as cfg: - experiment = experiment_builder.build(name=exp_config["name"]) + experiment = experiment_builder.build( + name=exp_config["name"], storage=cfg.storage_config + ) if cfg.trials: experiment._id = cfg.trials[0]["experiment"] client = ExperimentClient(experiment) @@ -221,6 +230,20 @@ def create_experiment(exp_config=None, trial_config=None, statuses=None): client.close() +@contextmanager +def falcon_client(exp_config=None, trial_config=None, statuses=None): + """Context manager for the creation of an ExperimentClient and storage init""" + + with create_experiment(exp_config, trial_config, statuses) as ( + cfg, + experiment, + exp_client, + ): + falcon_client = testing.TestClient(WebApi(cfg.storage, {})) + + yield cfg, experiment, exp_client, falcon_client + + class MockDatetime(datetime.datetime): """Fake Datetime""" diff --git a/src/orion/testing/state.py b/src/orion/testing/state.py index 161a44b63..33cae686c 100644 --- a/src/orion/testing/state.py +++ b/src/orion/testing/state.py @@ -7,25 +7,24 @@ """ # pylint: disable=protected-access +import copy import os import tempfile import yaml +import orion from orion.core.io import experiment_builder as experiment_builder -from orion.core.utils.singleton import ( - SingletonAlreadyInstantiatedError, - update_singletons, -) from orion.core.worker.trial import Trial -from orion.storage.base import get_storage, storage_factory +from orion.storage.base import setup_storage, storage_factory # pylint: disable=no-self-use,protected-access class BaseOrionState: - """Setup global variables and singleton for tests. + """Setup global variables and storage for tests. - It swaps the singleton with `None` at startup and restores them after the tests. + It generates a new storage configuration and swaps it, + the previous configuration is restored after the test. It also initializes PickleDB as the storage for testing. We use PickledDB as our storage mock @@ -82,7 +81,10 @@ def __init__( self.tempfile = None self.tempfile_path = None + + self.previous_config = copy.deepcopy(orion.core.config.storage.to_dict()) self.storage_config = _select(storage, _get_default_test_storage()) + self.storage = None self._benchmarks = _select(benchmarks, []) self._experiments = _select(experiments, []) @@ -99,7 +101,7 @@ def __init__( def init(self, config): """Initialize environment before testing""" - self.storage(config) + self.setup_storage(config) self.load_experience_configuration() return self @@ -118,19 +120,31 @@ def cleanup(self): os.close(self.tempfile) _remove(self.tempfile_path) + def add_experiments(self, *experiments): + """Add experiments to the database""" + for exp in experiments: + self.storage.create_experiment(exp) + self._experiments.append(exp) + + def add_trials(self, *trials): + """Add trials to the database""" + for trial in trials: + nt = self.storage.register_trial(Trial(**trial)) + self.trials.append(nt) + def _set_tables(self): self.trials = [] self.lies = [] for exp in self._experiments: - get_storage().create_experiment(exp) + self.storage.create_experiment(exp) for trial in self._trials: - nt = get_storage().register_trial(Trial(**trial)) + nt = self.storage.register_trial(Trial(**trial)) self.trials.append(nt.to_dict()) for lie in self._lies: - nt = get_storage().register_lie(Trial(**lie)) + nt = self.storage.register_lie(Trial(**lie)) self.lies.append(nt.to_dict()) def load_experience_configuration(self): @@ -179,33 +193,33 @@ def replace_file(v): def __enter__(self): """Load a new database state""" - self.singletons = update_singletons() self.cleanup() return self.init(self.make_config()) def __exit__(self, exc_type, exc_val, exc_tb): """Cleanup database state""" self.cleanup() + orion.core.config.storage.from_dict(self.previous_config) - update_singletons(self.singletons) - - def storage(self, config=None): + def setup_storage(self, config=None): """Return test storage""" + self.previous_config = orion.core.config.storage.to_dict() + orion.core.config.storage.from_dict(config) + if config is None: - return get_storage() + self.storage = setup_storage() + return self.storage try: + self.storage_config = copy.deepcopy(config) config["of_type"] = config.pop("type") - db = storage_factory.create(**config) - self.storage_config = config - except SingletonAlreadyInstantiatedError: - db = get_storage() + self.storage = storage_factory.create(**config) except KeyError: print(self.storage_config) raise - return db + return self.storage class LegacyOrionState(BaseOrionState): @@ -218,14 +232,14 @@ def __init__(self, *args, **kwargs): @property def database(self): """Retrieve legacy database handle""" - return get_storage()._db + return self.storage._db def init(self, config): """Initialize environment before testing""" - self.storage(config) + self.setup_storage(config) self.initialized = True - if hasattr(get_storage(), "_db"): + if hasattr(self.storage, "_db"): self.database.remove("experiments", {}) self.database.remove("trials", {}) @@ -244,11 +258,11 @@ def _set_tables(self): if self._experiments: self.database.write("experiments", self._experiments) for experiment in self._experiments: - get_storage().initialize_algorithm_lock( + self.storage.initialize_algorithm_lock( experiment["_id"], experiment.get("algorithms") ) # For tests that need a deterministic experiment id. - get_storage().initialize_algorithm_lock( + self.storage.initialize_algorithm_lock( experiment["name"], experiment.get("algorithms") ) if self._trials: diff --git a/tests/conftest.py b/tests/conftest.py index fc3114435..ca74830d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ """Common fixtures and utils for unittests and functional tests.""" import getpass import os -import tempfile import numpy import pytest @@ -16,9 +15,8 @@ from orion.core.io import resolve_config from orion.core.io.database import database_factory from orion.core.utils import format_trials -from orion.core.utils.singleton import update_singletons from orion.core.worker.trial import Trial -from orion.storage.base import get_storage, setup_storage, storage_factory +from orion.storage.base import storage_factory # So that assert messages show up in tests defined outside testing suite. pytest.register_assert_rewrite("orion.testing") @@ -361,26 +359,6 @@ def fixed_dictionary(user_script): monkeypatch.setattr(resolve_config, "infer_versioning_metadata", fixed_dictionary) -@pytest.fixture(scope="function") -def setup_pickleddb_database(): - """Configure the database""" - update_singletons() - temporary_file = tempfile.NamedTemporaryFile() - - os.environ["ORION_DB_TYPE"] = "pickleddb" - os.environ["ORION_DB_ADDRESS"] = temporary_file.name - yield - temporary_file.close() - del os.environ["ORION_DB_TYPE"] - del os.environ["ORION_DB_ADDRESS"] - - -@pytest.fixture(scope="function") -def storage(setup_pickleddb_database): - setup_storage() - yield get_storage() - - @pytest.fixture() def with_user_userxyz(monkeypatch): """Make ``getpass.getuser()`` return ``'userxyz'``.""" @@ -392,3 +370,15 @@ def random_dt(monkeypatch): """Make ``datetime.datetime.utcnow()`` return an arbitrary date.""" with mocked_datetime(monkeypatch) as datetime: yield datetime.utcnow() + + +@pytest.fixture(scope="function") +def orionstate(): + """Configure the database""" + with OrionState() as cfg: + yield cfg + + +@pytest.fixture(scope="function") +def storage(orionstate): + yield orionstate.storage diff --git a/tests/functional/algos/test_algos.py b/tests/functional/algos/test_algos.py index d49aa7a80..24e4a3722 100644 --- a/tests/functional/algos/test_algos.py +++ b/tests/functional/algos/test_algos.py @@ -302,12 +302,15 @@ def test_with_multidim(algorithm): def test_with_evc(algorithm): """Test a scenario where algos are warm-started with EVC.""" - with OrionState(storage={"type": "legacy", "database": {"type": "PickledDB"}}): + with OrionState( + storage={"type": "legacy", "database": {"type": "PickledDB"}} + ) as cfg: base_exp = create_experiment( name="exp", space=space_with_fidelity, algorithms=algorithm_configs["random"], max_trials=10, + storage=cfg.storage_config, ) base_exp.workon(rosenbrock, max_trials=10) @@ -317,6 +320,7 @@ def test_with_evc(algorithm): algorithms=algorithm, max_trials=30, branching={"branch_from": "exp", "enable": True}, + storage=cfg.storage_config, ) assert exp.version == 2 @@ -370,6 +374,7 @@ def test_parallel_workers(algorithm): name=name, space=space_with_fidelity, algorithms=algorithm, + storage=cfg.storage_config, ) exp.workon(rosenbrock, max_trials=MAX_TRIALS, n_workers=2) diff --git a/tests/functional/backward_compatibility/test_versions.py b/tests/functional/backward_compatibility/test_versions.py index 6fa76bff8..8f0e3b1bc 100644 --- a/tests/functional/backward_compatibility/test_versions.py +++ b/tests/functional/backward_compatibility/test_versions.py @@ -10,7 +10,7 @@ import orion.core.io.experiment_builder as experiment_builder from orion.client import create_experiment from orion.core.io.database import database_factory -from orion.storage.base import get_storage, storage_factory +from orion.storage.base import storage_factory DIRNAME = os.path.dirname(os.path.abspath(__file__)) @@ -235,9 +235,7 @@ def null_db_instances(): def build_storage(): """Build storage from scratch""" null_db_instances() - experiment_builder.setup_storage() - - return get_storage() + return experiment_builder.setup_storage() @pytest.mark.usefixtures("fill_db") diff --git a/tests/functional/benchmark/test_benchmark_flow.py b/tests/functional/benchmark/test_benchmark_flow.py index 809f5f251..4fecfed2b 100644 --- a/tests/functional/benchmark/test_benchmark_flow.py +++ b/tests/functional/benchmark/test_benchmark_flow.py @@ -7,6 +7,7 @@ from orion.benchmark.assessment import AverageRank, AverageResult from orion.benchmark.benchmark_client import get_or_create_benchmark from orion.benchmark.task import BenchmarkTask, Branin +from orion.storage.base import setup_storage algorithms = [ {"algorithm": {"random": {"seed": 1}}}, @@ -48,7 +49,7 @@ def get_search_space(self): return rspace -@pytest.mark.usefixtures("setup_pickleddb_database") +@pytest.mark.usefixtures("orionstate") def test_simple(): """Test a end 2 end exucution of benchmark""" task_num = 2 @@ -59,7 +60,9 @@ def test_simple(): BirdLike(max_trials), ] + storage = setup_storage() benchmark = get_or_create_benchmark( + storage, name="bm001", algorithms=algorithms, targets=[{"assess": assessments, "task": tasks}], @@ -83,7 +86,8 @@ def test_simple(): assert_benchmark_figures(figures, 4, assessments, tasks) - benchmark = get_or_create_benchmark(name="bm001") + storage = setup_storage() + benchmark = get_or_create_benchmark(storage, name="bm001") figures = benchmark.analysis() assert_benchmark_figures(figures, 4, assessments, tasks) diff --git a/tests/functional/branching/test_branching.py b/tests/functional/branching/test_branching.py index e08cab406..ab29d9abf 100644 --- a/tests/functional/branching/test_branching.py +++ b/tests/functional/branching/test_branching.py @@ -9,7 +9,7 @@ import orion.core.cli import orion.core.io.experiment_builder as experiment_builder -from orion.storage.base import get_storage +from orion.storage.base import setup_storage def execute(command, assert_code=0): @@ -19,7 +19,7 @@ def execute(command, assert_code=0): @pytest.fixture -def init_full_x(setup_pickleddb_database, monkeypatch): +def init_full_x(orionstate, monkeypatch): """Init original experiment""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) name = "full_x" @@ -59,6 +59,8 @@ def init_no_evc(monkeypatch): @pytest.fixture def init_full_x_full_y(init_full_x): """Add y dimension to original""" + print("init_full_x_full_y start") + name = "full_x" branch = "full_x_full_y" orion.core.cli.main( @@ -393,10 +395,10 @@ def test_init(init_full_x): assert pairs == ((("/x", 0),),) -def test_no_evc_overwrite(setup_pickleddb_database, init_no_evc): +def test_no_evc_overwrite(orionstate, init_no_evc): """Test that the experiment config is overwritten if --enable-evc is not passed""" - storage = get_storage() - assert len(get_storage().fetch_experiments({})) == 1 + storage = setup_storage() + assert len(storage.fetch_experiments({})) == 1 experiment = experiment_builder.load(name="full_x") assert experiment.refers["adapter"].configuration == [] @@ -966,7 +968,7 @@ def test_new_script(init_full_x, monkeypatch): metadata["user_script"] = "oh_oh_idontexist.py" metadata["user_args"][0] = "oh_oh_idontexist.py" metadata["parser"]["parser"]["arguments"][0][1] = "oh_oh_idontexist.py" - get_storage().update_experiment(experiment, metadata=metadata) + setup_storage().update_experiment(experiment, metadata=metadata) orion.core.cli.main( ( @@ -1014,7 +1016,7 @@ def test_missing_config(init_full_x_new_config, monkeypatch): metadata["parser"]["file_config_path"] = bad_config_file metadata["parser"]["parser"]["arguments"][2][1] = bad_config_file metadata["user_args"][3] = bad_config_file - get_storage().update_experiment(experiment, metadata=metadata) + setup_storage().update_experiment(experiment, metadata=metadata) orion.core.cli.main( ( @@ -1060,7 +1062,7 @@ def test_missing_and_new_config(init_full_x_new_config, monkeypatch): ) ) - get_storage().update_experiment(experiment, metadata=metadata) + setup_storage().update_experiment(experiment, metadata=metadata) orion.core.cli.main( ( @@ -1170,9 +1172,7 @@ def test_auto_resolution_with_fidelity(init_full_x_full_y, monkeypatch): ] -def test_init_w_version_from_parent_w_children( - setup_pickleddb_database, monkeypatch, capsys -): +def test_init_w_version_from_parent_w_children(orionstate, monkeypatch, capsys): """Test that init of experiment from version with children fails.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) execute( @@ -1199,7 +1199,7 @@ def test_init_w_version_from_parent_w_children( assert "Experiment name" in captured.err -def test_init_w_version_from_exp_wout_child(setup_pickleddb_database, monkeypatch): +def test_init_w_version_from_exp_wout_child(orionstate, monkeypatch): """Test that init of experiment from version without child works.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) execute( @@ -1219,11 +1219,11 @@ def test_init_w_version_from_exp_wout_child(setup_pickleddb_database, monkeypatc "-x~normal(0,1) -y~+normal(0,1) -z~+normal(0,1)" ) - exp = get_storage().fetch_experiments({"name": "experiment", "version": 3}) + exp = setup_storage().fetch_experiments({"name": "experiment", "version": 3}) assert len(list(exp)) -def test_init_w_version_gt_max(setup_pickleddb_database, monkeypatch): +def test_init_w_version_gt_max(orionstate, monkeypatch): """Test that init of experiment from version higher than max works.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) execute( @@ -1243,11 +1243,11 @@ def test_init_w_version_gt_max(setup_pickleddb_database, monkeypatch): "-x~normal(0,1) -y~+normal(0,1) -z~+normal(0,1)" ) - exp = get_storage().fetch_experiments({"name": "experiment", "version": 3}) + exp = setup_storage().fetch_experiments({"name": "experiment", "version": 3}) assert len(list(exp)) -def test_init_check_increment_w_children(setup_pickleddb_database, monkeypatch): +def test_init_check_increment_w_children(orionstate, monkeypatch): """Test that incrementing version works with not same-named children.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) execute( @@ -1266,11 +1266,11 @@ def test_init_check_increment_w_children(setup_pickleddb_database, monkeypatch): "./black_box.py -x~normal(0,1) -z~+normal(0,1)" ) - exp = get_storage().fetch_experiments({"name": "experiment", "version": 2}) + exp = setup_storage().fetch_experiments({"name": "experiment", "version": 2}) assert len(list(exp)) -def test_branch_from_selected_version(setup_pickleddb_database, monkeypatch): +def test_branch_from_selected_version(orionstate, monkeypatch): """Test that branching from a version passed with `--version` works.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) execute( @@ -1290,7 +1290,7 @@ def test_branch_from_selected_version(setup_pickleddb_database, monkeypatch): "-x~normal(0,1) -z~+normal(0,1)" ) - storage = get_storage() + storage = setup_storage() parent = storage.fetch_experiments({"name": "experiment", "version": 1})[0] exp = storage.fetch_experiments({"name": "experiment_2"})[0] assert exp["refers"]["parent_id"] == parent["_id"] diff --git a/tests/functional/client/test_cli_client.py b/tests/functional/client/test_cli_client.py index e577aadc2..2a0115e7b 100644 --- a/tests/functional/client/test_cli_client.py +++ b/tests/functional/client/test_cli_client.py @@ -13,7 +13,7 @@ def test_interrupt(monkeypatch, capsys): """Test interruption from within user script.""" with OrionState() as cfg: - storage = cfg.storage() + storage = cfg.storage monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -49,7 +49,7 @@ def test_interrupt(monkeypatch, capsys): assert trials[0].status == "interrupted" -def test_interrupt_diff_code(storage, monkeypatch, capsys): +def test_interrupt_diff_code(monkeypatch, capsys, storage): """Test interruption from within user script with custom int code""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -132,87 +132,52 @@ def empty_env(self, trial, results_file=None): @pytest.mark.parametrize("fct", ["report_bad_trial", "report_objective"]) -def test_report_no_name(storage, monkeypatch, fct): +def test_report_no_name(monkeypatch, fct): """Test report helper functions with default names""" - monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - - user_args = ["-x~uniform(-50, 50, precision=5)"] - - orion.core.cli.main( - [ - "hunt", - "--config", - "./orion_config.yaml", - "--exp-max-trials", - "2", - "--worker-trials", - "2", - "python", - "black_box.py", - fct, - "--objective", - "1.0", - ] - + user_args - ) - - exp = list(storage.fetch_experiments({"name": "voila_voici"})) - exp = exp[0] - exp_id = exp["_id"] - trials = list(storage.fetch_trials(uid=exp_id)) - assert len(trials) == 2 - assert trials[0].status == "completed" - assert trials[0].results[0].name == "objective" - assert trials[0].results[0].type == "objective" - assert trials[0].results[0].value == 1.0 - -@pytest.mark.parametrize("fct", ["report_bad_trial", "report_objective"]) -def test_report_with_name(storage, monkeypatch, fct): - """Test report helper functions with custom names""" - monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) + with OrionState() as cfg: + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - user_args = ["-x~uniform(-50, 50, precision=5)"] + user_args = ["-x~uniform(-50, 50, precision=5)"] - orion.core.cli.main( - [ - "hunt", - "--config", - "./orion_config.yaml", - "--exp-max-trials", - "2", - "--worker-trials", - "2", - "python", - "black_box.py", - fct, - "--objective", - "1.0", - "--name", - "metric", - ] - + user_args - ) + orion.core.cli.main( + [ + "hunt", + "--config", + "./orion_config.yaml", + "--exp-max-trials", + "2", + "--worker-trials", + "2", + "python", + "black_box.py", + fct, + "--objective", + "1.0", + ] + + user_args + ) - exp = list(storage.fetch_experiments({"name": "voila_voici"})) - exp = exp[0] - exp_id = exp["_id"] - trials = list(storage.fetch_trials(uid=exp_id)) - assert len(trials) == 2 - assert trials[0].status == "completed" - assert trials[0].results[0].name == "metric" - assert trials[0].results[0].type == "objective" - assert trials[0].results[0].value == 1.0 + exp = list(cfg.storage.fetch_experiments({"name": "voila_voici"})) + exp = exp[0] + exp_id = exp["_id"] + trials = list(cfg.storage.fetch_trials(uid=exp_id)) + assert len(trials) == 2 + assert trials[0].status == "completed" + assert trials[0].results[0].name == "objective" + assert trials[0].results[0].type == "objective" + assert trials[0].results[0].value == 1.0 @pytest.mark.parametrize("fct", ["report_bad_trial", "report_objective"]) -def test_report_with_bad_objective(storage, monkeypatch, fct): - """Test report helper functions with bad objective types""" - monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) +def test_report_with_name(monkeypatch, fct): + """Test report helper functions with custom names""" + with OrionState() as cfg: + storage = cfg.storage + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - user_args = ["-x~uniform(-50, 50, precision=5)"] + user_args = ["-x~uniform(-50, 50, precision=5)"] - with pytest.raises(InvalidResult) as exc: orion.core.cli.main( [ "hunt", @@ -226,110 +191,157 @@ def test_report_with_bad_objective(storage, monkeypatch, fct): "black_box.py", fct, "--objective", - "oh oh", + "1.0", + "--name", + "metric", ] + user_args ) - assert "must contain a type `objective` with type float/int" in str(exc.value) + exp = list(storage.fetch_experiments({"name": "voila_voici"})) + exp = exp[0] + exp_id = exp["_id"] + trials = list(storage.fetch_trials(uid=exp_id)) + assert len(trials) == 2 + assert trials[0].status == "completed" + assert trials[0].results[0].name == "metric" + assert trials[0].results[0].type == "objective" + assert trials[0].results[0].value == 1.0 + + +@pytest.mark.parametrize("fct", ["report_bad_trial", "report_objective"]) +def test_report_with_bad_objective(monkeypatch, fct): + """Test report helper functions with bad objective types""" + with OrionState() as cfg: + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) + user_args = ["-x~uniform(-50, 50, precision=5)"] -def test_report_with_bad_trial_no_objective(storage, monkeypatch): + with pytest.raises(InvalidResult) as exc: + orion.core.cli.main( + [ + "hunt", + "--config", + "./orion_config.yaml", + "--exp-max-trials", + "2", + "--worker-trials", + "2", + "python", + "black_box.py", + fct, + "--objective", + "oh oh", + ] + + user_args + ) + + assert "must contain a type `objective` with type float/int" in str(exc.value) + + +def test_report_with_bad_trial_no_objective(monkeypatch): """Test bad trial report helper function with default objective.""" - monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) + with OrionState() as cfg: + storage = cfg.storage - user_args = ["-x~uniform(-50, 50, precision=5)"] + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - orion.core.cli.main( - [ - "hunt", - "--config", - "./orion_config.yaml", - "--exp-max-trials", - "2", - "--worker-trials", - "2", - "python", - "black_box.py", - "report_bad_trial", - ] - + user_args - ) + user_args = ["-x~uniform(-50, 50, precision=5)"] - exp = list(storage.fetch_experiments({"name": "voila_voici"})) - exp = exp[0] - exp_id = exp["_id"] - trials = list(storage.fetch_trials(uid=exp_id)) - assert len(trials) == 2 - assert trials[0].status == "completed" - assert trials[0].results[0].name == "objective" - assert trials[0].results[0].type == "objective" - assert trials[0].results[0].value == 1e10 + orion.core.cli.main( + [ + "hunt", + "--config", + "./orion_config.yaml", + "--exp-max-trials", + "2", + "--worker-trials", + "2", + "python", + "black_box.py", + "report_bad_trial", + ] + + user_args + ) + exp = list(storage.fetch_experiments({"name": "voila_voici"})) + exp = exp[0] + exp_id = exp["_id"] + trials = list(storage.fetch_trials(uid=exp_id)) + assert len(trials) == 2 + assert trials[0].status == "completed" + assert trials[0].results[0].name == "objective" + assert trials[0].results[0].type == "objective" + assert trials[0].results[0].value == 1e10 -def test_report_with_bad_trial_with_data(storage, monkeypatch): + +def test_report_with_bad_trial_with_data(monkeypatch): """Test bad trial report helper function with additional data.""" - monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) + with OrionState() as cfg: + storage = cfg.storage + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - user_args = ["-x~uniform(-50, 50, precision=5)"] + user_args = ["-x~uniform(-50, 50, precision=5)"] - orion.core.cli.main( - [ - "hunt", - "--config", - "./orion_config.yaml", - "--exp-max-trials", - "2", - "--worker-trials", - "2", - "python", - "black_box.py", - "report_bad_trial", - "--data", - "another", - ] - + user_args - ) + orion.core.cli.main( + [ + "hunt", + "--config", + "./orion_config.yaml", + "--exp-max-trials", + "2", + "--worker-trials", + "2", + "python", + "black_box.py", + "report_bad_trial", + "--data", + "another", + ] + + user_args + ) - exp = list(storage.fetch_experiments({"name": "voila_voici"})) - exp = exp[0] - exp_id = exp["_id"] - trials = list(storage.fetch_trials(uid=exp_id)) - assert len(trials) == 2 - assert trials[0].status == "completed" - assert trials[0].results[0].name == "objective" - assert trials[0].results[0].type == "objective" - assert trials[0].results[0].value == 1e10 + exp = list(storage.fetch_experiments({"name": "voila_voici"})) + exp = exp[0] + exp_id = exp["_id"] + trials = list(storage.fetch_trials(uid=exp_id)) + assert len(trials) == 2 + assert trials[0].status == "completed" + assert trials[0].results[0].name == "objective" + assert trials[0].results[0].type == "objective" + assert trials[0].results[0].value == 1e10 - assert trials[0].results[1].name == "another" - assert trials[0].results[1].type == "constraint" - assert trials[0].results[1].value == 1.0 + assert trials[0].results[1].name == "another" + assert trials[0].results[1].type == "constraint" + assert trials[0].results[1].value == 1.0 -def test_no_report(storage, monkeypatch, capsys): +def test_no_report(monkeypatch, capsys): """Test script call without any results reported.""" - monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) + with OrionState() as cfg: + storage = cfg.storage + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - user_args = ["-x~uniform(-50, 50, precision=5)"] + user_args = ["-x~uniform(-50, 50, precision=5)"] - errorcode = orion.core.cli.main( - [ - "hunt", - "--config", - "./orion_config.yaml", - "--exp-max-trials", - "2", - "--worker-trials", - "2", - "python", - "black_box.py", - "no_report", - ] - + user_args - ) + errorcode = orion.core.cli.main( + [ + "hunt", + "--config", + "./orion_config.yaml", + "--exp-max-trials", + "2", + "--worker-trials", + "2", + "python", + "black_box.py", + "no_report", + ] + + user_args + ) - assert errorcode == 1 + assert errorcode == 1 - captured = capsys.readouterr() - assert captured.out == "" - assert "Cannot parse result file" in captured.err + captured = capsys.readouterr() + assert captured.out == "" + assert "Cannot parse result file" in captured.err diff --git a/tests/functional/commands/conftest.py b/tests/functional/commands/conftest.py index e4e28cf58..1fe757c3b 100644 --- a/tests/functional/commands/conftest.py +++ b/tests/functional/commands/conftest.py @@ -10,9 +10,7 @@ import orion.core.cli import orion.core.io.experiment_builder as experiment_builder import orion.core.utils.backward as backward -from orion.core.io.database import database_factory from orion.core.worker.trial import Trial -from orion.storage.base import get_storage @pytest.fixture() @@ -62,13 +60,13 @@ def one_experiment(monkeypatch, storage): ["hunt", "--init-only", "-n", name, "./black_box.py", "--x~uniform(0,1)"] ) ensure_deterministic_id(name, storage) - return get_storage().fetch_experiments({"name": name})[0] + return storage.fetch_experiments({"name": name})[0] @pytest.fixture -def one_experiment_changed_vcs(one_experiment): +def one_experiment_changed_vcs(storage, one_experiment): """Create an experiment without trials.""" - experiment = experiment_builder.build(name=one_experiment["name"]) + experiment = experiment_builder.build(name=one_experiment["name"], storage=storage) experiment.metadata["VCS"] = { "type": "git", @@ -78,46 +76,45 @@ def one_experiment_changed_vcs(one_experiment): "diff_sha": None, } - get_storage().update_experiment(experiment, metadata=experiment.metadata) + storage.update_experiment(experiment, metadata=experiment.metadata) @pytest.fixture -def one_experiment_no_version(monkeypatch, one_experiment): +def one_experiment_no_version(monkeypatch, one_experiment, storage): """Create an experiment without trials.""" one_experiment["name"] = one_experiment["name"] + "-no-version" one_experiment.pop("version") - def fetch_without_version(query, selection=None): + def fetch_without_version(self, query, selection=None): if query.get("name") == one_experiment["name"] or query == {}: return [copy.deepcopy(one_experiment)] return [] - monkeypatch.setattr(get_storage(), "fetch_experiments", fetch_without_version) + monkeypatch.setattr(type(storage), "fetch_experiments", fetch_without_version) return one_experiment @pytest.fixture -def with_experiment_using_python_api(monkeypatch, one_experiment): +def with_experiment_using_python_api(storage, monkeypatch, one_experiment): """Create an experiment without trials.""" experiment = experiment_builder.build( - name="from-python-api", space={"x": "uniform(0, 10)"} + name="from-python-api", space={"x": "uniform(0, 10)"}, storage=storage ) return experiment @pytest.fixture -def with_experiment_missing_conf_file(monkeypatch, one_experiment): +def with_experiment_missing_conf_file(monkeypatch, one_experiment, storage, orionstate): """Create an experiment without trials.""" - exp = experiment_builder.build(name="test_single_exp", version=1) + exp = experiment_builder.build(name="test_single_exp", version=1, storage=storage) conf_file = "idontexist.yaml" exp.metadata["user_config"] = conf_file exp.metadata["user_args"] += ["--config", conf_file] - database_factory.create().write( - "experiments", exp.configuration, query={"_id": exp.id} - ) + + orionstate.database.write("experiments", exp.configuration, query={"_id": exp.id}) return exp @@ -131,12 +128,12 @@ def broken_refers(one_experiment, storage): @pytest.fixture -def single_without_success(one_experiment): +def single_without_success(one_experiment, orionstate, storage): """Create an experiment without a successful trial.""" statuses = list(Trial.allowed_stati) statuses.remove("completed") - exp = experiment_builder.build(name="test_single_exp") + exp = experiment_builder.build(name="test_single_exp", storage=storage) x = {"name": "/x", "type": "real"} x_value = 0 @@ -144,18 +141,18 @@ def single_without_success(one_experiment): x["value"] = x_value trial = Trial(experiment=exp.id, params=[x], status=status) x_value += 1 - database_factory.create().write("trials", trial.to_dict()) + orionstate.database.write("trials", trial.to_dict()) @pytest.fixture -def single_with_trials(single_without_success): +def single_with_trials(single_without_success, orionstate, storage): """Create an experiment with all types of trials.""" - exp = experiment_builder.build(name="test_single_exp") + exp = experiment_builder.build(name="test_single_exp", storage=storage) x = {"name": "/x", "type": "real", "value": 100} results = {"name": "obj", "type": "objective", "value": 0} trial = Trial(experiment=exp.id, params=[x], status="completed", results=[results]) - database_factory.create().write("trials", trial.to_dict()) + orionstate.database.write("trials", trial.to_dict()) return exp.configuration @@ -193,10 +190,13 @@ def two_experiments(monkeypatch, storage): @pytest.fixture -def family_with_trials(two_experiments): +def family_with_trials(two_experiments, orionstate): """Create two related experiments with all types of trials.""" - exp = experiment_builder.build(name="test_double_exp") - exp2 = experiment_builder.build(name="test_double_exp_child") + + exp = experiment_builder.build(name="test_double_exp", storage=orionstate.storage) + exp2 = experiment_builder.build( + name="test_double_exp_child", storage=orionstate.storage + ) x = {"name": "/x", "type": "real"} y = {"name": "/y", "type": "real"} @@ -208,17 +208,19 @@ def family_with_trials(two_experiments): x["value"] = x_value + 0.5 # To avoid duplicates trial2 = Trial(experiment=exp2.id, params=[x, y], status=status) x_value += 1 - database_factory.create().write("trials", trial.to_dict()) - database_factory.create().write("trials", trial2.to_dict()) + orionstate.database.write("trials", trial.to_dict()) + orionstate.database.write("trials", trial2.to_dict()) @pytest.fixture -def unrelated_with_trials(family_with_trials, single_with_trials): +def unrelated_with_trials(family_with_trials, single_with_trials, orionstate): """Create two unrelated experiments with all types of trials.""" - exp = experiment_builder.build(name="test_double_exp_child") + exp = experiment_builder.build( + name="test_double_exp_child", storage=orionstate.storage + ) - database_factory.create().remove("trials", {"experiment": exp.id}) - database_factory.create().remove("experiments", {"_id": exp.id}) + orionstate.database.remove("trials", {"experiment": exp.id}) + orionstate.database.remove("experiments", {"_id": exp.id}) @pytest.fixture @@ -252,9 +254,11 @@ def three_experiments_family(two_experiments, storage): @pytest.fixture -def three_family_with_trials(three_experiments_family, family_with_trials): +def three_family_with_trials(three_experiments_family, family_with_trials, orionstate): """Create three experiments, all related, two direct children, with all types of trials.""" - exp = experiment_builder.build(name="test_double_exp_child2") + exp = experiment_builder.build( + name="test_double_exp_child2", storage=orionstate.storage + ) x = {"name": "/x", "type": "real"} z = {"name": "/z", "type": "real"} @@ -264,7 +268,7 @@ def three_family_with_trials(three_experiments_family, family_with_trials): z["value"] = x_value * 100 trial = Trial(experiment=exp.id, params=[x, z], status=status) x_value += 1 - database_factory.create().write("trials", trial.to_dict()) + orionstate.database.write("trials", trial.to_dict()) @pytest.fixture @@ -290,13 +294,15 @@ def three_experiments_family_branch(two_experiments, storage): @pytest.fixture def three_family_branch_with_trials( - three_experiments_family_branch, family_with_trials + three_experiments_family_branch, family_with_trials, orionstate ): """Create three experiments, all related, one child and one grandchild, with all types of trials. """ - exp = experiment_builder.build(name="test_double_exp_grand_child") + exp = experiment_builder.build( + name="test_double_exp_grand_child", storage=orionstate.storage + ) x = {"name": "/x", "type": "real"} y = {"name": "/y", "type": "real"} z = {"name": "/z", "type": "real"} @@ -308,7 +314,7 @@ def three_family_branch_with_trials( z["value"] = x_value * 100 trial = Trial(experiment=exp.id, params=[x, y, z], status=status) x_value += 1 - database_factory.create().write("trials", trial.to_dict()) + orionstate.database.write("trials", trial.to_dict()) @pytest.fixture @@ -396,8 +402,11 @@ def three_experiments_same_name(two_experiments_same_name, storage): @pytest.fixture -def three_experiments_same_name_with_trials(two_experiments_same_name, storage): +def three_experiments_same_name_with_trials( + two_experiments_same_name, orionstate, storage +): """Create three experiments with the same name but different versions.""" + orion.core.cli.main( [ "hunt", @@ -413,9 +422,9 @@ def three_experiments_same_name_with_trials(two_experiments_same_name, storage): ) ensure_deterministic_id("test_single_exp", storage, version=3) - exp = experiment_builder.build(name="test_single_exp", version=1) - exp2 = experiment_builder.build(name="test_single_exp", version=2) - exp3 = experiment_builder.build(name="test_single_exp", version=3) + exp = experiment_builder.build(name="test_single_exp", version=1, storage=storage) + exp2 = experiment_builder.build(name="test_single_exp", version=2, storage=storage) + exp3 = experiment_builder.build(name="test_single_exp", version=3, storage=storage) x = {"name": "/x", "type": "real"} y = {"name": "/y", "type": "real"} @@ -428,7 +437,7 @@ def three_experiments_same_name_with_trials(two_experiments_same_name, storage): trial = Trial(experiment=exp.id, params=[x], status=status) trial2 = Trial(experiment=exp2.id, params=[x, y], status=status) trial3 = Trial(experiment=exp3.id, params=[x, y, z], status=status) - database_factory.create().write("trials", trial.to_dict()) - database_factory.create().write("trials", trial2.to_dict()) - database_factory.create().write("trials", trial3.to_dict()) + orionstate.database.write("trials", trial.to_dict()) + orionstate.database.write("trials", trial2.to_dict()) + orionstate.database.write("trials", trial3.to_dict()) x_value += 1 diff --git a/tests/functional/commands/test_db_release.py b/tests/functional/commands/test_db_release.py index ee2e04992..08a4aa64e 100644 --- a/tests/functional/commands/test_db_release.py +++ b/tests/functional/commands/test_db_release.py @@ -4,7 +4,7 @@ import pytest import orion.core.cli -from orion.storage.base import get_storage +from orion.storage.base import setup_storage def execute(command, assert_code=0): @@ -13,7 +13,7 @@ def execute(command, assert_code=0): assert returncode == assert_code -def test_no_exp(setup_pickleddb_database, capsys): +def test_no_exp(orionstate, capsys): """Test that releasing non-existing exp exits gracefully""" execute("db release i-dont-exist", assert_code=1) @@ -40,35 +40,35 @@ def correct_name(*args): monkeypatch.setattr("builtins.input", correct_name) - experiments = get_storage().fetch_experiments({}) + experiments = setup_storage().fetch_experiments({}) uid = experiments[0]["_id"] - with get_storage().acquire_algorithm_lock(uid=uid) as algo_state_lock: + with setup_storage().acquire_algorithm_lock(uid=uid) as algo_state_lock: assert algo_state_lock.state is None algo_state_lock.set_state({}) - with get_storage().acquire_algorithm_lock(uid=uid) as algo_state_lock: + with setup_storage().acquire_algorithm_lock(uid=uid) as algo_state_lock: assert algo_state_lock.state == {} - assert get_storage().get_algorithm_lock_info(uid=uid).locked == 1 + assert setup_storage().get_algorithm_lock_info(uid=uid).locked == 1 execute("db release test_single_exp") - assert get_storage().get_algorithm_lock_info(uid=uid).locked == 0 - assert get_storage().get_algorithm_lock_info(uid=uid).state == {} + assert setup_storage().get_algorithm_lock_info(uid=uid).locked == 0 + assert setup_storage().get_algorithm_lock_info(uid=uid).state == {} def test_one_exp(single_with_trials): """Test that one exp is deleted properly""" - experiments = get_storage().fetch_experiments({}) + experiments = setup_storage().fetch_experiments({}) uid = experiments[0]["_id"] - assert get_storage().get_algorithm_lock_info(uid=uid).locked == 0 - with get_storage().acquire_algorithm_lock(uid=uid): - assert get_storage().get_algorithm_lock_info(uid=uid).locked == 1 + assert setup_storage().get_algorithm_lock_info(uid=uid).locked == 0 + with setup_storage().acquire_algorithm_lock(uid=uid): + assert setup_storage().get_algorithm_lock_info(uid=uid).locked == 1 execute("db release -f test_single_exp") - assert get_storage().get_algorithm_lock_info(uid=uid).locked == 0 + assert setup_storage().get_algorithm_lock_info(uid=uid).locked == 0 def test_release_name(three_family_branch_with_trials): """Test that deleting an experiment removes all children""" - experiments = get_storage().fetch_experiments({}) - storage = get_storage() + experiments = setup_storage().fetch_experiments({}) + storage = setup_storage() assert len(experiments) == 3 assert len(storage._fetch_trials({})) > 0 uid = None @@ -97,8 +97,8 @@ def test_release_name(three_family_branch_with_trials): def test_release_version(three_experiments_same_name_with_trials): """Test releasing a specific experiment version""" - experiments = get_storage().fetch_experiments({}) - storage = get_storage() + experiments = setup_storage().fetch_experiments({}) + storage = setup_storage() assert len(experiments) == 3 assert len(storage._fetch_trials({})) > 0 uid = None @@ -127,8 +127,8 @@ def test_release_version(three_experiments_same_name_with_trials): def test_release_default_leaf(three_experiments_same_name_with_trials): """Test that release an experiment releases the leaf by default""" - experiments = get_storage().fetch_experiments({}) - storage = get_storage() + experiments = setup_storage().fetch_experiments({}) + storage = setup_storage() assert len(experiments) == 3 assert len(storage._fetch_trials({})) > 0 uid = None diff --git a/tests/functional/commands/test_db_rm.py b/tests/functional/commands/test_db_rm.py index b3cd67763..d54ce4deb 100644 --- a/tests/functional/commands/test_db_rm.py +++ b/tests/functional/commands/test_db_rm.py @@ -5,7 +5,7 @@ import pytest import orion.core.cli -from orion.storage.base import get_storage +from orion.storage.base import setup_storage def hsh(name, version): @@ -18,7 +18,7 @@ def execute(command, assert_code=0): assert returncode == assert_code -def test_no_exp(setup_pickleddb_database, capsys): +def test_no_exp(orionstate, capsys): """Test that rm non-existing exp exits gracefully""" execute("db rm i-dont-exist", assert_code=1) @@ -45,125 +45,142 @@ def correct_name(*args): monkeypatch.setattr("builtins.input", correct_name) - assert len(get_storage().fetch_experiments({})) == 1 + assert len(setup_storage().fetch_experiments({})) == 1 execute("db rm test_single_exp") - assert len(get_storage().fetch_experiments({})) == 0 + assert len(setup_storage().fetch_experiments({})) == 0 def test_one_exp(single_with_trials): """Test that one exp is deleted properly""" - experiments = get_storage().fetch_experiments({}) + experiments = setup_storage().fetch_experiments({}) assert len(experiments) == 1 - assert len(get_storage()._fetch_trials({})) > 0 - assert get_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) is not None + assert len(setup_storage()._fetch_trials({})) > 0 + assert ( + setup_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) is not None + ) execute("db rm -f test_single_exp") - assert len(get_storage().fetch_experiments({})) == 0 - assert len(get_storage()._fetch_trials({})) == 0 - assert get_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) is None + assert len(setup_storage().fetch_experiments({})) == 0 + assert len(setup_storage()._fetch_trials({})) == 0 + assert setup_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) is None def test_rm_all_evc(three_family_branch_with_trials): """Test that deleting root removes all experiments""" - experiments = get_storage().fetch_experiments({}) + experiments = setup_storage().fetch_experiments({}) assert len(experiments) == 3 - assert len(get_storage()._fetch_trials({})) > 0 + assert len(setup_storage()._fetch_trials({})) > 0 for experiment in experiments: assert ( - get_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) + setup_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) is not None ) execute("db rm -f test_double_exp --version 1") - assert len(get_storage().fetch_experiments({})) == 0 - assert len(get_storage()._fetch_trials({})) == 0 + assert len(setup_storage().fetch_experiments({})) == 0 + assert len(setup_storage()._fetch_trials({})) == 0 for experiment in experiments: - assert get_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) is None + assert ( + setup_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) is None + ) def test_rm_under_evc(three_family_branch_with_trials): """Test that deleting an experiment removes all children""" - experiments = get_storage().fetch_experiments({}) + experiments = setup_storage().fetch_experiments({}) assert len(experiments) == 3 - assert len(get_storage()._fetch_trials({})) > 0 + assert len(setup_storage()._fetch_trials({})) > 0 for experiment in experiments: assert ( - get_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) + setup_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) is not None ) execute("db rm -f test_double_exp_child --version 1") - assert len(get_storage().fetch_experiments({})) == 1 + assert len(setup_storage().fetch_experiments({})) == 1 for experiment in experiments: if experiment["name"] == "test_double_exp": assert ( - len(get_storage()._fetch_trials({"experiment": experiment["_id"]})) > 0 + len(setup_storage()._fetch_trials({"experiment": experiment["_id"]})) + > 0 ) assert ( - get_storage().get_algorithm_lock_info(uid=experiment["_id"]) is not None + setup_storage().get_algorithm_lock_info(uid=experiment["_id"]) + is not None ) else: assert ( - len(get_storage()._fetch_trials({"experiment": experiment["_id"]})) == 0 + len(setup_storage()._fetch_trials({"experiment": experiment["_id"]})) + == 0 + ) + assert ( + setup_storage().get_algorithm_lock_info(uid=experiment["_id"]) is None ) - assert get_storage().get_algorithm_lock_info(uid=experiment["_id"]) is None def test_rm_default_leaf(three_experiments_same_name_with_trials): """Test that deleting an experiment removes the leaf by default""" - experiments = get_storage().fetch_experiments({}) + experiments = setup_storage().fetch_experiments({}) assert len(experiments) == 3 - assert len(get_storage()._fetch_trials({})) > 0 + assert len(setup_storage()._fetch_trials({})) > 0 for experiment in experiments: assert ( - get_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) + setup_storage().get_algorithm_lock_info(uid=experiments[-1]["_id"]) is not None ) execute("db rm -f test_single_exp") - assert len(get_storage().fetch_experiments({})) == 2 + assert len(setup_storage().fetch_experiments({})) == 2 for experiment in experiments: if experiment["version"] == 3: assert ( - len(get_storage()._fetch_trials({"experiment": experiment["_id"]})) == 0 + len(setup_storage()._fetch_trials({"experiment": experiment["_id"]})) + == 0 + ) + assert ( + setup_storage().get_algorithm_lock_info(uid=experiment["_id"]) is None ) - assert get_storage().get_algorithm_lock_info(uid=experiment["_id"]) is None else: assert ( - len(get_storage()._fetch_trials({"experiment": experiment["_id"]})) > 0 + len(setup_storage()._fetch_trials({"experiment": experiment["_id"]})) + > 0 ) assert ( - get_storage().get_algorithm_lock_info(uid=experiment["_id"]) is not None + setup_storage().get_algorithm_lock_info(uid=experiment["_id"]) + is not None ) def test_rm_trials_by_status(single_with_trials): """Test that trials can be deleted by status""" - trials = get_storage()._fetch_trials({}) + trials = setup_storage()._fetch_trials({}) n_broken = sum(trial.status == "broken" for trial in trials) assert n_broken > 0 execute("db rm -f test_single_exp --status broken") - assert len(get_storage()._fetch_trials({})) == len(trials) - n_broken + assert len(setup_storage()._fetch_trials({})) == len(trials) - n_broken def test_rm_trials_all(single_with_trials): """Test that trials all be deleted with '*'""" - assert len(get_storage()._fetch_trials({})) > 0 + assert len(setup_storage()._fetch_trials({})) > 0 execute("db rm -f test_single_exp --status *") - assert len(get_storage()._fetch_trials({})) == 0 + assert len(setup_storage()._fetch_trials({})) == 0 def test_rm_trials_in_evc(three_family_branch_with_trials): """Test that trials of parent experiment are not deleted""" - assert len(get_storage().fetch_experiments({})) == 3 + assert len(setup_storage().fetch_experiments({})) == 3 assert ( - len(get_storage()._fetch_trials({"experiment": hsh("test_double_exp", 1)})) > 0 + len(setup_storage()._fetch_trials({"experiment": hsh("test_double_exp", 1)})) + > 0 ) assert ( len( - get_storage()._fetch_trials({"experiment": hsh("test_double_exp_child", 1)}) + setup_storage()._fetch_trials( + {"experiment": hsh("test_double_exp_child", 1)} + ) ) > 0 ) assert ( len( - get_storage()._fetch_trials( + setup_storage()._fetch_trials( {"experiment": hsh("test_double_exp_grand_child", 1)} ) ) @@ -171,20 +188,23 @@ def test_rm_trials_in_evc(three_family_branch_with_trials): ) execute("db rm -f test_double_exp_child --status *") # Make sure no experiments were deleted - assert len(get_storage().fetch_experiments({})) == 3 + assert len(setup_storage().fetch_experiments({})) == 3 # Make sure only trials of given experiment were deleted assert ( - len(get_storage()._fetch_trials({"experiment": hsh("test_double_exp", 1)})) > 0 + len(setup_storage()._fetch_trials({"experiment": hsh("test_double_exp", 1)})) + > 0 ) assert ( len( - get_storage()._fetch_trials({"experiment": hsh("test_double_exp_child", 1)}) + setup_storage()._fetch_trials( + {"experiment": hsh("test_double_exp_child", 1)} + ) ) == 0 ) assert ( len( - get_storage()._fetch_trials( + setup_storage()._fetch_trials( {"experiment": hsh("test_double_exp_grand_child", 1)} ) ) diff --git a/tests/functional/commands/test_db_set.py b/tests/functional/commands/test_db_set.py index 000674dcb..566c17951 100644 --- a/tests/functional/commands/test_db_set.py +++ b/tests/functional/commands/test_db_set.py @@ -2,7 +2,7 @@ """Perform functional tests for db set.""" import orion.core.cli import orion.core.io.experiment_builder as experiment_builder -from orion.storage.base import get_storage +from orion.storage.base import setup_storage def execute(command, assert_code=0): @@ -37,9 +37,9 @@ def correct_name(*args): monkeypatch.setattr("builtins.input", correct_name) - assert len(get_storage()._fetch_trials({"status": "broken"})) > 0 + assert len(setup_storage()._fetch_trials({"status": "broken"})) > 0 execute("db set test_single_exp status=broken status=interrupted") - assert len(get_storage()._fetch_trials({"status": "broken"})) == 0 + assert len(setup_storage()._fetch_trials({"status": "broken"})) == 0 def test_invalid_query(single_with_trials, capsys): @@ -62,11 +62,11 @@ def test_invalid_update(single_with_trials, capsys): def test_update_trial(single_with_trials, capsys): """Test that trial is updated properly""" - trials = get_storage()._fetch_trials({}) + trials = setup_storage()._fetch_trials({}) assert sum(trial.status == "broken" for trial in trials) > 0 trials = dict(zip((trial.id for trial in trials), trials)) execute("db set -f test_single_exp status=broken status=interrupted") - for trial in get_storage()._fetch_trials({}): + for trial in setup_storage()._fetch_trials({}): if trials[trial.id].status == "broken": assert trial.status == "interrupted", "status not changed properly" else: @@ -81,11 +81,11 @@ def test_update_trial(single_with_trials, capsys): def test_update_trial_with_id(single_with_trials, capsys): """Test that trial is updated properly when querying with the id""" - trials = get_storage()._fetch_trials({}) + trials = setup_storage()._fetch_trials({}) trials = dict(zip((trial.id for trial in trials), trials)) - trial = get_storage()._fetch_trials({"status": "broken"})[0] + trial = setup_storage()._fetch_trials({"status": "broken"})[0] execute(f"db set -f test_single_exp id={trial.id} status=interrupted") - for new_trial in get_storage()._fetch_trials({}): + for new_trial in setup_storage()._fetch_trials({}): if new_trial.id == trial.id: assert new_trial.status == "interrupted", "status not changed properly" else: @@ -99,10 +99,10 @@ def test_update_trial_with_id(single_with_trials, capsys): def test_update_no_match_query(single_with_trials, capsys): """Test that no trials are updated when there is no match""" - trials = get_storage()._fetch_trials({}) + trials = setup_storage()._fetch_trials({}) trials = dict(zip((trial.id for trial in trials), trials)) execute("db set -f test_single_exp status=invalid status=interrupted") - for trial in get_storage()._fetch_trials({}): + for trial in setup_storage()._fetch_trials({}): assert ( trials[trial.id].status == trial.status ), "status should not have been changed" diff --git a/tests/functional/commands/test_list_command.py b/tests/functional/commands/test_list_command.py index 1fcc94dd4..8b5bc81bb 100644 --- a/tests/functional/commands/test_list_command.py +++ b/tests/functional/commands/test_list_command.py @@ -5,7 +5,7 @@ import orion.core.cli -def test_no_exp(monkeypatch, setup_pickleddb_database, capsys): +def test_no_exp(monkeypatch, orionstate, capsys): """Test that nothing is printed when there are no experiments.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) orion.core.cli.main(["list"]) @@ -24,7 +24,7 @@ def test_single_exp(one_experiment, capsys): assert captured == " test_single_exp-v1\n" -def test_no_version_backward_compatible(one_experiment_no_version, capsys): +def test_no_version_backward_compatible(one_experiment_no_version, capsys, storage): """Test status with no experiments.""" orion.core.cli.main(["list"]) diff --git a/tests/functional/commands/test_plot_commands.py b/tests/functional/commands/test_plot_commands.py index 50ec290e2..6e85c9805 100644 --- a/tests/functional/commands/test_plot_commands.py +++ b/tests/functional/commands/test_plot_commands.py @@ -8,7 +8,7 @@ import orion.plotting.backend_plotly from orion.core.cli.plot import IMAGE_TYPES, VALID_TYPES from orion.plotting.base import SINGLE_EXPERIMENT_PLOTS -from orion.storage.base import get_storage +from orion.storage.base import setup_storage from orion.testing import AssertNewFile @@ -158,7 +158,7 @@ def mock_plot(*args, **kwargs): monkeypatch.setattr(f"orion.plotting.backend_plotly.{kind}", mock_plot) - assert len(get_storage().fetch_trials(uid=single_with_trials["_id"])) > 0 + assert len(setup_storage().fetch_trials(uid=single_with_trials["_id"])) > 0 filename = f"test_single_exp-v1_{kind}.png" with AssertNewFile(filename): @@ -195,7 +195,7 @@ def check_args(self, output, scale, **kwargs): def test_no_trials(one_experiment, kind): """Test plotting works with empty experiments""" - assert get_storage().fetch_trials(uid=one_experiment["_id"]) == [] + assert setup_storage().fetch_trials(uid=one_experiment["_id"]) == [] filename = f"test_single_exp-v1_{kind}.png" with AssertNewFile(filename): diff --git a/tests/functional/commands/test_status_command.py b/tests/functional/commands/test_status_command.py index 12001508d..6bdd6ac2d 100644 --- a/tests/functional/commands/test_status_command.py +++ b/tests/functional/commands/test_status_command.py @@ -7,7 +7,7 @@ import orion.core.cli -def test_no_experiments(setup_pickleddb_database, monkeypatch, capsys): +def test_no_experiments(orionstate, monkeypatch, capsys): """Test status with no experiments.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) orion.core.cli.main(["status"]) @@ -17,7 +17,7 @@ def test_no_experiments(setup_pickleddb_database, monkeypatch, capsys): assert captured == "No experiment found\n" -def test_no_version_backward_compatible(one_experiment_no_version, capsys): +def test_no_version_backward_compatible(one_experiment_no_version, capsys, storage): """Test status with no experiments.""" orion.core.cli.main(["status"]) @@ -951,7 +951,7 @@ def test_three_related_branch_w_ac(three_family_branch_with_trials, capsys): assert captured == expected -def test_no_experiments_w_name(setup_pickleddb_database, monkeypatch, capsys): +def test_no_experiments_w_name(orionstate, monkeypatch, capsys): """Test status when --name does not exist.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) orion.core.cli.main(["status", "--name", "test_ghost_exp"]) diff --git a/tests/functional/commands/test_verbose_messages.py b/tests/functional/commands/test_verbose_messages.py index e94152b02..02fe89a21 100644 --- a/tests/functional/commands/test_verbose_messages.py +++ b/tests/functional/commands/test_verbose_messages.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """Perform a functional test of the debug verbosity level.""" import logging +import os import pytest @@ -24,3 +25,15 @@ def test_version_print_debug_verbosity(caplog): text.startswith("Orion version : ") and (loggerlevel != logging.DEBUG) ) assert "Orion version : " in caplog.text + + +def test_log_directory(tmp_path): + """Tests that Orion creates a log directory when the option is specified""" + + logging.shutdown() + + logdir = f"{tmp_path}/mydir" + with pytest.raises(SystemExit): + orion.core.cli.main(["-vvv", "--logdir", logdir]) + + assert os.path.exists(logdir), "Log dir was created" diff --git a/tests/functional/configuration/test_all_options.py b/tests/functional/configuration/test_all_options.py index 33387d0e4..07853e2d3 100644 --- a/tests/functional/configuration/test_all_options.py +++ b/tests/functional/configuration/test_all_options.py @@ -16,9 +16,9 @@ import orion.core.evc.conflicts import orion.core.io.resolve_config from orion.client import get_experiment +from orion.core.io import experiment_builder from orion.core.io.database.pickleddb import PickledDB -from orion.core.utils.singleton import SingletonNotInstantiatedError, update_singletons -from orion.storage.base import get_storage +from orion.storage.base import setup_storage from orion.storage.legacy import Legacy from orion.testing.state import OrionState @@ -37,14 +37,16 @@ def with_storage_fork(func): def call(*args, **kwargs): with tempfile.NamedTemporaryFile(delete=True) as tmp_file: - storage = get_storage() + + storage = setup_storage() old_path = storage._db.host - storage._db.host = tmp_file.name + + orion.core.config.storage.database.host = tmp_file.name shutil.copyfile(old_path, tmp_file.name) rval = func(*args, **kwargs) - storage._db.host = old_path + orion.core.config.storage.database.host = old_path return rval @@ -58,21 +60,36 @@ class ConfigurationTestSuite: default_storage = { "type": "legacy", - "database": {"type": "pickleddb", "host": "experiment.pkl"}, + "database": {"type": "pickleddb", "host": "${file}_experiment.pkl"}, } @contextmanager def setup_global_config(self, tmp_path): """Setup temporary yaml file for the global configuration""" - with OrionState(storage=self.default_storage): + with OrionState(storage=self.default_storage) as cfg: conf_file = tmp_path / "config.yaml" - conf_file.write_text(yaml.dump(self.config)) + + if "storage" not in self.config: + self.config["storage"] = self.default_storage + + config_str = yaml.dump(self.config) + config_str = config_str.replace("${tmp_path}", str(tmp_path)) + config_str = config_str.replace("${file}", str(cfg.tempfile_path)) + + conf_file.write_text(config_str) conf_files = orion.core.DEF_CONFIG_FILES_PATHS orion.core.DEF_CONFIG_FILES_PATHS = [conf_file] + orion.core.config = orion.core.build_config() + try: yield conf_file finally: + try: + os.remove(orion.core.config.storage.database.host) + except: + pass + orion.core.DEF_CONFIG_FILES_PATHS = conf_files orion.core.config = orion.core.build_config() @@ -82,6 +99,9 @@ def setup_env_var_config(self, tmp_path): with self.setup_global_config(tmp_path): tmp = {} for key, value in self.env_vars.items(): + if isinstance(value, str): + value = value.replace("${tmp_path}", str(tmp_path)) + tmp[key] = os.environ.pop(key, None) os.environ[key] = str(value) try: @@ -97,16 +117,20 @@ def setup_env_var_config(self, tmp_path): def setup_db_config(self, tmp_path): """Setup database with temporary data""" with self.setup_env_var_config(tmp_path): - storage = get_storage() + storage = setup_storage() storage.create_experiment(self.database) - yield + yield storage @contextmanager def setup_local_config(self, tmp_path): """Setup local configuration on top""" with self.setup_db_config(tmp_path): conf_file = tmp_path / "local.yaml" - conf_file.write_text(yaml.dump(self.local)) + + config_str = yaml.dump(self.local) + config_str = config_str.replace("${tmp_path}", str(tmp_path)) + + conf_file.write_text(config_str) yield conf_file @contextmanager @@ -117,14 +141,14 @@ def setup_cmd_args_config(self, tmp_path): def test_global_config(self, tmp_path, monkeypatch): """Test that global configuration is set properly based on global yaml""" - update_singletons() + self.sanity_check() with self.setup_global_config(tmp_path): self.check_global_config(tmp_path, monkeypatch) def test_env_var_config(self, tmp_path, monkeypatch): """Test that env vars are set properly in global config""" - update_singletons() + self.sanity_check() with self.setup_env_var_config(tmp_path): self.check_env_var_config(tmp_path, monkeypatch) @@ -134,7 +158,7 @@ def test_env_var_config(self, tmp_path, monkeypatch): ) def test_db_config(self, tmp_path): """Test that exp config in db overrides global config""" - update_singletons() + self.sanity_check() with self.setup_db_config(tmp_path): self.check_db_config() @@ -142,7 +166,7 @@ def test_db_config(self, tmp_path): @pytest.mark.usefixtures("with_user_userxyz", "version_XYZ") def test_local_config(self, tmp_path, monkeypatch): """Test that local config overrides db/global config""" - update_singletons() + self.sanity_check() with self.setup_local_config(tmp_path) as conf_file: self.check_local_config(tmp_path, conf_file, monkeypatch) @@ -150,7 +174,7 @@ def test_local_config(self, tmp_path, monkeypatch): @pytest.mark.usefixtures("with_user_userxyz", "version_XYZ") def test_cmd_args_config(self, tmp_path, monkeypatch): """Test that cmd_args config overrides local config""" - update_singletons() + self.sanity_check() with self.setup_cmd_args_config(tmp_path) as conf_file: self.check_cmd_args_config(tmp_path, conf_file, monkeypatch) @@ -165,7 +189,7 @@ class TestStorage(ConfigurationTestSuite): "database": { "name": "test_name", "type": "pickleddb", - "host": "here.pkl", + "host": "${tmp_path}/here.pkl", "port": 101, }, } @@ -175,14 +199,14 @@ class TestStorage(ConfigurationTestSuite): "ORION_STORAGE_TYPE": "legacy", "ORION_DB_NAME": "test_env_var_name", "ORION_DB_TYPE": "pickleddb", - "ORION_DB_ADDRESS": "there.pkl", + "ORION_DB_ADDRESS": "${tmp_path}/there.pkl", "ORION_DB_PORT": "103", } local = { "storage": { "type": "legacy", - "database": {"type": "pickleddb", "host": "local.pkl"}, + "database": {"type": "pickleddb", "host": "${tmp_path}/local.pkl"}, } } @@ -192,73 +216,94 @@ def sanity_check(self): def check_global_config(self, tmp_path, monkeypatch): """Check that global configuration is set properly""" - update_singletons() - assert orion.core.config.storage.to_dict() == self.config["storage"] + storage_config = copy.deepcopy(self.config["storage"]) + storage_config["database"]["host"] = storage_config["database"]["host"].replace( + "${tmp_path}", str(tmp_path) + ) + assert orion.core.config.storage.to_dict() == storage_config - with pytest.raises(SingletonNotInstantiatedError): - get_storage() + # Build storage + storage = setup_storage() + assert len(storage.fetch_experiments({"name": "test"})) == 0 command = f"hunt --exp-max-trials 0 -n test python {script} -x~uniform(0,1)" orion.core.cli.main(command.split(" ")) - storage = get_storage() + # if hunt worked it should insert its experiment + assert len(storage.fetch_experiments({"name": "test"})) == 1 + assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) - assert storage._db.host == os.path.abspath("here.pkl") + assert storage._db.host == str(tmp_path / "here.pkl") def check_env_var_config(self, tmp_path, monkeypatch): """Check that env vars overrides global configuration""" - update_singletons() assert orion.core.config.storage.to_dict() == { "type": self.env_vars["ORION_STORAGE_TYPE"], "database": { "name": self.env_vars["ORION_DB_NAME"], "type": self.env_vars["ORION_DB_TYPE"], - "host": self.env_vars["ORION_DB_ADDRESS"], + "host": self.env_vars["ORION_DB_ADDRESS"].replace( + "${tmp_path}", str(tmp_path) + ), "port": int(self.env_vars["ORION_DB_PORT"]), }, } - with pytest.raises(SingletonNotInstantiatedError): - get_storage() + # Build storage + storage = setup_storage() + assert len(storage.fetch_experiments({"name": "test"})) == 0 + # Make sure hunt is picking up the right database command = f"hunt --exp-max-trials 0 -n test python {script} -x~uniform(0,1)" orion.core.cli.main(command.split(" ")) - storage = get_storage() + # if hunt worked it should insert its experiment + assert len(storage.fetch_experiments({"name": "test"})) == 1 + assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) - assert storage._db.host == os.path.abspath(self.env_vars["ORION_DB_ADDRESS"]) + assert storage._db.host == self.env_vars["ORION_DB_ADDRESS"].replace( + "${tmp_path}", str(tmp_path) + ) def check_db_config(self): """No Storage config in DB, no test""" def check_local_config(self, tmp_path, conf_file, monkeypatch): """Check that local configuration overrides global/envvars configuration""" - update_singletons() assert orion.core.config.storage.to_dict() == { "type": self.env_vars["ORION_STORAGE_TYPE"], "database": { "name": self.env_vars["ORION_DB_NAME"], "type": self.env_vars["ORION_DB_TYPE"], - "host": self.env_vars["ORION_DB_ADDRESS"], + "host": self.env_vars["ORION_DB_ADDRESS"].replace( + "${tmp_path}", str(tmp_path) + ), "port": int(self.env_vars["ORION_DB_PORT"]), }, } - with pytest.raises(SingletonNotInstantiatedError): - get_storage() + # Build storage with local config + cmd_config = experiment_builder.get_cmd_config(dict(config=open(conf_file))) + builder = experiment_builder.ExperimentBuilder(cmd_config["storage"]) + storage = builder.storage + + assert len(storage.fetch_experiments({"name": "test"})) == 0 + # Make sure hunt is picking up the right database command = f"hunt --exp-max-trials 0 -n test -c {conf_file} python {script} -x~uniform(0,1)" orion.core.cli.main(command.split(" ")) - storage = get_storage() + # if hunt worked it should insert its experiment + assert len(storage.fetch_experiments({"name": "test"})) == 1 + assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) - assert storage._db.host == os.path.abspath("local.pkl") + assert storage._db.host == str(tmp_path / "local.pkl") def check_cmd_args_config(self, tmp_path, conf_file, monkeypatch): """No Storage config in cmdline, no test""" @@ -271,7 +316,7 @@ class TestDatabaseDeprecated(ConfigurationTestSuite): "database": { "name": "test_name", "type": "pickleddb", - "host": "dbhere.pkl", + "host": "${tmp_path}/dbhere.pkl", "port": 101, } } @@ -279,11 +324,11 @@ class TestDatabaseDeprecated(ConfigurationTestSuite): env_vars = { "ORION_DB_NAME": "test_env_var_name", "ORION_DB_TYPE": "pickleddb", - "ORION_DB_ADDRESS": "there.pkl", + "ORION_DB_ADDRESS": "${tmp_path}/dbthere.pkl", "ORION_DB_PORT": "103", } - local = {"database": {"type": "pickleddb", "host": "dblocal.pkl"}} + local = {"database": {"type": "pickleddb", "host": "${tmp_path}/dblocal.pkl"}} def sanity_check(self): """Check that defaults are different than testing configuration""" @@ -291,67 +336,80 @@ def sanity_check(self): def check_global_config(self, tmp_path, monkeypatch): """Check that global configuration is set properly""" - update_singletons() - assert orion.core.config.database.to_dict() == self.config["database"] + database = copy.deepcopy(self.config["database"]) + database["host"] = database["host"].replace("${tmp_path}", str(tmp_path)) + assert orion.core.config.database.to_dict() == database - with pytest.raises(SingletonNotInstantiatedError): - get_storage() + storage = setup_storage() + assert len(storage.fetch_experiments({"name": "test"})) == 0 command = f"hunt --exp-max-trials 0 -n test python {script} -x~uniform(0,1)" orion.core.cli.main(command.split(" ")) - storage = get_storage() + assert len(storage.fetch_experiments({"name": "test"})) == 1 + assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) - assert storage._db.host == os.path.abspath("dbhere.pkl") + assert storage._db.host == str(tmp_path / "dbhere.pkl") def check_env_var_config(self, tmp_path, monkeypatch): """Check that env vars overrides global configuration""" - update_singletons() assert orion.core.config.database.to_dict() == { "name": self.env_vars["ORION_DB_NAME"], "type": self.env_vars["ORION_DB_TYPE"], - "host": self.env_vars["ORION_DB_ADDRESS"], + "host": self.env_vars["ORION_DB_ADDRESS"].replace( + "${tmp_path}", str(tmp_path) + ), "port": int(self.env_vars["ORION_DB_PORT"]), } - with pytest.raises(SingletonNotInstantiatedError): - get_storage() + storage = setup_storage() + assert len(storage.fetch_experiments({"name": "test"})) == 0 command = f"hunt --exp-max-trials 0 -n test python {script} -x~uniform(0,1)" orion.core.cli.main(command.split(" ")) - storage = get_storage() + assert len(storage.fetch_experiments({"name": "test"})) == 1 + assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) - assert storage._db.host == os.path.abspath(self.env_vars["ORION_DB_ADDRESS"]) + assert storage._db.host == self.env_vars["ORION_DB_ADDRESS"].replace( + "${tmp_path}", str(tmp_path) + ) def check_db_config(self): """No Storage config in DB, no test""" def check_local_config(self, tmp_path, conf_file, monkeypatch): """Check that local configuration overrides global/envvars configuration""" - update_singletons() assert orion.core.config.database.to_dict() == { "name": self.env_vars["ORION_DB_NAME"], "type": self.env_vars["ORION_DB_TYPE"], - "host": self.env_vars["ORION_DB_ADDRESS"], + "host": self.env_vars["ORION_DB_ADDRESS"].replace( + "${tmp_path}", str(tmp_path) + ), "port": int(self.env_vars["ORION_DB_PORT"]), } - with pytest.raises(SingletonNotInstantiatedError): - get_storage() + cmd_config = experiment_builder.get_cmd_config(dict(config=open(conf_file))) + builder = experiment_builder.ExperimentBuilder(cmd_config["storage"]) + storage = builder.storage + assert len(storage.fetch_experiments({"name": "test"})) == 0 + + # Make sure hunt is picking up the right database command = f"hunt --exp-max-trials 0 -n test -c {conf_file} python {script} -x~uniform(0,1)" orion.core.cli.main(command.split(" ")) - storage = get_storage() + # if hunt worked it should insert its experiment + assert len(storage.fetch_experiments({"name": "test"})) == 1 + assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) - assert storage._db.host == os.path.abspath("dblocal.pkl") + assert storage._db.host == str(tmp_path / "dblocal.pkl") def check_cmd_args_config(self, tmp_path, conf_file, monkeypatch): """No Storage config in cmdline, no test""" @@ -468,11 +526,10 @@ def check_global_config(self, tmp_path, monkeypatch): self._compare( self.config["experiment"], orion.core.config.to_dict()["experiment"] ) - command = f"hunt --init-only -n test python {script} -x~uniform(0,1)" orion.core.cli.main(command.split(" ")) - storage = get_storage() + storage = setup_storage() experiment = get_experiment("test") self._compare( @@ -511,7 +568,7 @@ def check_db_config(self): command = f"hunt --worker-max-trials 0 -n {name}" orion.core.cli.main(command.split(" ")) - storage = get_storage() + storage = setup_storage() experiment = get_experiment(name) self._compare(self.database, experiment.configuration, ignore=["worker_trials"]) @@ -521,7 +578,7 @@ def check_local_config(self, tmp_path, conf_file, monkeypatch): command = f"hunt --worker-max-trials 0 -c {conf_file}" orion.core.cli.main(command.split(" ")) - storage = get_storage() + storage = setup_storage() experiment = get_experiment("test-name") self._compare(self.local["experiment"], experiment.configuration) @@ -534,7 +591,7 @@ def check_cmd_args_config(self, tmp_path, conf_file, monkeypatch): ) orion.core.cli.main(command.split(" ")) - storage = get_storage() + storage = setup_storage() experiment = get_experiment("exp-name") assert experiment.name == "exp-name" diff --git a/tests/functional/demo/test_demo.py b/tests/functional/demo/test_demo.py index b5d69528f..bb400156b 100644 --- a/tests/functional/demo/test_demo.py +++ b/tests/functional/demo/test_demo.py @@ -5,6 +5,7 @@ import subprocess import tempfile from collections import defaultdict +from contextlib import contextmanager import numpy import pytest @@ -13,9 +14,6 @@ import orion.core.cli import orion.core.io.experiment_builder as experiment_builder from orion.core.cli.hunt import workon -from orion.core.io.database.ephemeraldb import EphemeralDB -from orion.storage.base import get_storage -from orion.storage.legacy import Legacy from orion.testing import OrionState @@ -243,30 +241,81 @@ def test_demo_inexecutable_script(storage, monkeypatch, capsys): assert "User script is not executable" in captured -def test_demo_four_workers(storage, monkeypatch): +@contextmanager +def generate_config(template, tmp_path): + """Generate a configuration file inside a temporary directory with the current storage config""" + + with open(template) as file: + conf = yaml.safe_load(file) + + conf["storage"] = orion.core.config.storage.to_dict() + conf_file = os.path.join(tmp_path, "config.yaml") + config_str = yaml.dump(conf) + + with open(conf_file, "w") as file: + file.write(config_str) + + with open(conf_file) as file: + yield file + + +def logging_directory(): + """Default logging directory for testing `/logdir`. + The folder is deleted if it exists at the beginning of testing, + it will not be deleted at the end of the tests to help debugging. + + """ + base_repo = os.path.dirname(os.path.abspath(orion.core.__file__)) + logdir = os.path.abspath(os.path.join(base_repo, "..", "..", "..", "logdir")) + shutil.rmtree(logdir, ignore_errors=True) + return logdir + + +def test_demo_four_workers(tmp_path, storage, monkeypatch): """Test a simple usage scenario.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - processes = [] - for _ in range(4): - process = subprocess.Popen( - [ - "orion", - "hunt", - "-n", - "four_workers_demo", - "--config", - "./orion_config_random.yaml", - "--max-trials", - "20", - "./black_box.py", - "-x~norm(34, 3)", - ] - ) - processes.append(process) - for process in processes: - rcode = process.wait() - assert rcode == 0 + logdir = logging_directory() + print(logdir) + + with generate_config("orion_config_random.yaml", tmp_path) as conf_file: + processes = [] + for _ in range(4): + process = subprocess.Popen( + [ + "orion", + "-vvv", + "--logdir", + logdir, + "hunt", + "--working-dir", + str(tmp_path), + "-n", + "four_workers_demo", + "--config", + f"{conf_file.name}", + "--max-trials", + "20", + "./black_box.py", + "-x~norm(34, 3)", + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + processes.append(process) + + for process in processes: + stdout, _ = process.communicate() + + rcode = process.wait() + + if rcode != 0: + print("OUT", stdout.decode("utf-8")) + + assert rcode == 0 + + assert storage._db.host == orion.core.config.storage.database.host + print(storage._db.host) exp = list(storage.fetch_experiments({"name": "four_workers_demo"})) assert len(exp) == 1 @@ -307,8 +356,14 @@ def test_workon(): "-x~uniform(-50, 50, precision=None)", ] - with OrionState(): - experiment = experiment_builder.build_from_args(config) + with OrionState() as cfg: + cmd_config = experiment_builder.get_cmd_config(config) + + builder = experiment_builder.ExperimentBuilder( + cfg.storage, debug=cmd_config.get("debug") + ) + + experiment = builder.build(**cmd_config) workon( experiment, @@ -324,7 +379,7 @@ def test_workon(): executor_configuration={"backend": "threading"}, ) - storage = get_storage() + storage = cfg.storage exp = list(storage.fetch_experiments({"name": name})) assert len(exp) == 1 @@ -360,6 +415,7 @@ def test_stress_unique_folder_creation(storage, monkeypatch, tmpdir, capfd): monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) orion.core.cli.main( [ + "-vvv", "hunt", f"--max-trials={how_many}", "--name=lalala", @@ -642,9 +698,9 @@ def n_completed(): assert n_completed() == 6 -@pytest.mark.usefixtures("storage") -def test_resilience(monkeypatch): +def test_resilience(storage, monkeypatch): """Test if OrĂ­on stops after enough broken trials.""" + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) MAX_BROKEN = 3 @@ -664,27 +720,28 @@ def test_resilience(monkeypatch): assert len(exp.fetch_trials_by_status("broken")) == MAX_BROKEN -@pytest.mark.usefixtures("storage") -def test_demo_with_shutdown_quickly(monkeypatch): +def test_demo_with_shutdown_quickly(storage, monkeypatch, tmp_path): """Check simple pipeline with random search is reasonably fast.""" + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) monkeypatch.setattr(orion.core.config.worker, "heartbeat", 120) - process = subprocess.Popen( - [ - "orion", - "hunt", - "--config", - "./orion_config_random.yaml", - "--max-trials", - "10", - "./black_box.py", - "-x~uniform(-50, 50)", - ] - ) + with generate_config("orion_config_random.yaml", tmp_path) as conf_file: + process = subprocess.Popen( + [ + "orion", + "hunt", + "--config", + f"{conf_file.name}", + "--max-trials", + "10", + "./black_box.py", + "-x~uniform(-50, 50)", + ] + ) - assert process.wait(timeout=40) == 0 + assert process.wait(timeout=40) == 0 def test_demo_with_nondefault_config_keyword(storage, monkeypatch): @@ -773,30 +830,27 @@ def test_demo_precision(storage, monkeypatch): assert value == float(numpy.format_float_scientific(value, precision=4)) -@pytest.mark.usefixtures("setup_pickleddb_database") -def test_debug_mode(monkeypatch): +def test_debug_mode(storage, monkeypatch, tmp_path): """Test debug mode.""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) user_args = ["-x~uniform(-50, 50, precision=5)"] - orion.core.cli.main( - [ - "--debug", - "hunt", - "--config", - "./orion_config.yaml", - "--max-trials", - "2", - "./black_box.py", - ] - + user_args - ) - - storage = get_storage() + with generate_config("orion_config.yaml", tmp_path) as conf_file: + orion.core.cli.main( + [ + "--debug", + "hunt", + "--config", + f"{conf_file.name}", + "--max-trials", + "2", + "./black_box.py", + ] + + user_args + ) - assert isinstance(storage, Legacy) - assert isinstance(storage._db, EphemeralDB) + assert len(list(storage.fetch_experiments({}))) == 0 def test_no_args(capsys): diff --git a/tests/functional/example/test_scikit_learn.py b/tests/functional/example/test_scikit_learn.py index 3d9f7b831..130045e1d 100644 --- a/tests/functional/example/test_scikit_learn.py +++ b/tests/functional/example/test_scikit_learn.py @@ -6,7 +6,7 @@ import orion.core.cli from orion.client import create_experiment -from orion.storage.base import get_storage +from orion.storage.base import setup_storage def test_script_integrity(capsys): @@ -22,7 +22,7 @@ def test_script_integrity(capsys): ), "The example script encountered an error during its execution." -@pytest.mark.usefixtures("setup_pickleddb_database") +@pytest.mark.usefixtures("orionstate") def test_orion_runs_script(monkeypatch): """Verifies OrĂ­on can execute the example script.""" script = os.path.abspath("examples/scikitlearn-iris/main.py") @@ -41,7 +41,7 @@ def test_orion_runs_script(monkeypatch): assert len(keys) == 1 assert "/_pos_2" in keys - storage = get_storage() + storage = setup_storage() trials = storage.fetch_trials(uid=experiment.id) assert len(trials) == 1 @@ -50,7 +50,7 @@ def test_orion_runs_script(monkeypatch): assert trial.params["/_pos_2"] == 0.1 -@pytest.mark.usefixtures("setup_pickleddb_database") +@pytest.mark.usefixtures("orionstate") def test_result_reproducibility(monkeypatch): """Verifies the script results stays consistent (with respect to the documentation).""" script = os.path.abspath("examples/scikitlearn-iris/main.py") diff --git a/tests/functional/parsing/test_parsing_base.py b/tests/functional/parsing/test_parsing_base.py index 42e65d090..3acdd99c8 100644 --- a/tests/functional/parsing/test_parsing_base.py +++ b/tests/functional/parsing/test_parsing_base.py @@ -16,7 +16,7 @@ def _create_parser(need_subparser=True): return parser -def test_common_group_arguments(setup_pickleddb_database, monkeypatch): +def test_common_group_arguments(orionstate, monkeypatch): """Check the parsing of the common group""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) parser, subparsers = _create_parser() @@ -28,7 +28,7 @@ def test_common_group_arguments(setup_pickleddb_database, monkeypatch): assert args["config"].name == "./orion_config_random.yaml" -def test_user_group_arguments(setup_pickleddb_database, monkeypatch): +def test_user_group_arguments(orionstate, monkeypatch): """Test the parsing of the user group""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) parser = _create_parser(False) @@ -40,7 +40,7 @@ def test_user_group_arguments(setup_pickleddb_database, monkeypatch): assert args["user_args"] == ["./black_box.py", "-x~normal(50,50)"] -def test_common_and_user_group_arguments(setup_pickleddb_database, monkeypatch): +def test_common_and_user_group_arguments(orionstate, monkeypatch): """Test the parsing of the command and user groups""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) parser = _create_parser(False) diff --git a/tests/functional/parsing/test_parsing_hunt.py b/tests/functional/parsing/test_parsing_hunt.py index ddc61a8c8..d878fb084 100644 --- a/tests/functional/parsing/test_parsing_hunt.py +++ b/tests/functional/parsing/test_parsing_hunt.py @@ -16,7 +16,7 @@ def _create_parser(need_subparser=True): return parser -def test_hunt_command_full_parsing(setup_pickleddb_database, monkeypatch): +def test_hunt_command_full_parsing(orionstate, monkeypatch): """Test the parsing of the `hunt` command""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) parser, subparsers = _create_parser() diff --git a/tests/functional/parsing/test_parsing_insert.py b/tests/functional/parsing/test_parsing_insert.py index f8a68b756..2b10e93aa 100644 --- a/tests/functional/parsing/test_parsing_insert.py +++ b/tests/functional/parsing/test_parsing_insert.py @@ -16,7 +16,7 @@ def _create_parser(need_subparser=True): return parser -def test_insert_command_full_parsing(setup_pickleddb_database, monkeypatch): +def test_insert_command_full_parsing(orionstate, monkeypatch): """Test the parsing of all the options of insert""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) parser, subparsers = _create_parser() diff --git a/tests/functional/serving/conftest.py b/tests/functional/serving/conftest.py index fb4a9f479..288a46b16 100644 --- a/tests/functional/serving/conftest.py +++ b/tests/functional/serving/conftest.py @@ -8,23 +8,27 @@ @pytest.fixture() -def client(): - """Mock the falcon.API instance for testing with an in memory database""" +def ephemeral_storage(): storage = {"type": "legacy", "database": {"type": "EphemeralDB"}} - with OrionState(storage=storage): - yield testing.TestClient(WebApi({"storage": storage})) + with OrionState(storage=storage) as cfg: + yield cfg + + +@pytest.fixture() +def client(ephemeral_storage): + """Mock the falcon.API instance for testing with an in memory database""" + yield testing.TestClient(WebApi(ephemeral_storage.storage, {})) @pytest.fixture() -def client_with_frontends_uri(): +def client_with_frontends_uri(ephemeral_storage): """Mock the falcon.API instance for testing with custom frontend_uri""" - storage = {"type": "legacy", "database": {"type": "EphemeralDB"}} - with OrionState(storage=storage): - yield testing.TestClient( - WebApi( - { - "storage": storage, - "frontends_uri": ["http://123.456", "http://example.com"], - } - ) + yield testing.TestClient( + WebApi( + ephemeral_storage.storage, + { + "storage": ephemeral_storage.storage, + "frontends_uri": ["http://123.456", "http://example.com"], + }, ) + ) diff --git a/tests/functional/serving/test_experiments_resource.py b/tests/functional/serving/test_experiments_resource.py index 50218ca03..78500c4a0 100644 --- a/tests/functional/serving/test_experiments_resource.py +++ b/tests/functional/serving/test_experiments_resource.py @@ -3,7 +3,6 @@ import datetime from orion.core.worker.trial import Trial -from orion.storage.base import get_storage current_id = 0 @@ -59,26 +58,32 @@ def test_no_experiments(self, client): assert response.json == [] assert response.status == "200 OK" - def test_send_name_and_versions(self, client): + def test_send_name_and_versions(self, client, ephemeral_storage): """Tests that the API returns all the experiments with their name and version""" expected = [{"name": "a", "version": 1}, {"name": "b", "version": 1}] - _add_experiment(name="a", version=1, _id=1) - _add_experiment(name="b", version=1, _id=2) + storage = ephemeral_storage.storage + + _add_experiment(storage, name="a", version=1, _id=1) + _add_experiment(storage, name="b", version=1, _id=2) + + assert len(storage.fetch_experiments({})) == 2 response = client.simulate_get("/experiments") assert response.json == expected assert response.status == "200 OK" - def test_latest_versions(self, client): + def test_latest_versions(self, client, ephemeral_storage): """Tests that the API return the latest versions of each experiment""" expected = [{"name": "a", "version": 3}, {"name": "b", "version": 1}] - _add_experiment(name="a", version=1, _id=1) - _add_experiment(name="a", version=3, _id=2) - _add_experiment(name="a", version=2, _id=3) - _add_experiment(name="b", version=1, _id=4) + storage = ephemeral_storage.storage + + _add_experiment(storage, name="a", version=1, _id=1) + _add_experiment(storage, name="a", version=3, _id=2) + _add_experiment(storage, name="a", version=2, _id=3) + _add_experiment(storage, name="b", version=1, _id=4) response = client.simulate_get("/experiments") @@ -89,7 +94,7 @@ def test_latest_versions(self, client): class TestItem: """Tests the server's response to experiments/:name""" - def test_non_existent_experiment(self, client): + def test_non_existent_experiment(self, client, ephemeral_storage): """ Tests that a 404 response is returned when the experiment doesn't exist in the database @@ -102,7 +107,9 @@ def test_non_existent_experiment(self, client): "description": 'Experiment "a" does not exist', } - _add_experiment(name="a", version=1, _id=1) + storage = ephemeral_storage.storage + + _add_experiment(storage, name="a", version=1, _id=1) response = client.simulate_get("/experiments/b") assert response.status == "404 Not Found" @@ -111,10 +118,13 @@ def test_non_existent_experiment(self, client): "description": 'Experiment "b" does not exist', } - def test_experiment_specification(self, client): + def test_experiment_specification(self, client, ephemeral_storage): """Tests that the experiment returned is following the specification""" - _add_experiment(name="a", version=1, _id=1) - _add_trial(experiment=1, id_override="ae8", status="completed") + + storage = ephemeral_storage.storage + + _add_experiment(storage, name="a", version=1, _id=1) + _add_trial(storage, experiment=1, id_override="ae8", status="completed") response = client.simulate_get("/experiments/a") @@ -132,42 +142,50 @@ def test_experiment_specification(self, client): _assert_config(response.json["config"]) _assert_best_trial(response.json["bestTrial"]) - def test_default_is_latest_version(self, client): + def test_default_is_latest_version(self, client, ephemeral_storage): """Tests that the latest experiment is returned when no version parameter exists""" - _add_experiment(name="a", version=1, _id=1) - _add_experiment(name="a", version=3, _id=2) - _add_experiment(name="a", version=2, _id=3) + + storage = ephemeral_storage.storage + + _add_experiment(storage, name="a", version=1, _id=1) + _add_experiment(storage, name="a", version=3, _id=2) + _add_experiment(storage, name="a", version=2, _id=3) response = client.simulate_get("/experiments/a") assert response.status == "200 OK" assert response.json["version"] == 3 - def test_specific_version(self, client): + def test_specific_version(self, client, ephemeral_storage): """Tests that the specified version of an experiment is returned""" - _add_experiment(name="a", version=1, _id=1) - _add_experiment(name="a", version=2, _id=2) - _add_experiment(name="a", version=3, _id=3) + + storage = ephemeral_storage.storage + + _add_experiment(storage, name="a", version=1, _id=1) + _add_experiment(storage, name="a", version=2, _id=2) + _add_experiment(storage, name="a", version=3, _id=3) response = client.simulate_get("/experiments/a?version=2") assert response.status == "200 OK" assert response.json["version"] == 2 - def test_unknown_parameter(self, client): + def test_unknown_parameter(self, client, ephemeral_storage): """ Tests that if an unknown parameter is specified in the query string, an error is returned even if the experiment doesn't exist. """ response = client.simulate_get("/experiments/a?unknown=true") + storage = ephemeral_storage.storage + assert response.status == "400 Bad Request" assert response.json == { "title": "Invalid parameter", "description": 'Parameter "unknown" is not supported. Expected parameter "version".', } - _add_experiment(name="a", version=1, _id=1) + _add_experiment(storage, name="a", version=1, _id=1) response = client.simulate_get("/experiments/a?unknown=true") @@ -178,16 +196,16 @@ def test_unknown_parameter(self, client): } -def _add_experiment(**kwargs): +def _add_experiment(storage, **kwargs): """Adds experiment to the dummy orion instance""" base_experiment.update(copy.deepcopy(kwargs)) - get_storage().create_experiment(base_experiment) + storage.create_experiment(base_experiment) -def _add_trial(**kwargs): +def _add_trial(storage, **kwargs): """Add trials to the dummy orion instance""" base_trial.update(copy.deepcopy(kwargs)) - get_storage().register_trial(Trial(**base_trial)) + storage.register_trial(Trial(**base_trial)) def _assert_config(config): diff --git a/tests/functional/serving/test_plots_resource.py b/tests/functional/serving/test_plots_resource.py index 3088f3134..7ff1fc282 100644 --- a/tests/functional/serving/test_plots_resource.py +++ b/tests/functional/serving/test_plots_resource.py @@ -1,7 +1,7 @@ """Perform tests for the REST endpoint `/plots`""" import pytest -from orion.testing import create_experiment +from orion.testing import falcon_client config = dict( name="experiment-name", @@ -58,12 +58,13 @@ def test_unknown_experiment(self, client): "description": 'Experiment "unknown-experiment" does not exist', } - def test_plot(self, client): + def test_plot(self): """Tests that the API returns the plot in json format.""" - with create_experiment(config, trial_config, ["completed"]) as ( + with falcon_client(config, trial_config, ["completed"]) as ( _, _, experiment, + client, ): response = client.simulate_get("/plots/regret/experiment-name") @@ -71,12 +72,13 @@ def test_plot(self, client): assert response.json assert list(response.json.keys()) == ["data", "layout"] - def test_no_trials(self, client): + def test_no_trials(self): """Tests that the API returns an empty figure when no trials are found.""" - with create_experiment(config, trial_config, []) as ( + with falcon_client(config, trial_config, []) as ( _, _, experiment, + client, ): response = client.simulate_get("/plots/regret/experiment-name") @@ -98,12 +100,13 @@ def test_unknown_experiment(self, client): "description": 'Experiment "unknown-experiment" does not exist', } - def test_plot(self, client): + def test_plot(self): """Tests that the API returns the plot in json format.""" - with create_experiment(config, trial_config, ["completed"]) as ( + with falcon_client(config, trial_config, ["completed"]) as ( _, _, experiment, + client, ): response = client.simulate_get( "/plots/parallel_coordinates/experiment-name" @@ -113,12 +116,13 @@ def test_plot(self, client): assert response.json assert list(response.json.keys()) == ["data", "layout"] - def test_no_trials(self, client): + def test_no_trials(self): """Tests that the API returns an empty figure when no trials are found.""" - with create_experiment(config, trial_config, []) as ( + with falcon_client(config, trial_config, []) as ( _, _, experiment, + client, ): response = client.simulate_get( "/plots/parallel_coordinates/experiment-name" @@ -142,12 +146,13 @@ def test_unknown_experiment(self, client): "description": 'Experiment "unknown-experiment" does not exist', } - def test_plot(self, client): + def test_plot(self): """Tests that the API returns the plot in json format.""" - with create_experiment(config, trial_config, ["completed"]) as ( + with falcon_client(config, trial_config, ["completed"]) as ( _, _, experiment, + client, ): response = client.simulate_get( "/plots/partial_dependencies/experiment-name" @@ -157,12 +162,13 @@ def test_plot(self, client): assert response.json assert list(response.json.keys()) == ["data", "layout"] - def test_no_trials(self, client): + def test_no_trials(self): """Tests that the API returns an empty figure when no trials are found.""" - with create_experiment(config, trial_config, []) as ( + with falcon_client(config, trial_config, []) as ( _, _, experiment, + client, ): response = client.simulate_get( "/plots/partial_dependencies/experiment-name" @@ -186,12 +192,13 @@ def test_unknown_experiment(self, client): "description": 'Experiment "unknown-experiment" does not exist', } - def test_plot(self, client): + def test_plot(self): """Tests that the API returns the plot in json format.""" - with create_experiment(config, trial_config, ["completed"]) as ( + with falcon_client(config, trial_config, ["completed"]) as ( _, _, experiment, + client, ): response = client.simulate_get("/plots/lpi/experiment-name") @@ -199,12 +206,13 @@ def test_plot(self, client): assert response.json assert list(response.json.keys()) == ["data", "layout"] - def test_no_trials(self, client): + def test_no_trials(self): """Tests that the API returns an empty figure when no trials are found.""" - with create_experiment(config, trial_config, []) as ( + with falcon_client(config, trial_config, []) as ( _, _, experiment, + client, ): response = client.simulate_get("/plots/lpi/experiment-name") diff --git a/tests/functional/serving/test_trials_resource.py b/tests/functional/serving/test_trials_resource.py index 8d5475540..4633f588e 100644 --- a/tests/functional/serving/test_trials_resource.py +++ b/tests/functional/serving/test_trials_resource.py @@ -7,7 +7,6 @@ from orion.core.io import experiment_builder from orion.core.worker.trial import Trial -from orion.storage.base import get_storage base_experiment = dict( name="experiment-name", @@ -51,16 +50,17 @@ } -def add_experiment(**kwargs): +def add_experiment(storage, **kwargs): """Adds experiment to the dummy orion instance""" base_experiment.update(copy.deepcopy(kwargs)) experiment_builder.build( branching=dict(branch_from=base_experiment["name"], enable=True), + storage=storage, **base_experiment ) -def add_trial(experiment: int, status: str = None, value=10, **kwargs): +def add_trial(storage, experiment: int, status: str = None, value=10, **kwargs): """ Add trials to the dummy orion instance @@ -81,7 +81,7 @@ def add_trial(experiment: int, status: str = None, value=10, **kwargs): base_trial.update(copy.deepcopy(kwargs)) base_trial["params"][0]["value"] = value - get_storage().register_trial(Trial(**base_trial)) + storage.register_trial(Trial(**base_trial)) def test_root_endpoint_not_supported(client): @@ -108,11 +108,13 @@ def test_trials_for_unknown_experiment(self, client): "description": 'Experiment "unknown-experiment" does not exist', } - def test_unknown_parameter(self, client): + def test_unknown_parameter(self, client, ephemeral_storage): """ Tests that if an unknown parameter is specified in the query string, an error is returned even if the experiment doesn't exist. """ + storage = ephemeral_storage.storage + expected_error_message = ( 'Parameter "unknown" is not supported. ' "Expected one of ['ancestors', 'status', 'version']." @@ -126,7 +128,7 @@ def test_unknown_parameter(self, client): "description": expected_error_message, } - add_experiment(name="a", version=1, _id=1) + add_experiment(storage, name="a", version=1, _id=1) response = client.simulate_get("/trials/a?unknown=true") @@ -136,15 +138,17 @@ def test_unknown_parameter(self, client): "description": expected_error_message, } - def test_trials_for_latest_version(self, client): + def test_trials_for_latest_version(self, client, ephemeral_storage): """Tests that it returns the trials of the latest version of the experiment""" - add_experiment(name="a", version=1, _id=1) - add_experiment(name="a", version=2, _id=2) + storage = ephemeral_storage.storage + + add_experiment(storage, name="a", version=1, _id=1) + add_experiment(storage, name="a", version=2, _id=2) - add_trial(experiment=1, id_override="00", value=10) - add_trial(experiment=2, id_override="01", value=10) - add_trial(experiment=1, id_override="02", value=11) - add_trial(experiment=2, id_override="03", value=12) + add_trial(storage, experiment=1, id_override="00", value=10) + add_trial(storage, experiment=2, id_override="01", value=10) + add_trial(storage, experiment=1, id_override="02", value=11) + add_trial(storage, experiment=2, id_override="03", value=12) response = client.simulate_get("/trials/a") @@ -154,15 +158,17 @@ def test_trials_for_latest_version(self, client): {"id": "36b6bb34f0a01764e1793fe2d4de9078"}, ] - def test_trials_for_specific_version(self, client): + def test_trials_for_specific_version(self, client, ephemeral_storage): """Tests specific version of experiment""" - add_experiment(name="a", version=1, _id=1) - add_experiment(name="a", version=2, _id=2) - add_experiment(name="a", version=3, _id=3) + storage = ephemeral_storage.storage - add_trial(experiment=1, id_override="00") - add_trial(experiment=2, id_override="01") - add_trial(experiment=3, id_override="02") + add_experiment(storage, name="a", version=1, _id=1) + add_experiment(storage, name="a", version=2, _id=2) + add_experiment(storage, name="a", version=3, _id=3) + + add_trial(storage, experiment=1, id_override="00") + add_trial(storage, experiment=2, id_override="01") + add_trial(storage, experiment=3, id_override="02") # Happy case response = client.simulate_get("/trials/a?version=2") @@ -179,16 +185,18 @@ def test_trials_for_specific_version(self, client): "description": 'Experiment "a" has no version "4"', } - def test_trials_for_all_versions(self, client): + def test_trials_for_all_versions(self, client, ephemeral_storage): """Tests that trials from all ancestors are shown""" - add_experiment(name="a", version=1, _id=1) - add_experiment(name="a", version=2, _id=2) - add_experiment(name="a", version=3, _id=3) + storage = ephemeral_storage.storage + + add_experiment(storage, name="a", version=1, _id=1) + add_experiment(storage, name="a", version=2, _id=2) + add_experiment(storage, name="a", version=3, _id=3) # Specify values to avoid duplicates - add_trial(experiment=1, id_override="00", value=1) - add_trial(experiment=2, id_override="01", value=2) - add_trial(experiment=3, id_override="02", value=3) + add_trial(storage, experiment=1, id_override="00", value=1) + add_trial(storage, experiment=2, id_override="01", value=2) + add_trial(storage, experiment=3, id_override="02", value=3) # Happy case default response = client.simulate_get("/trials/a?ancestors=true") @@ -216,9 +224,11 @@ def test_trials_for_all_versions(self, client): 'The value of the parameter must be "true" or "false".', } - def test_trials_by_status(self, client): + def test_trials_by_status(self, client, ephemeral_storage): """Tests that trials are returned""" - add_experiment(name="a", version=1, _id=1) + storage = ephemeral_storage.storage + + add_experiment(storage, name="a", version=1, _id=1) # There exist no trial of the given status in an empty experiment response = client.simulate_get("/trials/a?status=completed") @@ -227,14 +237,14 @@ def test_trials_by_status(self, client): assert response.json == [] # There exist no trial of the given status while other status are present - add_trial(experiment=1, id_override="00", status="broken", value=0) + add_trial(storage, experiment=1, id_override="00", status="broken", value=0) response = client.simulate_get("/trials/a?status=completed") assert response.status == "200 OK" assert response.json == [] # There exist at least one trial of the given status - add_trial(experiment=1, id_override="01", status="completed", value=1) + add_trial(storage, experiment=1, id_override="01", status="completed", value=1) response = client.simulate_get("/trials/a?status=completed") @@ -258,17 +268,21 @@ def test_trials_by_status(self, client): "'interrupted', 'broken']", } - def test_trials_by_from_specific_version_by_status_with_ancestors(self, client): + def test_trials_by_from_specific_version_by_status_with_ancestors( + self, client, ephemeral_storage + ): """Tests that mixing parameters work as intended""" - add_experiment(name="a", version=1, _id=1) - add_experiment(name="b", version=1, _id=2) - add_experiment(name="a", version=2, _id=3) - add_experiment(name="a", version=3, _id=4) + storage = ephemeral_storage.storage + + add_experiment(storage, name="a", version=1, _id=1) + add_experiment(storage, name="b", version=1, _id=2) + add_experiment(storage, name="a", version=2, _id=3) + add_experiment(storage, name="a", version=3, _id=4) - add_trial(experiment=1, id_override="00", value=1, status="completed") - add_trial(experiment=3, id_override="01", value=2, status="broken") - add_trial(experiment=3, id_override="02", value=3, status="completed") - add_trial(experiment=2, id_override="03", value=4, status="completed") + add_trial(storage, experiment=1, id_override="00", value=1, status="completed") + add_trial(storage, experiment=3, id_override="01", value=2, status="broken") + add_trial(storage, experiment=3, id_override="02", value=3, status="completed") + add_trial(storage, experiment=2, id_override="03", value=4, status="completed") response = client.simulate_get( "/trials/a?ancestors=true&version=2&status=completed" @@ -284,7 +298,7 @@ def test_trials_by_from_specific_version_by_status_with_ancestors(self, client): class TestTrialItem: """Tests trials/:experiment_name/:trial_id""" - def test_unknown_experiment(self, client): + def test_unknown_experiment(self, client, ephemeral_storage): """Tests that an unknown experiment returns a not found error""" response = client.simulate_get("/trials/unknown-experiment/a-trial") @@ -294,9 +308,11 @@ def test_unknown_experiment(self, client): "description": 'Experiment "unknown-experiment" does not exist', } - def test_unknown_trial(self, client): + def test_unknown_trial(self, client, ephemeral_storage): """Tests that an unknown experiment returns a not found error""" - add_experiment(name="a", version=1, _id=1) + storage = ephemeral_storage.storage + + add_experiment(storage, name="a", version=1, _id=1) response = client.simulate_get("/trials/a/unknown-trial") @@ -306,11 +322,13 @@ def test_unknown_trial(self, client): "description": 'Trial "unknown-trial" does not exist', } - def test_get_trial(self, client): + def test_get_trial(self, client, ephemeral_storage): """Tests that an existing trial is returned according to the API specification""" - add_experiment(name="a", version=1, _id=1) - add_trial(experiment=1, id_override="00", status="completed", value=0) - add_trial(experiment=1, id_override="01", status="completed", value=1) + storage = ephemeral_storage.storage + + add_experiment(storage, name="a", version=1, _id=1) + add_trial(storage, experiment=1, id_override="00", status="completed", value=0) + add_trial(storage, experiment=1, id_override="01", status="completed", value=1) response = client.simulate_get("/trials/a/79a873a1146cbdcc385f53e7c14f41aa") diff --git a/tests/stress/client/stress_experiment.py b/tests/stress/client/stress_experiment.py index d704a27f5..3433e3c40 100644 --- a/tests/stress/client/stress_experiment.py +++ b/tests/stress/client/stress_experiment.py @@ -12,7 +12,6 @@ from orion.client import create_experiment from orion.core.io.database import DatabaseTimeout from orion.core.utils.exceptions import ReservationTimeout -from orion.core.utils.singleton import update_singletons DB_FILE = "stress.pkl" @@ -143,7 +142,6 @@ def stress_test(storage, space_type, workers, size): database.workers.drop() database.resources.drop() client.close() - update_singletons() print("Worker | Point") @@ -183,7 +181,6 @@ def stress_test(storage, space_type, workers, size): database.workers.drop() database.resources.drop() client.close() - update_singletons() return trials diff --git a/tests/unittests/benchmark/test_assessments.py b/tests/unittests/benchmark/test_assessments.py index 685196a9b..4d09c804b 100644 --- a/tests/unittests/benchmark/test_assessments.py +++ b/tests/unittests/benchmark/test_assessments.py @@ -45,22 +45,22 @@ def test_analysis(self, experiment_config, trial_config): ) @pytest.mark.usefixtures("version_XYZ") - def test_figure_layout(self, study_experiments_config): + def test_figure_layout(self, orionstate, study_experiments_config): """Test assessment plot format""" ar1 = AverageRank() - with create_study_experiments(**study_experiments_config) as experiments: - figure = ar1.analysis("task_name", experiments) - - assert_rankings_plot( - figure["AverageRank"]["task_name"]["rankings"], - [ - list(algorithm["algorithm"].keys())[0] - for algorithm in study_experiments_config["algorithms"] - ], - balanced=study_experiments_config["max_trial"], - with_avg=True, - ) + experiments = create_study_experiments(orionstate, **study_experiments_config) + figure = ar1.analysis("task_name", experiments) + + assert_rankings_plot( + figure["AverageRank"]["task_name"]["rankings"], + [ + list(algorithm["algorithm"].keys())[0] + for algorithm in study_experiments_config["algorithms"] + ], + balanced=study_experiments_config["max_trial"], + with_avg=True, + ) class TestAverageResult: @@ -93,22 +93,22 @@ def test_analysis(self, experiment_config, trial_config): ) @pytest.mark.usefixtures("version_XYZ") - def test_figure_layout(self, study_experiments_config): + def test_figure_layout(self, orionstate, study_experiments_config): """Test assessment plot format""" ar1 = AverageResult() - with create_study_experiments(**study_experiments_config) as experiments: - figure = ar1.analysis("task_name", experiments) - - assert_regrets_plot( - figure["AverageResult"]["task_name"]["regrets"], - [ - list(algorithm["algorithm"].keys())[0] - for algorithm in study_experiments_config["algorithms"] - ], - balanced=study_experiments_config["max_trial"], - with_avg=True, - ) + experiments = create_study_experiments(orionstate, **study_experiments_config) + figure = ar1.analysis("task_name", experiments) + + assert_regrets_plot( + figure["AverageResult"]["task_name"]["regrets"], + [ + list(algorithm["algorithm"].keys())[0] + for algorithm in study_experiments_config["algorithms"] + ], + balanced=study_experiments_config["max_trial"], + with_avg=True, + ) class TestParallelAssessment: @@ -132,7 +132,7 @@ def test_creation(self): assert pa3.get_executor(2).n_workers == 4 @pytest.mark.usefixtures("version_XYZ") - def test_analysis(self, study_experiments_config): + def test_analysis(self, orionstate, study_experiments_config): """Test assessment plot format""" task_num = 2 n_workers = [1, 2, 4] @@ -140,32 +140,32 @@ def test_analysis(self, study_experiments_config): study_experiments_config["task_number"] = task_num study_experiments_config["n_workers"] = n_workers - with create_study_experiments(**study_experiments_config) as experiments: - figure = pa1.analysis("task_name", experiments) - - names = [] - algorithms = [] - for algorithm in study_experiments_config["algorithms"]: - algo = list(algorithm["algorithm"].keys())[0] - algorithms.append(algo) - - for worker in n_workers: - names.append(algo + "_workers_" + str(worker)) - - assert len(figure["ParallelAssessment"]["task_name"]) == 3 - assert_regrets_plot( - figure["ParallelAssessment"]["task_name"]["regrets"], - names, - balanced=study_experiments_config["max_trial"], - with_avg=True, - ) - - asset_parallel_assessment_plot( - figure["ParallelAssessment"]["task_name"]["parallel_assessment"], - algorithms, - 3, - ) - - assert_durations_plot( - figure["ParallelAssessment"]["task_name"]["durations"], names - ) + experiments = create_study_experiments(orionstate, **study_experiments_config) + figure = pa1.analysis("task_name", experiments) + + names = [] + algorithms = [] + for algorithm in study_experiments_config["algorithms"]: + algo = list(algorithm["algorithm"].keys())[0] + algorithms.append(algo) + + for worker in n_workers: + names.append(algo + "_workers_" + str(worker)) + + assert len(figure["ParallelAssessment"]["task_name"]) == 3 + assert_regrets_plot( + figure["ParallelAssessment"]["task_name"]["regrets"], + names, + balanced=study_experiments_config["max_trial"], + with_avg=True, + ) + + asset_parallel_assessment_plot( + figure["ParallelAssessment"]["task_name"]["parallel_assessment"], + algorithms, + 3, + ) + + assert_durations_plot( + figure["ParallelAssessment"]["task_name"]["durations"], names + ) diff --git a/tests/unittests/benchmark/test_benchmark.py b/tests/unittests/benchmark/test_benchmark.py index 1fed2695b..0bd7099a5 100644 --- a/tests/unittests/benchmark/test_benchmark.py +++ b/tests/unittests/benchmark/test_benchmark.py @@ -13,9 +13,10 @@ @pytest.fixture -def benchmark(benchmark_algorithms): +def benchmark(storage, benchmark_algorithms): """Return a benchmark instance""" return Benchmark( + storage, name="benchmark007", algorithms=benchmark_algorithms, targets=[ @@ -77,14 +78,12 @@ def test_setup_studies(self, benchmark): def test_process(self, benchmark, study): """Test to process a benchmark""" - with OrionState(): - study.setup_experiments() - benchmark.studies = [study] - benchmark.process() - name = "benchmark007_AverageResult_RosenBrock_0_0" - experiment = experiment_builder.build(name) - - assert experiment is not None + study.setup_experiments() + benchmark.studies = [study] + benchmark.process() + name = "benchmark007_AverageResult_RosenBrock_0_0" + experiment = experiment_builder.build(name) + assert experiment is not None @pytest.mark.usefixtures("version_XYZ") def test_status( @@ -94,91 +93,92 @@ def test_status( study_experiments_config, task_number, max_trial, + orionstate, ): """Test to get the status of a benchmark""" - with create_study_experiments(**study_experiments_config) as experiments: + experiments = create_study_experiments(orionstate, **study_experiments_config) - study.experiments_info = experiments + study.experiments_info = experiments - benchmark.studies = [study] + benchmark.studies = [study] - assert benchmark.status() == [ - { - "Algorithms": "random", - "Assessments": "AverageResult", - "Tasks": "RosenBrock", - "Total Experiments": task_number, - "Completed Experiments": task_number, - "Submitted Trials": task_number * max_trial, - }, - { - "Algorithms": "tpe", - "Assessments": "AverageResult", - "Tasks": "RosenBrock", - "Total Experiments": task_number, - "Completed Experiments": task_number, - "Submitted Trials": task_number * max_trial, - }, - ] + assert benchmark.status() == [ + { + "Algorithms": "random", + "Assessments": "AverageResult", + "Tasks": "RosenBrock", + "Total Experiments": task_number, + "Completed Experiments": task_number, + "Submitted Trials": task_number * max_trial, + }, + { + "Algorithms": "tpe", + "Assessments": "AverageResult", + "Tasks": "RosenBrock", + "Total Experiments": task_number, + "Completed Experiments": task_number, + "Submitted Trials": task_number * max_trial, + }, + ] @pytest.mark.usefixtures("version_XYZ") - def test_analysis(self, benchmark, study, study_experiments_config): + def test_analysis(self, orionstate, benchmark, study, study_experiments_config): """Test to analysis benchmark result""" - with create_study_experiments(**study_experiments_config) as experiments: + experiments = create_study_experiments(orionstate, **study_experiments_config) + study.experiments_info = experiments - study.experiments_info = experiments + benchmark.studies = [study] - benchmark.studies = [study] + figures = benchmark.analysis() - figures = benchmark.analysis() - - assert len(figures) == 1 - assert ( - type(figures[study.assess_name][study.task_name]["regrets"]) - is plotly.graph_objects.Figure - ) + assert len(figures) == 1 + assert ( + type(figures[study.assess_name][study.task_name]["regrets"]) + is plotly.graph_objects.Figure + ) @pytest.mark.usefixtures("version_XYZ") def test_experiments( self, + orionstate, benchmark, study, study_experiments_config, max_trial, ): """Test to get experiments list of a benchmark""" - with create_study_experiments(**study_experiments_config) as experiments: + experiments = create_study_experiments(orionstate, **study_experiments_config) - study.experiments_info = experiments + study.experiments_info = experiments - benchmark.studies = [study] + benchmark.studies = [study] - assert benchmark.experiments() == [ - { - "Algorithm": "random", - "Experiment Name": "experiment-name-0", - "Number Trial": max_trial, - "Best Evaluation": 0, - }, - { - "Algorithm": "tpe", - "Experiment Name": "experiment-name-1", - "Number Trial": max_trial, - "Best Evaluation": 0, - }, - { - "Algorithm": "random", - "Experiment Name": "experiment-name-2", - "Number Trial": max_trial, - "Best Evaluation": 0, - }, - { - "Algorithm": "tpe", - "Experiment Name": "experiment-name-3", - "Number Trial": max_trial, - "Best Evaluation": 0, - }, - ] + assert benchmark.experiments() == [ + { + "Algorithm": "random", + "Experiment Name": "experiment-name-0", + "Number Trial": max_trial, + "Best Evaluation": 0, + }, + { + "Algorithm": "tpe", + "Experiment Name": "experiment-name-1", + "Number Trial": max_trial, + "Best Evaluation": 0, + }, + { + "Algorithm": "random", + "Experiment Name": "experiment-name-2", + "Number Trial": max_trial, + "Best Evaluation": 0, + }, + { + "Algorithm": "tpe", + "Experiment Name": "experiment-name-3", + "Number Trial": max_trial, + "Best Evaluation": 0, + }, + ] class TestStudy: @@ -221,84 +221,84 @@ def test_creation_algorithms(self, benchmark): def test_setup_experiments(self, study): """Test to setup experiments for study""" - with OrionState(): - study.setup_experiments() + study.setup_experiments() - assert len(study.experiments_info) == 4 - assert isinstance(study.experiments_info[0][1], ExperimentClient) + assert len(study.experiments_info) == 4 + assert isinstance(study.experiments_info[0][1], ExperimentClient) def test_execute(self, study): """Test to execute a study""" - with OrionState(): - study.setup_experiments() - study.execute() - name = "benchmark007_AverageResult_RosenBrock_0_0" - experiment = experiment_builder.build(name) + study.setup_experiments() + study.execute() + name = "benchmark007_AverageResult_RosenBrock_0_0" + experiment = experiment_builder.build(name) - assert len(experiment.fetch_trials()) == study.task.max_trials + assert len(experiment.fetch_trials()) == study.task.max_trials - assert experiment is not None + assert experiment is not None @pytest.mark.usefixtures("version_XYZ") def test_status( self, + orionstate, study, study_experiments_config, task_number, max_trial, ): """Test to get status of a study""" - with create_study_experiments(**study_experiments_config) as experiments: + experiments = create_study_experiments(orionstate, **study_experiments_config) - study.experiments_info = experiments + study.experiments_info = experiments - assert study.status() == [ - { - "algorithm": "random", - "assessment": "AverageResult", - "task": "RosenBrock", - "experiments": task_number, - "completed": task_number, - "trials": task_number * max_trial, - }, - { - "algorithm": "tpe", - "assessment": "AverageResult", - "task": "RosenBrock", - "experiments": task_number, - "completed": task_number, - "trials": task_number * max_trial, - }, - ] + assert study.status() == [ + { + "algorithm": "random", + "assessment": "AverageResult", + "task": "RosenBrock", + "experiments": task_number, + "completed": task_number, + "trials": task_number * max_trial, + }, + { + "algorithm": "tpe", + "assessment": "AverageResult", + "task": "RosenBrock", + "experiments": task_number, + "completed": task_number, + "trials": task_number * max_trial, + }, + ] @pytest.mark.usefixtures("version_XYZ") def test_analysis( self, + orionstate, study, study_experiments_config, ): """Test to get the ploty figure of a study""" - with create_study_experiments(**study_experiments_config) as experiments: + experiments = create_study_experiments(orionstate, **study_experiments_config) - study.experiments_info = experiments + study.experiments_info = experiments - figure = study.analysis() + figure = study.analysis() - assert ( - type(figure[study.assess_name][study.task_name]["regrets"]) - is plotly.graph_objects.Figure - ) + assert ( + type(figure[study.assess_name][study.task_name]["regrets"]) + is plotly.graph_objects.Figure + ) - def test_experiments(self, study, study_experiments_config, task_number): + def test_experiments( + self, orionstate, study, study_experiments_config, task_number + ): """Test to get experiments of a study""" algo_num = len(study_experiments_config["algorithms"]) - with create_study_experiments(**study_experiments_config) as experiments: + experiments = create_study_experiments(orionstate, **study_experiments_config) - study.experiments_info = experiments + study.experiments_info = experiments - experiments = study.experiments() + experiments = study.experiments() - assert ( - len(experiments) == study_experiments_config["task_number"] * algo_num - ) - assert isinstance(experiments[0], ExperimentClient) + assert len(experiments) == study_experiments_config["task_number"] * algo_num + assert isinstance(experiments[0], ExperimentClient) diff --git a/tests/unittests/benchmark/test_benchmark_client.py b/tests/unittests/benchmark/test_benchmark_client.py index df3635f39..ee0cfd270 100644 --- a/tests/unittests/benchmark/test_benchmark_client.py +++ b/tests/unittests/benchmark/test_benchmark_client.py @@ -11,13 +11,9 @@ from orion.benchmark.benchmark_client import get_or_create_benchmark from orion.benchmark.task import CarromTable, RosenBrock from orion.client import ExperimentClient -from orion.core.io.database.ephemeraldb import EphemeralDB -from orion.core.io.database.pickleddb import PickledDB from orion.core.utils.exceptions import NoConfigurationError -from orion.core.utils.singleton import SingletonNotInstantiatedError, update_singletons from orion.executor.joblib_backend import Joblib -from orion.storage.base import get_storage -from orion.storage.legacy import Legacy +from orion.storage.base import setup_storage from orion.testing.state import OrionState @@ -31,100 +27,25 @@ class DummyAssess: def count_benchmarks(): """Count experiments in storage""" - return len(get_storage().fetch_benchmark({})) + return len(setup_storage().fetch_benchmark({})) -class TestCreateBenchmark: - """Test Benchmark creation""" - - @pytest.mark.usefixtures("setup_pickleddb_database") - def test_create_benchmark_no_storage(self, benchmark_config_py): - """Test creation if storage is not configured""" - name = "oopsie_forgot_a_storage" - host = orion.core.config.storage.database.host - - with OrionState(storage=orion.core.config.storage.to_dict()) as cfg: - # Reset the Storage and drop instances so that get_storage() would fail. - cfg.cleanup() - cfg.singletons = update_singletons() - - # Make sure storage must be instantiated during `get_or_create_benchmark()` - with pytest.raises(SingletonNotInstantiatedError): - get_storage() - - get_or_create_benchmark(**benchmark_config_py).close() - - storage = get_storage() - - assert isinstance(storage, Legacy) - assert isinstance(storage._db, PickledDB) - assert storage._db.host == host - - def test_create_benchmark_with_storage(self, benchmark_config_py): - """Test benchmark instance has the storage configurations""" - - config = copy.deepcopy(benchmark_config_py) - storage = {"type": "legacy", "database": {"type": "EphemeralDB"}} - with OrionState(storage=storage): - config["storage"] = storage - bm = get_or_create_benchmark(**config) - bm.close() - - assert bm.storage_config == config["storage"] - - def test_create_benchmark_bad_storage(self, benchmark_config_py): - """Test error message if storage is not configured properly""" - name = "oopsie_bad_storage" - # Make sure there is no existing storage singleton - update_singletons() +storage_instance = "" - with pytest.raises(NotImplementedError) as exc: - benchmark_config_py["storage"] = { - "type": "legacy", - "database": {"type": "idontexist"}, - } - get_or_create_benchmark(**benchmark_config_py).close() - assert "Could not find implementation of Database, type = 'idontexist'" in str( - exc.value - ) - - def test_create_experiment_debug_mode(self, tmp_path, benchmark_config_py): - """Test that EphemeralDB is used in debug mode whatever the storage config given""" - update_singletons() - - conf_file = str(tmp_path / "db.pkl") - - config = copy.deepcopy(benchmark_config_py) - config["storage"] = { - "type": "legacy", - "database": {"type": "pickleddb", "host": conf_file}, - } - - get_or_create_benchmark(**config).close() - - storage = get_storage() - - assert isinstance(storage, Legacy) - assert isinstance(storage._db, PickledDB) - - update_singletons() - config["storage"] = {"type": "legacy", "database": {"type": "pickleddb"}} - config["debug"] = True - get_or_create_benchmark(**config).close() - - storage = get_storage() - - assert isinstance(storage, Legacy) - assert isinstance(storage._db, EphemeralDB) +class TestCreateBenchmark: + """Test Benchmark creation""" def test_create_benchmark(self, benchmark_config, benchmark_config_py): """Test creation with valid configuration""" - with OrionState(): - bm1 = get_or_create_benchmark(**benchmark_config_py) + with OrionState() as cfg: + bm1 = get_or_create_benchmark( + cfg.storage, + **benchmark_config_py, + ) bm1.close() - bm2 = get_or_create_benchmark("bm00001") + bm2 = get_or_create_benchmark(cfg.storage, "bm00001") bm2.close() assert bm1.configuration == benchmark_config @@ -133,18 +54,18 @@ def test_create_benchmark(self, benchmark_config, benchmark_config_py): def test_create_with_only_name(self): """Test creation with a non-existing benchmark name""" - with OrionState(): + with OrionState() as cfg: name = "bm00001" with pytest.raises(NoConfigurationError) as exc: - get_or_create_benchmark(name).close() + get_or_create_benchmark(cfg.storage, name).close() assert f"Benchmark {name} does not exist in DB" in str(exc.value) def test_create_with_different_configure(self, benchmark_config_py, caplog): """Test creation with same name but different configure""" - with OrionState(): + with OrionState() as cfg: config = copy.deepcopy(benchmark_config_py) - bm1 = get_or_create_benchmark(**config) + bm1 = get_or_create_benchmark(cfg.storage, **config) bm1.close() config = copy.deepcopy(benchmark_config_py) @@ -153,7 +74,7 @@ def test_create_with_different_configure(self, benchmark_config_py, caplog): with caplog.at_level( logging.WARNING, logger="orion.benchmark.benchmark_client" ): - bm2 = get_or_create_benchmark(**config) + bm2 = get_or_create_benchmark(cfg.storage, **config) bm2.close() assert bm2.configuration == bm1.configuration @@ -168,7 +89,7 @@ def test_create_with_different_configure(self, benchmark_config_py, caplog): with caplog.at_level( logging.WARNING, logger="orion.benchmark.benchmark_client" ): - bm3 = get_or_create_benchmark(**config) + bm3 = get_or_create_benchmark(cfg.storage, **config) bm3.close() assert bm3.configuration == bm1.configuration @@ -179,7 +100,7 @@ def test_create_with_different_configure(self, benchmark_config_py, caplog): def test_create_with_invalid_algorithms(self, benchmark_config_py): """Test creation with a not existed algorithm""" - with OrionState(): + with OrionState() as cfg: with pytest.raises(NotImplementedError) as exc: benchmark_config_py["algorithms"] = [ @@ -187,7 +108,9 @@ def test_create_with_invalid_algorithms(self, benchmark_config_py): ] # Pass executor to close it properly with Joblib(n_workers=2, backend="threading") as executor: - get_or_create_benchmark(**benchmark_config_py, executor=executor) + get_or_create_benchmark( + cfg.storage, **benchmark_config_py, executor=executor + ) assert "Could not find implementation of BaseAlgorithm" in str(exc.value) def test_create_with_deterministic_algorithm(self, benchmark_config_py): @@ -195,10 +118,10 @@ def test_create_with_deterministic_algorithm(self, benchmark_config_py): {"algorithm": {"random": {"seed": 1}}}, {"algorithm": {"gridsearch": {"n_values": 50}}, "deterministic": True}, ] - with OrionState(): + with OrionState() as cfg: config = copy.deepcopy(benchmark_config_py) config["algorithms"] = algorithms - bm = get_or_create_benchmark(**config) + bm = get_or_create_benchmark(cfg.storage, **config) bm.close() for study in bm.studies: @@ -211,14 +134,14 @@ def test_create_with_deterministic_algorithm(self, benchmark_config_py): def test_create_with_invalid_targets(self, benchmark_config_py): """Test creation with invalid Task and Assessment""" - with OrionState(): + with OrionState() as cfg: with pytest.raises(AttributeError) as exc: config = copy.deepcopy(benchmark_config_py) config["targets"] = [ {"assess": [AverageResult(2)], "task": [DummyTask]} ] - get_or_create_benchmark(**config).close() + get_or_create_benchmark(cfg.storage, **config).close() assert "type object '{}' has no attribute ".format("DummyTask") in str( exc.value @@ -229,7 +152,7 @@ def test_create_with_invalid_targets(self, benchmark_config_py): config["targets"] = [ {"assess": [DummyAssess], "task": [RosenBrock(25, dim=3)]} ] - get_or_create_benchmark(**config).close() + get_or_create_benchmark(cfg.storage, **config).close() assert "type object '{}' has no attribute ".format("DummyAssess") in str( exc.value @@ -241,9 +164,12 @@ def test_create_with_not_loaded_targets(self, benchmark_config): cfg_invalid_assess = copy.deepcopy(benchmark_config) cfg_invalid_assess["targets"][0]["assess"]["idontexist"] = {"task_num": 2} - with OrionState(benchmarks=cfg_invalid_assess): + with OrionState(benchmarks=cfg_invalid_assess) as cfg: with pytest.raises(NotImplementedError) as exc: - get_or_create_benchmark(benchmark_config["name"]).close() + get_or_create_benchmark( + cfg.storage, + benchmark_config["name"], + ).close() assert "Could not find implementation of BenchmarkAssessment" in str( exc.value ) @@ -251,9 +177,12 @@ def test_create_with_not_loaded_targets(self, benchmark_config): cfg_invalid_task = copy.deepcopy(benchmark_config) cfg_invalid_task["targets"][0]["task"]["idontexist"] = {"max_trials": 2} - with OrionState(benchmarks=cfg_invalid_task): + with OrionState(benchmarks=cfg_invalid_task) as cfg: with pytest.raises(NotImplementedError) as exc: - get_or_create_benchmark(benchmark_config["name"]) + get_or_create_benchmark( + cfg.storage, + benchmark_config["name"], + ) assert "Could not find implementation of BenchmarkTask" in str(exc.value) def test_create_with_not_exist_targets_parameters(self, benchmark_config): @@ -264,17 +193,17 @@ def test_create_with_not_exist_targets_parameters(self, benchmark_config): "idontexist": 100, } - with OrionState(benchmarks=benchmark_config): + with OrionState(benchmarks=benchmark_config) as cfg: with pytest.raises(TypeError) as exc: - get_or_create_benchmark(benchmark_config["name"]) + get_or_create_benchmark(cfg.storage, benchmark_config["name"]) assert "__init__() got an unexpected keyword argument 'idontexist'" in str( exc.value ) def test_create_from_db_config(self, benchmark_config): """Test creation from existing db configubenchmark_configre""" - with OrionState(benchmarks=copy.deepcopy(benchmark_config)): - bm = get_or_create_benchmark(benchmark_config["name"]) + with OrionState(benchmarks=copy.deepcopy(benchmark_config)) as cfg: + bm = get_or_create_benchmark(cfg.storage, benchmark_config["name"]) bm.close() assert bm.configuration == benchmark_config @@ -282,7 +211,7 @@ def test_create_race_condition( self, benchmark_config, benchmark_config_py, monkeypatch, caplog ): """Test creation in race condition""" - with OrionState(benchmarks=benchmark_config): + with OrionState(benchmarks=benchmark_config) as cfg: def insert_race_condition(*args, **kwargs): if insert_race_condition.count == 0: @@ -302,7 +231,9 @@ def insert_race_condition(*args, **kwargs): with caplog.at_level( logging.INFO, logger="orion.benchmark.benchmark_client" ): - bm = benchmark_client.get_or_create_benchmark(**benchmark_config_py) + bm = benchmark_client.get_or_create_benchmark( + cfg.storage, **benchmark_config_py + ) bm.close() assert ( @@ -318,9 +249,9 @@ def insert_race_condition(*args, **kwargs): def test_create_with_executor(self, benchmark_config, benchmark_config_py): - with OrionState(): + with OrionState() as cfg: config = copy.deepcopy(benchmark_config_py) - bm1 = get_or_create_benchmark(**config) + bm1 = get_or_create_benchmark(cfg.storage, **config) bm1.close() assert bm1.configuration == benchmark_config @@ -328,7 +259,7 @@ def test_create_with_executor(self, benchmark_config, benchmark_config_py): with Joblib(n_workers=2, backend="threading") as executor: config["executor"] = executor - bm2 = get_or_create_benchmark(**config) + bm2 = get_or_create_benchmark(cfg.storage, **config) assert bm2.configuration == benchmark_config assert bm2.executor.n_workers == executor.n_workers @@ -366,7 +297,7 @@ def submit(*args, c=count, **kwargs): c.value += 1 return FakeFuture([dict(name="v", type="objective", value=1)]) - with OrionState(): + with OrionState() as cfg: config = copy.deepcopy(benchmark_config_py) with Joblib(n_workers=5, backend="threading") as executor: @@ -374,7 +305,7 @@ def submit(*args, c=count, **kwargs): monkeypatch.setattr(executor, "submit", submit) config["executor"] = executor - bm1 = get_or_create_benchmark(**config) + bm1 = get_or_create_benchmark(cfg.storage, **config) client = bm1.studies[0].experiments_info[0][1] count.value = 0 diff --git a/tests/unittests/client/test_client.py b/tests/unittests/client/test_client.py index c0dce6dba..689ba8aac 100644 --- a/tests/unittests/client/test_client.py +++ b/tests/unittests/client/test_client.py @@ -19,8 +19,7 @@ RaceCondition, UnsupportedOperation, ) -from orion.core.utils.singleton import SingletonNotInstantiatedError, update_singletons -from orion.storage.base import get_storage +from orion.storage.base import setup_storage from orion.storage.legacy import Legacy from orion.testing import OrionState @@ -147,22 +146,23 @@ def test_call_interface_twice(self, monkeypatch, data): class TestCreateExperiment: """Test creation of experiment with `client.create_experiment()`""" - @pytest.mark.usefixtures("setup_pickleddb_database") + @pytest.mark.usefixtures("orionstate") def test_create_experiment_no_storage(self, monkeypatch): """Test creation if storage is not configured""" name = "oopsie_forgot_a_storage" host = orion.core.config.storage.database.host with OrionState(storage=orion.core.config.storage.to_dict()) as cfg: - # Reset the Storage and drop instances so that get_storage() would fail. + # Reset the Storage and drop instances so that setup_storage() would fail. cfg.cleanup() - cfg.singletons = update_singletons() # Make sure storage must be instantiated during `create_experiment()` - with pytest.raises(SingletonNotInstantiatedError): - get_storage() + # with pytest.raises(SingletonNotInstantiatedError): + # setup_storage() - experiment = create_experiment(name=name, space={"x": "uniform(0, 10)"}) + experiment = create_experiment( + name=name, space={"x": "uniform(0, 10)"}, storage=cfg.storage_config + ) assert isinstance(experiment._experiment._storage, Legacy) assert isinstance(experiment._experiment._storage._db, PickledDB) @@ -170,10 +170,10 @@ def test_create_experiment_no_storage(self, monkeypatch): def test_create_experiment_new_no_space(self): """Test that new experiment needs space""" - with OrionState(): + with OrionState() as cfg: name = "oopsie_forgot_a_space" with pytest.raises(NoConfigurationError) as exc: - create_experiment(name=name) + create_experiment(name=name, storage=cfg.storage_config) assert f"Experiment {name} does not exist in DB" in str(exc.value) @@ -181,7 +181,6 @@ def test_create_experiment_bad_storage(self): """Test error message if storage is not configured properly""" name = "oopsie_bad_storage" # Make sure there is no existing storage singleton - update_singletons() with pytest.raises(NotImplementedError) as exc: create_experiment( @@ -197,8 +196,10 @@ def test_create_experiment_new_default(self): """Test creating a new experiment with all defaults""" name = "all_default" space = {"x": "uniform(0, 10)"} - with OrionState(): - experiment = create_experiment(name="all_default", space=space) + with OrionState() as cfg: + experiment = create_experiment( + name="all_default", space=space, storage=cfg.storage_config + ) assert experiment.name == name assert experiment.space.configuration == space @@ -210,8 +211,8 @@ def test_create_experiment_new_default(self): def test_create_experiment_new_full_config(self, user_config): """Test creating a new experiment by specifying all attributes.""" - with OrionState(): - experiment = create_experiment(**user_config) + with OrionState() as cfg: + experiment = create_experiment(**user_config, storage=cfg.storage_config) exp_config = experiment.configuration @@ -223,8 +224,8 @@ def test_create_experiment_new_full_config(self, user_config): def test_create_experiment_hit_no_branch(self, user_config): """Test creating an existing experiment by specifying all identical attributes.""" - with OrionState(experiments=[config]): - experiment = create_experiment(**user_config) + with OrionState(experiments=[config]) as cfg: + experiment = create_experiment(**user_config, storage=cfg.storage_config) exp_config = experiment.configuration @@ -238,8 +239,8 @@ def test_create_experiment_hit_no_branch(self, user_config): def test_create_experiment_hit_no_config(self): """Test creating an existing experiment by specifying the name only.""" - with OrionState(experiments=[config]): - experiment = create_experiment(config["name"]) + with OrionState(experiments=[config]) as cfg: + experiment = create_experiment(config["name"], storage=cfg.storage_config) assert experiment.name == config["name"] assert experiment.version == 1 @@ -251,11 +252,12 @@ def test_create_experiment_hit_no_config(self): def test_create_experiment_hit_branch(self): """Test creating a differing experiment that cause branching.""" - with OrionState(experiments=[config]): + with OrionState(experiments=[config]) as cfg: experiment = create_experiment( config["name"], space={"y": "uniform(0, 10)"}, branching={"enable": True}, + storage=cfg.storage_config, ) assert experiment.name == config["name"] @@ -272,12 +274,13 @@ def test_create_experiment_race_condition(self, monkeypatch): RaceCondition during registration is already handled by `build()`, therefore we will only test for race conditions during version update. """ - with OrionState(experiments=[config]): + with OrionState(experiments=[config]) as cfg: parent = create_experiment(config["name"]) child = create_experiment( config["name"], space={"y": "uniform(0, 10)"}, branching={"enable": True}, + storage=cfg.storage_config, ) def insert_race_condition(self, query): @@ -300,7 +303,7 @@ def insert_race_condition(self, query): insert_race_condition.count = 0 monkeypatch.setattr( - get_storage().__class__, "fetch_experiments", insert_race_condition + setup_storage().__class__, "fetch_experiments", insert_race_condition ) experiment = create_experiment( @@ -315,12 +318,13 @@ def insert_race_condition(self, query): def test_create_experiment_race_condition_broken(self, monkeypatch): """Test that two or more race condition leads to raise""" - with OrionState(experiments=[config]): + with OrionState(experiments=[config]) as cfg: parent = create_experiment(config["name"]) child = create_experiment( config["name"], space={"y": "uniform(0, 10)"}, branching={"enable": True}, + storage=cfg.storage_config, ) def insert_race_condition(self, query): @@ -341,7 +345,7 @@ def insert_race_condition(self, query): insert_race_condition.count = 0 monkeypatch.setattr( - get_storage().__class__, "fetch_experiments", insert_race_condition + setup_storage().__class__, "fetch_experiments", insert_race_condition ) with pytest.raises(RaceCondition) as exc: @@ -359,9 +363,12 @@ def insert_race_condition(self, query): def test_create_experiment_hit_manual_branch(self): """Test creating a differing experiment that cause branching.""" new_space = {"y": "uniform(0, 10)"} - with OrionState(experiments=[config]): + with OrionState(experiments=[config]) as cfg: create_experiment( - config["name"], space=new_space, branching={"enable": True} + config["name"], + space=new_space, + branching={"enable": True}, + storage=cfg.storage_config, ) with pytest.raises(BranchingEvent) as exc: @@ -376,11 +383,10 @@ def test_create_experiment_hit_manual_branch(self): def test_create_experiment_debug_mode(self, tmp_path): """Test that EphemeralDB is used in debug mode whatever the storage config given""" - update_singletons() conf_file = str(tmp_path / "db.pkl") - create_experiment( + experiment = create_experiment( config["name"], space={"x": "uniform(0, 10)"}, storage={ @@ -389,22 +395,18 @@ def test_create_experiment_debug_mode(self, tmp_path): }, ) - storage = get_storage() - + storage = experiment._experiment._storage assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) - update_singletons() - - create_experiment( + experiment = create_experiment( config["name"], space={"x": "uniform(0, 10)"}, storage={"type": "legacy", "database": {"type": "pickleddb"}}, debug=True, ) - storage = get_storage() - + storage = experiment._experiment._storage assert isinstance(storage, Legacy) assert isinstance(storage._db, EphemeralDB) @@ -463,7 +465,6 @@ def build_fail(*args, **kwargs): monkeypatch.setattr("orion.core.io.experiment_builder.build", build_fail) # Flush storage singleton - update_singletons() with pytest.raises(RuntimeError) as exc: experiment = workon( @@ -472,23 +473,17 @@ def build_fail(*args, **kwargs): assert exc.match("You shall not build!") - # Verify that tmp storage was cleared - with pytest.raises(SingletonNotInstantiatedError): - get_storage() - # Now test with a prior storage with OrionState( storage={"type": "legacy", "database": {"type": "EphemeralDB"}} - ): - storage = get_storage() + ) as cfg: + storage = cfg.storage with pytest.raises(RuntimeError) as exc: workon(foo, space={"x": "uniform(0, 10)"}, max_trials=5, name="voici") assert exc.match("You shall not build!") - assert get_storage() is storage - def test_workon_twice(self): """Verify setting the each experiment has its own storage""" @@ -547,26 +542,28 @@ def test_experiment_do_not_exist(self): "no view can be created." == str(exception.value) ) - @pytest.mark.usefixtures("mock_database") - def test_experiment_exist(self): + def test_experiment_exist(self, mock_database): """ Tests that an instance of :class:`orion.client.experiment.ExperimentClient` is returned representing the latest version when none is given. """ - experiment = create_experiment("a", space={"x": "uniform(0, 10)"}) + experiment = create_experiment( + "a", space={"x": "uniform(0, 10)"}, storage=mock_database.storage + ) - experiment = get_experiment("a") + experiment = get_experiment("a", storage=mock_database.storage) assert experiment assert isinstance(experiment, ExperimentClient) assert experiment.mode == "r" - @pytest.mark.usefixtures("mock_database") - def test_version_do_not_exist(self, caplog): + def test_version_do_not_exist(self, caplog, mock_database): """Tests that a warning is printed when the experiment exist but the version doesn't""" - create_experiment("a", space={"x": "uniform(0, 10)"}) + create_experiment( + "a", space={"x": "uniform(0, 10)"}, storage=mock_database.storage + ) - experiment = get_experiment("a", 2) + experiment = get_experiment("a", 2, storage=mock_database.storage) assert experiment.version == 1 assert ( @@ -574,13 +571,14 @@ def test_version_do_not_exist(self, caplog): in caplog.text ) - @pytest.mark.usefixtures("mock_database") - def test_read_write_mode(self): + def test_read_write_mode(self, mock_database): """Tests that experiment can be created in write mode""" - experiment = create_experiment("a", space={"x": "uniform(0, 10)"}) + experiment = create_experiment( + "a", space={"x": "uniform(0, 10)"}, storage=mock_database.storage + ) assert experiment.mode == "x" - experiment = get_experiment("a", 2, mode="r") + experiment = get_experiment("a", 2, mode="r", storage=mock_database.storage) assert experiment.mode == "r" with pytest.raises(UnsupportedOperation) as exc: @@ -588,7 +586,7 @@ def test_read_write_mode(self): assert exc.match("ExperimentClient must have write rights to execute `insert()") - experiment = get_experiment("a", 2, mode="w") + experiment = get_experiment("a", 2, mode="w", storage=mock_database.storage) assert experiment.mode == "w" trial = experiment.insert({"x": 0}) diff --git a/tests/unittests/client/test_experiment_client.py b/tests/unittests/client/test_experiment_client.py index 0d7f91142..d8a242929 100644 --- a/tests/unittests/client/test_experiment_client.py +++ b/tests/unittests/client/test_experiment_client.py @@ -20,7 +20,7 @@ ) from orion.core.worker.trial import AlreadyReleased, Trial from orion.executor.base import ExecutorClosed, executor_factory -from orion.storage.base import get_storage +from orion.storage.base import setup_storage from orion.testing import create_experiment, mock_space_iterate config = dict( @@ -584,11 +584,11 @@ def test_completed_then_interrupted_trial(self): with client.suggest() as trial: assert trial.status == "reserved" assert trial.results == [] - assert get_storage().get_trial(trial).objective is None + assert setup_storage().get_trial(trial).objective is None client.observe( trial, [dict(name="objective", type="objective", value=101)] ) - assert get_storage().get_trial(trial).objective.value == 101 + assert setup_storage().get_trial(trial).objective.value == 101 assert trial.status == "completed" raise KeyboardInterrupt @@ -839,9 +839,9 @@ def test_observe(self): trial = Trial(**cfg.trials[1]) assert trial.results == [] client.reserve(trial) - assert get_storage().get_trial(trial).objective is None + assert setup_storage().get_trial(trial).objective is None client.observe(trial, [dict(name="objective", type="objective", value=101)]) - assert get_storage().get_trial(trial).objective.value == 101 + assert setup_storage().get_trial(trial).objective.value == 101 def test_observe_unreserved(self): """Verify that `observe()` will fail on non-reserved trials""" @@ -903,11 +903,11 @@ def test_observe_under_with(self): with client.suggest() as trial: assert trial.status == "reserved" assert trial.results == [] - assert get_storage().get_trial(trial).objective is None + assert setup_storage().get_trial(trial).objective is None client.observe( trial, [dict(name="objective", type="objective", value=101)] ) - assert get_storage().get_trial(trial).objective.value == 101 + assert setup_storage().get_trial(trial).objective.value == 101 assert trial.status == "completed" assert trial.status == "completed" # Still completed after __exit__ diff --git a/tests/unittests/core/cli/test_checks.py b/tests/unittests/core/cli/test_checks.py index 4f2fc6f14..8882d7d50 100644 --- a/tests/unittests/core/cli/test_checks.py +++ b/tests/unittests/core/cli/test_checks.py @@ -149,7 +149,7 @@ def mock_file_config(self): assert presence.db_config == config["database"] -@pytest.mark.usefixtures("null_db_instances", "setup_pickleddb_database") +@pytest.mark.usefixtures("null_db_instances", "orionstate") def test_creation_pass(presence, config): """Check if test passes with valid database configuration.""" presence.db_config = config["database"] @@ -162,7 +162,7 @@ def test_creation_pass(presence, config): assert creation.instance is not None -@pytest.mark.usefixtures("null_db_instances", "setup_pickleddb_database") +@pytest.mark.usefixtures("null_db_instances", "orionstate") def test_creation_fails(monkeypatch, presence, config): """Check if test fails when not connected.""" presence.db_config = config["database"] diff --git a/tests/unittests/core/database/test_mongodb.py b/tests/unittests/core/database/test_mongodb.py index 16c877629..7d3460acf 100644 --- a/tests/unittests/core/database/test_mongodb.py +++ b/tests/unittests/core/database/test_mongodb.py @@ -7,12 +7,7 @@ import pytest from pymongo import MongoClient -from orion.core.io.database import ( - Database, - DatabaseError, - DuplicateKeyError, - database_factory, -) +from orion.core.io.database import Database, DatabaseError, DuplicateKeyError from orion.core.io.database.mongodb import AUTH_FAILED_MESSAGES, MongoDB from .conftest import insert_test_collection @@ -226,21 +221,6 @@ def test_overwrite_partial_uri(self, monkeypatch): assert orion_db.password == "none" assert orion_db.name == "orion" - def test_singleton(self): - """Test that MongoDB class is a singleton.""" - orion_db = database_factory.create( - of_type="mongodb", - host="mongodb://localhost", - port=27017, - name="orion_test", - username="user", - password="pass", - ) - # reinit connection does not change anything - orion_db.initiate_connection() - orion_db.close_connection() - assert database_factory.create() is orion_db - def test_change_server_timeout(self): """Test that the server timeout is correctly changed.""" assert ( diff --git a/tests/unittests/core/evc/test_conflicts.py b/tests/unittests/core/evc/test_conflicts.py index b60eeafc4..8054cba94 100644 --- a/tests/unittests/core/evc/test_conflicts.py +++ b/tests/unittests/core/evc/test_conflicts.py @@ -450,72 +450,74 @@ def test_comparison_idem(self, yaml_config, script_path): assert list(conflict.ScriptConfigConflict.detect(old_config, new_config)) == [] -@pytest.mark.usefixtures("setup_pickleddb_database") +@pytest.mark.usefixtures("orionstate") class TestExperimentNameConflict: """Tests methods related to experiment name conflicts""" - def test_try_resolve_twice(self, experiment_name_conflict): + def test_try_resolve_twice(self, experiment_name_conflict, storage): """Verify that conflict cannot be resolved twice""" assert not experiment_name_conflict.is_resolved assert isinstance( - experiment_name_conflict.try_resolve("dummy"), + experiment_name_conflict.try_resolve("dummy", storage=storage), experiment_name_conflict.ExperimentNameResolution, ) assert experiment_name_conflict.is_resolved - assert experiment_name_conflict.try_resolve() is None + assert experiment_name_conflict.try_resolve(storage=storage) is None - def test_try_resolve(self, experiment_name_conflict): + def test_try_resolve(self, experiment_name_conflict, storage): """Verify that resolution is achievable with a valid name""" new_name = "dummy" assert not experiment_name_conflict.is_resolved - resolution = experiment_name_conflict.try_resolve(new_name) + resolution = experiment_name_conflict.try_resolve(new_name, storage=storage) assert isinstance(resolution, experiment_name_conflict.ExperimentNameResolution) assert experiment_name_conflict.is_resolved assert resolution.conflict is experiment_name_conflict assert resolution.new_name == new_name - def test_branch_w_existing_exp(self, existing_exp_conflict): + def test_branch_w_existing_exp(self, existing_exp_conflict, storage): """Test branching when an existing experiment with the new name already exists""" with pytest.raises(ValueError) as exc: - existing_exp_conflict.try_resolve("dummy") + existing_exp_conflict.try_resolve("dummy", storage=storage) assert "Cannot" in str(exc.value) - def test_conflict_exp_no_child(self, exp_no_child_conflict): + def test_conflict_exp_no_child(self, exp_no_child_conflict, storage): """Verify the version number is incremented when exp has no child.""" new_name = "test" assert not exp_no_child_conflict.is_resolved - resolution = exp_no_child_conflict.try_resolve(new_name) + resolution = exp_no_child_conflict.try_resolve(new_name, storage=storage) assert isinstance(resolution, exp_no_child_conflict.ExperimentNameResolution) assert exp_no_child_conflict.is_resolved assert resolution.conflict is exp_no_child_conflict assert resolution.old_version == 1 assert resolution.new_version == 2 - def test_conflict_exp_w_child(self, exp_w_child_conflict): + def test_conflict_exp_w_child(self, exp_w_child_conflict, storage): """Verify the version number is incremented from child when exp has a child.""" new_name = "test" assert not exp_w_child_conflict.is_resolved - resolution = exp_w_child_conflict.try_resolve(new_name) + resolution = exp_w_child_conflict.try_resolve(new_name, storage=storage) assert isinstance(resolution, exp_w_child_conflict.ExperimentNameResolution) assert exp_w_child_conflict.is_resolved assert resolution.conflict is exp_w_child_conflict assert resolution.new_version == 3 - def test_conflict_exp_w_child_as_parent(self, exp_w_child_as_parent_conflict): + def test_conflict_exp_w_child_as_parent( + self, exp_w_child_as_parent_conflict, storage + ): """Verify that an error is raised when trying to branch from parent.""" new_name = "test" with pytest.raises(ValueError) as exc: - exp_w_child_as_parent_conflict.try_resolve(new_name) + exp_w_child_as_parent_conflict.try_resolve(new_name, storage=storage) assert "Experiment name" in str(exc.value) - def test_conflict_exp_renamed(self, exp_w_child_conflict): + def test_conflict_exp_renamed(self, exp_w_child_conflict, storage): """Verify the version number is not incremented when exp is renamed.""" # It increments from child new_name = "test2" assert not exp_w_child_conflict.is_resolved - resolution = exp_w_child_conflict.try_resolve(new_name) + resolution = exp_w_child_conflict.try_resolve(new_name, storage=storage) assert isinstance(resolution, exp_w_child_conflict.ExperimentNameResolution) assert exp_w_child_conflict.is_resolved assert resolution.conflict is exp_w_child_conflict diff --git a/tests/unittests/core/evc/test_resolutions.py b/tests/unittests/core/evc/test_resolutions.py index 6d2c34464..66a45d368 100644 --- a/tests/unittests/core/evc/test_resolutions.py +++ b/tests/unittests/core/evc/test_resolutions.py @@ -56,10 +56,12 @@ def code_resolution(code_conflict): @pytest.fixture -def experiment_name_resolution(setup_pickleddb_database, experiment_name_conflict): +def experiment_name_resolution(orionstate, experiment_name_conflict): """Create a resolution for a code conflict""" return experiment_name_conflict.ExperimentNameResolution( - experiment_name_conflict, new_name="new-exp-name" + experiment_name_conflict, + new_name="new-exp-name", + storage=orionstate.storage, ) diff --git a/tests/unittests/core/io/conftest.py b/tests/unittests/core/io/conftest.py index 40e2010cf..0a9394c32 100644 --- a/tests/unittests/core/io/conftest.py +++ b/tests/unittests/core/io/conftest.py @@ -4,40 +4,65 @@ import copy import os +import tempfile import pytest from orion.core.evc import conflicts +class _GenerateConfig: + def __init__(self, name) -> None: + self.name = name + self.generated_config = None + self.database_file = None + + def __enter__(self): + file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), self.name) + + self.generated_config = tempfile.NamedTemporaryFile() + self.database_file = tempfile.NamedTemporaryFile() + + with open(file_path) as config: + new_config = config.read().replace("${FILE}", self.database_file.name) + self.generated_config.write(new_config.encode("utf-8")) + self.generated_config.flush() + + return open(self.generated_config.name) + + def __exit__(self, *args, **kwargs): + self.database_file.close() + self.generated_config.close() + + @pytest.fixture() -def config_file(): +def raw_config(): """Open config file with new config""" file_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "orion_config.yaml" ) - return open(file_path) +@pytest.fixture() +def config_file(): + """Open config file with new config""" + with _GenerateConfig("orion_config.yaml") as file: + yield file + + @pytest.fixture() def old_config_file(): """Open config file with original config from an experiment in db""" - file_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "orion_old_config.yaml" - ) - - return open(file_path) + with _GenerateConfig("orion_old_config.yaml") as file: + yield file @pytest.fixture() def incomplete_config_file(): """Open config file with partial database configuration""" - file_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "orion_incomplete_config.yaml" - ) - - return open(file_path) + with _GenerateConfig("orion_incomplete_config.yaml") as file: + yield file @pytest.fixture diff --git a/tests/unittests/core/io/database_test.py b/tests/unittests/core/io/database_test.py index 73ac58a12..8c404e156 100644 --- a/tests/unittests/core/io/database_test.py +++ b/tests/unittests/core/io/database_test.py @@ -4,12 +4,7 @@ import pytest from orion.core.io.database import ReadOnlyDB, database_factory -from orion.core.io.database.pickleddb import PickledDB -from orion.core.utils.singleton import ( - SingletonAlreadyInstantiatedError, - SingletonNotInstantiatedError, -) -from orion.storage.base import get_storage +from orion.storage.base import setup_storage @pytest.mark.usefixtures("null_db_instances") @@ -21,17 +16,6 @@ class TestDatabaseFactory: """ - def test_empty_first_call(self): - """Should not be able to make first call without any arguments. - - Hegelian Ontology Primer - ------------------------ - - Type indeterminate <-> type abstracted from its property <-> No type - """ - with pytest.raises(SingletonNotInstantiatedError): - database_factory.create() - def test_notfound_type_first_call(self): """Raise when supplying not implemented wrapper name.""" with pytest.raises(NotImplementedError) as exc_info: @@ -39,16 +23,6 @@ def test_notfound_type_first_call(self): assert "Database" in str(exc_info.value) - def test_instantiation_and_singleton(self): - """Test create just one object, that object persists between calls.""" - database = database_factory.create(of_type="PickledDB", name="orion_test") - - assert isinstance(database, PickledDB) - assert database is database_factory.create() - - with pytest.raises(SingletonAlreadyInstantiatedError): - database_factory.create("fire", [], {"it_matters": "it's singleton"}) - @pytest.mark.usefixtures("null_db_instances") class TestReadOnlyDatabase: @@ -64,7 +38,7 @@ def test_valid_attributes(self, storage): def test_read(self, hacked_exp): """Test read is coherent from view and wrapped database.""" - database = get_storage()._db + database = setup_storage()._db readonly_database = ReadOnlyDB(database) args = { diff --git a/tests/unittests/core/io/interactive_commands/test_branching_prompt.py b/tests/unittests/core/io/interactive_commands/test_branching_prompt.py index a9f05d3d5..fb498128b 100644 --- a/tests/unittests/core/io/interactive_commands/test_branching_prompt.py +++ b/tests/unittests/core/io/interactive_commands/test_branching_prompt.py @@ -78,9 +78,9 @@ def conflicts( @pytest.fixture -def branch_builder(conflicts): +def branch_builder(storage, conflicts): """Generate the experiment branch builder""" - return ExperimentBranchBuilder(conflicts, manual_resolution=True) + return ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) @pytest.fixture diff --git a/tests/unittests/core/io/orion_config.yaml b/tests/unittests/core/io/orion_config.yaml index f53cbb172..a8c0dc1be 100644 --- a/tests/unittests/core/io/orion_config.yaml +++ b/tests/unittests/core/io/orion_config.yaml @@ -7,6 +7,6 @@ experiment: algorithms: 'random' database: - type: 'mongodb' + type: 'pickleddb' name: 'orion_test' - host: 'mongodb://user:pass@localhost' + host: '${FILE}' diff --git a/tests/unittests/core/io/orion_old_config.yaml b/tests/unittests/core/io/orion_old_config.yaml index c07499fe0..6b761970e 100644 --- a/tests/unittests/core/io/orion_old_config.yaml +++ b/tests/unittests/core/io/orion_old_config.yaml @@ -13,6 +13,6 @@ algorithms: seed: null database: - type: 'mongodb' + type: 'pickleddb' name: 'orion_test' - host: 'mongodb://user:pass@localhost' + host: '${FILE}' diff --git a/tests/unittests/core/io/test_config.py b/tests/unittests/core/io/test_config.py index 068b58322..1bbd38429 100644 --- a/tests/unittests/core/io/test_config.py +++ b/tests/unittests/core/io/test_config.py @@ -624,3 +624,116 @@ def test_get_deprecated_key_ignore_warning(caplog): assert config.get("option", deprecated="ignore") == "hello" assert caplog.record_tuples == [] + + +def test_from_dict_additional_keys_are_ignored(): + config = Configuration() + config.from_dict(dict(undefined="are ignored")) + assert config.to_dict() == dict() + + +def test_from_dict_simple_value(): + config = Configuration() + config.add_option( + "option", + option_type=str, + default="hello", + ) + + values = dict(option="123") + config.from_dict(values) + assert config.to_dict() == values + + +def test_from_dict_simple_value_value_error(): + config = Configuration() + config.add_option( + "option", + option_type=float, + default="hello", + ) + + values = dict(option="1232") + with pytest.raises(ValueError): + config.from_dict(values) + + +def test_from_dict_old_values_are_popped(): + config = Configuration() + config.add_option( + "option1", + option_type=str, + default="hello", + ) + + config.add_option( + "option2", + option_type=str, + default="hello", + ) + + default = config.to_dict() + + # Override everything + overrides = dict(option1="123", option2="123") + config.from_dict(overrides) + assert config.to_dict() == overrides + + # + overrides.pop("option2") + config.from_dict(overrides) + + # Get the expected value + default.update(overrides) + + assert config.to_dict() == default + + +def test_from_dict_nested_config(): + config = Configuration() + nested = Configuration() + nested.add_option( + "option1", + option_type=str, + default="hello", + ) + config.sub = nested + + values = dict(sub=dict(option1="123")) + config.from_dict(values) + assert config.to_dict() == values + + +def test_from_dict_nested_values_are_popped(): + config = Configuration() + nested = Configuration() + nested.add_option( + "option1", + option_type=str, + default="hello", + ) + nested.add_option( + "option2", + option_type=str, + default="hello", + ) + config.sub = nested + default = nested.to_dict() + + # Override everything + overrides = dict(option1="123", option2="123") + all = dict(sub=overrides) + + config.from_dict(all) + assert config.to_dict() == all + + # Remove one override + overrides.pop("option2") + partial = dict(sub=overrides) + + # Get the expected result by applying overrides to the default config + default.update(overrides) + expected = dict(sub=default) + + config.from_dict(partial) + assert config.to_dict() == expected diff --git a/tests/unittests/core/io/test_experiment_builder.py b/tests/unittests/core/io/test_experiment_builder.py index 85a4b4015..3983212e2 100644 --- a/tests/unittests/core/io/test_experiment_builder.py +++ b/tests/unittests/core/io/test_experiment_builder.py @@ -19,16 +19,15 @@ RaceCondition, UnsupportedOperation, ) -from orion.core.utils.singleton import update_singletons from orion.core.worker.primary_algo import SpaceTransformAlgoWrapper -from orion.storage.base import get_storage +from orion.storage.base import setup_storage from orion.storage.legacy import Legacy from orion.testing import OrionState def count_experiments(): """Count experiments in storage""" - return len(get_storage().fetch_experiments({})) + return len(setup_storage().fetch_experiments({})) @pytest.fixture @@ -171,9 +170,9 @@ def child_version_config(parent_version_config): @pytest.mark.usefixtures("with_user_tsirif", "version_XYZ") -def test_get_cmd_config(config_file): +def test_get_cmd_config(raw_config): """Test local config (cmdconfig, cmdargs)""" - cmdargs = {"config": config_file} + cmdargs = {"config": raw_config} local_config = experiment_builder.get_cmd_config(cmdargs) assert local_config["algorithms"] == "random" @@ -182,9 +181,9 @@ def test_get_cmd_config(config_file): assert local_config["name"] == "voila_voici" assert local_config["storage"] == { "database": { - "host": "mongodb://user:pass@localhost", + "host": "${FILE}", "name": "orion_test", - "type": "mongodb", + "type": "pickleddb", } } assert local_config["metadata"] == {"orion_version": "XYZ", "user": "tsirif"} @@ -214,8 +213,10 @@ def test_get_cmd_config_from_incomplete_config(incomplete_config_file): def test_fetch_config_from_db_no_hit(): """Verify that fetch_config_from_db returns an empty dict when the experiment is not in db""" - with OrionState(experiments=[], trials=[]): - db_config = experiment_builder.fetch_config_from_db(name="supernaekei") + with OrionState(experiments=[], trials=[]) as cfg: + db_config = experiment_builder.ExperimentBuilder( + storage=cfg.storage_config + ).fetch_config_from_db(name="supernaekei") assert db_config == {} @@ -223,8 +224,10 @@ def test_fetch_config_from_db_no_hit(): @pytest.mark.usefixtures("with_user_tsirif") def test_fetch_config_from_db_hit(new_config): """Verify db config when experiment is in db""" - with OrionState(experiments=[new_config], trials=[]): - db_config = experiment_builder.fetch_config_from_db(name="supernaekei") + with OrionState(experiments=[new_config], trials=[]) as cfg: + db_config = experiment_builder.ExperimentBuilder( + storage=cfg.storage_config + ).fetch_config_from_db(name="supernaekei") assert db_config["name"] == new_config["name"] assert db_config["refers"] == new_config["refers"] @@ -249,11 +252,18 @@ def test_get_from_args_no_hit(config_file): @pytest.mark.usefixtures("with_user_tsirif") -def test_get_from_args_hit(config_file, random_dt, new_config): +def test_get_from_args_hit(monkeypatch, raw_config, random_dt, new_config): """Try building experiment view when in db""" - cmdargs = {"name": "supernaekei", "config": config_file} + cmdargs = {"name": "supernaekei", "config": raw_config} + + with OrionState(experiments=[new_config], trials=[]) as cfg: + # This is necessary because storage is instantiated inside + # `get_from_args` with its own config not the global config set by OrionState + def get_storage(*args, **kwargs): + return cfg.storage + + monkeypatch.setattr(experiment_builder, "setup_storage", get_storage) - with OrionState(experiments=[new_config], trials=[]): exp_view = experiment_builder.get_from_args(cmdargs) assert exp_view._id == new_config["_id"] @@ -266,7 +276,9 @@ def test_get_from_args_hit(config_file, random_dt, new_config): @pytest.mark.usefixtures("with_user_tsirif") -def test_get_from_args_hit_no_conf_file(config_file, random_dt, new_config): +def test_get_from_args_hit_no_conf_file( + monkeypatch, config_file, random_dt, new_config +): """Try building experiment view when in db, and local config file of user script does not exist """ @@ -277,6 +289,14 @@ def test_get_from_args_hit_no_conf_file(config_file, random_dt, new_config): ] with OrionState(experiments=[new_config], trials=[]) as cfg: + + # This is necessary because storage is instantiated inside + # `get_from_args` with its own config not the global config set by OrionState + def get_storage(*args, **kwargs): + return cfg.storage + + monkeypatch.setattr(experiment_builder, "setup_storage", get_storage) + exp_view = experiment_builder.get_from_args(cmdargs) assert exp_view._id == new_config["_id"] @@ -289,7 +309,9 @@ def test_get_from_args_hit_no_conf_file(config_file, random_dt, new_config): @pytest.mark.usefixtures("with_user_dendi") -def test_build_from_args_no_hit(config_file, random_dt, script_path, new_config): +def test_build_from_args_no_hit( + monkeypatch, config_file, random_dt, script_path, new_config +): """Try building experiment when not in db""" cmdargs = { "name": "supernaekei", @@ -297,13 +319,20 @@ def test_build_from_args_no_hit(config_file, random_dt, script_path, new_config) "user_args": [script_path, "x~uniform(0,10)"], } - with OrionState(experiments=[], trials=[]): + with OrionState(experiments=[], trials=[]) as cfg: with pytest.raises(NoConfigurationError) as exc_info: experiment_builder.get_from_args(cmdargs) assert "No experiment with given name 'supernaekei' and version '*'" in str( exc_info.value ) + # This is necessary because storage is instantiated inside + # `get_from_args` with its own config not the global config set by OrionState + def get_storage(*args, **kwargs): + return cfg.storage + + monkeypatch.setattr(experiment_builder, "setup_storage", get_storage) + exp = experiment_builder.build_from_args(cmdargs) assert exp.name == cmdargs["name"] @@ -324,7 +353,7 @@ def test_build_from_args_no_hit(config_file, random_dt, script_path, new_config) @pytest.mark.usefixtures( "version_XYZ", "with_user_tsirif", "mock_infer_versioning_metadata" ) -def test_build_from_args_hit(old_config_file, script_path, new_config): +def test_build_from_args_hit(monkeypatch, old_config_file, script_path, new_config): """Try building experiment when in db (no branch)""" cmdargs = { "name": "supernaekei", @@ -332,7 +361,15 @@ def test_build_from_args_hit(old_config_file, script_path, new_config): "user_args": [script_path, "--mini-batch~uniform(32, 256, discrete=True)"], } - with OrionState(experiments=[new_config], trials=[]): + with OrionState(experiments=[new_config], trials=[]) as cfg: + + # This is necessary because storage is instantiated inside + # `get_from_args` with its own config not the global config set by OrionState + def get_storage(*args, **kwargs): + return cfg.storage + + monkeypatch.setattr(experiment_builder, "setup_storage", get_storage) + # Test that experiment already exists experiment_builder.get_from_args(cmdargs) @@ -359,61 +396,60 @@ def test_build_from_args_force_user(new_config): assert exp_view.metadata["user"] == "tsirif" -@pytest.mark.usefixtures("setup_pickleddb_database") -def test_build_from_args_debug_mode(script_path): +def test_build_from_args_debug_mode(monkeypatch, script_path, storage): """Try building experiment in debug mode""" - update_singletons() - experiment_builder.build_from_args( + + experiment = experiment_builder.build_from_args( { "name": "whatever", "user_args": [script_path, "--mini-batch~uniform(32, 256)"], } ) - storage = get_storage() - + storage = experiment._storage assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) - update_singletons() - - experiment_builder.build_from_args( + experiment = experiment_builder.build_from_args( { "name": "whatever", "user_args": [script_path, "--mini-batch~uniform(32, 256)"], "debug": True, } ) - storage = get_storage() + storage = experiment._storage assert isinstance(storage, Legacy) assert isinstance(storage._db, EphemeralDB) -@pytest.mark.usefixtures("setup_pickleddb_database") -def test_get_from_args_debug_mode(script_path): +storage_instance = "" + + +def test_get_from_args_debug_mode(monkeypatch, script_path, storage): """Try building experiment view in debug mode""" - update_singletons() + + old_factory = experiment_builder.setup_storage + + def retrieve_storage(*args, **kwargs): + global storage_instance + + storage_instance = old_factory(*args, **kwargs) + return storage_instance + + monkeypatch.setattr(experiment_builder, "setup_storage", retrieve_storage) # Can't build view if none exist. It's fine we only want to test the storage creation. with pytest.raises(NoConfigurationError): experiment_builder.get_from_args({"name": "whatever"}) - storage = get_storage() - - assert isinstance(storage, Legacy) - assert isinstance(storage._db, PickledDB) - - update_singletons() + assert isinstance(storage_instance._db, PickledDB) # Can't build view if none exist. It's fine we only want to test the storage creation. with pytest.raises(NoConfigurationError): experiment_builder.get_from_args({"name": "whatever", "debug": True}) - storage = get_storage() - - assert isinstance(storage, Legacy) - assert isinstance(storage._db, EphemeralDB) + assert isinstance(storage_instance._db, EphemeralDB) @pytest.mark.usefixtures("with_user_tsirif", "version_XYZ") @@ -508,13 +544,23 @@ def test_build_without_config_hit(python_api_config): @pytest.mark.usefixtures( "with_user_tsirif", "version_XYZ", "mock_infer_versioning_metadata" ) -def test_build_from_args_without_cmd(old_config_file, script_path, new_config): +def test_build_from_args_without_cmd( + monkeypatch, old_config_file, script_path, new_config +): """Try building experiment without commandline when in db (no branch)""" name = "supernaekei" cmdargs = {"name": name, "config": old_config_file} - with OrionState(experiments=[new_config], trials=[]): + with OrionState(experiments=[new_config], trials=[]) as cfg: + + # This is necessary because storage is instantiated inside + # `get_from_args` with its own config not the global config set by OrionState + def get_storage(*args, **kwargs): + return cfg.storage + + monkeypatch.setattr(experiment_builder, "setup_storage", get_storage) + # Test that experiment already exists (this should fail otherwise) experiment_builder.get_from_args(cmdargs) @@ -711,7 +757,7 @@ def test_good_set_before_init_no_hit(self, random_dt, new_config): with OrionState(experiments=[], trials=[]): exp = experiment_builder.build(**new_config) found_config = list( - get_storage().fetch_experiments( + setup_storage().fetch_experiments( {"name": "supernaekei", "metadata.user": "tsirif"} ) ) @@ -747,7 +793,7 @@ def test_working_dir_is_correctly_set(self, new_config): with OrionState(): new_config["working_dir"] = "./" exp = experiment_builder.build(**new_config) - storage = get_storage() + storage = setup_storage() found_config = list( storage.fetch_experiments( {"name": "supernaekei", "metadata.user": "tsirif"} @@ -762,7 +808,7 @@ def test_working_dir_works_when_db_absent(self, database, new_config): """Check if working_dir is correctly when absent from the database.""" with OrionState(experiments=[], trials=[]): exp = experiment_builder.build(**new_config) - storage = get_storage() + storage = setup_storage() found_config = list( storage.fetch_experiments( {"name": "supernaekei", "metadata.user": "tsirif"} @@ -844,7 +890,7 @@ def test_try_set_after_race_condition(self, new_config, monkeypatch): properly. The experiment which looses the race condition cannot be initialized and needs to be rebuilt. """ - with OrionState(experiments=[new_config], trials=[]): + with OrionState(experiments=[new_config], trials=[]) as cfg: experiment_count_before = count_experiments() def insert_race_condition(*args, **kwargs): @@ -859,11 +905,10 @@ def insert_race_condition(*args, **kwargs): insert_race_condition.count = 0 - monkeypatch.setattr( - experiment_builder, "fetch_config_from_db", insert_race_condition - ) + builder = experiment_builder.ExperimentBuilder(cfg.storage) + monkeypatch.setattr(builder, "fetch_config_from_db", insert_race_condition) - experiment_builder.build(**new_config) + builder.build(**new_config) assert experiment_count_before == count_experiments() @@ -998,7 +1043,7 @@ def insert_race_condition_1(self, query): insert_race_condition_1.count = 0 monkeypatch.setattr( - get_storage().__class__, "fetch_experiments", insert_race_condition_1 + setup_storage().__class__, "fetch_experiments", insert_race_condition_1 ) with pytest.raises(RaceCondition) as exc_info: @@ -1031,7 +1076,7 @@ def insert_race_condition_2(self, query): insert_race_condition_2.count = 0 monkeypatch.setattr( - get_storage().__class__, "fetch_experiments", insert_race_condition_2 + setup_storage().__class__, "fetch_experiments", insert_race_condition_2 ) with pytest.raises(RaceCondition) as exc_info: @@ -1093,7 +1138,7 @@ def insert_race_condition_1(self, query): insert_race_condition_1.count = 0 monkeypatch.setattr( - get_storage().__class__, "fetch_experiments", insert_race_condition_1 + setup_storage().__class__, "fetch_experiments", insert_race_condition_1 ) with pytest.raises(BranchingEvent) as exc_info: @@ -1125,7 +1170,7 @@ def insert_race_condition_2(self, query): insert_race_condition_2.count = 0 monkeypatch.setattr( - get_storage().__class__, "fetch_experiments", insert_race_condition_2 + setup_storage().__class__, "fetch_experiments", insert_race_condition_2 ) with pytest.raises(RaceCondition) as exc_info: diff --git a/tests/unittests/core/io/test_resolve_config.py b/tests/unittests/core/io/test_resolve_config.py index 223ceedc4..9501902c7 100644 --- a/tests/unittests/core/io/test_resolve_config.py +++ b/tests/unittests/core/io/test_resolve_config.py @@ -204,15 +204,15 @@ def test_fetch_config_no_hit(): assert config == {} -def test_fetch_config(config_file): +def test_fetch_config(raw_config): """Verify fetch_config returns valid dictionary""" - config = resolve_config.fetch_config({"config": config_file}) + config = resolve_config.fetch_config({"config": raw_config}) assert config.pop("storage") == { "database": { - "host": "mongodb://user:pass@localhost", + "host": "${FILE}", "name": "orion_test", - "type": "mongodb", + "type": "pickleddb", } } diff --git a/tests/unittests/core/test_branch_config.py b/tests/unittests/core/test_branch_config.py index d1be6c9d7..626e8ba1c 100644 --- a/tests/unittests/core/test_branch_config.py +++ b/tests/unittests/core/test_branch_config.py @@ -399,12 +399,14 @@ def test_cli_ignored_conflict(self, parent_config, changed_cli_config): class TestResolutions: """Test resolution of conflicts""" - def test_add_single_hit(self, parent_config, new_config_with_w): + def test_add_single_hit(self, storage, parent_config, new_config_with_w): """Test if adding a dimension only touches the correct status""" del new_config_with_w["metadata"]["user_args"][2] backward.populate_space(new_config_with_w) conflicts = detect_conflicts(parent_config, new_config_with_w) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) branch_builder.add_dimension("w_d") assert len(conflicts.get()) == 3 @@ -412,10 +414,12 @@ def test_add_single_hit(self, parent_config, new_config_with_w): assert conflicts.get([NewDimensionConflict])[0].is_resolved assert not conflicts.get([MissingDimensionConflict])[0].is_resolved - def test_add_new(self, parent_config, new_config_with_w): + def test_add_new(self, parent_config, new_config_with_w, storage): """Test if adding a new dimension solves the conflict""" conflicts = detect_conflicts(parent_config, new_config_with_w) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) branch_builder.add_dimension("w_d") assert len(conflicts.get()) == 2 @@ -426,10 +430,12 @@ def test_add_new(self, parent_config, new_config_with_w): assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.AddDimensionResolution) - def test_add_changed(self, parent_config, changed_config): + def test_add_changed(self, parent_config, changed_config, storage): """Test if adding a changed dimension solves the conflict""" conflicts = detect_conflicts(parent_config, changed_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) branch_builder.add_dimension("y") assert len(conflicts.get()) == 2 @@ -440,10 +446,12 @@ def test_add_changed(self, parent_config, changed_config): assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.ChangeDimensionResolution) - def test_remove_missing(self, parent_config, missing_config): + def test_remove_missing(self, parent_config, missing_config, storage): """Test if removing a missing dimension solves the conflict""" conflicts = detect_conflicts(parent_config, missing_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) branch_builder.remove_dimension("x") assert len(conflicts.get()) == 3 @@ -454,12 +462,14 @@ def test_remove_missing(self, parent_config, missing_config): assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.RemoveDimensionResolution) - def test_rename_missing(self, parent_config, missing_config): + def test_rename_missing(self, parent_config, missing_config, storage): """Test if renaming a dimension to another solves both conflicts""" missing_config["metadata"]["user_args"].append("-w_d~uniform(0,1)") backward.populate_space(missing_config) conflicts = detect_conflicts(parent_config, missing_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) branch_builder.rename_dimension("x", "w_d") assert len(conflicts.get()) == 4 @@ -482,14 +492,16 @@ def test_rename_missing(self, parent_config, missing_config): == "/w_d" ) - def test_rename_missing_changed(self, parent_config, missing_config): + def test_rename_missing_changed(self, parent_config, missing_config, storage): """Test if renaming a dimension to another with different prior solves both conflicts but creates a new one which is not solved """ missing_config["metadata"]["user_args"].append("-w_d~normal(0,1)") backward.populate_space(missing_config) conflicts = detect_conflicts(parent_config, missing_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) assert len(conflicts.get()) == 4 @@ -516,10 +528,12 @@ def test_rename_missing_changed(self, parent_config, missing_config): == "/w_d" ) - def test_reset_dimension(self, parent_config, new_config_with_w): + def test_reset_dimension(self, parent_config, new_config_with_w, storage): """Test if resetting a dimension unsolves the conflict""" conflicts = detect_conflicts(parent_config, new_config_with_w) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) branch_builder.add_dimension("w_d") assert len(conflicts.get_resolved()) == 2 @@ -548,7 +562,9 @@ def test_name_experiment( storage.create_experiment(bad_exp_parent_config) storage.create_experiment(bad_exp_child_config) conflicts = detect_conflicts(bad_exp_parent_config, bad_exp_parent_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) assert len(conflicts.get()) == 1 assert len(conflicts.get_resolved()) == 0 @@ -562,7 +578,9 @@ def test_name_experiment( assert conflict.new_config["name"] == "test2" assert conflict.is_resolved - def test_bad_name_experiment(self, parent_config, child_config, monkeypatch): + def test_bad_name_experiment( + self, parent_config, child_config, monkeypatch, storage + ): """Test if changing the experiment names does not work for invalid name and revert to old one """ @@ -585,7 +603,9 @@ def _versions(self, *args, **kwargs): ) conflicts = detect_conflicts(parent_config, child_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) assert len(conflicts.get()) == 1 assert len(conflicts.get_resolved()) == 0 @@ -598,10 +618,12 @@ def _versions(self, *args, **kwargs): assert conflict.new_config["name"] == "test" assert not conflict.is_resolved - def test_algo_change(self, parent_config, changed_algo_config): + def test_algo_change(self, parent_config, changed_algo_config, storage): """Test if setting the algorithm conflict solves it""" conflicts = detect_conflicts(parent_config, changed_algo_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 1 @@ -616,10 +638,14 @@ def test_algo_change(self, parent_config, changed_algo_config): assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.AlgorithmResolution) - def test_orion_version_change(self, parent_config, changed_orion_version_config): + def test_orion_version_change( + self, parent_config, changed_orion_version_config, storage + ): """Test if setting the orion version conflict solves it""" conflicts = detect_conflicts(parent_config, changed_orion_version_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 1 @@ -634,10 +660,12 @@ def test_orion_version_change(self, parent_config, changed_orion_version_config) assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.OrionVersionResolution) - def test_code_change(self, parent_config, changed_code_config): + def test_code_change(self, parent_config, changed_code_config, storage): """Test if giving a proper change-type solves the code conflict""" conflicts = detect_conflicts(parent_config, changed_code_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 1 @@ -651,10 +679,12 @@ def test_code_change(self, parent_config, changed_code_config): assert conflict.is_resolved assert isinstance(conflict, CodeConflict) - def test_bad_code_change(self, capsys, parent_config, changed_code_config): + def test_bad_code_change(self, capsys, parent_config, changed_code_config, storage): """Test if giving an invalid change-type prints error message and do nothing""" conflicts = detect_conflicts(parent_config, changed_code_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) capsys.readouterr() branch_builder.set_code_change_type("bad-type") out, err = capsys.readouterr() @@ -663,10 +693,12 @@ def test_bad_code_change(self, capsys, parent_config, changed_code_config): assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 1 - def test_config_change(self, parent_config, changed_userconfig_config): + def test_config_change(self, parent_config, changed_userconfig_config, storage): """Test if giving a proper change-type solves the user script config conflict""" conflicts = detect_conflicts(parent_config, changed_userconfig_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) assert len(conflicts.get()) == 4 assert len(conflicts.get_resolved()) == 1 @@ -682,10 +714,14 @@ def test_config_change(self, parent_config, changed_userconfig_config): assert conflict.is_resolved assert isinstance(conflict, ScriptConfigConflict) - def test_bad_config_change(self, capsys, parent_config, changed_userconfig_config): + def test_bad_config_change( + self, capsys, parent_config, changed_userconfig_config, storage + ): """Test if giving an invalid change-type prints error message and do nothing""" conflicts = detect_conflicts(parent_config, changed_userconfig_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) capsys.readouterr() branch_builder.set_script_config_change_type("bad-type") out, err = capsys.readouterr() @@ -694,10 +730,12 @@ def test_bad_config_change(self, capsys, parent_config, changed_userconfig_confi assert len(conflicts.get()) == 4 assert len(conflicts.get_resolved()) == 1 - def test_cli_change(self, parent_config, changed_cli_config): + def test_cli_change(self, parent_config, changed_cli_config, storage): """Test if giving a proper change-type solves the command line conflict""" conflicts = detect_conflicts(parent_config, changed_cli_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 1 @@ -711,10 +749,12 @@ def test_cli_change(self, parent_config, changed_cli_config): assert conflict.is_resolved assert isinstance(conflict, CommandLineConflict) - def test_bad_cli_change(self, capsys, parent_config, changed_cli_config): + def test_bad_cli_change(self, capsys, parent_config, changed_cli_config, storage): """Test if giving an invalid change-type prints error message and do nothing""" conflicts = detect_conflicts(parent_config, changed_cli_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) capsys.readouterr() branch_builder.set_cli_change_type("bad-type") out, err = capsys.readouterr() @@ -723,9 +763,9 @@ def test_bad_cli_change(self, capsys, parent_config, changed_cli_config): assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 1 - def test_solve_all_automatically(self, conflicts): + def test_solve_all_automatically(self, conflicts, storage): """Test if all conflicts all automatically resolve by the ExperimentBranchBuilder.""" - ExperimentBranchBuilder(conflicts) + ExperimentBranchBuilder(conflicts, storage=storage) assert len(conflicts.get_resolved()) == 9 @@ -733,11 +773,11 @@ def test_solve_all_automatically(self, conflicts): class TestResolutionsWithMarkers: """Test resolution of conflicts with markers""" - def test_add_new(self, parent_config, new_config_with_w): + def test_add_new(self, parent_config, new_config_with_w, storage): """Test if new dimension conflict is automatically resolved""" new_config_with_w["metadata"]["user_args"][-1] = "-w_d~+normal(0,1)" conflicts = detect_conflicts(parent_config, new_config_with_w) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -747,14 +787,14 @@ def test_add_new(self, parent_config, new_config_with_w): assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.AddDimensionResolution) - def test_add_new_default(self, parent_config, new_config_with_w): + def test_add_new_default(self, parent_config, new_config_with_w, storage): """Test if new dimension conflict is automatically resolved""" new_config_with_w["metadata"]["user_args"][ -1 ] = "-w_d~+normal(0,1,default_value=0)" backward.populate_space(new_config_with_w) conflicts = detect_conflicts(parent_config, new_config_with_w) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -775,13 +815,13 @@ def test_add_bad_default(self, parent_config, new_config_with_w): detect_conflicts(parent_config, new_config_with_w) assert "Parameter '/w_d': Incorrect arguments." in str(exc.value) - def test_add_changed(self, parent_config, changed_config): + def test_add_changed(self, parent_config, changed_config, storage): """Test if changed dimension conflict is automatically resolved""" changed_config["metadata"]["user_args"][3] = changed_config["metadata"][ "user_args" ][3].replace("~", "~+") conflicts = detect_conflicts(parent_config, changed_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -791,12 +831,12 @@ def test_add_changed(self, parent_config, changed_config): assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.ChangeDimensionResolution) - def test_remove_missing(self, parent_config, child_config): + def test_remove_missing(self, parent_config, child_config, storage): """Test if missing dimension conflict is automatically resolved""" child_config["metadata"]["user_args"][2] = "-x~-" backward.populate_space(child_config) conflicts = detect_conflicts(parent_config, child_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -806,12 +846,12 @@ def test_remove_missing(self, parent_config, child_config): assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.RemoveDimensionResolution) - def test_remove_missing_default(self, parent_config, child_config): + def test_remove_missing_default(self, parent_config, child_config, storage): """Test if missing dimension conflict is automatically resolved""" child_config["metadata"]["user_args"][2] = "-x~-0.5" backward.populate_space(child_config) conflicts = detect_conflicts(parent_config, child_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -822,12 +862,12 @@ def test_remove_missing_default(self, parent_config, child_config): assert isinstance(conflict.resolution, conflict.RemoveDimensionResolution) assert conflict.resolution.default_value == 0.5 - def test_remove_missing_bad_default(self, parent_config, child_config): + def test_remove_missing_bad_default(self, parent_config, child_config, storage): """Test if missing dimension conflict raises an error if marked with invalid default""" child_config["metadata"]["user_args"][2] = "-x~--100" backward.populate_space(child_config) conflicts = detect_conflicts(parent_config, child_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 1 @@ -837,14 +877,14 @@ def test_remove_missing_bad_default(self, parent_config, child_config): assert not conflict.is_resolved assert isinstance(conflict, MissingDimensionConflict) - def test_rename_missing(self, parent_config, child_config): + def test_rename_missing(self, parent_config, child_config, storage): """Test if renaming is automatically applied with both conflicts resolved""" child_config["metadata"]["user_args"].append("-w_a~uniform(0,1)") child_config["metadata"]["user_args"].append("-w_b~normal(0,1)") child_config["metadata"]["user_args"][2] = "-x~>w_a" backward.populate_space(child_config) conflicts = detect_conflicts(parent_config, child_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 4 @@ -866,7 +906,7 @@ def test_rename_missing(self, parent_config, child_config): == "/w_a" ) - def test_rename_invalid(self, parent_config, child_config): + def test_rename_invalid(self, parent_config, child_config, storage): """Test if renaming to invalid dimension raises an error""" child_config["metadata"]["user_args"].append("-w_a~uniform(0,1)") child_config["metadata"]["user_args"].append("-w_b~uniform(0,1)") @@ -874,10 +914,10 @@ def test_rename_invalid(self, parent_config, child_config): backward.populate_space(child_config) conflicts = detect_conflicts(parent_config, child_config) with pytest.raises(ValueError) as exc: - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert "Dimension name 'w_c' not found in conflicts" in str(exc.value) - def test_rename_missing_changed(self, parent_config, child_config): + def test_rename_missing_changed(self, parent_config, child_config, storage): """Test if renaming is automatically applied with both conflicts resolved, but not the new one because of prior change """ @@ -886,7 +926,7 @@ def test_rename_missing_changed(self, parent_config, child_config): child_config["metadata"]["user_args"][2] = "-x~>w_b" backward.populate_space(child_config) conflicts = detect_conflicts(parent_config, child_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 5 @@ -913,7 +953,7 @@ def test_rename_missing_changed(self, parent_config, child_config): == "/w_b" ) - def test_rename_missing_changed_marked(self, parent_config, child_config): + def test_rename_missing_changed_marked(self, parent_config, child_config, storage): """Test if renaming is automatically applied with all conflicts resolved including the new one caused by prior change """ @@ -922,7 +962,7 @@ def test_rename_missing_changed_marked(self, parent_config, child_config): child_config["metadata"]["user_args"][2] = "-x~>w_b" backward.populate_space(child_config) conflicts = detect_conflicts(parent_config, child_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 5 @@ -956,7 +996,7 @@ def test_name_experiment_version_update(self, parent_config, child_config, stora storage.create_experiment(parent_config) child_config["version"] = 1 conflicts = detect_conflicts(parent_config, child_config) - ExperimentBranchBuilder(conflicts) + ExperimentBranchBuilder(conflicts, storage=storage) assert len(conflicts.get()) == 1 assert len(conflicts.get_resolved()) == 1 @@ -977,7 +1017,7 @@ def test_name_experiment_name_change(self, parent_config, child_config, storage) child_config2 = copy.deepcopy(child_config) child_config2["version"] = 1 conflicts = detect_conflicts(parent_config, child_config2) - ExperimentBranchBuilder(conflicts, branch_to=new_name) + ExperimentBranchBuilder(conflicts, branch_to=new_name, storage=storage) assert len(conflicts.get()) == 1 assert len(conflicts.get_resolved()) == 1 @@ -990,7 +1030,9 @@ def test_name_experiment_name_change(self, parent_config, child_config, storage) assert conflict.new_config["version"] == 1 assert conflict.is_resolved - def test_bad_name_experiment(self, parent_config, child_config, monkeypatch): + def test_bad_name_experiment( + self, parent_config, child_config, monkeypatch, storage + ): """Test if experiment name conflict is not resolved when invalid name is marked""" def _is_unique(self, *args, **kwargs): @@ -1003,17 +1045,17 @@ def _is_unique(self, *args, **kwargs): ) conflicts = detect_conflicts(parent_config, child_config) - ExperimentBranchBuilder(conflicts, branch_to="test2") + ExperimentBranchBuilder(conflicts, branch_to="test2", storage=storage) assert len(conflicts.get()) == 1 assert len(conflicts.get_resolved()) == 0 - def test_code_change(self, parent_config, changed_code_config): + def test_code_change(self, parent_config, changed_code_config, storage): """Test if code conflict is resolved automatically""" change_type = evc.adapters.CodeChange.types[0] changed_code_config["code_change_type"] = change_type conflicts = detect_conflicts(parent_config, changed_code_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -1024,11 +1066,11 @@ def test_code_change(self, parent_config, changed_code_config): assert isinstance(conflict.resolution, conflict.CodeResolution) assert conflict.resolution.type == change_type - def test_algo_change(self, parent_config, changed_algo_config): + def test_algo_change(self, parent_config, changed_algo_config, storage): """Test if algorithm conflict is resolved automatically""" changed_algo_config["algorithm_change"] = True conflicts = detect_conflicts(parent_config, changed_algo_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -1038,11 +1080,13 @@ def test_algo_change(self, parent_config, changed_algo_config): assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.AlgorithmResolution) - def test_orion_version_change(self, parent_config, changed_orion_version_config): + def test_orion_version_change( + self, parent_config, changed_orion_version_config, storage + ): """Test if orion version conflict is resolved automatically""" changed_orion_version_config["orion_version_change"] = True conflicts = detect_conflicts(parent_config, changed_orion_version_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -1052,12 +1096,12 @@ def test_orion_version_change(self, parent_config, changed_orion_version_config) assert conflict.is_resolved assert isinstance(conflict.resolution, conflict.OrionVersionResolution) - def test_config_change(self, parent_config, changed_userconfig_config): + def test_config_change(self, parent_config, changed_userconfig_config, storage): """Test if user's script's config conflict is resolved automatically""" change_type = evc.adapters.ScriptConfigChange.types[0] changed_userconfig_config["config_change_type"] = change_type conflicts = detect_conflicts(parent_config, changed_userconfig_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 4 assert len(conflicts.get_resolved()) == 2 @@ -1068,12 +1112,12 @@ def test_config_change(self, parent_config, changed_userconfig_config): assert isinstance(conflict.resolution, conflict.ScriptConfigResolution) assert conflict.resolution.type == change_type - def test_cli_change(self, parent_config, changed_cli_config): + def test_cli_change(self, parent_config, changed_cli_config, storage): """Test if command line conflict is resolved automatically""" change_type = evc.adapters.CommandLineChange.types[0] changed_cli_config["cli_change_type"] = change_type conflicts = detect_conflicts(parent_config, changed_cli_config) - ExperimentBranchBuilder(conflicts, manual_resolution=True) + ExperimentBranchBuilder(conflicts, manual_resolution=True, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -1088,12 +1132,14 @@ def test_cli_change(self, parent_config, changed_cli_config): class TestAdapters: """Test creation of adapters""" - def test_adapter_add_new(self, parent_config, cl_config): + def test_adapter_add_new(self, parent_config, cl_config, storage): """Test if a DimensionAddition is created when solving a new conflict""" cl_config["metadata"]["user_args"] = ["-w_d~+normal(0,1)"] conflicts = detect_conflicts(parent_config, cl_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) adapters = branch_builder.create_adapters().adapters @@ -1101,12 +1147,14 @@ def test_adapter_add_new(self, parent_config, cl_config): assert len(adapters) == 1 assert isinstance(adapters[0], evc.adapters.DimensionAddition) - def test_adapter_add_changed(self, parent_config, cl_config): + def test_adapter_add_changed(self, parent_config, cl_config, storage): """Test if a DimensionPriorChange is created when solving a new conflict""" cl_config["metadata"]["user_args"] = ["-y~+uniform(0,1)"] conflicts = detect_conflicts(parent_config, cl_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) adapters = branch_builder.create_adapters().adapters @@ -1114,12 +1162,14 @@ def test_adapter_add_changed(self, parent_config, cl_config): assert len(adapters) == 1 assert isinstance(adapters[0], evc.adapters.DimensionPriorChange) - def test_adapter_remove_missing(self, parent_config, cl_config): + def test_adapter_remove_missing(self, parent_config, cl_config, storage): """Test if a DimensionDeletion is created when solving a new conflict""" cl_config["metadata"]["user_args"] = ["-z~-"] conflicts = detect_conflicts(parent_config, cl_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) adapters = branch_builder.create_adapters().adapters @@ -1127,13 +1177,15 @@ def test_adapter_remove_missing(self, parent_config, cl_config): assert len(adapters) == 1 assert isinstance(adapters[0], evc.adapters.DimensionDeletion) - def test_adapter_rename_missing(self, parent_config, cl_config): + def test_adapter_rename_missing(self, parent_config, cl_config, storage): """Test if a DimensionRenaming is created when solving a new conflict""" cl_config["metadata"]["user_args"] = ["-x~>w_d", "-w_d~+uniform(0,1)"] backward.populate_space(cl_config) conflicts = detect_conflicts(parent_config, cl_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) adapters = branch_builder.create_adapters().adapters @@ -1141,12 +1193,14 @@ def test_adapter_rename_missing(self, parent_config, cl_config): assert len(adapters) == 1 assert isinstance(adapters[0], evc.adapters.DimensionRenaming) - def test_adapter_rename_different_prior(self, parent_config, cl_config): + def test_adapter_rename_different_prior(self, parent_config, cl_config, storage): """Test if a DimensionRenaming is created when solving a new conflict""" cl_config["metadata"]["user_args"] = ["-x~>w_d", "-w_d~+normal(0,1)"] conflicts = detect_conflicts(parent_config, cl_config) - branch_builder = ExperimentBranchBuilder(conflicts, manual_resolution=True) + branch_builder = ExperimentBranchBuilder( + conflicts, manual_resolution=True, storage=storage + ) adapters = branch_builder.create_adapters().adapters @@ -1159,11 +1213,11 @@ def test_adapter_rename_different_prior(self, parent_config, cl_config): class TestResolutionsConfig: """Test auto-resolution with specific types from orion.core.config.evc""" - def test_cli_change(self, parent_config, changed_cli_config): + def test_cli_change(self, parent_config, changed_cli_config, storage): """Test if giving a proper change-type solves the command line conflict""" conflicts = detect_conflicts(parent_config, changed_cli_config) orion.core.config.evc.cli_change_type = "noeffect" - ExperimentBranchBuilder(conflicts) + ExperimentBranchBuilder(conflicts, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -1174,21 +1228,21 @@ def test_cli_change(self, parent_config, changed_cli_config): assert conflict.resolution.type == "noeffect" orion.core.config.evc.cli_change_type = "break" - def test_bad_cli_change(self, capsys, parent_config, changed_cli_config): + def test_bad_cli_change(self, capsys, parent_config, changed_cli_config, storage): """Test if giving an invalid change-type fails the the resolution""" conflicts = detect_conflicts(parent_config, changed_cli_config) orion.core.config.evc.cli_change_type = "bad-type" - ExperimentBranchBuilder(conflicts) + ExperimentBranchBuilder(conflicts, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 1 orion.core.config.evc.cli_change_type = "break" - def test_code_change(self, parent_config, changed_code_config): + def test_code_change(self, parent_config, changed_code_config, storage): """Test if giving a proper change-type solves the code conflict""" conflicts = detect_conflicts(parent_config, changed_code_config) orion.core.config.evc.code_change_type = "noeffect" - ExperimentBranchBuilder(conflicts) + ExperimentBranchBuilder(conflicts, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 2 @@ -1199,21 +1253,21 @@ def test_code_change(self, parent_config, changed_code_config): assert conflict.resolution.type == "noeffect" orion.core.config.evc.code_change_type = "break" - def test_bad_code_change(self, capsys, parent_config, changed_code_config): + def test_bad_code_change(self, capsys, parent_config, changed_code_config, storage): """Test if giving an invalid change-type prints error message and do nothing""" conflicts = detect_conflicts(parent_config, changed_code_config) orion.core.config.evc.code_change_type = "bad-type" - ExperimentBranchBuilder(conflicts) + ExperimentBranchBuilder(conflicts, storage=storage) assert len(conflicts.get()) == 2 assert len(conflicts.get_resolved()) == 1 orion.core.config.evc.code_change_type = "break" - def test_config_change(self, parent_config, changed_userconfig_config): + def test_config_change(self, parent_config, changed_userconfig_config, storage): """Test if giving a proper change-type solves the user script config conflict""" conflicts = detect_conflicts(parent_config, changed_userconfig_config) orion.core.config.evc.config_change_type = "noeffect" - ExperimentBranchBuilder(conflicts) + ExperimentBranchBuilder(conflicts, storage=storage) assert len(conflicts.get()) == 4 assert len(conflicts.get_resolved()) == 4 @@ -1223,11 +1277,13 @@ def test_config_change(self, parent_config, changed_userconfig_config): assert isinstance(conflict, ScriptConfigConflict) assert conflict.resolution.type == "noeffect" - def test_bad_config_change(self, capsys, parent_config, changed_userconfig_config): + def test_bad_config_change( + self, capsys, parent_config, changed_userconfig_config, storage + ): """Test if giving an invalid change-type prints error message and do nothing""" conflicts = detect_conflicts(parent_config, changed_userconfig_config) orion.core.config.evc.config_change_type = "bad-type" - ExperimentBranchBuilder(conflicts) + ExperimentBranchBuilder(conflicts, storage=storage) assert len(conflicts.get()) == 4 assert len(conflicts.get_resolved()) == 3 diff --git a/tests/unittests/core/worker/test_consumer.py b/tests/unittests/core/worker/test_consumer.py index 39c9f38a4..6c86a4ef1 100644 --- a/tests/unittests/core/worker/test_consumer.py +++ b/tests/unittests/core/worker/test_consumer.py @@ -35,14 +35,13 @@ def config(exp_config): return config -@pytest.mark.usefixtures("storage") -def test_trials_interrupted_sigterm(config, monkeypatch): +def test_trials_interrupted_sigterm(storage, config, monkeypatch): """Check if a trial is set as interrupted when a signal is raised.""" def mock_popen(self, *args, **kwargs): os.kill(os.getpid(), signal.SIGTERM) - exp = experiment_builder.build(**config) + exp = experiment_builder.build(**config, storage=storage) monkeypatch.setattr(subprocess.Popen, "wait", mock_popen) @@ -58,10 +57,10 @@ def mock_popen(self, *args, **kwargs): shutil.rmtree(trial.working_dir) -@pytest.mark.usefixtures("storage") -def test_trial_working_dir_is_created(config): +def test_trial_working_dir_is_created(storage, config): """Check that trial working dir is created.""" - exp = experiment_builder.build(**config) + + exp = experiment_builder.build(**config, storage=storage) trial = tuple_to_trial((1.0,), exp.space) @@ -77,9 +76,9 @@ def test_trial_working_dir_is_created(config): shutil.rmtree(trial.working_dir) -def setup_code_change_mock(config, monkeypatch, ignore_code_changes): +def setup_code_change_mock(storage, config, monkeypatch, ignore_code_changes): """Mock create experiment and trials, and infer_versioning_metadata""" - exp = experiment_builder.build(**config) + exp = experiment_builder.build(**config, storage=storage) trial = tuple_to_trial((1.0,), exp.space) @@ -101,11 +100,12 @@ def code_changed(user_script): return con, trial -@pytest.mark.usefixtures("storage") -def test_code_changed_evc_disabled(config, monkeypatch, caplog): +def test_code_changed_evc_disabled(storage, config, monkeypatch, caplog): """Check that trial has its working_dir attribute changed.""" - con, trial = setup_code_change_mock(config, monkeypatch, ignore_code_changes=True) + con, trial = setup_code_change_mock( + storage, config, monkeypatch, ignore_code_changes=True + ) with caplog.at_level(logging.WARNING): con(trial) @@ -114,11 +114,12 @@ def test_code_changed_evc_disabled(config, monkeypatch, caplog): shutil.rmtree(trial.working_dir) -@pytest.mark.usefixtures("storage") -def test_code_changed_evc_enabled(config, monkeypatch): +def test_code_changed_evc_enabled(storage, config, monkeypatch): """Check that trial has its working_dir attribute changed.""" - con, trial = setup_code_change_mock(config, monkeypatch, ignore_code_changes=False) + con, trial = setup_code_change_mock( + storage, config, monkeypatch, ignore_code_changes=False + ) with pytest.raises(BranchingEvent) as exc: con(trial) @@ -128,14 +129,13 @@ def test_code_changed_evc_enabled(config, monkeypatch): shutil.rmtree(trial.working_dir) -@pytest.mark.usefixtures("storage") -def test_retrieve_result_nofile(config): +def test_retrieve_result_nofile(storage, config): """Test retrieve result""" results_file = tempfile.NamedTemporaryFile( mode="w", prefix="results_", suffix=".log", dir=".", delete=True ) - exp = experiment_builder.build(**config) + exp = experiment_builder.build(storage=storage, **config) con = Consumer(exp) diff --git a/tests/unittests/core/worker/test_experiment.py b/tests/unittests/core/worker/test_experiment.py index 1efda3af4..7bd130e38 100644 --- a/tests/unittests/core/worker/test_experiment.py +++ b/tests/unittests/core/worker/test_experiment.py @@ -22,7 +22,7 @@ from orion.core.worker.experiment import Experiment, Mode from orion.core.worker.primary_algo import create_algo from orion.core.worker.trial import Trial -from orion.storage.base import LockedAlgorithmState, get_storage +from orion.storage.base import LockedAlgorithmState, setup_storage from orion.testing import OrionState @@ -180,11 +180,11 @@ def space(): class TestReserveTrial: """Calls to interface `Experiment.reserve_trial`.""" - @pytest.mark.usefixtures("setup_pickleddb_database") + @pytest.mark.usefixtures("orionstate") def test_reserve_none(self, space: Space): """Find nothing, return None.""" - with OrionState(experiments=[], trials=[]): - exp = Experiment("supernaekei", mode="x", space=space) + with OrionState(experiments=[], trials=[]) as cfg: + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) trial = exp.reserve_trial() assert trial is None @@ -194,7 +194,7 @@ def test_reserve_success(self, random_dt, space: Space): with OrionState( trials=generate_trials(["new", "reserved"]), storage=storage_config ) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] trial = exp.reserve_trial() @@ -210,7 +210,7 @@ def test_reserve_when_exhausted(self, space: Space): """Return None once all the trials have been allocated""" statuses = ["new", "reserved", "interrupted", "completed", "broken"] with OrionState(trials=generate_trials(statuses)) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] assert exp.reserve_trial() is not None assert exp.reserve_trial() is not None @@ -224,7 +224,7 @@ def test_fix_lost_trials(self, space: Space): seconds=60 * 10 ) with OrionState(trials=[trial]) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] assert len(exp.fetch_trials_by_status("reserved")) == 1 @@ -240,7 +240,7 @@ def test_fix_only_lost_trials(self, space: Space): running_trial["heartbeat"] = datetime.datetime.utcnow() with OrionState(trials=[lost_trial, running_trial]) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] assert len(exp.fetch_trials_by_status("reserved")) == 2 @@ -263,7 +263,7 @@ def test_fix_lost_trials_race_condition(self, monkeypatch, caplog, space: Space) seconds=60 * 10 ) with OrionState(trials=[trial]) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] assert len(exp.fetch_trials_by_status("interrupted")) == 1 @@ -300,7 +300,7 @@ def test_fix_lost_trials_configurable_hb(self, space: Space): seconds=60 * 2 ) with OrionState(trials=[trial]) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] assert len(exp.fetch_trials_by_status("reserved")) == 1 @@ -323,7 +323,7 @@ def test_acquire_algorithm_lock_successful( self, new_config, algorithm, space: Space ): with OrionState(experiments=[new_config]) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = 0 exp.algorithms = algorithm @@ -350,7 +350,7 @@ def test_acquire_algorithm_lock_with_different_config( self, new_config, algorithm, space: Space ): with OrionState(experiments=[new_config]) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = 0 algorithm_original_config = algorithm.configuration exp.algorithms = algorithm @@ -369,13 +369,11 @@ def test_acquire_algorithm_lock_timeout( self, new_config, algorithm, mocker, space: Space ): with OrionState(experiments=[new_config]) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = 0 exp.algorithms = algorithm - storage_acquisition_mock = mocker.spy( - cfg.storage(), "acquire_algorithm_lock" - ) + storage_acquisition_mock = mocker.spy(cfg.storage, "acquire_algorithm_lock") with exp.acquire_algorithm_lock(timeout=0.2, retry_interval=0.1): pass @@ -388,7 +386,7 @@ def test_acquire_algorithm_lock_timeout( def test_update_completed_trial(random_dt, space: Space): """Successfully push a completed trial into database.""" with OrionState(trials=generate_trials(["new"])) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] trial = exp.reserve_trial() @@ -404,7 +402,7 @@ def test_update_completed_trial(random_dt, space: Space): exp.update_completed_trial(trial, results_file=results_file) - yo = get_storage().fetch_trials(exp)[0].to_dict() + yo = cfg.storage.fetch_trials(exp)[0].to_dict() assert len(yo["results"]) == len(trial.results) assert yo["results"][0] == trial.results[0].to_dict() @@ -417,8 +415,8 @@ def test_update_completed_trial(random_dt, space: Space): @pytest.mark.usefixtures("with_user_tsirif") def test_register_trials(tmp_path, random_dt, space: Space): """Register a list of newly proposed trials/parameters.""" - with OrionState(): - exp = Experiment("supernaekei", mode="x", space=space) + with OrionState() as cfg: + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = 0 exp.working_dir = tmp_path @@ -429,7 +427,7 @@ def test_register_trials(tmp_path, random_dt, space: Space): for trial in trials: exp.register_trial(trial) - yo = list(map(lambda trial: trial.to_dict(), get_storage().fetch_trials(exp))) + yo = list(map(lambda trial: trial.to_dict(), setup_storage().fetch_trials(exp))) assert len(yo) == len(trials) assert yo[0]["params"] == list(map(lambda x: x.to_dict(), trials[0]._params)) assert yo[1]["params"] == list(map(lambda x: x.to_dict(), trials[1]._params)) @@ -446,8 +444,8 @@ class TestToPandas: def test_empty(self, space: Space): """Test panda frame creation when there is no trials""" - with OrionState(): - exp = Experiment("supernaekei", mode="x", space=space) + with OrionState() as cfg: + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) assert exp.to_pandas().shape == (0, 8) assert list(exp.to_pandas().columns) == [ "id", @@ -465,7 +463,7 @@ def test_data(self, space: Space): with OrionState( trials=generate_trials(["new", "reserved", "completed"]) ) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] df = exp.to_pandas() assert df.shape == (3, 8) @@ -488,7 +486,7 @@ def test_data(self, space: Space): def test_fetch_all_trials(space: Space): """Fetch a list of all trials""" with OrionState(trials=generate_trials(["new", "reserved", "completed"])) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] trials = list(map(lambda trial: trial.to_dict(), exp.fetch_trials({}))) @@ -503,7 +501,7 @@ def test_fetch_pending_trials(space: Space): pending_stati = ["new", "interrupted", "suspended"] statuses = pending_stati + ["completed", "broken", "reserved"] with OrionState(trials=generate_trials(statuses)) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] trials = exp.fetch_pending_trials() @@ -519,7 +517,7 @@ def test_fetch_non_completed_trials(space: Space): non_completed_stati = ["new", "interrupted", "suspended", "reserved"] statuses = non_completed_stati + ["completed"] with OrionState(trials=generate_trials(statuses)) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] trials = exp.fetch_noncompleted_trials() @@ -532,7 +530,7 @@ def test_is_done_property_with_pending(algorithm, space: Space): completed = ["completed"] * 10 reserved = ["reserved"] * 5 with OrionState(trials=generate_trials(completed + reserved)) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] exp.algorithms = algorithm @@ -556,7 +554,7 @@ def test_is_done_property_no_pending(algorithm, space: Space): completed = ["completed"] * 10 broken = ["broken"] * 5 with OrionState(trials=generate_trials(completed + broken)) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] exp.algorithms = algorithm @@ -578,7 +576,7 @@ def test_broken_property(space: Space): statuses = (["reserved"] * 10) + (["broken"] * (MAX_BROKEN - 1)) with OrionState(trials=generate_trials(statuses)) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] exp.max_broken = MAX_BROKEN @@ -587,7 +585,7 @@ def test_broken_property(space: Space): statuses = (["reserved"] * 10) + (["broken"] * (MAX_BROKEN)) with OrionState(trials=generate_trials(statuses)) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] exp.max_broken = MAX_BROKEN @@ -601,7 +599,7 @@ def test_configurable_broken_property(space: Space): statuses = (["reserved"] * 10) + (["broken"] * (MAX_BROKEN)) with OrionState(trials=generate_trials(statuses)) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", space=space, storage=cfg.storage) exp._id = cfg.trials[0]["experiment"] exp.max_broken = MAX_BROKEN @@ -618,7 +616,7 @@ def test_experiment_stats(space: Space): NUM_COMPLETED = 3 statuses = (["completed"] * NUM_COMPLETED) + (["reserved"] * 2) with OrionState(trials=generate_trials(statuses)) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", space=space, storage=cfg.storage) exp._id = cfg.trials[0]["experiment"] exp.metadata = {"datetime": datetime.datetime.utcnow()} stats = exp.stats @@ -634,7 +632,7 @@ def test_experiment_pickleable(space: Space): """Test experiment instance is pickleable""" with OrionState(trials=generate_trials(["new"])) as cfg: - exp = Experiment("supernaekei", mode="x", space=space) + exp = Experiment("supernaekei", mode="x", storage=cfg.storage, space=space) exp._id = cfg.trials[0]["experiment"] exp_trials = exp.fetch_trials() @@ -673,6 +671,7 @@ def test_experiment_pickleable(space: Space): "to_pandas", "version", "working_dir", + "storage", ] read_write_only_methods = [ "fix_lost_trials", @@ -750,7 +749,7 @@ def compare_unsupported(attr_name, restricted_exp, execution_exp): restricted_attr(**kwargs.get(attr_name, {})) -def create_experiment(mode: Mode, space: Space, algorithm): +def create_experiment(mode: Mode, space: Space, algorithm, storage): experiment = Experiment( "supernaekei", mode=mode, @@ -758,6 +757,7 @@ def create_experiment(mode: Mode, space: Space, algorithm): algorithms=algorithm, max_broken=5, max_trials=5, + storage=storage, _id=1, ) return experiment @@ -780,16 +780,24 @@ class TestReadOnly: @pytest.mark.parametrize("method", read_only_methods) def test_read_only_methods(self, space, algorithm, method): with OrionState(trials=trials) as cfg: - read_only_exp = create_experiment("r", space, algorithm) - execution_exp = create_experiment("x", space, algorithm) + read_only_exp = create_experiment( + "r", space, algorithm, storage=cfg.storage + ) + execution_exp = create_experiment( + "x", space, algorithm, storage=cfg.storage + ) compare_supported(method, read_only_exp, execution_exp) @pytest.mark.parametrize("method", read_write_only_methods + execute_only_methods) def test_read_write_methods(self, space, algorithm, method, monkeypatch): with OrionState(trials=trials) as cfg: - disable_algo_lock(monkeypatch, cfg.storage()) - read_only_exp = create_experiment("r", space, algorithm) - execution_exp = create_experiment("x", space, algorithm) + disable_algo_lock(monkeypatch, cfg.storage) + read_only_exp = create_experiment( + "r", space, algorithm, storage=cfg.storage + ) + execution_exp = create_experiment( + "x", space, algorithm, storage=cfg.storage + ) compare_unsupported(method, read_only_exp, execution_exp) @@ -799,14 +807,22 @@ class TestReadWriteOnly: @pytest.mark.parametrize("method", read_only_methods) def test_read_only_methods(self, space, algorithm, method): with OrionState(trials=trials) as cfg: - read_only_exp = create_experiment("w", space, algorithm) - execution_exp = create_experiment("x", space, algorithm) + read_only_exp = create_experiment( + "w", space, algorithm, storage=cfg.storage + ) + execution_exp = create_experiment( + "x", space, algorithm, storage=cfg.storage + ) compare_supported(method, read_only_exp, execution_exp) @pytest.mark.parametrize("method", execute_only_methods) def test_execution_methods(self, space, algorithm, method, monkeypatch): with OrionState(trials=trials) as cfg: - disable_algo_lock(monkeypatch, cfg.storage()) - read_only_exp = create_experiment("w", space, algorithm) - execution_exp = create_experiment("x", space, algorithm) + disable_algo_lock(monkeypatch, cfg.storage) + read_only_exp = create_experiment( + "w", space, algorithm, storage=cfg.storage + ) + execution_exp = create_experiment( + "x", space, algorithm, storage=cfg.storage + ) compare_unsupported(method, read_only_exp, execution_exp) diff --git a/tests/unittests/core/worker/test_producer.py b/tests/unittests/core/worker/test_producer.py index 3f612896b..c58c0d3f7 100644 --- a/tests/unittests/core/worker/test_producer.py +++ b/tests/unittests/core/worker/test_producer.py @@ -66,7 +66,7 @@ def create_producer(): experiment.algorithms.algorithm.max_trials = 20 producer = Producer(experiment) - yield producer, cfg.storage() + yield producer, cfg.storage def test_produce(): diff --git a/tests/unittests/core/worker/test_trial_pacemaker.py b/tests/unittests/core/worker/test_trial_pacemaker.py index 7e84e6695..947157d02 100644 --- a/tests/unittests/core/worker/test_trial_pacemaker.py +++ b/tests/unittests/core/worker/test_trial_pacemaker.py @@ -8,7 +8,7 @@ import orion.core.io.experiment_builder as experiment_builder from orion.core.utils.format_trials import tuple_to_trial from orion.core.worker.trial_pacemaker import TrialPacemaker -from orion.storage.base import get_storage +from orion.storage.base import setup_storage @pytest.fixture @@ -21,9 +21,9 @@ def config(exp_config): @pytest.fixture -def exp(config): +def exp(storage, config): """Return an Experiment.""" - return experiment_builder.build(**config) + return experiment_builder.build(**config, storage=storage) @pytest.fixture @@ -35,15 +35,17 @@ def trial(exp): trial.status = "reserved" trial.heartbeat = heartbeat - get_storage().register_trial(trial) + setup_storage().register_trial(trial) return trial -@pytest.mark.usefixtures("storage") def test_trial_update_heartbeat(exp, trial): """Test that the heartbeat of a trial has been updated.""" - trial_monitor = TrialPacemaker(trial, wait_time=1) + + storage = setup_storage() + + trial_monitor = TrialPacemaker(trial, wait_time=1, storage=storage) trial_monitor.start() time.sleep(2) @@ -62,10 +64,11 @@ def test_trial_update_heartbeat(exp, trial): trial_monitor.stop() -@pytest.mark.usefixtures("storage") def test_trial_heartbeat_not_updated(exp, trial): """Test that the heartbeat of a trial is not updated when trial is not longer reserved.""" - trial_monitor = TrialPacemaker(trial, wait_time=1) + storage = setup_storage() + + trial_monitor = TrialPacemaker(trial, wait_time=1, storage=storage) trial_monitor.start() time.sleep(2) @@ -74,7 +77,7 @@ def test_trial_heartbeat_not_updated(exp, trial): assert trial.heartbeat != trials[0].heartbeat - get_storage().set_trial_status(trial, status="interrupted") + setup_storage().set_trial_status(trial, status="interrupted") time.sleep(2) @@ -83,10 +86,10 @@ def test_trial_heartbeat_not_updated(exp, trial): assert 1 -@pytest.mark.usefixtures("storage") def test_trial_heartbeat_not_updated_inbetween(exp, trial): """Test that the heartbeat of a trial is not updated before wait time.""" - trial_monitor = TrialPacemaker(trial, wait_time=5) + storage = setup_storage() + trial_monitor = TrialPacemaker(trial, wait_time=5, storage=storage) trial_monitor.start() time.sleep(1) diff --git a/tests/unittests/storage/test_legacy.py b/tests/unittests/storage/test_legacy.py index 9aba4a015..24481abd7 100644 --- a/tests/unittests/storage/test_legacy.py +++ b/tests/unittests/storage/test_legacy.py @@ -6,16 +6,10 @@ import pytest -from orion.core.io.database import database_factory from orion.core.io.database.pickleddb import PickledDB -from orion.core.utils.singleton import ( - SingletonAlreadyInstantiatedError, - SingletonNotInstantiatedError, - update_singletons, -) from orion.core.worker.trial import Trial from orion.storage.base import FailedUpdate -from orion.storage.legacy import get_database, setup_database +from orion.storage.legacy import setup_database from orion.testing import OrionState log = logging.getLogger(__name__) @@ -66,18 +60,17 @@ db_backends = [{"type": "legacy", "database": mongodb_config}] -@pytest.mark.usefixtures("setup_pickleddb_database") +@pytest.mark.usefixtures("orionstate") def test_setup_database_default(monkeypatch): """Test that database is setup using default config""" - update_singletons() - setup_database() - database = database_factory.create() + + database = setup_database() assert isinstance(database, PickledDB) def test_setup_database_bad(): """Test how setup fails when configuring with non-existent backends""" - update_singletons() + with pytest.raises(NotImplementedError) as exc: setup_database({"type": "idontexist"}) @@ -86,51 +79,32 @@ def test_setup_database_bad(): def test_setup_database_custom(): """Test setup with local configuration""" - update_singletons() - setup_database({"type": "pickleddb", "host": "test.pkl"}) - database = database_factory.create() + + database = setup_database({"type": "pickleddb", "host": "test.pkl"}) + assert isinstance(database, PickledDB) assert database.host == os.path.abspath("test.pkl") def test_setup_database_bad_override(): """Test setup with different type than existing singleton""" - update_singletons() - setup_database({"type": "pickleddb", "host": "test.pkl"}) - database = database_factory.create() - assert isinstance(database, PickledDB) - with pytest.raises(SingletonAlreadyInstantiatedError) as exc: - setup_database({"type": "mongodb"}) - assert exc.match(r"A singleton instance of \(type: Database\)") + database = setup_database({"type": "pickleddb", "host": "test.pkl"}) + assert isinstance(database, PickledDB) def test_setup_database_bad_config_override(): """Test setup with different config than existing singleton""" - update_singletons() - setup_database({"type": "pickleddb", "host": "test.pkl"}) - database = database_factory.create() - assert isinstance(database, PickledDB) - with pytest.raises(SingletonAlreadyInstantiatedError): - setup_database({"type": "pickleddb", "host": "other.pkl"}) - - -def test_get_database_uninitiated(): - """Test that get database fails if no database singleton exist""" - update_singletons() - with pytest.raises(SingletonNotInstantiatedError) as exc: - get_database() - assert exc.match(r"No singleton instance of \(type: Database\) was created") + database = setup_database({"type": "pickleddb", "host": "test.pkl"}) + assert isinstance(database, PickledDB) def test_get_database(): """Test that get database gets the singleton""" - update_singletons() - setup_database({"type": "pickleddb", "host": "test.pkl"}) - database = get_database() + + database = setup_database({"type": "pickleddb", "host": "test.pkl"}) assert isinstance(database, PickledDB) - assert get_database() == database class TestLegacyStorage: @@ -143,7 +117,7 @@ def test_push_trial_results(self, storage=None): with OrionState( experiments=[], trials=[reserved_trial], storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage trial = storage.get_trial(Trial(**reserved_trial)) results = [Trial.Result(name="loss", type="objective", value=2)] trial.results = results @@ -155,7 +129,7 @@ def test_push_trial_results(self, storage=None): def test_push_trial_results_unreserved(self, storage=None): """Successfully push a completed trial into database.""" with OrionState(experiments=[], trials=[base_trial], storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage trial = storage.get_trial(Trial(**base_trial)) results = [Trial.Result(name="loss", type="objective", value=2)] trial.results = results diff --git a/tests/unittests/storage/test_storage.py b/tests/unittests/storage/test_storage.py index e13d27acf..fb70cc0a8 100644 --- a/tests/unittests/storage/test_storage.py +++ b/tests/unittests/storage/test_storage.py @@ -13,20 +13,13 @@ import orion.core from orion.core.io.database import DuplicateKeyError from orion.core.io.database.pickleddb import PickledDB -from orion.core.utils.singleton import ( - SingletonAlreadyInstantiatedError, - SingletonNotInstantiatedError, - update_singletons, -) from orion.core.worker.trial import Trial from orion.storage.base import ( FailedUpdate, LockAcquisitionTimeout, LockedAlgorithmState, MissingArguments, - get_storage, setup_storage, - storage_factory, ) from orion.storage.legacy import Legacy from orion.storage.track import HAS_TRACK, REASON @@ -121,19 +114,18 @@ def generate_experiments(): return [_generate(exp, "name", value=str(i)) for i, exp in enumerate(exps)] -@pytest.mark.usefixtures("setup_pickleddb_database") +@pytest.mark.usefixtures("orionstate") def test_setup_storage_default(): """Test that storage is setup using default config""" - update_singletons() - setup_storage() - storage = storage_factory.create() + + storage = setup_storage() assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) def test_setup_storage_bad(): """Test how setup fails when configuring with non-existent backends""" - update_singletons() + with pytest.raises(NotImplementedError) as exc: setup_storage({"type": "idontexist"}) @@ -142,11 +134,10 @@ def test_setup_storage_bad(): def test_setup_storage_custom(): """Test setup with local configuration""" - update_singletons() - setup_storage( + + storage = setup_storage( {"type": "legacy", "database": {"type": "pickleddb", "host": "test.pkl"}} ) - storage = storage_factory.create() assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) assert storage._db.host == os.path.abspath("test.pkl") @@ -154,20 +145,20 @@ def test_setup_storage_custom(): def test_setup_storage_custom_type_missing(): """Test setup with local configuration with type missing""" - update_singletons() - setup_storage({"database": {"type": "pickleddb", "host": "test.pkl"}}) - storage = storage_factory.create() + + storage = setup_storage({"database": {"type": "pickleddb", "host": "test.pkl"}}) + assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) assert storage._db.host == os.path.abspath("test.pkl") -@pytest.mark.usefixtures("setup_pickleddb_database") +@pytest.mark.usefixtures("orionstate") def test_setup_storage_custom_legacy_emtpy(): """Test setup with local configuration with legacy but no config""" - update_singletons() - setup_storage({"type": "legacy"}) - storage = storage_factory.create() + + storage = setup_storage({"type": "legacy"}) + assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) assert storage._db.host == orion.core.config.storage.database.host @@ -175,61 +166,23 @@ def test_setup_storage_custom_legacy_emtpy(): def test_setup_storage_bad_override(): """Test setup with different type than existing singleton""" - update_singletons() - setup_storage( + + storage = setup_storage( {"type": "legacy", "database": {"type": "pickleddb", "host": "test.pkl"}} ) - storage = storage_factory.create() assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) - with pytest.raises(SingletonAlreadyInstantiatedError) as exc: - setup_storage({"type": "track"}) - - assert exc.match(r"A singleton instance of \(type: BaseStorageProtocol\)") - - -@pytest.mark.xfail(reason="Fix this when introducing #135 in v0.2.0") -def test_setup_storage_bad_config_override(): - """Test setup with different config than existing singleton""" - update_singletons() - setup_storage({"database": {"type": "pickleddb", "host": "test.pkl"}}) - storage = storage_factory.create() - assert isinstance(storage, Legacy) - assert isinstance(storage._db, PickledDB) - with pytest.raises(SingletonAlreadyInstantiatedError): - setup_storage({"database": {"type": "mongodb"}}) def test_setup_storage_stateless(): """Test that passed configuration dictionary is not modified by the function""" - update_singletons() + config = {"database": {"type": "pickleddb", "host": "test.pkl"}} passed_config = copy.deepcopy(config) setup_storage(passed_config) assert config == passed_config -def test_get_storage_uninitiated(): - """Test that get storage fails if no storage singleton exist""" - update_singletons() - with pytest.raises(SingletonNotInstantiatedError) as exc: - get_storage() - - assert exc.match( - r"No singleton instance of \(type: BaseStorageProtocol\) was created" - ) - - -def test_get_storage(): - """Test that get storage gets the singleton""" - update_singletons() - setup_storage({"database": {"type": "pickleddb", "host": "test.pkl"}}) - storage = get_storage() - assert isinstance(storage, Legacy) - assert isinstance(storage._db, PickledDB) - assert get_storage() == storage - - @pytest.mark.usefixtures("version_XYZ") @pytest.mark.parametrize("storage", storage_backends) class TestStorage: @@ -238,7 +191,7 @@ class TestStorage: def test_create_experiment(self, storage): """Test create experiment""" with OrionState(experiments=[], storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage storage.create_experiment(base_experiment) @@ -255,7 +208,7 @@ def test_create_experiment(self, storage): def test_fetch_experiments(self, storage, name="0", user="a"): """Test fetch experiments""" with OrionState(experiments=generate_experiments(), storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage experiments = storage.fetch_experiments({}) assert len(experiments) == len(cfg.experiments) @@ -279,7 +232,7 @@ def test_fetch_experiments(self, storage, name="0", user="a"): def test_update_experiment(self, monkeypatch, storage, name="0", user="a"): """Test fetch experiments""" with OrionState(experiments=generate_experiments(), storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage class _Dummy: pass @@ -319,7 +272,7 @@ def test_delete_experiment(self, storage): pytest.xfail("Track does not support deletion yet.") with OrionState(experiments=generate_experiments(), storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage n_experiments = len(storage.fetch_experiments({})) storage.delete_experiment(uid=cfg.experiments[0]["_id"]) @@ -330,7 +283,7 @@ def test_delete_experiment(self, storage): def test_register_trial(self, storage): """Test register trial""" with OrionState(experiments=[base_experiment], storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage trial1 = storage.register_trial(Trial(**base_trial)) trial2 = storage.get_trial(trial1) @@ -343,7 +296,7 @@ def test_register_duplicate_trial(self, storage): with OrionState( experiments=[base_experiment], trials=[base_trial], storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage with pytest.raises(DuplicateKeyError): storage.register_trial(Trial(**base_trial)) @@ -355,7 +308,7 @@ def test_update_trials(self, storage): trials=generate_trials(status=["completed", "reserved", "reserved"]), storage=storage, ) as cfg: - storage = cfg.storage() + storage = cfg.storage class _Dummy: pass @@ -375,7 +328,7 @@ def test_update_trial(self, storage): with OrionState( experiments=[base_experiment], trials=[base_trial], storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage trial = Trial(**cfg.trials[0]) @@ -388,7 +341,7 @@ def test_reserve_trial_success(self, storage): with OrionState( experiments=[base_experiment], trials=[base_trial], storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) trial = storage.reserve_trial(experiment) @@ -404,7 +357,7 @@ def test_reserve_trial_fail(self, storage): storage=storage, ) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) trial = storage.reserve_trial(experiment) @@ -415,7 +368,7 @@ def test_fetch_trials(self, storage): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) trials1 = storage.fetch_trials(experiment=experiment) @@ -437,7 +390,7 @@ def test_fetch_trials_with_query(self, storage): trials=generate_trials(status=["completed", "reserved", "reserved"]), storage=storage, ) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) trials_all = storage.fetch_trials(experiment=experiment) @@ -464,7 +417,7 @@ def test_delete_all_trials(self, storage): with OrionState( experiments=[base_experiment], trials=trials, storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage # Make sure we have sufficient trials to test deletion trials = storage.fetch_trials(uid="default_name") @@ -489,7 +442,7 @@ def test_delete_trials_with_query(self, storage): with OrionState( experiments=[base_experiment], trials=trials, storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name") # Make sure we have sufficient trials to test deletion @@ -515,7 +468,7 @@ def test_get_trial(self, storage): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage trial_dict = cfg.trials[0] @@ -545,7 +498,7 @@ def test_fetch_lost_trials(self, storage): with OrionState( experiments=[base_experiment], trials=trials, storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) trials = storage.fetch_lost_trials(experiment) @@ -559,15 +512,15 @@ def check_status_change(new_status): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - trial = get_storage().get_trial(cfg.get_trial(0)) + trial = setup_storage().get_trial(cfg.get_trial(0)) assert trial is not None, "was not able to retrieve trial for test" - get_storage().set_trial_status(trial, status=new_status) + setup_storage().set_trial_status(trial, status=new_status) assert ( trial.status == new_status ), "Trial status should have been updated locally" - trial = get_storage().get_trial(trial) + trial = setup_storage().get_trial(trial) assert ( trial.status == new_status ), "Trial status should have been updated in the storage" @@ -584,11 +537,11 @@ def test_change_status_invalid(self, storage): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - trial = get_storage().get_trial(cfg.get_trial(0)) + trial = setup_storage().get_trial(cfg.get_trial(0)) assert trial is not None, "Was not able to retrieve trial for test" with pytest.raises(ValueError) as exc: - get_storage().set_trial_status(trial, status="moo") + setup_storage().set_trial_status(trial, status="moo") assert exc.match("Given status `moo` not one of") @@ -599,7 +552,7 @@ def check_status_change(new_status): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - trial = get_storage().get_trial(cfg.get_trial(1)) + trial = cfg.storage.get_trial(cfg.get_trial(1)) assert trial is not None, "Was not able to retrieve trial for test" assert trial.status != new_status @@ -608,7 +561,7 @@ def check_status_change(new_status): with pytest.raises(FailedUpdate): trial.status = new_status - get_storage().set_trial_status(trial, status=new_status) + setup_storage().set_trial_status(trial, status=new_status) check_status_change("completed") check_status_change("broken") @@ -625,7 +578,7 @@ def check_status_change(new_status): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - trial = get_storage().get_trial(cfg.get_trial(1)) + trial = cfg.storage.get_trial(cfg.get_trial(1)) assert trial is not None, "Was not able to retrieve trial for test" assert trial.status != new_status @@ -636,9 +589,9 @@ def check_status_change(new_status): trial.status = "broken" assert correct_status != "broken" with pytest.raises(FailedUpdate): - get_storage().set_trial_status(trial, status=new_status) + setup_storage().set_trial_status(trial, status=new_status) - get_storage().set_trial_status( + setup_storage().set_trial_status( trial, status=new_status, was=correct_status ) @@ -655,7 +608,7 @@ def check_status_change(new_status): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - trial = get_storage().get_trial(cfg.get_trial(1)) + trial = cfg.storage.get_trial(cfg.get_trial(1)) assert trial is not None, "Was not able to retrieve trial for test" assert trial.status != new_status @@ -663,7 +616,7 @@ def check_status_change(new_status): return with pytest.raises(FailedUpdate): - get_storage().set_trial_status( + setup_storage().set_trial_status( trial, status=new_status, was=new_status ) @@ -678,7 +631,7 @@ def test_fetch_pending_trials(self, storage): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) trials = storage.fetch_pending_trials(experiment) @@ -697,7 +650,7 @@ def test_fetch_noncompleted_trials(self, storage): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) trials = storage.fetch_noncompleted_trials(experiment) @@ -722,7 +675,7 @@ def test_fetch_trials_by_status(self, storage): if trial["status"] == "completed": count += 1 - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) trials = storage.fetch_trials_by_status(experiment, "completed") @@ -740,7 +693,7 @@ def test_count_completed_trials(self, storage): if trial["status"] == "completed": count += 1 - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) trials = storage.count_completed_trials(experiment) @@ -756,7 +709,7 @@ def test_count_broken_trials(self, storage): if trial["status"] == "broken": count += 1 - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) @@ -770,7 +723,7 @@ def test_update_heartbeat(self, storage): experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: storage_name = storage - storage = cfg.storage() + storage = cfg.storage exp = cfg.get_experiment("default_name") trial1 = storage.fetch_trials_by_status(exp, status="reserved")[0] @@ -802,7 +755,7 @@ def test_serializable(self, storage): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - storage = cfg.storage() + storage = cfg.storage serialized = pickle.dumps(storage) deserialized = pickle.loads(serialized) assert storage.fetch_experiments({}) == deserialized.fetch_experiments({}) @@ -812,7 +765,7 @@ def test_get_algorithm_lock_info(self, storage): pytest.xfail("Track does not support algorithm lock yet.") with OrionState(experiments=generate_experiments(), storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage experiments = storage.fetch_experiments({}) @@ -826,7 +779,7 @@ def test_delete_algorithm_lock(self, storage): pytest.xfail("Track does not support algorithm lock yet.") with OrionState(experiments=generate_experiments(), storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage experiments = storage.fetch_experiments({}) @@ -838,7 +791,7 @@ def test_acquire_algorithm_lock_successful(self, storage): pytest.xfail("Track does not support algorithm lock yet.") with OrionState(experiments=[base_experiment], storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) with storage.acquire_algorithm_lock( @@ -852,7 +805,7 @@ def test_acquire_algorithm_lock_successful(self, storage): def test_acquire_algorithm_lock_timeout(self, storage, mocker): with OrionState(experiments=[base_experiment], storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) with storage.acquire_algorithm_lock(experiment) as locked_algo_state: @@ -871,7 +824,7 @@ def test_acquire_algorithm_lock_timeout(self, storage, mocker): def test_acquire_algorithm_lock_handle_fail(self, storage): with OrionState(experiments=[base_experiment], storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) with storage.acquire_algorithm_lock(experiment) as locked_algo_state: assert locked_algo_state.state is None @@ -888,7 +841,7 @@ def test_acquire_algorithm_lock_handle_fail(self, storage): def test_acquire_algorithm_lock_not_initialised(self, storage): with OrionState(experiments=[base_experiment], storage=storage) as cfg: - storage = cfg.storage() + storage = cfg.storage experiment = cfg.get_experiment("default_name", version=None) experiment._id = "bad id" with pytest.raises(LockAcquisitionTimeout):