diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..fa6f401b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/psf/black + rev: master + hooks: + - id: black + language_version: python3.6 \ No newline at end of file diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 379fb094..b98bc3d8 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -66,22 +66,27 @@ local development. $ git checkout -b name-of-your-bugfix-or-feature +4. Create your development environment and install the pre-commit hooks + $ # Activate your environment + $ pip install -r dev-requirements.txt + $ pre-commit install + Now you can make your changes locally. -4. When you're done making changes, check that your changes pass style and unit +5. When you're done making changes, check that your changes pass style and unit tests, including testing other Python versions with tox:: $ tox To get tox, just pip install it. -5. Commit your changes and push your branch to GitHub:: +6. Commit your changes and push your branch to GitHub:: $ git add . $ git commit -m "Your detailed description of your changes." $ git push origin name-of-your-bugfix-or-feature -6. Submit a pull request through the GitHub website. +7. Submit a pull request through the GitHub website. .. _Fork: https://github.com/IDSIA/sacred/fork diff --git a/azure-pipelines.yml b/azure-pipelines.yml index bf0d1193..de7387ec 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -63,6 +63,10 @@ jobs: IMAGE_NAME: 'ubuntu-16.04' PYTHON_VERSION: '3.7' TOX_CMD: 'flake8' + Linux black: + IMAGE_NAME: 'ubuntu-16.04' + PYTHON_VERSION: '3.6' + TOX_CMD: 'black' Linux coverage: IMAGE_NAME: 'ubuntu-16.04' PYTHON_VERSION: '3.7' diff --git a/dev-requirements.txt b/dev-requirements.txt index 003067be..deb4b86d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -28,4 +28,4 @@ wrapt==1.10.8 scikit-learn==0.20.3 pymongo==3.8.0 py-cpuinfo==4.0 - +pre-commit==1.18.0 diff --git a/docs/conf.py b/docs/conf.py index 5b4c19a8..6ed7fae0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) from sacred import __version__ # -- General configuration ------------------------------------------------ @@ -30,26 +30,23 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon' -] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. # source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'Sacred' -copyright = '2015, Klaus Greff' +project = "Sacred" +copyright = "2015, Klaus Greff" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -61,7 +58,7 @@ # The full version, including alpha/beta/rc tags. release = __version__ -version = '.'.join(release.split('.')[:2]) +version = ".".join(release.split(".")[:2]) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -75,7 +72,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -93,7 +90,7 @@ # show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] @@ -105,15 +102,16 @@ # -- Options for HTML output ---------------------------------------------- # on_rtd is whether we are on readthedocs.org -on_rtd = os.environ.get('READTHEDOCS') == 'True' +on_rtd = os.environ.get("READTHEDOCS") == "True" if not on_rtd: # only import and set the theme if we're building docs locally try: import sphinx_rtd_theme - html_theme = 'sphinx_rtd_theme' + + html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] except ImportError: - html_theme = 'default' + html_theme = "default" # else: readthedocs.org uses their theme by default, so no need to specify it @@ -144,7 +142,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied @@ -193,7 +191,7 @@ # html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'Sacreddoc' +htmlhelp_basename = "Sacreddoc" # -- Options for LaTeX output --------------------------------------------- @@ -201,10 +199,8 @@ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # 'preamble': '', } @@ -213,8 +209,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', 'Sacred.tex', 'Sacred Documentation', - 'Klaus Greff', 'manual'), + ("index", "Sacred.tex", "Sacred Documentation", "Klaus Greff", "manual") ] # The name of an image file (relative to this directory) to place at the top of @@ -242,10 +237,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'sacred', 'Sacred Documentation', - ['Klaus Greff'], 1) -] +man_pages = [("index", "sacred", "Sacred Documentation", ["Klaus Greff"], 1)] # If true, show URL addresses after external links. # man_show_urls = False @@ -257,9 +249,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'Sacred', 'Sacred Documentation', - 'Klaus Greff', 'Sacred', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "Sacred", + "Sacred Documentation", + "Klaus Greff", + "Sacred", + "One line description of project.", + "Miscellaneous", + ) ] # Documents to append as an appendix to all manuals. diff --git a/examples/01_hello_world.py b/examples/01_hello_world.py index 0e6c81b6..82fec57b 100755 --- a/examples/01_hello_world.py +++ b/examples/01_hello_world.py @@ -33,4 +33,4 @@ # This function should be executed so we are decorating it with @ex.automain @ex.automain def main(): - print('Hello world!') + print("Hello world!") diff --git a/examples/02_hello_config_dict.py b/examples/02_hello_config_dict.py index 967a819c..0cee0ea5 100755 --- a/examples/02_hello_config_dict.py +++ b/examples/02_hello_config_dict.py @@ -29,9 +29,7 @@ ex = Experiment() # We add message to the configuration of the experiment here -ex.add_config({ - "message": "Hello world!" -}) +ex.add_config({"message": "Hello world!"}) # Equivalent: # ex.add_config( # message="Hello world!" diff --git a/examples/03_hello_config_scope.py b/examples/03_hello_config_scope.py index 1cf0bcd5..9d376c6f 100755 --- a/examples/03_hello_config_scope.py +++ b/examples/03_hello_config_scope.py @@ -37,7 +37,7 @@ from sacred import Experiment -ex = Experiment('hello_cs') # here we name the experiment explicitly +ex = Experiment("hello_cs") # here we name the experiment explicitly # A ConfigScope is a function like this decorated with @ex.config diff --git a/examples/04_captured_functions.py b/examples/04_captured_functions.py index ec8edb2f..82c767f3 100755 --- a/examples/04_captured_functions.py +++ b/examples/04_captured_functions.py @@ -21,7 +21,7 @@ from sacred import Experiment -ex = Experiment('captured_functions') +ex = Experiment("captured_functions") @ex.config @@ -32,12 +32,12 @@ def cfg(): # Captured functions have access to all the configuration parameters @ex.capture def foo(message): - print(message.format('foo')) + print(message.format("foo")) @ex.capture def bar(message): - print(message.format('bar')) + print(message.format("bar")) @ex.automain @@ -45,4 +45,4 @@ def main(): foo() # Notice that we do not pass message here bar() # or here # But we can if we feel like it... - foo('Overriding the default message for {}.') + foo("Overriding the default message for {}.") diff --git a/examples/05_my_commands.py b/examples/05_my_commands.py index 8dc214e3..be17b508 100755 --- a/examples/05_my_commands.py +++ b/examples/05_my_commands.py @@ -42,12 +42,12 @@ from sacred import Experiment -ex = Experiment('my_commands') +ex = Experiment("my_commands") @ex.config def cfg(): - name = 'John' + name = "John" @ex.command @@ -57,7 +57,7 @@ def greet(name): Uses the name from config. """ - print('Hello {}! Nice to greet you!'.format(name)) + print("Hello {}! Nice to greet you!".format(name)) @ex.command @@ -65,9 +65,9 @@ def shout(): """ Shout slang question for "what is up?" """ - print('WHAZZZUUUUUUUUUUP!!!????') + print("WHAZZZUUUUUUUUUUP!!!????") @ex.automain def main(): - print('This is just the main command. Try greet or shout.') + print("This is just the main command. Try greet or shout.") diff --git a/examples/06_randomness.py b/examples/06_randomness.py index 1000c1fc..cdb65aa8 100755 --- a/examples/06_randomness.py +++ b/examples/06_randomness.py @@ -53,7 +53,7 @@ from sacred import Experiment -ex = Experiment('randomness') +ex = Experiment("randomness") @ex.config diff --git a/examples/07_magic.py b/examples/07_magic.py index 489b9774..509712f2 100755 --- a/examples/07_magic.py +++ b/examples/07_magic.py @@ -5,9 +5,7 @@ ex = Experiment("svm") -ex.observers.append( - FileStorageObserver.create("my_runs") -) +ex.observers.append(FileStorageObserver.create("my_runs")) @ex.config # Configuration is defined through local variables. @@ -26,7 +24,9 @@ def get_model(C, gamma, kernel): @ex.automain # Using automain to enable command line integration. def run(): X, y = datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) + X_train, X_test, y_train, y_test = model_selection.train_test_split( + X, y, test_size=0.2 + ) clf = get_model() # Parameters are injected automatically. clf.fit(X_train, y_train) return clf.score(X_test, y_test) diff --git a/examples/08_less_magic.py b/examples/08_less_magic.py index 02e572be..5e2eed3b 100755 --- a/examples/08_less_magic.py +++ b/examples/08_less_magic.py @@ -5,15 +5,15 @@ ex = Experiment("svm") -ex.observers.append( - FileStorageObserver.create("my_runs") +ex.observers.append(FileStorageObserver.create("my_runs")) +ex.add_config( + { # Configuration is explicitly defined as dictionary. + "C": 1.0, + "gamma": 0.7, + "kernel": "rbf", + "seed": 42, + } ) -ex.add_config({ # Configuration is explicitly defined as dictionary. - "C": 1.0, - "gamma": 0.7, - "kernel": "rbf", - "seed": 42 -}) def get_model(C, gamma, kernel): @@ -23,8 +23,12 @@ def get_model(C, gamma, kernel): @ex.main # Using main, command-line arguments will not be interpreted in any special way. def run(_config): X, y = datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) - clf = get_model(_config["C"], _config["gamma"], _config["kernel"]) # Parameters are passed explicitly. + X_train, X_test, y_train, y_test = model_selection.train_test_split( + X, y, test_size=0.2 + ) + clf = get_model( + _config["C"], _config["gamma"], _config["kernel"] + ) # Parameters are passed explicitly. clf.fit(X_train, y_train) return clf.score(X_test, y_test) diff --git a/examples/captured_out_filter.py b/examples/captured_out_filter.py index e53f656b..8c4e65dd 100755 --- a/examples/captured_out_filter.py +++ b/examples/captured_out_filter.py @@ -13,7 +13,7 @@ from sacred import Experiment from sacred.utils import apply_backspaces_and_linefeeds -ex = Experiment('progress') +ex = Experiment("progress") # try commenting out the line below to see the difference in captured output ex.captured_out_filter = apply_backspaces_and_linefeeds @@ -31,11 +31,11 @@ def __init__(self, count): def show(self, n=1): self.progress += n - text = 'Completed {}/{} tasks'.format(self.progress, self.count) - write_and_flush('\b' * 80, '\r', text) + text = "Completed {}/{} tasks".format(self.progress, self.count) + write_and_flush("\b" * 80, "\r", text) def done(self): - write_and_flush('\n') + write_and_flush("\n") def progress(items): @@ -52,7 +52,7 @@ def main(): time.sleep(0.05) -if __name__ == '__main__': +if __name__ == "__main__": run = ex.run_commandline() - print('=' * 80) - print('Captured output: ', repr(run.captured_out)) + print("=" * 80) + print("Captured output: ", repr(run.captured_out)) diff --git a/examples/ingredient.py b/examples/ingredient.py index 0300cba6..c0fecae8 100755 --- a/examples/ingredient.py +++ b/examples/ingredient.py @@ -6,12 +6,12 @@ # ================== Dataset Ingredient ======================================= # could be in a separate file -data_ingredient = Ingredient('dataset') +data_ingredient = Ingredient("dataset") @data_ingredient.config def cfg1(): - filename = 'my_dataset.npy' # dataset filename + filename = "my_dataset.npy" # dataset filename normalize = True # normalize dataset @@ -27,30 +27,35 @@ def load_data(filename, normalize): @data_ingredient.command def stats(filename, foo=12): print('Statistics for dataset "{}":'.format(filename)) - print('mean = 42.23') - print('foo=', foo) + print("mean = 42.23") + print("foo=", foo) # ================== Experiment =============================================== + @data_ingredient.config def cfg2(): - filename = 'foo.npy' + filename = "foo.npy" + # add the Ingredient while creating the experiment -ex = Experiment('my_experiment', ingredients=[data_ingredient]) +ex = Experiment("my_experiment", ingredients=[data_ingredient]) + @ex.config def cfg3(): a = 12 b = 42 + @ex.named_config def fbb(): a = 22 dataset = {"filename": "AwwwJiss.py"} + @ex.automain def run(): data = load_data() # just use the function - print('data={}'.format(data)) + print("data={}".format(data)) diff --git a/examples/log_example.py b/examples/log_example.py index e784f9ac..9246ded9 100755 --- a/examples/log_example.py +++ b/examples/log_example.py @@ -5,16 +5,16 @@ import logging from sacred import Experiment -ex = Experiment('log_example') +ex = Experiment("log_example") # set up a custom logger -logger = logging.getLogger('mylogger') +logger = logging.getLogger("mylogger") logger.handlers = [] ch = logging.StreamHandler() formatter = logging.Formatter('[%(levelname).1s] %(name)s >> "%(message)s"') ch.setFormatter(formatter) logger.addHandler(ch) -logger.setLevel('INFO') +logger.setLevel("INFO") # attach it to the experiment ex.logger = logger @@ -38,7 +38,7 @@ def transmogrify(got_gizmo, number, _log): @ex.automain def main(number, _log): - _log.info('Attempting to transmogrify %d...', number) + _log.info("Attempting to transmogrify %d...", number) result = transmogrify() - _log.info('Transmogrification complete: %d', result) + _log.info("Transmogrification complete: %d", result) return result diff --git a/examples/modular.py b/examples/modular.py index 2a503d6b..964b325a 100755 --- a/examples/modular.py +++ b/examples/modular.py @@ -21,8 +21,8 @@ def cfg1(): @data_paths.config def cfg2(settings): - v = not settings['verbose'] - base = '/home/sacred/' + v = not settings["verbose"] + base = "/home/sacred/" # ============== Ingredient 2: dataset ======================= @@ -31,7 +31,7 @@ def cfg2(settings): @data.config def cfg3(paths): - basepath = paths['base'] + 'datasets/' + basepath = paths["base"] + "datasets/" filename = "foo.hdf5" @@ -43,7 +43,7 @@ def foo(basepath, filename, paths, settings): # ============== Experiment ============================== -ex = Experiment('modular_example', ingredients=[data, data_paths]) +ex = Experiment("modular_example", ingredients=[data, data_paths]) @ex.config @@ -51,16 +51,16 @@ def cfg(dataset): a = 10 b = 17 c = a + b - out_base = dataset['paths']['base'] + 'outputs/' - out_filename = dataset['filename'].replace('.hdf5', '.out') + out_base = dataset["paths"]["base"] + "outputs/" + out_filename = dataset["filename"].replace(".hdf5", ".out") @ex.automain def main(a, b, c, out_base, out_filename, dataset): - print('a =', a) - print('b =', b) - print('c =', c) - print('out_base =', out_base, out_filename) + print("a =", a) + print("b =", b) + print("c =", c) + print("out_base =", out_base, out_filename) # print("dataset", dataset) # print("dataset.paths", dataset['paths']) print("foo()", foo()) diff --git a/examples/named_config.py b/examples/named_config.py index cfd4a2c0..a9f75c77 100755 --- a/examples/named_config.py +++ b/examples/named_config.py @@ -4,7 +4,7 @@ from sacred import Experiment -ex = Experiment('hello_config') +ex = Experiment("hello_config") @ex.named_config diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..b54bec17 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[tool.black] +target-version = ['py35'] +include = '\.pyi?$' +exclude = ''' +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ +) +''' \ No newline at end of file diff --git a/sacred/__about__.py b/sacred/__about__.py index c9637434..815ca902 100644 --- a/sacred/__about__.py +++ b/sacred/__about__.py @@ -12,7 +12,7 @@ __version__ = "0.7.5" -__author__ = 'Klaus Greff' -__author_email__ = 'klaus.greff@startmail.com' +__author__ = "Klaus Greff" +__author_email__ = "klaus.greff@startmail.com" __url__ = "https://github.com/IDSIA/sacred" diff --git a/sacred/__init__.py b/sacred/__init__.py index f5236cc9..57bd2412 100644 --- a/sacred/__init__.py +++ b/sacred/__init__.py @@ -14,6 +14,14 @@ from sacred.host_info import host_info_getter -__all__ = ('Experiment', 'Ingredient', 'observers', 'host_info_getter', - '__version__', '__author__', '__author_email__', '__url__', - 'SETTINGS') +__all__ = ( + "Experiment", + "Ingredient", + "observers", + "host_info_getter", + "__version__", + "__author__", + "__author_email__", + "__url__", + "SETTINGS", +) diff --git a/sacred/arg_parser.py b/sacred/arg_parser.py index 5c4dde47..283aed95 100644 --- a/sacred/arg_parser.py +++ b/sacred/arg_parser.py @@ -19,7 +19,7 @@ from sacred.utils import set_by_dotted_path -__all__ = ('get_config_updates', 'format_usage') +__all__ = ("get_config_updates", "format_usage") USAGE_TEMPLATE = """Usage: @@ -60,11 +60,11 @@ def get_config_updates(updates): if not updates: return config_updates, named_configs for upd in updates: - if upd == '': + if upd == "": continue - path, sep, value = upd.partition('=') - if sep == '=': - path = path.strip() # get rid of surrounding whitespace + path, sep, value = upd.partition("=") + if sep == "=": + path = path.strip() # get rid of surrounding whitespace value = value.strip() # get rid of surrounding whitespace set_by_dotted_path(config_updates, path, _convert_value(value)) else: @@ -92,14 +92,17 @@ def _format_options_usage(options): short, long = op.get_flags() if op.arg: flag = "{short} {arg} {long}={arg}".format( - short=short, long=long, arg=op.arg) + short=short, long=long, arg=op.arg + ) else: flag = "{short} {long}".format(short=short, long=long) - wrapped_description = textwrap.wrap(inspect.cleandoc(op.__doc__), - width=79, - initial_indent=' ' * 32, - subsequent_indent=' ' * 32) + wrapped_description = textwrap.wrap( + inspect.cleandoc(op.__doc__), + width=79, + initial_indent=" " * 32, + subsequent_indent=" " * 32, + ) wrapped_description = "\n".join(wrapped_description).strip() options_usage += " {:28} {}\n".format(flag, wrapped_description) @@ -125,13 +128,14 @@ def _format_arguments_usage(options): argument_usage = "" for op in options: if op.arg and op.arg_description: - wrapped_description = textwrap.wrap(op.arg_description, - width=79, - initial_indent=' ' * 12, - subsequent_indent=' ' * 12) + wrapped_description = textwrap.wrap( + op.arg_description, + width=79, + initial_indent=" " * 12, + subsequent_indent=" " * 12, + ) wrapped_description = "\n".join(wrapped_description).strip() - argument_usage += " {:8} {}\n".format(op.arg, - wrapped_description) + argument_usage += " {:8} {}\n".format(op.arg, wrapped_description) return argument_usage @@ -156,8 +160,11 @@ def _format_command_usage(commands): command_usage = "\nCommands:\n" cmd_len = max([len(c) for c in commands] + [8]) command_doc = OrderedDict( - [(cmd_name, _get_first_line_of_docstring(cmd_doc)) - for cmd_name, cmd_doc in commands.items()]) + [ + (cmd_name, _get_first_line_of_docstring(cmd_doc)) + for cmd_name, cmd_doc in commands.items() + ] + ) for cmd_name, cmd_doc in command_doc.items(): command_usage += (" {:%d} {}\n" % cmd_len).format(cmd_name, cmd_doc) return command_usage @@ -188,16 +195,16 @@ def format_usage(program_name, description, commands=None, options=()): """ usage = USAGE_TEMPLATE.format( program_name=quote(program_name), - description=description.strip() if description else '', + description=description.strip() if description else "", options=_format_options_usage(options), arguments=_format_arguments_usage(options), - commands=_format_command_usage(commands) + commands=_format_command_usage(commands), ) return usage def _get_first_line_of_docstring(func): - return textwrap.dedent(func.__doc__ or "").strip().split('\n')[0] + return textwrap.dedent(func.__doc__ or "").strip().split("\n")[0] def _convert_value(value): diff --git a/sacred/commandline_options.py b/sacred/commandline_options.py index 8b330af0..303426d5 100644 --- a/sacred/commandline_options.py +++ b/sacred/commandline_options.py @@ -50,14 +50,14 @@ def get_flag(cls): flag = cls.__name__ if flag.endswith("Option"): flag = flag[:-6] - return '--' + convert_camel_case_to_snake_case(flag) + return "--" + convert_camel_case_to_snake_case(flag) @classmethod def get_short_flag(cls): if cls.short_flag is None: - return '-' + cls.get_flag()[2] + return "-" + cls.get_flag()[2] else: - return '-' + cls.short_flag + return "-" + cls.short_flag @classmethod def get_flags(cls): @@ -105,8 +105,11 @@ def gather_command_line_options(filter_disabled=None): """Get a sorted list of all CommandLineOption subclasses.""" if filter_disabled is None: filter_disabled = not SETTINGS.COMMAND_LINE.SHOW_DISABLED_OPTIONS - options = [opt for opt in get_inheritors(CommandLineOption) - if not filter_disabled or opt._enabled] + options = [ + opt + for opt in get_inheritors(CommandLineOption) + if not filter_disabled or opt._enabled + ] return sorted(options, key=lambda opt: opt.__name__) @@ -130,7 +133,7 @@ def apply(cls, args, run): class PDBOption(CommandLineOption): """Automatically enter post-mortem debugging with pdb on failure.""" - short_flag = 'D' + short_flag = "D" @classmethod def apply(cls, args, run): @@ -140,9 +143,11 @@ def apply(cls, args, run): class LoglevelOption(CommandLineOption): """Adjust the loglevel.""" - arg = 'LEVEL' - arg_description = 'Loglevel either as 0 - 50 or as string: DEBUG(10), ' \ - 'INFO(20), WARNING(30), ERROR(40), CRITICAL(50)' + arg = "LEVEL" + arg_description = ( + "Loglevel either as 0 - 50 or as string: DEBUG(10), " + "INFO(20), WARNING(30), ERROR(40), CRITICAL(50)" + ) @classmethod def apply(cls, args, run): @@ -159,19 +164,19 @@ def apply(cls, args, run): class CommentOption(CommandLineOption): """Adds a message to the run.""" - arg = 'COMMENT' - arg_description = 'A comment that should be stored along with the run.' + arg = "COMMENT" + arg_description = "A comment that should be stored along with the run." @classmethod def apply(cls, args, run): """Add a comment to this run.""" - run.meta_info['comment'] = args + run.meta_info["comment"] = args class BeatIntervalOption(CommandLineOption): """Control the rate of heartbeat events.""" - arg = 'BEAT_INTERVAL' + arg = "BEAT_INTERVAL" arg_description = "Time between two heartbeat events measured in seconds." @classmethod @@ -210,9 +215,9 @@ def apply(cls, args, run): class PriorityOption(CommandLineOption): """Sets the priority for a queued up experiment.""" - short_flag = 'P' - arg = 'PRIORITY' - arg_description = 'The (numeric) priority for this run.' + short_flag = "P" + arg = "PRIORITY" + arg_description = "The (numeric) priority for this run." @classmethod def apply(cls, args, run): @@ -220,9 +225,10 @@ def apply(cls, args, run): try: priority = float(args) except ValueError: - raise ValueError("The PRIORITY argument must be a number! " - "(but was '{}')".format(args)) - run.meta_info['priority'] = priority + raise ValueError( + "The PRIORITY argument must be a number! " "(but was '{}')".format(args) + ) + run.meta_info["priority"] = priority class EnforceCleanOption(CommandLineOption): @@ -233,20 +239,25 @@ def apply(cls, args, run): try: import git # NOQA except ImportError: - warnings.warn('GitPython must be installed to use the ' - '--enforce-clean option.') + warnings.warn( + "GitPython must be installed to use the " "--enforce-clean option." + ) raise - repos = run.experiment_info['repositories'] + repos = run.experiment_info["repositories"] if not repos: - raise RuntimeError('No version control detected. ' - 'Cannot enforce clean repository.\n' - 'Make sure that your sources under VCS and the ' - 'corresponding python package is installed.') + raise RuntimeError( + "No version control detected. " + "Cannot enforce clean repository.\n" + "Make sure that your sources under VCS and the " + "corresponding python package is installed." + ) else: for repo in repos: - if repo['dirty']: - raise RuntimeError('EnforceClean: Uncommited changes in ' - 'the "{}" repository.'.format(repo)) + if repo["dirty"]: + raise RuntimeError( + "EnforceClean: Uncommited changes in " + 'the "{}" repository.'.format(repo) + ) class PrintConfigOption(CommandLineOption): @@ -255,26 +266,26 @@ class PrintConfigOption(CommandLineOption): @classmethod def apply(cls, args, run): print_config(run) - print('-' * 79) + print("-" * 79) class NameOption(CommandLineOption): """Set the name for this run.""" - arg = 'NAME' - arg_description = 'Name for this run.' + arg = "NAME" + arg_description = "Name for this run." @classmethod def apply(cls, args, run): - run.experiment_info['name'] = args + run.experiment_info["name"] = args run.run_logger = run.root_logger.getChild(args) class CaptureOption(CommandLineOption): """Control the way stdout and stderr are captured.""" - short_flag = 'C' - arg = 'CAPTURE_MODE' + short_flag = "C" + arg = "CAPTURE_MODE" arg_description = "stdout/stderr capture mode. One of [no, sys, fd]" @classmethod diff --git a/sacred/commands.py b/sacred/commands.py index 52bac3da..d2c69f92 100644 --- a/sacred/commands.py +++ b/sacred/commands.py @@ -14,8 +14,13 @@ from sacred.serializer import flatten from sacred.utils import PATHCHANGE, iterate_flattened_separately -__all__ = ('print_config', 'print_dependencies', 'save_config', - 'help_for_command', 'print_named_configs') +__all__ = ( + "print_config", + "print_dependencies", + "save_config", + "help_for_command", + "print_named_configs", +) COLOR_DIRTY = Fore.RED COLOR_TYPECHANGED = Fore.RED # prepend Style.BRIGHT for bold @@ -24,15 +29,28 @@ COLOR_DOC = Style.DIM ENDC = Style.RESET_ALL # '\033[0m' -LEGEND = \ - '(' + COLOR_MODIFIED + 'modified' + ENDC +\ - ', ' + COLOR_ADDED + 'added' + ENDC +\ - ', ' + COLOR_TYPECHANGED + 'typechanged' + ENDC +\ - ', ' + COLOR_DOC + 'doc' + ENDC + ')' - -ConfigEntry = namedtuple('ConfigEntry', - 'key value added modified typechanged doc') -PathEntry = namedtuple('PathEntry', 'key added modified typechanged doc') +LEGEND = ( + "(" + + COLOR_MODIFIED + + "modified" + + ENDC + + ", " + + COLOR_ADDED + + "added" + + ENDC + + ", " + + COLOR_TYPECHANGED + + "typechanged" + + ENDC + + ", " + + COLOR_DOC + + "doc" + + ENDC + + ")" +) + +ConfigEntry = namedtuple("ConfigEntry", "key value added modified typechanged doc") +PathEntry = namedtuple("PathEntry", "key added modified typechanged doc") def _non_unicode_repr(objekt, context, maxlevels, level): @@ -41,8 +59,9 @@ def _non_unicode_repr(objekt, context, maxlevels, level): E.g.: 'John' instead of u'John'. """ - repr_string, isreadable, isrecursive = pprint._safe_repr(objekt, context, - maxlevels, level) + repr_string, isreadable, isrecursive = pprint._safe_repr( + objekt, context, maxlevels, level + ) if repr_string.startswith('u"') or repr_string.startswith("u'"): repr_string = repr_string[1:] return repr_string, isreadable, isrecursive @@ -67,26 +86,27 @@ def print_config(_run): def _format_named_config(indent, path, named_config): - indent = ' ' * indent + indent = " " * indent assign = path - if hasattr(named_config, '__doc__') and named_config.__doc__ is not None: + if hasattr(named_config, "__doc__") and named_config.__doc__ is not None: doc_string = named_config.__doc__ - if doc_string.strip().count('\n') == 0: - assign += COLOR_DOC + ' # {}'.format(doc_string.strip()) + ENDC + if doc_string.strip().count("\n") == 0: + assign += COLOR_DOC + " # {}".format(doc_string.strip()) + ENDC else: - doc_string = doc_string.replace('\n', '\n' + indent) - assign += COLOR_DOC + '\n{}"""{}"""'.format(indent + ' ', - doc_string) + ENDC + doc_string = doc_string.replace("\n", "\n" + indent) + assign += ( + COLOR_DOC + '\n{}"""{}"""'.format(indent + " ", doc_string) + ENDC + ) return indent + assign def _format_named_configs(named_configs, indent=2): - lines = ['Named Configurations (' + COLOR_DOC + 'doc' + ENDC + '):'] + lines = ["Named Configurations (" + COLOR_DOC + "doc" + ENDC + "):"] for path, named_config in named_configs.items(): lines.append(_format_named_config(indent, path, named_config)) if len(lines) < 2: - lines.append(' ' * indent + 'No named configs') - return '\n'.join(lines) + lines.append(" " * indent + "No named configs") + return "\n".join(lines) def print_named_configs(ingredient): @@ -111,99 +131,100 @@ def help_for_command(command): """Get the help text (signature + docstring) for a command (function).""" help_text = pydoc.text.document(command) # remove backspaces - return re.subn('.\\x08', '', help_text)[0] + return re.subn(".\\x08", "", help_text)[0] def print_dependencies(_run): """Print the detected source-files and dependencies.""" - print('Dependencies:') - for dep in _run.experiment_info['dependencies']: - pack, _, version = dep.partition('==') - print(' {:<20} == {}'.format(pack, version)) - - print('\nSources:') - for source, digest in _run.experiment_info['sources']: - print(' {:<43} {}'.format(source, digest)) - - if _run.experiment_info['repositories']: - repos = _run.experiment_info['repositories'] - print('\nVersion Control:') + print("Dependencies:") + for dep in _run.experiment_info["dependencies"]: + pack, _, version = dep.partition("==") + print(" {:<20} == {}".format(pack, version)) + + print("\nSources:") + for source, digest in _run.experiment_info["sources"]: + print(" {:<43} {}".format(source, digest)) + + if _run.experiment_info["repositories"]: + repos = _run.experiment_info["repositories"] + print("\nVersion Control:") for repo in repos: - mod = COLOR_DIRTY + 'M' if repo['dirty'] else ' ' - print('{} {:<43} {}'.format(mod, repo['url'], repo['commit']) + - ENDC) - print('') + mod = COLOR_DIRTY + "M" if repo["dirty"] else " " + print("{} {:<43} {}".format(mod, repo["url"], repo["commit"]) + ENDC) + print("") -def save_config(_config, _log, config_filename='config.json'): +def save_config(_config, _log, config_filename="config.json"): """ Store the updated configuration in a file. By default uses the filename "config.json", but that can be changed by setting the config_filename config entry. """ - if 'config_filename' in _config: - del _config['config_filename'] + if "config_filename" in _config: + del _config["config_filename"] _log.info('Saving config to "{}"'.format(config_filename)) save_config_file(flatten(_config), config_filename) def _iterate_marked(cfg, config_mods): - for path, value in iterate_flattened_separately(cfg, ['__doc__']): + for path, value in iterate_flattened_separately(cfg, ["__doc__"]): if value is PATHCHANGE: yield path, PathEntry( - key=path.rpartition('.')[2], + key=path.rpartition(".")[2], added=path in config_mods.added, modified=path in config_mods.modified, typechanged=config_mods.typechanged.get(path), - doc=config_mods.docs.get(path)) + doc=config_mods.docs.get(path), + ) else: yield path, ConfigEntry( - key=path.rpartition('.')[2], + key=path.rpartition(".")[2], value=value, added=path in config_mods.added, modified=path in config_mods.modified, typechanged=config_mods.typechanged.get(path), - doc=config_mods.docs.get(path)) + doc=config_mods.docs.get(path), + ) def _format_entry(indent, entry): color = "" - indent = ' ' * indent + indent = " " * indent if entry.typechanged: color = COLOR_TYPECHANGED # red elif entry.added: color = COLOR_ADDED # green elif entry.modified: color = COLOR_MODIFIED # blue - if entry.key == '__doc__': + if entry.key == "__doc__": color = COLOR_DOC # grey - doc_string = entry.value.replace('\n', '\n' + indent) + doc_string = entry.value.replace("\n", "\n" + indent) assign = '{}"""{}"""'.format(indent, doc_string) elif isinstance(entry, ConfigEntry): assign = indent + entry.key + " = " + PRINTER.pformat(entry.value) else: # isinstance(entry, PathEntry): assign = indent + entry.key + ":" if entry.doc: - doc_string = COLOR_DOC + '# ' + entry.doc + ENDC + doc_string = COLOR_DOC + "# " + entry.doc + ENDC if len(assign) <= 35: assign = "{:<35} {}".format(assign, doc_string) else: - assign += ' ' + doc_string + assign += " " + doc_string end = ENDC if color else "" return color + assign + end def _format_config(cfg, config_mods): - lines = ['Configuration ' + LEGEND + ':'] + lines = ["Configuration " + LEGEND + ":"] for path, entry in _iterate_marked(cfg, config_mods): - indent = 2 + 2 * path.count('.') + indent = 2 + 2 * path.count(".") lines.append(_format_entry(indent, entry)) return "\n".join(lines) -def _write_file(base_dir, filename, content, mode='t'): +def _write_file(base_dir, filename, content, mode="t"): full_name = os.path.join(base_dir, filename) os.makedirs(os.path.dirname(full_name), exist_ok=True) - with open(full_name, 'w' + mode) as f: + with open(full_name, "w" + mode) as f: f.write(content) diff --git a/sacred/config/__init__.py b/sacred/config/__init__.py index 19a46fac..f287c2c0 100644 --- a/sacred/config/__init__.py +++ b/sacred/config/__init__.py @@ -5,9 +5,15 @@ from sacred.config.config_scope import ConfigScope from sacred.config.config_files import load_config_file, save_config_file from sacred.config.captured_function import create_captured_function -from sacred.config.utils import ( - chain_evaluate_config_scopes, dogmatize, undogmatize) +from sacred.config.utils import chain_evaluate_config_scopes, dogmatize, undogmatize -__all__ = ('ConfigDict', 'ConfigScope', 'load_config_file', 'save_config_file', - 'create_captured_function', 'chain_evaluate_config_scopes', - 'dogmatize', 'undogmatize') +__all__ = ( + "ConfigDict", + "ConfigScope", + "load_config_file", + "save_config_file", + "create_captured_function", + "chain_evaluate_config_scopes", + "dogmatize", + "undogmatize", +) diff --git a/sacred/config/captured_function.py b/sacred/config/captured_function.py index b60cd4dd..780f9ebf 100644 --- a/sacred/config/captured_function.py +++ b/sacred/config/captured_function.py @@ -14,8 +14,7 @@ def create_captured_function(function, prefix=None): sig = Signature(function) function.signature = sig - function.uses_randomness = ("_seed" in sig.arguments or - "_rnd" in sig.arguments) + function.uses_randomness = "_seed" in sig.arguments or "_rnd" in sig.arguments function.logger = None function.config = {} function.rnd = None @@ -27,18 +26,14 @@ def create_captured_function(function, prefix=None): @wrapt.decorator def captured_function(wrapped, instance, args, kwargs): options = fallback_dict( - wrapped.config, - _config=wrapped.config, - _log=wrapped.logger, - _run=wrapped.run + wrapped.config, _config=wrapped.config, _log=wrapped.logger, _run=wrapped.run ) if wrapped.uses_randomness: # only generate _seed and _rnd if needed - options['_seed'] = get_seed(wrapped.rnd) - options['_rnd'] = create_rnd(options['_seed']) + options["_seed"] = get_seed(wrapped.rnd) + options["_rnd"] = create_rnd(options["_seed"]) - bound = (instance is not None) - args, kwargs = wrapped.signature.construct_arguments(args, kwargs, options, - bound) + bound = instance is not None + args, kwargs = wrapped.signature.construct_arguments(args, kwargs, options, bound) if wrapped.logger is not None: wrapped.logger.debug("Started") start_time = time.time() diff --git a/sacred/config/config_dict.py b/sacred/config/config_dict.py index 0942c735..c73d83ff 100644 --- a/sacred/config/config_dict.py +++ b/sacred/config/config_dict.py @@ -2,8 +2,12 @@ # coding=utf-8 from sacred.config.config_summary import ConfigSummary -from sacred.config.utils import (dogmatize, normalize_or_die, undogmatize, - recursive_fill_in) +from sacred.config.utils import ( + dogmatize, + normalize_or_die, + undogmatize, + recursive_fill_in, +) class ConfigDict: @@ -15,7 +19,6 @@ def __call__(self, fixed=None, preset=None, fallback=None): recursive_fill_in(result, self._conf) recursive_fill_in(result, preset or {}) added = result.revelation() - config_summary = ConfigSummary(added, result.modified, - result.typechanges) + config_summary = ConfigSummary(added, result.modified, result.typechanges) config_summary.update(undogmatize(result)) return config_summary diff --git a/sacred/config/config_files.py b/sacred/config/config_files.py index f6e637c6..fd3f4cac 100644 --- a/sacred/config/config_files.py +++ b/sacred/config/config_files.py @@ -9,7 +9,7 @@ import sacred.optional as opt from sacred.serializer import flatten, restore -__all__ = ('load_config_file', 'save_config_file') +__all__ = ("load_config_file", "save_config_file") class Handler: @@ -20,17 +20,21 @@ def __init__(self, load, dump, mode): HANDLER_BY_EXT = { - '.json': Handler(lambda fp: restore(json.load(fp)), - lambda obj, fp: json.dump(flatten(obj), fp, - sort_keys=True, indent=2), ''), - '.pickle': Handler(pickle.load, pickle.dump, 'b'), + ".json": Handler( + lambda fp: restore(json.load(fp)), + lambda obj, fp: json.dump(flatten(obj), fp, sort_keys=True, indent=2), + "", + ), + ".pickle": Handler(pickle.load, pickle.dump, "b"), } -yaml_extensions = ('.yaml', '.yml') +yaml_extensions = (".yaml", ".yml") if opt.has_yaml: + def load_yaml(filename): return opt.yaml.load(filename, Loader=opt.yaml.FullLoader) - yaml_handler = Handler(load_yaml, opt.yaml.dump, '') + + yaml_handler = Handler(load_yaml, opt.yaml.dump, "") for extension in yaml_extensions: HANDLER_BY_EXT[extension] = yaml_handler @@ -39,8 +43,10 @@ def load_yaml(filename): def get_handler(filename): _, extension = os.path.splitext(filename) if extension in yaml_extensions and not opt.has_yaml: - raise KeyError('Configuration file "{}" cannot be loaded as ' - 'you do not have PyYAML installed.'.format(filename)) + raise KeyError( + 'Configuration file "{}" cannot be loaded as ' + "you do not have PyYAML installed.".format(filename) + ) try: return HANDLER_BY_EXT[extension] except KeyError: @@ -52,11 +58,11 @@ def get_handler(filename): def load_config_file(filename): handler = get_handler(filename) - with open(filename, 'r' + handler.mode) as f: + with open(filename, "r" + handler.mode) as f: return handler.load(f) def save_config_file(config, filename): handler = get_handler(filename) - with open(filename, 'w' + handler.mode) as f: + with open(filename, "w" + handler.mode) as f: handler.dump(config, f) diff --git a/sacred/config/config_scope.py b/sacred/config/config_scope.py index 721cfa2a..0441db71 100644 --- a/sacred/config/config_scope.py +++ b/sacred/config/config_scope.py @@ -17,12 +17,9 @@ class ConfigScope: def __init__(self, func): self.args, vararg_name, kw_wildcard, _, kwargs = get_argspec(func) - assert vararg_name is None, \ - "*args not allowed for ConfigScope functions" - assert kw_wildcard is None, \ - "**kwargs not allowed for ConfigScope functions" - assert not kwargs, \ - "default values are not allowed for ConfigScope functions" + assert vararg_name is None, "*args not allowed for ConfigScope functions" + assert kw_wildcard is None, "**kwargs not allowed for ConfigScope functions" + assert not kwargs, "default values are not allowed for ConfigScope functions" self._func = func self._body_code = get_function_body_code(func) @@ -61,9 +58,10 @@ def __call__(self, fixed=None, preset=None, fallback=None): for arg in self.args: if arg not in available_entries: - raise KeyError("'{}' not in preset for ConfigScope. " - "Available options are: {}" - .format(arg, available_entries)) + raise KeyError( + "'{}' not in preset for ConfigScope. " + "Available options are: {}".format(arg, available_entries) + ) if arg in preset: cfg_locals[arg] = preset[arg] else: # arg in fallback @@ -73,10 +71,13 @@ def __call__(self, fixed=None, preset=None, fallback=None): eval(self._body_code, copy(self._func.__globals__), cfg_locals) added = cfg_locals.revelation() - config_summary = ConfigSummary(added, cfg_locals.modified, - cfg_locals.typechanges, - cfg_locals.fallback_writes, - docs=self._var_docs) + config_summary = ConfigSummary( + added, + cfg_locals.modified, + cfg_locals.typechanges, + cfg_locals.fallback_writes, + docs=self._var_docs, + ) # fill in the unused presets recursive_fill_in(cfg_locals, preset) @@ -90,26 +91,29 @@ def __call__(self, fixed=None, preset=None, fallback=None): def get_function_body(func): func_code_lines, start_idx = inspect.getsourcelines(func) - func_code = ''.join(func_code_lines) + func_code = "".join(func_code_lines) arg = "(?:[a-zA-Z_][a-zA-Z0-9_]*)" arguments = r"{0}(?:\s*,\s*{0})*".format(arg) func_def = re.compile( r"^[ \t]*def[ \t]*{}[ \t]*\(\s*({})?\s*\)[ \t]*:[ \t]*\n".format( - func.__name__, arguments), flags=re.MULTILINE) + func.__name__, arguments + ), + flags=re.MULTILINE, + ) defs = list(re.finditer(func_def, func_code)) assert defs - line_offset = start_idx + func_code[:defs[0].end()].count('\n') - 1 - func_body = func_code[defs[0].end():] + line_offset = start_idx + func_code[: defs[0].end()].count("\n") - 1 + func_body = func_code[defs[0].end() :] return func_body, line_offset def is_empty_or_comment(line): sline = line.strip() - return sline == '' or sline.startswith('#') + return sline == "" or sline.startswith("#") def iscomment(line): - return line.strip().startswith('#') + return line.strip().startswith("#") def dedent_line(line, indent): @@ -123,18 +127,18 @@ def dedent_line(line, indent): def dedent_function_body(body): - lines = body.split('\n') + lines = body.split("\n") # find indentation by first line - indent = '' + indent = "" for line in lines: if is_empty_or_comment(line): continue else: - indent = re.match(r'^\s*', line).group() + indent = re.match(r"^\s*", line).group() break out_lines = [dedent_line(line, indent) for line in lines] - return '\n'.join(out_lines) + return "\n".join(out_lines) def get_function_body_code(func): @@ -148,14 +152,20 @@ def get_function_body_code(func): except SyntaxError as e: if e.args[0] == "'return' outside function": filename, lineno, _, statement = e.args[1] - raise SyntaxError('No return statements allowed in ConfigScopes\n' - '(\'{}\' in File "{}", line {})'.format( - statement.strip(), filename, lineno)) + raise SyntaxError( + "No return statements allowed in ConfigScopes\n" + "('{}' in File \"{}\", line {})".format( + statement.strip(), filename, lineno + ) + ) elif e.args[0] == "'yield' outside function": filename, lineno, _, statement = e.args[1] - raise SyntaxError('No yield statements allowed in ConfigScopes\n' - '(\'{}\' in File "{}", line {})'.format( - statement.strip(), filename, lineno)) + raise SyntaxError( + "No yield statements allowed in ConfigScopes\n" + "('{}' in File \"{}\", line {})".format( + statement.strip(), filename, lineno + ) + ) else: raise return body_code @@ -186,10 +196,10 @@ def find_doc_for(ast_entry, body_lines): lineno -= 1 while lineno >= 0: if iscomment(body_lines[lineno]): - comment = body_lines[lineno].strip('# ') + comment = body_lines[lineno].strip("# ") if not is_ignored(comment): return comment - if not body_lines[lineno].strip() == '': + if not body_lines[lineno].strip() == "": return None lineno -= 1 return None @@ -216,9 +226,9 @@ def get_config_comments(func): func_body, line_offset = get_function_body(func) body_source = dedent_function_body(func_body) body_code = compile(body_source, filename, "exec", ast.PyCF_ONLY_AST) - body_lines = body_source.split('\n') + body_lines = body_source.split("\n") - variables = {'seed': 'the random seed for this experiment'} + variables = {"seed": "the random seed for this experiment"} for ast_root in body_code.body: for ast_entry in [ast_root] + list(ast.iter_child_nodes(ast_root)): diff --git a/sacred/config/config_summary.py b/sacred/config/config_summary.py index 09f7f432..7bbf8e9d 100644 --- a/sacred/config/config_summary.py +++ b/sacred/config/config_summary.py @@ -5,8 +5,9 @@ class ConfigSummary(dict): - def __init__(self, added=(), modified=(), typechanged=(), - ignored_fallbacks=(), docs=()): + def __init__( + self, added=(), modified=(), typechanged=(), ignored_fallbacks=(), docs=() + ): super().__init__() self.added = set(added) self.modified = set(modified) # TODO: test for this member @@ -15,38 +16,43 @@ def __init__(self, added=(), modified=(), typechanged=(), self.docs = dict(docs) self.ensure_coherence() - def update_from(self, config_mod, path=''): + def update_from(self, config_mod, path=""): added = config_mod.added updated = config_mod.modified typechanged = config_mod.typechanged self.added &= {join_paths(path, a) for a in added} self.modified |= {join_paths(path, u) for u in updated} - self.typechanged.update({join_paths(path, k): v - for k, v in typechanged.items()}) + self.typechanged.update( + {join_paths(path, k): v for k, v in typechanged.items()} + ) self.ensure_coherence() for k, v in config_mod.docs.items(): - if not self.docs.get(k, ''): + if not self.docs.get(k, ""): self.docs[k] = v - def update_add(self, config_mod, path=''): + def update_add(self, config_mod, path=""): added = config_mod.added updated = config_mod.modified typechanged = config_mod.typechanged self.added |= {join_paths(path, a) for a in added} self.modified |= {join_paths(path, u) for u in updated} - self.typechanged.update({join_paths(path, k): v - for k, v in typechanged.items()}) - self.docs.update({join_paths(path, k): v - for k, v in config_mod.docs.items() - if path == '' or k != 'seed'}) + self.typechanged.update( + {join_paths(path, k): v for k, v in typechanged.items()} + ) + self.docs.update( + { + join_paths(path, k): v + for k, v in config_mod.docs.items() + if path == "" or k != "seed" + } + ) self.ensure_coherence() def ensure_coherence(self): # make sure parent paths show up as updated appropriately self.modified |= {p for a in self.added for p in iter_prefixes(a)} self.modified |= {p for u in self.modified for p in iter_prefixes(u)} - self.modified |= {p for t in self.typechanged - for p in iter_prefixes(t)} + self.modified |= {p for t in self.typechanged for p in iter_prefixes(t)} # make sure there is no overlap self.added -= set(self.typechanged.keys()) diff --git a/sacred/config/custom_containers.py b/sacred/config/custom_containers.py index bc4fa17f..d957681f 100644 --- a/sacred/config/custom_containers.py +++ b/sacred/config/custom_containers.py @@ -96,7 +96,7 @@ def __delitem__(self, key): def update(self, iterable=None, **kwargs): if iterable is not None: - if hasattr(iterable, 'keys'): + if hasattr(iterable, "keys"): for key in iterable: self[key] = iterable[key] else: @@ -152,7 +152,7 @@ def __delslice__(self, i, j): pass def pop(self, index=None): - raise TypeError('Cannot pop from DogmaticList') + raise TypeError("Cannot pop from DogmaticList") def remove(self, value): pass @@ -167,19 +167,17 @@ def revelation(self): class ReadOnlyContainer: def __init__(self, *args, message=None, **kwargs): super().__init__(*args, **kwargs) - self.message = message or 'This container is read-only!' + self.message = message or "This container is read-only!" def _readonly(self, *args, **kwargs): - raise SacredError( - self.message, - filter_traceback='always' - ) + raise SacredError(self.message, filter_traceback="always") class ReadOnlyDict(ReadOnlyContainer, dict): """ A read-only variant of a `dict` """ + # Overwrite all methods that can modify a dict clear = ReadOnlyContainer._readonly pop = ReadOnlyContainer._readonly @@ -191,7 +189,7 @@ class ReadOnlyDict(ReadOnlyContainer, dict): def __init__(self, *args, message=None, **kwargs): if message is None: - message = 'This ReadOnlyDict is read-only!' + message = "This ReadOnlyDict is read-only!" super().__init__(*args, message=message, **kwargs) def __copy__(self): @@ -206,6 +204,7 @@ class ReadOnlyList(ReadOnlyContainer, list): """ A read-only variant of a `list` """ + append = ReadOnlyContainer._readonly clear = ReadOnlyContainer._readonly extend = ReadOnlyContainer._readonly @@ -219,7 +218,7 @@ class ReadOnlyList(ReadOnlyContainer, list): def __init__(self, *iterable, message=None, **kwargs): if message is None: - message = 'This ReadOnlyList is read-only!' + message = "This ReadOnlyList is read-only!" super().__init__(*iterable, message=message, **kwargs) def __copy__(self): @@ -239,11 +238,12 @@ def make_read_only(o, error_message=None): if type(o) == dict: return ReadOnlyDict( {k: make_read_only(v, error_message) for k, v in o.items()}, - message=error_message) + message=error_message, + ) elif type(o) == list: return ReadOnlyList( - [make_read_only(v, error_message) for v in o], - message=error_message) + [make_read_only(v, error_message) for v in o], message=error_message + ) elif type(o) == tuple: return tuple(map(make_read_only, o)) else: @@ -267,13 +267,24 @@ def make_read_only(o, error_message=None): # datatypes to the corresponding python datatype if opt.has_numpy: from sacred.optional import np - NP_FLOATS = ['float', 'float16', 'float32', 'float64', 'float128'] + + NP_FLOATS = ["float", "float16", "float32", "float64", "float128"] for npf in NP_FLOATS: if hasattr(np, npf): SIMPLIFY_TYPE[getattr(np, npf)] = float - NP_INTS = ['int', 'int8', 'int16', 'int32', 'int64', - 'uint', 'uint8', 'uint16', 'uint32', 'uint64'] + NP_INTS = [ + "int", + "int8", + "int16", + "int32", + "int64", + "uint", + "uint8", + "uint16", + "uint32", + "uint64", + ] for npi in NP_INTS: if hasattr(np, npi): SIMPLIFY_TYPE[getattr(np, npi)] = int diff --git a/sacred/config/signature.py b/sacred/config/signature.py index 72bdff9e..d2494ccc 100644 --- a/sacred/config/signature.py +++ b/sacred/config/signature.py @@ -6,31 +6,38 @@ from collections import OrderedDict from sacred.utils import MissingConfigError, SignatureError -ARG_TYPES = [Parameter.POSITIONAL_ONLY, - Parameter.POSITIONAL_OR_KEYWORD, - Parameter.KEYWORD_ONLY] -POSARG_TYPES = [Parameter.POSITIONAL_ONLY, - Parameter.POSITIONAL_OR_KEYWORD] +ARG_TYPES = [ + Parameter.POSITIONAL_ONLY, + Parameter.POSITIONAL_OR_KEYWORD, + Parameter.KEYWORD_ONLY, +] +POSARG_TYPES = [Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD] def get_argspec(f): sig = inspect.signature(f) - args = [n for n, p in sig.parameters.items() - if p.kind in ARG_TYPES] - pos_args = [n for n, p in sig.parameters.items() - if p.kind in POSARG_TYPES and p.default == inspect._empty] - varargs = [n for n, p in sig.parameters.items() - if p.kind == Parameter.VAR_POSITIONAL] + args = [n for n, p in sig.parameters.items() if p.kind in ARG_TYPES] + pos_args = [ + n + for n, p in sig.parameters.items() + if p.kind in POSARG_TYPES and p.default == inspect._empty + ] + varargs = [ + n for n, p in sig.parameters.items() if p.kind == Parameter.VAR_POSITIONAL + ] # only use first vararg (how on earth would you have more anyways?) vararg_name = varargs[0] if varargs else None - varkws = [n for n, p in sig.parameters.items() - if p.kind == Parameter.VAR_KEYWORD] + varkws = [n for n, p in sig.parameters.items() if p.kind == Parameter.VAR_KEYWORD] # only use first varkw (how on earth would you have more anyways?) kw_wildcard_name = varkws[0] if varkws else None - kwargs = OrderedDict([(n, p.default) - for n, p in sig.parameters.items() - if p.default != inspect._empty]) + kwargs = OrderedDict( + [ + (n, p.default) + for n, p in sig.parameters.items() + if p.default != inspect._empty + ] + ) return args, vararg_name, kw_wildcard_name, pos_args, kwargs @@ -58,7 +65,7 @@ def __init__(self, f): def get_free_parameters(self, args, kwargs, bound=False): expected_args = self._get_expected_args(bound) - return [a for a in expected_args[len(args):] if a not in kwargs] + return [a for a in expected_args[len(args) :] if a not in kwargs] def construct_arguments(self, args, kwargs, options, bound=False): """ @@ -86,8 +93,7 @@ def construct_arguments(self, args, kwargs, options, bound=False): def __str__(self): pos_args = self.positional_args varg = ["*" + self.vararg_name] if self.vararg_name else [] - kwargs = ["{}={}".format(n, v.__repr__()) - for n, v in self.kwargs.items()] + kwargs = ["{}={}".format(n, v.__repr__()) for n, v in self.kwargs.items()] kw_wc = ["**" + self.kw_wildcard_name] if self.kw_wildcard_name else [] arglist = pos_args + varg + kwargs + kw_wc return "{}({})".format(self.name, ", ".join(arglist)) @@ -106,25 +112,31 @@ def _get_expected_args(self, bound): def _assert_no_unexpected_args(self, expected_args, args): if not self.vararg_name and len(args) > len(expected_args): - unexpected_args = args[len(expected_args):] - raise SignatureError("{} got unexpected argument(s): {}".format( - self.name, unexpected_args)) + unexpected_args = args[len(expected_args) :] + raise SignatureError( + "{} got unexpected argument(s): {}".format(self.name, unexpected_args) + ) def _assert_no_unexpected_kwargs(self, expected_args, kwargs): if self.kw_wildcard_name: return unexpected_kwargs = set(kwargs) - set(expected_args) if unexpected_kwargs: - raise SignatureError("{} got unexpected kwarg(s): {}".format( - self.name, sorted(unexpected_kwargs))) + raise SignatureError( + "{} got unexpected kwarg(s): {}".format( + self.name, sorted(unexpected_kwargs) + ) + ) def _assert_no_duplicate_args(self, expected_args, args, kwargs): - positional_arguments = expected_args[:len(args)] + positional_arguments = expected_args[: len(args)] duplicate_arguments = [v for v in positional_arguments if v in kwargs] if duplicate_arguments: raise SignatureError( "{} got multiple values for argument(s) {}".format( - self.name, duplicate_arguments)) + self.name, duplicate_arguments + ) + ) def _fill_in_options(self, args, kwargs, options, bound): free_params = self.get_free_parameters(args, kwargs, bound) @@ -139,6 +151,6 @@ def _assert_no_missing_args(self, args, kwargs, bound): missing_args = [m for m in free_params if m not in self.kwargs] if missing_args: raise MissingConfigError( - '{} is missing value(s):'.format(self.name), - missing_configs=missing_args + "{} is missing value(s):".format(self.name), + missing_configs=missing_args, ) diff --git a/sacred/config/utils.py b/sacred/config/utils.py index 3829e8ac..b12eb68f 100644 --- a/sacred/config/utils.py +++ b/sacred/config/utils.py @@ -37,31 +37,39 @@ def assert_is_valid_key(key): """ if SETTINGS.CONFIG.ENFORCE_KEYS_MONGO_COMPATIBLE and ( - isinstance(key, str) and ('.' in key or key[0] == '$')): - raise KeyError('Invalid key "{}". Config-keys cannot ' - 'contain "." or start with "$"'.format(key)) - - if SETTINGS.CONFIG.ENFORCE_KEYS_JSONPICKLE_COMPATIBLE and \ - isinstance(key, str) and ( - key in jsonpickle.tags.RESERVED or key.startswith('json://')): - raise KeyError('Invalid key "{}". Config-keys cannot be one of the' - 'reserved jsonpickle tags: {}' - .format(key, jsonpickle.tags.RESERVED)) - - if SETTINGS.CONFIG.ENFORCE_STRING_KEYS and ( - not isinstance(key, str)): - raise KeyError('Invalid key "{}". Config-keys have to be strings, ' - 'but was {}'.format(key, type(key))) + isinstance(key, str) and ("." in key or key[0] == "$") + ): + raise KeyError( + 'Invalid key "{}". Config-keys cannot ' + 'contain "." or start with "$"'.format(key) + ) + + if ( + SETTINGS.CONFIG.ENFORCE_KEYS_JSONPICKLE_COMPATIBLE + and isinstance(key, str) + and (key in jsonpickle.tags.RESERVED or key.startswith("json://")) + ): + raise KeyError( + 'Invalid key "{}". Config-keys cannot be one of the' + "reserved jsonpickle tags: {}".format(key, jsonpickle.tags.RESERVED) + ) + + if SETTINGS.CONFIG.ENFORCE_STRING_KEYS and (not isinstance(key, str)): + raise KeyError( + 'Invalid key "{}". Config-keys have to be strings, ' + "but was {}".format(key, type(key)) + ) if SETTINGS.CONFIG.ENFORCE_VALID_PYTHON_IDENTIFIER_KEYS and ( - isinstance(key, str) and not PYTHON_IDENTIFIER.match(key)): - raise KeyError('Key "{}" is not a valid python identifier' - .format(key)) + isinstance(key, str) and not PYTHON_IDENTIFIER.match(key) + ): + raise KeyError('Key "{}" is not a valid python identifier'.format(key)) - if SETTINGS.CONFIG.ENFORCE_KEYS_NO_EQUALS and ( - isinstance(key, str) and '=' in key): - raise KeyError('Invalid key "{}". Config keys may not contain an' - 'equals sign ("=").'.format('=')) + if SETTINGS.CONFIG.ENFORCE_KEYS_NO_EQUALS and (isinstance(key, str) and "=" in key): + raise KeyError( + 'Invalid key "{}". Config keys may not contain an' + 'equals sign ("=").'.format("=") + ) def normalize_numpy(obj): @@ -93,16 +101,13 @@ def recursive_fill_in(config, preset): recursive_fill_in(config[key], preset[key]) -def chain_evaluate_config_scopes(config_scopes, fixed=None, preset=None, - fallback=None): +def chain_evaluate_config_scopes(config_scopes, fixed=None, preset=None, fallback=None): fixed = fixed or {} fallback = fallback or {} final_config = dict(preset or {}) config_summaries = [] for config in config_scopes: - cfg = config(fixed=fixed, - preset=final_config, - fallback=fallback) + cfg = config(fixed=fixed, preset=final_config, fallback=fallback) config_summaries.append(cfg) final_config.update(cfg) diff --git a/sacred/dependencies.py b/sacred/dependencies.py index a4bba2b9..9ca8c442 100644 --- a/sacred/dependencies.py +++ b/sacred/dependencies.py @@ -18,66 +18,351 @@ MODULE_BLACKLIST = set(sys.builtin_module_names) # sadly many builtins are missing from the above, so we list them manually: MODULE_BLACKLIST |= { - None, '__future__', '_abcoll', '_bootlocale', '_bsddb', '_bz2', - '_codecs_cn', '_codecs_hk', '_codecs_iso2022', '_codecs_jp', '_codecs_kr', - '_codecs_tw', '_collections_abc', '_compat_pickle', '_compression', - '_crypt', '_csv', '_ctypes', '_ctypes_test', '_curses', '_curses_panel', - '_dbm', '_decimal', '_dummy_thread', '_elementtree', '_gdbm', '_hashlib', - '_hotshot', '_json', '_lsprof', '_LWPCookieJar', '_lzma', '_markupbase', - '_MozillaCookieJar', '_multibytecodec', '_multiprocessing', '_opcode', - '_osx_support', '_pydecimal', '_pyio', '_sitebuiltins', '_sqlite3', - '_ssl', '_strptime', '_sysconfigdata', '_sysconfigdata_m', - '_sysconfigdata_nd', '_testbuffer', '_testcapi', '_testimportmultiple', - '_testmultiphase', '_threading_local', '_tkinter', '_weakrefset', 'abc', - 'aifc', 'antigravity', 'anydbm', 'argparse', 'ast', 'asynchat', 'asyncio', - 'asyncore', 'atexit', 'audiodev', 'audioop', 'base64', 'BaseHTTPServer', - 'Bastion', 'bdb', 'binhex', 'bisect', 'bsddb', 'bz2', 'calendar', - 'Canvas', 'CDROM', 'cgi', 'CGIHTTPServer', 'cgitb', 'chunk', 'cmath', - 'cmd', 'code', 'codecs', 'codeop', 'collections', 'colorsys', 'commands', - 'compileall', 'compiler', 'concurrent', 'ConfigParser', 'configparser', - 'contextlib', 'Cookie', 'cookielib', 'copy', 'copy_reg', 'copyreg', - 'cProfile', 'crypt', 'csv', 'ctypes', 'curses', 'datetime', 'dbhash', - 'dbm', 'decimal', 'Dialog', 'difflib', 'dircache', 'dis', 'distutils', - 'DLFCN', 'doctest', 'DocXMLRPCServer', 'dumbdbm', 'dummy_thread', - 'dummy_threading', 'easy_install', 'email', 'encodings', 'ensurepip', - 'enum', 'filecmp', 'FileDialog', 'fileinput', 'FixTk', 'fnmatch', - 'formatter', 'fpectl', 'fpformat', 'fractions', 'ftplib', 'functools', - 'future_builtins', 'genericpath', 'getopt', 'getpass', 'gettext', 'glob', - 'gzip', 'hashlib', 'heapq', 'hmac', 'hotshot', 'html', 'htmlentitydefs', - 'htmllib', 'HTMLParser', 'http', 'httplib', 'idlelib', 'ihooks', - 'imaplib', 'imghdr', 'imp', 'importlib', 'imputil', 'IN', 'inspect', 'io', - 'ipaddress', 'json', 'keyword', 'lib2to3', 'linecache', 'linuxaudiodev', - 'locale', 'logging', 'lzma', 'macpath', 'macurl2path', 'mailbox', - 'mailcap', 'markupbase', 'md5', 'mhlib', 'mimetools', 'mimetypes', - 'MimeWriter', 'mimify', 'mmap', 'modulefinder', 'multifile', - 'multiprocessing', 'mutex', 'netrc', 'new', 'nis', 'nntplib', 'ntpath', - 'nturl2path', 'numbers', 'opcode', 'operator', 'optparse', 'os', - 'os2emxpath', 'ossaudiodev', 'parser', 'pathlib', 'pdb', 'pickle', - 'pickletools', 'pip', 'pipes', 'pkg_resources', 'pkgutil', 'platform', - 'plistlib', 'popen2', 'poplib', 'posixfile', 'posixpath', 'pprint', - 'profile', 'pstats', 'pty', 'py_compile', 'pyclbr', 'pydoc', 'pydoc_data', - 'pyexpat', 'Queue', 'queue', 'quopri', 'random', 're', 'readline', 'repr', - 'reprlib', 'resource', 'rexec', 'rfc822', 'rlcompleter', 'robotparser', - 'runpy', 'sched', 'ScrolledText', 'selectors', 'sets', 'setuptools', - 'sgmllib', 'sha', 'shelve', 'shlex', 'shutil', 'signal', 'SimpleDialog', - 'SimpleHTTPServer', 'SimpleXMLRPCServer', 'site', 'sitecustomize', - 'smtpd', 'smtplib', 'sndhdr', 'socket', 'SocketServer', 'socketserver', - 'sqlite3', 'sre', 'sre_compile', 'sre_constants', 'sre_parse', 'ssl', - 'stat', 'statistics', 'statvfs', 'string', 'StringIO', 'stringold', - 'stringprep', 'struct', 'subprocess', 'sunau', 'sunaudio', 'symbol', - 'symtable', 'sysconfig', 'tabnanny', 'tarfile', 'telnetlib', 'tempfile', - 'termios', 'test', 'textwrap', 'this', 'threading', 'timeit', 'Tix', - 'tkColorChooser', 'tkCommonDialog', 'Tkconstants', 'Tkdnd', - 'tkFileDialog', 'tkFont', 'tkinter', 'Tkinter', 'tkMessageBox', - 'tkSimpleDialog', 'toaiff', 'token', 'tokenize', 'trace', 'traceback', - 'tracemalloc', 'ttk', 'tty', 'turtle', 'types', 'TYPES', 'typing', - 'unittest', 'urllib', 'urllib2', 'urlparse', 'user', 'UserDict', - 'UserList', 'UserString', 'uu', 'uuid', 'venv', 'warnings', 'wave', - 'weakref', 'webbrowser', 'wheel', 'whichdb', 'wsgiref', 'xdrlib', 'xml', - 'xmllib', 'xmlrpc', 'xmlrpclib', 'xxlimited', 'zipapp', 'zipfile'} + None, + "__future__", + "_abcoll", + "_bootlocale", + "_bsddb", + "_bz2", + "_codecs_cn", + "_codecs_hk", + "_codecs_iso2022", + "_codecs_jp", + "_codecs_kr", + "_codecs_tw", + "_collections_abc", + "_compat_pickle", + "_compression", + "_crypt", + "_csv", + "_ctypes", + "_ctypes_test", + "_curses", + "_curses_panel", + "_dbm", + "_decimal", + "_dummy_thread", + "_elementtree", + "_gdbm", + "_hashlib", + "_hotshot", + "_json", + "_lsprof", + "_LWPCookieJar", + "_lzma", + "_markupbase", + "_MozillaCookieJar", + "_multibytecodec", + "_multiprocessing", + "_opcode", + "_osx_support", + "_pydecimal", + "_pyio", + "_sitebuiltins", + "_sqlite3", + "_ssl", + "_strptime", + "_sysconfigdata", + "_sysconfigdata_m", + "_sysconfigdata_nd", + "_testbuffer", + "_testcapi", + "_testimportmultiple", + "_testmultiphase", + "_threading_local", + "_tkinter", + "_weakrefset", + "abc", + "aifc", + "antigravity", + "anydbm", + "argparse", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audiodev", + "audioop", + "base64", + "BaseHTTPServer", + "Bastion", + "bdb", + "binhex", + "bisect", + "bsddb", + "bz2", + "calendar", + "Canvas", + "CDROM", + "cgi", + "CGIHTTPServer", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "commands", + "compileall", + "compiler", + "concurrent", + "ConfigParser", + "configparser", + "contextlib", + "Cookie", + "cookielib", + "copy", + "copy_reg", + "copyreg", + "cProfile", + "crypt", + "csv", + "ctypes", + "curses", + "datetime", + "dbhash", + "dbm", + "decimal", + "Dialog", + "difflib", + "dircache", + "dis", + "distutils", + "DLFCN", + "doctest", + "DocXMLRPCServer", + "dumbdbm", + "dummy_thread", + "dummy_threading", + "easy_install", + "email", + "encodings", + "ensurepip", + "enum", + "filecmp", + "FileDialog", + "fileinput", + "FixTk", + "fnmatch", + "formatter", + "fpectl", + "fpformat", + "fractions", + "ftplib", + "functools", + "future_builtins", + "genericpath", + "getopt", + "getpass", + "gettext", + "glob", + "gzip", + "hashlib", + "heapq", + "hmac", + "hotshot", + "html", + "htmlentitydefs", + "htmllib", + "HTMLParser", + "http", + "httplib", + "idlelib", + "ihooks", + "imaplib", + "imghdr", + "imp", + "importlib", + "imputil", + "IN", + "inspect", + "io", + "ipaddress", + "json", + "keyword", + "lib2to3", + "linecache", + "linuxaudiodev", + "locale", + "logging", + "lzma", + "macpath", + "macurl2path", + "mailbox", + "mailcap", + "markupbase", + "md5", + "mhlib", + "mimetools", + "mimetypes", + "MimeWriter", + "mimify", + "mmap", + "modulefinder", + "multifile", + "multiprocessing", + "mutex", + "netrc", + "new", + "nis", + "nntplib", + "ntpath", + "nturl2path", + "numbers", + "opcode", + "operator", + "optparse", + "os", + "os2emxpath", + "ossaudiodev", + "parser", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pip", + "pipes", + "pkg_resources", + "pkgutil", + "platform", + "plistlib", + "popen2", + "poplib", + "posixfile", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "py_compile", + "pyclbr", + "pydoc", + "pydoc_data", + "pyexpat", + "Queue", + "queue", + "quopri", + "random", + "re", + "readline", + "repr", + "reprlib", + "resource", + "rexec", + "rfc822", + "rlcompleter", + "robotparser", + "runpy", + "sched", + "ScrolledText", + "selectors", + "sets", + "setuptools", + "sgmllib", + "sha", + "shelve", + "shlex", + "shutil", + "signal", + "SimpleDialog", + "SimpleHTTPServer", + "SimpleXMLRPCServer", + "site", + "sitecustomize", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "SocketServer", + "socketserver", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "statvfs", + "string", + "StringIO", + "stringold", + "stringprep", + "struct", + "subprocess", + "sunau", + "sunaudio", + "symbol", + "symtable", + "sysconfig", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "this", + "threading", + "timeit", + "Tix", + "tkColorChooser", + "tkCommonDialog", + "Tkconstants", + "Tkdnd", + "tkFileDialog", + "tkFont", + "tkinter", + "Tkinter", + "tkMessageBox", + "tkSimpleDialog", + "toaiff", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "ttk", + "tty", + "turtle", + "types", + "TYPES", + "typing", + "unittest", + "urllib", + "urllib2", + "urlparse", + "user", + "UserDict", + "UserList", + "UserString", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "wheel", + "whichdb", + "wsgiref", + "xdrlib", + "xml", + "xmllib", + "xmlrpc", + "xmlrpclib", + "xxlimited", + "zipapp", + "zipfile", +} module = type(sys) -PEP440_VERSION_PATTERN = re.compile(r""" +PEP440_VERSION_PATTERN = re.compile( + r""" ^ (\d+!)? # epoch (\d[.\d]*(?<= \d)) # release @@ -85,14 +370,16 @@ (?:(\.post\d+))? # post-release (?:(\.dev\d+))? # development release $ -""", flags=re.VERBOSE) +""", + flags=re.VERBOSE, +) def get_py_file_if_possible(pyc_name): """Try to retrieve a X.py file for a given X.py[c] file.""" - if pyc_name.endswith(('.py', '.so', '.pyd')): + if pyc_name.endswith((".py", ".so", ".pyd")): return pyc_name - assert pyc_name.endswith('.pyc') + assert pyc_name.endswith(".pyc") non_compiled_file = pyc_name[:-1] if os.path.exists(non_compiled_file): return non_compiled_file @@ -102,7 +389,7 @@ def get_py_file_if_possible(pyc_name): def get_digest(filename): """Compute the MD5 hash for a given file.""" h = hashlib.md5() - with open(filename, 'rb') as f: + with open(filename, "rb") as f: data = f.read(1 * MB) while data: h.update(data) @@ -131,13 +418,14 @@ def get_commit_if_possible(filename): # git if opt.has_gitpython: from git import Repo, InvalidGitRepositoryError + try: directory = os.path.dirname(filename) repo = Repo(directory, search_parent_directories=True) try: path = repo.remote().url except ValueError: - path = 'git:/' + repo.working_dir + path = "git:/" + repo.working_dir is_dirty = repo.is_dirty() commit = repo.head.commit.hexsha return path, commit, is_dirty @@ -158,8 +446,7 @@ def __init__(self, filename, digest, repo, commit, isdirty): @staticmethod def create(filename): if not filename or not os.path.exists(filename): - raise ValueError('invalid filename or file not found "{}"' - .format(filename)) + raise ValueError('invalid filename or file not found "{}"'.format(filename)) main_file = get_py_file_if_possible(os.path.abspath(filename)) repo, commit, is_dirty = get_commit_if_possible(main_file) @@ -186,7 +473,7 @@ def __le__(self, other): return self.filename.__le__(other.filename) def __repr__(self): - return ''.format(self.filename) + return "".format(self.filename) @functools.total_ordering @@ -204,7 +491,7 @@ def fill_missing_version(self): self.version = dist.version if dist else None def to_json(self): - return '{}=={}'.format(self.name, self.version or '') + return "{}=={}".format(self.name, self.version or "") def __hash__(self): return hash(self.name) @@ -219,19 +506,18 @@ def __le__(self, other): return self.name.__le__(other.name) def __repr__(self): - return ''.format(self.name, self.version) + return "".format(self.name, self.version) @staticmethod def get_version_heuristic(mod): - possible_version_attributes = ['__version__', 'VERSION', 'version'] + possible_version_attributes = ["__version__", "VERSION", "version"] for vattr in possible_version_attributes: if hasattr(mod, vattr): version = getattr(mod, vattr) - if isinstance(version, str) and \ - PEP440_VERSION_PATTERN.match(version): + if isinstance(version, str) and PEP440_VERSION_PATTERN.match(version): return version if isinstance(version, tuple): - version = '.'.join([str(n) for n in version]) + version = ".".join([str(n) for n in version]) if PEP440_VERSION_PATTERN.match(version): return version @@ -244,16 +530,14 @@ def create(cls, mod): # so we set up a dict to map from module name to package name for dist in pkg_resources.working_set: try: - toplevel_names = dist._get_metadata('top_level.txt') + toplevel_names = dist._get_metadata("top_level.txt") for tln in toplevel_names: - cls.modname_to_dist[ - tln] = dist.project_name, dist.version + cls.modname_to_dist[tln] = dist.project_name, dist.version except Exception: pass # version = PackageDependency.get_version_heuristic(mod) - name, version = cls.modname_to_dist.get(mod.__name__, - (mod.__name__, None)) + name, version = cls.modname_to_dist.get(mod.__name__, (mod.__name__, None)) return PackageDependency(name, version) @@ -261,7 +545,7 @@ def create(cls, mod): def convert_path_to_module_parts(path): """Convert path to a python file into list of module names.""" module_parts = list(path.parts) - if module_parts[-1] in ['__init__.py', '__init__.pyc']: + if module_parts[-1] in ["__init__.py", "__init__.pyc"]: # remove trailing __init__.py module_parts = module_parts[:-1] else: @@ -301,24 +585,23 @@ def is_local_source(filename, modname, experiment_path): rel_path = filename.relative_to(experiment_path) path_parts = convert_path_to_module_parts(rel_path) - mod_parts = modname.split('.') + mod_parts = modname.split(".") if path_parts == mod_parts: return True if len(path_parts) > len(mod_parts): return False abs_path_parts = convert_path_to_module_parts(filename) - return all([p == m for p, m in zip(reversed(abs_path_parts), - reversed(mod_parts))]) + return all([p == m for p, m in zip(reversed(abs_path_parts), reversed(mod_parts))]) def get_main_file(globs): - filename = globs.get('__file__') + filename = globs.get("__file__") if filename is None: experiment_path = os.path.abspath(os.path.curdir) main = None else: - main = Source.create(globs.get('__file__')) + main = Source.create(globs.get("__file__")) experiment_path = os.path.dirname(main.filename) return experiment_path, main @@ -328,7 +611,7 @@ def iterate_imported_modules(globs): for glob in globs.values(): if isinstance(glob, module): mod_path = glob.__name__ - elif hasattr(glob, '__module__'): + elif hasattr(glob, "__module__"): mod_path = glob.__module__ else: continue # pragma: no cover @@ -348,10 +631,10 @@ def iterate_imported_modules(globs): def iterate_all_python_files(base_path): # TODO support ignored directories/files for dirname, subdirlist, filelist in os.walk(base_path): - if '__pycache__' in dirname: + if "__pycache__" in dirname: continue for filename in filelist: - if filename.endswith('.py'): + if filename.endswith(".py"): yield os.path.join(base_path, dirname, filename) @@ -366,12 +649,11 @@ def get_sources_from_modules(module_iterator, base_path): sources = set() for modname, mod in module_iterator: # hasattr doesn't work with python extensions - if not getattr(mod, '__file__', None): + if not getattr(mod, "__file__", None): continue filename = os.path.abspath(mod.__file__) - if filename not in sources and \ - is_local_source(filename, modname, base_path): + if filename not in sources and is_local_source(filename, modname, base_path): s = Source.create(filename) sources.add(s) return sources @@ -381,10 +663,11 @@ def get_dependencies_from_modules(module_iterator, base_path): dependencies = set() for modname, mod in module_iterator: # hasattr doesn't work with python extensions - if getattr(mod, '__file__', None) and is_local_source( - os.path.abspath(mod.__file__), modname, base_path): + if getattr(mod, "__file__", None) and is_local_source( + os.path.abspath(mod.__file__), modname, base_path + ): continue - if modname.startswith('_') or '.' in modname: + if modname.startswith("_") or "." in modname: continue try: @@ -405,8 +688,7 @@ def get_sources_from_imported_modules(globs, base_path): def get_sources_from_local_dir(globs, base_path): - return {Source.create(filename) - for filename in iterate_all_python_files(base_path)} + return {Source.create(filename) for filename in iterate_all_python_files(base_path)} def get_dependencies_from_sys_modules(globs, base_path): @@ -414,31 +696,30 @@ def get_dependencies_from_sys_modules(globs, base_path): def get_dependencies_from_imported_modules(globs, base_path): - return get_dependencies_from_modules(iterate_imported_modules(globs), - base_path) + return get_dependencies_from_modules(iterate_imported_modules(globs), base_path) def get_dependencies_from_pkg(globs, base_path): dependencies = set() for dist in pkg_resources.working_set: - if dist.version == '0.0.0': + if dist.version == "0.0.0": continue # ugly hack to deal with pkg-resource version bug dependencies.add(PackageDependency(dist.project_name, dist.version)) return dependencies source_discovery_strategies = { - 'none': lambda globs, path: set(), - 'imported': get_sources_from_imported_modules, - 'sys': get_sources_from_sys_modules, - 'dir': get_sources_from_local_dir + "none": lambda globs, path: set(), + "imported": get_sources_from_imported_modules, + "sys": get_sources_from_sys_modules, + "dir": get_sources_from_local_dir, } dependency_discovery_strategies = { - 'none': lambda globs, path: set(), - 'imported': get_dependencies_from_imported_modules, - 'sys': get_dependencies_from_sys_modules, - 'pkg': get_dependencies_from_pkg + "none": lambda globs, path: set(), + "imported": get_dependencies_from_imported_modules, + "sys": get_dependencies_from_sys_modules, + "pkg": get_dependencies_from_pkg, } @@ -449,13 +730,14 @@ def gather_sources_and_dependencies(globs, base_dir=None): base_dir = base_dir or experiment_path - gather_sources = source_discovery_strategies[SETTINGS['DISCOVER_SOURCES']] + gather_sources = source_discovery_strategies[SETTINGS["DISCOVER_SOURCES"]] sources = gather_sources(globs, base_dir) if main is not None: sources.add(main) gather_dependencies = dependency_discovery_strategies[ - SETTINGS['DISCOVER_DEPENDENCIES']] + SETTINGS["DISCOVER_DEPENDENCIES"] + ] dependencies = gather_dependencies(globs, base_dir) if opt.has_numpy: diff --git a/sacred/experiment.py b/sacred/experiment.py index d0c1c7ce..253c55a0 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -12,17 +12,29 @@ from sacred.arg_parser import format_usage, get_config_updates from sacred.commandline_options import ( - ForceOption, gather_command_line_options, LoglevelOption) -from sacred.commands import (help_for_command, print_config, - print_dependencies, save_config, - print_named_configs) + ForceOption, + gather_command_line_options, + LoglevelOption, +) +from sacred.commands import ( + help_for_command, + print_config, + print_dependencies, + save_config, + print_named_configs, +) from sacred.config.signature import Signature from sacred.ingredient import Ingredient from sacred.initialize import create_run -from sacred.utils import print_filtered_stacktrace, ensure_wellformed_argv, \ - SacredError, format_sacred_error, PathType +from sacred.utils import ( + print_filtered_stacktrace, + ensure_wellformed_argv, + SacredError, + format_sacred_error, + PathType, +) -__all__ = ('Experiment',) +__all__ = ("Experiment",) class Experiment(Ingredient): @@ -36,10 +48,13 @@ class Experiment(Ingredient): things in any experiment-file. """ - def __init__(self, name: Optional[str] = None, - ingredients: Sequence[Ingredient] = (), - interactive: bool = False, - base_dir: Optional[PathType] = None): + def __init__( + self, + name: Optional[str] = None, + ingredients: Sequence[Ingredient] = (), + interactive: bool = False, + base_dir: Optional[PathType] = None, + ): """ Create a new experiment with the given name and optional ingredients. @@ -66,22 +81,26 @@ def __init__(self, name: Optional[str] = None, caller_globals = inspect.stack()[1][0].f_globals if name is None: if interactive: - raise RuntimeError('name is required in interactive mode.') - mainfile = caller_globals.get('__file__') + raise RuntimeError("name is required in interactive mode.") + mainfile = caller_globals.get("__file__") if mainfile is None: - raise RuntimeError('No main-file found. Are you running in ' - 'interactive mode? If so please provide a ' - 'name and set interactive=True.') + raise RuntimeError( + "No main-file found. Are you running in " + "interactive mode? If so please provide a " + "name and set interactive=True." + ) name = os.path.basename(mainfile) - if name.endswith('.py'): + if name.endswith(".py"): name = name[:-3] - elif name.endswith('.pyc'): + elif name.endswith(".pyc"): name = name[:-4] - super().__init__(path=name, - ingredients=ingredients, - interactive=interactive, - base_dir=base_dir, - _caller_globals=caller_globals) + super().__init__( + path=name, + ingredients=ingredients, + interactive=interactive, + base_dir=base_dir, + _caller_globals=caller_globals, + ) self.default_command = None self.command(print_config, unobserved=True) self.command(print_dependencies, unobserved=True) @@ -129,15 +148,19 @@ def my_main(): ex.run_commandline() """ captured = self.main(function) - if function.__module__ == '__main__': + if function.__module__ == "__main__": # Ensure that automain is not used in interactive mode. import inspect + main_filename = inspect.getfile(function) - if (main_filename == '' or - (main_filename.startswith(''))): - raise RuntimeError('Cannot use @ex.automain decorator in ' - 'interactive mode. Use @ex.main instead.') + if main_filename == "" or ( + main_filename.startswith("") + ): + raise RuntimeError( + "Cannot use @ex.automain decorator in " + "interactive mode. Use @ex.main instead." + ) self.run_commandline() return captured @@ -160,8 +183,10 @@ def option_hook(self, function): """ sig = Signature(function) if "options" not in sig.arguments: - raise KeyError("option_hook functions must have an argument called" - " 'options', but got {}".format(sig.arguments)) + raise KeyError( + "option_hook functions must have an argument called" + " 'options', but got {}".format(sig.arguments) + ) self.option_hooks.append(function) return function @@ -169,24 +194,28 @@ def option_hook(self, function): def get_usage(self, program_name=None): """Get the commandline usage string for this experiment.""" - program_name = os.path.relpath(program_name or sys.argv[0] or 'Dummy', - self.base_dir) + program_name = os.path.relpath( + program_name or sys.argv[0] or "Dummy", self.base_dir + ) commands = OrderedDict(self.gather_commands()) options = gather_command_line_options() long_usage = format_usage(program_name, self.doc, commands, options) # internal usage is a workaround because docopt cannot handle spaces # in program names. So for parsing we use 'dummy' as the program name. # for printing help etc. we want to use the actual program name. - internal_usage = format_usage('dummy', self.doc, commands, options) + internal_usage = format_usage("dummy", self.doc, commands, options) short_usage = printable_usage(long_usage) return short_usage, long_usage, internal_usage - def run(self, command_name: Optional[str] = None, - config_updates: Optional[dict] = None, - named_configs: Sequence[str] = (), - info: Optional[dict] = None, - meta_info: Optional[dict] = None, - options: Optional[dict] = None): + def run( + self, + command_name: Optional[str] = None, + config_updates: Optional[dict] = None, + named_configs: Sequence[str] = (), + info: Optional[dict] = None, + meta_info: Optional[dict] = None, + options: Optional[dict] = None, + ): """ Run the main function of the experiment or a given command. @@ -216,8 +245,9 @@ def run(self, command_name: Optional[str] = None, the Run object corresponding to the finished run """ - run = self._create_run(command_name, config_updates, named_configs, - info, meta_info, options) + run = self._create_run( + command_name, config_updates, named_configs, info, meta_info, options + ) run() return run @@ -242,11 +272,11 @@ def run_commandline(self, argv=None): short_usage, usage, internal_usage = self.get_usage() args = docopt(internal_usage, [str(a) for a in argv[1:]], help=False) - cmd_name = args.get('COMMAND') or self.default_command - config_updates, named_configs = get_config_updates(args['UPDATE']) + cmd_name = args.get("COMMAND") or self.default_command + config_updates, named_configs = get_config_updates(args["UPDATE"]) err = self._check_command(cmd_name) - if not args['help'] and err: + if not args["help"] and err: print(short_usage) print(err) sys.exit(1) @@ -255,8 +285,14 @@ def run_commandline(self, argv=None): sys.exit() try: - return self.run(cmd_name, config_updates, named_configs, info={}, - meta_info={}, options=args) + return self.run( + cmd_name, + config_updates, + named_configs, + info={}, + meta_info={}, + options=args, + ) except Exception as e: if self.current_run: debug = self.current_run.debug @@ -265,7 +301,7 @@ def run_commandline(self, argv=None): # object is built completely. Some exceptions (e.g. # ConfigAddedError) are raised before this. In these cases, # the debug flag must be checked manually. - debug = args.get('--debug', False) + debug = args.get("--debug", False) if debug: # Debug: Don't change behaviour, just re-raise exception @@ -274,6 +310,7 @@ def run_commandline(self, argv=None): # Print exception and attach pdb debugger import traceback import pdb + traceback.print_exception(*sys.exc_info()) pdb.post_mortem() else: @@ -286,7 +323,7 @@ def run_commandline(self, argv=None): print_filtered_stacktrace() sys.exit(1) - def open_resource(self, filename: PathType, mode: str = 'r'): + def open_resource(self, filename: PathType, mode: str = "r"): """Open a file and also save it as a resource. Opens a file, reports it to the observers as a resource, and returns @@ -336,11 +373,11 @@ def add_resource(self, filename: PathType): self.current_run.add_resource(filename) def add_artifact( - self, - filename: PathType, - name: Optional[str] = None, - metadata: Optional[dict] = None, - content_type: Optional[str] = None, + self, + filename: PathType, + name: Optional[str] = None, + metadata: Optional[dict] = None, + content_type: Optional[str] = None, ): """Add a file as an artifact. @@ -383,9 +420,7 @@ def my_captured_function(_run): """ return self.current_run.info - def log_scalar(self, name: str, - value: float, - step: Optional[int] = None): + def log_scalar(self, name: str, value: float, step: Optional[int] = None): """ Add a new measurement. @@ -412,7 +447,7 @@ def _gather(self, func): for ingredient, _ in self.traverse_ingredients(): for name, item in func(ingredient): if ingredient == self: - name = name[len(self.path) + 1:] + name = name[len(self.path) + 1 :] yield name, item def get_default_options(self): @@ -428,16 +463,25 @@ def get_default_options(self): """ _, _, internal_usage = self.get_usage() args = docopt(internal_usage, []) - return {k: v for k, v in args.items() if k.startswith('--')} + return {k: v for k, v in args.items() if k.startswith("--")} # =========================== Internal Interface ========================== - def _create_run(self, command_name=None, config_updates=None, - named_configs=(), info=None, meta_info=None, options=None): + def _create_run( + self, + command_name=None, + config_updates=None, + named_configs=(), + info=None, + meta_info=None, + options=None, + ): command_name = command_name or self.default_command if command_name is None: - raise RuntimeError('No command found to be run. Specify a command ' - 'or define a main function.') + raise RuntimeError( + "No command found to be run. Specify a command " + "or define a main function." + ) default_options = self.get_default_options() if options: @@ -448,16 +492,19 @@ def _create_run(self, command_name=None, config_updates=None, for oh in self.option_hooks: oh(options=options) - run = create_run(self, command_name, config_updates, - named_configs=named_configs, - force=options.get(ForceOption.get_flag(), False), - log_level=options.get(LoglevelOption.get_flag(), - None)) + run = create_run( + self, + command_name, + config_updates, + named_configs=named_configs, + force=options.get(ForceOption.get_flag(), False), + log_level=options.get(LoglevelOption.get_flag(), None), + ) if info is not None: run.info.update(info) - run.meta_info['command'] = command_name - run.meta_info['options'] = options + run.meta_info["command"] = command_name + run.meta_info["options"] = options if meta_info: run.meta_info.update(meta_info) @@ -473,21 +520,25 @@ def _create_run(self, command_name=None, config_updates=None, def _check_command(self, cmd_name): commands = dict(self.gather_commands()) if cmd_name is not None and cmd_name not in commands: - return 'Error: Command "{}" not found. Available commands are: ' \ - '{}'.format(cmd_name, ", ".join(commands.keys())) + return ( + 'Error: Command "{}" not found. Available commands are: ' + "{}".format(cmd_name, ", ".join(commands.keys())) + ) if cmd_name is None: - return 'Error: No command found to be run. Specify a command' \ - ' or define main function. Available commands' \ - ' are: {}'.format(", ".join(commands.keys())) + return ( + "Error: No command found to be run. Specify a command" + " or define main function. Available commands" + " are: {}".format(", ".join(commands.keys())) + ) def _handle_help(self, args, usage): - if args['help'] or args['--help']: - if args['COMMAND'] is None: + if args["help"] or args["--help"]: + if args["COMMAND"] is None: print(usage) return True else: commands = dict(self.gather_commands()) - print(help_for_command(commands[args['COMMAND']])) + print(help_for_command(commands[args["COMMAND"]])) return True return False diff --git a/sacred/host_info.py b/sacred/host_info.py index fef19418..bea4a3a0 100644 --- a/sacred/host_info.py +++ b/sacred/host_info.py @@ -13,7 +13,7 @@ from sacred.utils import optional_kwargs_decorator from sacred.settings import SETTINGS -__all__ = ('host_info_gatherers', 'get_host_info', 'host_info_getter') +__all__ = ("host_info_gatherers", "get_host_info", "host_info_getter") host_info_gatherers = {} """Global dict of functions that are used to collect the host information.""" @@ -73,22 +73,23 @@ def host_info_getter(func, name=None): # #################### Default Host Information ############################### -@host_info_getter(name='hostname') + +@host_info_getter(name="hostname") def _hostname(): return platform.node() -@host_info_getter(name='os') +@host_info_getter(name="os") def _os(): return [platform.system(), platform.platform()] -@host_info_getter(name='python_version') +@host_info_getter(name="python_version") def _python_version(): return platform.python_version() -@host_info_getter(name='cpu') +@host_info_getter(name="cpu") def _cpu(): if platform.system() == "Windows": return _get_cpu_by_pycpuinfo() @@ -102,35 +103,35 @@ def _cpu(): return _get_cpu_by_pycpuinfo() -@host_info_getter(name='gpus') +@host_info_getter(name="gpus") def _gpus(): if not SETTINGS.HOST_INFO.INCLUDE_GPU_INFO: return try: - xml = subprocess.check_output(['nvidia-smi', '-q', '-x']).decode() + xml = subprocess.check_output(["nvidia-smi", "-q", "-x"]).decode() except (FileNotFoundError, OSError, subprocess.CalledProcessError): raise IgnoreHostInfo() - gpu_info = {'gpus': []} + gpu_info = {"gpus": []} for child in ElementTree.fromstring(xml): - if child.tag == 'driver_version': - gpu_info['driver_version'] = child.text - if child.tag != 'gpu': + if child.tag == "driver_version": + gpu_info["driver_version"] = child.text + if child.tag != "gpu": continue gpu = { - 'model': child.find('product_name').text, - 'total_memory': int(child.find('fb_memory_usage').find('total') - .text.split()[0]), - 'persistence_mode': (child.find('persistence_mode').text == - 'Enabled') + "model": child.find("product_name").text, + "total_memory": int( + child.find("fb_memory_usage").find("total").text.split()[0] + ), + "persistence_mode": (child.find("persistence_mode").text == "Enabled"), } - gpu_info['gpus'].append(gpu) + gpu_info["gpus"].append(gpu) return gpu_info -@host_info_getter(name='ENV') +@host_info_getter(name="ENV") def _environment(): keys_to_capture = SETTINGS.HOST_INFO.CAPTURED_ENV return {k: os.environ[k] for k in keys_to_capture if k in os.environ} @@ -140,7 +141,7 @@ def _environment(): def _get_cpu_by_sysctl(): - os.environ['PATH'] += ':/usr/sbin' + os.environ["PATH"] += ":/usr/sbin" command = ["sysctl", "-n", "machdep.cpu.brand_string"] return subprocess.check_output(command).decode().strip() @@ -155,4 +156,4 @@ def _get_cpu_by_proc_cpuinfo(): def _get_cpu_by_pycpuinfo(): - return cpuinfo.get_cpu_info().get('brand', 'Unknown') + return cpuinfo.get_cpu_info().get("brand", "Unknown") diff --git a/sacred/ingredient.py b/sacred/ingredient.py index 768fa4f8..6dccbbc4 100644 --- a/sacred/ingredient.py +++ b/sacred/ingredient.py @@ -10,19 +10,29 @@ import wrapt -from sacred.config import (ConfigDict, ConfigScope, create_captured_function, - load_config_file) -from sacred.dependencies import (PEP440_VERSION_PATTERN, PackageDependency, - Source, gather_sources_and_dependencies) -from sacred.utils import (CircularDependencyError, optional_kwargs_decorator, - join_paths) - -__all__ = ('Ingredient',) +from sacred.config import ( + ConfigDict, + ConfigScope, + create_captured_function, + load_config_file, +) +from sacred.dependencies import ( + PEP440_VERSION_PATTERN, + PackageDependency, + Source, + gather_sources_and_dependencies, +) +from sacred.utils import CircularDependencyError, optional_kwargs_decorator, join_paths + +__all__ = ("Ingredient",) def collect_repositories(sources): - return [{'url': s.repo, 'commit': s.commit, 'dirty': s.is_dirty} - for s in sources if s.repo] + return [ + {"url": s.repo, "commit": s.commit, "dirty": s.is_dirty} + for s in sources + if s.repo + ] @wrapt.decorator @@ -49,11 +59,14 @@ class Ingredient: Ingredients can themselves use ingredients. """ - def __init__(self, path: PathType, - ingredients: Sequence['Ingredient'] = (), - interactive: bool = False, - _caller_globals: Optional[dict] = None, - base_dir: Optional[PathType] = None): + def __init__( + self, + path: PathType, + ingredients: Sequence["Ingredient"] = (), + interactive: bool = False, + _caller_globals: Optional[dict] = None, + base_dir: Optional[PathType] = None, + ): self.path = path self.config_hooks = [] self.configurations = [] @@ -67,16 +80,19 @@ def __init__(self, path: PathType, self.commands = OrderedDict() # capture some context information _caller_globals = _caller_globals or inspect.stack()[1][0].f_globals - mainfile_dir = os.path.dirname(_caller_globals.get('__file__', '.')) + mainfile_dir = os.path.dirname(_caller_globals.get("__file__", ".")) self.base_dir = os.path.abspath(base_dir or mainfile_dir) - self.doc = _caller_globals.get('__doc__', "") - self.mainfile, self.sources, self.dependencies = \ - gather_sources_and_dependencies(_caller_globals, self.base_dir) + self.doc = _caller_globals.get("__doc__", "") + self.mainfile, self.sources, self.dependencies = gather_sources_and_dependencies( + _caller_globals, self.base_dir + ) if self.mainfile is None and not interactive: - raise RuntimeError("Defining an experiment in interactive mode! " - "The sourcecode cannot be stored and the " - "experiment won't be reproducible. If you still" - " want to run it pass interactive=True") + raise RuntimeError( + "Defining an experiment in interactive mode! " + "The sourcecode cannot be stored and the " + "experiment won't be reproducible. If you still" + " want to run it pass interactive=True" + ) # =========================== Decorators ================================== @optional_kwargs_decorator @@ -183,11 +199,17 @@ def config_hook(self, func): ingredient. """ argspec = inspect.getfullargspec(func) - args = ['config', 'command_name', 'logger'] - if not (argspec.args == args and argspec.varargs is None and - not argspec.kwonlyargs and argspec.defaults is None): - raise ValueError('Wrong signature for config_hook. Expected: ' - '(config, command_name, logger)') + args = ["config", "command_name", "logger"] + if not ( + argspec.args == args + and argspec.varargs is None + and not argspec.kwonlyargs + and argspec.defaults is None + ): + raise ValueError( + "Wrong signature for config_hook. Expected: " + "(config, command_name, logger)" + ) self.config_hooks.append(func) return self.config_hooks[-1] @@ -210,20 +232,19 @@ def add_config(self, cfg_or_file=None, **kw_conf): :param kw_conf: Configuration entries to be added to this ingredient/experiment. """ - self.configurations.append(self._create_config_dict(cfg_or_file, - kw_conf)) + self.configurations.append(self._create_config_dict(cfg_or_file, kw_conf)) def _add_named_config(self, name, conf): if name in self.named_configs: - raise KeyError('Configuration name "{}" already in use!' - .format(name)) + raise KeyError('Configuration name "{}" already in use!'.format(name)) self.named_configs[name] = conf @staticmethod def _create_config_dict(cfg_or_file, kw_conf): if cfg_or_file is not None and kw_conf: - raise ValueError("cannot combine keyword config with " - "positional argument") + raise ValueError( + "cannot combine keyword config with " "positional argument" + ) if cfg_or_file is None: if not kw_conf: raise ValueError("attempted to add empty config") @@ -232,12 +253,11 @@ def _create_config_dict(cfg_or_file, kw_conf): return ConfigDict(cfg_or_file) elif isinstance(cfg_or_file, str): if not os.path.exists(cfg_or_file): - raise OSError('File not found {}'.format(cfg_or_file)) + raise OSError("File not found {}".format(cfg_or_file)) abspath = os.path.abspath(cfg_or_file) return ConfigDict(load_config_file(abspath)) else: - raise TypeError("Invalid argument type {}" - .format(type(cfg_or_file))) + raise TypeError("Invalid argument type {}".format(type(cfg_or_file))) def add_named_config(self, name, cfg_or_file=None, **kw_conf): """ @@ -260,8 +280,7 @@ def add_named_config(self, name, cfg_or_file=None, **kw_conf): :param kw_conf: Configuration entries to be added to this ingredient/experiment. """ - self._add_named_config(name, self._create_config_dict(cfg_or_file, - kw_conf)) + self._add_named_config(name, self._create_config_dict(cfg_or_file, kw_conf)) def add_source_file(self, filename): """ @@ -347,8 +366,7 @@ def get_experiment_info(self): for dep in dependencies: dep.fill_missing_version() - mainfile = (self.mainfile.to_json(self.base_dir)[0] - if self.mainfile else None) + mainfile = self.mainfile.to_json(self.base_dir)[0] if self.mainfile else None def name_lower(d): return d.name.lower() @@ -357,10 +375,9 @@ def name_lower(d): name=self.path, base_dir=self.base_dir, sources=[s.to_json(self.base_dir) for s in sorted(sources)], - dependencies=[d.to_json() - for d in sorted(dependencies, key=name_lower)], + dependencies=[d.to_json() for d in sorted(dependencies, key=name_lower)], repositories=collect_repositories(sources), - mainfile=mainfile + mainfile=mainfile, ) def traverse_ingredients(self): diff --git a/sacred/initialize.py b/sacred/initialize.py index 020b752c..abaf3675 100755 --- a/sacred/initialize.py +++ b/sacred/initialize.py @@ -5,24 +5,47 @@ from collections import OrderedDict, defaultdict from copy import copy, deepcopy -from sacred.config import (ConfigDict, chain_evaluate_config_scopes, dogmatize, - load_config_file, undogmatize) +from sacred.config import ( + ConfigDict, + chain_evaluate_config_scopes, + dogmatize, + load_config_file, + undogmatize, +) from sacred.config.config_summary import ConfigSummary from sacred.config.custom_containers import make_read_only from sacred.host_info import get_host_info from sacred.randomness import create_rnd, get_seed from sacred.run import Run -from sacred.utils import (convert_to_nested_dict, create_basic_stream_logger, - get_by_dotted_path, is_prefix, rel_path, - iterate_flattened, set_by_dotted_path, - recursive_update, iter_prefixes, join_paths, - NamedConfigNotFoundError, ConfigAddedError) +from sacred.utils import ( + convert_to_nested_dict, + create_basic_stream_logger, + get_by_dotted_path, + is_prefix, + rel_path, + iterate_flattened, + set_by_dotted_path, + recursive_update, + iter_prefixes, + join_paths, + NamedConfigNotFoundError, + ConfigAddedError, +) from sacred.settings import SETTINGS class Scaffold: - def __init__(self, config_scopes, subrunners, path, captured_functions, - commands, named_configs, config_hooks, generate_seed): + def __init__( + self, + config_scopes, + subrunners, + path, + captured_functions, + commands, + named_configs, + config_hooks, + generate_seed, + ): self.config_scopes = config_scopes self.named_configs = named_configs self.subrunners = subrunners @@ -42,39 +65,40 @@ def __init__(self, config_scopes, subrunners, path, captured_functions, self.commands = commands self.config_mods = None self.summaries = [] - self.captured_args = {join_paths(cf.prefix, n) - for cf in self._captured_functions - for n in cf.signature.arguments} - self.captured_args.add('__doc__') # allow setting the config docstring + self.captured_args = { + join_paths(cf.prefix, n) + for cf in self._captured_functions + for n in cf.signature.arguments + } + self.captured_args.add("__doc__") # allow setting the config docstring def set_up_seed(self, rnd=None): if self.seed is not None: return - self.seed = self.config.get('seed') + self.seed = self.config.get("seed") if self.seed is None: self.seed = get_seed(rnd) self.rnd = create_rnd(self.seed) if self.generate_seed: - self.config['seed'] = self.seed + self.config["seed"] = self.seed - if 'seed' in self.config and 'seed' in self.config_mods.added: - self.config_mods.modified.add('seed') - self.config_mods.added -= {'seed'} + if "seed" in self.config and "seed" in self.config_mods.added: + self.config_mods.modified.add("seed") + self.config_mods.added -= {"seed"} # Hierarchically set the seed of proper subrunners - for subrunner_path, subrunner in reversed(list( - self.subrunners.items())): + for subrunner_path, subrunner in reversed(list(self.subrunners.items())): if is_prefix(self.path, subrunner_path): subrunner.set_up_seed(self.rnd) def gather_fallbacks(self): - fallback = {'_log': self.logger} + fallback = {"_log": self.logger} for sr_path, subrunner in self.subrunners.items(): if self.path and is_prefix(self.path, sr_path): - path = sr_path[len(self.path):].strip('.') + path = sr_path[len(self.path) :].strip(".") set_by_dotted_path(fallback, path, subrunner.config) else: set_by_dotted_path(fallback, sr_path, subrunner.config) @@ -90,12 +114,15 @@ def run_named_config(self, config_name): if config_name not in self.named_configs: raise NamedConfigNotFoundError( named_config=config_name, - available_named_configs=tuple(self.named_configs.keys())) + available_named_configs=tuple(self.named_configs.keys()), + ) nc = self.named_configs[config_name] - cfg = nc(fixed=self.get_config_updates_recursive(), - preset=self.presets, - fallback=self.fallback) + cfg = nc( + fixed=self.get_config_updates_recursive(), + preset=self.presets, + fallback=self.fallback, + ) return undogmatize(cfg) @@ -104,7 +131,8 @@ def set_up_config(self): self.config_scopes, fixed=self.config_updates, preset=self.config, - fallback=self.fallback) + fallback=self.fallback, + ) self.get_config_modifications() @@ -119,8 +147,8 @@ def run_config_hooks(self, config, command_name, logger): def get_config_modifications(self): self.config_mods = ConfigSummary( - added={key - for key, value in iterate_flattened(self.config_updates)}) + added={key for key, value in iterate_flattened(self.config_updates)} + ) for cfg_summary in self.summaries: self.config_mods.update_from(cfg_summary) @@ -146,7 +174,7 @@ def get_fixture_recursive(runner): sub_fix = copy(subrunner.config) sub_path = sr_path if is_prefix(self.path, sub_path): - sub_path = sr_path[len(self.path):].strip('.') + sub_path = sr_path[len(self.path) :].strip(".") # Note: This might fail if we allow non-dict fixtures set_by_dotted_path(self.fixture, sub_path, sub_fix) @@ -158,8 +186,8 @@ def get_fixture_recursive(runner): def finalize_initialization(self, run): # look at seed again, because it might have changed during the # configuration process - if 'seed' in self.config: - self.seed = self.config['seed'] + if "seed" in self.config: + self.seed = self.config["seed"] self.rnd = create_rnd(self.seed) for cfunc in self._captured_functions: @@ -168,15 +196,17 @@ def finalize_initialization(self, run): seed = get_seed(self.rnd) cfunc.rnd = create_rnd(seed) cfunc.run = run - cfunc.config = get_by_dotted_path(self.get_fixture(), cfunc.prefix, - default={}) + cfunc.config = get_by_dotted_path( + self.get_fixture(), cfunc.prefix, default={} + ) # Make configuration read only if enabled in settings if SETTINGS.CONFIG.READ_ONLY_CONFIG: cfunc.config = make_read_only( cfunc.config, - error_message='The configuration is read-only in a ' - 'captured function!') + error_message="The configuration is read-only in a " + "captured function!", + ) if not run.force: self._warn_about_suspicious_changes() @@ -194,14 +224,15 @@ def _warn_about_suspicious_changes(self): if type_old in (int, float) and type_new in (int, float): continue self.logger.warning( - 'Changed type of config entry "%s" from %s to %s' % - (key, type_old.__name__, type_new.__name__)) + 'Changed type of config entry "%s" from %s to %s' + % (key, type_old.__name__, type_new.__name__) + ) for cfg_summary in self.summaries: for key in cfg_summary.ignored_fallbacks: self.logger.warning( 'Ignored attempt to set value of "%s", because it is an ' - 'ingredient.' % key + "ingredient." % key ) def __repr__(self): @@ -223,12 +254,13 @@ def get_configuration(scaffolding): def distribute_named_configs(scaffolding, named_configs): for ncfg in named_configs: if os.path.exists(ncfg): - scaffolding[''].use_named_config(ncfg) + scaffolding[""].use_named_config(ncfg) else: - path, _, cfg_name = ncfg.rpartition('.') + path, _, cfg_name = ncfg.rpartition(".") if path not in scaffolding: - raise KeyError('Ingredient for named config "{}" not found' - .format(ncfg)) + raise KeyError( + 'Ingredient for named config "{}" not found'.format(ncfg) + ) scaffolding[path].use_named_config(cfg_name) @@ -260,34 +292,35 @@ def create_scaffolding(experiment, sorted_ingredients): for ingredient in sorted_ingredients[:-1]: scaffolding[ingredient] = Scaffold( config_scopes=ingredient.configurations, - subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m]) - for m in ingredient.ingredients]), + subrunners=OrderedDict( + [(scaffolding[m].path, scaffolding[m]) for m in ingredient.ingredients] + ), path=ingredient.path, captured_functions=ingredient.captured_functions, commands=ingredient.commands, named_configs=ingredient.named_configs, config_hooks=ingredient.config_hooks, - generate_seed=False) + generate_seed=False, + ) scaffolding[experiment] = Scaffold( experiment.configurations, - subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m]) - for m in experiment.ingredients]), - path='', + subrunners=OrderedDict( + [(scaffolding[m].path, scaffolding[m]) for m in experiment.ingredients] + ), + path="", captured_functions=experiment.captured_functions, commands=experiment.commands, named_configs=experiment.named_configs, config_hooks=experiment.config_hooks, - generate_seed=True) + generate_seed=True, + ) - scaffolding_ret = OrderedDict([ - (sc.path, sc) - for sc in scaffolding.values() - ]) + scaffolding_ret = OrderedDict([(sc.path, sc) for sc in scaffolding.values()]) if len(scaffolding_ret) != len(scaffolding): raise ValueError( - 'The pathes of the ingredients are not unique. ' - '{}'.format([s.path for s in scaffolding]) + "The pathes of the ingredients are not unique. " + "{}".format([s.path for s in scaffolding]) ) return scaffolding_ret @@ -308,7 +341,7 @@ def get_config_modifications(scaffolding): def get_command(scaffolding, command_path): - path, _, command_name = command_path.rpartition('.') + path, _, command_name = command_path.rpartition(".") if path not in scaffolding: raise KeyError('Ingredient for command "%s" not found.' % command_path) @@ -316,19 +349,20 @@ def get_command(scaffolding, command_path): return scaffolding[path].commands[command_name] else: if path: - raise KeyError('Command "%s" not found in ingredient "%s"' % - (command_name, path)) + raise KeyError( + 'Command "%s" not found in ingredient "%s"' % (command_name, path) + ) else: raise KeyError('Command "%s" not found' % command_name) def find_best_match(path, prefixes): """Find the Ingredient that shares the longest prefix with path.""" - path_parts = path.split('.') + path_parts = path.split(".") for p in prefixes: - if len(p) <= len(path_parts) and p == path_parts[:len(p)]: - return '.'.join(p), '.'.join(path_parts[len(p):]) - return '', path + if len(p) <= len(path_parts) and p == path_parts[: len(p)]: + return ".".join(p), ".".join(path_parts[len(p) :]) + return "", path def distribute_presets(prefixes, scaffolding, config_updates): @@ -347,33 +381,42 @@ def distribute_config_updates(prefixes, scaffolding, config_updates): def get_scaffolding_and_config_name(named_config, scaffolding): if os.path.exists(named_config): - path, cfg_name = '', named_config + path, cfg_name = "", named_config else: - path, _, cfg_name = named_config.rpartition('.') + path, _, cfg_name = named_config.rpartition(".") if path not in scaffolding: - raise KeyError('Ingredient for named config "{}" not found' - .format(named_config)) + raise KeyError( + 'Ingredient for named config "{}" not found'.format(named_config) + ) scaff = scaffolding[path] return scaff, cfg_name -def create_run(experiment, command_name, config_updates=None, - named_configs=(), force=False, log_level=None): +def create_run( + experiment, + command_name, + config_updates=None, + named_configs=(), + force=False, + log_level=None, +): sorted_ingredients = gather_ingredients_topological(experiment) scaffolding = create_scaffolding(experiment, sorted_ingredients) # get all split non-empty prefixes sorted from deepest to shallowest - prefixes = sorted([s.split('.') for s in scaffolding if s != ''], - reverse=True, key=lambda p: len(p)) + prefixes = sorted( + [s.split(".") for s in scaffolding if s != ""], + reverse=True, + key=lambda p: len(p), + ) # --------- configuration process ------------------- # Phase 1: Config updates config_updates = config_updates or {} config_updates = convert_to_nested_dict(config_updates) - root_logger, run_logger = initialize_logging(experiment, scaffolding, - log_level) + root_logger, run_logger = initialize_logging(experiment, scaffolding, log_level) distribute_config_updates(prefixes, scaffolding, config_updates) # Phase 2: Named Configs @@ -383,9 +426,7 @@ def create_run(experiment, command_name, config_updates=None, ncfg_updates = scaff.run_named_config(cfg_name) distribute_presets(prefixes, scaffolding, ncfg_updates) for ncfg_key, value in iterate_flattened(ncfg_updates): - set_by_dotted_path(config_updates, - join_paths(scaff.path, ncfg_key), - value) + set_by_dotted_path(config_updates, join_paths(scaff.path, ncfg_key), value) distribute_config_updates(prefixes, scaffolding, config_updates) @@ -398,7 +439,8 @@ def create_run(experiment, command_name, config_updates=None, config = get_configuration(scaffolding) # run config hooks config_hook_updates = scaffold.run_config_hooks( - config, command_name, run_logger) + config, command_name, run_logger + ) recursive_update(scaffold.config, config_hook_updates) # Phase 4: finalize seeding @@ -416,12 +458,21 @@ def create_run(experiment, command_name, config_updates=None, pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] - run = Run(config, config_modifications, main_function, - copy(experiment.observers), root_logger, run_logger, - experiment_info, host_info, pre_runs, post_runs, - experiment.captured_out_filter) - - if hasattr(main_function, 'unobserved'): + run = Run( + config, + config_modifications, + main_function, + copy(experiment.observers), + root_logger, + run_logger, + experiment_info, + host_info, + pre_runs, + post_runs, + experiment.captured_out_filter, + ) + + if hasattr(main_function, "unobserved"): run.unobserved = main_function.unobserved run.force = force diff --git a/sacred/metrics_logger.py b/sacred/metrics_logger.py index ff51c7e1..ca975760 100644 --- a/sacred/metrics_logger.py +++ b/sacred/metrics_logger.py @@ -45,9 +45,8 @@ def log_scalar_metric(self, metric_name, value, step=None): if step is None: step = self._metric_step_counter.get(metric_name, -1) + 1 self._logged_metrics.put( - ScalarMetricLogEntry(metric_name, step, - datetime.datetime.utcnow(), - value)) + ScalarMetricLogEntry(metric_name, step, datetime.datetime.utcnow(), value) + ) self._metric_step_counter[metric_name] = step def get_last_metrics(self): @@ -65,7 +64,7 @@ def get_last_metrics(self): return messages -class ScalarMetricLogEntry(): +class ScalarMetricLogEntry: """Container for measurements of scalar metrics. There is exactly one ScalarMetricLogEntry per logged scalar metric value. @@ -98,12 +97,9 @@ def linearize_metrics(logged_metrics): "steps": [], "values": [], "timestamps": [], - "name": metric_entry.name + "name": metric_entry.name, } - metrics_by_name[metric_entry.name]["steps"] \ - .append(metric_entry.step) - metrics_by_name[metric_entry.name]["values"] \ - .append(metric_entry.value) - metrics_by_name[metric_entry.name]["timestamps"] \ - .append(metric_entry.timestamp) + metrics_by_name[metric_entry.name]["steps"].append(metric_entry.step) + metrics_by_name[metric_entry.name]["values"].append(metric_entry.value) + metrics_by_name[metric_entry.name]["timestamps"].append(metric_entry.timestamp) return metrics_by_name diff --git a/sacred/observers/__init__.py b/sacred/observers/__init__.py index d0b81b52..31f6e9aa 100644 --- a/sacred/observers/__init__.py +++ b/sacred/observers/__init__.py @@ -10,6 +10,13 @@ from sacred.observers.telegram_obs import TelegramObserver -__all__ = ('FileStorageObserver', 'RunObserver', 'MongoObserver', - 'SqlObserver', 'TinyDbObserver', 'TinyDbReader', - 'SlackObserver', 'TelegramObserver') +__all__ = ( + "FileStorageObserver", + "RunObserver", + "MongoObserver", + "SqlObserver", + "TinyDbObserver", + "TinyDbReader", + "SlackObserver", + "TelegramObserver", +) diff --git a/sacred/observers/base.py b/sacred/observers/base.py index c55d55f9..0f3411cc 100644 --- a/sacred/observers/base.py +++ b/sacred/observers/base.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 -__all__ = ('RunObserver', 'td_format') +__all__ = ("RunObserver", "td_format") class RunObserver: @@ -9,12 +9,14 @@ class RunObserver: priority = 0 - def queued_event(self, ex_info, command, host_info, queue_time, config, - meta_info, _id): + def queued_event( + self, ex_info, command, host_info, queue_time, config, meta_info, _id + ): pass - def started_event(self, ex_info, command, host_info, start_time, config, - meta_info, _id): + def started_event( + self, ex_info, command, host_info, start_time, config, meta_info, _id + ): pass def heartbeat_event(self, info, captured_out, beat_time, result): @@ -43,12 +45,12 @@ def td_format(td_object): return "less than a second" periods = [ - ('year', 60 * 60 * 24 * 365), - ('month', 60 * 60 * 24 * 30), - ('day', 60 * 60 * 24), - ('hour', 60 * 60), - ('minute', 60), - ('second', 1) + ("year", 60 * 60 * 24 * 365), + ("month", 60 * 60 * 24 * 30), + ("day", 60 * 60 * 24), + ("hour", 60 * 60), + ("minute", 60), + ("second", 1), ] strings = [] diff --git a/sacred/observers/file_storage.py b/sacred/observers/file_storage.py index aef84924..4c7de727 100644 --- a/sacred/observers/file_storage.py +++ b/sacred/observers/file_storage.py @@ -21,29 +21,39 @@ class FileStorageObserver(RunObserver): - VERSION = 'FileStorageObserver-0.7.0' + VERSION = "FileStorageObserver-0.7.0" @classmethod - def create(cls, basedir: PathType, - resource_dir: Optional[PathType] = None, - source_dir: Optional[PathType] = None, - template: Optional[PathType] = None, - priority: int = DEFAULT_FILE_STORAGE_PRIORITY): + def create( + cls, + basedir: PathType, + resource_dir: Optional[PathType] = None, + source_dir: Optional[PathType] = None, + template: Optional[PathType] = None, + priority: int = DEFAULT_FILE_STORAGE_PRIORITY, + ): basedir = Path(basedir) - resource_dir = resource_dir or basedir / '_resources' - source_dir = source_dir or basedir / '_sources' + resource_dir = resource_dir or basedir / "_resources" + source_dir = source_dir or basedir / "_sources" if template is not None: if not os.path.exists(template): - raise FileNotFoundError("Couldn't find template file '{}'" - .format(template)) + raise FileNotFoundError( + "Couldn't find template file '{}'".format(template) + ) else: - template = basedir / 'template.html' + template = basedir / "template.html" if not template.exists(): template = None return cls(basedir, resource_dir, source_dir, template, priority) - def __init__(self, basedir, resource_dir, source_dir, template, - priority=DEFAULT_FILE_STORAGE_PRIORITY): + def __init__( + self, + basedir, + resource_dir, + source_dir, + template, + priority=DEFAULT_FILE_STORAGE_PRIORITY, + ): self.basedir = str(basedir) self.resource_dir = resource_dir self.source_dir = source_dir @@ -57,9 +67,11 @@ def __init__(self, basedir, resource_dir, source_dir, template, self.cout_write_cursor = 0 def _maximum_existing_run_id(self): - dir_nrs = [int(d) for d in os.listdir(self.basedir) - if os.path.isdir(os.path.join(self.basedir, d)) and - d.isdigit()] + dir_nrs = [ + int(d) + for d in os.listdir(self.basedir) + if os.path.isdir(os.path.join(self.basedir, d)) and d.isdigit() + ] if dir_nrs: return max(dir_nrs) else: @@ -89,32 +101,33 @@ def _make_run_dir(self, _id): self.dir = os.path.join(self.basedir, str(_id)) os.mkdir(self.dir) - def queued_event(self, ex_info, command, host_info, queue_time, config, - meta_info, _id): + def queued_event( + self, ex_info, command, host_info, queue_time, config, meta_info, _id + ): self._make_run_dir(_id) self.run_entry = { - 'experiment': dict(ex_info), - 'command': command, - 'host': dict(host_info), - 'meta': meta_info, - 'status': 'QUEUED', + "experiment": dict(ex_info), + "command": command, + "host": dict(host_info), + "meta": meta_info, + "status": "QUEUED", } self.config = config self.info = {} - self.save_json(self.run_entry, 'run.json') - self.save_json(self.config, 'config.json') + self.save_json(self.run_entry, "run.json") + self.save_json(self.config, "config.json") - for s, m in ex_info['sources']: + for s, m in ex_info["sources"]: self.save_file(s) return os.path.relpath(self.dir, self.basedir) if _id is None else _id def save_sources(self, ex_info): - base_dir = ex_info['base_dir'] + base_dir = ex_info["base_dir"] source_info = [] - for s, m in ex_info['sources']: + for s, m in ex_info["sources"]: abspath = os.path.join(base_dir, s) store_path, md5sum = self.find_or_save(abspath, self.source_dir) # assert m == md5sum @@ -122,30 +135,31 @@ def save_sources(self, ex_info): source_info.append([s, relative_source]) return source_info - def started_event(self, ex_info, command, host_info, start_time, config, - meta_info, _id): + def started_event( + self, ex_info, command, host_info, start_time, config, meta_info, _id + ): self._make_run_dir(_id) - ex_info['sources'] = self.save_sources(ex_info) + ex_info["sources"] = self.save_sources(ex_info) self.run_entry = { - 'experiment': dict(ex_info), - 'command': command, - 'host': dict(host_info), - 'start_time': start_time.isoformat(), - 'meta': meta_info, - 'status': 'RUNNING', - 'resources': [], - 'artifacts': [], - 'heartbeat': None + "experiment": dict(ex_info), + "command": command, + "host": dict(host_info), + "start_time": start_time.isoformat(), + "meta": meta_info, + "status": "RUNNING", + "resources": [], + "artifacts": [], + "heartbeat": None, } self.config = config self.info = {} self.cout = "" self.cout_write_cursor = 0 - self.save_json(self.run_entry, 'run.json') - self.save_json(self.config, 'config.json') + self.save_json(self.run_entry, "run.json") + self.save_json(self.config, "config.json") self.save_cout() return os.path.relpath(self.dir, self.basedir) if _id is None else _id @@ -154,14 +168,14 @@ def find_or_save(self, filename, store_dir: Path): os.makedirs(str(store_dir), exist_ok=True) source_name, ext = os.path.splitext(os.path.basename(filename)) md5sum = get_digest(filename) - store_name = source_name + '_' + md5sum + ext + store_name = source_name + "_" + md5sum + ext store_path = store_dir / store_name if not store_path.exists(): copyfile(filename, str(store_path)) return store_path, md5sum def save_json(self, obj, filename): - with open(os.path.join(self.dir, filename), 'w') as f: + with open(os.path.join(self.dir, filename), "w") as f: json.dump(flatten(obj), f, sort_keys=True, indent=2) def save_file(self, filename, target_name=None): @@ -169,70 +183,73 @@ def save_file(self, filename, target_name=None): copyfile(filename, os.path.join(self.dir, target_name)) def save_cout(self): - with open(os.path.join(self.dir, 'cout.txt'), 'ab') as f: - f.write(self.cout[self.cout_write_cursor:].encode("utf-8")) + with open(os.path.join(self.dir, "cout.txt"), "ab") as f: + f.write(self.cout[self.cout_write_cursor :].encode("utf-8")) self.cout_write_cursor = len(self.cout) def render_template(self): if opt.has_mako and self.template: from mako.template import Template + template = Template(filename=self.template) - report = template.render(run=self.run_entry, - config=self.config, - info=self.info, - cout=self.cout, - savedir=self.dir) + report = template.render( + run=self.run_entry, + config=self.config, + info=self.info, + cout=self.cout, + savedir=self.dir, + ) ext = self.template.suffix - with open(os.path.join(self.dir, 'report' + ext), 'w') as f: + with open(os.path.join(self.dir, "report" + ext), "w") as f: f.write(report) def heartbeat_event(self, info, captured_out, beat_time, result): self.info = info - self.run_entry['heartbeat'] = beat_time.isoformat() - self.run_entry['result'] = result + self.run_entry["heartbeat"] = beat_time.isoformat() + self.run_entry["result"] = result self.cout = captured_out self.save_cout() - self.save_json(self.run_entry, 'run.json') + self.save_json(self.run_entry, "run.json") if self.info: - self.save_json(self.info, 'info.json') + self.save_json(self.info, "info.json") def completed_event(self, stop_time, result): - self.run_entry['stop_time'] = stop_time.isoformat() - self.run_entry['result'] = result - self.run_entry['status'] = 'COMPLETED' + self.run_entry["stop_time"] = stop_time.isoformat() + self.run_entry["result"] = result + self.run_entry["status"] = "COMPLETED" - self.save_json(self.run_entry, 'run.json') + self.save_json(self.run_entry, "run.json") self.render_template() def interrupted_event(self, interrupt_time, status): - self.run_entry['stop_time'] = interrupt_time.isoformat() - self.run_entry['status'] = status - self.save_json(self.run_entry, 'run.json') + self.run_entry["stop_time"] = interrupt_time.isoformat() + self.run_entry["status"] = status + self.save_json(self.run_entry, "run.json") self.render_template() def failed_event(self, fail_time, fail_trace): - self.run_entry['stop_time'] = fail_time.isoformat() - self.run_entry['status'] = 'FAILED' - self.run_entry['fail_trace'] = fail_trace - self.save_json(self.run_entry, 'run.json') + self.run_entry["stop_time"] = fail_time.isoformat() + self.run_entry["status"] = "FAILED" + self.run_entry["fail_trace"] = fail_trace + self.save_json(self.run_entry, "run.json") self.render_template() def resource_event(self, filename): store_path, md5sum = self.find_or_save(filename, self.resource_dir) - self.run_entry['resources'].append([filename, str(store_path)]) - self.save_json(self.run_entry, 'run.json') + self.run_entry["resources"].append([filename, str(store_path)]) + self.save_json(self.run_entry, "run.json") def artifact_event(self, name, filename, metadata=None, content_type=None): self.save_file(filename, name) - self.run_entry['artifacts'].append(name) - self.save_json(self.run_entry, 'run.json') + self.run_entry["artifacts"].append(name) + self.save_json(self.run_entry, "run.json") def log_metrics(self, metrics_by_name, info): """Store new measurements into metrics.json. """ try: metrics_path = os.path.join(self.dir, "metrics.json") - with open(metrics_path, 'r') as f: + with open(metrics_path, "r") as f: saved_metrics = json.load(f) except IOError: # We haven't recorded anything yet. Start Collecting. @@ -241,20 +258,21 @@ def log_metrics(self, metrics_by_name, info): for metric_name, metric_ptr in metrics_by_name.items(): if metric_name not in saved_metrics: - saved_metrics[metric_name] = {"values": [], - "steps": [], - "timestamps": []} + saved_metrics[metric_name] = { + "values": [], + "steps": [], + "timestamps": [], + } saved_metrics[metric_name]["values"] += metric_ptr["values"] saved_metrics[metric_name]["steps"] += metric_ptr["steps"] # Manually convert them to avoid passing a datetime dtype handler # when we're trying to convert into json. - timestamps_norm = [ts.isoformat() - for ts in metric_ptr["timestamps"]] + timestamps_norm = [ts.isoformat() for ts in metric_ptr["timestamps"]] saved_metrics[metric_name]["timestamps"] += timestamps_norm - self.save_json(saved_metrics, 'metrics.json') + self.save_json(saved_metrics, "metrics.json") def __eq__(self, other): if isinstance(other, FileStorageObserver): @@ -265,8 +283,8 @@ def __eq__(self, other): class FileStorageOption(CommandLineOption): """Add a file-storage observer to the experiment.""" - short_flag = 'F' - arg = 'BASEDIR' + short_flag = "F" + arg = "BASEDIR" arg_description = "Base-directory to write the runs to" @classmethod diff --git a/sacred/observers/mongo.py b/sacred/observers/mongo.py index 5e3b8c70..3b95f70b 100644 --- a/sacred/observers/mongo.py +++ b/sacred/observers/mongo.py @@ -25,42 +25,58 @@ def force_valid_bson_key(key): key = str(key) - if key.startswith('$'): - key = '@' + key[1:] - key = key.replace('.', ',') + if key.startswith("$"): + key = "@" + key[1:] + key = key.replace(".", ",") return key def force_bson_encodeable(obj): import bson + if isinstance(obj, dict): try: bson.BSON.encode(obj, check_keys=True) return obj except bson.InvalidDocument: - return {force_valid_bson_key(k): force_bson_encodeable(v) - for k, v in obj.items()} + return { + force_valid_bson_key(k): force_bson_encodeable(v) + for k, v in obj.items() + } elif opt.has_numpy and isinstance(obj, opt.np.ndarray): return obj else: try: - bson.BSON.encode({'dict_just_for_testing': obj}) + bson.BSON.encode({"dict_just_for_testing": obj}) return obj except bson.InvalidDocument: return str(obj) class MongoObserver(RunObserver): - COLLECTION_NAME_BLACKLIST = {'fs.files', 'fs.chunks', '_properties', - 'system.indexes', 'search_space', - 'search_spaces'} - VERSION = 'MongoObserver-0.7.0' + COLLECTION_NAME_BLACKLIST = { + "fs.files", + "fs.chunks", + "_properties", + "system.indexes", + "search_space", + "search_spaces", + } + VERSION = "MongoObserver-0.7.0" @classmethod - def create(cls, url=None, db_name='sacred', collection='runs', - overwrite=None, priority=DEFAULT_MONGO_PRIORITY, - client=None, failure_dir=None, **kwargs): + def create( + cls, + url=None, + db_name="sacred", + collection="runs", + overwrite=None, + priority=DEFAULT_MONGO_PRIORITY, + client=None, + failure_dir=None, + **kwargs + ): """Factory method for MongoObserver. Parameters @@ -82,38 +98,51 @@ def create(cls, url=None, db_name='sacred', collection='runs', if client is not None: if not isinstance(client, pymongo.MongoClient): - raise ValueError("client needs to be a pymongo.MongoClient, " - "but is {} instead".format(type(client))) + raise ValueError( + "client needs to be a pymongo.MongoClient, " + "but is {} instead".format(type(client)) + ) if url is not None: - raise ValueError('Cannot pass both a client and a url.') + raise ValueError("Cannot pass both a client and a url.") else: client = pymongo.MongoClient(url, **kwargs) database = client[db_name] if collection in MongoObserver.COLLECTION_NAME_BLACKLIST: - raise KeyError('Collection name "{}" is reserved. ' - 'Please use a different one.'.format(collection)) + raise KeyError( + 'Collection name "{}" is reserved. ' + "Please use a different one.".format(collection) + ) runs_collection = database[collection] metrics_collection = database["metrics"] fs = gridfs.GridFS(database) - return cls(runs_collection, - fs, overwrite=overwrite, - metrics_collection=metrics_collection, - failure_dir=failure_dir, - priority=priority) - - def __init__(self, runs_collection, - fs, overwrite=None, metrics_collection=None, - failure_dir=None, - priority=DEFAULT_MONGO_PRIORITY): + return cls( + runs_collection, + fs, + overwrite=overwrite, + metrics_collection=metrics_collection, + failure_dir=failure_dir, + priority=priority, + ) + + def __init__( + self, + runs_collection, + fs, + overwrite=None, + metrics_collection=None, + failure_dir=None, + priority=DEFAULT_MONGO_PRIORITY, + ): self.runs = runs_collection self.metrics = metrics_collection self.fs = fs if isinstance(overwrite, (int, str)): overwrite = int(overwrite) - run = self.runs.find_one({'_id': overwrite}) + run = self.runs.find_one({"_id": overwrite}) if run is None: - raise RuntimeError("Couldn't find run to overwrite with " - "_id='{}'".format(overwrite)) + raise RuntimeError( + "Couldn't find run to overwrite with " "_id='{}'".format(overwrite) + ) else: overwrite = run self.overwrite = overwrite @@ -121,79 +150,83 @@ def __init__(self, runs_collection, self.priority = priority self.failure_dir = failure_dir - def queued_event(self, ex_info, command, host_info, queue_time, config, - meta_info, _id): + def queued_event( + self, ex_info, command, host_info, queue_time, config, meta_info, _id + ): if self.overwrite is not None: raise RuntimeError("Can't overwrite with QUEUED run.") self.run_entry = { - 'experiment': dict(ex_info), - 'command': command, - 'host': dict(host_info), - 'config': flatten(config), - 'meta': meta_info, - 'status': 'QUEUED' + "experiment": dict(ex_info), + "command": command, + "host": dict(host_info), + "config": flatten(config), + "meta": meta_info, + "status": "QUEUED", } # set ID if given if _id is not None: - self.run_entry['_id'] = _id + self.run_entry["_id"] = _id # save sources - self.run_entry['experiment']['sources'] = self.save_sources(ex_info) + self.run_entry["experiment"]["sources"] = self.save_sources(ex_info) self.insert() - return self.run_entry['_id'] + return self.run_entry["_id"] - def started_event(self, ex_info, command, host_info, start_time, config, - meta_info, _id): + def started_event( + self, ex_info, command, host_info, start_time, config, meta_info, _id + ): if self.overwrite is None: - self.run_entry = {'_id': _id} + self.run_entry = {"_id": _id} else: if self.run_entry is not None: raise RuntimeError("Cannot overwrite more than once!") # TODO sanity checks self.run_entry = self.overwrite - self.run_entry.update({ - 'experiment': dict(ex_info), - 'format': self.VERSION, - 'command': command, - 'host': dict(host_info), - 'start_time': start_time, - 'config': flatten(config), - 'meta': meta_info, - 'status': 'RUNNING', - 'resources': [], - 'artifacts': [], - 'captured_out': '', - 'info': {}, - 'heartbeat': None - }) + self.run_entry.update( + { + "experiment": dict(ex_info), + "format": self.VERSION, + "command": command, + "host": dict(host_info), + "start_time": start_time, + "config": flatten(config), + "meta": meta_info, + "status": "RUNNING", + "resources": [], + "artifacts": [], + "captured_out": "", + "info": {}, + "heartbeat": None, + } + ) # save sources - self.run_entry['experiment']['sources'] = self.save_sources(ex_info) + self.run_entry["experiment"]["sources"] = self.save_sources(ex_info) self.insert() - return self.run_entry['_id'] + return self.run_entry["_id"] def heartbeat_event(self, info, captured_out, beat_time, result): - self.run_entry['info'] = flatten(info) - self.run_entry['captured_out'] = captured_out - self.run_entry['heartbeat'] = beat_time - self.run_entry['result'] = flatten(result) + self.run_entry["info"] = flatten(info) + self.run_entry["captured_out"] = captured_out + self.run_entry["heartbeat"] = beat_time + self.run_entry["result"] = flatten(result) self.save() def completed_event(self, stop_time, result): - self.run_entry['stop_time'] = stop_time - self.run_entry['result'] = flatten(result) - self.run_entry['status'] = 'COMPLETED' + self.run_entry["stop_time"] = stop_time + self.run_entry["result"] = flatten(result) + self.run_entry["status"] = "COMPLETED" self.final_save(attempts=10) def interrupted_event(self, interrupt_time, status): - self.run_entry['stop_time'] = interrupt_time - self.run_entry['status'] = status + self.run_entry["stop_time"] = interrupt_time + self.run_entry["status"] = status self.final_save(attempts=3) def failed_event(self, fail_time, fail_trace): - self.run_entry['stop_time'] = fail_time - self.run_entry['status'] = 'FAILED' - self.run_entry['fail_trace'] = fail_trace + self.run_entry["stop_time"] = fail_time + self.run_entry["status"] = "FAILED" + self.run_entry["fail_trace"] = fail_trace self.final_save(attempts=1) def resource_event(self, filename): @@ -201,40 +234,42 @@ def resource_event(self, filename): md5hash = get_digest(filename) if self.fs.exists(filename=filename, md5=md5hash): resource = (filename, md5hash) - if resource not in self.run_entry['resources']: - self.run_entry['resources'].append(resource) + if resource not in self.run_entry["resources"]: + self.run_entry["resources"].append(resource) self.save() return - with open(filename, 'rb') as f: + with open(filename, "rb") as f: file_id = self.fs.put(f, filename=filename) md5hash = self.fs.get(file_id).md5 - self.run_entry['resources'].append((filename, md5hash)) + self.run_entry["resources"].append((filename, md5hash)) self.save() def artifact_event(self, name, filename, metadata=None, content_type=None): - with open(filename, 'rb') as f: - run_id = self.run_entry['_id'] - db_filename = 'artifact://{}/{}/{}'.format(self.runs.name, run_id, - name) + with open(filename, "rb") as f: + run_id = self.run_entry["_id"] + db_filename = "artifact://{}/{}/{}".format(self.runs.name, run_id, name) if content_type is None: content_type = self._try_to_detect_content_type(filename) - file_id = self.fs.put(f, filename=db_filename, - metadata=metadata, content_type=content_type) + file_id = self.fs.put( + f, filename=db_filename, metadata=metadata, content_type=content_type + ) - self.run_entry['artifacts'].append({'name': name, - 'file_id': file_id}) + self.run_entry["artifacts"].append({"name": name, "file_id": file_id}) self.save() @staticmethod def _try_to_detect_content_type(filename): mime_type, _ = mimetypes.guess_type(filename) if mime_type is not None: - print('Added {} as content-type of artifact {}.'.format( - mime_type, filename)) + print( + "Added {} as content-type of artifact {}.".format(mime_type, filename) + ) else: - print('Failed to detect content-type automatically for ' - 'artifact {}.'.format(filename)) + print( + "Failed to detect content-type automatically for " + "artifact {}.".format(filename) + ) return mime_type def log_metrics(self, metrics_by_name, info): @@ -250,18 +285,19 @@ def log_metrics(self, metrics_by_name, info): # do not try to save anything there. return for key in metrics_by_name: - query = {"run_id": self.run_entry['_id'], - "name": key} - push = {"steps": {"$each": metrics_by_name[key]["steps"]}, - "values": {"$each": metrics_by_name[key]["values"]}, - "timestamps": {"$each": metrics_by_name[key]["timestamps"]} - } + query = {"run_id": self.run_entry["_id"], "name": key} + push = { + "steps": {"$each": metrics_by_name[key]["steps"]}, + "values": {"$each": metrics_by_name[key]["values"]}, + "timestamps": {"$each": metrics_by_name[key]["timestamps"]}, + } update = {"$push": push} result = self.metrics.update_one(query, update, upsert=True) if result.upserted_id is not None: # This is the first time we are storing this metric - info.setdefault("metrics", []) \ - .append({"name": key, "id": str(result.upserted_id)}) + info.setdefault("metrics", []).append( + {"name": key, "id": str(result.upserted_id)} + ) def insert(self): import pymongo.errors @@ -269,18 +305,20 @@ def insert(self): if self.overwrite: return self.save() - autoinc_key = self.run_entry.get('_id') is None + autoinc_key = self.run_entry.get("_id") is None while True: if autoinc_key: - c = self.runs.find({}, {'_id': 1}) - c = c.sort('_id', pymongo.DESCENDING).limit(1) - self.run_entry['_id'] = c.next()['_id'] + 1 if c.count() else 1 + c = self.runs.find({}, {"_id": 1}) + c = c.sort("_id", pymongo.DESCENDING).limit(1) + self.run_entry["_id"] = c.next()["_id"] + 1 if c.count() else 1 try: self.runs.insert_one(self.run_entry) return except pymongo.errors.InvalidDocument as e: - raise ObserverError('Run contained an unserializable entry.' - '(most likely in the info)\n{}'.format(e)) + raise ObserverError( + "Run contained an unserializable entry." + "(most likely in the info)\n{}".format(e) + ) except pymongo.errors.DuplicateKeyError: if not autoinc_key: raise @@ -289,21 +327,26 @@ def save(self): import pymongo.errors try: - self.runs.update_one({'_id': self.run_entry['_id']}, - {'$set': self.run_entry}) + self.runs.update_one( + {"_id": self.run_entry["_id"]}, {"$set": self.run_entry} + ) except pymongo.errors.AutoReconnect: pass # just wait for the next save except pymongo.errors.InvalidDocument: - raise ObserverError('Run contained an unserializable entry.' - '(most likely in the info)') + raise ObserverError( + "Run contained an unserializable entry." "(most likely in the info)" + ) def final_save(self, attempts): import pymongo.errors for i in range(attempts): try: - self.runs.update_one({'_id': self.run_entry['_id']}, - {'$set': self.run_entry}, upsert=True) + self.runs.update_one( + {"_id": self.run_entry["_id"]}, + {"$set": self.run_entry}, + upsert=True, + ) return except pymongo.errors.AutoReconnect: if i < attempts - 1: @@ -312,33 +355,38 @@ def final_save(self, attempts): pass except pymongo.errors.InvalidDocument: self.run_entry = force_bson_encodeable(self.run_entry) - print("Warning: Some of the entries of the run were not " - "BSON-serializable!\n They have been altered such that " - "they can be stored, but you should fix your experiment!" - "Most likely it is either the 'info' or the 'result'.", - file=sys.stderr) + print( + "Warning: Some of the entries of the run were not " + "BSON-serializable!\n They have been altered such that " + "they can be stored, but you should fix your experiment!" + "Most likely it is either the 'info' or the 'result'.", + file=sys.stderr, + ) os.makedirs(self.failure_dir, exist_ok=True) - with NamedTemporaryFile(suffix='.pickle', delete=False, - prefix='sacred_mongo_fail_{}_'.format( - self.run_entry["_id"] - ), - dir=self.failure_dir) as f: + with NamedTemporaryFile( + suffix=".pickle", + delete=False, + prefix="sacred_mongo_fail_{}_".format(self.run_entry["_id"]), + dir=self.failure_dir, + ) as f: pickle.dump(self.run_entry, f) - print("Warning: saving to MongoDB failed! " - "Stored experiment entry in '{}'".format(f.name), - file=sys.stderr) + print( + "Warning: saving to MongoDB failed! " + "Stored experiment entry in '{}'".format(f.name), + file=sys.stderr, + ) def save_sources(self, ex_info): - base_dir = ex_info['base_dir'] + base_dir = ex_info["base_dir"] source_info = [] - for source_name, md5 in ex_info['sources']: + for source_name, md5 in ex_info["sources"]: abs_path = os.path.join(base_dir, source_name) - file = self.fs.find_one({'filename': abs_path, 'md5': md5}) + file = self.fs.find_one({"filename": abs_path, "md5": md5}) if file: _id = file._id else: - with open(abs_path, 'rb') as f: + with open(abs_path, "rb") as f: _id = self.fs.put(f, filename=abs_path) source_info.append([source_name, _id]) return source_info @@ -352,36 +400,47 @@ def __eq__(self, other): class MongoDbOption(CommandLineOption): """Add a MongoDB Observer to the experiment.""" - arg = 'DB' - arg_description = "Database specification. Can be " \ - "[host:port:]db_name[.collection[:id]][!priority]" + arg = "DB" + arg_description = ( + "Database specification. Can be " + "[host:port:]db_name[.collection[:id]][!priority]" + ) RUN_ID_PATTERN = r"(?P\d{1,12})" PORT1_PATTERN = r"(?P\d{1,5})" PORT2_PATTERN = r"(?P\d{1,5})" PRIORITY_PATTERN = r"(?P-?\d+)?" - DB_NAME_PATTERN = r"(?P[_A-Za-z]" \ - r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})" - COLL_NAME_PATTERN = r"(?P[_A-Za-z]" \ - r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})" - HOSTNAME1_PATTERN = r"(?P" \ - r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" \ - r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" \ - r"[0-9A-Za-z])?)*)" - HOSTNAME2_PATTERN = r"(?P" \ - r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" \ - r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" \ - r"[0-9A-Za-z])?)*)" - - HOST_ONLY = r"^(?:{host}:{port})$".format(host=HOSTNAME1_PATTERN, - port=PORT1_PATTERN) - FULL = r"^(?:{host}:{port}:)?{db}(?:\.{collection}(?::{rid})?)?" \ - r"(?:!{priority})?$".format(host=HOSTNAME2_PATTERN, - port=PORT2_PATTERN, - db=DB_NAME_PATTERN, - collection=COLL_NAME_PATTERN, - rid=RUN_ID_PATTERN, - priority=PRIORITY_PATTERN) + DB_NAME_PATTERN = r"(?P[_A-Za-z]" r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})" + COLL_NAME_PATTERN = ( + r"(?P[_A-Za-z]" r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})" + ) + HOSTNAME1_PATTERN = ( + r"(?P" + r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" + r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" + r"[0-9A-Za-z])?)*)" + ) + HOSTNAME2_PATTERN = ( + r"(?P" + r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" + r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" + r"[0-9A-Za-z])?)*)" + ) + + HOST_ONLY = r"^(?:{host}:{port})$".format( + host=HOSTNAME1_PATTERN, port=PORT1_PATTERN + ) + FULL = ( + r"^(?:{host}:{port}:)?{db}(?:\.{collection}(?::{rid})?)?" + r"(?:!{priority})?$".format( + host=HOSTNAME2_PATTERN, + port=PORT2_PATTERN, + db=DB_NAME_PATTERN, + collection=COLL_NAME_PATTERN, + rid=RUN_ID_PATTERN, + priority=PRIORITY_PATTERN, + ) + ) PATTERN = r"{host_only}|{full}".format(host_only=HOST_ONLY, full=FULL) @@ -395,20 +454,21 @@ def apply(cls, args, run): def parse_mongo_db_arg(cls, mongo_db): g = re.match(cls.PATTERN, mongo_db).groupdict() if g is None: - raise ValueError('mongo_db argument must have the form "db_name" ' - 'or "host:port[:db_name]" but was {}' - .format(mongo_db)) + raise ValueError( + 'mongo_db argument must have the form "db_name" ' + 'or "host:port[:db_name]" but was {}'.format(mongo_db) + ) kwargs = {} - if g['host1']: - kwargs['url'] = '{}:{}'.format(g['host1'], g['port1']) - elif g['host2']: - kwargs['url'] = '{}:{}'.format(g['host2'], g['port2']) + if g["host1"]: + kwargs["url"] = "{}:{}".format(g["host1"], g["port1"]) + elif g["host2"]: + kwargs["url"] = "{}:{}".format(g["host2"], g["port2"]) - if g['priority'] is not None: - kwargs['priority'] = int(g['priority']) + if g["priority"] is not None: + kwargs["priority"] = int(g["priority"]) - for p in ['db_name', 'collection', 'overwrite']: + for p in ["db_name", "collection", "overwrite"]: if g[p] is not None: kwargs[p] = g[p] @@ -416,7 +476,6 @@ def parse_mongo_db_arg(cls, mongo_db): class QueueCompatibleMongoObserver(MongoObserver): - def log_metrics(self, metric_name, metrics_values, info): """Store new measurements to the database. @@ -429,63 +488,88 @@ def log_metrics(self, metric_name, metrics_values, info): # If, for whatever reason, the metrics collection has not been set # do not try to save anything there. return - query = {"run_id": self.run_entry['_id'], - "name": metric_name} - push = {"steps": {"$each": metrics_values["steps"]}, - "values": {"$each": metrics_values["values"]}, - "timestamps": {"$each": metrics_values["timestamps"]}} + query = {"run_id": self.run_entry["_id"], "name": metric_name} + push = { + "steps": {"$each": metrics_values["steps"]}, + "values": {"$each": metrics_values["values"]}, + "timestamps": {"$each": metrics_values["timestamps"]}, + } update = {"$push": push} result = self.metrics.update_one(query, update, upsert=True) if result.upserted_id is not None: # This is the first time we are storing this metric - info.setdefault("metrics", []) \ - .append({"name": metric_name, "id": str(result.upserted_id)}) + info.setdefault("metrics", []).append( + {"name": metric_name, "id": str(result.upserted_id)} + ) def save(self): import pymongo + try: - self.runs.update_one({'_id': self.run_entry['_id']}, - {'$set': self.run_entry}) + self.runs.update_one( + {"_id": self.run_entry["_id"]}, {"$set": self.run_entry} + ) except pymongo.errors.InvalidDocument: - raise ObserverError('Run contained an unserializable entry.' - '(most likely in the info)') + raise ObserverError( + "Run contained an unserializable entry." "(most likely in the info)" + ) def final_save(self, attempts): import pymongo + try: - self.runs.update_one({'_id': self.run_entry['_id']}, - {'$set': self.run_entry}, upsert=True) + self.runs.update_one( + {"_id": self.run_entry["_id"]}, {"$set": self.run_entry}, upsert=True + ) return except pymongo.errors.InvalidDocument: self.run_entry = force_bson_encodeable(self.run_entry) - print("Warning: Some of the entries of the run were not " - "BSON-serializable!\n They have been altered such that " - "they can be stored, but you should fix your experiment!" - "Most likely it is either the 'info' or the 'result'.", - file=sys.stderr) - - with NamedTemporaryFile(suffix='.pickle', delete=False, - prefix='sacred_mongo_fail_') as f: + print( + "Warning: Some of the entries of the run were not " + "BSON-serializable!\n They have been altered such that " + "they can be stored, but you should fix your experiment!" + "Most likely it is either the 'info' or the 'result'.", + file=sys.stderr, + ) + + with NamedTemporaryFile( + suffix=".pickle", delete=False, prefix="sacred_mongo_fail_" + ) as f: pickle.dump(self.run_entry, f) - print("Warning: saving to MongoDB failed! " - "Stored experiment entry in '{}'".format(f.name), - file=sys.stderr) + print( + "Warning: saving to MongoDB failed! " + "Stored experiment entry in '{}'".format(f.name), + file=sys.stderr, + ) raise ObserverError("Warning: saving to MongoDB failed!") class QueuedMongoObserver(QueueObserver): @classmethod - def create(cls, interval=20, retry_interval=10, url=None, db_name='sacred', - collection='runs', overwrite=None, - priority=DEFAULT_MONGO_PRIORITY, client=None, **kwargs): + def create( + cls, + interval=20, + retry_interval=10, + url=None, + db_name="sacred", + collection="runs", + overwrite=None, + priority=DEFAULT_MONGO_PRIORITY, + client=None, + **kwargs + ): return cls( - QueueCompatibleMongoObserver.create(url=url, db_name=db_name, - collection=collection, - overwrite=overwrite, - priority=priority, - client=client, **kwargs), + QueueCompatibleMongoObserver.create( + url=url, + db_name=db_name, + collection=collection, + overwrite=overwrite, + priority=priority, + client=client, + **kwargs, + ), interval=interval, retry_interval=retry_interval, ) diff --git a/sacred/observers/queue.py b/sacred/observers/queue.py index a5acef7a..4232a5d1 100644 --- a/sacred/observers/queue.py +++ b/sacred/observers/queue.py @@ -11,7 +11,6 @@ class QueueObserver(RunObserver): - def __init__(self, covered_observer, interval=20, retry_interval=10): self._covered_observer = covered_observer self._retry_interval = retry_interval @@ -26,8 +25,7 @@ def queued_event(self, *args, **kwargs): def started_event(self, *args, **kwargs): self._queue = Queue() self._stop_worker_event, self._worker = IntervalTimer.create( - self._run, - interval=self._interval, + self._run, interval=self._interval ) self._worker.start() @@ -59,11 +57,7 @@ def artifact_event(self, *args, **kwargs): def log_metrics(self, metrics_by_name, info): for metric_name, metric_values in metrics_by_name.items(): self._queue.put( - WrappedEvent( - "log_metrics", - [metric_name, metric_values, info], - {}, - ) + WrappedEvent("log_metrics", [metric_name, metric_values, info], {}) ) def _run(self): diff --git a/sacred/observers/slack.py b/sacred/observers/slack.py index 32c82231..8e497cf8 100644 --- a/sacred/observers/slack.py +++ b/sacred/observers/slack.py @@ -25,40 +25,56 @@ def from_config(cls, filename): """ d = load_config_file(filename) obs = None - if 'webhook_url' in d: - obs = cls(d['webhook_url']) + if "webhook_url" in d: + obs = cls(d["webhook_url"]) else: - raise ValueError("Slack configuration file must contain " - "an entry for 'webhook_url'!") - for k in ['completed_text', 'interrupted_text', 'failed_text', - 'bot_name', 'icon']: + raise ValueError( + "Slack configuration file must contain " "an entry for 'webhook_url'!" + ) + for k in [ + "completed_text", + "interrupted_text", + "failed_text", + "bot_name", + "icon", + ]: if k in d: setattr(obs, k, d[k]) return obs - def __init__(self, webhook_url, bot_name="sacred-bot", icon=":angel:", - priority=DEFAULT_SLACK_PRIORITY): + def __init__( + self, + webhook_url, + bot_name="sacred-bot", + icon=":angel:", + priority=DEFAULT_SLACK_PRIORITY, + ): self.webhook_url = webhook_url self.bot_name = bot_name self.icon = icon - self.completed_text = ":white_check_mark: *{experiment[name]}* " \ + self.completed_text = ( + ":white_check_mark: *{experiment[name]}* " "completed after _{elapsed_time}_ with result=`{result}`" - self.interrupted_text = ":warning: *{experiment[name]}* " \ - "interrupted after _{elapsed_time}_" - self.failed_text = ":x: *{experiment[name]}* failed after " \ - "_{elapsed_time}_ with `{error}`" + ) + self.interrupted_text = ( + ":warning: *{experiment[name]}* " "interrupted after _{elapsed_time}_" + ) + self.failed_text = ( + ":x: *{experiment[name]}* failed after " "_{elapsed_time}_ with `{error}`" + ) self.run = None self.priority = priority - def started_event(self, ex_info, command, host_info, start_time, config, - meta_info, _id): + def started_event( + self, ex_info, command, host_info, start_time, config, meta_info, _id + ): self.run = { - '_id': _id, - 'config': config, - 'start_time': start_time, - 'experiment': ex_info, - 'command': command, - 'host_info': host_info, + "_id": _id, + "config": config, + "start_time": start_time, + "experiment": ex_info, + "command": command, + "host_info": host_info, } def get_completed_text(self): @@ -72,55 +88,55 @@ def get_failed_text(self): def completed_event(self, stop_time, result): import requests + if self.completed_text is None: return - self.run['result'] = result - self.run['stop_time'] = stop_time - self.run['elapsed_time'] = td_format(stop_time - - self.run['start_time']) + self.run["result"] = result + self.run["stop_time"] = stop_time + self.run["elapsed_time"] = td_format(stop_time - self.run["start_time"]) data = { "username": self.bot_name, "icon_emoji": self.icon, - "text": self.get_completed_text() + "text": self.get_completed_text(), } - headers = {'Content-type': 'application/json', 'Accept': 'text/plain'} + headers = {"Content-type": "application/json", "Accept": "text/plain"} requests.post(self.webhook_url, data=json.dumps(data), headers=headers) def interrupted_event(self, interrupt_time, status): import requests + if self.interrupted_text is None: return - self.run['status'] = status - self.run['interrupt_time'] = interrupt_time - self.run['elapsed_time'] = td_format(interrupt_time - - self.run['start_time']) + self.run["status"] = status + self.run["interrupt_time"] = interrupt_time + self.run["elapsed_time"] = td_format(interrupt_time - self.run["start_time"]) data = { "username": self.bot_name, "icon_emoji": self.icon, - "text": self.get_interrupted_text() + "text": self.get_interrupted_text(), } - headers = {'Content-type': 'application/json', 'Accept': 'text/plain'} + headers = {"Content-type": "application/json", "Accept": "text/plain"} requests.post(self.webhook_url, data=json.dumps(data), headers=headers) def failed_event(self, fail_time, fail_trace): import requests + if self.failed_text is None: return - self.run['fail_trace'] = fail_trace - self.run['error'] = fail_trace[-1].strip() - self.run['fail_time'] = fail_time - self.run['elapsed_time'] = td_format(fail_time - - self.run['start_time']) + self.run["fail_trace"] = fail_trace + self.run["error"] = fail_trace[-1].strip() + self.run["fail_time"] = fail_time + self.run["elapsed_time"] = td_format(fail_time - self.run["start_time"]) data = { "username": self.bot_name, "icon_emoji": self.icon, - "text": self.get_failed_text() + "text": self.get_failed_text(), } - headers = {'Content-type': 'application/json', 'Accept': 'text/plain'} + headers = {"Content-type": "application/json", "Accept": "text/plain"} requests.post(self.webhook_url, data=json.dumps(data), headers=headers) diff --git a/sacred/observers/sql.py b/sacred/observers/sql.py index fedde548..b97810d6 100644 --- a/sacred/observers/sql.py +++ b/sacred/observers/sql.py @@ -13,11 +13,13 @@ # ############################# Observer #################################### # + class SqlObserver(RunObserver): @classmethod def create(cls, url, echo=False, priority=DEFAULT_SQL_PRIORITY): from sqlalchemy.orm import sessionmaker, scoped_session import sqlalchemy as sa + engine = sa.create_engine(url, echo=echo) session_factory = sessionmaker(bind=engine) # make session thread-local to avoid problems with sqlite (see #275) @@ -31,20 +33,32 @@ def __init__(self, engine, session, priority=DEFAULT_SQL_PRIORITY): self.run = None self.lock = Lock() - def started_event(self, ex_info, command, host_info, start_time, config, - meta_info, _id): - return self._add_event(ex_info, command, host_info, config, - meta_info, _id, 'RUNNING', - start_time=start_time) - - def queued_event(self, ex_info, command, host_info, queue_time, config, - meta_info, _id): - return self._add_event(ex_info, command, host_info, config, - meta_info, _id, 'QUEUED') - - def _add_event(self, ex_info, command, host_info, config, - meta_info, _id, status, **kwargs): + def started_event( + self, ex_info, command, host_info, start_time, config, meta_info, _id + ): + return self._add_event( + ex_info, + command, + host_info, + config, + meta_info, + _id, + "RUNNING", + start_time=start_time, + ) + + def queued_event( + self, ex_info, command, host_info, queue_time, config, meta_info, _id + ): + return self._add_event( + ex_info, command, host_info, config, meta_info, _id, "QUEUED" + ) + + def _add_event( + self, ex_info, command, host_info, config, meta_info, _id, status, **kwargs + ): from .sql_bases import Base, Experiment, Host, Run + Base.metadata.create_all(self.engine) sql_exp = Experiment.get_or_create(ex_info, self.session) sql_host = Host.get_or_create(host_info, self.session) @@ -52,15 +66,17 @@ def _add_event(self, ex_info, command, host_info, config, i = self.session.query(Run).order_by(Run.id.desc()).first() _id = 0 if i is None else i.id + 1 - self.run = Run(run_id=str(_id), - config=json.dumps(flatten(config)), - command=command, - priority=meta_info.get('priority', 0), - comment=meta_info.get('comment', ''), - experiment=sql_exp, - host=sql_host, - status=status, - **kwargs) + self.run = Run( + run_id=str(_id), + config=json.dumps(flatten(config)), + command=command, + priority=meta_info.get("priority", 0), + comment=meta_info.get("comment", ""), + experiment=sql_exp, + host=sql_host, + status=status, + **kwargs, + ) self.session.add(self.run) self.save() return _id or self.run.run_id @@ -75,7 +91,7 @@ def heartbeat_event(self, info, captured_out, beat_time, result): def completed_event(self, stop_time, result): self.run.stop_time = stop_time self.run.result = result - self.run.status = 'COMPLETED' + self.run.status = "COMPLETED" self.save() def interrupted_event(self, interrupt_time, status): @@ -85,18 +101,20 @@ def interrupted_event(self, interrupt_time, status): def failed_event(self, fail_time, fail_trace): self.run.stop_time = fail_time - self.run.fail_trace = '\n'.join(fail_trace) - self.run.status = 'FAILED' + self.run.fail_trace = "\n".join(fail_trace) + self.run.status = "FAILED" self.save() def resource_event(self, filename): from .sql_bases import Resource + res = Resource.get_or_create(filename, self.session) self.run.resources.append(res) self.save() def artifact_event(self, name, filename, metadata=None, content_type=None): from .sql_bases import Artifact + a = Artifact.create(name, filename) self.run.artifacts.append(a) self.save() @@ -107,25 +125,27 @@ def save(self): def query(self, _id): from .sql_bases import Run + run = self.session.query(Run).filter_by(id=_id).first() return run.to_json() def __eq__(self, other): if isinstance(other, SqlObserver): # fixme: this will probably fail to detect two equivalent engines - return (self.engine == other.engine and - self.session == other.session) + return self.engine == other.engine and self.session == other.session return False # ######################## Commandline Option ############################### # + class SqlOption(CommandLineOption): """Add a SQL Observer to the experiment.""" - arg = 'DB_URL' - arg_description = \ + arg = "DB_URL" + arg_description = ( "The typical form is: dialect://username:password@host:port/database" + ) @classmethod def apply(cls, args, run): diff --git a/sacred/observers/sql_bases.py b/sacred/observers/sql_bases.py index 61f3daa7..d5d53c4d 100644 --- a/sacred/observers/sql_bases.py +++ b/sacred/observers/sql_bases.py @@ -13,19 +13,21 @@ class Source(Base): - __tablename__ = 'source' + __tablename__ = "source" @classmethod def get_or_create(cls, filename, md5sum, session, basedir): - instance = session.query(cls).filter_by(filename=filename, - md5sum=md5sum).first() + instance = ( + session.query(cls).filter_by(filename=filename, md5sum=md5sum).first() + ) if instance: return instance full_path = os.path.join(basedir, filename) md5sum_ = get_digest(full_path) - assert md5sum_ == md5sum, 'found md5 mismatch for {}: {} != {}'\ - .format(filename, md5sum, md5sum_) - with open(full_path, 'r') as f: + assert md5sum_ == md5sum, "found md5 mismatch for {}: {} != {}".format( + filename, md5sum, md5sum_ + ) + with open(full_path, "r") as f: return cls(filename=filename, md5sum=md5sum, content=f.read()) source_id = sa.Column(sa.Integer, primary_key=True) @@ -34,18 +36,16 @@ def get_or_create(cls, filename, md5sum, session, basedir): content = sa.Column(sa.Text) def to_json(self): - return {'filename': self.filename, - 'md5sum': self.md5sum} + return {"filename": self.filename, "md5sum": self.md5sum} class Dependency(Base): - __tablename__ = 'dependency' + __tablename__ = "dependency" @classmethod def get_or_create(cls, dep, session): - name, _, version = dep.partition('==') - instance = session.query(cls).filter_by(name=name, - version=version).first() + name, _, version = dep.partition("==") + instance = session.query(cls).filter_by(name=name, version=version).first() if instance: return instance return cls(name=name, version=version) @@ -59,36 +59,36 @@ def to_json(self): class Artifact(Base): - __tablename__ = 'artifact' + __tablename__ = "artifact" @classmethod def create(cls, name, filename): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: return cls(filename=name, content=f.read()) artifact_id = sa.Column(sa.Integer, primary_key=True) filename = sa.Column(sa.String(64)) content = sa.Column(sa.LargeBinary) - run_id = sa.Column(sa.String(24), sa.ForeignKey('run.run_id')) - run = sa.orm.relationship("Run", backref=sa.orm.backref('artifacts')) + run_id = sa.Column(sa.String(24), sa.ForeignKey("run.run_id")) + run = sa.orm.relationship("Run", backref=sa.orm.backref("artifacts")) def to_json(self): - return {'_id': self.artifact_id, - 'filename': self.filename} + return {"_id": self.artifact_id, "filename": self.filename} class Resource(Base): - __tablename__ = 'resource' + __tablename__ = "resource" @classmethod def get_or_create(cls, filename, session): md5sum = get_digest(filename) - instance = session.query(cls).filter_by(filename=filename, - md5sum=md5sum).first() + instance = ( + session.query(cls).filter_by(filename=filename, md5sum=md5sum).first() + ) if instance: return instance - with open(filename, 'rb') as f: + with open(filename, "rb") as f: return cls(filename=filename, md5sum=md5sum, content=f.read()) resource_id = sa.Column(sa.Integer, primary_key=True) @@ -97,21 +97,20 @@ def get_or_create(cls, filename, session): content = sa.Column(sa.LargeBinary) def to_json(self): - return {'filename': self.filename, - 'md5sum': self.md5sum} + return {"filename": self.filename, "md5sum": self.md5sum} class Host(Base): - __tablename__ = 'host' + __tablename__ = "host" @classmethod def get_or_create(cls, host_info, session): h = dict( - hostname=host_info['hostname'], - cpu=host_info['cpu'], - os=host_info['os'][0], - os_info=host_info['os'][1], - python_version=host_info['python_version'] + hostname=host_info["hostname"], + cpu=host_info["cpu"], + os=host_info["os"][0], + os_info=host_info["os"][1], + python_version=host_info["python_version"], ) return session.query(cls).filter_by(**h).first() or cls(**h) @@ -124,81 +123,89 @@ def get_or_create(cls, host_info, session): python_version = sa.Column(sa.String(16)) def to_json(self): - return {'cpu': self.cpu, - 'hostname': self.hostname, - 'os': [self.os, self.os_info], - 'python_version': self.python_version} + return { + "cpu": self.cpu, + "hostname": self.hostname, + "os": [self.os, self.os_info], + "python_version": self.python_version, + } experiment_source_association = sa.Table( - 'experiments_sources', Base.metadata, - sa.Column('experiment_id', sa.Integer, - sa.ForeignKey('experiment.experiment_id')), - sa.Column('source_id', sa.Integer, sa.ForeignKey('source.source_id')) + "experiments_sources", + Base.metadata, + sa.Column("experiment_id", sa.Integer, sa.ForeignKey("experiment.experiment_id")), + sa.Column("source_id", sa.Integer, sa.ForeignKey("source.source_id")), ) experiment_dependency_association = sa.Table( - 'experiments_dependencies', Base.metadata, - sa.Column('experiment_id', sa.Integer, - sa.ForeignKey('experiment.experiment_id')), - sa.Column('dependency_id', sa.Integer, - sa.ForeignKey('dependency.dependency_id')) + "experiments_dependencies", + Base.metadata, + sa.Column("experiment_id", sa.Integer, sa.ForeignKey("experiment.experiment_id")), + sa.Column("dependency_id", sa.Integer, sa.ForeignKey("dependency.dependency_id")), ) class Experiment(Base): - __tablename__ = 'experiment' + __tablename__ = "experiment" @classmethod def get_or_create(cls, ex_info, session): - name = ex_info['name'] + name = ex_info["name"] # Compute a MD5sum of the ex_info to determine its uniqueness h = hashlib.md5() h.update(json.dumps(ex_info).encode()) md5 = h.hexdigest() - instance = session.query(cls).filter_by(name=name, - md5sum=md5).first() + instance = session.query(cls).filter_by(name=name, md5sum=md5).first() if instance: return instance - dependencies = [Dependency.get_or_create(d, session) - for d in ex_info['dependencies']] - sources = [Source.get_or_create(s, md5sum, session, - ex_info['base_dir']) - for s, md5sum in ex_info['sources']] - - return cls(name=name, dependencies=dependencies, sources=sources, - md5sum=md5, base_dir=ex_info['base_dir']) + dependencies = [ + Dependency.get_or_create(d, session) for d in ex_info["dependencies"] + ] + sources = [ + Source.get_or_create(s, md5sum, session, ex_info["base_dir"]) + for s, md5sum in ex_info["sources"] + ] + + return cls( + name=name, + dependencies=dependencies, + sources=sources, + md5sum=md5, + base_dir=ex_info["base_dir"], + ) experiment_id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(32)) md5sum = sa.Column(sa.String(32)) base_dir = sa.Column(sa.String(64)) - sources = sa.orm.relationship("Source", - secondary=experiment_source_association, - backref="experiments") + sources = sa.orm.relationship( + "Source", secondary=experiment_source_association, backref="experiments" + ) dependencies = sa.orm.relationship( - "Dependency", - secondary=experiment_dependency_association, - backref="experiments") + "Dependency", secondary=experiment_dependency_association, backref="experiments" + ) def to_json(self): - return {'name': self.name, - 'base_dir': self.base_dir, - 'sources': [s.to_json() for s in self.sources], - 'dependencies': [d.to_json() for d in self.dependencies]} + return { + "name": self.name, + "base_dir": self.base_dir, + "sources": [s.to_json() for s in self.sources], + "dependencies": [d.to_json() for d in self.dependencies], + } run_resource_association = sa.Table( - 'runs_resources', Base.metadata, - sa.Column('run_id', sa.String(24), sa.ForeignKey('run.run_id')), - sa.Column('resource_id', sa.Integer, - sa.ForeignKey('resource.resource_id')) + "runs_resources", + Base.metadata, + sa.Column("run_id", sa.String(24), sa.ForeignKey("run.run_id")), + sa.Column("resource_id", sa.Integer, sa.ForeignKey("resource.resource_id")), ) class Run(Base): - __tablename__ = 'run' + __tablename__ = "run" id = sa.Column(sa.Integer, primary_key=True) run_id = sa.Column(sa.String(24), unique=True) @@ -226,42 +233,46 @@ class Run(Base): config = sa.Column(sa.Text) info = sa.Column(sa.Text) - status = sa.Column(sa.Enum("RUNNING", "COMPLETED", "INTERRUPTED", - "TIMEOUT", "FAILED", name="status_enum")) + status = sa.Column( + sa.Enum( + "RUNNING", + "COMPLETED", + "INTERRUPTED", + "TIMEOUT", + "FAILED", + name="status_enum", + ) + ) - host_id = sa.Column(sa.Integer, sa.ForeignKey('host.host_id')) - host = sa.orm.relationship("Host", backref=sa.orm.backref('runs')) + host_id = sa.Column(sa.Integer, sa.ForeignKey("host.host_id")) + host = sa.orm.relationship("Host", backref=sa.orm.backref("runs")) - experiment_id = sa.Column(sa.Integer, - sa.ForeignKey('experiment.experiment_id')) - experiment = sa.orm.relationship("Experiment", - backref=sa.orm.backref('runs')) + experiment_id = sa.Column(sa.Integer, sa.ForeignKey("experiment.experiment_id")) + experiment = sa.orm.relationship("Experiment", backref=sa.orm.backref("runs")) # artifacts = backref - resources = sa.orm.relationship("Resource", - secondary=run_resource_association, - backref="runs") + resources = sa.orm.relationship( + "Resource", secondary=run_resource_association, backref="runs" + ) result = sa.Column(sa.Float) def to_json(self): return { - '_id': self.run_id, - 'command': self.command, - 'start_time': self.start_time, - 'heartbeat': self.heartbeat, - 'stop_time': self.stop_time, - 'queue_time': self.queue_time, - 'status': self.status, - 'result': self.result, - 'meta': { - 'comment': self.comment, - 'priority': self.priority}, - 'resources': [r.to_json() for r in self.resources], - 'artifacts': [a.to_json() for a in self.artifacts], - 'host': self.host.to_json(), - 'experiment': self.experiment.to_json(), - 'config': restore(json.loads(self.config)), - 'captured_out': self.captured_out, - 'fail_trace': self.fail_trace, + "_id": self.run_id, + "command": self.command, + "start_time": self.start_time, + "heartbeat": self.heartbeat, + "stop_time": self.stop_time, + "queue_time": self.queue_time, + "status": self.status, + "result": self.result, + "meta": {"comment": self.comment, "priority": self.priority}, + "resources": [r.to_json() for r in self.resources], + "artifacts": [a.to_json() for a in self.artifacts], + "host": self.host.to_json(), + "experiment": self.experiment.to_json(), + "config": restore(json.loads(self.config)), + "captured_out": self.captured_out, + "fail_trace": self.fail_trace, } diff --git a/sacred/observers/telegram_obs.py b/sacred/observers/telegram_obs.py index 7803120d..c216e1fc 100644 --- a/sacred/observers/telegram_obs.py +++ b/sacred/observers/telegram_obs.py @@ -11,35 +11,39 @@ class TelegramObserver(RunObserver): """Sends a message to Telegram upon completion/failing of an experiment.""" + @staticmethod def get_proxy_request(telegram_config): from telegram.utils.request import Request - if telegram_config['proxy_url'].startswith('socks5'): + if telegram_config["proxy_url"].startswith("socks5"): urllib3_proxy_kwargs = dict() - for key in ['username', 'password']: + for key in ["username", "password"]: if key in telegram_config: urllib3_proxy_kwargs[key] = telegram_config[key] - return Request(proxy_url=telegram_config['proxy_url'], - urllib3_proxy_kwargs=urllib3_proxy_kwargs) - elif telegram_config['proxy_url'].startswith('http'): - cred_string = '' - if 'username' in telegram_config: - cred_string += telegram_config['username'] - if 'password' in telegram_config: - cred_string += ':' + telegram_config['password'] + return Request( + proxy_url=telegram_config["proxy_url"], + urllib3_proxy_kwargs=urllib3_proxy_kwargs, + ) + elif telegram_config["proxy_url"].startswith("http"): + cred_string = "" + if "username" in telegram_config: + cred_string += telegram_config["username"] + if "password" in telegram_config: + cred_string += ":" + telegram_config["password"] if len(cred_string) > 0: - domain = (telegram_config['proxy_url'] - .split('/')[-1].split('@')[-1]) - cred_string += '@' - proxy_url = 'http://{}{}'.format(cred_string, domain) + domain = telegram_config["proxy_url"].split("/")[-1].split("@")[-1] + cred_string += "@" + proxy_url = "http://{}{}".format(cred_string, domain) return Request(proxy_url=proxy_url) else: - return Request(proxy_url=telegram_config['proxy_url']) + return Request(proxy_url=telegram_config["proxy_url"]) else: - raise Exception("Proxy URL should be in format " - "PROTOCOL://PROXY_HOST[:PROXY_PORT].\n" - "HTTP and Socks5 are supported.") + raise Exception( + "Proxy URL should be in format " + "PROTOCOL://PROXY_HOST[:PROXY_PORT].\n" + "HTTP and Socks5 are supported." + ) @classmethod def from_config(cls, filename): @@ -53,63 +57,81 @@ def from_config(cls, filename): ``interrupted_text``, and ``failed_text``. """ import telegram + d = load_config_file(filename) - request = cls.get_proxy_request(d) if 'proxy_url' in d else None + request = cls.get_proxy_request(d) if "proxy_url" in d else None - if 'token' in d and 'chat_id' in d: - bot = telegram.Bot(d['token'], request=request) + if "token" in d and "chat_id" in d: + bot = telegram.Bot(d["token"], request=request) obs = cls(bot, **d) else: - raise ValueError("Telegram configuration file must contain " - "entries for 'token' and 'chat_id'!") - for k in ['started_text', 'completed_text', 'interrupted_text', - 'failed_text']: + raise ValueError( + "Telegram configuration file must contain " + "entries for 'token' and 'chat_id'!" + ) + for k in ["started_text", "completed_text", "interrupted_text", "failed_text"]: if k in d: setattr(obs, k, d[k]) return obs - def __init__(self, bot, chat_id, silent_completion=False, - priority=DEFAULT_TELEGRAM_PRIORITY, **kwargs): + def __init__( + self, + bot, + chat_id, + silent_completion=False, + priority=DEFAULT_TELEGRAM_PRIORITY, + **kwargs + ): self.silent_completion = silent_completion self.chat_id = chat_id self.bot = bot - self.started_text = "♻ *{experiment[name]}* " \ - "started at _{start_time}_ " \ - "on host `{host_info[hostname]}`" - self.completed_text = "✅ *{experiment[name]}* " \ - "completed after _{elapsed_time}_ " \ - "with result=`{result}`" - self.interrupted_text = "⚠ *{experiment[name]}* " \ - "interrupted after _{elapsed_time}_" - self.failed_text = "❌ *{experiment[name]}* failed after " \ - "_{elapsed_time}_ with `{error}`\n\n" \ - "Backtrace:\n```{backtrace}```" + self.started_text = ( + "♻ *{experiment[name]}* " + "started at _{start_time}_ " + "on host `{host_info[hostname]}`" + ) + self.completed_text = ( + "✅ *{experiment[name]}* " + "completed after _{elapsed_time}_ " + "with result=`{result}`" + ) + self.interrupted_text = ( + "⚠ *{experiment[name]}* " "interrupted after _{elapsed_time}_" + ) + self.failed_text = ( + "❌ *{experiment[name]}* failed after " + "_{elapsed_time}_ with `{error}`\n\n" + "Backtrace:\n```{backtrace}```" + ) self.run = None self.priority = priority - def started_event(self, ex_info, command, host_info, start_time, config, - meta_info, _id): + def started_event( + self, ex_info, command, host_info, start_time, config, meta_info, _id + ): import telegram + self.run = { - '_id': _id, - 'config': config, - 'start_time': start_time, - 'experiment': ex_info, - 'command': command, - 'host_info': host_info, + "_id": _id, + "config": config, + "start_time": start_time, + "experiment": ex_info, + "command": command, + "host_info": host_info, } if self.started_text is None: return try: - self.bot.send_message(chat_id=self.chat_id, - text=self.get_started_text(), - disable_notification=True, - parse_mode=telegram.ParseMode.MARKDOWN) + self.bot.send_message( + chat_id=self.chat_id, + text=self.get_started_text(), + disable_notification=True, + parse_mode=telegram.ParseMode.MARKDOWN, + ) except Exception as e: - log = logging.getLogger('telegram-observer') - log.warning('failed to send start_event message via telegram.', - exc_info=e) + log = logging.getLogger("telegram-observer") + log.warning("failed to send start_event message via telegram.", exc_info=e) def get_started_text(self): return self.started_text.format(**self.run) @@ -122,7 +144,8 @@ def get_interrupted_text(self): def get_failed_text(self): return self.failed_text.format( - backtrace=''.join(self.run['fail_trace']), **self.run) + backtrace="".join(self.run["fail_trace"]), **self.run + ) def completed_event(self, stop_time, result): import telegram @@ -130,20 +153,22 @@ def completed_event(self, stop_time, result): if self.completed_text is None: return - self.run['result'] = result - self.run['stop_time'] = stop_time - self.run['elapsed_time'] = td_format(stop_time - - self.run['start_time']) + self.run["result"] = result + self.run["stop_time"] = stop_time + self.run["elapsed_time"] = td_format(stop_time - self.run["start_time"]) try: - self.bot.send_message(chat_id=self.chat_id, - text=self.get_completed_text(), - disable_notification=self.silent_completion, - parse_mode=telegram.ParseMode.MARKDOWN) + self.bot.send_message( + chat_id=self.chat_id, + text=self.get_completed_text(), + disable_notification=self.silent_completion, + parse_mode=telegram.ParseMode.MARKDOWN, + ) except Exception as e: - log = logging.getLogger('telegram-observer') - log.warning('failed to send completed_event message via telegram.', - exc_info=e) + log = logging.getLogger("telegram-observer") + log.warning( + "failed to send completed_event message via telegram.", exc_info=e + ) def interrupted_event(self, interrupt_time, status): import telegram @@ -151,20 +176,22 @@ def interrupted_event(self, interrupt_time, status): if self.interrupted_text is None: return - self.run['status'] = status - self.run['interrupt_time'] = interrupt_time - self.run['elapsed_time'] = td_format(interrupt_time - - self.run['start_time']) + self.run["status"] = status + self.run["interrupt_time"] = interrupt_time + self.run["elapsed_time"] = td_format(interrupt_time - self.run["start_time"]) try: - self.bot.send_message(chat_id=self.chat_id, - text=self.get_interrupted_text(), - disable_notification=False, - parse_mode=telegram.ParseMode.MARKDOWN) + self.bot.send_message( + chat_id=self.chat_id, + text=self.get_interrupted_text(), + disable_notification=False, + parse_mode=telegram.ParseMode.MARKDOWN, + ) except Exception as e: - log = logging.getLogger('telegram-observer') - log.warning('failed to send interrupted_event message ' - 'via telegram.', exc_info=e) + log = logging.getLogger("telegram-observer") + log.warning( + "failed to send interrupted_event message " "via telegram.", exc_info=e + ) def failed_event(self, fail_time, fail_trace): import telegram @@ -172,18 +199,18 @@ def failed_event(self, fail_time, fail_trace): if self.failed_text is None: return - self.run['fail_trace'] = fail_trace - self.run['error'] = fail_trace[-1].strip() - self.run['fail_time'] = fail_time - self.run['elapsed_time'] = td_format(fail_time - - self.run['start_time']) + self.run["fail_trace"] = fail_trace + self.run["error"] = fail_trace[-1].strip() + self.run["fail_time"] = fail_time + self.run["elapsed_time"] = td_format(fail_time - self.run["start_time"]) try: - self.bot.send_message(chat_id=self.chat_id, - text=self.get_failed_text(), - disable_notification=False, - parse_mode=telegram.ParseMode.MARKDOWN) + self.bot.send_message( + chat_id=self.chat_id, + text=self.get_failed_text(), + disable_notification=False, + parse_mode=telegram.ParseMode.MARKDOWN, + ) except Exception as e: - log = logging.getLogger('telegram-observer') - log.warning('failed to send failed_event message via telegram.', - exc_info=e) + log = logging.getLogger("telegram-observer") + log.warning("failed to send failed_event message via telegram.", exc_info=e) diff --git a/sacred/observers/tinydb_hashfs.py b/sacred/observers/tinydb_hashfs.py index f8244fc4..2c0318d1 100644 --- a/sacred/observers/tinydb_hashfs.py +++ b/sacred/observers/tinydb_hashfs.py @@ -1,8 +1,7 @@ #!/usr/bin/env python # coding=utf-8 -from __future__ import (division, print_function, unicode_literals, - absolute_import) +from __future__ import division, print_function, unicode_literals, absolute_import import os import textwrap @@ -19,8 +18,9 @@ class TinyDbObserver(RunObserver): VERSION = "TinyDbObserver-{}".format(__version__) @staticmethod - def create(path='./runs_db', overwrite=None): + def create(path="./runs_db", overwrite=None): from .tinydb_hashfs_bases import get_db_file_manager + root_dir = os.path.abspath(path) if not os.path.exists(root_dir): os.makedirs(root_dir) @@ -30,7 +30,7 @@ def create(path='./runs_db', overwrite=None): def __init__(self, db, fs, overwrite=None, root=None): self.db = db - self.runs = db.table('runs') + self.runs = db.table("runs") self.fs = fs self.overwrite = overwrite self.run_entry = {} @@ -47,14 +47,15 @@ def save(self): def save_sources(self, ex_info): from .tinydb_hashfs_bases import BufferedReaderWrapper + source_info = [] - for source_name, md5 in ex_info['sources']: + for source_name, md5 in ex_info["sources"]: # Substitute any HOME or Environment Vars to get absolute path - abs_path = os.path.join(ex_info['base_dir'], source_name) + abs_path = os.path.join(ex_info["base_dir"], source_name) abs_path = os.path.expanduser(abs_path) abs_path = os.path.expandvars(abs_path) - handle = BufferedReaderWrapper(open(abs_path, 'rb')) + handle = BufferedReaderWrapper(open(abs_path, "rb")) file = self.fs.get(md5) if file: @@ -65,84 +66,89 @@ def save_sources(self, ex_info): source_info.append([source_name, id_, handle]) return source_info - def queued_event(self, ex_info, command, host_info, queue_time, config, - meta_info, _id): - raise NotImplementedError('queued_event method is not implemented for' - ' local TinyDbObserver.') + def queued_event( + self, ex_info, command, host_info, queue_time, config, meta_info, _id + ): + raise NotImplementedError( + "queued_event method is not implemented for" " local TinyDbObserver." + ) - def started_event(self, ex_info, command, host_info, start_time, config, - meta_info, _id): + def started_event( + self, ex_info, command, host_info, start_time, config, meta_info, _id + ): self.db_run_id = None self.run_entry = { - 'experiment': dict(ex_info), - 'format': self.VERSION, - 'command': command, - 'host': dict(host_info), - 'start_time': start_time, - 'config': config, - 'meta': meta_info, - 'status': 'RUNNING', - 'resources': [], - 'artifacts': [], - 'captured_out': '', - 'info': {}, - 'heartbeat': None + "experiment": dict(ex_info), + "format": self.VERSION, + "command": command, + "host": dict(host_info), + "start_time": start_time, + "config": config, + "meta": meta_info, + "status": "RUNNING", + "resources": [], + "artifacts": [], + "captured_out": "", + "info": {}, + "heartbeat": None, } # set ID if not given if _id is None: _id = uuid.uuid4().hex - self.run_entry['_id'] = _id + self.run_entry["_id"] = _id # save sources - self.run_entry['experiment']['sources'] = self.save_sources(ex_info) + self.run_entry["experiment"]["sources"] = self.save_sources(ex_info) self.save() - return self.run_entry['_id'] + return self.run_entry["_id"] def heartbeat_event(self, info, captured_out, beat_time, result): - self.run_entry['info'] = info - self.run_entry['captured_out'] = captured_out - self.run_entry['heartbeat'] = beat_time - self.run_entry['result'] = result + self.run_entry["info"] = info + self.run_entry["captured_out"] = captured_out + self.run_entry["heartbeat"] = beat_time + self.run_entry["result"] = result self.save() def completed_event(self, stop_time, result): - self.run_entry['stop_time'] = stop_time - self.run_entry['result'] = result - self.run_entry['status'] = 'COMPLETED' + self.run_entry["stop_time"] = stop_time + self.run_entry["result"] = result + self.run_entry["status"] = "COMPLETED" self.save() def interrupted_event(self, interrupt_time, status): - self.run_entry['stop_time'] = interrupt_time - self.run_entry['status'] = status + self.run_entry["stop_time"] = interrupt_time + self.run_entry["status"] = status self.save() def failed_event(self, fail_time, fail_trace): - self.run_entry['stop_time'] = fail_time - self.run_entry['status'] = 'FAILED' - self.run_entry['fail_trace'] = fail_trace + self.run_entry["stop_time"] = fail_time + self.run_entry["status"] = "FAILED" + self.run_entry["fail_trace"] = fail_trace self.save() def resource_event(self, filename): from .tinydb_hashfs_bases import BufferedReaderWrapper + id_ = self.fs.put(filename).id - handle = BufferedReaderWrapper(open(filename, 'rb')) + handle = BufferedReaderWrapper(open(filename, "rb")) resource = [filename, id_, handle] - if resource not in self.run_entry['resources']: - self.run_entry['resources'].append(resource) + if resource not in self.run_entry["resources"]: + self.run_entry["resources"].append(resource) self.save() def artifact_event(self, name, filename, metadata=None, content_type=None): from .tinydb_hashfs_bases import BufferedReaderWrapper + id_ = self.fs.put(filename).id - handle = BufferedReaderWrapper(open(filename, 'rb')) + handle = BufferedReaderWrapper(open(filename, "rb")) artifact = [name, filename, id_, handle] - if artifact not in self.run_entry['artifacts']: - self.run_entry['artifacts'].append(artifact) + if artifact not in self.run_entry["artifacts"]: + self.run_entry["artifacts"].append(artifact) self.save() def __eq__(self, other): @@ -154,7 +160,7 @@ def __eq__(self, other): class TinyDbOption(CommandLineOption): """Add a TinyDB Observer to the experiment.""" - arg = 'BASEDIR' + arg = "BASEDIR" @classmethod def apply(cls, args, run): @@ -168,17 +174,17 @@ def parse_tinydb_arg(cls, args): class TinyDbReader: - def __init__(self, path): from .tinydb_hashfs_bases import get_db_file_manager + root_dir = os.path.abspath(path) if not os.path.exists(root_dir): - raise IOError('Path does not exist: %s' % path) + raise IOError("Path does not exist: %s" % path) db, fs = get_db_file_manager(root_dir) self.db = db - self.runs = db.table('runs') + self.runs = db.table("runs") self.fs = fs def search(self, *args, **kwargs): @@ -206,20 +212,22 @@ def fetch_files(self, exp_name=None, query=None, indices=None): all_matched_entries = [] for ent in entries: - rec = dict(exp_name=ent['experiment']['name'], - exp_id=ent['_id'], - date=ent['start_time']) + rec = dict( + exp_name=ent["experiment"]["name"], + exp_id=ent["_id"], + date=ent["start_time"], + ) - source_files = {x[0]: x[2] for x in ent['experiment']['sources']} - resource_files = {x[0]: x[2] for x in ent['resources']} - artifact_files = {x[0]: x[3] for x in ent['artifacts']} + source_files = {x[0]: x[2] for x in ent["experiment"]["sources"]} + resource_files = {x[0]: x[2] for x in ent["resources"]} + artifact_files = {x[0]: x[3] for x in ent["artifacts"]} if source_files: - rec['sources'] = source_files + rec["sources"] = source_files if resource_files: - rec['resources'] = resource_files + rec["resources"] = resource_files if artifact_files: - rec['artifacts'] = artifact_files + rec["artifacts"] = artifact_files all_matched_entries.append(rec) @@ -258,45 +266,47 @@ def fetch_report(self, exp_name=None, query=None, indices=None): all_matched_entries = [] for ent in entries: - date = ent['start_time'] - weekdays = 'Mon Tue Wed Thu Fri Sat Sun'.split() + date = ent["start_time"] + weekdays = "Mon Tue Wed Thu Fri Sat Sun".split() w = weekdays[date.weekday()] - date = ' '.join([w, date.strftime('%d %b %Y')]) + date = " ".join([w, date.strftime("%d %b %Y")]) - duration = ent['stop_time'] - ent['start_time'] + duration = ent["stop_time"] - ent["start_time"] secs = duration.total_seconds() hours, remainder = divmod(secs, 3600) minutes, seconds = divmod(remainder, 60) - duration = '%02d:%02d:%04.1f' % (hours, minutes, seconds) + duration = "%02d:%02d:%04.1f" % (hours, minutes, seconds) - parameters = self._dict_to_indented_list(ent['config']) + parameters = self._dict_to_indented_list(ent["config"]) - result = self._indent(ent['result'].__repr__(), prefix=' ') + result = self._indent(ent["result"].__repr__(), prefix=" ") - deps = ent['experiment']['dependencies'] - deps = self._indent('\n'.join(deps), prefix=' ') + deps = ent["experiment"]["dependencies"] + deps = self._indent("\n".join(deps), prefix=" ") - resources = [x[0] for x in ent['resources']] - resources = self._indent('\n'.join(resources), prefix=' ') + resources = [x[0] for x in ent["resources"]] + resources = self._indent("\n".join(resources), prefix=" ") - sources = [x[0] for x in ent['experiment']['sources']] - sources = self._indent('\n'.join(sources), prefix=' ') + sources = [x[0] for x in ent["experiment"]["sources"]] + sources = self._indent("\n".join(sources), prefix=" ") - artifacts = [x[0] for x in ent['artifacts']] - artifacts = self._indent('\n'.join(artifacts), prefix=' ') + artifacts = [x[0] for x in ent["artifacts"]] + artifacts = self._indent("\n".join(artifacts), prefix=" ") - none_str = ' None' + none_str = " None" - rec = dict(exp_name=ent['experiment']['name'], - exp_id=ent['_id'], - start_date=date, - duration=duration, - parameters=parameters if parameters else none_str, - result=result if result else none_str, - dependencies=deps if deps else none_str, - resources=resources if resources else none_str, - sources=sources if sources else none_str, - artifacts=artifacts if artifacts else none_str) + rec = dict( + exp_name=ent["experiment"]["name"], + exp_id=ent["_id"], + start_date=date, + duration=duration, + parameters=parameters if parameters else none_str, + result=result if result else none_str, + dependencies=deps if deps else none_str, + resources=resources if resources else none_str, + sources=sources if sources else none_str, + artifacts=artifacts if artifacts else none_str, + ) report = template.format(**rec) @@ -308,6 +318,7 @@ def fetch_metadata(self, exp_name=None, query=None, indices=None): """Return all metadata for matching experiment name, index or query.""" from tinydb import Query from tinydb.queries import QueryImpl + if exp_name or query: if query: assert type(query), QueryImpl @@ -326,14 +337,16 @@ def fetch_metadata(self, exp_name=None, query=None, indices=None): for idx in indices: if idx >= num_recs: raise ValueError( - 'Index value ({}) must be less than ' - 'number of records ({})'.format(idx, num_recs)) + "Index value ({}) must be less than " + "number of records ({})".format(idx, num_recs) + ) entries = [self.runs.all()[ind] for ind in indices] else: - raise ValueError('Must specify an experiment name, indicies or ' - 'pass custom query') + raise ValueError( + "Must specify an experiment name, indicies or " "pass custom query" + ) return entries @@ -341,25 +354,25 @@ def _dict_to_indented_list(self, d): d = OrderedDict(sorted(d.items(), key=lambda t: t[0])) - output_str = '' + output_str = "" for k, v in d.items(): - output_str += '%s: %s' % (k, v) - output_str += '\n' + output_str += "%s: %s" % (k, v) + output_str += "\n" - output_str = self._indent(output_str.strip(), prefix=' ') + output_str = self._indent(output_str.strip(), prefix=" ") return output_str def _indent(self, message, prefix): """Wrapper for indenting strings in Python 2 and 3.""" preferred_width = 150 - wrapper = textwrap.TextWrapper(initial_indent=prefix, - width=preferred_width, - subsequent_indent=prefix) + wrapper = textwrap.TextWrapper( + initial_indent=prefix, width=preferred_width, subsequent_indent=prefix + ) lines = message.splitlines() formatted_lines = [wrapper.fill(lin) for lin in lines] - formatted_text = '\n'.join(formatted_lines) + formatted_text = "\n".join(formatted_lines) return formatted_text diff --git a/sacred/observers/tinydb_hashfs_bases.py b/sacred/observers/tinydb_hashfs_bases.py index f91bc5e2..b3df2b52 100644 --- a/sacred/observers/tinydb_hashfs_bases.py +++ b/sacred/observers/tinydb_hashfs_bases.py @@ -44,10 +44,10 @@ class DateTimeSerializer(Serializer): OBJ_CLASS = dt.datetime # The class this serializer handles def encode(self, obj): - return obj.strftime('%Y-%m-%dT%H:%M:%S.%f') + return obj.strftime("%Y-%m-%dT%H:%M:%S.%f") def decode(self, s): - return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S.%f') + return dt.datetime.strptime(s, "%Y-%m-%dT%H:%M:%S.%f") class NdArraySerializer(Serializer): @@ -77,7 +77,7 @@ def encode(self, obj): return obj.to_json() def decode(self, s): - return opt.pandas.read_json(s, typ='series') + return opt.pandas.read_json(s, typ="series") class FileSerializer(Serializer): @@ -99,23 +99,18 @@ def decode(self, s): def get_db_file_manager(root_dir): - fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3, - width=2, algorithm='md5') + fs = HashFS(os.path.join(root_dir, "hashfs"), depth=3, width=2, algorithm="md5") # Setup Serialisation object for non list/dict objects serialization_store = SerializationMiddleware() - serialization_store.register_serializer(DateTimeSerializer(), 'TinyDate') - serialization_store.register_serializer(FileSerializer(fs), 'TinyFile') + serialization_store.register_serializer(DateTimeSerializer(), "TinyDate") + serialization_store.register_serializer(FileSerializer(fs), "TinyFile") if opt.has_numpy: - serialization_store.register_serializer(NdArraySerializer(), - 'TinyArray') + serialization_store.register_serializer(NdArraySerializer(), "TinyArray") if opt.has_pandas: - serialization_store.register_serializer(DataFrameSerializer(), - 'TinyDataFrame') - serialization_store.register_serializer(SeriesSerializer(), - 'TinySeries') + serialization_store.register_serializer(DataFrameSerializer(), "TinyDataFrame") + serialization_store.register_serializer(SeriesSerializer(), "TinySeries") - db = TinyDB(os.path.join(root_dir, 'metadata.json'), - storage=serialization_store) + db = TinyDB(os.path.join(root_dir, "metadata.json"), storage=serialization_store) return db, fs diff --git a/sacred/optional.py b/sacred/optional.py index 477e1e0f..5f54e0c9 100644 --- a/sacred/optional.py +++ b/sacred/optional.py @@ -16,11 +16,14 @@ def optional_import(*package_names): def get_tensorflow(): # Ensures backward and forward compatibility with TensorFlow 1 and 2. - if get_package_version('tensorflow') < parse_version('1.13.1'): + if get_package_version("tensorflow") < parse_version("1.13.1"): import warnings - warnings.warn("Use of TensorFlow 1.12 and older is deprecated. " - "Use Tensorflow 1.13 or newer instead.", - DeprecationWarning) + + warnings.warn( + "Use of TensorFlow 1.12 and older is deprecated. " + "Use Tensorflow 1.13 or newer instead.", + DeprecationWarning, + ) import tensorflow as tf else: import tensorflow.compat.v1 as tf @@ -39,15 +42,15 @@ def get_tensorflow(): try: has_libc, libc = True, ctypes.cdll.msvcrt # Windows except OSError: - has_libc, libc = True, ctypes.cdll.LoadLibrary(find_library('c')) + has_libc, libc = True, ctypes.cdll.LoadLibrary(find_library("c")) -has_numpy, np = optional_import('numpy') -has_yaml, yaml = optional_import('yaml') -has_pandas, pandas = optional_import('pandas') +has_numpy, np = optional_import("numpy") +has_yaml, yaml = optional_import("yaml") +has_pandas, pandas = optional_import("pandas") -has_sqlalchemy = modules_exist('sqlalchemy') -has_mako = modules_exist('mako') -has_gitpython = modules_exist('git') -has_tinydb = modules_exist('tinydb', 'tinydb_serialization', 'hashfs') +has_sqlalchemy = modules_exist("sqlalchemy") +has_mako = modules_exist("mako") +has_gitpython = modules_exist("git") +has_tinydb = modules_exist("tinydb", "tinydb_serialization", "hashfs") has_tensorflow = modules_exist("tensorflow") diff --git a/sacred/pytee.py b/sacred/pytee.py index dd2450f8..34925feb 100644 --- a/sacred/pytee.py +++ b/sacred/pytee.py @@ -1,10 +1,10 @@ #!/usr/bin/env python # coding=utf-8 -if __name__ == '__main__': +if __name__ == "__main__": import sys - buffer = ' ' + buffer = " " while len(buffer): buffer = sys.stdin.read() sys.stdout.write(buffer) diff --git a/sacred/randomness.py b/sacred/randomness.py index 5fe20d85..5de629f8 100644 --- a/sacred/randomness.py +++ b/sacred/randomness.py @@ -16,8 +16,9 @@ def get_seed(rnd=None): def create_rnd(seed): - assert isinstance(seed, int), \ - "Seed has to be integer but was {} {}".format(repr(seed), type(seed)) + assert isinstance(seed, int), "Seed has to be integer but was {} {}".format( + repr(seed), type(seed) + ) if opt.has_numpy: return opt.np.random.RandomState(seed) else: @@ -28,11 +29,12 @@ def set_global_seed(seed): random.seed(seed) if opt.has_numpy: opt.np.random.seed(seed) - if module_is_in_cache('tensorflow'): + if module_is_in_cache("tensorflow"): tf = opt.get_tensorflow() tf.set_random_seed(seed) - if module_is_in_cache('torch'): + if module_is_in_cache("torch"): import torch + torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) diff --git a/sacred/run.py b/sacred/run.py index 012a3914..a2cc8b9d 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -9,22 +9,32 @@ from sacred import metrics_logger from sacred.metrics_logger import linearize_metrics from sacred.randomness import set_global_seed -from sacred.utils import (SacredInterrupt, join_paths, - IntervalTimer) +from sacred.utils import SacredInterrupt, join_paths, IntervalTimer from sacred.stdout_capturing import get_stdcapturer class Run: """Represent and manage a single run of an experiment.""" - def __init__(self, config, config_modifications, main_function, observers, - root_logger, run_logger, experiment_info, host_info, - pre_run_hooks, post_run_hooks, captured_out_filter=None): + def __init__( + self, + config, + config_modifications, + main_function, + observers, + root_logger, + run_logger, + experiment_info, + host_info, + pre_run_hooks, + post_run_hooks, + captured_out_filter=None, + ): self._id = None """The ID of this run as assigned by the first observer""" - self.captured_out = '' + self.captured_out = "" """Captured stdout and stderr""" self.config = config @@ -108,7 +118,7 @@ def __init__(self, config, config_modifications, main_function, observers, self._metrics = metrics_logger.MetricsLogger() - def open_resource(self, filename, mode='r'): + def open_resource(self, filename, mode="r"): """Open a file and also save it as a resource. Opens a file, reports it to the observers as a resource, and returns @@ -156,13 +166,7 @@ def add_resource(self, filename): filename = os.path.abspath(filename) self._emit_resource_added(filename) - def add_artifact( - self, - filename, - name=None, - metadata=None, - content_type=None, - ): + def add_artifact(self, filename, name=None, metadata=None, content_type=None): """Add a file as an artifact. In Sacred terminology an artifact is a file produced by the experiment @@ -203,8 +207,10 @@ def __call__(self, *args): """ if self.start_time is not None: - raise RuntimeError('A run can only be started once. ' - '(Last start was {})'.format(self.start_time)) + raise RuntimeError( + "A run can only be started once. " + "(Last start was {})".format(self.start_time) + ) if self.unobserved: self.observers = [] @@ -212,7 +218,7 @@ def __call__(self, *args): self.observers = sorted(self.observers, key=lambda x: -x.priority) self.warn_if_unobserved() - set_global_seed(self.config['seed']) + set_global_seed(self.config["seed"]) if self.capture_mode is None and not self.observers: capture_mode = "no" @@ -232,15 +238,15 @@ def __call__(self, *args): self.result = self.main_function(*args) self._execute_post_run_hooks() if self.result is not None: - self.run_logger.info('Result: {}'.format(self.result)) + self.run_logger.info("Result: {}".format(self.result)) elapsed_time = self._stop_time() - self.run_logger.info('Completed after %s', elapsed_time) + self.run_logger.info("Completed after %s", elapsed_time) self._get_captured_output() self._stop_heartbeat() self._emit_completed(self.result) except (SacredInterrupt, KeyboardInterrupt) as e: self._stop_heartbeat() - status = getattr(e, 'STATUS', 'INTERRUPTED') + status = getattr(e, "STATUS", "INTERRUPTED") self._emit_interrupted(status) raise except BaseException: @@ -258,7 +264,7 @@ def _get_captured_output(self): return text = self._output_file.get() if isinstance(text, bytes): - text = text.decode('utf-8', 'replace') + text = text.decode("utf-8", "replace") if self.captured_out: text = self.captured_out + text if self.captured_out_filter is not None: @@ -266,28 +272,30 @@ def _get_captured_output(self): self.captured_out = text def _start_heartbeat(self): - self.run_logger.debug('Starting Heartbeat') + self.run_logger.debug("Starting Heartbeat") if self.beat_interval > 0: self._stop_heartbeat_event, self._heartbeat = IntervalTimer.create( - self._emit_heartbeat, self.beat_interval) + self._emit_heartbeat, self.beat_interval + ) self._heartbeat.start() def _stop_heartbeat(self): - self.run_logger.debug('Stopping Heartbeat') + self.run_logger.debug("Stopping Heartbeat") # only stop if heartbeat was started if self._heartbeat is not None: self._stop_heartbeat_event.set() self._heartbeat.join(timeout=2) def _emit_queued(self): - self.status = 'QUEUED' + self.status = "QUEUED" queue_time = datetime.datetime.utcnow() - self.meta_info['queue_time'] = queue_time - command = join_paths(self.main_function.prefix, - self.main_function.signature.name) + self.meta_info["queue_time"] = queue_time + command = join_paths( + self.main_function.prefix, self.main_function.signature.name + ) self.run_logger.info("Queuing-up command '%s'", command) for observer in self.observers: - if hasattr(observer, 'queued_event'): + if hasattr(observer, "queued_event"): _id = observer.queued_event( ex_info=self.experiment_info, command=command, @@ -295,7 +303,7 @@ def _emit_queued(self): queue_time=queue_time, config=self.config, meta_info=self.meta_info, - _id=self._id + _id=self._id, ) if self._id is None: self._id = _id @@ -303,18 +311,19 @@ def _emit_queued(self): # the experiment SHOULD fail if any of the observers fails if self._id is None: - self.run_logger.info('Queued') + self.run_logger.info("Queued") else: self.run_logger.info('Queued-up run with ID "{}"'.format(self._id)) def _emit_started(self): - self.status = 'RUNNING' + self.status = "RUNNING" self.start_time = datetime.datetime.utcnow() - command = join_paths(self.main_function.prefix, - self.main_function.signature.name) + command = join_paths( + self.main_function.prefix, self.main_function.signature.name + ) self.run_logger.info("Running command '%s'", command) for observer in self.observers: - if hasattr(observer, 'started_event'): + if hasattr(observer, "started_event"): _id = observer.started_event( ex_info=self.experiment_info, command=command, @@ -322,14 +331,14 @@ def _emit_started(self): start_time=self.start_time, config=self.config, meta_info=self.meta_info, - _id=self._id + _id=self._id, ) if self._id is None: self._id = _id # do not catch any exceptions on startup: # the experiment SHOULD fail if any of the observers fails if self._id is None: - self.run_logger.info('Started') + self.run_logger.info("Started") else: self.run_logger.info('Started run with ID "{}"'.format(self._id)) @@ -340,58 +349,71 @@ def _emit_heartbeat(self): logged_metrics = self._metrics.get_last_metrics() metrics_by_name = linearize_metrics(logged_metrics) for observer in self.observers: - self._safe_call(observer, 'log_metrics', - metrics_by_name=metrics_by_name, - info=self.info) - self._safe_call(observer, 'heartbeat_event', - info=self.info, - captured_out=self.captured_out, - beat_time=beat_time, - result=self.result) + self._safe_call( + observer, "log_metrics", metrics_by_name=metrics_by_name, info=self.info + ) + self._safe_call( + observer, + "heartbeat_event", + info=self.info, + captured_out=self.captured_out, + beat_time=beat_time, + result=self.result, + ) def _stop_time(self): self.stop_time = datetime.datetime.utcnow() elapsed_time = datetime.timedelta( - seconds=round((self.stop_time - self.start_time).total_seconds())) + seconds=round((self.stop_time - self.start_time).total_seconds()) + ) return elapsed_time def _emit_completed(self, result): - self.status = 'COMPLETED' + self.status = "COMPLETED" for observer in self.observers: - self._final_call(observer, 'completed_event', - stop_time=self.stop_time, - result=result) + self._final_call( + observer, "completed_event", stop_time=self.stop_time, result=result + ) def _emit_interrupted(self, status): self.status = status elapsed_time = self._stop_time() self.run_logger.warning("Aborted after %s!", elapsed_time) for observer in self.observers: - self._final_call(observer, 'interrupted_event', - interrupt_time=self.stop_time, - status=status) + self._final_call( + observer, + "interrupted_event", + interrupt_time=self.stop_time, + status=status, + ) def _emit_failed(self, exc_type, exc_value, trace): - self.status = 'FAILED' + self.status = "FAILED" elapsed_time = self._stop_time() self.run_logger.error("Failed after %s!", elapsed_time) self.fail_trace = tb.format_exception(exc_type, exc_value, trace) for observer in self.observers: - self._final_call(observer, 'failed_event', - fail_time=self.stop_time, - fail_trace=self.fail_trace) + self._final_call( + observer, + "failed_event", + fail_time=self.stop_time, + fail_trace=self.fail_trace, + ) def _emit_resource_added(self, filename): for observer in self.observers: - self._safe_call(observer, 'resource_event', filename=filename) + self._safe_call(observer, "resource_event", filename=filename) def _emit_artifact_added(self, name, filename, metadata, content_type): for observer in self.observers: - self._safe_call(observer, 'artifact_event', - name=name, - filename=filename, - metadata=metadata, - content_type=content_type) + self._safe_call( + observer, + "artifact_event", + name=name, + filename=filename, + metadata=metadata, + content_type=content_type, + ) def _safe_call(self, obs, method, **kwargs): if obs not in self._failed_observers and hasattr(obs, method): @@ -399,8 +421,9 @@ def _safe_call(self, obs, method, **kwargs): getattr(obs, method)(**kwargs) except Exception as e: self._failed_observers.append(obs) - self.run_logger.warning("An error ocurred in the '{}' " - "observer: {}".format(obs, e)) + self.run_logger.warning( + "An error ocurred in the '{}' " "observer: {}".format(obs, e) + ) def _final_call(self, observer, method, **kwargs): if hasattr(observer, method): @@ -415,12 +438,14 @@ def _final_call(self, observer, method, **kwargs): def _wait_for_observers(self): """Block until all observers finished processing.""" for observer in self.observers: - self._safe_call(observer, 'join') + self._safe_call(observer, "join") def _warn_about_failed_observers(self): for observer in self._failed_observers: - self.run_logger.warning("The observer '{}' failed at some point " - "during the run.".format(observer)) + self.run_logger.warning( + "The observer '{}' failed at some point " + "during the run.".format(observer) + ) def _execute_pre_run_hooks(self): for pr in self.pre_run_hooks: diff --git a/sacred/serializer.py b/sacred/serializer.py index 490e8606..22cb01a8 100644 --- a/sacred/serializer.py +++ b/sacred/serializer.py @@ -7,7 +7,7 @@ from sacred import optional as opt -__all__ = ('flatten', 'restore') +__all__ = ("flatten", "restore") # class DatetimeHandler(BaseHandler): @@ -25,8 +25,8 @@ class NumpyArrayHandler(BaseHandler): def flatten(self, obj, data): - data['values'] = obj.tolist() - data['dtype'] = str(obj.dtype) + data["values"] = obj.tolist() + data["dtype"] = str(obj.dtype) return data def restore(self, obj): @@ -40,9 +40,24 @@ def restore(self, obj): return obj NumpyArrayHandler.handles(np.ndarray) - for t in [np.bool_, np.int_, np.float_, np.intc, np.intp, np.int8, - np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, - np.uint64, np.float16, np.float32, np.float64]: + for t in [ + np.bool_, + np.int_, + np.float_, + np.intc, + np.intp, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + ]: NumpyGenericHandler.handles(t) @@ -52,19 +67,18 @@ def restore(self, obj): class PandasDataframeHandler(BaseHandler): def flatten(self, obj, data): # TODO: this is slow - data['values'] = json.loads(obj.to_json()) - data['dtypes'] = {k: str(v) for k, v in dict(obj.dtypes).items()} + data["values"] = json.loads(obj.to_json()) + data["dtypes"] = {k: str(v) for k, v in dict(obj.dtypes).items()} return data def restore(self, obj): # TODO: get rid of unnecessary json.dumps - return pd.read_json(json.dumps(obj['values']), - dtype=obj['dtypes']) + return pd.read_json(json.dumps(obj["values"]), dtype=obj["dtypes"]) PandasDataframeHandler.handles(pd.DataFrame) -json.set_encoder_options('simplejson', sort_keys=True, indent=4) -json.set_encoder_options('demjson', compactly=False) +json.set_encoder_options("simplejson", sort_keys=True, indent=4) +json.set_encoder_options("demjson", compactly=False) def flatten(obj): diff --git a/sacred/settings.py b/sacred/settings.py index a12ef5a0..fb08d8df 100644 --- a/sacred/settings.py +++ b/sacred/settings.py @@ -4,47 +4,48 @@ import platform from munch import munchify -__all__ = ('SETTINGS', ) +__all__ = ("SETTINGS",) -SETTINGS = munchify({ - 'CONFIG': { - # make sure all config keys are compatible with MongoDB - 'ENFORCE_KEYS_MONGO_COMPATIBLE': True, - # make sure all config keys are serializable with jsonpickle - # THIS IS IMPORTANT. Only deactivate if you know what you're doing. - 'ENFORCE_KEYS_JSONPICKLE_COMPATIBLE': True, - # make sure all config keys are valid python identifiers - 'ENFORCE_VALID_PYTHON_IDENTIFIER_KEYS': False, - # make sure all config keys are strings - 'ENFORCE_STRING_KEYS': False, - # make sure no config key contains an equals sign - 'ENFORCE_KEYS_NO_EQUALS': True, - # if true, all dicts and lists in the configuration of a captured - # function are replaced with a read-only container that raises an - # Exception if it is attempted to write to those containers - 'READ_ONLY_CONFIG': True, - - # regex patterns to filter out certain IDE or linter directives from - # inline comments in the documentation - 'IGNORED_COMMENTS': ['^pylint:', '^noinspection'], - }, - 'HOST_INFO': { - # Collect information about GPUs using the nvidia-smi tool - 'INCLUDE_GPU_INFO': True, - # List of ENVIRONMENT variables to store in host-info - 'CAPTURED_ENV': [] - }, - 'COMMAND_LINE': { - # disallow string fallback, if parsing a value from command-line failed - 'STRICT_PARSING': False, - # show command line options that are disabled (e.g. unmet dependencies) - 'SHOW_DISABLED_OPTIONS': True - }, - # configure how stdout/stderr are captured. ['no', 'sys', 'fd'] - 'CAPTURE_MODE': "sys" if platform.system() == "Windows" else "fd", - # configure how dependencies are discovered. [none, imported, sys, pkg] - 'DISCOVER_DEPENDENCIES': "imported", - # configure how source-files are discovered. [none, imported, sys, dir] - 'DISCOVER_SOURCES': "imported" -}) +SETTINGS = munchify( + { + "CONFIG": { + # make sure all config keys are compatible with MongoDB + "ENFORCE_KEYS_MONGO_COMPATIBLE": True, + # make sure all config keys are serializable with jsonpickle + # THIS IS IMPORTANT. Only deactivate if you know what you're doing. + "ENFORCE_KEYS_JSONPICKLE_COMPATIBLE": True, + # make sure all config keys are valid python identifiers + "ENFORCE_VALID_PYTHON_IDENTIFIER_KEYS": False, + # make sure all config keys are strings + "ENFORCE_STRING_KEYS": False, + # make sure no config key contains an equals sign + "ENFORCE_KEYS_NO_EQUALS": True, + # if true, all dicts and lists in the configuration of a captured + # function are replaced with a read-only container that raises an + # Exception if it is attempted to write to those containers + "READ_ONLY_CONFIG": True, + # regex patterns to filter out certain IDE or linter directives from + # inline comments in the documentation + "IGNORED_COMMENTS": ["^pylint:", "^noinspection"], + }, + "HOST_INFO": { + # Collect information about GPUs using the nvidia-smi tool + "INCLUDE_GPU_INFO": True, + # List of ENVIRONMENT variables to store in host-info + "CAPTURED_ENV": [], + }, + "COMMAND_LINE": { + # disallow string fallback, if parsing a value from command-line failed + "STRICT_PARSING": False, + # show command line options that are disabled (e.g. unmet dependencies) + "SHOW_DISABLED_OPTIONS": True, + }, + # configure how stdout/stderr are captured. ['no', 'sys', 'fd'] + "CAPTURE_MODE": "sys" if platform.system() == "Windows" else "fd", + # configure how dependencies are discovered. [none, imported, sys, pkg] + "DISCOVER_DEPENDENCIES": "imported", + # configure how source-files are discovered. [none, imported, sys, dir] + "DISCOVER_SOURCES": "imported", + } +) diff --git a/sacred/stdout_capturing.py b/sacred/stdout_capturing.py index 18e6f9f2..ff9f948e 100644 --- a/sacred/stdout_capturing.py +++ b/sacred/stdout_capturing.py @@ -27,13 +27,13 @@ def flush(): def get_stdcapturer(mode=None): mode = mode if mode is not None else SETTINGS.CAPTURE_MODE - capture_options = { - "no": no_tee, - "fd": tee_output_fd, - "sys": tee_output_python} + capture_options = {"no": no_tee, "fd": tee_output_fd, "sys": tee_output_python} if mode not in capture_options: - raise KeyError("Unknown capture mode '{}'. Available options are {}" - .format(mode, sorted(capture_options.keys()))) + raise KeyError( + "Unknown capture mode '{}'. Available options are {}".format( + mode, sorted(capture_options.keys()) + ) + ) return mode, capture_options[mode] @@ -116,7 +116,7 @@ def tee_output_python(): @contextmanager def tee_output_fd(): """Duplicate stdout and stderr to a file on the file descriptor level.""" - with NamedTemporaryFile(mode='w+') as target: + with NamedTemporaryFile(mode="w+") as target: original_stdout_fd = 1 original_stderr_fd = 2 target_fd = target.fileno() @@ -129,20 +129,30 @@ def tee_output_fd(): # start_new_session=True to move process to a new process group # this is done to avoid receiving KeyboardInterrupts (see #149) tee_stdout = subprocess.Popen( - ['tee', '-a', target.name], start_new_session=True, - stdin=subprocess.PIPE, stdout=1) + ["tee", "-a", target.name], + start_new_session=True, + stdin=subprocess.PIPE, + stdout=1, + ) tee_stderr = subprocess.Popen( - ['tee', '-a', target.name], start_new_session=True, - stdin=subprocess.PIPE, stdout=2) + ["tee", "-a", target.name], + start_new_session=True, + stdin=subprocess.PIPE, + stdout=2, + ) except (FileNotFoundError, OSError, AttributeError): # No tee found in this operating system. Trying to use a python # implementation of tee. However this is slow and error-prone. tee_stdout = subprocess.Popen( [sys.executable, "-m", "sacred.pytee"], - stdin=subprocess.PIPE, stderr=target_fd) + stdin=subprocess.PIPE, + stderr=target_fd, + ) tee_stderr = subprocess.Popen( [sys.executable, "-m", "sacred.pytee"], - stdin=subprocess.PIPE, stdout=target_fd) + stdin=subprocess.PIPE, + stdout=target_fd, + ) flush() os.dup2(tee_stdout.stdin.fileno(), original_stdout_fd) diff --git a/sacred/stflow/internal.py b/sacred/stflow/internal.py index 1645e1d0..e2387be9 100644 --- a/sacred/stflow/internal.py +++ b/sacred/stflow/internal.py @@ -1,7 +1,7 @@ import functools -class ContextMethodDecorator(): +class ContextMethodDecorator: """A helper ContextManager decorating a method with a custom function.""" def __init__(self, classx, method_name, decorator_func): @@ -31,17 +31,16 @@ def __init__(self, classx, method_name, decorator_func): def __enter__(self): self.original_method = getattr(self.classx, self.method_name) - if not hasattr(self.original_method, - "sacred_patched%s" % self.__class__.__name__): + if not hasattr( + self.original_method, "sacred_patched%s" % self.__class__.__name__ + ): + @functools.wraps(self.original_method) def decorated(instance, *args, **kwargs): - return self.decorator_func(instance, - self.original_method, args, - kwargs) + return self.decorator_func(instance, self.original_method, args, kwargs) setattr(self.classx, self.method_name, decorated) - setattr(decorated, - "sacred_patched%s" % self.__class__.__name__, True) + setattr(decorated, "sacred_patched%s" % self.__class__.__name__, True) self.patched_by_me = True def __exit__(self, type, value, traceback): diff --git a/sacred/stflow/method_interception.py b/sacred/stflow/method_interception.py index 83ba9bd3..b6052352 100644 --- a/sacred/stflow/method_interception.py +++ b/sacred/stflow/method_interception.py @@ -60,19 +60,19 @@ def run_experiment(_run): def __init__(self, experiment): self.experiment = experiment - def log_writer_decorator(instance, original_method, original_args, - original_kwargs): - result = original_method(instance, *original_args, - **original_kwargs) + def log_writer_decorator( + instance, original_method, original_args, original_kwargs + ): + result = original_method(instance, *original_args, **original_kwargs) if "logdir" in original_kwargs: logdir = original_kwargs["logdir"] else: logdir = original_args[0] self.experiment.info.setdefault("tensorflow", {}).setdefault( - "logdirs", []).append(logdir) + "logdirs", [] + ).append(logdir) return result - ContextMethodDecorator.__init__(self, - tf.summary.FileWriter, - "__init__", - log_writer_decorator) + ContextMethodDecorator.__init__( + self, tf.summary.FileWriter, "__init__", log_writer_decorator + ) diff --git a/sacred/utils.py b/sacred/utils.py index 631897f0..fb2e32bd 100755 --- a/sacred/utils.py +++ b/sacred/utils.py @@ -19,19 +19,35 @@ import wrapt -__all__ = ["NO_LOGGER", "PYTHON_IDENTIFIER", "CircularDependencyError", - "ObserverError", "SacredInterrupt", "TimeoutInterrupt", - "create_basic_stream_logger", "recursive_update", - "iterate_flattened", "iterate_flattened_separately", - "set_by_dotted_path", "get_by_dotted_path", "iter_path_splits", - "iter_prefixes", "join_paths", "is_prefix", - "convert_to_nested_dict", "convert_camel_case_to_snake_case", - "print_filtered_stacktrace", - "optional_kwargs_decorator", "get_inheritors", - "apply_backspaces_and_linefeeds", "rel_path", "IntervalTimer", - "PathType"] - -NO_LOGGER = logging.getLogger('ignore') +__all__ = [ + "NO_LOGGER", + "PYTHON_IDENTIFIER", + "CircularDependencyError", + "ObserverError", + "SacredInterrupt", + "TimeoutInterrupt", + "create_basic_stream_logger", + "recursive_update", + "iterate_flattened", + "iterate_flattened_separately", + "set_by_dotted_path", + "get_by_dotted_path", + "iter_path_splits", + "iter_prefixes", + "join_paths", + "is_prefix", + "convert_to_nested_dict", + "convert_camel_case_to_snake_case", + "print_filtered_stacktrace", + "optional_kwargs_decorator", + "get_inheritors", + "apply_backspaces_and_linefeeds", + "rel_path", + "IntervalTimer", + "PathType", +] + +NO_LOGGER = logging.getLogger("ignore") NO_LOGGER.disabled = 1 PATHCHANGE = object() @@ -68,14 +84,19 @@ class TimeoutInterrupt(SacredInterrupt): class SacredError(Exception): - def __init__(self, message, print_traceback=True, - filter_traceback='default', print_usage=False): + def __init__( + self, + message, + print_traceback=True, + filter_traceback="default", + print_usage=False, + ): super().__init__(message) self.print_traceback = print_traceback - if filter_traceback not in ['always', 'default', 'never']: + if filter_traceback not in ["always", "default", "never"]: raise ValueError( - 'filter_traceback must be one of \'always\', ' - '\'default\' or \'never\', not ' + filter_traceback + "filter_traceback must be one of 'always', " + "'default' or 'never', not " + filter_traceback ) self.filter_traceback = filter_traceback self.print_usage = print_usage @@ -96,14 +117,19 @@ def track(cls, ingredient): e.__ingredients__.append(ingredient) raise e - def __init__(self, message='Circular dependency detected:', - ingredients=None, print_traceback=True, - filter_traceback='default', print_usage=False): + def __init__( + self, + message="Circular dependency detected:", + ingredients=None, + print_traceback=True, + filter_traceback="default", + print_usage=False, + ): super().__init__( message, print_traceback=print_traceback, filter_traceback=filter_traceback, - print_usage=print_usage + print_usage=print_usage, ) if ingredients is None: @@ -112,23 +138,31 @@ def __init__(self, message='Circular dependency detected:', self.__circular_dependency_handled__ = False def __str__(self): - return (super().__str__() + '->'.join( - [i.path for i in reversed(self.__ingredients__)])) + return super().__str__() + "->".join( + [i.path for i in reversed(self.__ingredients__)] + ) class ConfigError(SacredError): """There was an error in the configuration. Pretty prints the conflicting configuration values.""" - def __init__(self, message, conflicting_configs=(), - print_conflicting_configs=True, - print_traceback=True, - filter_traceback='default', print_usage=False, - config=None): - super().__init__(message, - print_traceback=print_traceback, - filter_traceback=filter_traceback, - print_usage=print_usage) + def __init__( + self, + message, + conflicting_configs=(), + print_conflicting_configs=True, + print_traceback=True, + filter_traceback="default", + print_usage=False, + config=None, + ): + super().__init__( + message, + print_traceback=print_traceback, + filter_traceback=filter_traceback, + print_usage=print_usage, + ) self.print_conflicting_configs = print_conflicting_configs if isinstance(conflicting_configs, str): @@ -150,8 +184,8 @@ def track(cls, wrapped): if not e.__prefix_handled__: if wrapped.prefix: e.__conflicting_configs__ = ( - join_paths(wrapped.prefix, str(c)) for c in - e.__conflicting_configs__ + join_paths(wrapped.prefix, str(c)) + for c in e.__conflicting_configs__ ) e.__config__ = wrapped.config e.__prefix_handled__ = True @@ -165,11 +199,12 @@ def __str__(self): # Conflicting configuration values: # a=3 # b.c=4 - s += '\nConflicting configuration values:' + s += "\nConflicting configuration values:" for conflicting_config in self.__conflicting_configs__: - s += '\n {}={}'.format(conflicting_config, - get_by_dotted_path(self.__config__, - conflicting_config)) + s += "\n {}={}".format( + conflicting_config, + get_by_dotted_path(self.__config__, conflicting_config), + ) return s @@ -187,6 +222,7 @@ class InvalidConfigError(ConfigError): ... 'Need to be equal', ... conflicting_configs=('a', 'b.a')) """ + pass @@ -194,46 +230,63 @@ class MissingConfigError(SacredError): """A config value that is needed by a captured function is not present in the provided config.""" - def __init__(self, message='Configuration values are missing:', - missing_configs=(), - print_traceback=False, filter_traceback='default', - print_usage=True): - message = '{} {}'.format(message, missing_configs) + def __init__( + self, + message="Configuration values are missing:", + missing_configs=(), + print_traceback=False, + filter_traceback="default", + print_usage=True, + ): + message = "{} {}".format(message, missing_configs) super().__init__( - message, print_traceback=print_traceback, - filter_traceback=filter_traceback, print_usage=print_usage + message, + print_traceback=print_traceback, + filter_traceback=filter_traceback, + print_usage=print_usage, ) class NamedConfigNotFoundError(SacredError): """A named config is not found.""" - def __init__(self, named_config, message='Named config not found:', - available_named_configs=(), - print_traceback=False, - filter_traceback='default', print_usage=False): + def __init__( + self, + named_config, + message="Named config not found:", + available_named_configs=(), + print_traceback=False, + filter_traceback="default", + print_usage=False, + ): message = '{} "{}". Available config values are: {}'.format( - message, named_config, available_named_configs) + message, named_config, available_named_configs + ) super().__init__( message, print_traceback=print_traceback, filter_traceback=filter_traceback, - print_usage=print_usage) + print_usage=print_usage, + ) class ConfigAddedError(ConfigError): - SPECIAL_ARGS = {'_log', '_config', '_seed', '__doc__', 'config_filename', - '_run'} + SPECIAL_ARGS = {"_log", "_config", "_seed", "__doc__", "config_filename", "_run"} """Special args that show up in the captured args but can never be set by the user""" - def __init__(self, conflicting_configs, - message='Added new config entry that is not used anywhere', - captured_args=(), - print_conflicting_configs=True, print_traceback=False, - filter_traceback='default', print_usage=False, - print_suggestions=True, - config=None): + def __init__( + self, + conflicting_configs, + message="Added new config entry that is not used anywhere", + captured_args=(), + print_conflicting_configs=True, + print_traceback=False, + filter_traceback="default", + print_usage=False, + print_suggestions=True, + config=None, + ): super().__init__( message, conflicting_configs=conflicting_configs, @@ -241,7 +294,7 @@ def __init__(self, conflicting_configs, print_traceback=print_traceback, filter_traceback=filter_traceback, print_usage=print_usage, - config=config + config=config, ) self.captured_args = captured_args self.print_suggestions = print_suggestions @@ -251,7 +304,7 @@ def __str__(self): if self.print_suggestions: possible_keys = set(self.captured_args) - self.SPECIAL_ARGS if possible_keys: - s += '\nPossible config keys are: {}'.format(possible_keys) + s += "\nPossible config keys are: {}".format(possible_keys) return s @@ -260,11 +313,15 @@ class SignatureError(SacredError, TypeError): Error that is raised when the passed arguments do not match the functions signature """ - def __init__(self, message, print_traceback=True, - filter_traceback='always', - print_usage=False): - super().__init__( - message, print_traceback, filter_traceback, print_usage) + + def __init__( + self, + message, + print_traceback=True, + filter_traceback="always", + print_usage=False, + ): + super().__init__(message, print_traceback, filter_traceback, print_usage) def create_basic_stream_logger(): @@ -277,8 +334,9 @@ def create_basic_stream_logger(): already is configured (i.e. `len(getLogger().handlers) > 0`) """ logging.basicConfig( - level=logging.INFO, format='%(levelname)s - %(name)s - %(message)s') - return logging.getLogger('') + level=logging.INFO, format="%(levelname)s - %(name)s - %(message)s" + ) + return logging.getLogger("") def recursive_update(d, u): @@ -313,21 +371,26 @@ def iterate_flattened_separately(dictionary, manually_sorted_keys=None): if key in dictionary: yield key, dictionary[key] - single_line_keys = [key for key in dictionary.keys() if - key not in manually_sorted_keys and - (not dictionary[key] or - not isinstance(dictionary[key], dict))] + single_line_keys = [ + key + for key in dictionary.keys() + if key not in manually_sorted_keys + and (not dictionary[key] or not isinstance(dictionary[key], dict)) + ] for key in sorted(single_line_keys): yield key, dictionary[key] - multi_line_keys = [key for key in dictionary.keys() if - key not in manually_sorted_keys and - (dictionary[key] and - isinstance(dictionary[key], dict))] + multi_line_keys = [ + key + for key in dictionary.keys() + if key not in manually_sorted_keys + and (dictionary[key] and isinstance(dictionary[key], dict)) + ] for key in sorted(multi_line_keys): yield key, PATHCHANGE - for k, val in iterate_flattened_separately(dictionary[key], - manually_sorted_keys): + for k, val in iterate_flattened_separately( + dictionary[key], manually_sorted_keys + ): yield join_paths(key, k), val @@ -363,7 +426,7 @@ def set_by_dotted_path(d, path, value): {'foo': {'bar': 10, 'd': {'baz': 3}}} """ - split_path = path.split('.') + split_path = path.split(".") current_option = d for p in split_path[:-1]: if p not in current_option: @@ -382,7 +445,7 @@ def get_by_dotted_path(d, path, default=None): """ if not path: return d - split_path = path.split('.') + split_path = path.split(".") current_option = d for p in split_path: if p not in current_option: @@ -403,7 +466,7 @@ def iter_path_splits(path): ('foo', 'bar.baz'), ('foo.bar', 'baz')] """ - split_path = path.split('.') + split_path = path.split(".") for i in range(len(split_path)): p1 = join_paths(*split_path[:i]) p2 = join_paths(*split_path[i:]) @@ -418,29 +481,29 @@ def iter_prefixes(path): >>> list(iter_prefixes('foo.bar.baz')) ['foo', 'foo.bar', 'foo.bar.baz'] """ - split_path = path.split('.') + split_path = path.split(".") for i in range(1, len(split_path) + 1): yield join_paths(*split_path[:i]) def join_paths(*parts): """Join different parts together to a valid dotted path.""" - return '.'.join(str(p).strip('.') for p in parts if p) + return ".".join(str(p).strip(".") for p in parts if p) def is_prefix(pre_path, path): """Return True if pre_path is a path-prefix of path.""" - pre_path = pre_path.strip('.') - path = path.strip('.') - return not pre_path or path.startswith(pre_path + '.') + pre_path = pre_path.strip(".") + path = path.strip(".") + return not pre_path or path.startswith(pre_path + ".") def rel_path(base, path): """Return path relative to base.""" if base == path: - return '' + return "" assert is_prefix(base, path), "{} not a prefix of {}".format(base, path) - return path[len(base):].strip('.') + return path[len(base) :].strip(".") def convert_to_nested_dict(dotted_dict): @@ -452,14 +515,14 @@ def convert_to_nested_dict(dotted_dict): def _is_sacred_frame(frame): - return frame.f_globals["__name__"].split('.')[0] == 'sacred' + return frame.f_globals["__name__"].split(".")[0] == "sacred" -def print_filtered_stacktrace(filter_traceback='default'): +def print_filtered_stacktrace(filter_traceback="default"): print(format_filtered_stacktrace(filter_traceback), file=sys.stderr) -def format_filtered_stacktrace(filter_traceback='default'): +def format_filtered_stacktrace(filter_traceback="default"): """ Returns the traceback as `string`. @@ -477,20 +540,21 @@ def format_filtered_stacktrace(filter_traceback='default'): while current_tb.tb_next is not None: current_tb = current_tb.tb_next - if filter_traceback == 'default' \ - and _is_sacred_frame(current_tb.tb_frame): + if filter_traceback == "default" and _is_sacred_frame(current_tb.tb_frame): # just print sacred internal trace - header = ["Exception originated from within Sacred.\n" - "Traceback (most recent calls):\n"] + header = [ + "Exception originated from within Sacred.\n" + "Traceback (most recent calls):\n" + ] texts = tb.format_exception(exc_type, exc_value, current_tb) - return ''.join(header + texts[1:]).strip() - elif filter_traceback in ('default', 'always'): + return "".join(header + texts[1:]).strip() + elif filter_traceback in ("default", "always"): # print filtered stacktrace if sys.version_info >= (3, 5): - tb_exception = \ - tb.TracebackException(exc_type, exc_value, exc_traceback, - limit=None) - return ''.join(filtered_traceback_format(tb_exception)) + tb_exception = tb.TracebackException( + exc_type, exc_value, exc_traceback, limit=None + ) + return "".join(filtered_traceback_format(tb_exception)) else: s = "Traceback (most recent calls WITHOUT Sacred internals):" current_tb = exc_traceback @@ -498,16 +562,13 @@ def format_filtered_stacktrace(filter_traceback='default'): if not _is_sacred_frame(current_tb.tb_frame): tb.print_tb(current_tb, 1) current_tb = current_tb.tb_next - s += "\n".join(tb.format_exception_only(exc_type, - exc_value)).strip() + s += "\n".join(tb.format_exception_only(exc_type, exc_value)).strip() return s - elif filter_traceback == 'never': + elif filter_traceback == "never": # print full stacktrace - return '\n'.join( - tb.format_exception(exc_type, exc_value, exc_traceback)) + return "\n".join(tb.format_exception(exc_type, exc_value, exc_traceback)) else: - raise ValueError('Unknown value for filter_traceback: ' + - filter_traceback) + raise ValueError("Unknown value for filter_traceback: " + filter_traceback) def format_sacred_error(e, short_usage): @@ -518,29 +579,29 @@ def format_sacred_error(e, short_usage): lines.append(format_filtered_stacktrace(e.filter_traceback)) else: import traceback as tb - lines.append('\n'.join(tb.format_exception_only(type(e), e))) - return '\n'.join(lines) + + lines.append("\n".join(tb.format_exception_only(type(e), e))) + return "\n".join(lines) def filtered_traceback_format(tb_exception, chain=True): if chain: if tb_exception.__cause__ is not None: - yield from filtered_traceback_format(tb_exception.__cause__, - chain=chain) + yield from filtered_traceback_format(tb_exception.__cause__, chain=chain) yield tb._cause_message - elif (tb_exception.__context__ is not None and - not tb_exception.__suppress_context__): - yield from filtered_traceback_format(tb_exception.__context__, - chain=chain) + elif ( + tb_exception.__context__ is not None + and not tb_exception.__suppress_context__ + ): + yield from filtered_traceback_format(tb_exception.__context__, chain=chain) yield tb._context_message - yield 'Traceback (most recent calls WITHOUT Sacred internals):\n' + yield "Traceback (most recent calls WITHOUT Sacred internals):\n" current_tb = tb_exception.exc_traceback while current_tb is not None: if not _is_sacred_frame(current_tb.tb_frame): - stack = tb.StackSummary.extract(tb.walk_tb(current_tb), - limit=1, - lookup_lines=True, - capture_locals=False) + stack = tb.StackSummary.extract( + tb.walk_tb(current_tb), limit=1, lookup_lines=True, capture_locals=False + ) yield from stack.format() current_tb = current_tb.tb_next yield from tb_exception.format_exception_only() @@ -573,8 +634,8 @@ def get_inheritors(cls): # Taken from http://stackoverflow.com/a/1176023/1388435 def convert_camel_case_to_snake_case(name): """Convert CamelCase to snake_case.""" - s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() def apply_backspaces_and_linefeeds(text): @@ -587,30 +648,34 @@ def apply_backspaces_and_linefeeds(text): If final line ends with a carriage it keeps it to be concatenable with next output chunk. """ - orig_lines = text.split('\n') + orig_lines = text.split("\n") orig_lines_len = len(orig_lines) new_lines = [] for orig_line_idx, orig_line in enumerate(orig_lines): chars, cursor = [], 0 orig_line_len = len(orig_line) for orig_char_idx, orig_char in enumerate(orig_line): - if orig_char == '\r' and (orig_char_idx != orig_line_len - 1 or - orig_line_idx != orig_lines_len - 1): + if orig_char == "\r" and ( + orig_char_idx != orig_line_len - 1 + or orig_line_idx != orig_lines_len - 1 + ): cursor = 0 - elif orig_char == '\b': + elif orig_char == "\b": cursor = max(0, cursor - 1) else: - if (orig_char == '\r' and - orig_char_idx == orig_line_len - 1 and - orig_line_idx == orig_lines_len - 1): + if ( + orig_char == "\r" + and orig_char_idx == orig_line_len - 1 + and orig_line_idx == orig_lines_len - 1 + ): cursor = len(chars) if cursor == len(chars): chars.append(orig_char) else: chars[cursor] = orig_char cursor += 1 - new_lines.append(''.join(chars)) - return '\n'.join(new_lines) + new_lines.append("".join(chars)) + return "\n".join(new_lines) def module_exists(modname): @@ -666,12 +731,13 @@ def ensure_wellformed_argv(argv): argv = shlex.split(argv) else: if not isinstance(argv, (list, tuple)): - raise ValueError("argv must be str or list, but was {}" - .format(type(argv))) + raise ValueError("argv must be str or list, but was {}".format(type(argv))) if not all([isinstance(a, str) for a in argv]): problems = [a for a in argv if not isinstance(a, str)] - raise ValueError("argv must be list of str but contained the " - "following elements: {}".format(problems)) + raise ValueError( + "argv must be list of str but contained the " + "following elements: {}".format(problems) + ) return argv @@ -682,7 +748,7 @@ def create(cls, func, interval=10): timer_thread = cls(stop_event, func, interval) return stop_event, timer_thread - def __init__(self, event, func, interval=10.): + def __init__(self, event, func, interval=10.0): # TODO use super here. threading.Thread.__init__(self) self.stopped = event diff --git a/setup.cfg b/setup.cfg index b6003349..8b56231e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,4 +7,4 @@ pep8ignore = dist/* ALL sacred.egg-info/* ALL [flake8] -ignore = D100,D101,D102,D103,D104,D105,D203,D401,F821,W504,E722 +ignore = D100,D101,D102,D103,D104,D105,D203,D401,F821,W504,E722,W503,E203,E501 \ No newline at end of file diff --git a/setup.py b/setup.py index 67dca023..65f6cf16 100755 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ try: from sacred import __about__ + about = __about__.__dict__ except ImportError: # installing - dependencies are not there yet @@ -31,30 +32,24 @@ setup( - name='sacred', - version=about['__version__'], - - author=about['__author__'], - author_email=about['__author_email__'], - - url=about['__url__'], - - packages=['sacred', 'sacred.observers', 'sacred.config', 'sacred.stflow'], + name="sacred", + version=about["__version__"], + author=about["__author__"], + author_email=about["__author_email__"], + url=about["__url__"], + packages=["sacred", "sacred.observers", "sacred.config", "sacred.stflow"], scripts=[], install_requires=[ - 'docopt>=0.3, <1.0', - 'jsonpickle>=0.7.2, <1.0', - 'munch>=2.0.2, <3.0', - 'wrapt>=1.0, <2.0', - 'py-cpuinfo>=4.0', - 'colorama>=0.4', - 'packaging>=18.0', + "docopt>=0.3, <1.0", + "jsonpickle>=0.7.2, <1.0", + "munch>=2.0.2, <3.0", + "wrapt>=1.0, <2.0", + "py-cpuinfo>=4.0", + "colorama>=0.4", + "packaging>=18.0", ], - tests_require=[ - 'mock>=0.8, <3.0', - 'pytest==4.3.0'], - - classifiers=list(filter(None, classifiers.split('\n'))), - description='Facilitates automated and reproducible experimental research', - long_description=codecs.open('README.rst', encoding='utf_8').read() + tests_require=["mock>=0.8, <3.0", "pytest==4.3.0"], + classifiers=list(filter(None, classifiers.split("\n"))), + description="Facilitates automated and reproducible experimental research", + long_description=codecs.open("README.rst", encoding="utf_8").read(), ) diff --git a/tests/conftest.py b/tests/conftest.py index ce679974..1131c09c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,8 +10,8 @@ from sacred.settings import SETTINGS -EXAMPLES_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'examples') -BLOCK_START = re.compile(r'^\s\s+\$.*$', flags=re.MULTILINE) +EXAMPLES_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples") +BLOCK_START = re.compile(r"^\s\s+\$.*$", flags=re.MULTILINE) def get_calls_from_doc(doc): @@ -25,13 +25,13 @@ def get_calls_from_doc(doc): outputs = [] out = [] block_indent = 2 - for l in doc.split('\n'): + for l in doc.split("\n"): if BLOCK_START.match(l): - block_indent = l.find('$') - calls.append(shlex.split(l[block_indent + 1:])) + block_indent = l.find("$") + calls.append(shlex.split(l[block_indent + 1 :])) out = [] outputs.append(out) - elif l.startswith(' ' * block_indent): + elif l.startswith(" " * block_indent): out.append(l[block_indent:]) else: out = [] @@ -42,10 +42,15 @@ def get_calls_from_doc(doc): def pytest_generate_tests(metafunc): # collects all examples and parses their docstring for calls + outputs # it then parametrizes the function with 'example_test' - if 'example_test' in metafunc.fixturenames: - examples = [os.path.splitext(f)[0] for f in os.listdir(EXAMPLES_PATH) - if os.path.isfile(os.path.join(EXAMPLES_PATH, f)) and - f.endswith('.py') and f != '__init__.py' and re.match(r'^\d', f)] + if "example_test" in metafunc.fixturenames: + examples = [ + os.path.splitext(f)[0] + for f in os.listdir(EXAMPLES_PATH) + if os.path.isfile(os.path.join(EXAMPLES_PATH, f)) + and f.endswith(".py") + and f != "__init__.py" + and re.match(r"^\d", f) + ] sys.path.append(EXAMPLES_PATH) example_tests = [] @@ -54,20 +59,27 @@ def pytest_generate_tests(metafunc): try: example = __import__(example_name) except ModuleNotFoundError: - warnings.warn('could not import {name}, skips during test.'.format(name=example_name)) + warnings.warn( + "could not import {name}, skips during test.".format( + name=example_name + ) + ) continue calls_outs = get_calls_from_doc(example.__doc__) for i, (call, out) in enumerate(calls_outs): example = reload(example) example_tests.append((example.ex, call, out)) - example_ids.append('{}_{}'.format(example_name, i)) - metafunc.parametrize('example_test', example_tests, ids=example_ids) + example_ids.append("{}_{}".format(example_name, i)) + metafunc.parametrize("example_test", example_tests, ids=example_ids) def pytest_addoption(parser): - parser.addoption("--sqlalchemy-connect-url", action="store", - default='sqlite://', - help="Name of the database to connect to") + parser.addoption( + "--sqlalchemy-connect-url", + action="store", + default="sqlite://", + help="Name of the database to connect to", + ) # Deactivate GPU info to speed up tests diff --git a/tests/dependency_example.py b/tests/dependency_example.py index 884686bd..70ac39ae 100644 --- a/tests/dependency_example.py +++ b/tests/dependency_example.py @@ -17,4 +17,5 @@ def some_func(): pass + ignore_this = 17 diff --git a/tests/foo/__init__.py b/tests/foo/__init__.py index 134e71ca..2a3cace2 100644 --- a/tests/foo/__init__.py +++ b/tests/foo/__init__.py @@ -3,4 +3,3 @@ """ A local package used to test the gathering of sources by test_dependencies. """ - diff --git a/tests/foo/bar.py b/tests/foo/bar.py index ab0e1bc2..ff3e4379 100644 --- a/tests/foo/bar.py +++ b/tests/foo/bar.py @@ -6,6 +6,5 @@ """ - def test_func(): pass diff --git a/tests/test_arg_parser.py b/tests/test_arg_parser.py index 8154e8f5..55e38dd1 100644 --- a/tests/test_arg_parser.py +++ b/tests/test_arg_parser.py @@ -6,33 +6,46 @@ import shlex from docopt import docopt -from sacred.arg_parser import (_convert_value, get_config_updates, format_usage) +from sacred.arg_parser import _convert_value, get_config_updates, format_usage from sacred.commandline_options import gather_command_line_options -@pytest.mark.parametrize("argv,expected", [ - ('', {}), - ('run', {'COMMAND': 'run'}), - ('with 1 2', {'with': True, 'UPDATE': ['1', '2']}), - ('evaluate', {'COMMAND': 'evaluate'}), - ('help', {'help': True}), - ('help evaluate', {'help': True, 'COMMAND': 'evaluate'}), - ('-h', {'--help': True}), - ('--help', {'--help': True}), - ('-m foo', {'--mongo_db': 'foo'}), - ('--mongo_db=bar', {'--mongo_db': 'bar'}), - ('-l 10', {'--loglevel': '10'}), - ('--loglevel=30', {'--loglevel': '30'}), - ('--force', {'--force': True}), - ('run with a=17 b=1 -m localhost:22222', {'COMMAND': 'run', - 'with': True, - 'UPDATE': ['a=17', 'b=1'], - '--mongo_db': 'localhost:22222'}), - ('evaluate with a=18 b=2 -l30', {'COMMAND': 'evaluate', - 'with': True, - 'UPDATE': ['a=18', 'b=2'], - '--loglevel': '30'}), -]) +@pytest.mark.parametrize( + "argv,expected", + [ + ("", {}), + ("run", {"COMMAND": "run"}), + ("with 1 2", {"with": True, "UPDATE": ["1", "2"]}), + ("evaluate", {"COMMAND": "evaluate"}), + ("help", {"help": True}), + ("help evaluate", {"help": True, "COMMAND": "evaluate"}), + ("-h", {"--help": True}), + ("--help", {"--help": True}), + ("-m foo", {"--mongo_db": "foo"}), + ("--mongo_db=bar", {"--mongo_db": "bar"}), + ("-l 10", {"--loglevel": "10"}), + ("--loglevel=30", {"--loglevel": "30"}), + ("--force", {"--force": True}), + ( + "run with a=17 b=1 -m localhost:22222", + { + "COMMAND": "run", + "with": True, + "UPDATE": ["a=17", "b=1"], + "--mongo_db": "localhost:22222", + }, + ), + ( + "evaluate with a=18 b=2 -l30", + { + "COMMAND": "evaluate", + "with": True, + "UPDATE": ["a=18", "b=2"], + "--loglevel": "30", + }, + ), + ], +) def test_parse_individual_arguments(argv, expected): options = gather_command_line_options() usage = format_usage("test.py", "", {}, options) @@ -43,48 +56,56 @@ def test_parse_individual_arguments(argv, expected): assert args == plain -@pytest.mark.parametrize("update,expected", [ - (None, {}), - (['a=5'], {'a': 5}), - (['foo.bar=6'], {'foo': {'bar': 6}}), - (['a=9', 'b=0'], {'a': 9, 'b': 0}), - (["hello='world'"], {'hello': 'world'}), - (['hello="world"'], {'hello': 'world'}), - (["f=23.5"], {'f': 23.5}), - (["n=None"], {'n': None}), - (["t=True"], {'t': True}), - (["f=False"], {'f': False}), -]) +@pytest.mark.parametrize( + "update,expected", + [ + (None, {}), + (["a=5"], {"a": 5}), + (["foo.bar=6"], {"foo": {"bar": 6}}), + (["a=9", "b=0"], {"a": 9, "b": 0}), + (["hello='world'"], {"hello": "world"}), + (['hello="world"'], {"hello": "world"}), + (["f=23.5"], {"f": 23.5}), + (["n=None"], {"n": None}), + (["t=True"], {"t": True}), + (["f=False"], {"f": False}), + ], +) def test_get_config_updates(update, expected): assert get_config_updates(update) == (expected, []) -@pytest.mark.parametrize("value,expected", [ - ('None', None), - ('True', True), - ('False', False), - ('246', 246), - ('1.0', 1.0), - ('1.', 1.0), - ('.1', 0.1), - ('1e3', 1e3), - ('-.4e-12', -0.4e-12), - ('-.4e-12', -0.4e-12), - ('[1,2,3]', [1, 2, 3]), - ('[1.,.1]', [1., .1]), - ('[True, False]', [True, False]), - ('[None, None]', [None, None]), - ('[1.0,2.0,3.0]', [1.0, 2.0, 3.0]), - ('{"a":1}', {'a': 1}), - ('{"foo":1, "bar":2.0}', {'foo': 1, 'bar': 2.0}), - ('{"a":1., "b":.2}', {'a': 1., 'b': .2}), - ('{"a":True, "b":False}', {'a': True, 'b': False}), - ('{"a":None}', {'a': None}), - ('{"a":[1, 2.0, True, None], "b":"foo"}', {"a": [1, 2.0, True, None], - "b": "foo"}), - ('bob', 'bob'), - ('"hello world"', 'hello world'), - ("'hello world'", 'hello world'), -]) +@pytest.mark.parametrize( + "value,expected", + [ + ("None", None), + ("True", True), + ("False", False), + ("246", 246), + ("1.0", 1.0), + ("1.", 1.0), + (".1", 0.1), + ("1e3", 1e3), + ("-.4e-12", -0.4e-12), + ("-.4e-12", -0.4e-12), + ("[1,2,3]", [1, 2, 3]), + ("[1.,.1]", [1.0, 0.1]), + ("[True, False]", [True, False]), + ("[None, None]", [None, None]), + ("[1.0,2.0,3.0]", [1.0, 2.0, 3.0]), + ('{"a":1}', {"a": 1}), + ('{"foo":1, "bar":2.0}', {"foo": 1, "bar": 2.0}), + ('{"a":1., "b":.2}', {"a": 1.0, "b": 0.2}), + ('{"a":True, "b":False}', {"a": True, "b": False}), + ('{"a":None}', {"a": None}), + ( + '{"a":[1, 2.0, True, None], "b":"foo"}', + {"a": [1, 2.0, True, None], "b": "foo"}, + ), + ("bob", "bob"), + ('"hello world"', "hello world"), + ("'hello world'", "hello world"), + ], +) def test_convert_value(value, expected): assert _convert_value(value) == expected diff --git a/tests/test_commands.py b/tests/test_commands.py index 8ea6e40a..c7e50956 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -6,11 +6,22 @@ import pytest from sacred import Ingredient, Experiment -from sacred.commands import (COLOR_MODIFIED, ENDC, COLOR_DOC, COLOR_ADDED, - COLOR_TYPECHANGED, ConfigEntry, - PathEntry, _format_config, _format_entry, - help_for_command, _iterate_marked, _non_unicode_repr, - _format_named_configs, _format_named_config) +from sacred.commands import ( + COLOR_MODIFIED, + ENDC, + COLOR_DOC, + COLOR_ADDED, + COLOR_TYPECHANGED, + ConfigEntry, + PathEntry, + _format_config, + _format_entry, + help_for_command, + _iterate_marked, + _non_unicode_repr, + _format_named_configs, + _format_named_config, +) from sacred.config import ConfigScope from sacred.config.config_summary import ConfigSummary @@ -19,122 +30,121 @@ def test_non_unicode_repr(): p = pprint.PrettyPrinter() p.format = _non_unicode_repr # make sure there is no u' in the representation - assert p.pformat('HelloWorld') == "'HelloWorld'" + assert p.pformat("HelloWorld") == "'HelloWorld'" @pytest.fixture def cfg(): return { - 'a': 0, - 'b': {}, # 1 - 'c': { # 2 - 'cA': 3, - 'cB': 4, - 'cC': { # 5 - 'cC1': 6 - } - }, - 'd': { # 7 - 'dA': 8 - } + "a": 0, + "b": {}, # 1 + "c": {"cA": 3, "cB": 4, "cC": {"cC1": 6}}, # 2 # 5 + "d": {"dA": 8}, # 7 } def test_iterate_marked(cfg): - assert list(_iterate_marked(cfg, ConfigSummary())) == \ - [('a', ConfigEntry('a', 0, False, False, None, None)), - ('b', ConfigEntry('b', {}, False, False, None, None)), - ('c', PathEntry('c', False, False, None, None)), - ('c.cA', ConfigEntry('cA', 3, False, False, None, None)), - ('c.cB', ConfigEntry('cB', 4, False, False, None, None)), - ('c.cC', PathEntry('cC', False, False, None, None)), - ('c.cC.cC1', ConfigEntry('cC1', 6, False, False, None, None)), - ('d', PathEntry('d', False, False, None, None)), - ('d.dA', ConfigEntry('dA', 8, False, False, None, None)) - ] + assert list(_iterate_marked(cfg, ConfigSummary())) == [ + ("a", ConfigEntry("a", 0, False, False, None, None)), + ("b", ConfigEntry("b", {}, False, False, None, None)), + ("c", PathEntry("c", False, False, None, None)), + ("c.cA", ConfigEntry("cA", 3, False, False, None, None)), + ("c.cB", ConfigEntry("cB", 4, False, False, None, None)), + ("c.cC", PathEntry("cC", False, False, None, None)), + ("c.cC.cC1", ConfigEntry("cC1", 6, False, False, None, None)), + ("d", PathEntry("d", False, False, None, None)), + ("d.dA", ConfigEntry("dA", 8, False, False, None, None)), + ] def test_iterate_marked_added(cfg): - added = {'a', 'c.cB', 'c.cC.cC1'} - assert list(_iterate_marked(cfg, ConfigSummary(added=added))) == \ - [('a', ConfigEntry('a', 0, True, False, None, None)), - ('b', ConfigEntry('b', {}, False, False, None, None)), - ('c', PathEntry('c', False, True, None, None)), - ('c.cA', ConfigEntry('cA', 3, False, False, None, None)), - ('c.cB', ConfigEntry('cB', 4, True, False, None, None)), - ('c.cC', PathEntry('cC', False, True, None, None)), - ('c.cC.cC1', ConfigEntry('cC1', 6, True, False, None, None)), - ('d', PathEntry('d', False, False, None, None)), - ('d.dA', ConfigEntry('dA', 8, False, False, None, None)) - ] + added = {"a", "c.cB", "c.cC.cC1"} + assert list(_iterate_marked(cfg, ConfigSummary(added=added))) == [ + ("a", ConfigEntry("a", 0, True, False, None, None)), + ("b", ConfigEntry("b", {}, False, False, None, None)), + ("c", PathEntry("c", False, True, None, None)), + ("c.cA", ConfigEntry("cA", 3, False, False, None, None)), + ("c.cB", ConfigEntry("cB", 4, True, False, None, None)), + ("c.cC", PathEntry("cC", False, True, None, None)), + ("c.cC.cC1", ConfigEntry("cC1", 6, True, False, None, None)), + ("d", PathEntry("d", False, False, None, None)), + ("d.dA", ConfigEntry("dA", 8, False, False, None, None)), + ] def test_iterate_marked_updated(cfg): - modified = {'b', 'c', 'c.cC.cC1'} - assert list(_iterate_marked(cfg, ConfigSummary(modified=modified))) == \ - [('a', ConfigEntry('a', 0, False, False, None, None)), - ('b', ConfigEntry('b', {}, False, True, None, None)), - ('c', PathEntry('c', False, True, None, None)), - ('c.cA', ConfigEntry('cA', 3, False, False, None, None)), - ('c.cB', ConfigEntry('cB', 4, False, False, None, None)), - ('c.cC', PathEntry('cC', False, True, None, None)), - ('c.cC.cC1', ConfigEntry('cC1', 6, False, True, None, None)), - ('d', PathEntry('d', False, False, None, None)), - ('d.dA', ConfigEntry('dA', 8, False, False, None, None)) - ] + modified = {"b", "c", "c.cC.cC1"} + assert list(_iterate_marked(cfg, ConfigSummary(modified=modified))) == [ + ("a", ConfigEntry("a", 0, False, False, None, None)), + ("b", ConfigEntry("b", {}, False, True, None, None)), + ("c", PathEntry("c", False, True, None, None)), + ("c.cA", ConfigEntry("cA", 3, False, False, None, None)), + ("c.cB", ConfigEntry("cB", 4, False, False, None, None)), + ("c.cC", PathEntry("cC", False, True, None, None)), + ("c.cC.cC1", ConfigEntry("cC1", 6, False, True, None, None)), + ("d", PathEntry("d", False, False, None, None)), + ("d.dA", ConfigEntry("dA", 8, False, False, None, None)), + ] def test_iterate_marked_typechanged(cfg): - typechanged = {'a': (bool, int), - 'd.dA': (float, int)} + typechanged = {"a": (bool, int), "d.dA": (float, int)} result = list(_iterate_marked(cfg, ConfigSummary(typechanged=typechanged))) - assert result == \ - [('a', ConfigEntry('a', 0, False, False, (bool, int), None)), - ('b', ConfigEntry('b', {}, False, False, None, None)), - ('c', PathEntry('c', False, False, None, None)), - ('c.cA', ConfigEntry('cA', 3, False, False, None, None)), - ('c.cB', ConfigEntry('cB', 4, False, False, None, None)), - ('c.cC', PathEntry('cC', False, False, None, None)), - ('c.cC.cC1', ConfigEntry('cC1', 6, False, False, None, None)), - ('d', PathEntry('d', False, True, None, None)), - ('d.dA', ConfigEntry('dA', 8, False, False, (float, int), None)) - ] - - -@pytest.mark.parametrize("entry,expected", [ - (ConfigEntry('a', 0, False, False, None, None), "a = 0"), - (ConfigEntry('foo', 'bar', False, False, None, None), "foo = 'bar'"), - (ConfigEntry('b', [0, 1], False, False, None, None), "b = [0, 1]"), - (ConfigEntry('c', True, False, False, None, None), "c = True"), - (ConfigEntry('d', 0.5, False, False, None, None), "d = 0.5"), - (ConfigEntry('e', {}, False, False, None, None), "e = {}"), - # Path entries - (PathEntry('f', False, False, None, None), "f:"), - # Docstring entry - (ConfigEntry('__doc__', 'multiline\ndocstring', False, False, None, None), - COLOR_DOC + '"""multiline\ndocstring"""' + ENDC), -]) + assert result == [ + ("a", ConfigEntry("a", 0, False, False, (bool, int), None)), + ("b", ConfigEntry("b", {}, False, False, None, None)), + ("c", PathEntry("c", False, False, None, None)), + ("c.cA", ConfigEntry("cA", 3, False, False, None, None)), + ("c.cB", ConfigEntry("cB", 4, False, False, None, None)), + ("c.cC", PathEntry("cC", False, False, None, None)), + ("c.cC.cC1", ConfigEntry("cC1", 6, False, False, None, None)), + ("d", PathEntry("d", False, True, None, None)), + ("d.dA", ConfigEntry("dA", 8, False, False, (float, int), None)), + ] + + +@pytest.mark.parametrize( + "entry,expected", + [ + (ConfigEntry("a", 0, False, False, None, None), "a = 0"), + (ConfigEntry("foo", "bar", False, False, None, None), "foo = 'bar'"), + (ConfigEntry("b", [0, 1], False, False, None, None), "b = [0, 1]"), + (ConfigEntry("c", True, False, False, None, None), "c = True"), + (ConfigEntry("d", 0.5, False, False, None, None), "d = 0.5"), + (ConfigEntry("e", {}, False, False, None, None), "e = {}"), + # Path entries + (PathEntry("f", False, False, None, None), "f:"), + # Docstring entry + ( + ConfigEntry("__doc__", "multiline\ndocstring", False, False, None, None), + COLOR_DOC + '"""multiline\ndocstring"""' + ENDC, + ), + ], +) def test_format_entry(entry, expected): assert _format_entry(0, entry) == expected -@pytest.mark.parametrize("entry,color", [ - (ConfigEntry('a', 1, True, False, None, None), COLOR_ADDED), - (ConfigEntry('b', 2, False, True, None, None), COLOR_MODIFIED), - (ConfigEntry('c', 3, False, False, (bool, int), None), COLOR_TYPECHANGED), - (ConfigEntry('d', 4, True, True, None, None), COLOR_ADDED), - (ConfigEntry('e', 5, True, False, (bool, int), None), COLOR_TYPECHANGED), - (ConfigEntry('f', 6, False, True, (bool, int), None), COLOR_TYPECHANGED), - (ConfigEntry('g', 7, True, True, (bool, int), None), COLOR_TYPECHANGED), - # Path entries - (PathEntry('a', True, False, None, None), COLOR_ADDED), - (PathEntry('b', False, True, None, None), COLOR_MODIFIED), - (PathEntry('c', False, False, (bool, int), None), COLOR_TYPECHANGED), - (PathEntry('d', True, True, None, None), COLOR_ADDED), - (PathEntry('e', True, False, (bool, int), None), COLOR_TYPECHANGED), - (PathEntry('f', False, True, (bool, int), None), COLOR_TYPECHANGED), - (PathEntry('g', True, True, (bool, int), None), COLOR_TYPECHANGED), -]) +@pytest.mark.parametrize( + "entry,color", + [ + (ConfigEntry("a", 1, True, False, None, None), COLOR_ADDED), + (ConfigEntry("b", 2, False, True, None, None), COLOR_MODIFIED), + (ConfigEntry("c", 3, False, False, (bool, int), None), COLOR_TYPECHANGED), + (ConfigEntry("d", 4, True, True, None, None), COLOR_ADDED), + (ConfigEntry("e", 5, True, False, (bool, int), None), COLOR_TYPECHANGED), + (ConfigEntry("f", 6, False, True, (bool, int), None), COLOR_TYPECHANGED), + (ConfigEntry("g", 7, True, True, (bool, int), None), COLOR_TYPECHANGED), + # Path entries + (PathEntry("a", True, False, None, None), COLOR_ADDED), + (PathEntry("b", False, True, None, None), COLOR_MODIFIED), + (PathEntry("c", False, False, (bool, int), None), COLOR_TYPECHANGED), + (PathEntry("d", True, True, None, None), COLOR_ADDED), + (PathEntry("e", True, False, (bool, int), None), COLOR_TYPECHANGED), + (PathEntry("f", False, True, (bool, int), None), COLOR_TYPECHANGED), + (PathEntry("g", True, True, (bool, int), None), COLOR_TYPECHANGED), + ], +) def test_format_entry_colors(entry, color): s = _format_entry(0, entry) assert s.startswith(color) @@ -143,17 +153,17 @@ def test_format_entry_colors(entry, color): def test_format_config(cfg): cfg_text = _format_config(cfg, ConfigSummary()) - lines = cfg_text.split('\n') - assert lines[0].startswith('Configuration') - assert ' a = 0' in lines[1] - assert ' b = {}' in lines[2] - assert ' c:' in lines[3] - assert ' cA = 3' in lines[4] - assert ' cB = 4' in lines[5] - assert ' cC:' in lines[6] - assert ' cC1 = 6' in lines[7] - assert ' d:' in lines[8] - assert ' dA = 8' in lines[9] + lines = cfg_text.split("\n") + assert lines[0].startswith("Configuration") + assert " a = 0" in lines[1] + assert " b = {}" in lines[2] + assert " c:" in lines[3] + assert " cA = 3" in lines[4] + assert " cB = 4" in lines[5] + assert " cC:" in lines[6] + assert " cC1 = 6" in lines[7] + assert " d:" in lines[8] + assert " dA = 8" in lines[9] def test_help_for_command(): @@ -178,22 +188,33 @@ def _config_scope_with_multiline_doc(): pass -@pytest.mark.parametrize('indent, path, named_config, expected', [ - (0, 'a', None, 'a'), - (1, 'b', None, ' b'), - (4, 'a.b.c', None, ' a.b.c'), - (0, 'c', ConfigScope(_config_scope_with_single_line_doc), 'c' + COLOR_DOC - + ' # doc' + ENDC), - (0, 'd', ConfigScope(_config_scope_with_multiline_doc), - 'd' + COLOR_DOC + '\n """Multiline\n docstring!\n """' + ENDC) -]) +@pytest.mark.parametrize( + "indent, path, named_config, expected", + [ + (0, "a", None, "a"), + (1, "b", None, " b"), + (4, "a.b.c", None, " a.b.c"), + ( + 0, + "c", + ConfigScope(_config_scope_with_single_line_doc), + "c" + COLOR_DOC + " # doc" + ENDC, + ), + ( + 0, + "d", + ConfigScope(_config_scope_with_multiline_doc), + "d" + COLOR_DOC + '\n """Multiline\n docstring!\n """' + ENDC, + ), + ], +) def test_format_named_config(indent, path, named_config, expected): assert _format_named_config(indent, path, named_config) == expected def test_format_named_configs(): - ingred = Ingredient('ingred') - ex = Experiment(name='experiment', ingredients=[ingred]) + ingred = Ingredient("ingred") + ex = Experiment(name="experiment", ingredients=[ingred]) @ingred.named_config def named_config1(): @@ -205,13 +226,13 @@ def named_config2(): pass dict_config = dict(v=42) - ingred.add_named_config('dict_config', dict_config) - - named_configs_text = _format_named_configs(OrderedDict( - ex.gather_named_configs())) - assert named_configs_text.startswith('Named Configurations (' + - COLOR_DOC + 'doc' + ENDC + '):') - assert 'named_config2' in named_configs_text - assert '# named config with doc' in named_configs_text - assert 'ingred.named_config1' in named_configs_text - assert 'ingred.dict_config' in named_configs_text + ingred.add_named_config("dict_config", dict_config) + + named_configs_text = _format_named_configs(OrderedDict(ex.gather_named_configs())) + assert named_configs_text.startswith( + "Named Configurations (" + COLOR_DOC + "doc" + ENDC + "):" + ) + assert "named_config2" in named_configs_text + assert "# named config with doc" in named_configs_text + assert "ingred.named_config1" in named_configs_text + assert "ingred.dict_config" in named_configs_text diff --git a/tests/test_config/__init__.py b/tests/test_config/__init__.py index 9a947b1a..5c9136ad 100644 --- a/tests/test_config/__init__.py +++ b/tests/test_config/__init__.py @@ -1,3 +1,2 @@ #!/usr/bin/env python # coding=utf-8 - diff --git a/tests/test_config/enclosed_config_scope.py b/tests/test_config/enclosed_config_scope.py index 7252d5fd..c6a30723 100644 --- a/tests/test_config/enclosed_config_scope.py +++ b/tests/test_config/enclosed_config_scope.py @@ -2,7 +2,6 @@ # coding=utf-8 - from sacred.config.config_scope import ConfigScope SIX = 6 diff --git a/tests/test_config/test_captured_functions.py b/tests/test_config/test_captured_functions.py index c160611e..44f76f8c 100644 --- a/tests/test_config/test_captured_functions.py +++ b/tests/test_config/test_captured_functions.py @@ -14,8 +14,8 @@ def foo(): cf = create_captured_function(foo) - assert cf.__name__ == 'foo' - assert cf.__doc__ == 'my docstring' + assert cf.__name__ == "foo" + assert cf.__doc__ == "my docstring" assert cf.prefix is None assert cf.config == {} assert not cf.uses_randomness @@ -28,12 +28,12 @@ def foo(a, b, c, d=4, e=5, f=6): cf = create_captured_function(foo) cf.logger = mock.MagicMock() - cf.config = {'a': 11, 'b': 12, 'd': 14} + cf.config = {"a": 11, "b": 12, "d": 14} assert cf(21, c=23, f=26) == (21, 12, 23, 14, 5, 26) - cf.logger.debug.assert_has_calls([ - mock.call("Started"), - mock.call("Finished after %s.", datetime.timedelta(0))]) + cf.logger.debug.assert_has_calls( + [mock.call("Started"), mock.call("Finished after %s.", datetime.timedelta(0))] + ) def test_captured_function_randomness(): @@ -72,7 +72,7 @@ def foo(_config): cf = create_captured_function(foo) cf.logger = mock.MagicMock() - cf.config = {'a': 2, 'b': 2} + cf.config = {"a": 2, "b": 2} assert cf() == cf.config @@ -97,6 +97,6 @@ def foo(a, _log): cf.logger = mock.MagicMock() cf.run = mock.MagicMock() - d = {'a': 7} + d = {"a": 7} assert cf(**d) == 7 - assert d == {'a': 7} + assert d == {"a": 7} diff --git a/tests/test_config/test_config_dict.py b/tests/test_config/test_config_dict.py index f5f3c1a8..edd55215 100644 --- a/tests/test_config/test_config_dict.py +++ b/tests/test_config/test_config_dict.py @@ -10,14 +10,16 @@ @pytest.fixture def conf_dict(): - cfg = ConfigDict({ - "a": 1, - "b": 2.0, - "c": True, - "d": 'string', - "e": [1, 2, 3], - "f": {'a': 'b', 'c': 'd'}, - }) + cfg = ConfigDict( + { + "a": 1, + "b": 2.0, + "c": True, + "d": "string", + "e": [1, 2, 3], + "f": {"a": "b", "c": "d"}, + } + ) return cfg @@ -27,17 +29,17 @@ def test_config_dict_returns_dict(conf_dict): def test_config_dict_result_contains_keys(conf_dict): cfg = conf_dict() - assert set(cfg.keys()) == {'a', 'b', 'c', 'd', 'e', 'f'} - assert cfg['a'] == 1 - assert cfg['b'] == 2.0 - assert cfg['c'] - assert cfg['d'] == 'string' - assert cfg['e'] == [1, 2, 3] - assert cfg['f'] == {'a': 'b', 'c': 'd'} + assert set(cfg.keys()) == {"a", "b", "c", "d", "e", "f"} + assert cfg["a"] == 1 + assert cfg["b"] == 2.0 + assert cfg["c"] + assert cfg["d"] == "string" + assert cfg["e"] == [1, 2, 3] + assert cfg["f"] == {"a": "b", "c": "d"} def test_fixing_values(conf_dict): - assert conf_dict({'a': 100})['a'] == 100 + assert conf_dict({"a": 100})["a"] == 100 @pytest.mark.parametrize("key", ["$f", "contains.dot", "py/tuple", "json://1"]) @@ -46,34 +48,36 @@ def test_config_dict_raises_on_invalid_keys(key): ConfigDict({key: True}) -@pytest.mark.parametrize("value", [lambda x:x, pytest, test_fixing_values]) +@pytest.mark.parametrize("value", [lambda x: x, pytest, test_fixing_values]) def test_config_dict_accepts_special_types(value): - assert ConfigDict({"special": value})()['special'] == value + assert ConfigDict({"special": value})()["special"] == value def test_fixing_nested_dicts(conf_dict): - cfg = conf_dict({'f': {'c': 't'}}) - assert cfg['f']['a'] == 'b' - assert cfg['f']['c'] == 't' + cfg = conf_dict({"f": {"c": "t"}}) + assert cfg["f"]["a"] == "b" + assert cfg["f"]["c"] == "t" def test_adding_values(conf_dict): - cfg = conf_dict({'g': 23, 'h': {'i': 10}}) - assert cfg['g'] == 23 - assert cfg['h'] == {'i': 10} - assert cfg.added == {'g', 'h', 'h.i'} + cfg = conf_dict({"g": 23, "h": {"i": 10}}) + assert cfg["g"] == 23 + assert cfg["h"] == {"i": 10} + assert cfg.added == {"g", "h", "h.i"} def test_typechange(conf_dict): - cfg = conf_dict({'a': 'bar', 'b': 'foo', 'c': 1}) - assert cfg.typechanged == {'a': (int, type('bar')), - 'b': (float, type('foo')), - 'c': (bool, int)} + cfg = conf_dict({"a": "bar", "b": "foo", "c": 1}) + assert cfg.typechanged == { + "a": (int, type("bar")), + "b": (float, type("foo")), + "c": (bool, int), + } def test_nested_typechange(conf_dict): - cfg = conf_dict({'f': {'a': 10}}) - assert cfg.typechanged == {'f.a': (type('a'), int)} + cfg = conf_dict({"f": {"a": 10}}) + assert cfg.typechanged == {"f.a": (type("a"), int)} def is_dogmatic(a): @@ -86,73 +90,62 @@ def is_dogmatic(a): def test_result_of_conf_dict_is_not_dogmatic(conf_dict): - cfg = conf_dict({'e': [1, 1, 1]}) + cfg = conf_dict({"e": [1, 1, 1]}) assert not is_dogmatic(cfg) @pytest.mark.skipif(not opt.has_numpy, reason="requires numpy") def test_conf_scope_handles_numpy_bools(): - cfg = ConfigDict({ - "a": opt.np.bool_(1) - }) - assert 'a' in cfg() - assert cfg()['a'] + cfg = ConfigDict({"a": opt.np.bool_(1)}) + assert "a" in cfg() + assert cfg()["a"] def test_conf_scope_contains_presets(): - conf_dict = ConfigDict({ - "answer": 42 - }) - cfg = conf_dict(preset={'a': 21, 'unrelated': True}) - assert set(cfg.keys()) == {'a', 'answer', 'unrelated'} - assert cfg['a'] == 21 - assert cfg['answer'] == 42 - assert cfg['unrelated'] is True + conf_dict = ConfigDict({"answer": 42}) + cfg = conf_dict(preset={"a": 21, "unrelated": True}) + assert set(cfg.keys()) == {"a", "answer", "unrelated"} + assert cfg["a"] == 21 + assert cfg["answer"] == 42 + assert cfg["unrelated"] is True def test_conf_scope_does_not_contain_fallback(): - config_dict = ConfigDict({ - "answer": 42 - }) + config_dict = ConfigDict({"answer": 42}) - cfg = config_dict(fallback={'a': 21, 'b': 10}) + cfg = config_dict(fallback={"a": 21, "b": 10}) - assert set(cfg.keys()) == {'answer'} + assert set(cfg.keys()) == {"answer"} def test_fixed_subentry_of_preset(): config_dict = ConfigDict({}) - cfg = config_dict(preset={'d': {'a': 1, 'b': 2}}, fixed={'d': {'a': 10}}) + cfg = config_dict(preset={"d": {"a": 1, "b": 2}}, fixed={"d": {"a": 10}}) - assert set(cfg.keys()) == {'d'} - assert set(cfg['d'].keys()) == {'a', 'b'} - assert cfg['d']['a'] == 10 - assert cfg['d']['b'] == 2 + assert set(cfg.keys()) == {"d"} + assert set(cfg["d"].keys()) == {"a", "b"} + assert cfg["d"]["a"] == 10 + assert cfg["d"]["b"] == 2 def test_add_config_dict_sequential(): # https://github.com/IDSIA/sacred/issues/409 - adict = ConfigDict(dict( - dictnest2 = { - 'key_1': 'value_1', - 'key_2': 'value_2' - })) + adict = ConfigDict(dict(dictnest2={"key_1": "value_1", "key_2": "value_2"})) - bdict = ConfigDict(dict( - dictnest2 = { - 'key_2': 'update_value_2', - 'key_3': 'value3', - 'key_4': 'value4' - })) + bdict = ConfigDict( + dict( + dictnest2={"key_2": "update_value_2", "key_3": "value3", "key_4": "value4"} + ) + ) final_config = bdict(preset=adict()) assert final_config == { - 'dictnest2': { - 'key_1': 'value_1', - 'key_2': 'update_value_2', - 'key_3': 'value3', - 'key_4': 'value4' + "dictnest2": { + "key_1": "value_1", + "key_2": "update_value_2", + "key_3": "value3", + "key_4": "value4", } } diff --git a/tests/test_config/test_config_files.py b/tests/test_config/test_config_files.py index ca15713f..28c67160 100644 --- a/tests/test_config/test_config_files.py +++ b/tests/test_config/test_config_files.py @@ -9,25 +9,19 @@ from sacred.config.config_files import HANDLER_BY_EXT, load_config_file -data = { - 'foo': 42, - 'baz': [1, 0.2, 'bar', True, { - 'some_number': -12, - 'simon': 'hugo' - }] -} +data = {"foo": 42, "baz": [1, 0.2, "bar", True, {"some_number": -12, "simon": "hugo"}]} -@pytest.mark.parametrize('handler', HANDLER_BY_EXT.values()) +@pytest.mark.parametrize("handler", HANDLER_BY_EXT.values()) def test_save_and_load(handler): - with tempfile.TemporaryFile('w+' + handler.mode) as f: + with tempfile.TemporaryFile("w+" + handler.mode) as f: handler.dump(data, f) f.seek(0) # simulates closing and reopening d = handler.load(f) assert d == data -@pytest.mark.parametrize('ext, handler', HANDLER_BY_EXT.items()) +@pytest.mark.parametrize("ext, handler", HANDLER_BY_EXT.items()) def test_load_config_file(ext, handler): handle, f_name = tempfile.mkstemp(suffix=ext) f = os.fdopen(handle, "w" + handler.mode) @@ -39,7 +33,7 @@ def test_load_config_file(ext, handler): def test_load_config_file_exception_msg_invalid_ext(): - handle, f_name = tempfile.mkstemp(suffix='.invalid') + handle, f_name = tempfile.mkstemp(suffix=".invalid") f = os.fdopen(handle, "w") # necessary for windows f.close() try: diff --git a/tests/test_config/test_config_scope.py b/tests/test_config/test_config_scope.py index 9067e3ef..f20f8f10 100644 --- a/tests/test_config/test_config_scope.py +++ b/tests/test_config/test_config_scope.py @@ -4,9 +4,13 @@ import pytest import sacred.optional as opt -from sacred.config.config_scope import (ConfigScope, dedent_function_body, - dedent_line, get_function_body, - is_empty_or_comment) +from sacred.config.config_scope import ( + ConfigScope, + dedent_function_body, + dedent_line, + get_function_body, + is_empty_or_comment, +) from sacred.config.custom_containers import DogmaticDict, DogmaticList @@ -19,19 +23,19 @@ def cfg(): # description for b and c b, c = 2.0, True # d and dd are both strings - d = dd = 'string' + d = dd = "string" e = [1, 2, 3] # inline description for e - f = {'a': 'b', 'c': 'd'} + f = {"a": "b", "c": "d"} composit1 = a + b # pylint: this comment is filtered out - composit2 = f['c'] + "ada" + composit2 = f["c"] + "ada" func1 = lambda: 23 deriv = func1() def func2(a): - return 'Nothing to report' + a + return "Nothing to report" + a some_type = int @@ -46,67 +50,81 @@ def test_result_of_config_scope_is_dict(conf_scope): def test_result_of_config_scope_contains_keys(conf_scope): cfg = conf_scope() - assert set(cfg.keys()) == {'a', 'b', 'c', 'd', 'dd', 'e', 'f', - 'composit1', 'composit2', 'deriv', 'func1', - 'func2', 'some_type'} - - assert cfg['a'] == 1 - assert cfg['b'] == 2.0 - assert cfg['c'] - assert cfg['d'] == 'string' - assert cfg['dd'] == 'string' - assert cfg['e'] == [1, 2, 3] - assert cfg['f'] == {'a': 'b', 'c': 'd'} - assert cfg['composit1'] == 3.0 - assert cfg['composit2'] == 'dada' - assert cfg['func1']() == 23 - assert cfg['func2'](', sir!') == 'Nothing to report, sir!' - assert cfg['some_type'] == int - assert cfg['deriv'] == 23 + assert set(cfg.keys()) == { + "a", + "b", + "c", + "d", + "dd", + "e", + "f", + "composit1", + "composit2", + "deriv", + "func1", + "func2", + "some_type", + } + + assert cfg["a"] == 1 + assert cfg["b"] == 2.0 + assert cfg["c"] + assert cfg["d"] == "string" + assert cfg["dd"] == "string" + assert cfg["e"] == [1, 2, 3] + assert cfg["f"] == {"a": "b", "c": "d"} + assert cfg["composit1"] == 3.0 + assert cfg["composit2"] == "dada" + assert cfg["func1"]() == 23 + assert cfg["func2"](", sir!") == "Nothing to report, sir!" + assert cfg["some_type"] == int + assert cfg["deriv"] == 23 def test_fixing_values(conf_scope): - cfg = conf_scope({'a': 100}) - assert cfg['a'] == 100 - assert cfg['composit1'] == 102.0 + cfg = conf_scope({"a": 100}) + assert cfg["a"] == 100 + assert cfg["composit1"] == 102.0 def test_fixing_nested_dicts(conf_scope): - cfg = conf_scope({'f': {'c': 't'}}) - assert cfg['f']['a'] == 'b' - assert cfg['f']['c'] == 't' - assert cfg['composit2'] == 'tada' + cfg = conf_scope({"f": {"c": "t"}}) + assert cfg["f"]["a"] == "b" + assert cfg["f"]["c"] == "t" + assert cfg["composit2"] == "tada" def test_adding_values(conf_scope): - cfg = conf_scope({'g': 23, 'h': {'i': 10}}) - assert cfg['g'] == 23 - assert cfg['h'] == {'i': 10} - assert cfg.added == {'g', 'h', 'h.i'} + cfg = conf_scope({"g": 23, "h": {"i": 10}}) + assert cfg["g"] == 23 + assert cfg["h"] == {"i": 10} + assert cfg.added == {"g", "h", "h.i"} def test_typechange(conf_scope): - cfg = conf_scope({'a': 'bar', 'b': 'foo', 'c': 1}) - assert cfg.typechanged == {'a': (int, type('bar')), - 'b': (float, type('foo')), - 'c': (bool, int)} + cfg = conf_scope({"a": "bar", "b": "foo", "c": 1}) + assert cfg.typechanged == { + "a": (int, type("bar")), + "b": (float, type("foo")), + "c": (bool, int), + } def test_nested_typechange(conf_scope): - cfg = conf_scope({'f': {'a': 10}}) - assert cfg.typechanged == {'f.a': (type('a'), int)} + cfg = conf_scope({"f": {"a": 10}}) + assert cfg.typechanged == {"f.a": (type("a"), int)} def test_config_docs(conf_scope): cfg = conf_scope() assert cfg.docs == { - 'a': 'description for a', - 'b': 'description for b and c', - 'c': 'description for b and c', - 'd': 'd and dd are both strings', - 'dd': 'd and dd are both strings', - 'e': 'inline description for e', - 'seed': 'the random seed for this experiment' + "a": "description for a", + "b": "description for b and c", + "c": "description for b and c", + "d": "d and dd are both strings", + "dd": "d and dd are both strings", + "e": "inline description for e", + "seed": "the random seed for this experiment", } @@ -120,7 +138,7 @@ def is_dogmatic(a): def test_conf_scope_is_not_dogmatic(conf_scope): - assert not is_dogmatic(conf_scope({'e': [1, 1, 1]})) + assert not is_dogmatic(conf_scope({"e": [1, 1, 1]})) @pytest.mark.skipif(not opt.has_numpy, reason="requires numpy") @@ -130,8 +148,8 @@ def conf_scope(): a = opt.np.bool_(1) cfg = conf_scope() - assert 'a' in cfg - assert cfg['a'] + assert "a" in cfg + assert cfg["a"] def test_conf_scope_can_access_preset(): @@ -139,8 +157,8 @@ def test_conf_scope_can_access_preset(): def conf_scope(a): answer = 2 * a - cfg = conf_scope(preset={'a': 21}) - assert cfg['answer'] == 42 + cfg = conf_scope(preset={"a": 21}) + assert cfg["answer"] == 42 def test_conf_scope_contains_presets(): @@ -148,11 +166,11 @@ def test_conf_scope_contains_presets(): def conf_scope(a): answer = 2 * a - cfg = conf_scope(preset={'a': 21, 'unrelated': True}) - assert set(cfg.keys()) == {'a', 'answer', 'unrelated'} - assert cfg['a'] == 21 - assert cfg['answer'] == 42 - assert cfg['unrelated'] is True + cfg = conf_scope(preset={"a": 21, "unrelated": True}) + assert set(cfg.keys()) == {"a", "answer", "unrelated"} + assert cfg["a"] == 21 + assert cfg["answer"] == 42 + assert cfg["unrelated"] is True def test_conf_scope_cannot_access_undeclared_presets(): @@ -161,7 +179,7 @@ def conf_scope(): answer = 2 * a with pytest.raises(NameError): - conf_scope(preset={'a': 21}) + conf_scope(preset={"a": 21}) def test_conf_scope_can_access_fallback(): @@ -169,8 +187,8 @@ def test_conf_scope_can_access_fallback(): def conf_scope(a): answer = 2 * a - cfg = conf_scope(fallback={'a': 21}) - assert cfg['answer'] == 42 + cfg = conf_scope(fallback={"a": 21}) + assert cfg["answer"] == 42 def test_conf_scope_does_not_contain_fallback(): @@ -178,8 +196,8 @@ def test_conf_scope_does_not_contain_fallback(): def conf_scope(a): answer = 2 * a - cfg = conf_scope(fallback={'a': 21, 'b': 10}) - assert set(cfg.keys()) == {'answer'} + cfg = conf_scope(fallback={"a": 21, "b": 10}) + assert set(cfg.keys()) == {"answer"} def test_conf_scope_cannot_access_undeclared_fallback(): @@ -188,7 +206,7 @@ def conf_scope(): answer = 2 * a with pytest.raises(NameError): - conf_scope(fallback={'a': 21}) + conf_scope(fallback={"a": 21}) def test_conf_scope_can_access_fallback_and_preset(): @@ -196,8 +214,8 @@ def test_conf_scope_can_access_fallback_and_preset(): def conf_scope(a, b): answer = a + b - cfg = conf_scope(preset={'b': 40}, fallback={'a': 2}) - assert cfg['answer'] == 42 + cfg = conf_scope(preset={"b": 40}, fallback={"a": 2}) + assert cfg["answer"] == 42 def test_conf_raises_for_unaccessible_arguments(): @@ -206,14 +224,15 @@ def conf_scope(a, b, c): answer = 42 with pytest.raises(KeyError): - conf_scope(preset={'a': 1}, fallback={'b': 2}) + conf_scope(preset={"a": 1}, fallback={"b": 2}) def test_can_access_globals_from_original_scope(): from .enclosed_config_scope import cfg as conf_scope + cfg = conf_scope() - assert set(cfg.keys()) == {'answer'} - assert cfg['answer'] == 42 + assert set(cfg.keys()) == {"answer"} + assert cfg["answer"] == 42 SEVEN = 7 @@ -221,6 +240,7 @@ def test_can_access_globals_from_original_scope(): def test_cannot_access_globals_from_calling_scope(): from .enclosed_config_scope import cfg2 as conf_scope + with pytest.raises(NameError): conf_scope() # would require SEVEN @@ -230,14 +250,15 @@ def test_fixed_subentry_of_preset(): def conf_scope(): pass - cfg = conf_scope(preset={'d': {'a': 1, 'b': 2}}, fixed={'d': {'a': 10}}) + cfg = conf_scope(preset={"d": {"a": 1, "b": 2}}, fixed={"d": {"a": 10}}) - assert set(cfg.keys()) == {'d'} - assert set(cfg['d'].keys()) == {'a', 'b'} - assert cfg['d']['a'] == 10 - assert cfg['d']['b'] == 2 + assert set(cfg.keys()) == {"d"} + assert set(cfg["d"].keys()) == {"a", "b"} + assert cfg["d"]["a"] == 10 + assert cfg["d"]["b"] == 2 +# fmt: off @pytest.mark.parametrize("line,indent,expected", [ (' a=5', ' ', 'a=5'), (' a=5', ' ', 'a=5'), @@ -346,6 +367,7 @@ def subfunc(): def subfunc(): return 23 ''' +# fmt: on def test_dedent_body(): diff --git a/tests/test_config/test_config_scope_chain.py b/tests/test_config/test_config_scope_chain.py index 34f2587f..9d67e97a 100644 --- a/tests/test_config/test_config_scope_chain.py +++ b/tests/test_config/test_config_scope_chain.py @@ -16,9 +16,9 @@ def cfg2(): b = 20 final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2]) - assert set(final_cfg.keys()) == {'a', 'b'} - assert final_cfg['a'] == 10 - assert final_cfg['b'] == 20 + assert set(final_cfg.keys()) == {"a", "b"} + assert final_cfg["a"] == 10 + assert final_cfg["b"] == 20 def test_chained_config_scopes_can_access_previous_keys(): @@ -31,8 +31,8 @@ def cfg2(a): b = 2 * a final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2]) - assert set(final_cfg.keys()) == {'a', 'b'} - assert final_cfg['a'] == 10 + assert set(final_cfg.keys()) == {"a", "b"} + assert final_cfg["a"] == 10 def test_chained_config_scopes_can_modify_previous_keys(): @@ -47,9 +47,9 @@ def cfg2(a): b = 22 final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2]) - assert set(final_cfg.keys()) == {'a', 'b'} - assert final_cfg['a'] == 20 - assert final_cfg['b'] == 22 + assert set(final_cfg.keys()) == {"a", "b"} + assert final_cfg["a"] == 20 + assert final_cfg["b"] == 22 def test_chained_config_scopes_raise_for_undeclared_previous_keys(): @@ -76,12 +76,11 @@ def cfg2(c): b = 4 * c c *= 3 - final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], - fixed={'c': 5}) - assert set(final_cfg.keys()) == {'a', 'b', 'c'} - assert final_cfg['a'] == 10 - assert final_cfg['b'] == 20 - assert final_cfg['c'] == 5 + final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], fixed={"c": 5}) + assert set(final_cfg.keys()) == {"a", "b", "c"} + assert final_cfg["a"] == 10 + assert final_cfg["b"] == 20 + assert final_cfg["c"] == 5 def test_chained_config_scopes_can_access_preset(): @@ -93,12 +92,11 @@ def cfg1(c): def cfg2(a, c): b = a * 2 + c - final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], - preset={'c': 32}) - assert set(final_cfg.keys()) == {'a', 'b', 'c'} - assert final_cfg['a'] == 42 - assert final_cfg['b'] == 116 - assert final_cfg['c'] == 32 + final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], preset={"c": 32}) + assert set(final_cfg.keys()) == {"a", "b", "c"} + assert final_cfg["a"] == 42 + assert final_cfg["b"] == 116 + assert final_cfg["c"] == 32 def test_chained_config_scopes_can_access_fallback(): @@ -110,66 +108,58 @@ def cfg1(c): def cfg2(a, c): b = a * 2 + c - final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], - fallback={'c': 32}) - assert set(final_cfg.keys()) == {'a', 'b'} - assert final_cfg['a'] == 42 - assert final_cfg['b'] == 116 + final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], fallback={"c": 32}) + assert set(final_cfg.keys()) == {"a", "b"} + assert final_cfg["a"] == 42 + assert final_cfg["b"] == 116 def test_chained_config_scopes_fix_subentries(): @ConfigScope def cfg1(): - d = { - 'a': 10, - 'b': 20 - } + d = {"a": 10, "b": 20} @ConfigScope def cfg2(): pass - final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], - fixed={'d': {'a': 0}}) - assert set(final_cfg['d'].keys()) == {'a', 'b'} - assert final_cfg['d']['a'] == 0 - assert final_cfg['d']['b'] == 20 + final_cfg, summary = chain_evaluate_config_scopes( + [cfg1, cfg2], fixed={"d": {"a": 0}} + ) + assert set(final_cfg["d"].keys()) == {"a", "b"} + assert final_cfg["d"]["a"] == 0 + assert final_cfg["d"]["b"] == 20 def test_empty_chain_contains_preset_and_fixed(): - final_cfg, summary = chain_evaluate_config_scopes([], - fixed={'a': 0}, - preset={'a': 1, 'b': 2}) - assert set(final_cfg.keys()) == {'a', 'b'} - assert final_cfg['a'] == 0 - assert final_cfg['b'] == 2 + final_cfg, summary = chain_evaluate_config_scopes( + [], fixed={"a": 0}, preset={"a": 1, "b": 2} + ) + assert set(final_cfg.keys()) == {"a", "b"} + assert final_cfg["a"] == 0 + assert final_cfg["b"] == 2 def test_add_config_dict_sequential(): # https://github.com/IDSIA/sacred/issues/409 @ConfigScope def cfg1(): - dictnest2 = { - 'key_1': 'value_1', - 'key_2': 'value_2' - } + dictnest2 = {"key_1": "value_1", "key_2": "value_2"} + cfg1dict = ConfigDict(cfg1()) @ConfigScope def cfg2(): - dictnest2 = { - 'key_2': 'update_value_2', - 'key_3': 'value3', - 'key_4': 'value4' - } + dictnest2 = {"key_2": "update_value_2", "key_3": "value3", "key_4": "value4"} + cfg2dict = ConfigDict(cfg2()) final_config_scope, _ = chain_evaluate_config_scopes([cfg1, cfg2]) assert final_config_scope == { - 'dictnest2': { - 'key_1': 'value_1', - 'key_2': 'update_value_2', - 'key_3': 'value3', - 'key_4': 'value4' + "dictnest2": { + "key_1": "value_1", + "key_2": "update_value_2", + "key_3": "value3", + "key_4": "value4", } } diff --git a/tests/test_config/test_dogmatic_dict.py b/tests/test_config/test_dogmatic_dict.py index 4d30ce64..84e86aae 100644 --- a/tests/test_config/test_dogmatic_dict.py +++ b/tests/test_config/test_dogmatic_dict.py @@ -20,135 +20,133 @@ def test_dict_interface_initialized_empty(): def test_dict_interface_set_item(): d = DogmaticDict() - d['a'] = 12 - d['b'] = 'foo' - assert 'a' in d - assert 'b' in d + d["a"] = 12 + d["b"] = "foo" + assert "a" in d + assert "b" in d - assert d['a'] == 12 - assert d['b'] == 'foo' + assert d["a"] == 12 + assert d["b"] == "foo" - assert set(d.keys()) == {'a', 'b'} - assert set(d.values()) == {12, 'foo'} - assert set(d.items()) == {('a', 12), ('b', 'foo')} + assert set(d.keys()) == {"a", "b"} + assert set(d.values()) == {12, "foo"} + assert set(d.items()) == {("a", 12), ("b", "foo")} def test_dict_interface_del_item(): d = DogmaticDict() - d['a'] = 12 - del d['a'] - assert 'a' not in d + d["a"] = 12 + del d["a"] + assert "a" not in d def test_dict_interface_update_with_dict(): d = DogmaticDict() - d['a'] = 12 - d['b'] = 'foo' + d["a"] = 12 + d["b"] = "foo" - d.update({'a': 1, 'c': 2}) - assert d['a'] == 1 - assert d['b'] == 'foo' - assert d['c'] == 2 + d.update({"a": 1, "c": 2}) + assert d["a"] == 1 + assert d["b"] == "foo" + assert d["c"] == 2 def test_dict_interface_update_with_kwargs(): d = DogmaticDict() - d['a'] = 12 - d['b'] = 'foo' + d["a"] = 12 + d["b"] = "foo" d.update(a=2, b=3) - assert d['a'] == 2 - assert d['b'] == 3 + assert d["a"] == 2 + assert d["b"] == 3 def test_dict_interface_update_with_list_of_items(): d = DogmaticDict() - d['a'] = 12 - d['b'] = 'foo' - d.update([('b', 9), ('c', 7)]) - assert d['a'] == 12 - assert d['b'] == 9 - assert d['c'] == 7 + d["a"] = 12 + d["b"] = "foo" + d.update([("b", 9), ("c", 7)]) + assert d["a"] == 12 + assert d["b"] == 9 + assert d["c"] == 7 def test_fixed_value_not_initialized(): - d = DogmaticDict({'a': 7}) - assert 'a' not in d + d = DogmaticDict({"a": 7}) + assert "a" not in d def test_fixed_value_fixed(): - d = DogmaticDict({'a': 7}) - d['a'] = 8 - assert d['a'] == 7 + d = DogmaticDict({"a": 7}) + d["a"] = 8 + assert d["a"] == 7 - del d['a'] - assert 'a' in d - assert d['a'] == 7 + del d["a"] + assert "a" in d + assert d["a"] == 7 - d.update([('a', 9), ('b', 12)]) - assert d['a'] == 7 + d.update([("a", 9), ("b", 12)]) + assert d["a"] == 7 - d.update({'a': 9, 'b': 12}) - assert d['a'] == 7 + d.update({"a": 9, "b": 12}) + assert d["a"] == 7 d.update(a=10, b=13) - assert d['a'] == 7 + assert d["a"] == 7 def test_revelation(): - d = DogmaticDict({'a': 7, 'b': 12}) - d['b'] = 23 - assert 'a' not in d + d = DogmaticDict({"a": 7, "b": 12}) + d["b"] = 23 + assert "a" not in d m = d.revelation() - assert set(m) == {'a'} - assert 'a' in d + assert set(m) == {"a"} + assert "a" in d def test_fallback(): - d = DogmaticDict(fallback={'a': 23}) - assert 'a' in d - assert d['a'] == 23 - assert d.get('a') == 23 + d = DogmaticDict(fallback={"a": 23}) + assert "a" in d + assert d["a"] == 23 + assert d.get("a") == 23 d = DogmaticDict() - d.fallback = {'a': 23} - assert 'a' in d - assert d['a'] == 23 - assert d.get('a') == 23 + d.fallback = {"a": 23} + assert "a" in d + assert d["a"] == 23 + assert d.get("a") == 23 def test_fallback_not_iterated(): - d = DogmaticDict(fallback={'a': 23}) - d['b'] = 1234 - assert list(d.keys()) == ['b'] + d = DogmaticDict(fallback={"a": 23}) + d["b"] = 1234 + assert list(d.keys()) == ["b"] assert list(d.values()) == [1234] - assert list(d.items()) == [('b', 1234)] + assert list(d.items()) == [("b", 1234)] def test_overwrite_fallback(): - d = DogmaticDict(fallback={'a': 23}) - d['a'] = 0 - assert d['a'] == 0 - assert list(d.keys()) == ['a'] + d = DogmaticDict(fallback={"a": 23}) + d["a"] = 0 + assert d["a"] == 0 + assert list(d.keys()) == ["a"] assert list(d.values()) == [0] - assert list(d.items()) == [('a', 0)] + assert list(d.items()) == [("a", 0)] def test_fixed_has_precedence_over_fallback(): - d = DogmaticDict(fixed={'a': 0}, fallback={'a': 23}) - assert d['a'] == 0 + d = DogmaticDict(fixed={"a": 0}, fallback={"a": 23}) + assert d["a"] == 0 def test_nested_fixed_merges_with_fallback(): - d = DogmaticDict(fixed={'foo': {'bar': 20}}, - fallback={'foo': {'bar': 10, 'c': 5}}) - assert d['foo']['bar'] == 20 - assert d['foo']['c'] == 5 + d = DogmaticDict(fixed={"foo": {"bar": 20}}, fallback={"foo": {"bar": 10, "c": 5}}) + assert d["foo"]["bar"] == 20 + assert d["foo"]["c"] == 5 def test_nested_fixed_with_fallback_madness(): - d = DogmaticDict(fixed={'foo': {'bar': 20}}, - fallback={'foo': {'bar': 10, 'c': 5}}) - d['foo'] = {'bar': 30, 'a': 1} - assert d['foo']['bar'] == 20 - assert d['foo']['a'] == 1 - assert d['foo']['c'] == 5 + d = DogmaticDict(fixed={"foo": {"bar": 20}}, fallback={"foo": {"bar": 10, "c": 5}}) + d["foo"] = {"bar": 30, "a": 1} + assert d["foo"]["bar"] == 20 + assert d["foo"]["a"] == 1 + assert d["foo"]["c"] == 5 diff --git a/tests/test_config/test_dogmatic_list.py b/tests/test_config/test_dogmatic_list.py index 287fa9d5..c4ca47de 100644 --- a/tests/test_config/test_dogmatic_list.py +++ b/tests/test_config/test_dogmatic_list.py @@ -132,11 +132,11 @@ def test_empty_revelation(): def test_nested_dict_revelation(): - d1 = DogmaticDict({'a': 7, 'b': 12}) - d2 = DogmaticDict({'c': 7}) + d1 = DogmaticDict({"a": 7, "b": 12}) + d2 = DogmaticDict({"c": 7}) l = DogmaticList([d1, 2, d2]) -# assert l.revelation() == {'0.a', '0.b', '2.c'} + # assert l.revelation() == {'0.a', '0.b', '2.c'} l.revelation() - assert 'a' in l[0] - assert 'b' in l[0] - assert 'c' in l[2] + assert "a" in l[0] + assert "b" in l[0] + assert "c" in l[2] diff --git a/tests/test_config/test_fallback_dict.py b/tests/test_config/test_fallback_dict.py index a51536be..a969a290 100644 --- a/tests/test_config/test_fallback_dict.py +++ b/tests/test_config/test_fallback_dict.py @@ -7,7 +7,7 @@ @pytest.fixture def fbdict(): - return fallback_dict({'fall1': 7, 'fall3': True}) + return fallback_dict({"fall1": 7, "fall3": True}) def test_is_dictionary(fbdict): @@ -15,22 +15,22 @@ def test_is_dictionary(fbdict): def test_getitem(fbdict): - assert 'foo' not in fbdict - fbdict['foo'] = 23 - assert 'foo' in fbdict - assert fbdict['foo'] == 23 + assert "foo" not in fbdict + fbdict["foo"] = 23 + assert "foo" in fbdict + assert fbdict["foo"] == 23 def test_fallback(fbdict): - assert 'fall1' in fbdict - assert fbdict['fall1'] == 7 - fbdict['fall1'] = 8 - assert fbdict['fall1'] == 8 + assert "fall1" in fbdict + assert fbdict["fall1"] == 7 + fbdict["fall1"] = 8 + assert fbdict["fall1"] == 8 def test_get(fbdict): - fbdict['a'] = 'b' - assert fbdict.get('a', 18) == 'b' - assert fbdict.get('fall1', 18) == 7 - assert fbdict.get('notexisting', 18) == 18 - assert fbdict.get('fall3', 18) is True + fbdict["a"] = "b" + assert fbdict.get("a", 18) == "b" + assert fbdict.get("fall1", 18) == 7 + assert fbdict.get("notexisting", 18) == 18 + assert fbdict.get("fall3", 18) is True diff --git a/tests/test_config/test_readonly_containers.py b/tests/test_config/test_readonly_containers.py index 3188d4ff..24ae8040 100644 --- a/tests/test_config/test_readonly_containers.py +++ b/tests/test_config/test_readonly_containers.py @@ -1,16 +1,14 @@ import pytest from copy import copy, deepcopy -from sacred.config.custom_containers import (make_read_only, ReadOnlyList, - ReadOnlyDict, ) +from sacred.config.custom_containers import make_read_only, ReadOnlyList, ReadOnlyDict from sacred.utils import SacredError def _check_read_only_dict(d): assert isinstance(d, ReadOnlyDict) - raises_dict = pytest.raises( - SacredError, match='This ReadOnlyDict is read-only!') + raises_dict = pytest.raises(SacredError, match="This ReadOnlyDict is read-only!") if len(d) > 0: # Test removal of entries and overwrite an already present entry @@ -27,27 +25,26 @@ def _check_read_only_dict(d): # Test direct writes with raises_dict: - d['abcdefg'] = 42 + d["abcdefg"] = 42 # Test other functions that modify the dict with raises_dict: d.clear() with raises_dict: - d.update({'abcdefg': 42}) + d.update({"abcdefg": 42}) with raises_dict: d.popitem() with raises_dict: - d.setdefault('a', 0) + d.setdefault("a", 0) def _check_read_only_list(lst): assert isinstance(lst, ReadOnlyList) - raises_list = pytest.raises( - SacredError, match='This ReadOnlyList is read-only!') + raises_list = pytest.raises(SacredError, match="This ReadOnlyList is read-only!") if len(lst): with raises_list: @@ -94,7 +91,7 @@ def test_nested_readonly_dict(): d = dict(a=1, b=dict(c=3)) d = make_read_only(d) _check_read_only_dict(d) - _check_read_only_dict(d['b']) + _check_read_only_dict(d["b"]) def test_readonly_list(): @@ -133,7 +130,9 @@ def test_copy_on_readonly_dict(): d = dict(a=1, b=2, c=3) d = make_read_only(d) copied_d = copy(d) - for (k, v), (k_copied, v_copied) in zip(sorted(d.items()), sorted(copied_d.items())): + for (k, v), (k_copied, v_copied) in zip( + sorted(d.items()), sorted(copied_d.items()) + ): assert k == k_copied assert v == v_copied @@ -142,7 +141,9 @@ def test_copy_on_nested_readonly_dict(): d = dict(a=1, b=dict(c=3)) d = make_read_only(d) copied_d = copy(d) - for (k, v), (k_copied, v_copied) in zip(sorted(d.items()), sorted(copied_d.items())): + for (k, v), (k_copied, v_copied) in zip( + sorted(d.items()), sorted(copied_d.items()) + ): assert k == k_copied assert v == v_copied @@ -159,7 +160,9 @@ def test_deepcopy_on_readonly_dict(): d = dict(a=1, b=2, c=3) d = make_read_only(d) copied_d = deepcopy(d) - for (k, v), (k_copied, v_copied) in zip(sorted(d.items()), sorted(copied_d.items())): + for (k, v), (k_copied, v_copied) in zip( + sorted(d.items()), sorted(copied_d.items()) + ): assert k == k_copied assert v == v_copied @@ -168,7 +171,9 @@ def test_deepcopy_on_nested_readonly_dict(): d = dict(a=1, b=dict(c=3)) d = make_read_only(d) copied_d = deepcopy(d) - for (k, v), (k_copied, v_copied) in zip(sorted(d.items()), sorted(copied_d.items())): + for (k, v), (k_copied, v_copied) in zip( + sorted(d.items()), sorted(copied_d.items()) + ): assert k == k_copied assert v == v_copied diff --git a/tests/test_config/test_signature.py b/tests/test_config/test_signature.py index 4b7dd03a..994ee4a2 100644 --- a/tests/test_config/test_signature.py +++ b/tests/test_config/test_signature.py @@ -7,6 +7,7 @@ import pytest from sacred.config.signature import Signature + # ############# function definitions to test on ############################## from sacred.utils import MissingConfigError @@ -19,7 +20,7 @@ def bariza(a: int, b: float, c: str): return a, b, c -def complex_function_name(a: int = 5, b: str = 'fo', c: float = 9): +def complex_function_name(a: int = 5, b: str = "fo", c: float = 9): return a, b, c @@ -47,38 +48,78 @@ def generic(*args, **kwargs): def onlykwrgs(**kwargs): return kwargs + def kwonly_args(a, *, b, c=10): return b renamed = old_name -functions = [foo, bariza, complex_function_name, FunCTIonWithCAPItals, - _name_with_underscore_, __double_underscore__, old_name, - renamed, kwonly_args] - -ids = ['foo', 'bariza', 'complex_function_name', 'FunCTIonWithCAPItals', - '_name_with_underscore_', '__double_underscore__', 'old_name', - 'renamed','kwonly_args'] - -names = ['foo', 'bariza', 'complex_function_name', 'FunCTIonWithCAPItals', - '_name_with_underscore_', '__double_underscore__', 'old_name', - 'old_name', 'kwonly_args'] - -arguments = [[], ['a', 'b', 'c'], ['a', 'b', 'c'], ['a', 'b', 'c'], - ['fo', 'bar'], ['man', 'o'], ['verylongvariablename'], - ['verylongvariablename'], ['a', 'b', 'c']] - -vararg_names = [None, None, None, None, 'baz', 'men', None, None, None] - -kw_wc_names = [None, None, None, 'kwargs', None, 'oo', None, None, None] - -pos_arguments = [[], ['a', 'b', 'c'], [], ['a', 'b'], ['fo', 'bar'], - ['man', 'o'], ['verylongvariablename'], - ['verylongvariablename'], ['a']] - -kwarg_list = [{}, {}, {'a': 5, 'b': 'fo', 'c': 9}, {'c': 3}, - {}, {}, {}, {}, {'c': 10}] +functions = [ + foo, + bariza, + complex_function_name, + FunCTIonWithCAPItals, + _name_with_underscore_, + __double_underscore__, + old_name, + renamed, + kwonly_args, +] + +ids = [ + "foo", + "bariza", + "complex_function_name", + "FunCTIonWithCAPItals", + "_name_with_underscore_", + "__double_underscore__", + "old_name", + "renamed", + "kwonly_args", +] + +names = [ + "foo", + "bariza", + "complex_function_name", + "FunCTIonWithCAPItals", + "_name_with_underscore_", + "__double_underscore__", + "old_name", + "old_name", + "kwonly_args", +] + +arguments = [ + [], + ["a", "b", "c"], + ["a", "b", "c"], + ["a", "b", "c"], + ["fo", "bar"], + ["man", "o"], + ["verylongvariablename"], + ["verylongvariablename"], + ["a", "b", "c"], +] + +vararg_names = [None, None, None, None, "baz", "men", None, None, None] + +kw_wc_names = [None, None, None, "kwargs", None, "oo", None, None, None] + +pos_arguments = [ + [], + ["a", "b", "c"], + [], + ["a", "b"], + ["fo", "bar"], + ["man", "o"], + ["verylongvariablename"], + ["verylongvariablename"], + ["a"], +] + +kwarg_list = [{}, {}, {"a": 5, "b": "fo", "c": 9}, {"c": 3}, {}, {}, {}, {}, {"c": 10}] class SomeClass: @@ -88,78 +129,78 @@ def bla(self, a, b, c): # ####################### Tests ############################################# + @pytest.mark.parametrize("function, name", zip(functions, names), ids=ids) def test_constructor_extract_function_name(function, name): - s = Signature(function) - assert s.name == name + s = Signature(function) + assert s.name == name @pytest.mark.parametrize("function, args", zip(functions, arguments), ids=ids) def test_constructor_extracts_all_arguments(function, args): - s = Signature(function) - assert s.arguments == args + s = Signature(function) + assert s.arguments == args -@pytest.mark.parametrize("function, vararg", zip(functions, vararg_names), - ids=ids) +@pytest.mark.parametrize("function, vararg", zip(functions, vararg_names), ids=ids) def test_constructor_extract_vararg_name(function, vararg): - s = Signature(function) - assert s.vararg_name == vararg + s = Signature(function) + assert s.vararg_name == vararg -@pytest.mark.parametrize("function, kw_wc", zip(functions, kw_wc_names), - ids=ids) +@pytest.mark.parametrize("function, kw_wc", zip(functions, kw_wc_names), ids=ids) def test_constructor_extract_kwargs_wildcard_name(function, kw_wc): - s = Signature(function) - assert s.kw_wildcard_name == kw_wc + s = Signature(function) + assert s.kw_wildcard_name == kw_wc -@pytest.mark.parametrize("function, pos_args", zip(functions, pos_arguments), - ids=ids) +@pytest.mark.parametrize("function, pos_args", zip(functions, pos_arguments), ids=ids) def test_constructor_extract_positional_arguments(function, pos_args): - s = Signature(function) - assert s.positional_args == pos_args + s = Signature(function) + assert s.positional_args == pos_args -@pytest.mark.parametrize("function, kwargs", - zip(functions, kwarg_list), - ids=ids) +@pytest.mark.parametrize("function, kwargs", zip(functions, kwarg_list), ids=ids) def test_constructor_extract_kwargs(function, kwargs): - s = Signature(function) - assert s.kwargs == kwargs + s = Signature(function) + assert s.kwargs == kwargs def test_get_free_parameters(): free = Signature(foo).get_free_parameters([], {}) assert free == [] - free = Signature(bariza).get_free_parameters([], {'c': 3}) - assert free == ['a', 'b'] + free = Signature(bariza).get_free_parameters([], {"c": 3}) + assert free == ["a", "b"] free = Signature(complex_function_name).get_free_parameters([], {}) - assert free == ['a', 'b', 'c'] + assert free == ["a", "b", "c"] free = Signature(_name_with_underscore_).get_free_parameters([], {}) - assert free == ['fo', 'bar'] + assert free == ["fo", "bar"] s = Signature(__double_underscore__) assert s.get_free_parameters([1, 2, 3], {}) == [] -@pytest.mark.parametrize('function', - [foo, bariza, complex_function_name, - _name_with_underscore_, old_name, renamed]) +@pytest.mark.parametrize( + "function", + [foo, bariza, complex_function_name, _name_with_underscore_, old_name, renamed], +) def test_construct_arguments_with_unexpected_kwargs_raises_typeerror(function): - kwargs = {'zimbabwe': 23} + kwargs = {"zimbabwe": 23} unexpected = re.compile(".*unexpected.*zimbabwe.*") with pytest.raises(TypeError) as excinfo: Signature(function).construct_arguments([], kwargs, {}) assert unexpected.match(excinfo.value.args[0]) -@pytest.mark.parametrize('func,args', [ - (foo, [1]), - (bariza, [1, 2, 3, 4]), - (complex_function_name, [1, 2, 3, 4]), - (old_name, [1, 2]), - (renamed, [1, 2]) -]) +@pytest.mark.parametrize( + "func,args", + [ + (foo, [1]), + (bariza, [1, 2, 3, 4]), + (complex_function_name, [1, 2, 3, 4]), + (old_name, [1, 2]), + (renamed, [1, 2]), + ], +) def test_construct_arguments_with_unexpected_args_raises_typeerror(func, args): unexpected = re.compile(".*unexpected.*") with pytest.raises(TypeError) as excinfo: @@ -168,58 +209,55 @@ def test_construct_arguments_with_unexpected_args_raises_typeerror(func, args): def test_construct_arguments_with_kwargswildcard_doesnt_raise(): - kwargs = {'zimbabwe': 23} + kwargs = {"zimbabwe": 23} Signature(__double_underscore__).construct_arguments([1, 2], kwargs, {}) - Signature(FunCTIonWithCAPItals).construct_arguments( - [1, 2, 3], kwargs, {}) + Signature(FunCTIonWithCAPItals).construct_arguments([1, 2, 3], kwargs, {}) def test_construct_arguments_with_varargs_doesnt_raise(): Signature(generic).construct_arguments([1, 2, 3], {}, {}) - Signature(__double_underscore__).construct_arguments( - [1, 2, 3, 4, 5], {}, {}) - Signature(_name_with_underscore_).construct_arguments( - [1, 2, 3, 4], {}, {}) + Signature(__double_underscore__).construct_arguments([1, 2, 3, 4, 5], {}, {}) + Signature(_name_with_underscore_).construct_arguments([1, 2, 3, 4], {}, {}) def test_construct_arguments_with_expected_kwargs_does_not_raise(): s = Signature(complex_function_name) - s.construct_arguments([], {'a': 4, 'b': 3, 'c': 2}, {}) + s.construct_arguments([], {"a": 4, "b": 3, "c": 2}, {}) s = Signature(FunCTIonWithCAPItals) - s.construct_arguments([1, 2], {'c': 5}, {}) + s.construct_arguments([1, 2], {"c": 5}, {}) def test_construct_arguments_with_kwargs_for_posargs_does_not_raise(): - Signature(bariza).construct_arguments([], {'a': 4, 'b': 3, 'c': 2}, {}) + Signature(bariza).construct_arguments([], {"a": 4, "b": 3, "c": 2}, {}) s = Signature(FunCTIonWithCAPItals) - s.construct_arguments([], {'a': 4, 'b': 3, 'c': 2, 'd': 6}, {}) + s.construct_arguments([], {"a": 4, "b": 3, "c": 2, "d": 6}, {}) def test_construct_arguments_with_duplicate_args_raises_typeerror(): multiple_values = re.compile(".*multiple values.*") with pytest.raises(TypeError) as excinfo: s = Signature(bariza) - s.construct_arguments([1, 2, 3], {'a': 4, 'b': 5}, {}) + s.construct_arguments([1, 2, 3], {"a": 4, "b": 5}, {}) assert multiple_values.match(excinfo.value.args[0]) with pytest.raises(TypeError) as excinfo: s = Signature(complex_function_name) - s.construct_arguments([1], {'a': 4}, {}) + s.construct_arguments([1], {"a": 4}, {}) assert multiple_values.match(excinfo.value.args[0]) with pytest.raises(TypeError) as excinfo: s = Signature(FunCTIonWithCAPItals) - s.construct_arguments([1, 2, 3], {'c': 6}, {}) + s.construct_arguments([1, 2, 3], {"c": 6}, {}) assert multiple_values.match(excinfo.value.args[0]) def test_construct_arguments_without_duplicates_passes(): s = Signature(bariza) - s.construct_arguments([1, 2], {'c': 5}, {}) + s.construct_arguments([1, 2], {"c": 5}, {}) s = Signature(complex_function_name) - s.construct_arguments([1], {'b': 4}, {}) + s.construct_arguments([1], {"b": 4}, {}) s = Signature(FunCTIonWithCAPItals) - s.construct_arguments([], {'a': 6, 'b': 6, 'c': 6}, {}) + s.construct_arguments([], {"a": 6, "b": 6, "c": 6}, {}) def test_construct_arguments_without_options_returns_same_args_kwargs(): @@ -234,53 +272,51 @@ def test_construct_arguments_without_options_returns_same_args_kwargs(): assert kwargs == {} s = Signature(complex_function_name) - args, kwargs = s.construct_arguments([2], {'c': 6, 'b': 7}, {}) + args, kwargs = s.construct_arguments([2], {"c": 6, "b": 7}, {}) assert args == [2] - assert kwargs == {'c': 6, 'b': 7} + assert kwargs == {"c": 6, "b": 7} s = Signature(_name_with_underscore_) - args, kwargs = s.construct_arguments([], {'fo': 7, 'bar': 6}, {}) + args, kwargs = s.construct_arguments([], {"fo": 7, "bar": 6}, {}) assert args == [] - assert kwargs == {'fo': 7, 'bar': 6} + assert kwargs == {"fo": 7, "bar": 6} def test_construct_arguments_completes_kwargs_from_options(): s = Signature(bariza) - args, kwargs = s.construct_arguments([2, 4], {}, {'c': 6}) + args, kwargs = s.construct_arguments([2, 4], {}, {"c": 6}) assert args == [2, 4] - assert kwargs == {'c': 6} + assert kwargs == {"c": 6} s = Signature(complex_function_name) - args, kwargs = s.construct_arguments([], {'c': 6, 'b': 7}, {'a': 1}) + args, kwargs = s.construct_arguments([], {"c": 6, "b": 7}, {"a": 1}) assert args == [] - assert kwargs == {'a': 1, 'c': 6, 'b': 7} + assert kwargs == {"a": 1, "c": 6, "b": 7} s = Signature(_name_with_underscore_) - args, kwargs = s.construct_arguments([], {}, {'fo': 7, 'bar': 6}) + args, kwargs = s.construct_arguments([], {}, {"fo": 7, "bar": 6}) assert args == [] - assert kwargs == {'fo': 7, 'bar': 6} + assert kwargs == {"fo": 7, "bar": 6} def test_construct_arguments_ignores_excess_options(): s = Signature(bariza) - args, kwargs = s.construct_arguments([2], {'b': 4}, - {'c': 6, 'foo': 9, 'bar': 0}) + args, kwargs = s.construct_arguments([2], {"b": 4}, {"c": 6, "foo": 9, "bar": 0}) assert args == [2] - assert kwargs == {'b': 4, 'c': 6} + assert kwargs == {"b": 4, "c": 6} def test_construct_arguments_does_not_overwrite_args_and_kwargs(): s = Signature(bariza) - args, kwargs = s.construct_arguments([1, 2], {'c': 3}, - {'a': 6, 'b': 6, 'c': 6}) + args, kwargs = s.construct_arguments([1, 2], {"c": 3}, {"a": 6, "b": 6, "c": 6}) assert args == [1, 2] - assert kwargs == {'c': 3} + assert kwargs == {"c": 3} def test_construct_arguments_overwrites_defaults(): s = Signature(complex_function_name) - args, kwargs = s.construct_arguments([], {}, {'a': 11, 'b': 12, 'c': 7}) + args, kwargs = s.construct_arguments([], {}, {"a": 11, "b": 12, "c": 7}) assert args == [] - assert kwargs == {'a': 11, 'b': 12, 'c': 7} + assert kwargs == {"a": 11, "b": 12, "c": 7} def test_construct_arguments_raises_if_args_unfilled(): @@ -293,19 +329,19 @@ def test_construct_arguments_raises_if_args_unfilled(): s.construct_arguments([1, 2], {}, {}) assert missing.match(excinfo.value.args[0]) with pytest.raises(MissingConfigError) as excinfo: - s.construct_arguments([1], {'b': 3}, {}) + s.construct_arguments([1], {"b": 3}, {}) assert missing.match(excinfo.value.args[0]) with pytest.raises(MissingConfigError) as excinfo: - s.construct_arguments([1], {'c': 5}, {}) + s.construct_arguments([1], {"c": 5}, {}) assert missing.match(excinfo.value.args[0]) def test_construct_arguments_does_not_raise_if_all_args_filled(): s = Signature(bariza) s.construct_arguments([1, 2, 3], {}, {}) - s.construct_arguments([1, 2], {'c': 6}, {}) - s.construct_arguments([1], {'b': 6, 'c': 6}, {}) - s.construct_arguments([], {'a': 6, 'b': 6, 'c': 6}, {}) + s.construct_arguments([1, 2], {"c": 6}, {}) + s.construct_arguments([1], {"b": 6, "c": 6}, {}) + s.construct_arguments([], {"a": 6, "b": 6, "c": 6}, {}) def test_construct_arguments_does_not_raise_for_missing_defaults(): @@ -315,22 +351,25 @@ def test_construct_arguments_does_not_raise_for_missing_defaults(): def test_construct_arguments_for_bound_method(): s = Signature(SomeClass.bla) - args, kwargs = s.construct_arguments([1], {'b': 2}, {'c': 3}, bound=True) + args, kwargs = s.construct_arguments([1], {"b": 2}, {"c": 3}, bound=True) assert args == [1] - assert kwargs == {'b': 2, 'c': 3} - - -@pytest.mark.parametrize('func,expected', [ - (foo, "foo()"), - (bariza, "bariza(a, b, c)"), - (FunCTIonWithCAPItals, "FunCTIonWithCAPItals(a, b, c=3, **kwargs)"), - (_name_with_underscore_, "_name_with_underscore_(fo, bar, *baz)"), - (__double_underscore__, "__double_underscore__(man, o, *men, **oo)"), - (old_name, "old_name(verylongvariablename)"), - (renamed, "old_name(verylongvariablename)"), - (generic, "generic(*args, **kwargs)"), - (onlykwrgs, "onlykwrgs(**kwargs)") -]) + assert kwargs == {"b": 2, "c": 3} + + +@pytest.mark.parametrize( + "func,expected", + [ + (foo, "foo()"), + (bariza, "bariza(a, b, c)"), + (FunCTIonWithCAPItals, "FunCTIonWithCAPItals(a, b, c=3, **kwargs)"), + (_name_with_underscore_, "_name_with_underscore_(fo, bar, *baz)"), + (__double_underscore__, "__double_underscore__(man, o, *men, **oo)"), + (old_name, "old_name(verylongvariablename)"), + (renamed, "old_name(verylongvariablename)"), + (generic, "generic(*args, **kwargs)"), + (onlykwrgs, "onlykwrgs(**kwargs)"), + ], +) def test_unicode_(func, expected): assert str(Signature(func)) == expected @@ -340,7 +379,7 @@ def test_unicode_special(): assert str_signature in str(Signature(complex_function_name)) -@pytest.mark.parametrize('name,func', zip(names, functions)) +@pytest.mark.parametrize("name,func", zip(names, functions)) def test_repr_(name, func): regex = "" assert re.match(regex % name, Signature(func).__repr__()) diff --git a/tests/test_config/test_utils.py b/tests/test_config/test_utils.py index f8cf86f0..70a68e37 100644 --- a/tests/test_config/test_utils.py +++ b/tests/test_config/test_utils.py @@ -7,20 +7,54 @@ @pytest.mark.skipif(not opt.has_numpy, reason="requires numpy") -@pytest.mark.parametrize('typename', [ - 'bool_', 'int_', 'intc', 'intp', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', 'float_', 'float16', 'float32', - 'float64']) +@pytest.mark.parametrize( + "typename", + [ + "bool_", + "int_", + "intc", + "intp", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float_", + "float16", + "float32", + "float64", + ], +) def test_normalize_or_die_for_numpy_datatypes(typename): dtype = getattr(opt.np, typename) - assert normalize_or_die(dtype(7.)) + assert normalize_or_die(dtype(7.0)) @pytest.mark.skipif(not opt.has_numpy, reason="requires numpy") -@pytest.mark.parametrize('typename', [ - 'bool_', 'int_', 'intc', 'intp', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', 'float_', 'float16', 'float32', - 'float64']) +@pytest.mark.parametrize( + "typename", + [ + "bool_", + "int_", + "intc", + "intp", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float_", + "float16", + "float32", + "float64", + ], +) def test_normalize_or_die_for_numpy_arrays(typename): np = opt.np dtype = getattr(np, typename) diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index 5139261e..3ff20473 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -6,59 +6,86 @@ import mock import pytest -from sacred.dependencies import (PEP440_VERSION_PATTERN, PackageDependency, - Source, gather_sources_and_dependencies, - get_digest, get_py_file_if_possible, - is_local_source) +from sacred.dependencies import ( + PEP440_VERSION_PATTERN, + PackageDependency, + Source, + gather_sources_and_dependencies, + get_digest, + get_py_file_if_possible, + is_local_source, +) import sacred.optional as opt TEST_DIRECTORY = os.path.dirname(__file__) -EXAMPLE_SOURCE = os.path.join(TEST_DIRECTORY, '__init__.py') +EXAMPLE_SOURCE = os.path.join(TEST_DIRECTORY, "__init__.py") # The digest below is calculated from the test/__init__.py with only Python shebang and coding-information. # This type of hard-coding is most probably a quite bad idea. -EXAMPLE_DIGEST = '9e428c0aa58b75ff150c4f625e32af68' - - -@pytest.mark.parametrize('version', [ - '0.9.11', '2012.04', '1!1.1', '17.10a104', '43.0rc1', '0.9.post3', - '12.4a22.post8', '13.3rc2.dev1515', '1.0.dev456', '1.0a1', '1.0a2.dev456', - '1.0a12.dev456', '1.0a12', '1.0b1.dev456', '1.0b2', '1.0b2.post345.dev456', - '1.0b2.post345', '1.0rc1.dev456', '1.0rc1', '1.0', '1.0.post456.dev34', - '1.0.post456', '1.1.dev1' -]) +EXAMPLE_DIGEST = "9e428c0aa58b75ff150c4f625e32af68" + + +@pytest.mark.parametrize( + "version", + [ + "0.9.11", + "2012.04", + "1!1.1", + "17.10a104", + "43.0rc1", + "0.9.post3", + "12.4a22.post8", + "13.3rc2.dev1515", + "1.0.dev456", + "1.0a1", + "1.0a2.dev456", + "1.0a12.dev456", + "1.0a12", + "1.0b1.dev456", + "1.0b2", + "1.0b2.post345.dev456", + "1.0b2.post345", + "1.0rc1.dev456", + "1.0rc1", + "1.0", + "1.0.post456.dev34", + "1.0.post456", + "1.1.dev1", + ], +) def test_pep440_version_pattern(version): assert PEP440_VERSION_PATTERN.match(version) def test_pep440_version_pattern_invalid(): - assert PEP440_VERSION_PATTERN.match('foo') is None - assert PEP440_VERSION_PATTERN.match('_12_') is None - assert PEP440_VERSION_PATTERN.match('version 4') is None + assert PEP440_VERSION_PATTERN.match("foo") is None + assert PEP440_VERSION_PATTERN.match("_12_") is None + assert PEP440_VERSION_PATTERN.match("version 4") is None -@pytest.mark.skipif(os.name == 'nt', reason='Weird win bug') +@pytest.mark.skipif(os.name == "nt", reason="Weird win bug") def test_source_get_digest(): assert get_digest(EXAMPLE_SOURCE) == EXAMPLE_DIGEST def test_source_create_empty(): with pytest.raises(ValueError): - Source.create('') + Source.create("") def test_source_create_non_existing(): with pytest.raises(ValueError): - Source.create('doesnotexist.py') + Source.create("doesnotexist.py") + -@pytest.mark.skipif(os.name == 'nt', reason='Weird win bug') +@pytest.mark.skipif(os.name == "nt", reason="Weird win bug") def test_source_create_py(): s = Source.create(EXAMPLE_SOURCE) assert s.filename == os.path.abspath(EXAMPLE_SOURCE) assert s.digest == EXAMPLE_DIGEST -@pytest.mark.skipif(os.name == 'nt', reason='Weird win bug') +@pytest.mark.skipif(os.name == "nt", reason="Weird win bug") def test_source_to_json(): s = Source.create(EXAMPLE_SOURCE) assert s.to_json() == (os.path.abspath(EXAMPLE_SOURCE), EXAMPLE_DIGEST) @@ -69,7 +96,7 @@ def test_get_py_file_if_possible_with_py_file(): def test_get_py_file_if_possible_with_pyc_file(): - assert get_py_file_if_possible(EXAMPLE_SOURCE + 'c') == EXAMPLE_SOURCE + assert get_py_file_if_possible(EXAMPLE_SOURCE + "c") == EXAMPLE_SOURCE def test_source_repr(): @@ -78,79 +105,80 @@ def test_source_repr(): def test_get_py_file_if_possible_with_pyc_but_nonexistent_py_file(): - assert get_py_file_if_possible('doesnotexist.pyc') == 'doesnotexist.pyc' + assert get_py_file_if_possible("doesnotexist.pyc") == "doesnotexist.pyc" versions = [ - ('0.7.2', '0.7.2'), - ('1.0', '1.0'), - ('foobar', None), + ("0.7.2", "0.7.2"), + ("1.0", "1.0"), + ("foobar", None), (10, None), - ((2, 6), '2.6'), - ((1, 4, 8), '1.4.8') + ((2, 6), "2.6"), + ((1, 4, 8), "1.4.8"), ] -@pytest.mark.parametrize('version,expected', versions) +@pytest.mark.parametrize("version,expected", versions) def test_package_dependency_get_version_heuristic_version__(version, expected): mod = mock.Mock(spec=[], __version__=version) assert PackageDependency.get_version_heuristic(mod) == expected -@pytest.mark.parametrize('version,expected', versions) +@pytest.mark.parametrize("version,expected", versions) def test_package_dependency_get_version_heuristic_version(version, expected): mod = mock.Mock(spec=[], version=version) assert PackageDependency.get_version_heuristic(mod) == expected -@pytest.mark.parametrize('version,expected', versions) +@pytest.mark.parametrize("version,expected", versions) def test_package_dependency_get_version_heuristic_VERSION(version, expected): mod = mock.Mock(spec=[], VERSION=version) assert PackageDependency.get_version_heuristic(mod) == expected def test_package_dependency_create_no_version(): - mod = mock.Mock(spec=[], __name__='testmod') + mod = mock.Mock(spec=[], __name__="testmod") pd = PackageDependency.create(mod) - assert pd.name == 'testmod' + assert pd.name == "testmod" assert pd.version is None def test_package_dependency_fill_non_missing_version(): - pd = PackageDependency('mymod', '1.2.3rc4') + pd = PackageDependency("mymod", "1.2.3rc4") pd.fill_missing_version() - assert pd.version == '1.2.3rc4' + assert pd.version == "1.2.3rc4" def test_package_dependency_fill_missing_version_unknown(): - pd = PackageDependency('mymod', None) + pd = PackageDependency("mymod", None) pd.fill_missing_version() assert pd.version == None def test_package_dependency_fill_missing_version(): - pd = PackageDependency('pytest', None) + pd = PackageDependency("pytest", None) pd.fill_missing_version() assert pd.version == pytest.__version__ def test_package_dependency_repr(): - pd = PackageDependency('pytest', '12.4') - assert repr(pd) == '' + pd = PackageDependency("pytest", "12.4") + assert repr(pd) == "" def test_gather_sources_and_dependencies(): from tests.dependency_example import some_func + main, sources, deps = gather_sources_and_dependencies(some_func.__globals__) assert isinstance(main, Source) assert isinstance(sources, set) assert isinstance(deps, set) - assert main == Source.create(os.path.join(TEST_DIRECTORY, 'dependency_example.py')) + assert main == Source.create(os.path.join(TEST_DIRECTORY, "dependency_example.py")) expected_sources = { - Source.create(os.path.join(TEST_DIRECTORY, '__init__.py')), - Source.create(os.path.join(TEST_DIRECTORY, 'dependency_example.py')), - Source.create(os.path.join(TEST_DIRECTORY, 'foo', '__init__.py')), - Source.create(os.path.join(TEST_DIRECTORY, 'foo', 'bar.py')) + Source.create(os.path.join(TEST_DIRECTORY, "__init__.py")), + Source.create(os.path.join(TEST_DIRECTORY, "dependency_example.py")), + Source.create(os.path.join(TEST_DIRECTORY, "foo", "__init__.py")), + Source.create(os.path.join(TEST_DIRECTORY, "foo", "bar.py")), } assert sources == expected_sources @@ -167,38 +195,46 @@ def test_gather_sources_and_dependencies(): def test_custom_base_dir(): from tests.basedir.my_experiment import some_func - main, sources, deps = gather_sources_and_dependencies(some_func.__globals__, TEST_DIRECTORY) + + main, sources, deps = gather_sources_and_dependencies( + some_func.__globals__, TEST_DIRECTORY + ) assert isinstance(main, Source) assert isinstance(sources, set) assert isinstance(deps, set) - assert main == Source.create(os.path.join(TEST_DIRECTORY, 'basedir', 'my_experiment.py')) + assert main == Source.create( + os.path.join(TEST_DIRECTORY, "basedir", "my_experiment.py") + ) expected_sources = { - Source.create(os.path.join(TEST_DIRECTORY, '__init__.py')), - Source.create(os.path.join(TEST_DIRECTORY, 'basedir', '__init__.py')), - Source.create(os.path.join(TEST_DIRECTORY, 'basedir', 'my_experiment.py')), - Source.create(os.path.join(TEST_DIRECTORY, 'foo', '__init__.py')), - Source.create(os.path.join(TEST_DIRECTORY, 'foo', 'bar.py')) + Source.create(os.path.join(TEST_DIRECTORY, "__init__.py")), + Source.create(os.path.join(TEST_DIRECTORY, "basedir", "__init__.py")), + Source.create(os.path.join(TEST_DIRECTORY, "basedir", "my_experiment.py")), + Source.create(os.path.join(TEST_DIRECTORY, "foo", "__init__.py")), + Source.create(os.path.join(TEST_DIRECTORY, "foo", "bar.py")), } assert sources == expected_sources -@pytest.mark.parametrize('f_name, mod_name, ex_path, is_local', [ - ('./foo.py', 'bar', '.', False), - ('./foo.pyc', 'bar', '.', False), - ('./bar.py', 'bar', '.', True), - ('./bar.pyc', 'bar', '.', True), - ('./venv/py/bar.py', 'bar', '.', False), - ('./venv/py/bar.py', 'venv.py.bar', '.', True), - ('./venv/py/bar.pyc', 'venv.py.bar', '.', True), - ('foo.py', 'bar', '.', False), - ('bar.py', 'bar', '.', True), - ('bar.pyc', 'bar', '.', True), - ('bar.pyc', 'some.bar', '.', False), - ('/home/user/bar.py', 'user.bar', '/home/user/', True), - ('bar/__init__.py', 'bar', '.', True), - ('bar/__init__.py', 'foo', '.', False), - ('/home/user/bar/__init__.py', 'home.user.bar', '/home/user/', True), - ('/home/user/bar/__init__.py', 'home.user.foo', '/home/user/', False), -]) +@pytest.mark.parametrize( + "f_name, mod_name, ex_path, is_local", + [ + ("./foo.py", "bar", ".", False), + ("./foo.pyc", "bar", ".", False), + ("./bar.py", "bar", ".", True), + ("./bar.pyc", "bar", ".", True), + ("./venv/py/bar.py", "bar", ".", False), + ("./venv/py/bar.py", "venv.py.bar", ".", True), + ("./venv/py/bar.pyc", "venv.py.bar", ".", True), + ("foo.py", "bar", ".", False), + ("bar.py", "bar", ".", True), + ("bar.pyc", "bar", ".", True), + ("bar.pyc", "some.bar", ".", False), + ("/home/user/bar.py", "user.bar", "/home/user/", True), + ("bar/__init__.py", "bar", ".", True), + ("bar/__init__.py", "foo", ".", False), + ("/home/user/bar/__init__.py", "home.user.bar", "/home/user/", True), + ("/home/user/bar/__init__.py", "home.user.foo", "/home/user/", False), + ], +) def test_is_local_source(f_name, mod_name, ex_path, is_local): assert is_local_source(f_name, mod_name, ex_path) == is_local diff --git a/tests/test_examples.py b/tests/test_examples.py index eb35aa96..9d30f18e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -14,13 +14,13 @@ def test_example(capsys, example_test): captured_out, captured_err = capsys.readouterr() print(captured_out) print(captured_err) - captured_out = captured_out.split('\n') - captured_err = captured_err.split('\n') + captured_out = captured_out.split("\n") + captured_err = captured_err.split("\n") for out_line in out: assert out_line in [captured_out[0], captured_err[0]] if out_line == captured_out[0]: captured_out.pop(0) else: captured_err.pop(0) - assert captured_out == [''] - assert captured_err == [''] + assert captured_out == [""] + assert captured_err == [""] diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index a020be1b..491b0107 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -4,9 +4,15 @@ import re from sacred import Ingredient, Experiment -from sacred.utils import CircularDependencyError, ConfigAddedError, \ - MissingConfigError, NamedConfigNotFoundError, format_filtered_stacktrace, \ - format_sacred_error, SacredError +from sacred.utils import ( + CircularDependencyError, + ConfigAddedError, + MissingConfigError, + NamedConfigNotFoundError, + format_filtered_stacktrace, + format_sacred_error, + SacredError, +) """Global Docstring""" @@ -15,46 +21,48 @@ def test_circular_dependency_raises(): # create experiment with circular dependency - ing = Ingredient('ing') - ex = Experiment('exp', ingredients=[ing]) + ing = Ingredient("ing") + ex = Experiment("exp", ingredients=[ing]) ex.main(lambda: None) ing.ingredients.append(ex) # run and see if it raises - with pytest.raises(CircularDependencyError, match='exp->ing->exp'): + with pytest.raises(CircularDependencyError, match="exp->ing->exp"): ex.run() def test_config_added_raises(): - ex = Experiment('exp') + ex = Experiment("exp") ex.main(lambda: None) with pytest.raises( - ConfigAddedError, - match=r'Added new config entry that is not used anywhere.*\n' - r'\s*Conflicting configuration values:\n' - r'\s*a=42'): - ex.run(config_updates={'a': 42}) + ConfigAddedError, + match=r"Added new config entry that is not used anywhere.*\n" + r"\s*Conflicting configuration values:\n" + r"\s*a=42", + ): + ex.run(config_updates={"a": 42}) def test_missing_config_raises(): - ex = Experiment('exp') + ex = Experiment("exp") ex.main(lambda a: None) with pytest.raises(MissingConfigError): ex.run() def test_named_config_not_found_raises(): - ex = Experiment('exp') + ex = Experiment("exp") ex.main(lambda: None) - with pytest.raises(NamedConfigNotFoundError, - match='Named config not found: "not_there". ' - 'Available config values are:'): - ex.run(named_configs=('not_there',)) + with pytest.raises( + NamedConfigNotFoundError, + match='Named config not found: "not_there". ' "Available config values are:", + ): + ex.run(named_configs=("not_there",)) def test_format_filtered_stacktrace_true(): - ex = Experiment('exp') + ex = Experiment("exp") @ex.capture def f(): @@ -63,20 +71,20 @@ def f(): try: f() except: - st = format_filtered_stacktrace(filter_traceback='default') - assert 'captured_function' not in st - assert 'WITHOUT Sacred internals' in st + st = format_filtered_stacktrace(filter_traceback="default") + assert "captured_function" not in st + assert "WITHOUT Sacred internals" in st try: f() except: - st = format_filtered_stacktrace(filter_traceback='always') - assert 'captured_function' not in st - assert 'WITHOUT Sacred internals' in st + st = format_filtered_stacktrace(filter_traceback="always") + assert "captured_function" not in st + assert "WITHOUT Sacred internals" in st def test_format_filtered_stacktrace_false(): - ex = Experiment('exp') + ex = Experiment("exp") @ex.capture def f(): @@ -85,35 +93,50 @@ def f(): try: f() except: - st = format_filtered_stacktrace(filter_traceback='never') - assert 'captured_function' in st + st = format_filtered_stacktrace(filter_traceback="never") + assert "captured_function" in st @pytest.mark.parametrize( - 'print_traceback,filter_traceback,print_usage,expected', [ - (False, 'never', False, '.*SacredError: message'), - (True, 'never', False, r'Traceback \(most recent call last\):\n*' - r'\s*File ".*", line \d*, in ' - r'test_format_sacred_error\n*' - r'.*\n*' - r'.*SacredError: message'), - (False, 'default', False, r'.*SacredError: message'), - (False, 'always', False, r'.*SacredError: message'), - (False, 'never', True, r'usage\n.*SacredError: message'), - (True, 'default', False, r'Traceback \(most recent calls WITHOUT ' - r'Sacred internals\):\n*' - r'(\n|.)*' - r'.*SacredError: message'), - (True, 'always', False, r'Traceback \(most recent calls WITHOUT ' - r'Sacred internals\):\n*' - r'(\n|.)*' - r'.*SacredError: message') - ]) -def test_format_sacred_error(print_traceback, filter_traceback, print_usage, - expected): + "print_traceback,filter_traceback,print_usage,expected", + [ + (False, "never", False, ".*SacredError: message"), + ( + True, + "never", + False, + r"Traceback \(most recent call last\):\n*" + r'\s*File ".*", line \d*, in ' + r"test_format_sacred_error\n*" + r".*\n*" + r".*SacredError: message", + ), + (False, "default", False, r".*SacredError: message"), + (False, "always", False, r".*SacredError: message"), + (False, "never", True, r"usage\n.*SacredError: message"), + ( + True, + "default", + False, + r"Traceback \(most recent calls WITHOUT " + r"Sacred internals\):\n*" + r"(\n|.)*" + r".*SacredError: message", + ), + ( + True, + "always", + False, + r"Traceback \(most recent calls WITHOUT " + r"Sacred internals\):\n*" + r"(\n|.)*" + r".*SacredError: message", + ), + ], +) +def test_format_sacred_error(print_traceback, filter_traceback, print_usage, expected): try: - raise SacredError('message', print_traceback, filter_traceback, - print_usage) + raise SacredError("message", print_traceback, filter_traceback, print_usage) except SacredError as e: - st = format_sacred_error(e, 'usage') + st = format_sacred_error(e, "usage") assert re.match(expected, st, re.MULTILINE) diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 9069d70f..73f80d20 100755 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -10,13 +10,12 @@ import sys from sacred.experiment import Experiment -from sacred.utils import apply_backspaces_and_linefeeds, ConfigAddedError, \ - SacredError +from sacred.utils import apply_backspaces_and_linefeeds, ConfigAddedError, SacredError @pytest.fixture def ex(): - return Experiment('ator3000') + return Experiment("ator3000") def test_main(ex): @@ -24,23 +23,23 @@ def test_main(ex): def foo(): pass - assert 'foo' in ex.commands - assert ex.commands['foo'] == foo - assert ex.default_command == 'foo' + assert "foo" in ex.commands + assert ex.commands["foo"] == foo + assert ex.default_command == "foo" def test_automain_imported(ex): main_called = [False] - with patch.object(sys, 'argv', ['test.py']): + with patch.object(sys, "argv", ["test.py"]): @ex.automain def foo(): main_called[0] = True - assert 'foo' in ex.commands - assert ex.commands['foo'] == foo - assert ex.default_command == 'foo' + assert "foo" in ex.commands + assert ex.commands["foo"] == foo + assert ex.default_command == "foo" assert main_called[0] is False @@ -50,15 +49,16 @@ def test_automain_script_runs_main(ex): main_called = [False] try: - __name__ = '__main__' - with patch.object(sys, 'argv', ['test.py']): + __name__ = "__main__" + with patch.object(sys, "argv", ["test.py"]): + @ex.automain def foo(): main_called[0] = True - assert 'foo' in ex.commands - assert ex.commands['foo'] == foo - assert ex.default_command == 'foo' + assert "foo" in ex.commands + assert ex.commands["foo"] == foo + assert ex.default_command == "foo" assert main_called[0] is True finally: __name__ = oldname @@ -75,37 +75,37 @@ def foo(a, b=2): return a + b # normal config updates work - assert ex.run(config_updates={'a': 3}).result == 5 + assert ex.run(config_updates={"a": 3}).result == 5 # not in config but used works - assert ex.run(config_updates={'b': 8}).result == 9 + assert ex.run(config_updates={"b": 8}).result == 9 # unused but in config updates work - assert ex.run(config_updates={'c': 9}).result == 3 + assert ex.run(config_updates={"c": 9}).result == 3 # unused config updates raise with pytest.raises(ConfigAddedError): - ex.run(config_updates={'d': 3}) + ex.run(config_updates={"d": 3}) def test_fails_on_nested_unused_config_updates(ex): @ex.config def cfg(): - a = {'b': 1} - d = {'e': 3} + a = {"b": 1} + d = {"e": 3} @ex.main def foo(a): - return a['b'] + return a["b"] # normal config updates work - assert ex.run(config_updates={'a': {'b': 2}}).result == 2 + assert ex.run(config_updates={"a": {"b": 2}}).result == 2 # not in config but parent is works - assert ex.run(config_updates={'a': {'c': 5}}).result == 1 + assert ex.run(config_updates={"a": {"c": 5}}).result == 1 # unused but in config works - assert ex.run(config_updates={'d': {'e': 7}}).result == 1 + assert ex.run(config_updates={"d": {"e": 7}}).result == 1 # unused nested config updates raise with pytest.raises(ConfigAddedError): - ex.run(config_updates={'d': {'f': 3}}) + ex.run(config_updates={"d": {"f": 3}}) def test_considers_captured_functions_for_fail_on_unused_config(ex): @@ -121,19 +121,19 @@ def transmogrify(a, b=0): def foo(): return transmogrify() - assert ex.run(config_updates={'a': 7}).result == 7 - assert ex.run(config_updates={'b': 3}).result == 4 + assert ex.run(config_updates={"a": 7}).result == 7 + assert ex.run(config_updates={"b": 3}).result == 4 with pytest.raises(ConfigAddedError): - ex.run(config_updates={'c': 3}) + ex.run(config_updates={"c": 3}) def test_considers_prefix_for_fail_on_unused_config(ex): @ex.config def cfg(): - a = {'b': 1} + a = {"b": 1} - @ex.capture(prefix='a') + @ex.capture(prefix="a") def transmogrify(b): return b @@ -141,17 +141,17 @@ def transmogrify(b): def foo(): return transmogrify() - assert ex.run(config_updates={'a': {'b': 3}}).result == 3 + assert ex.run(config_updates={"a": {"b": 3}}).result == 3 with pytest.raises(ConfigAddedError): - ex.run(config_updates={'b': 5}) + ex.run(config_updates={"b": 5}) with pytest.raises(ConfigAddedError): - ex.run(config_updates={'a': {'c': 5}}) + ex.run(config_updates={"a": {"c": 5}}) def test_non_existing_prefix_is_treatet_as_empty_dict(ex): - @ex.capture(prefix='nonexisting') + @ex.capture(prefix="nonexisting") def transmogrify(b=10): return b @@ -176,21 +176,21 @@ def run(a): return a assert ex.run().result == 1 - assert ex.run(named_configs=['ncfg']).result == 10 + assert ex.run(named_configs=["ncfg"]).result == 10 def test_empty_dict_named_config(ex): @ex.named_config def ncfg(): empty_dict = {} - nested_empty_dict = {'k1': {'k2': {}}} + nested_empty_dict = {"k1": {"k2": {}}} @ex.automain def main(empty_dict=1, nested_empty_dict=2): return empty_dict, nested_empty_dict assert ex.run().result == (1, 2) - assert ex.run(named_configs=['ncfg']).result == ({}, {'k1': {'k2': {}}}) + assert ex.run(named_configs=["ncfg"]).result == ({}, {"k1": {"k2": {}}}) def test_empty_dict_config_updates(ex): @@ -200,18 +200,18 @@ def cfg(): @ex.config def default(): - a = {'b': 1} + a = {"b": 1} @ex.main def main(): pass r = ex.run() - assert r.config['a']['b'] == 1 + assert r.config["a"]["b"] == 1 def test_named_config_and_ingredient(): - ing = Ingredient('foo') + ing = Ingredient("foo") @ing.config def cfg(): @@ -231,26 +231,26 @@ def named(): def main(): pass - r = ex.run(named_configs=['named']) - assert r.config['b'] == 30 - assert r.config['foo'] == {'a': 10} + r = ex.run(named_configs=["named"]) + assert r.config["b"] == 30 + assert r.config["foo"] == {"a": 10} def test_captured_out_filter(ex, capsys): @ex.main def run_print_mock_progress(): - sys.stdout.write('progress 0') + sys.stdout.write("progress 0") sys.stdout.flush() for i in range(10): - sys.stdout.write('\b') + sys.stdout.write("\b") sys.stdout.write("{}".format(i)) sys.stdout.flush() ex.captured_out_filter = apply_backspaces_and_linefeeds # disable logging and set capture mode to python - options = {'--loglevel': 'CRITICAL', '--capture': 'sys'} + options = {"--loglevel": "CRITICAL", "--capture": "sys"} with capsys.disabled(): - assert ex.run(options=options).captured_out == 'progress 9' + assert ex.run(options=options).captured_out == "progress 9" def test_adding_option_hooks(ex): @@ -268,20 +268,20 @@ def hook2(options): def test_option_hooks_without_options_arg_raises(ex): with pytest.raises(KeyError): + @ex.option_hook def invalid_hook(wrong_arg_name): pass def test_config_hook_updates_config(ex): - @ex.config def cfg(): - a = 'hello' + a = "hello" @ex.config_hook def hook(config, command_name, logger): - config.update({'a': 'me'}) + config.update({"a": "me"}) return config @ex.main @@ -289,7 +289,7 @@ def foo(): pass r = ex.run() - assert r.config['a'] == 'me' + assert r.config["a"] == "me" def test_info_kwarg_updates_info(ex): @@ -300,8 +300,8 @@ def test_info_kwarg_updates_info(ex): def foo(): pass - run = ex.run(info={'bar': 'baz'}) - assert 'bar' in run.info + run = ex.run(info={"bar": "baz"}) + assert "bar" in run.info def test_info_kwargs_default_behavior(ex): @@ -310,57 +310,60 @@ def test_info_kwargs_default_behavior(ex): @ex.automain def foo(_run): - _run.info['bar'] = 'baz' + _run.info["bar"] = "baz" run = ex.run() - assert 'bar' in run.info + assert "bar" in run.info + def test_fails_on_config_write(ex): @ex.config def cfg(): - a = 'hello' - nested_dict = {'dict': {'dict': 1234, 'list': [1, 2, 3, 4]}} - nested_list = [{'a': 42}, (1, 2, 3, 4), [1, 2, 3, 4]] - nested_tuple = ({'a': 42}, (1, 2, 3, 4), [1, 2, 3, 4]) + a = "hello" + nested_dict = {"dict": {"dict": 1234, "list": [1, 2, 3, 4]}} + nested_list = [{"a": 42}, (1, 2, 3, 4), [1, 2, 3, 4]] + nested_tuple = ({"a": 42}, (1, 2, 3, 4), [1, 2, 3, 4]) @ex.main def main(_config, nested_dict, nested_list, nested_tuple): raises_list = pytest.raises( - SacredError, match='The configuration is read-only in a captured function!') + SacredError, match="The configuration is read-only in a captured function!" + ) raises_dict = pytest.raises( - SacredError, match='The configuration is read-only in a captured function!') + SacredError, match="The configuration is read-only in a captured function!" + ) - print('in main') + print("in main") # Test for ReadOnlyDict with raises_dict: - _config['a'] = 'world!' + _config["a"] = "world!" with raises_dict: - nested_dict['dict'] = 'world!' + nested_dict["dict"] = "world!" with raises_dict: - nested_dict['list'] = 'world!' + nested_dict["list"] = "world!" with raises_dict: nested_dict.clear() with raises_dict: - nested_dict.update({'a': 'world'}) + nested_dict.update({"a": "world"}) # Test ReadOnlyList with raises_list: - nested_dict['dict']['list'][0] = 1 + nested_dict["dict"]["list"][0] = 1 with raises_list: - nested_list[0] = 'world!' + nested_list[0] = "world!" with raises_list: nested_dict.clear() # Test nested tuple with raises_dict: - nested_tuple[0]['a'] = 'world!' + nested_tuple[0]["a"] = "world!" with raises_list: nested_tuple[2][0] = 123 @@ -369,39 +372,24 @@ def main(_config, nested_dict, nested_list, nested_tuple): def test_add_config_dict_chain(ex): - @ex.config def config1(): """This is my demo configuration""" - dictnest_cap = { - 'key_1': 'value_1', - 'key_2': 'value_2' - } - + dictnest_cap = {"key_1": "value_1", "key_2": "value_2"} @ex.config def config2(): """This is my demo configuration""" - dictnest_cap = { - 'key_2': 'update_value_2', - 'key_3': 'value3', - 'key_4': 'value4' - } - + dictnest_cap = {"key_2": "update_value_2", "key_3": "value3", "key_4": "value4"} - adict = { - 'dictnest_dict': { - 'key_1': 'value_1', - 'key_2': 'value_2' - } - } + adict = {"dictnest_dict": {"key_1": "value_1", "key_2": "value_2"}} ex.add_config(adict) bdict = { - 'dictnest_dict': { - 'key_2': 'update_value_2', - 'key_3': 'value3', - 'key_4': 'value4' + "dictnest_dict": { + "key_2": "update_value_2", + "key_3": "value3", + "key_4": "value4", } } ex.add_config(bdict) @@ -411,7 +399,10 @@ def run(): pass final_config = ex.run().config - assert final_config['dictnest_cap'] == { - 'key_1': 'value_1', 'key_2': 'update_value_2', - 'key_3': 'value3', 'key_4': 'value4'} - assert final_config['dictnest_cap'] == final_config['dictnest_dict'] + assert final_config["dictnest_cap"] == { + "key_1": "value_1", + "key_2": "update_value_2", + "key_3": "value3", + "key_4": "value4", + } + assert final_config["dictnest_cap"] == final_config["dictnest_dict"] diff --git a/tests/test_host_info.py b/tests/test_host_info.py index 79fd0b81..abc3ca16 100644 --- a/tests/test_host_info.py +++ b/tests/test_host_info.py @@ -1,44 +1,43 @@ #!/usr/bin/env python # coding=utf-8 -from sacred.host_info import (get_host_info, host_info_getter, - host_info_gatherers) +from sacred.host_info import get_host_info, host_info_getter, host_info_gatherers def test_get_host_info(): host_info = get_host_info() - assert isinstance(host_info['hostname'], str) - assert isinstance(host_info['cpu'], str) - assert isinstance(host_info['os'], (tuple, list)) - assert isinstance(host_info['python_version'], str) + assert isinstance(host_info["hostname"], str) + assert isinstance(host_info["cpu"], str) + assert isinstance(host_info["os"], (tuple, list)) + assert isinstance(host_info["python_version"], str) def test_host_info_decorator(): try: - assert 'greeting' not in host_info_gatherers + assert "greeting" not in host_info_gatherers @host_info_getter def greeting(): return "hello" - assert 'greeting' in host_info_gatherers - assert host_info_gatherers['greeting'] == greeting - assert get_host_info()['greeting'] == 'hello' + assert "greeting" in host_info_gatherers + assert host_info_gatherers["greeting"] == greeting + assert get_host_info()["greeting"] == "hello" finally: - del host_info_gatherers['greeting'] + del host_info_gatherers["greeting"] def test_host_info_decorator_with_name(): try: - assert 'foo' not in host_info_gatherers + assert "foo" not in host_info_gatherers - @host_info_getter(name='foo') + @host_info_getter(name="foo") def greeting(): return "hello" - assert 'foo' in host_info_gatherers - assert 'greeting' not in host_info_gatherers - assert host_info_gatherers['foo'] == greeting - assert get_host_info()['foo'] == 'hello' + assert "foo" in host_info_gatherers + assert "greeting" not in host_info_gatherers + assert host_info_gatherers["foo"] == greeting + assert get_host_info()["foo"] == "hello" finally: - del host_info_gatherers['foo'] + del host_info_gatherers["foo"] diff --git a/tests/test_ingredients.py b/tests/test_ingredients.py index af05c68c..a191429a 100644 --- a/tests/test_ingredients.py +++ b/tests/test_ingredients.py @@ -16,11 +16,11 @@ @pytest.fixture def ing(): - return Ingredient('tickle') + return Ingredient("tickle") def test_create_ingredient(ing): - assert ing.path == 'tickle' + assert ing.path == "tickle" assert ing.doc == __doc__ assert Source.create(__file__) in ing.sources @@ -29,16 +29,18 @@ def test_capture_function(ing): @ing.capture def foo(something): pass + assert foo in ing.captured_functions assert foo.prefix is None def test_capture_function_with_prefix(ing): - @ing.capture(prefix='bar') + @ing.capture(prefix="bar") def foo(something): pass + assert foo in ing.captured_functions - assert foo.prefix == 'bar' + assert foo.prefix == "bar" def test_capture_function_twice(ing): @@ -55,36 +57,40 @@ def test_add_pre_run_hook(ing): @ing.pre_run_hook def foo(something): pass + assert foo in ing.pre_run_hooks assert foo in ing.captured_functions assert foo.prefix is None def test_add_pre_run_hook_with_prefix(ing): - @ing.pre_run_hook(prefix='bar') + @ing.pre_run_hook(prefix="bar") def foo(something): pass + assert foo in ing.pre_run_hooks assert foo in ing.captured_functions - assert foo.prefix == 'bar' + assert foo.prefix == "bar" def test_add_post_run_hook(ing): @ing.post_run_hook def foo(something): pass + assert foo in ing.post_run_hooks assert foo in ing.captured_functions assert foo.prefix is None def test_add_post_run_hook_with_prefix(ing): - @ing.post_run_hook(prefix='bar') + @ing.post_run_hook(prefix="bar") def foo(something): pass + assert foo in ing.post_run_hooks assert foo in ing.captured_functions - assert foo.prefix == 'bar' + assert foo.prefix == "bar" def test_add_command(ing): @@ -92,19 +98,19 @@ def test_add_command(ing): def foo(a, b): pass - assert 'foo' in ing.commands - assert ing.commands['foo'] == foo + assert "foo" in ing.commands + assert ing.commands["foo"] == foo assert foo.prefix is None def test_add_command_with_prefix(ing): - @ing.command(prefix='bar') + @ing.command(prefix="bar") def foo(a, b): pass - assert 'foo' in ing.commands - assert ing.commands['foo'] == foo - assert foo.prefix == 'bar' + assert "foo" in ing.commands + assert ing.commands["foo"] == foo + assert foo.prefix == "bar" def test_add_unobserved_command(ing): @@ -112,14 +118,15 @@ def test_add_unobserved_command(ing): def foo(a, b): pass - assert 'foo' in ing.commands - assert ing.commands['foo'] == foo + assert "foo" in ing.commands + assert ing.commands["foo"] == foo assert foo.unobserved is True def test_add_config_hook(ing): def foo(config, command_name, logger): pass + ch = ing.config_hook(foo) assert ch == foo assert foo in ing.config_hooks @@ -138,35 +145,37 @@ def test_add_named_config(ing): @ing.named_config def foo(): pass + assert isinstance(foo, ConfigScope) - assert 'foo' in ing.named_configs - assert ing.named_configs['foo'] == foo + assert "foo" in ing.named_configs + assert ing.named_configs["foo"] == foo def test_add_config_hook_with_invalid_signature_raises(ing): with pytest.raises(ValueError): + @ing.config_hook def foo(wrong, signature): pass def test_add_config_dict(ing): - ing.add_config({'foo': 12, 'bar': 4}) + ing.add_config({"foo": 12, "bar": 4}) assert len(ing.configurations) == 1 assert isinstance(ing.configurations[0], ConfigDict) - assert ing.configurations[0]() == {'foo': 12, 'bar': 4} + assert ing.configurations[0]() == {"foo": 12, "bar": 4} def test_add_config_kwargs(ing): ing.add_config(foo=18, bar=3) assert len(ing.configurations) == 1 assert isinstance(ing.configurations[0], ConfigDict) - assert ing.configurations[0]() == {'foo': 18, 'bar': 3} + assert ing.configurations[0]() == {"foo": 18, "bar": 3} def test_add_config_kwargs_and_dict_raises(ing): with pytest.raises(ValueError): - ing.add_config({'foo': 12}, bar=3) + ing.add_config({"foo": 12}, bar=3) def test_add_config_empty_raises(ing): @@ -183,15 +192,15 @@ def test_add_config_non_dict_raises(ing): def test_add_config_file(ing): - handle, f_name = tempfile.mkstemp(suffix='.json') + handle, f_name = tempfile.mkstemp(suffix=".json") f = os.fdopen(handle, "w") - f.write(json.encode({'foo': 15, 'bar': 7})) + f.write(json.encode({"foo": 15, "bar": 7})) f.close() ing.add_config(f_name) assert len(ing.configurations) == 1 assert isinstance(ing.configurations[0], ConfigDict) - assert ing.configurations[0]() == {'foo': 15, 'bar': 7} + assert ing.configurations[0]() == {"foo": 15, "bar": 7} os.remove(f_name) @@ -201,7 +210,7 @@ def test_add_config_file_nonexisting_raises(ing): def test_add_source_file(ing): - handle, f_name = tempfile.mkstemp(suffix='.py') + handle, f_name = tempfile.mkstemp(suffix=".py") f = os.fdopen(handle, "w") f.write("print('Hello World')") f.close() @@ -212,35 +221,35 @@ def test_add_source_file(ing): def test_add_source_file_nonexisting_raises(ing): with pytest.raises(ValueError): - ing.add_source_file('nonexisting.py') + ing.add_source_file("nonexisting.py") def test_add_package_dependency(ing): - ing.add_package_dependency('django', '1.8.2') - assert PackageDependency('django', '1.8.2') in ing.dependencies + ing.add_package_dependency("django", "1.8.2") + assert PackageDependency("django", "1.8.2") in ing.dependencies def test_add_package_dependency_invalid_version_raises(ing): with pytest.raises(ValueError): - ing.add_package_dependency('django', 'foobar') + ing.add_package_dependency("django", "foobar") def test_get_experiment_info(ing): info = ing.get_experiment_info() - assert info['name'] == 'tickle' - assert 'dependencies' in info - assert 'sources' in info + assert info["name"] == "tickle" + assert "dependencies" in info + assert "sources" in info def test_get_experiment_info_circular_dependency_raises(ing): - ing2 = Ingredient('other', ingredients=[ing]) + ing2 = Ingredient("other", ingredients=[ing]) ing.ingredients = [ing2] with pytest.raises(CircularDependencyError): ing.get_experiment_info() def test_gather_commands(ing): - ing2 = Ingredient('other', ingredients=[ing]) + ing2 = Ingredient("other", ingredients=[ing]) @ing.command def foo(): @@ -251,12 +260,12 @@ def bar(): pass commands = list(ing2.gather_commands()) - assert ('other.bar', bar) in commands - assert ('tickle.foo', foo) in commands + assert ("other.bar", bar) in commands + assert ("tickle.foo", foo) in commands def test_gather_named_configs(ing): - ing2 = Ingredient('ing2', ingredients=[ing]) + ing2 = Ingredient("ing2", ingredients=[ing]) @ing.named_config def named_config1(): @@ -268,8 +277,8 @@ def named_config2(): pass named_configs = list(ing2.gather_named_configs()) - assert ('ing2.named_config2', named_config2) in named_configs - assert ('tickle.named_config1', named_config1) in named_configs + assert ("ing2.named_config2", named_config2) in named_configs + assert ("tickle.named_config1", named_config1) in named_configs def test_config_docs_are_preserved(ing): @@ -284,5 +293,5 @@ def run(): return 5 run = ex._create_run() - assert 'tickle.a' in run.config_modifications.docs - assert run.config_modifications.docs['tickle.a'] == 'documented entry' + assert "tickle.a" in run.config_modifications.docs + assert run.config_modifications.docs["tickle.a"] == "documented entry" diff --git a/tests/test_metrics_logger.py b/tests/test_metrics_logger.py index 3d82f94a..e914e6b8 100644 --- a/tests/test_metrics_logger.py +++ b/tests/test_metrics_logger.py @@ -17,24 +17,26 @@ def test_log_scalar_metric_with_run(ex): END = 100 STEP_SIZE = 5 messages = {} + @ex.main def main_function(_run): # First, make sure the queue is empty: assert len(ex.current_run._metrics.get_last_metrics()) == 0 for i in range(START, END, STEP_SIZE): - val = i*i + val = i * i _run.log_scalar("training.loss", val, i) messages["messages"] = ex.current_run._metrics.get_last_metrics() """Calling get_last_metrics clears the metrics logger internal queue. If we don't call it here, it would be called during Sacred heartbeat event after the run finishes, and the data we want to test would be lost.""" + ex.run() assert ex.current_run is not None messages = messages["messages"] - assert len(messages) == (END - START)/STEP_SIZE - for i in range(len(messages)-1): - assert messages[i].step < messages[i+1].step + assert len(messages) == (END - START) / STEP_SIZE + for i in range(len(messages) - 1): + assert messages[i].step < messages[i + 1].step assert messages[i].step == START + i * STEP_SIZE assert messages[i].timestamp <= messages[i + 1].timestamp @@ -44,36 +46,40 @@ def test_log_scalar_metric_with_ex(ex): START = 10 END = 100 STEP_SIZE = 5 + @ex.main def main_function(_run): for i in range(START, END, STEP_SIZE): - val = i*i + val = i * i ex.log_scalar("training.loss", val, i) messages["messages"] = ex.current_run._metrics.get_last_metrics() + ex.run() assert ex.current_run is not None messages = messages["messages"] assert len(messages) == (END - START) / STEP_SIZE - for i in range(len(messages)-1): - assert messages[i].step < messages[i+1].step + for i in range(len(messages) - 1): + assert messages[i].step < messages[i + 1].step assert messages[i].step == START + i * STEP_SIZE assert messages[i].timestamp <= messages[i + 1].timestamp def test_log_scalar_metric_with_implicit_step(ex): messages = {} + @ex.main def main_function(_run): for i in range(10): - val = i*i + val = i * i ex.log_scalar("training.loss", val) messages["messages"] = ex.current_run._metrics.get_last_metrics() + ex.run() assert ex.current_run is not None messages = messages["messages"] assert len(messages) == 10 - for i in range(len(messages)-1): - assert messages[i].step < messages[i+1].step + for i in range(len(messages) - 1): + assert messages[i].step < messages[i + 1].step assert messages[i].step == i assert messages[i].timestamp <= messages[i + 1].timestamp @@ -84,10 +90,11 @@ def test_log_scalar_metrics_with_implicit_step(ex): @ex.main def main_function(_run): for i in range(10): - val = i*i + val = i * i ex.log_scalar("training.loss", val) ex.log_scalar("training.accuracy", val + 1) messages["messages"] = ex.current_run._metrics.get_last_metrics() + ex.run() assert ex.current_run is not None messages = messages["messages"] @@ -108,12 +115,14 @@ def main_function(_run): def test_linearize_metrics(): - entries = [ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 100), - ScalarMetricLogEntry("training.accuracy", 5, datetime.datetime.utcnow(), 50), - ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 200), - ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100), - ScalarMetricLogEntry("training.accuracy", 15, datetime.datetime.utcnow(), 150), - ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300)] + entries = [ + ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 100), + ScalarMetricLogEntry("training.accuracy", 5, datetime.datetime.utcnow(), 50), + ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 200), + ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100), + ScalarMetricLogEntry("training.accuracy", 15, datetime.datetime.utcnow(), 150), + ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300), + ] linearized = linearize_metrics(entries) assert type(linearized) == dict assert len(linearized.keys()) == 2 diff --git a/tests/test_modules.py b/tests/test_modules.py index 0290ffbe..3b1e3759 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -11,12 +11,12 @@ def test_ingredient_config(): @m.config def cfg(): a = 5 - b = 'foo' + b = "foo" assert len(m.configurations) == 1 cfg = m.configurations[0] assert isinstance(cfg, ConfigScope) - assert cfg() == {'a': 5, 'b': 'foo'} + assert cfg() == {"a": 5, "b": "foo"} def test_ingredient_captured_functions(): @@ -34,21 +34,22 @@ def get_answer(b): def test_ingredient_command(): m = Ingredient("somemod") - m.add_config(a=42, b='foo{}') + m.add_config(a=42, b="foo{}") @m.command def transmogrify(a, b): return b.format(a) - assert 'transmogrify' in m.commands - assert m.commands['transmogrify'] == transmogrify - ex = Experiment('foo', ingredients=[m]) + assert "transmogrify" in m.commands + assert m.commands["transmogrify"] == transmogrify + ex = Experiment("foo", ingredients=[m]) - assert ex.run('somemod.transmogrify').result == 'foo42' + assert ex.run("somemod.transmogrify").result == "foo42" # ############# Experiment #################################################### + def test_experiment_run(): ex = Experiment("some_experiment") @@ -65,7 +66,7 @@ def test_experiment_run_access_subingredient(): @somemod.config def cfg(): a = 5 - b = 'foo' + b = "foo" ex = Experiment("some_experiment", ingredients=[somemod]) @@ -74,8 +75,8 @@ def main(somemod): return somemod r = ex.run().result - assert r['a'] == 5 - assert r['b'] == 'foo' + assert r["a"] == 5 + assert r["b"] == "foo" def test_experiment_run_subingredient_function(): @@ -84,7 +85,7 @@ def test_experiment_run_subingredient_function(): @somemod.config def cfg(): a = 5 - b = 'foo' + b = "foo" @somemod.capture def get_answer(b): @@ -96,7 +97,7 @@ def get_answer(b): def main(): return get_answer() - assert ex.run().result == 'foo' + assert ex.run().result == "foo" def test_experiment_named_config_subingredient(): @@ -123,17 +124,17 @@ def cfg(): @ex.named_config def ncfg(): a = 2 - somemod = {'a': 25} + somemod = {"a": 25} @ex.main def main(a): return a, get_answer() assert ex.run().result == (1, 15) - assert ex.run(named_configs=['somemod.nsubcfg']).result == (1, 16) - assert ex.run(named_configs=['ncfg']).result == (2, 25) - assert ex.run(named_configs=['ncfg', 'somemod.nsubcfg']).result == (2, 16) - assert ex.run(named_configs=['somemod.nsubcfg', 'ncfg']).result == (2, 25) + assert ex.run(named_configs=["somemod.nsubcfg"]).result == (1, 16) + assert ex.run(named_configs=["ncfg"]).result == (2, 25) + assert ex.run(named_configs=["ncfg", "somemod.nsubcfg"]).result == (2, 16) + assert ex.run(named_configs=["somemod.nsubcfg", "ncfg"]).result == (2, 25) def test_experiment_named_config_subingredient_overwrite(): @@ -147,17 +148,17 @@ def get_answer(a): @ex.named_config def ncfg(): - somemod = {'a': 1} + somemod = {"a": 1} @ex.main def main(): return get_answer() - assert ex.run(named_configs=['ncfg']).result == 1 - assert ex.run(config_updates={'somemod': {'a': 2}}).result == 2 - assert ex.run(named_configs=['ncfg'], - config_updates={'somemod': {'a': 2}} - ).result == 2 + assert ex.run(named_configs=["ncfg"]).result == 1 + assert ex.run(config_updates={"somemod": {"a": 2}}).result == 2 + assert ( + ex.run(named_configs=["ncfg"], config_updates={"somemod": {"a": 2}}).result == 2 + ) def test_experiment_double_named_config(): @@ -166,39 +167,33 @@ def test_experiment_double_named_config(): @ex.config def config(): a = 0 - d = { - 'e': 0, - 'f': 0 - } + d = {"e": 0, "f": 0} @ex.named_config def A(): a = 2 - d = { - 'e': 2, - 'f': 2 - } + d = {"e": 2, "f": 2} @ex.named_config def B(): - d = {'f': -1} + d = {"f": -1} @ex.main def run(a, d): - return a, d['e'], d['f'] + return a, d["e"], d["f"] assert ex.run().result == (0, 0, 0) - assert ex.run(named_configs=['A']).result == (2, 2, 2) - assert ex.run(named_configs=['B']).result == (0, 0, -1) - assert ex.run(named_configs=['A', 'B']).result == (2, 2, -1) - assert ex.run(named_configs=['B', 'A']).result == (2, 2, 2) + assert ex.run(named_configs=["A"]).result == (2, 2, 2) + assert ex.run(named_configs=["B"]).result == (0, 0, -1) + assert ex.run(named_configs=["A", "B"]).result == (2, 2, -1) + assert ex.run(named_configs=["B", "A"]).result == (2, 2, 2) def test_double_nested_config(): - sub_sub_ing = Ingredient('sub_sub_ing') - sub_ing = Ingredient('sub_ing', [sub_sub_ing]) - ing = Ingredient('ing', [sub_ing]) - ex = Experiment('ex', [ing]) + sub_sub_ing = Ingredient("sub_sub_ing") + sub_ing = Ingredient("sub_ing", [sub_sub_ing]) + ing = Ingredient("ing", [sub_ing]) + ex = Experiment("ex", [ing]) @ex.config def config(): @@ -219,33 +214,28 @@ def config(): @sub_sub_ing.capture def sub_sub_ing_main(_config): - assert _config == { - 'd': 3 - }, _config + assert _config == {"d": 3}, _config @sub_ing.capture def sub_ing_main(_config): - assert _config == { - 'c': 2, - 'sub_sub_ing': {'d': 3} - }, _config + assert _config == {"c": 2, "sub_sub_ing": {"d": 3}}, _config @ing.capture def ing_main(_config): assert _config == { - 'b': 1, - 'sub_sub_ing': {'d': 3}, - 'sub_ing': {'c': 2} + "b": 1, + "sub_sub_ing": {"d": 3}, + "sub_ing": {"c": 2}, }, _config @ex.main def main(_config): assert _config == { - 'a': 1, - 'sub_sub_ing': {'d': 3}, - 'sub_ing': {'c': 2}, - 'ing': {'b': 1}, - 'seed': 42 + "a": 1, + "sub_sub_ing": {"d": 3}, + "sub_ing": {"c": 2}, + "ing": {"b": 1}, + "seed": 42, }, _config ing_main() diff --git a/tests/test_observers/__init__.py b/tests/test_observers/__init__.py index 9a947b1a..5c9136ad 100644 --- a/tests/test_observers/__init__.py +++ b/tests/test_observers/__init__.py @@ -1,3 +1,2 @@ #!/usr/bin/env python # coding=utf-8 - diff --git a/tests/test_observers/failing_mongo_mock.py b/tests/test_observers/failing_mongo_mock.py index f1aeee30..c0bdb541 100644 --- a/tests/test_observers/failing_mongo_mock.py +++ b/tests/test_observers/failing_mongo_mock.py @@ -4,39 +4,48 @@ class FailingMongoClient(mongomock.MongoClient): - def __init__(self, max_calls_before_failure=2, - exception_to_raise=pymongo.errors.AutoReconnect, **kwargs): + def __init__( + self, + max_calls_before_failure=2, + exception_to_raise=pymongo.errors.AutoReconnect, + **kwargs + ): super().__init__(**kwargs) self._max_calls_before_failure = max_calls_before_failure self.exception_to_raise = exception_to_raise self._exception_to_raise = exception_to_raise - def get_database(self, name, codec_options=None, read_preference=None, - write_concern=None): + def get_database( + self, name, codec_options=None, read_preference=None, write_concern=None + ): db = self._databases.get(name) if db is None: db = self._databases[name] = FailingDatabase( max_calls_before_failure=self._max_calls_before_failure, - exception_to_raise=self._exception_to_raise, client=self, - name=name, ) + exception_to_raise=self._exception_to_raise, + client=self, + name=name, + ) return db class FailingDatabase(mongomock.Database): - def __init__(self, max_calls_before_failure, exception_to_raise=None, - **kwargs): + def __init__(self, max_calls_before_failure, exception_to_raise=None, **kwargs): super().__init__(**kwargs) self._max_calls_before_failure = max_calls_before_failure self._exception_to_raise = exception_to_raise - def get_collection(self, name, codec_options=None, read_preference=None, - write_concern=None): + def get_collection( + self, name, codec_options=None, read_preference=None, write_concern=None + ): collection = self._collections.get(name) if collection is None: collection = self._collections[name] = FailingCollection( max_calls_before_failure=self._max_calls_before_failure, - exception_to_raise=self._exception_to_raise, db=self, - name=name, ) + exception_to_raise=self._exception_to_raise, + db=self, + name=name, + ) return collection @@ -59,8 +68,7 @@ def update_one(self, filter, update, upsert=False): if self._calls > self._max_calls_before_failure: raise pymongo.errors.ConnectionFailure else: - return super().update_one(filter, update, - upsert) + return super().update_one(filter, update, upsert) class ReconnectingMongoClient(FailingMongoClient): @@ -68,15 +76,18 @@ def __init__(self, max_calls_before_reconnect, **kwargs): super().__init__(**kwargs) self._max_calls_before_reconnect = max_calls_before_reconnect - def get_database(self, name, codec_options=None, read_preference=None, - write_concern=None): + def get_database( + self, name, codec_options=None, read_preference=None, write_concern=None + ): db = self._databases.get(name) if db is None: db = self._databases[name] = ReconnectingDatabase( max_calls_before_reconnect=self._max_calls_before_reconnect, max_calls_before_failure=self._max_calls_before_failure, - exception_to_raise=self._exception_to_raise, client=self, - name=name, ) + exception_to_raise=self._exception_to_raise, + client=self, + name=name, + ) return db @@ -85,15 +96,18 @@ def __init__(self, max_calls_before_reconnect, **kwargs): super().__init__(**kwargs) self._max_calls_before_reconnect = max_calls_before_reconnect - def get_collection(self, name, codec_options=None, read_preference=None, - write_concern=None): + def get_collection( + self, name, codec_options=None, read_preference=None, write_concern=None + ): collection = self._collections.get(name) if collection is None: collection = self._collections[name] = ReconnectingCollection( max_calls_before_reconnect=self._max_calls_before_reconnect, max_calls_before_failure=self._max_calls_before_failure, - exception_to_raise=self._exception_to_raise, db=self, - name=name, ) + exception_to_raise=self._exception_to_raise, + db=self, + name=name, + ) return collection @@ -120,10 +134,11 @@ def update_one(self, filter, update, upsert=False): else: print(self.name, "update connection reestablished") - return mongomock.Collection.update_one(self, filter, update, - upsert) + return mongomock.Collection.update_one(self, filter, update, upsert) def _is_in_failure_range(self): - return (self._max_calls_before_failure - < self._calls - <= self._max_calls_before_reconnect) + return ( + self._max_calls_before_failure + < self._calls + <= self._max_calls_before_reconnect + ) diff --git a/tests/test_observers/test_file_storage_observer.py b/tests/test_observers/test_file_storage_observer.py index 39f9a64e..ff45eeca 100644 --- a/tests/test_observers/test_file_storage_observer.py +++ b/tests/test_observers/test_file_storage_observer.py @@ -19,25 +19,25 @@ @pytest.fixture() def sample_run(): - exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'} - host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'} - config = {'config': 'True', 'foo': 'bar', 'answer': 42} - command = 'run' - meta_info = {'comment': 'test run'} + exp = {"name": "test_exp", "sources": [], "doc": "", "base_dir": "/tmp"} + host = {"hostname": "test_host", "cpu_count": 1, "python_version": "3.4"} + config = {"config": "True", "foo": "bar", "answer": 42} + command = "run" + meta_info = {"comment": "test run"} return { - '_id': 'FEDCBA9876543210', - 'ex_info': exp, - 'command': command, - 'host_info': host, - 'start_time': T1, - 'config': config, - 'meta_info': meta_info, + "_id": "FEDCBA9876543210", + "ex_info": exp, + "command": command, + "host_info": host, + "start_time": T1, + "config": config, + "meta_info": meta_info, } @pytest.fixture() def dir_obs(tmpdir): - basedir = tmpdir.join('file_storage') + basedir = tmpdir.join("file_storage") return basedir, FileStorageObserver.create(basedir.strpath) @@ -47,9 +47,9 @@ def tmpfile(): # manually deleting the file, such that we can close it before running the # tests. This is necessary since on Windows we can not open the same file # twice, so for the FileStorageObserver to read it, we need to close it. - f = tempfile.NamedTemporaryFile(suffix='.py', delete=False) + f = tempfile.NamedTemporaryFile(suffix=".py", delete=False) - f.content = 'import sacred\n' + f.content = "import sacred\n" f.write(f.content.encode()) f.flush() f.seek(0) @@ -70,100 +70,110 @@ def test_fs_observer_create_does_not_create_basedir(dir_obs): def test_fs_observer_queued_event_creates_rundir(dir_obs, sample_run): basedir, obs = dir_obs _id = obs.queued_event( - sample_run['ex_info'], sample_run['command'], sample_run['host_info'], - datetime.datetime.utcnow(), sample_run['config'], - sample_run['meta_info'], sample_run['_id']) + sample_run["ex_info"], + sample_run["command"], + sample_run["host_info"], + datetime.datetime.utcnow(), + sample_run["config"], + sample_run["meta_info"], + sample_run["_id"], + ) assert _id is not None run_dir = basedir.join(str(_id)) assert run_dir.exists() - config = json.loads(run_dir.join('config.json').read()) - assert config == sample_run['config'] + config = json.loads(run_dir.join("config.json").read()) + assert config == sample_run["config"] - run = json.loads(run_dir.join('run.json').read()) + run = json.loads(run_dir.join("run.json").read()) assert run == { - 'experiment': sample_run['ex_info'], - 'command': sample_run['command'], - 'host': sample_run['host_info'], - 'meta': sample_run['meta_info'], - 'status': 'QUEUED' + "experiment": sample_run["ex_info"], + "command": sample_run["command"], + "host": sample_run["host_info"], + "meta": sample_run["meta_info"], + "status": "QUEUED", } def test_fs_observer_started_event_creates_rundir(dir_obs, sample_run): basedir, obs = dir_obs - sample_run['_id'] = None + sample_run["_id"] = None _id = obs.started_event(**sample_run) run_dir = basedir.join(str(_id)) assert run_dir.exists() - assert run_dir.join('cout.txt').exists() - config = json.loads(run_dir.join('config.json').read()) - assert config == sample_run['config'] + assert run_dir.join("cout.txt").exists() + config = json.loads(run_dir.join("config.json").read()) + assert config == sample_run["config"] - run = json.loads(run_dir.join('run.json').read()) + run = json.loads(run_dir.join("run.json").read()) assert run == { - 'experiment': sample_run['ex_info'], - 'command': sample_run['command'], - 'host': sample_run['host_info'], - 'start_time': T1.isoformat(), - 'heartbeat': None, - 'meta': sample_run['meta_info'], + "experiment": sample_run["ex_info"], + "command": sample_run["command"], + "host": sample_run["host_info"], + "start_time": T1.isoformat(), + "heartbeat": None, + "meta": sample_run["meta_info"], "resources": [], "artifacts": [], - "status": "RUNNING" + "status": "RUNNING", } -def test_fs_observer_started_event_creates_rundir_with_filesystem_delay(dir_obs, sample_run, monkeypatch): +def test_fs_observer_started_event_creates_rundir_with_filesystem_delay( + dir_obs, sample_run, monkeypatch +): """ Assumes listdir doesn't show existing file (e.g. due to caching or delay of network storage) """ basedir, obs = dir_obs - sample_run['_id'] = None + sample_run["_id"] = None _id = obs.started_event(**sample_run) - assert _id == '1' + assert _id == "1" assert os.listdir(str(basedir)) == [_id] with monkeypatch.context() as m: - m.setattr('os.listdir', lambda _: []) + m.setattr("os.listdir", lambda _: []) assert os.listdir(str(basedir)) == [] _id2 = obs.started_event(**sample_run) - assert _id2 == '2' + assert _id2 == "2" -def test_fs_observer_started_event_raises_file_exists_error(dir_obs, sample_run, monkeypatch): +def test_fs_observer_started_event_raises_file_exists_error( + dir_obs, sample_run, monkeypatch +): """ Assumes some problem with the filesystem exists therefore run dir creation should stop after some re-tries """ + def mkdir_raises_file_exists(name, mode=0o777): raise FileExistsError("File already exists: " + name) basedir, obs = dir_obs - sample_run['_id'] = None + sample_run["_id"] = None with monkeypatch.context() as m: - m.setattr('os.mkdir', mkdir_raises_file_exists) + m.setattr("os.mkdir", mkdir_raises_file_exists) with pytest.raises(FileExistsError): obs.started_event(**sample_run) def test_fs_observer_started_event_stores_source(dir_obs, sample_run, tmpfile): basedir, obs = dir_obs - sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]] + sample_run["ex_info"]["sources"] = [[tmpfile.name, tmpfile.md5sum]] _id = obs.started_event(**sample_run) run_dir = basedir.join(_id) assert run_dir.exists() - run = json.loads(run_dir.join('run.json').read()) - ex_info = copy(run['experiment']) - assert ex_info['sources'][0][0] == tmpfile.name - source_path = ex_info['sources'][0][1] + run = json.loads(run_dir.join("run.json").read()) + ex_info = copy(run["experiment"]) + assert ex_info["sources"][0][0] == tmpfile.name + source_path = ex_info["sources"][0][1] source = basedir.join(source_path) assert source.exists() - assert source.read() == 'import sacred\n' + assert source.read() == "import sacred\n" def test_fs_observer_started_event_uses_given_id(dir_obs, sample_run): basedir, obs = dir_obs _id = obs.started_event(**sample_run) - assert _id == sample_run['_id'] + assert _id == sample_run["_id"] assert basedir.join(_id).exists() @@ -171,18 +181,17 @@ def test_fs_observer_heartbeat_event_updates_run(dir_obs, sample_run): basedir, obs = dir_obs _id = obs.started_event(**sample_run) run_dir = basedir.join(_id) - info = {'my_info': [1, 2, 3], 'nr': 7} - obs.heartbeat_event(info=info, captured_out='some output', beat_time=T2, - result=17) + info = {"my_info": [1, 2, 3], "nr": 7} + obs.heartbeat_event(info=info, captured_out="some output", beat_time=T2, result=17) - assert run_dir.join('cout.txt').read() == 'some output' - run = json.loads(run_dir.join('run.json').read()) + assert run_dir.join("cout.txt").read() == "some output" + run = json.loads(run_dir.join("run.json").read()) - assert run['heartbeat'] == T2.isoformat() - assert run['result'] == 17 + assert run["heartbeat"] == T2.isoformat() + assert run["result"] == 17 - assert run_dir.join('info.json').exists() - i = json.loads(run_dir.join('info.json').read()) + assert run_dir.join("info.json").exists() + i = json.loads(run_dir.join("info.json").read()) assert info == i @@ -190,25 +199,30 @@ def test_fs_observer_heartbeat_event_multiple_updates_run(dir_obs, sample_run): basedir, obs = dir_obs _id = obs.started_event(**sample_run) run_dir = basedir.join(_id) - info = {'my_info': [1, 2, 3], 'nr': 7} + info = {"my_info": [1, 2, 3], "nr": 7} captured_outs = [("some output %d\n" % i) for i in range(10)] - beat_times = [(T2 + datetime.timedelta(seconds=i*10)) for i in range(10)] + beat_times = [(T2 + datetime.timedelta(seconds=i * 10)) for i in range(10)] for idx in range(len(beat_times)): - expected_captured_output = "\n".join( - [x.strip() for x in captured_outs[:(idx+1)]]) + "\n" - obs.heartbeat_event(info=info, captured_out=expected_captured_output, - beat_time=beat_times[idx], result=17) - - assert run_dir.join('cout.txt').read() == expected_captured_output - run = json.loads(run_dir.join('run.json').read()) - - assert run['heartbeat'] == beat_times[idx].isoformat() - assert run['result'] == 17 - - assert run_dir.join('info.json').exists() - i = json.loads(run_dir.join('info.json').read()) + expected_captured_output = ( + "\n".join([x.strip() for x in captured_outs[: (idx + 1)]]) + "\n" + ) + obs.heartbeat_event( + info=info, + captured_out=expected_captured_output, + beat_time=beat_times[idx], + result=17, + ) + + assert run_dir.join("cout.txt").read() == expected_captured_output + run = json.loads(run_dir.join("run.json").read()) + + assert run["heartbeat"] == beat_times[idx].isoformat() + assert run["result"] == 17 + + assert run_dir.join("info.json").exists() + i = json.loads(run_dir.join("info.json").read()) assert info == i @@ -219,10 +233,10 @@ def test_fs_observer_completed_event_updates_run(dir_obs, sample_run): obs.completed_event(stop_time=T2, result=42) - run = json.loads(run_dir.join('run.json').read()) - assert run['stop_time'] == T2.isoformat() - assert run['status'] == 'COMPLETED' - assert run['result'] == 42 + run = json.loads(run_dir.join("run.json").read()) + assert run["stop_time"] == T2.isoformat() + assert run["status"] == "COMPLETED" + assert run["result"] == 42 def test_fs_observer_interrupted_event_updates_run(dir_obs, sample_run): @@ -230,11 +244,11 @@ def test_fs_observer_interrupted_event_updates_run(dir_obs, sample_run): _id = obs.started_event(**sample_run) run_dir = basedir.join(_id) - obs.interrupted_event(interrupt_time=T2, status='CUSTOM_INTERRUPTION') + obs.interrupted_event(interrupt_time=T2, status="CUSTOM_INTERRUPTION") - run = json.loads(run_dir.join('run.json').read()) - assert run['stop_time'] == T2.isoformat() - assert run['status'] == 'CUSTOM_INTERRUPTION' + run = json.loads(run_dir.join("run.json").read()) + assert run["stop_time"] == T2.isoformat() + assert run["status"] == "CUSTOM_INTERRUPTION" def test_fs_observer_failed_event_updates_run(dir_obs, sample_run): @@ -245,26 +259,26 @@ def test_fs_observer_failed_event_updates_run(dir_obs, sample_run): fail_trace = "lots of errors and\nso\non..." obs.failed_event(fail_time=T2, fail_trace=fail_trace) - run = json.loads(run_dir.join('run.json').read()) - assert run['stop_time'] == T2.isoformat() - assert run['status'] == 'FAILED' - assert run['fail_trace'] == fail_trace + run = json.loads(run_dir.join("run.json").read()) + assert run["stop_time"] == T2.isoformat() + assert run["status"] == "FAILED" + assert run["fail_trace"] == fail_trace def test_fs_observer_artifact_event(dir_obs, sample_run, tmpfile): basedir, obs = dir_obs _id = obs.started_event(**sample_run) run_dir = basedir.join(_id) - - obs.artifact_event('my_artifact.py', tmpfile.name) - artifact = run_dir.join('my_artifact.py') + obs.artifact_event("my_artifact.py", tmpfile.name) + + artifact = run_dir.join("my_artifact.py") assert artifact.exists() assert artifact.read() == tmpfile.content - run = json.loads(run_dir.join('run.json').read()) - assert len(run['artifacts']) == 1 - assert run['artifacts'][0] == artifact.relto(run_dir) + run = json.loads(run_dir.join("run.json").read()) + assert len(run["artifacts"]) == 1 + assert run["artifacts"][0] == artifact.relto(run_dir) def test_fs_observer_resource_event(dir_obs, sample_run, tmpfile): @@ -274,37 +288,36 @@ def test_fs_observer_resource_event(dir_obs, sample_run, tmpfile): obs.resource_event(tmpfile.name) - res_dir = basedir.join('_resources') + res_dir = basedir.join("_resources") assert res_dir.exists() assert len(res_dir.listdir()) == 1 assert res_dir.listdir()[0].read() == tmpfile.content - run = json.loads(run_dir.join('run.json').read()) - assert len(run['resources']) == 1 - assert run['resources'][0] == [tmpfile.name, res_dir.listdir()[0].strpath] + run = json.loads(run_dir.join("run.json").read()) + assert len(run["resources"]) == 1 + assert run["resources"][0] == [tmpfile.name, res_dir.listdir()[0].strpath] -def test_fs_observer_resource_event_does_not_duplicate(dir_obs, sample_run, - tmpfile): +def test_fs_observer_resource_event_does_not_duplicate(dir_obs, sample_run, tmpfile): basedir, obs = dir_obs obs2 = FileStorageObserver.create(obs.basedir) obs.started_event(**sample_run) obs.resource_event(tmpfile.name) # let's have another run from a different observer - sample_run['_id'] = None + sample_run["_id"] = None _id = obs2.started_event(**sample_run) run_dir = basedir.join(str(_id)) obs2.resource_event(tmpfile.name) - res_dir = basedir.join('_resources') + res_dir = basedir.join("_resources") assert res_dir.exists() assert len(res_dir.listdir()) == 1 assert res_dir.listdir()[0].read() == tmpfile.content - run = json.loads(run_dir.join('run.json').read()) - assert len(run['resources']) == 1 - assert run['resources'][0] == [tmpfile.name, res_dir.listdir()[0].strpath] + run = json.loads(run_dir.join("run.json").read()) + assert len(run["resources"]) == 1 + assert run["resources"][0] == [tmpfile.name, res_dir.listdir()[0].strpath] def test_fs_observer_equality(dir_obs): @@ -313,8 +326,9 @@ def test_fs_observer_equality(dir_obs): assert obs == obs2 assert not obs != obs2 - assert not obs == 'foo' - assert obs != 'foo' + assert not obs == "foo" + assert obs != "foo" + @pytest.fixture def logged_metrics(): @@ -322,14 +336,12 @@ def logged_metrics(): ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 1), ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 2), ScalarMetricLogEntry("training.loss", 30, datetime.datetime.utcnow(), 3), - ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100), ScalarMetricLogEntry("training.accuracy", 20, datetime.datetime.utcnow(), 200), ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300), - ScalarMetricLogEntry("training.loss", 40, datetime.datetime.utcnow(), 10), ScalarMetricLogEntry("training.loss", 50, datetime.datetime.utcnow(), 20), - ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), 30) + ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), 30), ] @@ -345,36 +357,32 @@ def test_log_metrics(dir_obs, sample_run, logged_metrics): and the timestamp at which the measurement was taken(timestamps) """ - # Start the experiment + # Start the experiment basedir, obs = dir_obs - sample_run['_id'] = None - _id = obs.started_event(**sample_run) + sample_run["_id"] = None + _id = obs.started_event(**sample_run) run_dir = basedir.join(str(_id)) # Initialize the info dictionary and standard output with arbitrary values - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" obs.log_metrics(linearize_metrics(logged_metrics[:6]), info) - obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, - result=0) - - - assert run_dir.join('metrics.json').exists() - metrics = json.loads(run_dir.join('metrics.json').read()) + obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, result=0) + assert run_dir.join("metrics.json").exists() + metrics = json.loads(run_dir.join("metrics.json").read()) # Confirm that we have only two metric names registered. # and they have all the information we need. assert len(metrics) == 2 assert "training.loss" in metrics assert "training.accuracy" in metrics - for v in ["steps","values","timestamps"]: - assert v in metrics["training.loss"] + for v in ["steps", "values", "timestamps"]: + assert v in metrics["training.loss"] assert v in metrics["training.accuracy"] - - # Verify they have all the information + # Verify they have all the information # we logged in the right order. loss = metrics["training.loss"] assert loss["steps"] == [10, 20, 30] @@ -386,15 +394,13 @@ def test_log_metrics(dir_obs, sample_run, logged_metrics): assert accuracy["steps"] == [10, 20, 30] assert accuracy["values"] == [100, 200, 300] - # Now, process the remaining events # The metrics shouldn't be overwritten, but appended instead. obs.log_metrics(linearize_metrics(logged_metrics[6:]), info) - obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, - result=0) + obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, result=0) # Reload the new metrics - metrics = json.loads(run_dir.join('metrics.json').read()) + metrics = json.loads(run_dir.join("metrics.json").read()) # The newly added metrics belong to the same run and have the same names, # so the total number of metrics should not change. @@ -407,7 +413,6 @@ def test_log_metrics(dir_obs, sample_run, logged_metrics): for i in range(len(loss["timestamps"]) - 1): assert loss["timestamps"][i] <= loss["timestamps"][i + 1] - # Read the training.accuracy metric and verify it's unchanged assert "training.accuracy" in metrics accuracy = metrics["training.accuracy"] @@ -416,8 +421,8 @@ def test_log_metrics(dir_obs, sample_run, logged_metrics): def test_observer_equality(tmpdir): - observer_1 = FileStorageObserver.create(str(tmpdir / 'a')) - observer_2 = FileStorageObserver.create(str(tmpdir / 'b')) - observer_3 = FileStorageObserver.create(str(tmpdir / 'a')) + observer_1 = FileStorageObserver.create(str(tmpdir / "a")) + observer_2 = FileStorageObserver.create(str(tmpdir / "b")) + observer_3 = FileStorageObserver.create(str(tmpdir / "a")) assert observer_1 == observer_3 assert observer_1 != observer_2 diff --git a/tests/test_observers/test_mongo_observer.py b/tests/test_observers/test_mongo_observer.py index 39a2314a..83ec53ba 100644 --- a/tests/test_observers/test_mongo_observer.py +++ b/tests/test_observers/test_mongo_observer.py @@ -15,7 +15,7 @@ from .failing_mongo_mock import FailingMongoClient from sacred.dependencies import get_digest -from sacred.observers.mongo import (MongoObserver, force_bson_encodeable) +from sacred.observers.mongo import MongoObserver, force_bson_encodeable T1 = datetime.datetime(1999, 5, 4, 3, 2, 1) T2 = datetime.datetime(1999, 5, 5, 5, 5, 5) @@ -30,8 +30,7 @@ def test_create_should_raise_error_on_non_pymongo_client(): def test_create_should_raise_error_on_both_client_and_url(): real_client = pymongo.MongoClient() - with pytest.raises(ValueError, - match="Cannot pass both a client and a url."): + with pytest.raises(ValueError, match="Cannot pass both a client and a url."): MongoObserver.create(client=real_client, url="mymongourl") @@ -46,9 +45,11 @@ def mongo_obs(): @pytest.fixture def failing_mongo_observer(): - db = FailingMongoClient(max_calls_before_failure=2, + db = FailingMongoClient( + max_calls_before_failure=2, # exception_to_raise=pymongo.errors.AutoReconnect - exception_to_raise=pymongo.errors.ServerSelectionTimeoutError).db + exception_to_raise=pymongo.errors.ServerSelectionTimeoutError, + ).db runs = db.runs metrics = db.metrics fs = mock.MagicMock() @@ -57,39 +58,52 @@ def failing_mongo_observer(): @pytest.fixture() def sample_run(): - exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'} - host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'} - config = {'config': 'True', 'foo': 'bar', 'answer': 42} - command = 'run' - meta_info = {'comment': 'test run'} - return {'_id': 'FEDCBA9876543210', 'ex_info': exp, 'command': command, - 'host_info': host, 'start_time': T1, 'config': config, - 'meta_info': meta_info, } + exp = {"name": "test_exp", "sources": [], "doc": "", "base_dir": "/tmp"} + host = {"hostname": "test_host", "cpu_count": 1, "python_version": "3.4"} + config = {"config": "True", "foo": "bar", "answer": 42} + command = "run" + meta_info = {"comment": "test run"} + return { + "_id": "FEDCBA9876543210", + "ex_info": exp, + "command": command, + "host_info": host, + "start_time": T1, + "config": config, + "meta_info": meta_info, + } def test_mongo_observer_started_event_creates_run(mongo_obs, sample_run): - sample_run['_id'] = None + sample_run["_id"] = None _id = mongo_obs.started_event(**sample_run) assert _id is not None assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run == {'_id': _id, 'experiment': sample_run['ex_info'], - 'format': mongo_obs.VERSION, - 'command': sample_run['command'], - 'host': sample_run['host_info'], - 'start_time': sample_run['start_time'], - 'heartbeat': None, 'info': {}, 'captured_out': '', - 'artifacts': [], 'config': sample_run['config'], - 'meta': sample_run['meta_info'], 'status': 'RUNNING', - 'resources': []} + assert db_run == { + "_id": _id, + "experiment": sample_run["ex_info"], + "format": mongo_obs.VERSION, + "command": sample_run["command"], + "host": sample_run["host_info"], + "start_time": sample_run["start_time"], + "heartbeat": None, + "info": {}, + "captured_out": "", + "artifacts": [], + "config": sample_run["config"], + "meta": sample_run["meta_info"], + "status": "RUNNING", + "resources": [], + } def test_mongo_observer_started_event_uses_given_id(mongo_obs, sample_run): _id = mongo_obs.started_event(**sample_run) - assert _id == sample_run['_id'] + assert _id == sample_run["_id"] assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['_id'] == sample_run['_id'] + assert db_run["_id"] == sample_run["_id"] def test_mongo_observer_equality(mongo_obs): @@ -99,53 +113,55 @@ def test_mongo_observer_equality(mongo_obs): assert mongo_obs == m assert not mongo_obs != m - assert not mongo_obs == 'foo' - assert mongo_obs != 'foo' + assert not mongo_obs == "foo" + assert mongo_obs != "foo" def test_mongo_observer_heartbeat_event_updates_run(mongo_obs, sample_run): mongo_obs.started_event(**sample_run) - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' - mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, - result=1337) + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" + mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, result=1337) assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['heartbeat'] == T2 - assert db_run['result'] == 1337 - assert db_run['info'] == info - assert db_run['captured_out'] == outp + assert db_run["heartbeat"] == T2 + assert db_run["result"] == 1337 + assert db_run["info"] == info + assert db_run["captured_out"] == outp def test_mongo_observer_fails(failing_mongo_observer, sample_run): failing_mongo_observer.started_event(**sample_run) - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' - failing_mongo_observer.heartbeat_event(info=info, captured_out=outp, - beat_time=T2, result=1337, ) + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" + failing_mongo_observer.heartbeat_event( + info=info, captured_out=outp, beat_time=T2, result=1337 + ) with pytest.raises(pymongo.errors.ConnectionFailure): - failing_mongo_observer.heartbeat_event(info=info, captured_out=outp, - beat_time=T3, result=1337, ) + failing_mongo_observer.heartbeat_event( + info=info, captured_out=outp, beat_time=T3, result=1337 + ) -def test_mongo_observer_saves_after_failure(failing_mongo_observer, - sample_run): +def test_mongo_observer_saves_after_failure(failing_mongo_observer, sample_run): failure_dir = "/tmp/my_failure/dir" failing_mongo_observer.failure_dir = failure_dir failing_mongo_observer.started_event(**sample_run) - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' - failing_mongo_observer.heartbeat_event(info=info, captured_out=outp, - beat_time=T2, result=1337, ) + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" + failing_mongo_observer.heartbeat_event( + info=info, captured_out=outp, beat_time=T2, result=1337 + ) failing_mongo_observer.completed_event(stop_time=T3, result=42) - glob_pattern = "{}/sacred_mongo_fail_{}*.pickle".format(failure_dir, - sample_run["_id"]) + glob_pattern = "{}/sacred_mongo_fail_{}*.pickle".format( + failure_dir, sample_run["_id"] + ) os.path.isfile(glob(glob_pattern)[-1]) @@ -156,20 +172,20 @@ def test_mongo_observer_completed_event_updates_run(mongo_obs, sample_run): assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['stop_time'] == T2 - assert db_run['result'] == 42 - assert db_run['status'] == 'COMPLETED' + assert db_run["stop_time"] == T2 + assert db_run["result"] == 42 + assert db_run["status"] == "COMPLETED" def test_mongo_observer_interrupted_event_updates_run(mongo_obs, sample_run): mongo_obs.started_event(**sample_run) - mongo_obs.interrupted_event(interrupt_time=T2, status='INTERRUPTED') + mongo_obs.interrupted_event(interrupt_time=T2, status="INTERRUPTED") assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['stop_time'] == T2 - assert db_run['status'] == 'INTERRUPTED' + assert db_run["stop_time"] == T2 + assert db_run["status"] == "INTERRUPTED" def test_mongo_observer_failed_event_updates_run(mongo_obs, sample_run): @@ -180,24 +196,24 @@ def test_mongo_observer_failed_event_updates_run(mongo_obs, sample_run): assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['stop_time'] == T2 - assert db_run['status'] == 'FAILED' - assert db_run['fail_trace'] == fail_trace + assert db_run["stop_time"] == T2 + assert db_run["status"] == "FAILED" + assert db_run["fail_trace"] == fail_trace def test_mongo_observer_artifact_event(mongo_obs, sample_run): mongo_obs.started_event(**sample_run) filename = "setup.py" - name = 'mysetup' + name = "mysetup" mongo_obs.artifact_event(name, filename) assert mongo_obs.fs.put.called - assert mongo_obs.fs.put.call_args[1]['filename'].endswith(name) + assert mongo_obs.fs.put.call_args[1]["filename"].endswith(name) db_run = mongo_obs.runs.find_one() - assert db_run['artifacts'] + assert db_run["artifacts"] def test_mongo_observer_resource_event(mongo_obs, sample_run): @@ -213,56 +229,57 @@ def test_mongo_observer_resource_event(mongo_obs, sample_run): db_run = mongo_obs.runs.find_one() - assert db_run['resources'] == [(filename, md5)] + assert db_run["resources"] == [(filename, md5)] def test_force_bson_encodable_doesnt_change_valid_document(): - d = {'int': 1, 'string': 'foo', 'float': 23.87, 'list': ['a', 1, True], - 'bool': True, 'cr4zy: _but_ [legal) Key!': '$illegal.key.as.value', - 'datetime': datetime.datetime.utcnow(), 'tuple': (1, 2.0, 'three'), - 'none': None} + d = { + "int": 1, + "string": "foo", + "float": 23.87, + "list": ["a", 1, True], + "bool": True, + "cr4zy: _but_ [legal) Key!": "$illegal.key.as.value", + "datetime": datetime.datetime.utcnow(), + "tuple": (1, 2.0, "three"), + "none": None, + } assert force_bson_encodeable(d) == d def test_force_bson_encodable_substitutes_illegal_value_with_strings(): - d = {'a_module': datetime, - 'some_legal_stuff': {'foo': 'bar', 'baz': [1, 23, 4]}, - 'nested': {'dict': {'with': {'illegal_module': mock}}}, - '$illegal': 'because it starts with a $', - 'il.legal': 'because it contains a .', - 12.7: 'illegal because it is not a string key'} - expected = {'a_module': str(datetime), - 'some_legal_stuff': {'foo': 'bar', 'baz': [1, 23, 4]}, - 'nested': {'dict': {'with': {'illegal_module': str(mock)}}}, - '@illegal': 'because it starts with a $', - 'il,legal': 'because it contains a .', - '12,7': 'illegal because it is not a string key'} + d = { + "a_module": datetime, + "some_legal_stuff": {"foo": "bar", "baz": [1, 23, 4]}, + "nested": {"dict": {"with": {"illegal_module": mock}}}, + "$illegal": "because it starts with a $", + "il.legal": "because it contains a .", + 12.7: "illegal because it is not a string key", + } + expected = { + "a_module": str(datetime), + "some_legal_stuff": {"foo": "bar", "baz": [1, 23, 4]}, + "nested": {"dict": {"with": {"illegal_module": str(mock)}}}, + "@illegal": "because it starts with a $", + "il,legal": "because it contains a .", + "12,7": "illegal because it is not a string key", + } assert force_bson_encodeable(d) == expected @pytest.fixture def logged_metrics(): return [ - ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), - 1), - ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), - 2), - ScalarMetricLogEntry("training.loss", 30, datetime.datetime.utcnow(), - 3), - - ScalarMetricLogEntry("training.accuracy", 10, - datetime.datetime.utcnow(), 100), - ScalarMetricLogEntry("training.accuracy", 20, - datetime.datetime.utcnow(), 200), - ScalarMetricLogEntry("training.accuracy", 30, - datetime.datetime.utcnow(), 300), - - ScalarMetricLogEntry("training.loss", 40, datetime.datetime.utcnow(), - 10), - ScalarMetricLogEntry("training.loss", 50, datetime.datetime.utcnow(), - 20), - ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), - 30)] + ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 1), + ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 2), + ScalarMetricLogEntry("training.loss", 30, datetime.datetime.utcnow(), 3), + ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100), + ScalarMetricLogEntry("training.accuracy", 20, datetime.datetime.utcnow(), 200), + ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300), + ScalarMetricLogEntry("training.loss", 40, datetime.datetime.utcnow(), 10), + ScalarMetricLogEntry("training.loss", 50, datetime.datetime.utcnow(), 20), + ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), 30), + ] def test_log_metrics(mongo_obs, sample_run, logged_metrics): @@ -284,23 +301,22 @@ def test_log_metrics(mongo_obs, sample_run, logged_metrics): mongo_obs.started_event(**sample_run) # Initialize the info dictionary and standard output with arbitrary values - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" # Take first 6 measured events, group them by metric name # and store the measured series to the 'metrics' collection # and reference the newly created records in the 'info' dictionary. mongo_obs.log_metrics(linearize_metrics(logged_metrics[:6]), info) # Call standard heartbeat event (store the info dictionary to the database) - mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, - result=0) + mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, result=0) # There should be only one run stored assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() # ... and the info dictionary should contain a list of created metrics - assert "metrics" in db_run['info'] - assert type(db_run['info']["metrics"]) == list + assert "metrics" in db_run["info"] + assert type(db_run["info"]["metrics"]) == list # The metrics, stored in the metrics collection, # should be two (training.loss and training.accuracy) @@ -308,9 +324,11 @@ def test_log_metrics(mongo_obs, sample_run, logged_metrics): # Read the training.loss metric and make sure it references the correct run # and that the run (in the info dictionary) references the correct metric record. loss = mongo_obs.metrics.find_one( - {"name": "training.loss", "run_id": db_run['_id']}) - assert {"name": "training.loss", "id": str(loss["_id"])} in db_run['info'][ - "metrics"] + {"name": "training.loss", "run_id": db_run["_id"]} + ) + assert {"name": "training.loss", "id": str(loss["_id"])} in db_run["info"][ + "metrics" + ] assert loss["steps"] == [10, 20, 30] assert loss["values"] == [1, 2, 3] for i in range(len(loss["timestamps"]) - 1): @@ -318,29 +336,32 @@ def test_log_metrics(mongo_obs, sample_run, logged_metrics): # Read the training.accuracy metric and check the references as with the training.loss above accuracy = mongo_obs.metrics.find_one( - {"name": "training.accuracy", "run_id": db_run['_id']}) - assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in \ - db_run['info']["metrics"] + {"name": "training.accuracy", "run_id": db_run["_id"]} + ) + assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in db_run["info"][ + "metrics" + ] assert accuracy["steps"] == [10, 20, 30] assert accuracy["values"] == [100, 200, 300] # Now, process the remaining events # The metrics shouldn't be overwritten, but appended instead. mongo_obs.log_metrics(linearize_metrics(logged_metrics[6:]), info) - mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, - result=0) + mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, result=0) assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert "metrics" in db_run['info'] + assert "metrics" in db_run["info"] # The newly added metrics belong to the same run and have the same names, so the total number # of metrics should not change. assert mongo_obs.metrics.count() == 2 loss = mongo_obs.metrics.find_one( - {"name": "training.loss", "run_id": db_run['_id']}) - assert {"name": "training.loss", "id": str(loss["_id"])} in db_run['info'][ - "metrics"] + {"name": "training.loss", "run_id": db_run["_id"]} + ) + assert {"name": "training.loss", "id": str(loss["_id"])} in db_run["info"][ + "metrics" + ] # ... but the values should be appended to the original list assert loss["steps"] == [10, 20, 30, 40, 50, 60] assert loss["values"] == [1, 2, 3, 10, 20, 30] @@ -348,9 +369,11 @@ def test_log_metrics(mongo_obs, sample_run, logged_metrics): assert loss["timestamps"][i] <= loss["timestamps"][i + 1] accuracy = mongo_obs.metrics.find_one( - {"name": "training.accuracy", "run_id": db_run['_id']}) - assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in \ - db_run['info']["metrics"] + {"name": "training.accuracy", "run_id": db_run["_id"]} + ) + assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in db_run["info"][ + "metrics" + ] assert accuracy["steps"] == [10, 20, 30] assert accuracy["values"] == [100, 200, 300] @@ -360,61 +383,58 @@ def test_log_metrics(mongo_obs, sample_run, logged_metrics): # Start the experiment mongo_obs.started_event(**sample_run) mongo_obs.log_metrics(linearize_metrics(logged_metrics[:4]), info) - mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, - result=0) + mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, result=0) # A new run has been created assert mongo_obs.runs.count() == 2 # Another 2 metrics have been created assert mongo_obs.metrics.count() == 4 -def test_mongo_observer_artifact_event_content_type_added(mongo_obs, - sample_run): +def test_mongo_observer_artifact_event_content_type_added(mongo_obs, sample_run): """Test that the detected content_type is added to other metadata.""" mongo_obs.started_event(**sample_run) - filename = 'setup.py' - name = 'mysetup' + filename = "setup.py" + name = "mysetup" mongo_obs.artifact_event(name, filename) assert mongo_obs.fs.put.called - assert mongo_obs.fs.put.call_args[1]['content_type'] == 'text/x-python' + assert mongo_obs.fs.put.call_args[1]["content_type"] == "text/x-python" db_run = mongo_obs.runs.find_one() - assert db_run['artifacts'] + assert db_run["artifacts"] -def test_mongo_observer_artifact_event_content_type_not_overwritten(mongo_obs, - sample_run): +def test_mongo_observer_artifact_event_content_type_not_overwritten( + mongo_obs, sample_run +): """Test that manually set content_type is not overwritten by automatic detection.""" mongo_obs.started_event(**sample_run) - filename = 'setup.py' - name = 'mysetup' + filename = "setup.py" + name = "mysetup" - mongo_obs.artifact_event(name, filename, content_type='application/json') + mongo_obs.artifact_event(name, filename, content_type="application/json") assert mongo_obs.fs.put.called - assert mongo_obs.fs.put.call_args[1]['content_type'] == 'application/json' + assert mongo_obs.fs.put.call_args[1]["content_type"] == "application/json" db_run = mongo_obs.runs.find_one() - assert db_run['artifacts'] + assert db_run["artifacts"] def test_mongo_observer_artifact_event_metadata(mongo_obs, sample_run): """Test that the detected content-type is added to other metadata.""" mongo_obs.started_event(**sample_run) - filename = 'setup.py' - name = 'mysetup' + filename = "setup.py" + name = "mysetup" - mongo_obs.artifact_event(name, filename, - metadata={'comment': 'the setup file'}) + mongo_obs.artifact_event(name, filename, metadata={"comment": "the setup file"}) assert mongo_obs.fs.put.called - assert mongo_obs.fs.put.call_args[1]['metadata'][ - 'comment'] == 'the setup file' + assert mongo_obs.fs.put.call_args[1]["metadata"]["comment"] == "the setup file" db_run = mongo_obs.runs.find_one() - assert db_run['artifacts'] + assert db_run["artifacts"] diff --git a/tests/test_observers/test_mongo_option.py b/tests/test_observers/test_mongo_option.py index e4fe2bef..1840e87a 100644 --- a/tests/test_observers/test_mongo_option.py +++ b/tests/test_observers/test_mongo_option.py @@ -10,56 +10,83 @@ def test_parse_mongo_db_arg(): - assert MongoDbOption.parse_mongo_db_arg('foo') == {'db_name': 'foo'} + assert MongoDbOption.parse_mongo_db_arg("foo") == {"db_name": "foo"} def test_parse_mongo_db_arg_collection(): - kwargs = MongoDbOption.parse_mongo_db_arg('foo.bar') - assert kwargs == {'db_name': 'foo', 'collection': 'bar'} + kwargs = MongoDbOption.parse_mongo_db_arg("foo.bar") + assert kwargs == {"db_name": "foo", "collection": "bar"} def test_parse_mongo_db_arg_hostname(): - assert MongoDbOption.parse_mongo_db_arg('localhost:28017') == \ - {'url': 'localhost:28017'} + assert MongoDbOption.parse_mongo_db_arg("localhost:28017") == { + "url": "localhost:28017" + } - assert MongoDbOption.parse_mongo_db_arg('www.mymongo.db:28017') == \ - {'url': 'www.mymongo.db:28017'} + assert MongoDbOption.parse_mongo_db_arg("www.mymongo.db:28017") == { + "url": "www.mymongo.db:28017" + } - assert MongoDbOption.parse_mongo_db_arg('123.45.67.89:27017') == \ - {'url': '123.45.67.89:27017'} + assert MongoDbOption.parse_mongo_db_arg("123.45.67.89:27017") == { + "url": "123.45.67.89:27017" + } def test_parse_mongo_db_arg_hostname_dbname(): - assert MongoDbOption.parse_mongo_db_arg('localhost:28017:foo') == \ - {'url': 'localhost:28017', 'db_name': 'foo'} + assert MongoDbOption.parse_mongo_db_arg("localhost:28017:foo") == { + "url": "localhost:28017", + "db_name": "foo", + } - assert MongoDbOption.parse_mongo_db_arg('www.mymongo.db:28017:bar') == \ - {'url': 'www.mymongo.db:28017', 'db_name': 'bar'} + assert MongoDbOption.parse_mongo_db_arg("www.mymongo.db:28017:bar") == { + "url": "www.mymongo.db:28017", + "db_name": "bar", + } - assert MongoDbOption.parse_mongo_db_arg('123.45.67.89:27017:baz') == \ - {'url': '123.45.67.89:27017', 'db_name': 'baz'} + assert MongoDbOption.parse_mongo_db_arg("123.45.67.89:27017:baz") == { + "url": "123.45.67.89:27017", + "db_name": "baz", + } def test_parse_mongo_db_arg_hostname_dbname_collection_name(): - assert MongoDbOption.parse_mongo_db_arg('localhost:28017:foo.bar') == \ - {'url': 'localhost:28017', 'db_name': 'foo', 'collection': 'bar'} + assert MongoDbOption.parse_mongo_db_arg("localhost:28017:foo.bar") == { + "url": "localhost:28017", + "db_name": "foo", + "collection": "bar", + } - assert MongoDbOption.parse_mongo_db_arg('www.mymongo.db:28017:bar.baz') ==\ - {'url': 'www.mymongo.db:28017', 'db_name': 'bar', 'collection': 'baz'} + assert MongoDbOption.parse_mongo_db_arg("www.mymongo.db:28017:bar.baz") == { + "url": "www.mymongo.db:28017", + "db_name": "bar", + "collection": "baz", + } - assert MongoDbOption.parse_mongo_db_arg('123.45.67.89:27017:baz.foo') == \ - {'url': '123.45.67.89:27017', 'db_name': 'baz', 'collection': 'foo'} + assert MongoDbOption.parse_mongo_db_arg("123.45.67.89:27017:baz.foo") == { + "url": "123.45.67.89:27017", + "db_name": "baz", + "collection": "foo", + } def test_parse_mongo_db_arg_priority(): - assert MongoDbOption.parse_mongo_db_arg('localhost:28017:foo.bar!17') == \ - {'url': 'localhost:28017', 'db_name': 'foo', 'collection': 'bar', - 'priority': 17} - - assert MongoDbOption.parse_mongo_db_arg('www.mymongo.db:28017:bar.baz!2') ==\ - {'url': 'www.mymongo.db:28017', 'db_name': 'bar', 'collection': 'baz', - 'priority': 2} - - assert MongoDbOption.parse_mongo_db_arg('123.45.67.89:27017:baz.foo!-123') == \ - {'url': '123.45.67.89:27017', 'db_name': 'baz', 'collection': 'foo', - 'priority': -123} + assert MongoDbOption.parse_mongo_db_arg("localhost:28017:foo.bar!17") == { + "url": "localhost:28017", + "db_name": "foo", + "collection": "bar", + "priority": 17, + } + + assert MongoDbOption.parse_mongo_db_arg("www.mymongo.db:28017:bar.baz!2") == { + "url": "www.mymongo.db:28017", + "db_name": "bar", + "collection": "baz", + "priority": 2, + } + + assert MongoDbOption.parse_mongo_db_arg("123.45.67.89:27017:baz.foo!-123") == { + "url": "123.45.67.89:27017", + "db_name": "baz", + "collection": "foo", + "priority": -123, + } diff --git a/tests/test_observers/test_queue_mongo_observer.py b/tests/test_observers/test_queue_mongo_observer.py index 495dd617..ed97dca7 100644 --- a/tests/test_observers/test_queue_mongo_observer.py +++ b/tests/test_observers/test_queue_mongo_observer.py @@ -25,57 +25,54 @@ def mongo_obs(monkeypatch): client = ReconnectingMongoClient( max_calls_before_reconnect=10, max_calls_before_failure=1, - exception_to_raise=pymongo.errors.ServerSelectionTimeoutError + exception_to_raise=pymongo.errors.ServerSelectionTimeoutError, ) monkeypatch.setattr(pymongo, "MongoClient", lambda *args, **kwargs: client) monkeypatch.setattr(gridfs, "GridFS", lambda d: mock.MagicMock()) - return QueuedMongoObserver.create( - interval=0.01, - retry_interval=0.01, - ) + return QueuedMongoObserver.create(interval=0.01, retry_interval=0.01) @pytest.fixture() def sample_run(): - exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'} - host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'} - config = {'config': 'True', 'foo': 'bar', 'answer': 42} - command = 'run' - meta_info = {'comment': 'test run'} + exp = {"name": "test_exp", "sources": [], "doc": "", "base_dir": "/tmp"} + host = {"hostname": "test_host", "cpu_count": 1, "python_version": "3.4"} + config = {"config": "True", "foo": "bar", "answer": 42} + command = "run" + meta_info = {"comment": "test run"} return { - '_id': 'FEDCBA9876543210', - 'ex_info': exp, - 'command': command, - 'host_info': host, - 'start_time': T1, - 'config': config, - 'meta_info': meta_info, + "_id": "FEDCBA9876543210", + "ex_info": exp, + "command": command, + "host_info": host, + "start_time": T1, + "config": config, + "meta_info": meta_info, } def test_mongo_observer_started_event_creates_run(mongo_obs, sample_run): - sample_run['_id'] = None + sample_run["_id"] = None _id = mongo_obs.started_event(**sample_run) mongo_obs.join() assert _id is not None assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() assert db_run == { - '_id': _id, - 'experiment': sample_run['ex_info'], - 'format': mongo_obs.VERSION, - 'command': sample_run['command'], - 'host': sample_run['host_info'], - 'start_time': sample_run['start_time'], - 'heartbeat': None, - 'info': {}, - 'captured_out': '', - 'artifacts': [], - 'config': sample_run['config'], - 'meta': sample_run['meta_info'], - 'status': 'RUNNING', - 'resources': [] + "_id": _id, + "experiment": sample_run["ex_info"], + "format": mongo_obs.VERSION, + "command": sample_run["command"], + "host": sample_run["host_info"], + "start_time": sample_run["start_time"], + "heartbeat": None, + "info": {}, + "captured_out": "", + "artifacts": [], + "config": sample_run["config"], + "meta": sample_run["meta_info"], + "status": "RUNNING", + "resources": [], } @@ -83,10 +80,10 @@ def test_mongo_observer_started_event_uses_given_id(mongo_obs, sample_run): _id = mongo_obs.started_event(**sample_run) mongo_obs.join() - assert _id == sample_run['_id'] + assert _id == sample_run["_id"] assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['_id'] == sample_run['_id'] + assert db_run["_id"] == sample_run["_id"] def test_mongo_observer_equality(mongo_obs): @@ -98,24 +95,23 @@ def test_mongo_observer_equality(mongo_obs): assert mongo_obs == m assert not mongo_obs != m - assert not mongo_obs == 'foo' - assert mongo_obs != 'foo' + assert not mongo_obs == "foo" + assert mongo_obs != "foo" def test_mongo_observer_heartbeat_event_updates_run(mongo_obs, sample_run): mongo_obs.started_event(**sample_run) - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' - mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, - result=1337) + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" + mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, result=1337) mongo_obs.join() assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['heartbeat'] == T2 - assert db_run['result'] == 1337 - assert db_run['info'] == info - assert db_run['captured_out'] == outp + assert db_run["heartbeat"] == T2 + assert db_run["result"] == 1337 + assert db_run["info"] == info + assert db_run["captured_out"] == outp def test_mongo_observer_completed_event_updates_run(mongo_obs, sample_run): @@ -126,51 +122,50 @@ def test_mongo_observer_completed_event_updates_run(mongo_obs, sample_run): assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['stop_time'] == T2 - assert db_run['result'] == 42 - assert db_run['status'] == 'COMPLETED' + assert db_run["stop_time"] == T2 + assert db_run["result"] == 42 + assert db_run["status"] == "COMPLETED" def test_mongo_observer_interrupted_event_updates_run(mongo_obs, sample_run): mongo_obs.started_event(**sample_run) - mongo_obs.interrupted_event(interrupt_time=T2, status='INTERRUPTED') + mongo_obs.interrupted_event(interrupt_time=T2, status="INTERRUPTED") mongo_obs.join() assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['stop_time'] == T2 - assert db_run['status'] == 'INTERRUPTED' + assert db_run["stop_time"] == T2 + assert db_run["status"] == "INTERRUPTED" def test_mongo_observer_failed_event_updates_run(mongo_obs, sample_run): mongo_obs.started_event(**sample_run) fail_trace = "lots of errors and\nso\non..." - mongo_obs.failed_event(fail_time=T2, - fail_trace=fail_trace) + mongo_obs.failed_event(fail_time=T2, fail_trace=fail_trace) mongo_obs.join() assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert db_run['stop_time'] == T2 - assert db_run['status'] == 'FAILED' - assert db_run['fail_trace'] == fail_trace + assert db_run["stop_time"] == T2 + assert db_run["status"] == "FAILED" + assert db_run["fail_trace"] == fail_trace def test_mongo_observer_artifact_event(mongo_obs, sample_run): mongo_obs.started_event(**sample_run) filename = "setup.py" - name = 'mysetup' + name = "mysetup" mongo_obs.artifact_event(name, filename) mongo_obs.join() assert mongo_obs.fs.put.called - assert mongo_obs.fs.put.call_args[1]['filename'].endswith(name) + assert mongo_obs.fs.put.call_args[1]["filename"].endswith(name) db_run = mongo_obs.runs.find_one() - assert db_run['artifacts'] + assert db_run["artifacts"] def test_mongo_observer_resource_event(mongo_obs, sample_run): @@ -181,17 +176,16 @@ def test_mongo_observer_resource_event(mongo_obs, sample_run): mongo_obs.resource_event(filename) # Add extra heartbeat to make sure that run is updated. - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' - mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, - result=1337) + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" + mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, result=1337) mongo_obs.join() assert mongo_obs.fs.exists.called mongo_obs.fs.exists.assert_any_call(filename=filename) db_run = mongo_obs.runs.find_one() - assert db_run['resources'] == [(filename, md5)] + assert db_run["resources"] == [(filename, md5)] @pytest.fixture @@ -200,14 +194,12 @@ def logged_metrics(): ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 1), ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 2), ScalarMetricLogEntry("training.loss", 30, datetime.datetime.utcnow(), 3), - ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100), ScalarMetricLogEntry("training.accuracy", 20, datetime.datetime.utcnow(), 200), ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300), - ScalarMetricLogEntry("training.loss", 40, datetime.datetime.utcnow(), 10), ScalarMetricLogEntry("training.loss", 50, datetime.datetime.utcnow(), 20), - ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), 30) + ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), 30), ] @@ -230,40 +222,46 @@ def test_log_metrics(mongo_obs, sample_run, logged_metrics): mongo_obs.started_event(**sample_run) # Initialize the info dictionary and standard output with arbitrary values - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" # Take first 6 measured events, group them by metric name # and store the measured series to the 'metrics' collection # and reference the newly created records in the 'info' dictionary. mongo_obs.log_metrics(linearize_metrics(logged_metrics[:6]), info) # Call standard heartbeat event (store the info dictionary to the database) - mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, - result=0) + mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, result=0) # Now, process the remaining events # The metrics shouldn't be overwritten, but appended instead. mongo_obs.log_metrics(linearize_metrics(logged_metrics[6:]), info) - mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, - result=0) + mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, result=0) mongo_obs.join() assert mongo_obs.runs.count() == 1 db_run = mongo_obs.runs.find_one() - assert "metrics" in db_run['info'] + assert "metrics" in db_run["info"] # The newly added metrics belong to the same run and have the same names, so the total number # of metrics should not change. assert mongo_obs.metrics.count() == 2 - loss = mongo_obs.metrics.find_one({"name": "training.loss", "run_id": db_run['_id']}) - assert {"name": "training.loss", "id": str(loss["_id"])} in db_run['info']["metrics"] + loss = mongo_obs.metrics.find_one( + {"name": "training.loss", "run_id": db_run["_id"]} + ) + assert {"name": "training.loss", "id": str(loss["_id"])} in db_run["info"][ + "metrics" + ] # ... but the values should be appended to the original list assert loss["steps"] == [10, 20, 30, 40, 50, 60] assert loss["values"] == [1, 2, 3, 10, 20, 30] for i in range(len(loss["timestamps"]) - 1): assert loss["timestamps"][i] <= loss["timestamps"][i + 1] - accuracy = mongo_obs.metrics.find_one({"name": "training.accuracy", "run_id": db_run['_id']}) - assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in db_run['info']["metrics"] + accuracy = mongo_obs.metrics.find_one( + {"name": "training.accuracy", "run_id": db_run["_id"]} + ) + assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in db_run["info"][ + "metrics" + ] assert accuracy["steps"] == [10, 20, 30] assert accuracy["values"] == [100, 200, 300] @@ -273,8 +271,7 @@ def test_log_metrics(mongo_obs, sample_run, logged_metrics): # Start the experiment mongo_obs.started_event(**sample_run) mongo_obs.log_metrics(linearize_metrics(logged_metrics[:4]), info) - mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, - result=0) + mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, result=0) mongo_obs.join() # A new run has been created assert mongo_obs.runs.count() == 2 @@ -286,48 +283,50 @@ def test_mongo_observer_artifact_event_content_type_added(mongo_obs, sample_run) """Test that the detected content_type is added to other metadata.""" mongo_obs.started_event(**sample_run) - filename = 'setup.py' - name = 'mysetup' + filename = "setup.py" + name = "mysetup" mongo_obs.artifact_event(name, filename) mongo_obs.join() assert mongo_obs.fs.put.called - assert mongo_obs.fs.put.call_args[1]['content_type'] == 'text/x-python' + assert mongo_obs.fs.put.call_args[1]["content_type"] == "text/x-python" db_run = mongo_obs.runs.find_one() - assert db_run['artifacts'] + assert db_run["artifacts"] -def test_mongo_observer_artifact_event_content_type_not_overwritten(mongo_obs, sample_run): +def test_mongo_observer_artifact_event_content_type_not_overwritten( + mongo_obs, sample_run +): """Test that manually set content_type is not overwritten by automatic detection.""" mongo_obs.started_event(**sample_run) - filename = 'setup.py' - name = 'mysetup' + filename = "setup.py" + name = "mysetup" - mongo_obs.artifact_event(name, filename, content_type='application/json') + mongo_obs.artifact_event(name, filename, content_type="application/json") mongo_obs.join() assert mongo_obs.fs.put.called - assert mongo_obs.fs.put.call_args[1]['content_type'] == 'application/json' + assert mongo_obs.fs.put.call_args[1]["content_type"] == "application/json" db_run = mongo_obs.runs.find_one() - assert db_run['artifacts'] + assert db_run["artifacts"] def test_mongo_observer_artifact_event_metadata(mongo_obs, sample_run): """Test that the detected content-type is added to other metadata.""" mongo_obs.started_event(**sample_run) - filename = 'setup.py' - name = 'mysetup' + filename = "setup.py" + name = "mysetup" - mongo_obs.artifact_event(name, filename, metadata={'comment': 'the setup file'}) + mongo_obs.artifact_event(name, filename, metadata={"comment": "the setup file"}) mongo_obs.join() assert mongo_obs.fs.put.called - assert mongo_obs.fs.put.call_args[1]['metadata']['comment'] == 'the setup file' + assert mongo_obs.fs.put.call_args[1]["metadata"]["comment"] == "the setup file" db_run = mongo_obs.runs.find_one() - assert db_run['artifacts'] + assert db_run["artifacts"] diff --git a/tests/test_observers/test_queue_observer.py b/tests/test_observers/test_queue_observer.py index eac36296..424ca704 100644 --- a/tests/test_observers/test_queue_observer.py +++ b/tests/test_observers/test_queue_observer.py @@ -7,11 +7,7 @@ @pytest.fixture def queue_observer(): - return QueueObserver( - mock.MagicMock(), - interval=0.01, - retry_interval=0.01, - ) + return QueueObserver(mock.MagicMock(), interval=0.01, retry_interval=0.01) def test_started_event(queue_observer): @@ -24,8 +20,7 @@ def test_started_event(queue_observer): @pytest.mark.parametrize( - "event_name", - ["heartbeat_event", "resource_event", "artifact_event"], + "event_name", ["heartbeat_event", "resource_event", "artifact_event"] ) def test_non_terminal_generic_events(queue_observer, event_name): queue_observer.started_event() @@ -37,8 +32,7 @@ def test_non_terminal_generic_events(queue_observer, event_name): @pytest.mark.parametrize( - "event_name", - ["completed_event", "interrupted_event", "failed_event"], + "event_name", ["completed_event", "interrupted_event", "failed_event"] ) def test_terminal_generic_events(queue_observer, event_name): queue_observer.started_event() @@ -56,8 +50,16 @@ def test_log_metrics(queue_observer): queue_observer.log_metrics(OrderedDict([first, second]), "info") queue_observer.join() assert queue_observer._covered_observer.method_calls[1][0] == "log_metrics" - assert queue_observer._covered_observer.method_calls[1][1] == (first[0], first[1], "info") + assert queue_observer._covered_observer.method_calls[1][1] == ( + first[0], + first[1], + "info", + ) assert queue_observer._covered_observer.method_calls[1][2] == {} assert queue_observer._covered_observer.method_calls[2][0] == "log_metrics" - assert queue_observer._covered_observer.method_calls[2][1] == (second[0], second[1], "info") + assert queue_observer._covered_observer.method_calls[2][1] == ( + second[0], + second[1], + "info", + ) assert queue_observer._covered_observer.method_calls[2][2] == {} diff --git a/tests/test_observers/test_run_observer.py b/tests/test_observers/test_run_observer.py index 71ca16f2..6369bec0 100644 --- a/tests/test_observers/test_run_observer.py +++ b/tests/test_observers/test_run_observer.py @@ -8,10 +8,12 @@ def test_run_observer(): # basically to silence coverage r = RunObserver() - assert r.started_event({}, 'run', {}, datetime.utcnow(), {}, 'comment', None) is None - assert r.heartbeat_event({}, '', datetime.utcnow(), 'result') is None + assert ( + r.started_event({}, "run", {}, datetime.utcnow(), {}, "comment", None) is None + ) + assert r.heartbeat_event({}, "", datetime.utcnow(), "result") is None assert r.completed_event(datetime.utcnow(), 123) is None assert r.interrupted_event(datetime.utcnow(), "INTERRUPTED") is None - assert r.failed_event(datetime.utcnow(), 'trace') is None - assert r.artifact_event('foo', 'foo.txt') is None - assert r.resource_event('foo.txt') is None + assert r.failed_event(datetime.utcnow(), "trace") is None + assert r.artifact_event("foo", "foo.txt") is None + assert r.resource_event("foo.txt") is None diff --git a/tests/test_observers/test_sql_observer.py b/tests/test_observers/test_sql_observer.py index 06de897f..3daa025f 100644 --- a/tests/test_observers/test_sql_observer.py +++ b/tests/test_observers/test_sql_observer.py @@ -25,6 +25,7 @@ def engine(request): """Engine configuration.""" url = request.config.getoption("--sqlalchemy-connect-url") from sqlalchemy.engine import create_engine + engine = create_engine(url) yield engine engine.dispose() @@ -33,6 +34,7 @@ def engine(request): @pytest.fixture def session(engine): from sqlalchemy.orm import sessionmaker, scoped_session + connection = engine.connect() trans = connection.begin() session_factory = sessionmaker(bind=engine) @@ -51,21 +53,24 @@ def sql_obs(session, engine): @pytest.fixture def sample_run(): - exp = {'name': 'test_exp', 'sources': [], 'dependencies': [], - 'base_dir': '/tmp'} - host = {'hostname': 'test_host', 'cpu': 'Intel', 'os': ['Linux', 'Ubuntu'], - 'python_version': '3.4'} - config = {'config': 'True', 'foo': 'bar', 'answer': 42} - command = 'run' - meta_info = {'comment': 'test run'} + exp = {"name": "test_exp", "sources": [], "dependencies": [], "base_dir": "/tmp"} + host = { + "hostname": "test_host", + "cpu": "Intel", + "os": ["Linux", "Ubuntu"], + "python_version": "3.4", + } + config = {"config": "True", "foo": "bar", "answer": 42} + command = "run" + meta_info = {"comment": "test run"} return { - '_id': 'FEDCBA9876543210', - 'ex_info': exp, - 'command': command, - 'host_info': host, - 'start_time': T1, - 'config': config, - 'meta_info': meta_info, + "_id": "FEDCBA9876543210", + "ex_info": exp, + "command": command, + "host_info": host, + "start_time": T1, + "config": config, + "meta_info": meta_info, } @@ -75,9 +80,9 @@ def tmpfile(): # manually deleting the file, such that we can close it before running the # tests. This is necessary since on Windows we can not open the same file # twice, so for the FileStorageObserver to read it, we need to close it. - f = tempfile.NamedTemporaryFile(suffix='.py', delete=False) + f = tempfile.NamedTemporaryFile(suffix=".py", delete=False) - f.content = 'import sacred\n' + f.content = "import sacred\n" f.write(f.content.encode()) f.flush() f.seek(0) @@ -91,7 +96,7 @@ def tmpfile(): def test_sql_observer_started_event_creates_run(sql_obs, sample_run, session): - sample_run['_id'] = None + sample_run["_id"] = None _id = sql_obs.started_event(**sample_run) assert _id is not None assert session.query(Run).count() == 1 @@ -99,38 +104,35 @@ def test_sql_observer_started_event_creates_run(sql_obs, sample_run, session): assert session.query(Experiment).count() == 1 run = session.query(Run).first() assert run.to_json() == { - '_id': _id, - 'command': sample_run['command'], - 'start_time': sample_run['start_time'], - 'heartbeat': None, - 'stop_time': None, - 'queue_time': None, - 'status': 'RUNNING', - 'result': None, - 'meta': { - 'comment': sample_run['meta_info']['comment'], - 'priority': 0.0}, - 'resources': [], - 'artifacts': [], - 'host': sample_run['host_info'], - 'experiment': sample_run['ex_info'], - 'config': sample_run['config'], - 'captured_out': None, - 'fail_trace': None, - } + "_id": _id, + "command": sample_run["command"], + "start_time": sample_run["start_time"], + "heartbeat": None, + "stop_time": None, + "queue_time": None, + "status": "RUNNING", + "result": None, + "meta": {"comment": sample_run["meta_info"]["comment"], "priority": 0.0}, + "resources": [], + "artifacts": [], + "host": sample_run["host_info"], + "experiment": sample_run["ex_info"], + "config": sample_run["config"], + "captured_out": None, + "fail_trace": None, + } def test_sql_observer_started_event_uses_given_id(sql_obs, sample_run, session): _id = sql_obs.started_event(**sample_run) - assert _id == sample_run['_id'] + assert _id == sample_run["_id"] assert session.query(Run).count() == 1 db_run = session.query(Run).first() - assert db_run.run_id == sample_run['_id'] + assert db_run.run_id == sample_run["_id"] -def test_fs_observer_started_event_saves_source(sql_obs, sample_run, session, - tmpfile): - sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]] +def test_fs_observer_started_event_saves_source(sql_obs, sample_run, session, tmpfile): + sample_run["ex_info"]["sources"] = [[tmpfile.name, tmpfile.md5sum]] sql_obs.started_event(**sample_run) @@ -140,17 +142,16 @@ def test_fs_observer_started_event_saves_source(sql_obs, sample_run, session, assert len(db_run.experiment.sources) == 1 source = db_run.experiment.sources[0] assert source.filename == tmpfile.name - assert source.content == 'import sacred\n' + assert source.content == "import sacred\n" assert source.md5sum == tmpfile.md5sum def test_sql_observer_heartbeat_event_updates_run(sql_obs, sample_run, session): sql_obs.started_event(**sample_run) - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' - sql_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, - result=23.5) + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" + sql_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, result=23.5) assert session.query(Run).count() == 1 db_run = session.query(Run).first() @@ -169,18 +170,18 @@ def test_sql_observer_completed_event_updates_run(sql_obs, sample_run, session): assert db_run.stop_time == T2 assert db_run.result == 42 - assert db_run.status == 'COMPLETED' + assert db_run.status == "COMPLETED" def test_sql_observer_interrupted_event_updates_run(sql_obs, sample_run, session): sql_obs.started_event(**sample_run) - sql_obs.interrupted_event(interrupt_time=T2, status='INTERRUPTED') + sql_obs.interrupted_event(interrupt_time=T2, status="INTERRUPTED") assert session.query(Run).count() == 1 db_run = session.query(Run).first() assert db_run.stop_time == T2 - assert db_run.status == 'INTERRUPTED' + assert db_run.status == "INTERRUPTED" def test_sql_observer_failed_event_updates_run(sql_obs, sample_run, session): @@ -192,14 +193,14 @@ def test_sql_observer_failed_event_updates_run(sql_obs, sample_run, session): db_run = session.query(Run).first() assert db_run.stop_time == T2 - assert db_run.status == 'FAILED' + assert db_run.status == "FAILED" assert db_run.fail_trace == "lots of errors and\nso\non..." def test_sql_observer_artifact_event(sql_obs, sample_run, session, tmpfile): sql_obs.started_event(**sample_run) - sql_obs.artifact_event('my_artifact.py', tmpfile.name) + sql_obs.artifact_event("my_artifact.py", tmpfile.name) assert session.query(Run).count() == 1 db_run = session.query(Run).first() @@ -207,7 +208,7 @@ def test_sql_observer_artifact_event(sql_obs, sample_run, session, tmpfile): assert len(db_run.artifacts) == 1 artifact = db_run.artifacts[0] - assert artifact.filename == 'my_artifact.py' + assert artifact.filename == "my_artifact.py" assert artifact.content.decode() == tmpfile.content @@ -228,8 +229,8 @@ def test_fs_observer_resource_event(sql_obs, sample_run, session, tmpfile): def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpfile): sql_obs2 = SqlObserver(sql_obs.engine, session) - sample_run['_id'] = None - sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]] + sample_run["_id"] = None + sample_run["ex_info"]["sources"] = [[tmpfile.name, tmpfile.md5sum]] sql_obs.started_event(**sample_run) sql_obs2.started_event(**sample_run) @@ -240,8 +241,8 @@ def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpf def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tmpfile): sql_obs2 = SqlObserver(sql_obs.engine, session) - sample_run['_id'] = None - sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]] + sample_run["_id"] = None + sample_run["ex_info"]["sources"] = [[tmpfile.name, tmpfile.md5sum]] sql_obs.started_event(**sample_run) sql_obs2.started_event(**sample_run) @@ -259,5 +260,5 @@ def test_sql_observer_equality(sql_obs, engine, session): assert not sql_obs != sql_obs2 - assert not sql_obs == 'foo' - assert sql_obs != 'foo' + assert not sql_obs == "foo" + assert sql_obs != "foo" diff --git a/tests/test_observers/test_sql_observer_not_installed.py b/tests/test_observers/test_sql_observer_not_installed.py index a386cdd3..4690b509 100644 --- a/tests/test_observers/test_sql_observer_not_installed.py +++ b/tests/test_observers/test_sql_observer_not_installed.py @@ -6,20 +6,20 @@ @pytest.fixture def ex(): - return Experiment('ator3000') + return Experiment("ator3000") -@pytest.mark.skipif(has_sqlalchemy, reason='We are testing the import error.') +@pytest.mark.skipif(has_sqlalchemy, reason="We are testing the import error.") def test_importerror_sql(ex): with pytest.raises(ImportError): - ex.observers.append(SqlObserver.create('some_uri')) + ex.observers.append(SqlObserver.create("some_uri")) @ex.config def cfg(): - a = {'b': 1} + a = {"b": 1} @ex.main def foo(a): - return a['b'] + return a["b"] ex.run() diff --git a/tests/test_observers/test_tinydb_observer.py b/tests/test_observers/test_tinydb_observer.py index 9704e9b0..db793fb9 100644 --- a/tests/test_observers/test_tinydb_observer.py +++ b/tests/test_observers/test_tinydb_observer.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # coding=utf-8 -from __future__ import (division, print_function, unicode_literals, - absolute_import) +from __future__ import division, print_function, unicode_literals, absolute_import import os import datetime @@ -34,66 +33,69 @@ def tinydb_obs(tmpdir): @pytest.fixture() def sample_run(): - exp = {'name': 'test_exp', 'sources': [], 'doc': '', - 'base_dir': os.path.join(os.path.dirname(__file__), '..', '..')} - host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'} - config = {'config': 'True', 'foo': 'bar', 'answer': 42} - command = 'run' - meta_info = {'comment': 'test run'} + exp = { + "name": "test_exp", + "sources": [], + "doc": "", + "base_dir": os.path.join(os.path.dirname(__file__), "..", ".."), + } + host = {"hostname": "test_host", "cpu_count": 1, "python_version": "3.4"} + config = {"config": "True", "foo": "bar", "answer": 42} + command = "run" + meta_info = {"comment": "test run"} return { - '_id': 'FEDCBA9876543210', - 'ex_info': exp, - 'command': command, - 'host_info': host, - 'start_time': T1, - 'config': config, - 'meta_info': meta_info, + "_id": "FEDCBA9876543210", + "ex_info": exp, + "command": command, + "host_info": host, + "start_time": T1, + "config": config, + "meta_info": meta_info, } def test_tinydb_observer_creates_missing_directories(tmpdir): - tinydb_obs = TinyDbObserver.create(path=os.path.join(tmpdir.strpath, 'foo')) - assert tinydb_obs.root == os.path.join(tmpdir.strpath, 'foo') + tinydb_obs = TinyDbObserver.create(path=os.path.join(tmpdir.strpath, "foo")) + assert tinydb_obs.root == os.path.join(tmpdir.strpath, "foo") def test_tinydb_observer_started_event_creates_run(tinydb_obs, sample_run): - sample_run['_id'] = None + sample_run["_id"] = None _id = tinydb_obs.started_event(**sample_run) assert _id is not None assert len(tinydb_obs.runs) == 1 db_run = tinydb_obs.runs.get(eid=1) assert db_run == { - '_id': _id, - 'experiment': sample_run['ex_info'], - 'format': tinydb_obs.VERSION, - 'command': sample_run['command'], - 'host': sample_run['host_info'], - 'start_time': sample_run['start_time'], - 'heartbeat': None, - 'info': {}, - 'captured_out': '', - 'artifacts': [], - 'config': sample_run['config'], - 'meta': sample_run['meta_info'], - 'status': 'RUNNING', - 'resources': [] + "_id": _id, + "experiment": sample_run["ex_info"], + "format": tinydb_obs.VERSION, + "command": sample_run["command"], + "host": sample_run["host_info"], + "start_time": sample_run["start_time"], + "heartbeat": None, + "info": {}, + "captured_out": "", + "artifacts": [], + "config": sample_run["config"], + "meta": sample_run["meta_info"], + "status": "RUNNING", + "resources": [], } def test_tinydb_observer_started_event_uses_given_id(tinydb_obs, sample_run): _id = tinydb_obs.started_event(**sample_run) - assert _id == sample_run['_id'] + assert _id == sample_run["_id"] assert len(tinydb_obs.runs) == 1 db_run = tinydb_obs.runs.get(eid=1) - assert db_run['_id'] == sample_run['_id'] + assert db_run["_id"] == sample_run["_id"] -def test_tinydb_observer_started_event_saves_given_sources(tinydb_obs, - sample_run): - filename = 'setup.py' +def test_tinydb_observer_started_event_saves_given_sources(tinydb_obs, sample_run): + filename = "setup.py" md5 = get_digest(filename) - sample_run['ex_info']['sources'] = [[filename, md5]] + sample_run["ex_info"]["sources"] = [[filename, md5]] _id = tinydb_obs.started_event(**sample_run) assert _id is not None @@ -102,27 +104,27 @@ def test_tinydb_observer_started_event_saves_given_sources(tinydb_obs, # Check all but the experiment section db_run_copy = db_run.copy() - del db_run_copy['experiment'] + del db_run_copy["experiment"] assert db_run_copy == { - '_id': _id, - 'format': tinydb_obs.VERSION, - 'command': sample_run['command'], - 'host': sample_run['host_info'], - 'start_time': sample_run['start_time'], - 'heartbeat': None, - 'info': {}, - 'captured_out': '', - 'artifacts': [], - 'config': sample_run['config'], - 'meta': sample_run['meta_info'], - 'status': 'RUNNING', - 'resources': [] + "_id": _id, + "format": tinydb_obs.VERSION, + "command": sample_run["command"], + "host": sample_run["host_info"], + "start_time": sample_run["start_time"], + "heartbeat": None, + "info": {}, + "captured_out": "", + "artifacts": [], + "config": sample_run["config"], + "meta": sample_run["meta_info"], + "status": "RUNNING", + "resources": [], } - assert len(db_run['experiment']['sources']) == 1 - assert len(db_run['experiment']['sources'][0]) == 3 - assert db_run['experiment']['sources'][0][:2] == [filename, md5] - assert isinstance(db_run['experiment']['sources'][0][2], io.BufferedReader) + assert len(db_run["experiment"]["sources"]) == 1 + assert len(db_run["experiment"]["sources"][0]) == 3 + assert db_run["experiment"]["sources"][0][:2] == [filename, md5] + assert isinstance(db_run["experiment"]["sources"][0][2], io.BufferedReader) # Check that duplicate source files are still listed in ex_info tinydb_obs.db_run_id = None @@ -130,19 +132,22 @@ def test_tinydb_observer_started_event_saves_given_sources(tinydb_obs, assert len(tinydb_obs.runs) == 2 db_run2 = tinydb_obs.runs.get(eid=2) - assert (db_run['experiment']['sources'][0][:2] == - db_run2['experiment']['sources'][0][:2]) + assert ( + db_run["experiment"]["sources"][0][:2] + == db_run2["experiment"]["sources"][0][:2] + ) -def test_tinydb_observer_started_event_generates_different_run_ids(tinydb_obs, - sample_run): - sample_run['_id'] = None +def test_tinydb_observer_started_event_generates_different_run_ids( + tinydb_obs, sample_run +): + sample_run["_id"] = None _id = tinydb_obs.started_event(**sample_run) assert _id is not None # Check that duplicate source files are still listed in ex_info tinydb_obs.db_run_id = None - sample_run['_id'] = None + sample_run["_id"] = None _id2 = tinydb_obs.started_event(**sample_run) assert len(tinydb_obs.runs) == 2 @@ -150,12 +155,11 @@ def test_tinydb_observer_started_event_generates_different_run_ids(tinydb_obs, assert _id != _id2 -def test_tinydb_observer_queued_event_is_not_implemented(tinydb_obs, - sample_run): +def test_tinydb_observer_queued_event_is_not_implemented(tinydb_obs, sample_run): sample_queued_run = sample_run.copy() - del sample_queued_run['start_time'] - sample_queued_run['queue_time'] = T1 + del sample_queued_run["start_time"] + sample_queued_run["queue_time"] = T1 with pytest.raises(NotImplementedError): tinydb_obs.queued_event(**sample_queued_run) @@ -163,32 +167,32 @@ def test_tinydb_observer_queued_event_is_not_implemented(tinydb_obs, def test_tinydb_observer_equality(tmpdir, tinydb_obs): - db = TinyDB(os.path.join(tmpdir.strpath, 'metadata.json')) - fs = HashFS(os.path.join(tmpdir.strpath, 'hashfs'), depth=3, - width=2, algorithm='md5') + db = TinyDB(os.path.join(tmpdir.strpath, "metadata.json")) + fs = HashFS( + os.path.join(tmpdir.strpath, "hashfs"), depth=3, width=2, algorithm="md5" + ) m = TinyDbObserver(db, fs) assert tinydb_obs == m assert not tinydb_obs != m - assert not tinydb_obs == 'foo' - assert tinydb_obs != 'foo' + assert not tinydb_obs == "foo" + assert tinydb_obs != "foo" def test_tinydb_observer_heartbeat_event_updates_run(tinydb_obs, sample_run): tinydb_obs.started_event(**sample_run) - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' - tinydb_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, - result=42) + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" + tinydb_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, result=42) assert len(tinydb_obs.runs) == 1 db_run = tinydb_obs.runs.get(eid=1) - assert db_run['heartbeat'] == T2 - assert db_run['result'] == 42 - assert db_run['info'] == info - assert db_run['captured_out'] == outp + assert db_run["heartbeat"] == T2 + assert db_run["result"] == 42 + assert db_run["info"] == info + assert db_run["captured_out"] == outp def test_tinydb_observer_completed_event_updates_run(tinydb_obs, sample_run): @@ -198,53 +202,51 @@ def test_tinydb_observer_completed_event_updates_run(tinydb_obs, sample_run): assert len(tinydb_obs.runs) == 1 db_run = tinydb_obs.runs.get(eid=1) - assert db_run['stop_time'] == T2 - assert db_run['result'] == 42 - assert db_run['status'] == 'COMPLETED' + assert db_run["stop_time"] == T2 + assert db_run["result"] == 42 + assert db_run["status"] == "COMPLETED" -def test_tinydb_observer_interrupted_event_updates_run(tinydb_obs, - sample_run): +def test_tinydb_observer_interrupted_event_updates_run(tinydb_obs, sample_run): tinydb_obs.started_event(**sample_run) - tinydb_obs.interrupted_event(interrupt_time=T2, status='INTERRUPTED') + tinydb_obs.interrupted_event(interrupt_time=T2, status="INTERRUPTED") assert len(tinydb_obs.runs) == 1 db_run = tinydb_obs.runs.get(eid=1) - assert db_run['stop_time'] == T2 - assert db_run['status'] == 'INTERRUPTED' + assert db_run["stop_time"] == T2 + assert db_run["status"] == "INTERRUPTED" def test_tinydb_observer_failed_event_updates_run(tinydb_obs, sample_run): tinydb_obs.started_event(**sample_run) fail_trace = "lots of errors and\nso\non..." - tinydb_obs.failed_event(fail_time=T2, - fail_trace=fail_trace) + tinydb_obs.failed_event(fail_time=T2, fail_trace=fail_trace) assert len(tinydb_obs.runs) == 1 db_run = tinydb_obs.runs.get(eid=1) - assert db_run['stop_time'] == T2 - assert db_run['status'] == 'FAILED' - assert db_run['fail_trace'] == fail_trace + assert db_run["stop_time"] == T2 + assert db_run["status"] == "FAILED" + assert db_run["fail_trace"] == fail_trace def test_tinydb_observer_artifact_event(tinydb_obs, sample_run): tinydb_obs.started_event(**sample_run) filename = "setup.py" - name = 'mysetup' + name = "mysetup" tinydb_obs.artifact_event(name, filename) assert tinydb_obs.fs.exists(filename) db_run = tinydb_obs.runs.get(eid=1) - assert db_run['artifacts'][0][0] == name + assert db_run["artifacts"][0][0] == name - with open(filename, 'rb') as f: + with open(filename, "rb") as f: file_content = f.read() - assert db_run['artifacts'][0][3].read() == file_content + assert db_run["artifacts"][0][3].read() == file_content def test_tinydb_observer_resource_event(tinydb_obs, sample_run): @@ -258,15 +260,14 @@ def test_tinydb_observer_resource_event(tinydb_obs, sample_run): assert tinydb_obs.fs.exists(filename) db_run = tinydb_obs.runs.get(eid=1) - assert db_run['resources'][0][:2] == [filename, md5] + assert db_run["resources"][0][:2] == [filename, md5] - with open(filename, 'rb') as f: + with open(filename, "rb") as f: file_content = f.read() - assert db_run['resources'][0][2].read() == file_content + assert db_run["resources"][0][2].read() == file_content -def test_tinydb_observer_resource_event_when_resource_present(tinydb_obs, - sample_run): +def test_tinydb_observer_resource_event_when_resource_present(tinydb_obs, sample_run): tinydb_obs.started_event(**sample_run) filename = "setup.py" @@ -278,15 +279,15 @@ def test_tinydb_observer_resource_event_when_resource_present(tinydb_obs, tinydb_obs.resource_event(filename) db_run = tinydb_obs.runs.get(eid=1) - assert db_run['resources'][0][:2] == [filename, md5] + assert db_run["resources"][0][:2] == [filename, md5] def test_custom_bufferreaderwrapper(tmpdir): import copy - with open(os.path.join(tmpdir.strpath, 'test.txt'), 'w') as f: - f.write('some example text') - with open(os.path.join(tmpdir.strpath, 'test.txt'), 'rb') as f: + with open(os.path.join(tmpdir.strpath, "test.txt"), "w") as f: + f.write("some example text") + with open(os.path.join(tmpdir.strpath, "test.txt"), "rb") as f: custom_fh = BufferedReaderWrapper(f) assert f.name == custom_fh.name assert f.mode == custom_fh.mode @@ -306,7 +307,7 @@ def test_custom_bufferreaderwrapper(tmpdir): assert not custom_fh_deepcopy.closed -@pytest.mark.skipif(not opt.has_numpy, reason='needs numpy') +@pytest.mark.skipif(not opt.has_numpy, reason="needs numpy") def test_serialisation_of_numpy_ndarray(tmpdir): from sacred.observers.tinydb_hashfs_bases import NdArraySerializer from tinydb_serialization import SerializationMiddleware @@ -314,31 +315,26 @@ def test_serialisation_of_numpy_ndarray(tmpdir): # Setup Serialisation object for non list/dict objects serialization_store = SerializationMiddleware() - serialization_store.register_serializer(NdArraySerializer(), 'TinyArray') + serialization_store.register_serializer(NdArraySerializer(), "TinyArray") - db = TinyDB(os.path.join(tmpdir.strpath, 'metadata.json'), - storage=serialization_store) + db = TinyDB( + os.path.join(tmpdir.strpath, "metadata.json"), storage=serialization_store + ) eye_mat = np.eye(3) ones_array = np.ones(5) - document = { - 'foo': 'bar', - 'some_array': eye_mat, - 'nested': { - 'ones': ones_array - } - } + document = {"foo": "bar", "some_array": eye_mat, "nested": {"ones": ones_array}} db.insert(document) returned_doc = db.all()[0] - assert returned_doc['foo'] == 'bar' - assert (returned_doc['some_array'] == eye_mat).all() - assert (returned_doc['nested']['ones'] == ones_array).all() + assert returned_doc["foo"] == "bar" + assert (returned_doc["some_array"] == eye_mat).all() + assert (returned_doc["nested"]["ones"] == ones_array).all() -@pytest.mark.skipif(not opt.has_pandas, reason='needs pandas') +@pytest.mark.skipif(not opt.has_pandas, reason="needs pandas") def test_serialisation_of_pandas_dataframe(tmpdir): from sacred.observers.tinydb_hashfs_bases import DataFrameSerializer from sacred.observers.tinydb_hashfs_bases import SeriesSerializer @@ -349,35 +345,28 @@ def test_serialisation_of_pandas_dataframe(tmpdir): # Setup Serialisation object for non list/dict objects serialization_store = SerializationMiddleware() - serialization_store.register_serializer(DataFrameSerializer(), - 'TinyDataFrame') - serialization_store.register_serializer(SeriesSerializer(), - 'TinySeries') + serialization_store.register_serializer(DataFrameSerializer(), "TinyDataFrame") + serialization_store.register_serializer(SeriesSerializer(), "TinySeries") - db = TinyDB(os.path.join(tmpdir.strpath, 'metadata.json'), - storage=serialization_store) + db = TinyDB( + os.path.join(tmpdir.strpath, "metadata.json"), storage=serialization_store + ) - df = pd.DataFrame(np.eye(3), columns=list('ABC')) + df = pd.DataFrame(np.eye(3), columns=list("ABC")) series = pd.Series(np.ones(5)) - document = { - 'foo': 'bar', - 'some_dataframe': df, - 'nested': { - 'ones': series - } - } + document = {"foo": "bar", "some_dataframe": df, "nested": {"ones": series}} db.insert(document) returned_doc = db.all()[0] - assert returned_doc['foo'] == 'bar' - assert (returned_doc['some_dataframe'] == df).all().all() - assert (returned_doc['nested']['ones'] == series).all() + assert returned_doc["foo"] == "bar" + assert (returned_doc["some_dataframe"] == df).all().all() + assert (returned_doc["nested"]["ones"] == series).all() def test_parse_tinydb_arg(): - assert TinyDbOption.parse_tinydb_arg('foo') == 'foo' + assert TinyDbOption.parse_tinydb_arg("foo") == "foo" def test_parse_tinydboption_apply(tmpdir): diff --git a/tests/test_observers/test_tinydb_observer_not_installed.py b/tests/test_observers/test_tinydb_observer_not_installed.py index 9202eaaf..3860642c 100644 --- a/tests/test_observers/test_tinydb_observer_not_installed.py +++ b/tests/test_observers/test_tinydb_observer_not_installed.py @@ -6,20 +6,20 @@ @pytest.fixture def ex(): - return Experiment('ator3000') + return Experiment("ator3000") -@pytest.mark.skipif(has_tinydb, reason='We are testing the import error.') +@pytest.mark.skipif(has_tinydb, reason="We are testing the import error.") def test_importerror_sql(ex): with pytest.raises(ImportError): - ex.observers.append(TinyDbObserver.create('some_uri')) + ex.observers.append(TinyDbObserver.create("some_uri")) @ex.config def cfg(): - a = {'b': 1} + a = {"b": 1} @ex.main def foo(a): - return a['b'] + return a["b"] ex.run() diff --git a/tests/test_observers/test_tinydb_reader.py b/tests/test_observers/test_tinydb_reader.py index 92905a1d..4d13eea8 100644 --- a/tests/test_observers/test_tinydb_reader.py +++ b/tests/test_observers/test_tinydb_reader.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # coding=utf-8 -from __future__ import (division, print_function, unicode_literals, - absolute_import) +from __future__ import division, print_function, unicode_literals, absolute_import import datetime import os @@ -25,29 +24,29 @@ def sample_run(): T1 = datetime.datetime(1999, 5, 4, 3, 2, 1, 0) exp = { - 'name': 'test_exp', - 'sources': [], - 'doc': '', - 'base_dir': os.path.join(os.path.dirname(__file__), '..', '..'), - 'dependencies': ['sacred==0.7b0'] + "name": "test_exp", + "sources": [], + "doc": "", + "base_dir": os.path.join(os.path.dirname(__file__), "..", ".."), + "dependencies": ["sacred==0.7b0"], } - host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'} - config = {'config': 'True', 'foo': 'bar', 'answer': 42} - command = 'run' - meta_info = {'comment': 'test run'} + host = {"hostname": "test_host", "cpu_count": 1, "python_version": "3.4"} + config = {"config": "True", "foo": "bar", "answer": 42} + command = "run" + meta_info = {"comment": "test run"} sample_run = { - '_id': 'FED235DA13', - 'ex_info': exp, - 'command': command, - 'host_info': host, - 'start_time': T1, - 'config': config, - 'meta_info': meta_info, + "_id": "FED235DA13", + "ex_info": exp, + "command": command, + "host_info": host, + "start_time": T1, + "config": config, + "meta_info": meta_info, } - filename = 'setup.py' + filename = "setup.py" md5 = get_digest(filename) - sample_run['ex_info']['sources'] = [[filename, md5]] + sample_run["ex_info"]["sources"] = [[filename, md5]] return sample_run @@ -58,8 +57,8 @@ def run_test_experiment(exp_name, exp_id, root_dir): T3 = datetime.datetime(1999, 5, 5, 6, 6, 6, 6) run_date = sample_run() - run_date['ex_info']['name'] = exp_name - run_date['_id'] = exp_id + run_date["ex_info"]["name"] = exp_name + run_date["_id"] = exp_id # Create tinydb_obs = TinyDbObserver.create(path=root_dir) @@ -68,13 +67,12 @@ def run_test_experiment(exp_name, exp_id, root_dir): tinydb_obs.started_event(**run_date) # Heartbeat - info = {'my_info': [1, 2, 3], 'nr': 7} - outp = 'some output' - tinydb_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, - result=7) + info = {"my_info": [1, 2, 3], "nr": 7} + outp = "some output" + tinydb_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, result=7) # Add Artifact filename = "sacred/__about__.py" - name = 'about' + name = "about" tinydb_obs.artifact_event(name, filename) # Add Resource @@ -99,9 +97,9 @@ def strip_file_handles(results): cleaned_results = [] for result in results: - sources = result['experiment']['sources'] - artifacts = result['artifacts'] - resources = result['resources'] + sources = result["experiment"]["sources"] + artifacts = result["artifacts"] + resources = result["resources"] if sources: for src in sources: if isinstance(src[-1], io.BufferedReader): @@ -123,7 +121,7 @@ def strip_file_handles(results): def test_tinydb_reader_loads_db_and_fs(tmpdir): root = tmpdir.strpath - tinydb_obs = run_test_experiment(exp_name='exp1', exp_id='1234', root_dir=root) + tinydb_obs = run_test_experiment(exp_name="exp1", exp_id="1234", root_dir=root) tinydb_reader = TinyDbReader(root) assert tinydb_obs.fs.root == tinydb_reader.fs.root @@ -134,19 +132,22 @@ def test_tinydb_reader_loads_db_and_fs(tmpdir): def test_tinydb_reader_raises_exceptions(tmpdir): with pytest.raises(IOError): - TinyDbReader('foo') + TinyDbReader("foo") def test_fetch_metadata_function_with_indices(tmpdir): sample_run_ = sample_run() # Setup and run three experiments root = tmpdir.strpath - tinydb_obs = run_test_experiment(exp_name='experiment 1 alpha', - exp_id='1234', root_dir=root) - tinydb_obs = run_test_experiment(exp_name='experiment 2 beta', - exp_id='5678', root_dir=root) - tinydb_obs = run_test_experiment(exp_name='experiment 3 alpha', - exp_id='9990', root_dir=root) + tinydb_obs = run_test_experiment( + exp_name="experiment 1 alpha", exp_id="1234", root_dir=root + ) + tinydb_obs = run_test_experiment( + exp_name="experiment 2 beta", exp_id="5678", root_dir=root + ) + tinydb_obs = run_test_experiment( + exp_name="experiment 3 alpha", exp_id="9990", root_dir=root + ) tinydb_reader = TinyDbReader(root) @@ -159,8 +160,8 @@ def test_fetch_metadata_function_with_indices(tmpdir): exp1_res = tinydb_reader.fetch_metadata(indices=0) assert len(exp1_res) == 1 - assert exp1_res[0]['experiment']['name'] == 'experiment 1 alpha' - assert exp1_res[0]['_id'] == '1234' + assert exp1_res[0]["experiment"]["name"] == "experiment 1 alpha" + assert exp1_res[0]["_id"] == "1234" # Test Exception with pytest.raises(ValueError): @@ -169,32 +170,28 @@ def test_fetch_metadata_function_with_indices(tmpdir): # Test returned values exp1 = strip_file_handles(exp1_res)[0] - sample_run_['ex_info']['name'] = 'experiment 1 alpha' - sample_run_['ex_info']['sources'] = [ - ['setup.py', get_digest('setup.py')] - ] + sample_run_["ex_info"]["name"] = "experiment 1 alpha" + sample_run_["ex_info"]["sources"] = [["setup.py", get_digest("setup.py")]] assert exp1 == { - '_id': '1234', - 'experiment': sample_run_['ex_info'], - 'format': tinydb_obs.VERSION, - 'command': sample_run_['command'], - 'host': sample_run_['host_info'], - 'start_time': sample_run_['start_time'], - 'heartbeat': datetime.datetime(1999, 5, 5, 5, 5, 5, 5), - 'info': {'my_info': [1, 2, 3], 'nr': 7}, - 'captured_out': 'some output', - 'artifacts': [ - ['about', 'sacred/__about__.py', get_digest('sacred/__about__.py')] + "_id": "1234", + "experiment": sample_run_["ex_info"], + "format": tinydb_obs.VERSION, + "command": sample_run_["command"], + "host": sample_run_["host_info"], + "start_time": sample_run_["start_time"], + "heartbeat": datetime.datetime(1999, 5, 5, 5, 5, 5, 5), + "info": {"my_info": [1, 2, 3], "nr": 7}, + "captured_out": "some output", + "artifacts": [ + ["about", "sacred/__about__.py", get_digest("sacred/__about__.py")] ], - 'config': sample_run_['config'], - 'meta': sample_run_['meta_info'], - 'status': 'COMPLETED', - 'resources': [ - ['sacred/__init__.py', get_digest('sacred/__init__.py')] - ], - 'result': 42, - 'stop_time': datetime.datetime(1999, 5, 5, 6, 6, 6, 6) + "config": sample_run_["config"], + "meta": sample_run_["meta_info"], + "status": "COMPLETED", + "resources": [["sacred/__init__.py", get_digest("sacred/__init__.py")]], + "result": 42, + "stop_time": datetime.datetime(1999, 5, 5, 6, 6, 6, 6), } @@ -202,22 +199,19 @@ def test_fetch_metadata_function_with_exp_name(tmpdir): # Setup and run three experiments root = tmpdir.strpath - run_test_experiment(exp_name='experiment 1 alpha', - exp_id='1234', root_dir=root) - run_test_experiment(exp_name='experiment 2 beta', - exp_id='5678', root_dir=root) - run_test_experiment(exp_name='experiment 3 alpha', - exp_id='9990', root_dir=root) + run_test_experiment(exp_name="experiment 1 alpha", exp_id="1234", root_dir=root) + run_test_experiment(exp_name="experiment 2 beta", exp_id="5678", root_dir=root) + run_test_experiment(exp_name="experiment 3 alpha", exp_id="9990", root_dir=root) tinydb_reader = TinyDbReader(root) # Test Fetch by exp name - res1 = tinydb_reader.fetch_metadata(exp_name='alpha') + res1 = tinydb_reader.fetch_metadata(exp_name="alpha") assert len(res1) == 2 - res2 = tinydb_reader.fetch_metadata(exp_name='experiment 1') + res2 = tinydb_reader.fetch_metadata(exp_name="experiment 1") assert len(res2) == 1 - assert res2[0]['experiment']['name'] == 'experiment 1 alpha' - res2 = tinydb_reader.fetch_metadata(exp_name='foo') + assert res2[0]["experiment"]["name"] == "experiment 1 alpha" + res2 = tinydb_reader.fetch_metadata(exp_name="foo") assert len(res2) == 0 @@ -225,36 +219,33 @@ def test_fetch_metadata_function_with_querry(tmpdir): # Setup and run three experiments root = tmpdir.strpath - run_test_experiment(exp_name='experiment 1 alpha', - exp_id='1234', root_dir=root) - run_test_experiment(exp_name='experiment 2 beta', - exp_id='5678', root_dir=root) - run_test_experiment(exp_name='experiment 3 alpha beta', - exp_id='9990', root_dir=root) + run_test_experiment(exp_name="experiment 1 alpha", exp_id="1234", root_dir=root) + run_test_experiment(exp_name="experiment 2 beta", exp_id="5678", root_dir=root) + run_test_experiment( + exp_name="experiment 3 alpha beta", exp_id="9990", root_dir=root + ) tinydb_reader = TinyDbReader(root) record = Query() - exp1_query = record.experiment.name.matches('.*alpha$') + exp1_query = record.experiment.name.matches(".*alpha$") - exp3_query = ( - (record.experiment.name.search('alpha')) & - (record._id == '9990') - ) + exp3_query = (record.experiment.name.search("alpha")) & (record._id == "9990") # Test Fetch by Tinydb Query res1 = tinydb_reader.fetch_metadata(query=exp1_query) assert len(res1) == 1 - assert res1[0]['experiment']['name'] == 'experiment 1 alpha' + assert res1[0]["experiment"]["name"] == "experiment 1 alpha" res2 = tinydb_reader.fetch_metadata( - query=record.experiment.name.search('experiment [23]')) + query=record.experiment.name.search("experiment [23]") + ) assert len(res2) == 2 res3 = tinydb_reader.fetch_metadata(query=exp3_query) assert len(res3) == 1 - assert res3[0]['experiment']['name'] == 'experiment 3 alpha beta' + assert res3[0]["experiment"]["name"] == "experiment 3 alpha beta" # Test Exception with pytest.raises(ValueError): @@ -265,18 +256,17 @@ def test_search_function(tmpdir): # Setup and run three experiments root = tmpdir.strpath - run_test_experiment(exp_name='experiment 1 alpha', - exp_id='1234', root_dir=root) - run_test_experiment(exp_name='experiment 2 beta', - exp_id='5678', root_dir=root) - run_test_experiment(exp_name='experiment 3 alpha beta', - exp_id='9990', root_dir=root) + run_test_experiment(exp_name="experiment 1 alpha", exp_id="1234", root_dir=root) + run_test_experiment(exp_name="experiment 2 beta", exp_id="5678", root_dir=root) + run_test_experiment( + exp_name="experiment 3 alpha beta", exp_id="9990", root_dir=root + ) tinydb_reader = TinyDbReader(root) # Test Fetch by Tinydb Query in search function record = Query() - q = record.experiment.name.search('experiment [23]') + q = record.experiment.name.search("experiment [23]") res = tinydb_reader.search(q) assert len(res) == 2 @@ -287,38 +277,36 @@ def test_search_function(tmpdir): def test_fetch_files_function(tmpdir): # Setup and run three experiments root = tmpdir.strpath - run_test_experiment(exp_name='experiment 1 alpha', - exp_id='1234', root_dir=root) - run_test_experiment(exp_name='experiment 2 beta', - exp_id='5678', root_dir=root) - run_test_experiment(exp_name='experiment 3 alpha beta', - exp_id='9990', root_dir=root) + run_test_experiment(exp_name="experiment 1 alpha", exp_id="1234", root_dir=root) + run_test_experiment(exp_name="experiment 2 beta", exp_id="5678", root_dir=root) + run_test_experiment( + exp_name="experiment 3 alpha beta", exp_id="9990", root_dir=root + ) tinydb_reader = TinyDbReader(root) res = tinydb_reader.fetch_files(indices=0) assert len(res) == 1 - assert list(res[0]['artifacts'].keys()) == ['about'] - assert isinstance(res[0]['artifacts']['about'], io.BufferedReader) - assert res[0]['date'] == datetime.datetime(1999, 5, 4, 3, 2, 1) - assert res[0]['exp_id'] == '1234' - assert res[0]['exp_name'] == 'experiment 1 alpha' - assert list(res[0]['resources'].keys()) == ['sacred/__init__.py'] - assert isinstance(res[0]['resources']['sacred/__init__.py'], io.BufferedReader) - assert list(res[0]['sources'].keys()) == ['setup.py'] - assert isinstance(res[0]['sources']['setup.py'], io.BufferedReader) + assert list(res[0]["artifacts"].keys()) == ["about"] + assert isinstance(res[0]["artifacts"]["about"], io.BufferedReader) + assert res[0]["date"] == datetime.datetime(1999, 5, 4, 3, 2, 1) + assert res[0]["exp_id"] == "1234" + assert res[0]["exp_name"] == "experiment 1 alpha" + assert list(res[0]["resources"].keys()) == ["sacred/__init__.py"] + assert isinstance(res[0]["resources"]["sacred/__init__.py"], io.BufferedReader) + assert list(res[0]["sources"].keys()) == ["setup.py"] + assert isinstance(res[0]["sources"]["setup.py"], io.BufferedReader) def test_fetch_report_function(tmpdir): # Setup and run three experiments root = tmpdir.strpath - run_test_experiment(exp_name='experiment 1 alpha', - exp_id='1234', root_dir=root) - run_test_experiment(exp_name='experiment 2 beta', - exp_id='5678', root_dir=root) - run_test_experiment(exp_name='experiment 3 alpha beta', - exp_id='9990', root_dir=root) + run_test_experiment(exp_name="experiment 1 alpha", exp_id="1234", root_dir=root) + run_test_experiment(exp_name="experiment 2 beta", exp_id="5678", root_dir=root) + run_test_experiment( + exp_name="experiment 3 alpha beta", exp_id="9990", root_dir=root + ) tinydb_reader = TinyDbReader(root) diff --git a/tests/test_run.py b/tests/test_run.py index 28692cb7..5934ce42 100755 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -10,21 +10,24 @@ from sacred.run import Run from sacred.config.config_summary import ConfigSummary -from sacred.utils import (ObserverError, SacredInterrupt, TimeoutInterrupt, - apply_backspaces_and_linefeeds) +from sacred.utils import ( + ObserverError, + SacredInterrupt, + TimeoutInterrupt, + apply_backspaces_and_linefeeds, +) @pytest.fixture def run(): - config = {'a': 17, 'foo': {'bar': True, 'baz': False}, 'seed': 1234} + config = {"a": 17, "foo": {"bar": True, "baz": False}, "seed": 1234} config_mod = ConfigSummary() signature = mock.Mock() - signature.name = 'main_func' - main_func = mock.Mock(return_value=123, prefix='', signature=signature) + signature.name = "main_func" + main_func = mock.Mock(return_value=123, prefix="", signature=signature) logger = mock.Mock() observer = [mock.Mock(priority=10)] - return Run(config, config_mod, main_func, observer, logger, logger, {}, - {}, [], []) + return Run(config, config_mod, main_func, observer, logger, logger, {}, {}, [], []) def test_run_attributes(run): @@ -38,7 +41,7 @@ def test_run_attributes(run): def test_run_state_attributes(run): assert run.start_time is None assert run.stop_time is None - assert run.captured_out == '' + assert run.captured_out == "" assert run.result is None @@ -47,7 +50,7 @@ def test_run_run(run): assert (run.start_time - datetime.utcnow()).total_seconds() < 1 assert (run.stop_time - datetime.utcnow()).total_seconds() < 1 assert run.result == 123 - assert run.captured_out == '' + assert run.captured_out == "" def test_run_emits_events_if_successful(run): @@ -61,11 +64,14 @@ def test_run_emits_events_if_successful(run): assert not observer.failed_event.called -@pytest.mark.parametrize('exception,status', [ - (KeyboardInterrupt, 'INTERRUPTED'), - (SacredInterrupt, 'INTERRUPTED'), - (TimeoutInterrupt, 'TIMEOUT'), -]) +@pytest.mark.parametrize( + "exception,status", + [ + (KeyboardInterrupt, "INTERRUPTED"), + (SacredInterrupt, "INTERRUPTED"), + (TimeoutInterrupt, "TIMEOUT"), + ], +) def test_run_emits_events_if_interrupted(run, exception, status): observer = run.observers[0] run.main_function.side_effect = exception @@ -76,8 +82,8 @@ def test_run_emits_events_if_interrupted(run, exception, status): assert not observer.completed_event.called assert observer.interrupted_event.called observer.interrupted_event.assert_called_with( - interrupt_time=run.stop_time, - status=status) + interrupt_time=run.stop_time, status=status + ) assert not observer.failed_event.called @@ -97,13 +103,13 @@ def test_run_started_event(run): observer = run.observers[0] run() observer.started_event.assert_called_with( - command='main_func', + command="main_func", ex_info=run.experiment_info, host_info=run.host_info, start_time=run.start_time, config=run.config, meta_info={}, - _id=None + _id=None, ) @@ -111,30 +117,30 @@ def test_run_completed_event(run): observer = run.observers[0] run() observer.completed_event.assert_called_with( - stop_time=run.stop_time, - result=run.result + stop_time=run.stop_time, result=run.result ) def test_run_heartbeat_event(run): observer = run.observers[0] - run.info['test'] = 321 + run.info["test"] = 321 run() call_args, call_kwargs = observer.heartbeat_event.call_args_list[0] - assert call_kwargs['info'] == run.info - assert call_kwargs['captured_out'] == "" - assert (call_kwargs['beat_time'] - datetime.utcnow()).total_seconds() < 1 + assert call_kwargs["info"] == run.info + assert call_kwargs["captured_out"] == "" + assert (call_kwargs["beat_time"] - datetime.utcnow()).total_seconds() < 1 def test_run_artifact_event(run): observer = run.observers[0] handle, f_name = tempfile.mkstemp() - name = 'foobar' - metadata = {'testkey': 42} - content_type = 'text/plain' + name = "foobar" + metadata = {"testkey": 42} + content_type = "text/plain" run.add_artifact(f_name, name=name, metadata=metadata, content_type=content_type) - observer.artifact_event.assert_called_with(filename=f_name, name=name, - metadata=metadata, content_type=content_type) + observer.artifact_event.assert_called_with( + filename=f_name, name=name, metadata=metadata, content_type=content_type + ) os.close(handle) os.remove(f_name) @@ -226,7 +232,7 @@ def print_mock_progress(): run.capture_mode = "no" with capsys.disabled(): run() - assert run.captured_out == '' + assert run.captured_out == "" def test_stdout_capturing_sys(run, capsys): @@ -239,12 +245,12 @@ def print_mock_progress(): run.capture_mode = "sys" with capsys.disabled(): run() - assert run.captured_out == '0123456789' + assert run.captured_out == "0123456789" # @pytest.mark.skipif(sys.platform.startswith('win'), # reason="does not work on windows") -@pytest.mark.skip('Breaks randomly on test server') +@pytest.mark.skip("Breaks randomly on test server") def test_stdout_capturing_fd(run, capsys): def print_mock_progress(): for i in range(10): @@ -255,15 +261,15 @@ def print_mock_progress(): run.capture_mode = "fd" with capsys.disabled(): run() - assert run.captured_out == '0123456789' + assert run.captured_out == "0123456789" def test_captured_out_filter(run, capsys): def print_mock_progress(): - sys.stdout.write('progress 0') + sys.stdout.write("progress 0") sys.stdout.flush() for i in range(10): - sys.stdout.write('\b') + sys.stdout.write("\b") sys.stdout.write(str(i)) sys.stdout.flush() @@ -273,4 +279,4 @@ def print_mock_progress(): with capsys.disabled(): run() sys.stdout.flush() - assert run.captured_out == 'progress 9' + assert run.captured_out == "progress 9" diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 2cd323e9..d4ca68bd 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -7,30 +7,36 @@ import sacred.optional as opt -@pytest.mark.parametrize('obj', [ - 12, - 3.14, - "mystring", - "αβγδ", - [1, 2., "3", [4]], - {'foo': 'bar', 'answer': 42}, - None, - True -]) +@pytest.mark.parametrize( + "obj", + [ + 12, + 3.14, + "mystring", + "αβγδ", + [1, 2.0, "3", [4]], + {"foo": "bar", "answer": 42}, + None, + True, + ], +) def test_flatten_on_json_is_noop(obj): assert flatten(obj) == obj -@pytest.mark.parametrize('obj', [ - 12, - 3.14, - "mystring", - "αβγδ", - [1, 2., "3", [4]], - {'foo': 'bar', 'answer': 42}, - None, - True -]) +@pytest.mark.parametrize( + "obj", + [ + 12, + 3.14, + "mystring", + "αβγδ", + [1, 2.0, "3", [4]], + {"foo": "bar", "answer": 42}, + None, + True, + ], +) def test_restore_on_json_is_noop(obj): assert flatten(obj) == obj @@ -41,10 +47,27 @@ def test_serialize_non_str_keys(): @pytest.mark.skipif(not opt.has_numpy, reason="requires numpy") -@pytest.mark.parametrize('typename', [ - 'bool_', 'int_', 'intc', 'intp', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', 'float_', 'float16', 'float32', - 'float64']) +@pytest.mark.parametrize( + "typename", + [ + "bool_", + "int_", + "intc", + "intp", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float_", + "float16", + "float32", + "float64", + ], +) def test_flatten_normalizes_numpy_scalars(typename): dtype = getattr(opt.np, typename) a = 1 @@ -63,7 +86,7 @@ def test_serialize_numpy_arrays(): def test_serialize_tuples(): - t = (1, 'two') + t = (1, "two") assert restore(flatten(t)) == t assert isinstance(restore(flatten(t)), tuple) @@ -71,7 +94,7 @@ def test_serialize_tuples(): @pytest.mark.skipif(not opt.has_pandas, reason="requires pandas") def test_serialize_pandas_dataframes(): pd, np = opt.pandas, opt.np - df = pd.DataFrame(np.arange(20).reshape(5, 4), columns=list('ABCD')) + df = pd.DataFrame(np.arange(20).reshape(5, 4), columns=list("ABCD")) b = restore(flatten(df)) assert np.all(df == b) assert np.all(df.dtypes == b.dtypes) @@ -81,5 +104,3 @@ def test_serialize_pandas_dataframes(): # from datetime import datetime # t = datetime.utcnow() # assert restore(flatten(t)) == t - - diff --git a/tests/test_stdout_capturing.py b/tests/test_stdout_capturing.py index 7d6ef4d1..697eff3f 100644 --- a/tests/test_stdout_capturing.py +++ b/tests/test_stdout_capturing.py @@ -9,51 +9,49 @@ def test_python_tee_output(capsys): - expected_lines = { - "captured stdout", - "captured stderr"} + expected_lines = {"captured stdout", "captured stderr"} capture_mode, capture_stdout = get_stdcapturer("sys") with capsys.disabled(): - print('before (stdout)') - print('before (stderr)') + print("before (stdout)") + print("before (stderr)") with capture_stdout() as out: print("captured stdout") print("captured stderr") output = out.get() - print('after (stdout)') - print('after (stderr)') + print("after (stdout)") + print("after (stderr)") assert set(output.strip().split("\n")) == expected_lines -@pytest.mark.skipif(sys.platform.startswith('win'), - reason="does not run on windows") +@pytest.mark.skipif(sys.platform.startswith("win"), reason="does not run on windows") def test_fd_tee_output(capsys): expected_lines = { "captured stdout", "captured stderr", "stdout from C", - "and this is from echo"} + "and this is from echo", + } capture_mode, capture_stdout = get_stdcapturer("fd") output = "" with capsys.disabled(): - print('before (stdout)') - print('before (stderr)') + print("before (stdout)") + print("before (stderr)") with capture_stdout() as out: print("captured stdout") print("captured stderr", file=sys.stderr) output += out.get() - libc.puts(b'stdout from C') + libc.puts(b"stdout from C") libc.fflush(None) - os.system('echo and this is from echo') + os.system("echo and this is from echo") output += out.get() output += out.get() - print('after (stdout)') - print('after (stderr)') + print("after (stdout)") + print("after (stderr)") assert set(output.strip().split("\n")) == expected_lines diff --git a/tests/test_stflow/test_internal.py b/tests/test_stflow/test_internal.py index d83734a6..a3abc818 100644 --- a/tests/test_stflow/test_internal.py +++ b/tests/test_stflow/test_internal.py @@ -5,7 +5,8 @@ def test_context_method_decorator(): """ Ensure that ContextMethodDecorator can intercept method calls. """ - class FooClass(): + + class FooClass: def __init__(self, x): self.x = x @@ -15,8 +16,7 @@ def do_foo(self, y, z): print(z) return y * self.x + z - def decorate_three_times(instance, original_method, original_args, - original_kwargs): + def decorate_three_times(instance, original_method, original_args, original_kwargs): print("three_times") print(original_args) print(original_kwargs) @@ -31,14 +31,16 @@ def decorate_three_times(instance, original_method, original_args, assert foo.do_foo(5, z=6) == (5 * 10 + 6) assert foo.do_foo(y=5, z=6) == (5 * 10 + 6) - def decorate_three_times_with_exception(instance, original_method, - original_args, original_kwargs): + def decorate_three_times_with_exception( + instance, original_method, original_args, original_kwargs + ): raise RuntimeError("This should be caught") exception = False try: - with ContextMethodDecorator(FooClass, "do_foo", - decorate_three_times_with_exception): + with ContextMethodDecorator( + FooClass, "do_foo", decorate_three_times_with_exception + ): foo = FooClass(10) this_should_raise_exception = foo.do_foo(5, 6) except RuntimeError: diff --git a/tests/test_utils.py b/tests/test_utils.py index 98400d74..c3bb05df 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,102 +3,124 @@ import pytest -from sacred.utils import (PATHCHANGE, convert_to_nested_dict, - get_by_dotted_path, is_prefix, - iter_path_splits, iter_prefixes, iterate_flattened, - iterate_flattened_separately, join_paths, - recursive_update, set_by_dotted_path, get_inheritors, - convert_camel_case_to_snake_case, - apply_backspaces_and_linefeeds, module_exists, - module_is_imported, module_is_in_cache, - get_package_version, parse_version, rel_path) +from sacred.utils import ( + PATHCHANGE, + convert_to_nested_dict, + get_by_dotted_path, + is_prefix, + iter_path_splits, + iter_prefixes, + iterate_flattened, + iterate_flattened_separately, + join_paths, + recursive_update, + set_by_dotted_path, + get_inheritors, + convert_camel_case_to_snake_case, + apply_backspaces_and_linefeeds, + module_exists, + module_is_imported, + module_is_in_cache, + get_package_version, + parse_version, + rel_path, +) def test_recursive_update(): - d = {'a': {'b': 1}} - res = recursive_update(d, {'c': 2, 'a': {'d': 3}}) + d = {"a": {"b": 1}} + res = recursive_update(d, {"c": 2, "a": {"d": 3}}) assert d is res - assert res == {'a': {'b': 1, 'd': 3}, 'c': 2} + assert res == {"a": {"b": 1, "d": 3}, "c": 2} def test_iterate_flattened_separately(): - d = {'a1': 1, - 'b2': {'bar': 'foo', 'foo': 'bar'}, - 'c1': 'f', - 'd1': [1, 2, 3], - 'e2': {}} - res = list(iterate_flattened_separately(d, ['foo', 'bar'])) - assert res == [('a1', 1), ('c1', 'f'), ('d1', [1, 2, 3]), ('e2', {}), - ('b2', PATHCHANGE), ('b2.foo', 'bar'), ('b2.bar', 'foo')] + d = { + "a1": 1, + "b2": {"bar": "foo", "foo": "bar"}, + "c1": "f", + "d1": [1, 2, 3], + "e2": {}, + } + res = list(iterate_flattened_separately(d, ["foo", "bar"])) + assert res == [ + ("a1", 1), + ("c1", "f"), + ("d1", [1, 2, 3]), + ("e2", {}), + ("b2", PATHCHANGE), + ("b2.foo", "bar"), + ("b2.bar", "foo"), + ] def test_iterate_flattened(): - d = {'a': {'aa': 1, 'ab': {'aba': 8}}, 'b': 3} - assert list(iterate_flattened(d)) == \ - [('a.aa', 1), ('a.ab.aba', 8), ('b', 3)] + d = {"a": {"aa": 1, "ab": {"aba": 8}}, "b": 3} + assert list(iterate_flattened(d)) == [("a.aa", 1), ("a.ab.aba", 8), ("b", 3)] def test_set_by_dotted_path(): - d = {'foo': {'bar': 7}} - set_by_dotted_path(d, 'foo.bar', 10) - assert d == {'foo': {'bar': 10}} + d = {"foo": {"bar": 7}} + set_by_dotted_path(d, "foo.bar", 10) + assert d == {"foo": {"bar": 10}} def test_set_by_dotted_path_creates_missing_dicts(): - d = {'foo': {'bar': 7}} - set_by_dotted_path(d, 'foo.d.baz', 3) - assert d == {'foo': {'bar': 7, 'd': {'baz': 3}}} + d = {"foo": {"bar": 7}} + set_by_dotted_path(d, "foo.d.baz", 3) + assert d == {"foo": {"bar": 7, "d": {"baz": 3}}} def test_get_by_dotted_path(): - assert get_by_dotted_path({'a': 12}, 'a') == 12 - assert get_by_dotted_path({'a': 12}, '') == {'a': 12} - assert get_by_dotted_path({'foo': {'a': 12}}, 'foo.a') == 12 - assert get_by_dotted_path({'foo': {'a': 12}}, 'foo.b') is None + assert get_by_dotted_path({"a": 12}, "a") == 12 + assert get_by_dotted_path({"a": 12}, "") == {"a": 12} + assert get_by_dotted_path({"foo": {"a": 12}}, "foo.a") == 12 + assert get_by_dotted_path({"foo": {"a": 12}}, "foo.b") is None def test_iter_path_splits(): - assert list(iter_path_splits('foo.bar.baz')) ==\ - [('', 'foo.bar.baz'), - ('foo', 'bar.baz'), - ('foo.bar', 'baz')] + assert list(iter_path_splits("foo.bar.baz")) == [ + ("", "foo.bar.baz"), + ("foo", "bar.baz"), + ("foo.bar", "baz"), + ] def test_iter_prefixes(): - assert list(iter_prefixes('foo.bar.baz')) == \ - ['foo', 'foo.bar', 'foo.bar.baz'] + assert list(iter_prefixes("foo.bar.baz")) == ["foo", "foo.bar", "foo.bar.baz"] def test_join_paths(): - assert join_paths() == '' - assert join_paths('foo') == 'foo' - assert join_paths('foo', 'bar') == 'foo.bar' - assert join_paths('a', 'b', 'c', 'd') == 'a.b.c.d' - assert join_paths('', 'b', '', 'd') == 'b.d' - assert join_paths('a.b', 'c.d.e') == 'a.b.c.d.e' - assert join_paths('a.b.', 'c.d.e') == 'a.b.c.d.e' + assert join_paths() == "" + assert join_paths("foo") == "foo" + assert join_paths("foo", "bar") == "foo.bar" + assert join_paths("a", "b", "c", "d") == "a.b.c.d" + assert join_paths("", "b", "", "d") == "b.d" + assert join_paths("a.b", "c.d.e") == "a.b.c.d.e" + assert join_paths("a.b.", "c.d.e") == "a.b.c.d.e" def test_is_prefix(): - assert is_prefix('', 'foo') - assert is_prefix('foo', 'foo.bar') - assert is_prefix('foo.bar', 'foo.bar.baz') + assert is_prefix("", "foo") + assert is_prefix("foo", "foo.bar") + assert is_prefix("foo.bar", "foo.bar.baz") - assert not is_prefix('a', 'foo.bar') - assert not is_prefix('a.bar', 'foo.bar') - assert not is_prefix('foo.b', 'foo.bar') - assert not is_prefix('foo.bar', 'foo.bar') + assert not is_prefix("a", "foo.bar") + assert not is_prefix("a.bar", "foo.bar") + assert not is_prefix("foo.b", "foo.bar") + assert not is_prefix("foo.bar", "foo.bar") def test_convert_to_nested_dict(): - dotted_dict = {'foo.bar': 8, 'foo.baz': 7} - assert convert_to_nested_dict(dotted_dict) == {'foo': {'bar': 8, 'baz': 7}} + dotted_dict = {"foo.bar": 8, "foo.baz": 7} + assert convert_to_nested_dict(dotted_dict) == {"foo": {"bar": 8, "baz": 7}} def test_convert_to_nested_dict_nested(): - dotted_dict = {'a.b': {'foo.bar': 8}, 'a.b.foo.baz': 7} - assert convert_to_nested_dict(dotted_dict) == \ - {'a': {'b': {'foo': {'bar': 8, 'baz': 7}}}} + dotted_dict = {"a.b": {"foo.bar": 8}, "a.b.foo.baz": 7} + assert convert_to_nested_dict(dotted_dict) == { + "a": {"b": {"foo": {"bar": 8, "baz": 7}}} + } def test_get_inheritors(): @@ -120,84 +142,90 @@ class E: assert get_inheritors(A) == {B, C, D} -@pytest.mark.parametrize('name,expected', [ - ('CamelCase', 'camel_case'), - ('snake_case', 'snake_case'), - ('CamelCamelCase', 'camel_camel_case'), - ('Camel2Camel2Case', 'camel2_camel2_case'), - ('getHTTPResponseCode', 'get_http_response_code'), - ('get2HTTPResponseCode', 'get2_http_response_code'), - ('HTTPResponseCode', 'http_response_code'), - ('HTTPResponseCodeXYZ', 'http_response_code_xyz') -]) +@pytest.mark.parametrize( + "name,expected", + [ + ("CamelCase", "camel_case"), + ("snake_case", "snake_case"), + ("CamelCamelCase", "camel_camel_case"), + ("Camel2Camel2Case", "camel2_camel2_case"), + ("getHTTPResponseCode", "get_http_response_code"), + ("get2HTTPResponseCode", "get2_http_response_code"), + ("HTTPResponseCode", "http_response_code"), + ("HTTPResponseCodeXYZ", "http_response_code_xyz"), + ], +) def test_convert_camel_case_to_snake_case(name, expected): assert convert_camel_case_to_snake_case(name) == expected -@pytest.mark.parametrize('text,expected', [ - ('', ''), - ('\b', ''), - ('\r', '\r'), - ('\r\n', '\n'), - ('ab\bc', 'ac'), - ('\ba', 'a'), - ('ab\nc\b\bd', 'ab\nd'), - ('abc\rdef', 'def'), - ('abc\r', 'abc\r'), - ('abc\rd', 'dbc'), - ('abc\r\nd', 'abc\nd'), - ('abc\ndef\rg', 'abc\ngef'), - ('abc\ndef\r\rg', 'abc\ngef'), - ('abcd\refg\r', 'efgd\r'), - ('abcd\refg\r\n', 'efgd\n') -]) +@pytest.mark.parametrize( + "text,expected", + [ + ("", ""), + ("\b", ""), + ("\r", "\r"), + ("\r\n", "\n"), + ("ab\bc", "ac"), + ("\ba", "a"), + ("ab\nc\b\bd", "ab\nd"), + ("abc\rdef", "def"), + ("abc\r", "abc\r"), + ("abc\rd", "dbc"), + ("abc\r\nd", "abc\nd"), + ("abc\ndef\rg", "abc\ngef"), + ("abc\ndef\r\rg", "abc\ngef"), + ("abcd\refg\r", "efgd\r"), + ("abcd\refg\r\n", "efgd\n"), + ], +) def test_apply_backspaces_and_linefeeds(text, expected): assert apply_backspaces_and_linefeeds(text) == expected def test_module_exists_base_level_modules(): - assert module_exists('pytest') - assert not module_exists('clearly_non_existing_module_name') + assert module_exists("pytest") + assert not module_exists("clearly_non_existing_module_name") def test_module_exists_does_not_import_module(): - assert module_exists('tests.donotimport') + assert module_exists("tests.donotimport") def test_module_is_in_cache(): - assert module_is_in_cache('pytest') - assert module_is_in_cache('pkgutil') - assert not module_is_in_cache('does_not_even_exist') + assert module_is_in_cache("pytest") + assert module_is_in_cache("pkgutil") + assert not module_is_in_cache("does_not_even_exist") def test_module_is_imported(): globs = globals() - assert module_is_imported('pytest', scope=globs) - assert not module_is_imported('pkgutil', scope=globs) - assert not module_is_imported('does_not_even_exist', scope=globs) + assert module_is_imported("pytest", scope=globs) + assert not module_is_imported("pkgutil", scope=globs) + assert not module_is_imported("does_not_even_exist", scope=globs) def test_module_is_imported_uses_caller_globals_by_default(): - assert module_is_imported('pytest') - assert not module_is_imported('pkgutil') - assert not module_is_imported('does_not_even_exist') + assert module_is_imported("pytest") + assert not module_is_imported("pkgutil") + assert not module_is_imported("does_not_even_exist") def test_get_package_version(): - package_version = get_package_version('pytest') - assert str(package_version) == '4.3.0' + package_version = get_package_version("pytest") + assert str(package_version) == "4.3.0" def test_parse_version(): - parsed_version = parse_version('4.3.0') - assert str(parsed_version) == '4.3.0' + parsed_version = parse_version("4.3.0") + assert str(parsed_version) == "4.3.0" def test_get_package_version_comparison(): - package_version = get_package_version('pytest') - current_version = parse_version('4.3.0') - old_version = parse_version('4.2.1') - new_version = parse_version('4.4.1') + package_version = get_package_version("pytest") + current_version = parse_version("4.3.0") + old_version = parse_version("4.2.1") + new_version = parse_version("4.4.1") assert package_version == current_version assert not package_version < current_version assert not package_version > current_version @@ -206,8 +234,8 @@ def test_get_package_version_comparison(): def test_rel_path(): - assert rel_path('', 'foo.bar.baz') == 'foo.bar.baz' - assert rel_path('foo', 'foo.bar.baz') == 'bar.baz' - assert rel_path('foo.bar', 'foo.bar.baz') == 'baz' - assert rel_path('foo.bar.baz', 'foo.bar.baz') == '' - assert rel_path('', '') == '' + assert rel_path("", "foo.bar.baz") == "foo.bar.baz" + assert rel_path("foo", "foo.bar.baz") == "bar.baz" + assert rel_path("foo.bar", "foo.bar.baz") == "baz" + assert rel_path("foo.bar.baz", "foo.bar.baz") == "" + assert rel_path("", "") == "" diff --git a/tox.ini b/tox.ini index 844f4104..e665ab52 100644 --- a/tox.ini +++ b/tox.ini @@ -72,6 +72,12 @@ deps = commands = flake8 --max-complexity 10 sacred +[testenv:black] +basepython = python +deps = git+https://github.com/psf/black +commands = + black --check sacred/ tests/ + [testenv:coverage] passenv = TRAVIS TRAVIS_* basepython = python