Skip to content

Commit

Permalink
Upgrade supported Tensorflow version to 2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed May 8, 2020
1 parent 9875d97 commit 94cd726
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 10 deletions.
5 changes: 1 addition & 4 deletions docs/examples/keras_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ def get_loss(self, y_target, pred=None):


if __name__ == "__main__":
# IMPORTANT: Enable eager execution
tf.compat.v1.enable_eager_execution()

tf.random.set_random_seed(42) # Fix random seed
tf.random.set_seed(42) # Fix random seed

# Load data
X, y = load_iris(True)
Expand Down
2 changes: 1 addition & 1 deletion docs/tut_tfkeras.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Since keras is a higher-lever interface for tensorflow and nowadays part of tens
Computing a counterfactual of a tensorflow/keras model is done by using the :func:`ceml.tfkeras.counterfactual.generate_counterfactual` function.

.. note::
We have to run in *eager execution mode* when computing a counterfactual!
We have to run in *eager execution mode* when computing a counterfactual! Since tensorflow 2, eager execution is enabled by default.

We must provide the tensorflow/keras model within a class that is derived from the :class:`ceml.model.model.ModelWithLoss` class.
In this class, we must overwrite the `predict` function and `get_loss` function which returns a loss that we want to use - a couple of differentiable loss functions are implemented in :class:`ceml.backend.tensorflow.costfunctions`.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ jaxlib==0.1.21
cvxpy==1.0.24
scikit-learn==0.22.2
sklearn-lvq==1.1.0
tensorflow==1.15.2
tensorflow==2.2.0
torch==1.5.0
3 changes: 1 addition & 2 deletions tests/tfkeras/test_tfkeras_linearregression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
sys.path.insert(0,'..')

import tensorflow as tf
tf.compat.v1.enable_eager_execution()
tf.random.set_random_seed(42)
tf.random.set_seed(42)

import numpy as np
np.random.seed(42)
Expand Down
3 changes: 1 addition & 2 deletions tests/tfkeras/test_tfkeras_softmaxregression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
sys.path.insert(0,'..')

import tensorflow as tf
tf.compat.v1.enable_eager_execution()
tf.compat.v1.random.set_random_seed(42)
tf.random.set_seed(42)

import numpy as np
np.random.seed(42)
Expand Down

0 comments on commit 94cd726

Please sign in to comment.