Skip to content

Commit

Permalink
Merge 9f9b76f into 37e6fc3
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse committed Jul 19, 2019
2 parents 37e6fc3 + 9f9b76f commit 61d118a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 28 deletions.
14 changes: 14 additions & 0 deletions sacred/optional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import importlib
from sacred.utils import modules_exist
from sacred.utils import get_package_version, parse_version


def optional_import(*package_names):
Expand All @@ -13,6 +14,19 @@ def optional_import(*package_names):
return False, None


def get_tensorflow():
# Ensures backward and forward compatibility with TensorFlow 1 and 2.
if get_package_version('tensorflow') < parse_version('1.13.1'):
import warnings
warnings.warn("Use of TensorFlow 1.12 and older is deprecated. "
"Use Tensorflow 1.13 or newer instead.",
DeprecationWarning)
import tensorflow as tf
else:
import tensorflow.compat.v1 as tf
return tf


# Get libc in a cross-platform way and use it to also flush the c stdio buffers
# credit to J.F. Sebastians SO answer from here:
# http://stackoverflow.com/a/22434262/1388435
Expand Down
13 changes: 2 additions & 11 deletions sacred/randomness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import random

import sacred.optional as opt
from sacred.utils import (module_is_in_cache, get_package_version,
parse_version)
from sacred.utils import module_is_in_cache

SEEDRANGE = (1, int(1e9))

Expand All @@ -30,15 +29,7 @@ def set_global_seed(seed):
if opt.has_numpy:
opt.np.random.seed(seed)
if module_is_in_cache('tensorflow'):
# Ensures backward and forward compatibility with TensorFlow 1 and 2.
if get_package_version('tensorflow') < parse_version('1.13.1'):
import warnings
warnings.warn("Use of TensorFlow 1.12 and older is deprecated. "
"Use Tensorflow 1.13 or newer instead.",
DeprecationWarning)
import tensorflow as tf
else:
import tensorflow.compat.v1 as tf
tf = opt.get_tensorflow()
tf.set_random_seed(seed)
if module_is_in_cache('torch'):
import torch
Expand Down
13 changes: 3 additions & 10 deletions sacred/stflow/method_interception.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
from .contextlibbackport import ContextDecorator
from .internal import ContextMethodDecorator
import sacred.optional as opt
from sacred.utils import get_package_version, parse_version


if opt.has_tensorflow:
# Ensures backward and forward compatibility with TensorFlow 1 and 2.
if get_package_version('tensorflow') < parse_version('1.13.1'):
import warnings
warnings.warn("Use of TensorFlow 1.12 and older is deprecated. "
"Use Tensorflow 1.13 or newer instead.",
DeprecationWarning)
import tensorflow as tf
else:
import tensorflow.compat.v1 as tf
tf = opt.get_tensorflow()
else:
tf = None

Expand Down
9 changes: 2 additions & 7 deletions tests/test_stflow/test_method_interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from sacred import Experiment
from sacred.stflow import LogFileWriter
from sacred.utils import get_package_version, parse_version
import sacred.optional as opt


@pytest.fixture
Expand All @@ -19,12 +19,7 @@ def tf():
"""
from sacred.optional import has_tensorflow
if has_tensorflow:
# Ensures backward and forward compatibility with TensorFlow 1 and 2.
if get_package_version('tensorflow') < parse_version('1.13.1'):
import tensorflow as tf
else:
import tensorflow.compat.v1 as tf
return tf
return opt.get_tensorflow()
else:
# Let's define a mocked tensorflow
class tensorflow:
Expand Down

0 comments on commit 61d118a

Please sign in to comment.