diff --git a/.travis.yml b/.travis.yml index 9a3b741..006b76c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,12 +1,12 @@ dist: xenial language: python -python: - - 3.6 +python: "3.6" env: - - KERAS_BACKEND=tensorflow TF_KERAS=1 TF_2=1 - - KERAS_BACKEND=tensorflow TF_KERAS=1 TF_EAGER=1 - - KERAS_BACKEND=tensorflow TF_KERAS=1 - - KERAS_BACKEND=tensorflow + global: + - COVERALLS_PARALLEL=true + matrix: + - KERAS_BACKEND=tensorflow + - KERAS_BACKEND=tensorflow TF_KERAS=1 install: - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; - bash miniconda.sh -b -p $HOME/miniconda @@ -20,9 +20,10 @@ install: - pip install --upgrade pip - pip install -r requirements.txt - pip install -r requirements-dev.txt - - if [[ $TF_2 == "1" ]]; then pip install tensorflow==2.0.0-beta1; fi - pip install coveralls script: - ./test.sh after_success: coveralls +notifications: + webhooks: https://coveralls.io/webhook diff --git a/keras_trans_mask/__init__.py b/keras_trans_mask/__init__.py index eab0121..0e384c7 100644 --- a/keras_trans_mask/__init__.py +++ b/keras_trans_mask/__init__.py @@ -1 +1,3 @@ from .masks import * + +__version__ = '0.4.0' diff --git a/keras_trans_mask/backend.py b/keras_trans_mask/backend.py index dae3a96..90dd5a9 100644 --- a/keras_trans_mask/backend.py +++ b/keras_trans_mask/backend.py @@ -1,25 +1,17 @@ import os +from distutils.util import strtobool __all__ = [ - 'keras', 'utils', 'activations', 'applications', 'backend', 'datasets', 'engine', + 'keras', 'utils', 'activations', 'applications', 'backend', 'datasets', 'layers', 'preprocessing', 'wrappers', 'callbacks', 'constraints', 'initializers', - 'metrics', 'models', 'losses', 'optimizers', 'regularizers', 'TF_KERAS', 'EAGER_MODE' + 'metrics', 'models', 'losses', 'optimizers', 'regularizers', 'TF_KERAS', ] -TF_KERAS = False -EAGER_MODE = False +TF_KERAS = strtobool(os.environ.get('TF_KERAS', '0')) -if os.environ.get('TF_KERAS', '0') != '0': +if TF_KERAS: import tensorflow as tf - from tensorflow.python import keras - TF_KERAS = True - if os.environ.get('TF_EAGER', '0') != '0': - try: - tf.enable_eager_execution() - raise AttributeError() - except AttributeError as e: - pass - EAGER_MODE = tf.executing_eagerly() + keras = tf.keras else: import keras @@ -28,7 +20,6 @@ applications = keras.applications backend = keras.backend datasets = keras.datasets -engine = keras.engine layers = keras.layers preprocessing = keras.preprocessing wrappers = keras.wrappers diff --git a/keras_trans_mask/masks.py b/keras_trans_mask/masks.py index 899c17f..dbe514d 100644 --- a/keras_trans_mask/masks.py +++ b/keras_trans_mask/masks.py @@ -56,7 +56,7 @@ def compute_mask(self, inputs, mask=None): return None def call(self, inputs, **kwargs): - return K.identity(inputs) + return inputs + 0.0 class RestoreMask(keras.layers.Layer): @@ -81,4 +81,4 @@ def compute_mask(self, inputs, mask=None): return mask[1] def call(self, inputs, **kwargs): - return K.identity(inputs[0]) + return inputs[0] + 0.0 diff --git a/setup.py b/setup.py index 2a64861..13989d1 100644 --- a/setup.py +++ b/setup.py @@ -1,30 +1,43 @@ +import os +import re import codecs from setuptools import setup, find_packages +current_path = os.path.abspath(os.path.dirname(__file__)) -with codecs.open('README.md', 'r', 'utf8') as reader: - long_description = reader.read() +def read_file(*parts): + with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: + return reader.read() -with codecs.open('requirements.txt', 'r', 'utf8') as reader: - install_requires = list(map(lambda x: x.strip(), reader.readlines())) + +def get_requirements(*parts): + with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: + return list(map(lambda x: x.strip(), reader.readlines())) + + +def find_version(*file_paths): + version_file = read_file(*file_paths) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) + if version_match: + return version_match.group(1) + raise RuntimeError('Unable to find version string.') setup( name='keras-trans-mask', - version='0.3.0', + version=find_version('keras_trans_mask', '__init__.py'), packages=find_packages(), url='https://github.com/CyberZHG/keras-trans-mask', license='MIT', author='CyberZHG', author_email='CyberZHG@users.noreply.github.com', description='Transfer masking in Keras', - long_description=long_description, + long_description=read_file('README.md'), long_description_content_type='text/markdown', - install_requires=install_requires, + install_requires=get_requirements('requirements.txt'), classifiers=( - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ), diff --git a/tests/test_masks.py b/tests/test_masks.py index ce1d281..88da03a 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -1,4 +1,6 @@ from unittest import TestCase +import os +import tempfile import numpy as np @@ -33,4 +35,11 @@ def test_over_fit(self): [6, 7, 8, 9, 9, 9, 9, 9], ] * 1024) y = np.array([[0], [1]] * 1024) + model_path = os.path.join(tempfile.gettempdir(), 'test_trans_mask_%f.h5' % np.random.random()) + model.save(model_path) + model = keras.models.load_model(model_path, custom_objects={ + 'CreateMask': CreateMask, + 'RemoveMask': RemoveMask, + 'RestoreMask': RestoreMask, + }) model.fit(x, y, epochs=10)