Skip to content

Commit

Permalink
Merge 80cf962 into 854bf59
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse committed Jul 22, 2019
2 parents 854bf59 + 80cf962 commit d317b77
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 40 deletions.
43 changes: 34 additions & 9 deletions sacred/optional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,42 @@ def optional_import(*package_names):
return False, None


def get_tensorflow():
def get_tensorflow(allow_mock=False):
# Ensures backward and forward compatibility with TensorFlow 1 and 2.
if get_package_version('tensorflow') < parse_version('1.13.1'):
import warnings
warnings.warn("Use of TensorFlow 1.12 and older is deprecated. "
"Use Tensorflow 1.13 or newer instead.",
DeprecationWarning)
import tensorflow as tf
if has_tensorflow:
if get_package_version('tensorflow') < parse_version('1.13.1'):
import warnings
warnings.warn("Use of TensorFlow 1.12 and older is deprecated. "
"Use Tensorflow 1.13 or newer instead.",
DeprecationWarning)
import tensorflow
else:
import tensorflow.compat.v1 as tensorflow
return tensorflow
elif allow_mock:
# Let's define a mocked tensorflow
class tensorflow:
class summary:
class FileWriter:
def __init__(self, logdir, graph):
self.logdir = logdir
self.graph = graph
print(f'Mocked FileWriter got '
f'logdir={logdir}, graph={graph}')

class Session:
def __init__(self):
self.graph = None

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass

return tensorflow
else:
import tensorflow.compat.v1 as tf
return tf
return None


# Get libc in a cross-platform way and use it to also flush the c stdio buffers
Expand Down
5 changes: 1 addition & 4 deletions sacred/stflow/method_interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import sacred.optional as opt


if opt.has_tensorflow:
tf = opt.get_tensorflow()
else:
tf = None
tf = opt.get_tensorflow()


class LogFileWriter(ContextDecorator, ContextMethodDecorator):
Expand Down
32 changes: 5 additions & 27 deletions tests/test_stflow/test_method_interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,11 @@ def tf():
Creates a simplified tensorflow interface if necessary,
so `tensorflow` is not required during the tests.
"""
from sacred.optional import has_tensorflow
if has_tensorflow:
return opt.get_tensorflow()
else:
# Let's define a mocked tensorflow
class tensorflow:
class summary:
class FileWriter:
def __init__(self, logdir, graph):
self.logdir = logdir
self.graph = graph
print("Mocked FileWriter got logdir=%s, graph=%s" % (logdir, graph))

class Session:
def __init__(self):
self.graph = None

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass

# Set stflow to use the mock as the test
import sacred.stflow.method_interception
sacred.stflow.method_interception.tf = tensorflow
return tensorflow

tensorflow = opt.get_tensorflow(allow_mock=True)
import sacred.stflow.method_interception
sacred.stflow.method_interception.tf = tensorflow
return tensorflow


def test_log_file_writer(ex, tf):
Expand Down

0 comments on commit d317b77

Please sign in to comment.