Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
#13 Fix eager mode
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed May 29, 2019
1 parent c67f172 commit 59e0a39
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 9 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python:
env:
- KERAS_BACKEND=tensorflow
- KERAS_BACKEND=tensorflow TF_KERAS=1
- KERAS_BACKEND=tensorflow TF_KERAS=1 TF_EAGER=1
- KERAS_BACKEND=theano THEANO_FLAGS=optimizer=fast_compile
# - KERAS_BACKEND=cntk PYTHONWARNINGS=ignore
install:
Expand Down
8 changes: 7 additions & 1 deletion keras_transformer/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
__all__ = [
'keras', 'utils', 'activations', 'applications', 'backend', 'datasets', 'engine',
'layers', 'preprocessing', 'wrappers', 'callbacks', 'constraints', 'initializers',
'metrics', 'models', 'losses', 'optimizers', 'regularizers',
'metrics', 'models', 'losses', 'optimizers', 'regularizers', 'EAGER_MODE'
]

EAGER_MODE = False

if 'TF_KERAS' in os.environ and os.environ['TF_KERAS'] != '0':
from tensorflow.python import keras
if 'TF_EAGER' in os.environ and os.environ['TF_EAGER'] != '0':
import tensorflow as tf
tf.enable_eager_execution()
EAGER_MODE = True
else:
import keras

Expand Down
11 changes: 9 additions & 2 deletions keras_transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,14 @@ def _get_max_suffix_repeat_times(tokens, max_len):
return max_repeat


def decode(model, tokens, start_token, end_token, pad_token, max_len=10000, max_repeat=10, max_repeat_block=10):
def decode(model,
tokens,
start_token,
end_token,
pad_token,
max_len=10000,
max_repeat=10,
max_repeat_block=10):
"""Decode with the given model and input tokens.
:param model: The trained model.
Expand Down Expand Up @@ -448,7 +455,7 @@ def decode(model, tokens, start_token, end_token, pad_token, max_len=10000, max_
max_input_len = max(max_input_len, len(tokens[i]))
for i in range(len(batch_inputs)):
batch_inputs[i] += [pad_token] * (max_input_len - len(batch_inputs[i]))
predicts = model.predict([np.asarray(batch_inputs), np.asarray(batch_outputs)])
predicts = model.predict([np.array(batch_inputs), np.array(batch_outputs)])
for i in range(len(predicts)):
last_token = np.argmax(predicts[i][-1])
decoder_inputs[index_map[i]].append(last_token)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ keras-pos-embd==0.10.0
keras-multi-head==0.20.0
keras-layer-normalization==0.12.0
keras-position-wise-feed-forward==0.5.0
keras-embed-sim==0.4.0
keras-embed-sim==0.5.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name='keras-transformer',
version='0.24.0',
version='0.25.0',
packages=find_packages(),
url='https://github.com/CyberZHG/keras-transformer',
license='MIT',
Expand Down
8 changes: 5 additions & 3 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import unittest
import numpy as np
from keras_transformer.backend import keras
from keras_transformer.backend import keras, EAGER_MODE
from keras_transformer import get_custom_objects, get_model, decode


Expand Down Expand Up @@ -57,8 +57,10 @@ def test_decode(self):
epochs=10,
batch_size=128,
)
model.save(model_path)
model = keras.models.load_model(model_path, custom_objects=get_custom_objects())
if not EAGER_MODE:
model.save(model_path)
if not EAGER_MODE:
model = keras.models.load_model(model_path, custom_objects=get_custom_objects())
decoded = decode(
model,
encoder_inputs_no_padding * 2,
Expand Down
5 changes: 4 additions & 1 deletion tests/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ def test_sample(self):
try:
results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).eval(session=K.get_session())
except Exception as e:
results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).eval()
try:
results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).eval()
except Exception as e:
results = gelu(K.constant([-30.0, -1.0, 0.0, 1.0, 30.0])).numpy()
self.assertEqual(0.0, results[0])
self.assertGreater(0.0, results[1])
self.assertLess(-1.0, results[1])
Expand Down

0 comments on commit 59e0a39

Please sign in to comment.