From 59e0a3969a3799c2eab3ed06b6f540622132f960 Mon Sep 17 00:00:00 2001 From: CyberZHG Date: Wed, 29 May 2019 22:46:20 +0800 Subject: [PATCH] #13 Fix eager mode --- .travis.yml | 1 + keras_transformer/backend.py | 8 +++++++- keras_transformer/transformer.py | 11 +++++++++-- requirements.txt | 2 +- setup.py | 2 +- tests/test_decode.py | 8 +++++--- tests/test_gelu.py | 5 ++++- 7 files changed, 28 insertions(+), 9 deletions(-) diff --git a/.travis.yml b/.travis.yml index 104440f..faa4be6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,6 +6,7 @@ python: env: - KERAS_BACKEND=tensorflow - KERAS_BACKEND=tensorflow TF_KERAS=1 + - KERAS_BACKEND=tensorflow TF_KERAS=1 TF_EAGER=1 - KERAS_BACKEND=theano THEANO_FLAGS=optimizer=fast_compile # - KERAS_BACKEND=cntk PYTHONWARNINGS=ignore install: diff --git a/keras_transformer/backend.py b/keras_transformer/backend.py index f279ccd..a16b960 100644 --- a/keras_transformer/backend.py +++ b/keras_transformer/backend.py @@ -3,11 +3,17 @@ __all__ = [ 'keras', 'utils', 'activations', 'applications', 'backend', 'datasets', 'engine', 'layers', 'preprocessing', 'wrappers', 'callbacks', 'constraints', 'initializers', - 'metrics', 'models', 'losses', 'optimizers', 'regularizers', + 'metrics', 'models', 'losses', 'optimizers', 'regularizers', 'EAGER_MODE' ] +EAGER_MODE = False + if 'TF_KERAS' in os.environ and os.environ['TF_KERAS'] != '0': from tensorflow.python import keras + if 'TF_EAGER' in os.environ and os.environ['TF_EAGER'] != '0': + import tensorflow as tf + tf.enable_eager_execution() + EAGER_MODE = True else: import keras diff --git a/keras_transformer/transformer.py b/keras_transformer/transformer.py index 8fe3ccb..2fd1049 100644 --- a/keras_transformer/transformer.py +++ b/keras_transformer/transformer.py @@ -415,7 +415,14 @@ def _get_max_suffix_repeat_times(tokens, max_len): return max_repeat -def decode(model, tokens, start_token, end_token, pad_token, max_len=10000, max_repeat=10, max_repeat_block=10): +def decode(model, + tokens, + start_token, + end_token, + pad_token, + max_len=10000, + max_repeat=10, + max_repeat_block=10): """Decode with the given model and input tokens. :param model: The trained model. @@ -448,7 +455,7 @@ def decode(model, tokens, start_token, end_token, pad_token, max_len=10000, max_ max_input_len = max(max_input_len, len(tokens[i])) for i in range(len(batch_inputs)): batch_inputs[i] += [pad_token] * (max_input_len - len(batch_inputs[i])) - predicts = model.predict([np.asarray(batch_inputs), np.asarray(batch_outputs)]) + predicts = model.predict([np.array(batch_inputs), np.array(batch_outputs)]) for i in range(len(predicts)): last_token = np.argmax(predicts[i][-1]) decoder_inputs[index_map[i]].append(last_token) diff --git a/requirements.txt b/requirements.txt index 6581e39..5bae7cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ keras-pos-embd==0.10.0 keras-multi-head==0.20.0 keras-layer-normalization==0.12.0 keras-position-wise-feed-forward==0.5.0 -keras-embed-sim==0.4.0 +keras-embed-sim==0.5.0 diff --git a/setup.py b/setup.py index 001fc33..151f766 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( name='keras-transformer', - version='0.24.0', + version='0.25.0', packages=find_packages(), url='https://github.com/CyberZHG/keras-transformer', license='MIT', diff --git a/tests/test_decode.py b/tests/test_decode.py index ce1ea35..da0842e 100644 --- a/tests/test_decode.py +++ b/tests/test_decode.py @@ -1,7 +1,7 @@ import os import unittest import numpy as np -from keras_transformer.backend import keras +from keras_transformer.backend import keras, EAGER_MODE from keras_transformer import get_custom_objects, get_model, decode @@ -57,8 +57,10 @@ def test_decode(self): epochs=10, batch_size=128, ) - model.save(model_path) - model = keras.models.load_model(model_path, custom_objects=get_custom_objects()) + if not EAGER_MODE: + model.save(model_path) + if not EAGER_MODE: + model = keras.models.load_model(model_path, custom_objects=get_custom_objects()) decoded = decode( model, encoder_inputs_no_padding * 2, diff --git a/tests/test_gelu.py b/tests/test_gelu.py index 9777571..b6a7830 100644 --- a/tests/test_gelu.py +++ b/tests/test_gelu.py @@ -9,7 +9,10 @@ def test_sample(self): try: results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).eval(session=K.get_session()) except Exception as e: - results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).eval() + try: + results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).eval() + except Exception as e: + results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).numpy() self.assertEqual(0.0, results[0]) self.assertGreater(0.0, results[1]) self.assertLess(-1.0, results[1])