From 080b771eeba1410210b8f515dab588ecc0dfc87a Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 19 Jul 2019 11:21:26 +0200 Subject: [PATCH 1/6] Refactoring of tensorflow. --- sacred/optional.py | 14 ++++++++++++++ sacred/randomness.py | 10 +--------- sacred/stflow/method_interception.py | 13 +++---------- tests/test_stflow/test_method_interception.py | 9 ++------- 4 files changed, 20 insertions(+), 26 deletions(-) diff --git a/sacred/optional.py b/sacred/optional.py index 14fabb0c..477e1e0f 100644 --- a/sacred/optional.py +++ b/sacred/optional.py @@ -3,6 +3,7 @@ import importlib from sacred.utils import modules_exist +from sacred.utils import get_package_version, parse_version def optional_import(*package_names): @@ -13,6 +14,19 @@ def optional_import(*package_names): return False, None +def get_tensorflow(): + # Ensures backward and forward compatibility with TensorFlow 1 and 2. + if get_package_version('tensorflow') < parse_version('1.13.1'): + import warnings + warnings.warn("Use of TensorFlow 1.12 and older is deprecated. " + "Use Tensorflow 1.13 or newer instead.", + DeprecationWarning) + import tensorflow as tf + else: + import tensorflow.compat.v1 as tf + return tf + + # Get libc in a cross-platform way and use it to also flush the c stdio buffers # credit to J.F. Sebastians SO answer from here: # http://stackoverflow.com/a/22434262/1388435 diff --git a/sacred/randomness.py b/sacred/randomness.py index b40cb48a..26870947 100644 --- a/sacred/randomness.py +++ b/sacred/randomness.py @@ -30,15 +30,7 @@ def set_global_seed(seed): if opt.has_numpy: opt.np.random.seed(seed) if module_is_in_cache('tensorflow'): - # Ensures backward and forward compatibility with TensorFlow 1 and 2. - if get_package_version('tensorflow') < parse_version('1.13.1'): - import warnings - warnings.warn("Use of TensorFlow 1.12 and older is deprecated. " - "Use Tensorflow 1.13 or newer instead.", - DeprecationWarning) - import tensorflow as tf - else: - import tensorflow.compat.v1 as tf + tf = opt.get_tensorflow() tf.set_random_seed(seed) if module_is_in_cache('torch'): import torch diff --git a/sacred/stflow/method_interception.py b/sacred/stflow/method_interception.py index aafee383..641e7e47 100644 --- a/sacred/stflow/method_interception.py +++ b/sacred/stflow/method_interception.py @@ -1,17 +1,10 @@ from .contextlibbackport import ContextDecorator from .internal import ContextMethodDecorator import sacred.optional as opt -from sacred.utils import get_package_version, parse_version + + if opt.has_tensorflow: - # Ensures backward and forward compatibility with TensorFlow 1 and 2. - if get_package_version('tensorflow') < parse_version('1.13.1'): - import warnings - warnings.warn("Use of TensorFlow 1.12 and older is deprecated. " - "Use Tensorflow 1.13 or newer instead.", - DeprecationWarning) - import tensorflow as tf - else: - import tensorflow.compat.v1 as tf + tf = opt.get_tensorflow() else: tf = None diff --git a/tests/test_stflow/test_method_interception.py b/tests/test_stflow/test_method_interception.py index c701c034..a74adc5a 100644 --- a/tests/test_stflow/test_method_interception.py +++ b/tests/test_stflow/test_method_interception.py @@ -3,7 +3,7 @@ from sacred import Experiment from sacred.stflow import LogFileWriter -from sacred.utils import get_package_version, parse_version +import sacred.optional as opt @pytest.fixture @@ -19,12 +19,7 @@ def tf(): """ from sacred.optional import has_tensorflow if has_tensorflow: - # Ensures backward and forward compatibility with TensorFlow 1 and 2. - if get_package_version('tensorflow') < parse_version('1.13.1'): - import tensorflow as tf - else: - import tensorflow.compat.v1 as tf - return tf + return opt.get_tensorflow() else: # Let's define a mocked tensorflow class tensorflow: From fa1b0d3a707d2f9f7c125ad3e3f58e6ce80cbe69 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 19 Jul 2019 11:23:10 +0200 Subject: [PATCH 2/6] Removed unused imports. --- sacred/randomness.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sacred/randomness.py b/sacred/randomness.py index 26870947..5fe20d85 100644 --- a/sacred/randomness.py +++ b/sacred/randomness.py @@ -4,8 +4,7 @@ import random import sacred.optional as opt -from sacred.utils import (module_is_in_cache, get_package_version, - parse_version) +from sacred.utils import module_is_in_cache SEEDRANGE = (1, int(1e9)) From 261c517a4832fed7800eb47b879abd2ba0325944 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 19 Jul 2019 12:00:01 +0200 Subject: [PATCH 3/6] Fixed the problem with the tensorflow mock. --- sacred/optional.py | 50 ++++++++++++++---- sacred/stflow/method_interception.py | 7 +-- tests/test_stflow/test_method_interception.py | 51 ++++--------------- 3 files changed, 51 insertions(+), 57 deletions(-) diff --git a/sacred/optional.py b/sacred/optional.py index 477e1e0f..dec066e3 100644 --- a/sacred/optional.py +++ b/sacred/optional.py @@ -6,6 +6,9 @@ from sacred.utils import get_package_version, parse_version +_TENSORFLOW_ALLOW_MOCK = False + + def optional_import(*package_names): try: packages = [importlib.import_module(pn) for pn in package_names] @@ -15,16 +18,43 @@ def optional_import(*package_names): def get_tensorflow(): - # Ensures backward and forward compatibility with TensorFlow 1 and 2. - if get_package_version('tensorflow') < parse_version('1.13.1'): - import warnings - warnings.warn("Use of TensorFlow 1.12 and older is deprecated. " - "Use Tensorflow 1.13 or newer instead.", - DeprecationWarning) - import tensorflow as tf - else: - import tensorflow.compat.v1 as tf - return tf + try: + # Ensures backward and forward compatibility with TensorFlow 1 and 2. + if get_package_version('tensorflow') < parse_version('1.13.1'): + import warnings + warnings.warn("Use of TensorFlow 1.12 and older is deprecated. " + "Use Tensorflow 1.13 or newer instead.", + DeprecationWarning) + import tensorflow as tf + else: + import tensorflow.compat.v1 as tf + return tf + except ImportError: + if not _TENSORFLOW_ALLOW_MOCK: + raise + # Let's define a mocked tensorflow + class tensorflow: + class summary: + class FileWriter: + def __init__(self, logdir, graph): + self.logdir = logdir + self.graph = graph + print("Mocked FileWriter got logdir=%s, graph=%s" % (logdir, graph)) + + class Session: + def __init__(self): + self.graph = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + # Set stflow to use the mock as the test + import sacred.stflow.method_interception + sacred.stflow.method_interception.tf = tensorflow + return tensorflow # Get libc in a cross-platform way and use it to also flush the c stdio buffers diff --git a/sacred/stflow/method_interception.py b/sacred/stflow/method_interception.py index 641e7e47..263973b5 100644 --- a/sacred/stflow/method_interception.py +++ b/sacred/stflow/method_interception.py @@ -3,12 +3,6 @@ import sacred.optional as opt -if opt.has_tensorflow: - tf = opt.get_tensorflow() -else: - tf = None - - class LogFileWriter(ContextDecorator, ContextMethodDecorator): """ Intercept ``logdir`` each time a new ``FileWriter`` instance is created. @@ -72,6 +66,7 @@ def log_writer_decorator(instance, original_method, original_args, "logdirs", []).append(logdir) return result + tf = opt.get_tensorflow() ContextMethodDecorator.__init__(self, tf.summary.FileWriter, "__init__", diff --git a/tests/test_stflow/test_method_interception.py b/tests/test_stflow/test_method_interception.py index a74adc5a..d3671bbd 100644 --- a/tests/test_stflow/test_method_interception.py +++ b/tests/test_stflow/test_method_interception.py @@ -6,47 +6,16 @@ import sacred.optional as opt +opt._TENSORFLOW_ALLOW_MOCK = True +tf = opt.get_tensorflow() + + @pytest.fixture def ex(): return Experiment('tensorflow_tests') -@pytest.fixture() -def tf(): - """ - Creates a simplified tensorflow interface if necessary, - so `tensorflow` is not required during the tests. - """ - from sacred.optional import has_tensorflow - if has_tensorflow: - return opt.get_tensorflow() - else: - # Let's define a mocked tensorflow - class tensorflow: - class summary: - class FileWriter: - def __init__(self, logdir, graph): - self.logdir = logdir - self.graph = graph - print("Mocked FileWriter got logdir=%s, graph=%s" % (logdir, graph)) - - class Session: - def __init__(self): - self.graph = None - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - # Set stflow to use the mock as the test - import sacred.stflow.method_interception - sacred.stflow.method_interception.tf = tensorflow - return tensorflow - - -def test_log_file_writer(ex, tf): +def test_log_file_writer(ex): """ Tests whether logdir is stored into the info dictionary when creating a new FileWriter object. """ @@ -68,7 +37,7 @@ def run_experiment(_run): ex.run() -def test_log_summary_writer_as_context_manager(ex, tf): +def test_log_summary_writer_as_context_manager(ex): """ Check that Tensorflow log directory is captured by LogFileWriter context manager. """ @@ -99,7 +68,7 @@ def run_experiment(_run): ex.run() -def test_log_file_writer_as_context_manager_with_exception(ex, tf): +def test_log_file_writer_as_context_manager_with_exception(ex): """ Check that Tensorflow log directory is captured by LogFileWriter context manager. """ @@ -125,7 +94,7 @@ def run_experiment(_run): ex.run() -def test_log_summary_writer_class(ex, tf): +def test_log_summary_writer_class(ex): """ Tests whether logdir is stored into the info dictionary when creating a new FileWriter object, but this time on a method of a class. @@ -163,5 +132,5 @@ def run_experiment(_run): ex.run() -if __name__ == "__main__": - test_log_file_writer(ex(), tf()) +if __name__ == '__main__': + pytest.main([__file__]) From 9a34fca4130bcd6c9d834fd5addd0cc91574eede Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 19 Jul 2019 12:03:27 +0200 Subject: [PATCH 4/6] No need for attribute assignment on module. --- sacred/optional.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sacred/optional.py b/sacred/optional.py index dec066e3..aae99226 100644 --- a/sacred/optional.py +++ b/sacred/optional.py @@ -51,9 +51,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass - # Set stflow to use the mock as the test - import sacred.stflow.method_interception - sacred.stflow.method_interception.tf = tensorflow return tensorflow From 6e388daf202452c9c9e988d325fb4f5cede2f9be Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 19 Jul 2019 12:18:27 +0200 Subject: [PATCH 5/6] Reverted some changes. --- sacred/optional.py | 47 ++++------------- sacred/stflow/method_interception.py | 7 ++- tests/test_stflow/test_method_interception.py | 51 +++++++++++++++---- 3 files changed, 57 insertions(+), 48 deletions(-) diff --git a/sacred/optional.py b/sacred/optional.py index aae99226..477e1e0f 100644 --- a/sacred/optional.py +++ b/sacred/optional.py @@ -6,9 +6,6 @@ from sacred.utils import get_package_version, parse_version -_TENSORFLOW_ALLOW_MOCK = False - - def optional_import(*package_names): try: packages = [importlib.import_module(pn) for pn in package_names] @@ -18,40 +15,16 @@ def optional_import(*package_names): def get_tensorflow(): - try: - # Ensures backward and forward compatibility with TensorFlow 1 and 2. - if get_package_version('tensorflow') < parse_version('1.13.1'): - import warnings - warnings.warn("Use of TensorFlow 1.12 and older is deprecated. " - "Use Tensorflow 1.13 or newer instead.", - DeprecationWarning) - import tensorflow as tf - else: - import tensorflow.compat.v1 as tf - return tf - except ImportError: - if not _TENSORFLOW_ALLOW_MOCK: - raise - # Let's define a mocked tensorflow - class tensorflow: - class summary: - class FileWriter: - def __init__(self, logdir, graph): - self.logdir = logdir - self.graph = graph - print("Mocked FileWriter got logdir=%s, graph=%s" % (logdir, graph)) - - class Session: - def __init__(self): - self.graph = None - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - return tensorflow + # Ensures backward and forward compatibility with TensorFlow 1 and 2. + if get_package_version('tensorflow') < parse_version('1.13.1'): + import warnings + warnings.warn("Use of TensorFlow 1.12 and older is deprecated. " + "Use Tensorflow 1.13 or newer instead.", + DeprecationWarning) + import tensorflow as tf + else: + import tensorflow.compat.v1 as tf + return tf # Get libc in a cross-platform way and use it to also flush the c stdio buffers diff --git a/sacred/stflow/method_interception.py b/sacred/stflow/method_interception.py index 263973b5..9b48586b 100644 --- a/sacred/stflow/method_interception.py +++ b/sacred/stflow/method_interception.py @@ -3,6 +3,12 @@ import sacred.optional as opt +if opt.has_tensorflow: + opt.get_tensorflow() +else: + tf = None + + class LogFileWriter(ContextDecorator, ContextMethodDecorator): """ Intercept ``logdir`` each time a new ``FileWriter`` instance is created. @@ -66,7 +72,6 @@ def log_writer_decorator(instance, original_method, original_args, "logdirs", []).append(logdir) return result - tf = opt.get_tensorflow() ContextMethodDecorator.__init__(self, tf.summary.FileWriter, "__init__", diff --git a/tests/test_stflow/test_method_interception.py b/tests/test_stflow/test_method_interception.py index d3671bbd..a74adc5a 100644 --- a/tests/test_stflow/test_method_interception.py +++ b/tests/test_stflow/test_method_interception.py @@ -6,16 +6,47 @@ import sacred.optional as opt -opt._TENSORFLOW_ALLOW_MOCK = True -tf = opt.get_tensorflow() - - @pytest.fixture def ex(): return Experiment('tensorflow_tests') -def test_log_file_writer(ex): +@pytest.fixture() +def tf(): + """ + Creates a simplified tensorflow interface if necessary, + so `tensorflow` is not required during the tests. + """ + from sacred.optional import has_tensorflow + if has_tensorflow: + return opt.get_tensorflow() + else: + # Let's define a mocked tensorflow + class tensorflow: + class summary: + class FileWriter: + def __init__(self, logdir, graph): + self.logdir = logdir + self.graph = graph + print("Mocked FileWriter got logdir=%s, graph=%s" % (logdir, graph)) + + class Session: + def __init__(self): + self.graph = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + # Set stflow to use the mock as the test + import sacred.stflow.method_interception + sacred.stflow.method_interception.tf = tensorflow + return tensorflow + + +def test_log_file_writer(ex, tf): """ Tests whether logdir is stored into the info dictionary when creating a new FileWriter object. """ @@ -37,7 +68,7 @@ def run_experiment(_run): ex.run() -def test_log_summary_writer_as_context_manager(ex): +def test_log_summary_writer_as_context_manager(ex, tf): """ Check that Tensorflow log directory is captured by LogFileWriter context manager. """ @@ -68,7 +99,7 @@ def run_experiment(_run): ex.run() -def test_log_file_writer_as_context_manager_with_exception(ex): +def test_log_file_writer_as_context_manager_with_exception(ex, tf): """ Check that Tensorflow log directory is captured by LogFileWriter context manager. """ @@ -94,7 +125,7 @@ def run_experiment(_run): ex.run() -def test_log_summary_writer_class(ex): +def test_log_summary_writer_class(ex, tf): """ Tests whether logdir is stored into the info dictionary when creating a new FileWriter object, but this time on a method of a class. @@ -132,5 +163,5 @@ def run_experiment(_run): ex.run() -if __name__ == '__main__': - pytest.main([__file__]) +if __name__ == "__main__": + test_log_file_writer(ex(), tf()) From 9f9b76fea12e6e0bbefecc03873a25f20e2acba9 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 19 Jul 2019 12:21:24 +0200 Subject: [PATCH 6/6] Forgot simple variable assignement. --- sacred/stflow/method_interception.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sacred/stflow/method_interception.py b/sacred/stflow/method_interception.py index 9b48586b..641e7e47 100644 --- a/sacred/stflow/method_interception.py +++ b/sacred/stflow/method_interception.py @@ -4,7 +4,7 @@ if opt.has_tensorflow: - opt.get_tensorflow() + tf = opt.get_tensorflow() else: tf = None