diff --git a/.travis.yml b/.travis.yml index 18e523e..006b76c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,15 +1,29 @@ +dist: xenial language: python -python: - - 2.7 - - 3.6 +python: "3.6" +env: + global: + - COVERALLS_PARALLEL=true + matrix: + - KERAS_BACKEND=tensorflow + - KERAS_BACKEND=tensorflow TF_KERAS=1 install: + - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; + - bash miniconda.sh -b -p $HOME/miniconda + - export PATH="$HOME/miniconda/bin:$PATH" + - conda config --set always_yes yes --set changeps1 no + - conda update -q conda + - conda info -a + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION + - source activate test-environment + - export LD_LIBRARY_PATH=$HOME/miniconda/envs/test-environment/lib/:$LD_LIBRARY_PATH - pip install --upgrade pip - pip install -r requirements.txt - pip install -r requirements-dev.txt - pip install coveralls -before_script: - - bash lint.sh script: - - bash test.sh + - ./test.sh after_success: coveralls +notifications: + webhooks: https://coveralls.io/webhook diff --git a/keras_drop_block/__init__.py b/keras_drop_block/__init__.py index 8654ccc..456a902 100644 --- a/keras_drop_block/__init__.py +++ b/keras_drop_block/__init__.py @@ -1 +1,3 @@ from .drop_block import DropBlock1D, DropBlock2D + +__version__ = '0.5.0' diff --git a/keras_drop_block/backend.py b/keras_drop_block/backend.py new file mode 100644 index 0000000..90dd5a9 --- /dev/null +++ b/keras_drop_block/backend.py @@ -0,0 +1,33 @@ +import os +from distutils.util import strtobool + +__all__ = [ + 'keras', 'utils', 'activations', 'applications', 'backend', 'datasets', + 'layers', 'preprocessing', 'wrappers', 'callbacks', 'constraints', 'initializers', + 'metrics', 'models', 'losses', 'optimizers', 'regularizers', 'TF_KERAS', +] + +TF_KERAS = strtobool(os.environ.get('TF_KERAS', '0')) + +if TF_KERAS: + import tensorflow as tf + keras = tf.keras +else: + import keras + +utils = keras.utils +activations = keras.activations +applications = keras.applications +backend = keras.backend +datasets = keras.datasets +layers = keras.layers +preprocessing = keras.preprocessing +wrappers = keras.wrappers +callbacks = keras.callbacks +constraints = keras.constraints +initializers = keras.initializers +metrics = keras.metrics +models = keras.models +losses = keras.losses +optimizers = keras.optimizers +regularizers = keras.regularizers diff --git a/keras_drop_block/drop_block.py b/keras_drop_block/drop_block.py index 0fce4bd..3769c83 100644 --- a/keras_drop_block/drop_block.py +++ b/keras_drop_block/drop_block.py @@ -1,5 +1,5 @@ -import keras -import keras.backend as K +from .backend import keras +from .backend import backend as K class DropBlock1D(keras.layers.Layer): @@ -23,9 +23,18 @@ def __init__(self, self.block_size = block_size self.keep_prob = keep_prob self.sync_channels = sync_channels - self.data_format = K.normalize_data_format(data_format) - self.input_spec = keras.engine.base_layer.InputSpec(ndim=3) + self.data_format = data_format self.supports_masking = True + self.seq_len = self.ones = self.zeros = None + + def build(self, input_shape): + if self.data_format == 'channels_first': + self.seq_len = input_shape[-1] + else: + self.seq_len = input_shape[1] + self.ones = K.ones(self.seq_len, name='ones') + self.zeros = K.zeros(self.seq_len, name='zeros') + super().build(input_shape) def get_config(self): config = {'block_size': self.block_size, @@ -41,35 +50,34 @@ def compute_mask(self, inputs, mask=None): def compute_output_shape(self, input_shape): return input_shape - def _get_gamma(self, feature_dim): + def _get_gamma(self): """Get the number of activation units to drop""" - feature_dim = K.cast(feature_dim, K.floatx()) + feature_dim = K.cast(self.seq_len, K.floatx()) block_size = K.constant(self.block_size, dtype=K.floatx()) return ((1.0 - self.keep_prob) / block_size) * (feature_dim / (feature_dim - block_size + 1.0)) - def _compute_valid_seed_region(self, seq_length): - positions = K.arange(seq_length) + def _compute_valid_seed_region(self): + positions = K.arange(self.seq_len) half_block_size = self.block_size // 2 valid_seed_region = K.switch( K.all( K.stack( [ positions >= half_block_size, - positions < seq_length - half_block_size, + positions < self.seq_len - half_block_size, ], axis=-1, ), axis=-1, ), - K.ones((seq_length,)), - K.zeros((seq_length,)), + self.ones, + self.zeros, ) return K.expand_dims(K.expand_dims(valid_seed_region, axis=0), axis=-1) def _compute_drop_mask(self, shape): - seq_length = shape[1] - mask = K.random_binomial(shape, p=self._get_gamma(seq_length)) - mask *= self._compute_valid_seed_region(seq_length) + mask = K.random_binomial(shape, p=self._get_gamma()) + mask *= self._compute_valid_seed_region() mask = keras.layers.MaxPool1D( pool_size=self.block_size, padding='same', @@ -119,9 +127,18 @@ def __init__(self, self.block_size = block_size self.keep_prob = keep_prob self.sync_channels = sync_channels - self.data_format = K.normalize_data_format(data_format) - self.input_spec = keras.engine.base_layer.InputSpec(ndim=4) + self.data_format = data_format self.supports_masking = True + self.height = self.width = self.ones = self.zeros = None + + def build(self, input_shape): + if self.data_format == 'channels_first': + self.height, self.width = input_shape[2], input_shape[3] + else: + self.height, self.width = input_shape[1], input_shape[2] + self.ones = K.ones((self.height, self.width), name='ones') + self.zeros = K.zeros((self.height, self.width), name='zeros') + super().build(input_shape) def get_config(self): config = {'block_size': self.block_size, @@ -137,17 +154,17 @@ def compute_mask(self, inputs, mask=None): def compute_output_shape(self, input_shape): return input_shape - def _get_gamma(self, height, width): + def _get_gamma(self): """Get the number of activation units to drop""" - height, width = K.cast(height, K.floatx()), K.cast(width, K.floatx()) + height, width = K.cast(self.height, K.floatx()), K.cast(self.width, K.floatx()) block_size = K.constant(self.block_size, dtype=K.floatx()) return ((1.0 - self.keep_prob) / (block_size ** 2)) *\ (height * width / ((height - block_size + 1.0) * (width - block_size + 1.0))) - def _compute_valid_seed_region(self, height, width): + def _compute_valid_seed_region(self): positions = K.concatenate([ - K.expand_dims(K.tile(K.expand_dims(K.arange(height), axis=1), [1, width]), axis=-1), - K.expand_dims(K.tile(K.expand_dims(K.arange(width), axis=0), [height, 1]), axis=-1), + K.expand_dims(K.tile(K.expand_dims(K.arange(self.height), axis=1), [1, self.width]), axis=-1), + K.expand_dims(K.tile(K.expand_dims(K.arange(self.width), axis=0), [self.height, 1]), axis=-1), ], axis=-1) half_block_size = self.block_size // 2 valid_seed_region = K.switch( @@ -156,22 +173,21 @@ def _compute_valid_seed_region(self, height, width): [ positions[:, :, 0] >= half_block_size, positions[:, :, 1] >= half_block_size, - positions[:, :, 0] < height - half_block_size, - positions[:, :, 1] < width - half_block_size, + positions[:, :, 0] < self.height - half_block_size, + positions[:, :, 1] < self.width - half_block_size, ], axis=-1, ), axis=-1, ), - K.ones((height, width)), - K.zeros((height, width)), + self.ones, + self.zeros, ) return K.expand_dims(K.expand_dims(valid_seed_region, axis=0), axis=-1) def _compute_drop_mask(self, shape): - height, width = shape[1], shape[2] - mask = K.random_binomial(shape, p=self._get_gamma(height, width)) - mask *= self._compute_valid_seed_region(height, width) + mask = K.random_binomial(shape, p=self._get_gamma()) + mask *= self._compute_valid_seed_region() mask = keras.layers.MaxPool2D( pool_size=(self.block_size, self.block_size), padding='same', diff --git a/lint.sh b/lint.sh deleted file mode 100755 index 08f29a0..0000000 --- a/lint.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/usr/bin/env bash -pycodestyle --max-line-length=120 keras_drop_block tests demo diff --git a/requirements-dev.txt b/requirements-dev.txt index e67855c..14d80e3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,7 @@ +setuptools>=38.6.0 +twine>=1.11.0 +wheel>=0.31.0 +nose tensorflow pycodestyle coverage diff --git a/setup.py b/setup.py index 05c7af1..1d982e8 100644 --- a/setup.py +++ b/setup.py @@ -1,30 +1,43 @@ +import os +import re import codecs from setuptools import setup, find_packages +current_path = os.path.abspath(os.path.dirname(__file__)) -with codecs.open('README.md', 'r', 'utf8') as reader: - long_description = reader.read() +def read_file(*parts): + with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: + return reader.read() -with codecs.open('requirements.txt', 'r', 'utf8') as reader: - install_requires = list(map(lambda x: x.strip(), reader.readlines())) + +def get_requirements(*parts): + with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: + return list(map(lambda x: x.strip(), reader.readlines())) + + +def find_version(*file_paths): + version_file = read_file(*file_paths) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) + if version_match: + return version_match.group(1) + raise RuntimeError('Unable to find version string.') setup( name='keras-drop-block', - version='0.4.0', + version=find_version('keras_drop_block', '__init__.py'), packages=find_packages(), url='https://github.com/CyberZHG/keras-drop-block', license='MIT', author='CyberZHG', - author_email='CyberZHG@gmail.com', + author_email='CyberZHG@users.noreply.github.com', description='DropBlock implemented in Keras', - long_description=long_description, + long_description=read_file('README.md'), long_description_content_type='text/markdown', - install_requires=install_requires, + install_requires=get_requirements('requirements.txt'), classifiers=( - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ), diff --git a/test.sh b/test.sh index d718532..a2614e8 100755 --- a/test.sh +++ b/test.sh @@ -1,2 +1,3 @@ #!/usr/bin/env bash -nosetests --with-coverage --cover-html --cover-html-dir=htmlcov --cover-package="keras_drop_block" tests \ No newline at end of file +pycodestyle --max-line-length=120 keras_drop_block tests && \ + nosetests --with-coverage --cover-html --cover-html-dir=htmlcov --cover-package=keras_drop_block tests diff --git a/tests/test_drop_block_1d.py b/tests/test_drop_block_1d.py index b3563e4..a43a11d 100644 --- a/tests/test_drop_block_1d.py +++ b/tests/test_drop_block_1d.py @@ -2,8 +2,10 @@ import random import tempfile import unittest -import keras + import numpy as np + +from keras_drop_block.backend import keras from keras_drop_block import DropBlock1D @@ -42,9 +44,7 @@ def test_training(self): def test_mask_shape(self): input_layer = keras.layers.Input(shape=(100, 3)) - drop_block_layer = keras.layers.Lambda( - lambda x: DropBlock1D(block_size=3, keep_prob=0.7)(x, training=True), - )(input_layer) + drop_block_layer = DropBlock1D(block_size=3, keep_prob=0.7)(input_layer, training=True) model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer) model.compile(optimizer='adam', loss='mse', metrics={}) model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random()) @@ -65,9 +65,8 @@ def test_mask_shape(self): self.assertTrue(0.65 < keep_prob < 0.8, keep_prob) input_layer = keras.layers.Input(shape=(3, 100)) - drop_block_layer = keras.layers.Lambda( - lambda x: DropBlock1D(block_size=3, keep_prob=0.7, data_format='channels_first')(x, training=True), - )(input_layer) + drop_block_layer = DropBlock1D(block_size=3, keep_prob=0.7, + data_format='channels_first')(input_layer, training=True) model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer) model.compile(optimizer='adam', loss='mse', metrics={}) model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random()) @@ -89,9 +88,7 @@ def test_mask_shape(self): def test_sync_channels(self): input_layer = keras.layers.Input(shape=(100, 3)) - drop_block_layer = keras.layers.Lambda( - lambda x: DropBlock1D(block_size=3, keep_prob=0.7, sync_channels=True)(x, training=True), - )(input_layer) + drop_block_layer = DropBlock1D(block_size=3, keep_prob=0.7, sync_channels=True)(input_layer, training=True) model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer) model.compile(optimizer='adam', loss='mse', metrics={}) model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random()) diff --git a/tests/test_drop_block_2d.py b/tests/test_drop_block_2d.py index 55d38d2..ee3ca5e 100644 --- a/tests/test_drop_block_2d.py +++ b/tests/test_drop_block_2d.py @@ -2,8 +2,10 @@ import random import tempfile import unittest -import keras + import numpy as np + +from keras_drop_block.backend import keras from keras_drop_block import DropBlock2D @@ -42,9 +44,7 @@ def test_training(self): def test_mask_shape(self): input_layer = keras.layers.Input(shape=(10, 10, 3)) - drop_block_layer = keras.layers.Lambda( - lambda x: DropBlock2D(block_size=3, keep_prob=0.7)(x, training=True), - )(input_layer) + drop_block_layer = DropBlock2D(block_size=3, keep_prob=0.7)(input_layer, training=True) model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer) model.compile(optimizer='adam', loss='mse', metrics={}) model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random()) @@ -65,9 +65,8 @@ def test_mask_shape(self): self.assertTrue(0.65 < keep_prob < 0.8, keep_prob) input_layer = keras.layers.Input(shape=(3, 10, 10)) - drop_block_layer = keras.layers.Lambda( - lambda x: DropBlock2D(block_size=3, keep_prob=0.7, data_format='channels_first')(x, training=True), - )(input_layer) + drop_block_layer = DropBlock2D(block_size=3, keep_prob=0.7, + data_format='channels_first')(input_layer, training=True) model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer) model.compile(optimizer='adam', loss='mse', metrics={}) model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random()) @@ -89,9 +88,7 @@ def test_mask_shape(self): def test_sync_channels(self): input_layer = keras.layers.Input(shape=(10, 10, 3)) - drop_block_layer = keras.layers.Lambda( - lambda x: DropBlock2D(block_size=3, keep_prob=0.7, sync_channels=True)(x, training=True), - )(input_layer) + drop_block_layer = DropBlock2D(block_size=3, keep_prob=0.7, sync_channels=True)(input_layer, training=True) model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer) model.compile(optimizer='adam', loss='mse', metrics={}) model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random())