Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Fix init location
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed May 17, 2020
1 parent 40df4f2 commit 0edc5e3
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 67 deletions.
26 changes: 20 additions & 6 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
dist: xenial
language: python
python:
- 2.7
- 3.6
python: "3.6"
env:
global:
- COVERALLS_PARALLEL=true
matrix:
- KERAS_BACKEND=tensorflow
- KERAS_BACKEND=tensorflow TF_KERAS=1
install:
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH"
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION
- source activate test-environment
- export LD_LIBRARY_PATH=$HOME/miniconda/envs/test-environment/lib/:$LD_LIBRARY_PATH
- pip install --upgrade pip
- pip install -r requirements.txt
- pip install -r requirements-dev.txt
- pip install coveralls
before_script:
- bash lint.sh
script:
- bash test.sh
- ./test.sh
after_success:
coveralls
notifications:
webhooks: https://coveralls.io/webhook
2 changes: 2 additions & 0 deletions keras_drop_block/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .drop_block import DropBlock1D, DropBlock2D

__version__ = '0.5.0'
33 changes: 33 additions & 0 deletions keras_drop_block/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
from distutils.util import strtobool

__all__ = [
'keras', 'utils', 'activations', 'applications', 'backend', 'datasets',
'layers', 'preprocessing', 'wrappers', 'callbacks', 'constraints', 'initializers',
'metrics', 'models', 'losses', 'optimizers', 'regularizers', 'TF_KERAS',
]

TF_KERAS = strtobool(os.environ.get('TF_KERAS', '0'))

if TF_KERAS:
import tensorflow as tf
keras = tf.keras
else:
import keras

utils = keras.utils
activations = keras.activations
applications = keras.applications
backend = keras.backend
datasets = keras.datasets
layers = keras.layers
preprocessing = keras.preprocessing
wrappers = keras.wrappers
callbacks = keras.callbacks
constraints = keras.constraints
initializers = keras.initializers
metrics = keras.metrics
models = keras.models
losses = keras.losses
optimizers = keras.optimizers
regularizers = keras.regularizers
72 changes: 44 additions & 28 deletions keras_drop_block/drop_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import keras
import keras.backend as K
from .backend import keras
from .backend import backend as K


class DropBlock1D(keras.layers.Layer):
Expand All @@ -23,9 +23,18 @@ def __init__(self,
self.block_size = block_size
self.keep_prob = keep_prob
self.sync_channels = sync_channels
self.data_format = K.normalize_data_format(data_format)
self.input_spec = keras.engine.base_layer.InputSpec(ndim=3)
self.data_format = data_format
self.supports_masking = True
self.seq_len = self.ones = self.zeros = None

def build(self, input_shape):
if self.data_format == 'channels_first':
self.seq_len = input_shape[-1]
else:
self.seq_len = input_shape[1]
self.ones = K.ones(self.seq_len, name='ones')
self.zeros = K.zeros(self.seq_len, name='zeros')
super().build(input_shape)

def get_config(self):
config = {'block_size': self.block_size,
Expand All @@ -41,35 +50,34 @@ def compute_mask(self, inputs, mask=None):
def compute_output_shape(self, input_shape):
return input_shape

def _get_gamma(self, feature_dim):
def _get_gamma(self):
"""Get the number of activation units to drop"""
feature_dim = K.cast(feature_dim, K.floatx())
feature_dim = K.cast(self.seq_len, K.floatx())
block_size = K.constant(self.block_size, dtype=K.floatx())
return ((1.0 - self.keep_prob) / block_size) * (feature_dim / (feature_dim - block_size + 1.0))

def _compute_valid_seed_region(self, seq_length):
positions = K.arange(seq_length)
def _compute_valid_seed_region(self):
positions = K.arange(self.seq_len)
half_block_size = self.block_size // 2
valid_seed_region = K.switch(
K.all(
K.stack(
[
positions >= half_block_size,
positions < seq_length - half_block_size,
positions < self.seq_len - half_block_size,
],
axis=-1,
),
axis=-1,
),
K.ones((seq_length,)),
K.zeros((seq_length,)),
self.ones,
self.zeros,
)
return K.expand_dims(K.expand_dims(valid_seed_region, axis=0), axis=-1)

def _compute_drop_mask(self, shape):
seq_length = shape[1]
mask = K.random_binomial(shape, p=self._get_gamma(seq_length))
mask *= self._compute_valid_seed_region(seq_length)
mask = K.random_binomial(shape, p=self._get_gamma())
mask *= self._compute_valid_seed_region()
mask = keras.layers.MaxPool1D(
pool_size=self.block_size,
padding='same',
Expand Down Expand Up @@ -119,9 +127,18 @@ def __init__(self,
self.block_size = block_size
self.keep_prob = keep_prob
self.sync_channels = sync_channels
self.data_format = K.normalize_data_format(data_format)
self.input_spec = keras.engine.base_layer.InputSpec(ndim=4)
self.data_format = data_format
self.supports_masking = True
self.height = self.width = self.ones = self.zeros = None

def build(self, input_shape):
if self.data_format == 'channels_first':
self.height, self.width = input_shape[2], input_shape[3]
else:
self.height, self.width = input_shape[1], input_shape[2]
self.ones = K.ones((self.height, self.width), name='ones')
self.zeros = K.zeros((self.height, self.width), name='zeros')
super().build(input_shape)

def get_config(self):
config = {'block_size': self.block_size,
Expand All @@ -137,17 +154,17 @@ def compute_mask(self, inputs, mask=None):
def compute_output_shape(self, input_shape):
return input_shape

def _get_gamma(self, height, width):
def _get_gamma(self):
"""Get the number of activation units to drop"""
height, width = K.cast(height, K.floatx()), K.cast(width, K.floatx())
height, width = K.cast(self.height, K.floatx()), K.cast(self.width, K.floatx())
block_size = K.constant(self.block_size, dtype=K.floatx())
return ((1.0 - self.keep_prob) / (block_size ** 2)) *\
(height * width / ((height - block_size + 1.0) * (width - block_size + 1.0)))

def _compute_valid_seed_region(self, height, width):
def _compute_valid_seed_region(self):
positions = K.concatenate([
K.expand_dims(K.tile(K.expand_dims(K.arange(height), axis=1), [1, width]), axis=-1),
K.expand_dims(K.tile(K.expand_dims(K.arange(width), axis=0), [height, 1]), axis=-1),
K.expand_dims(K.tile(K.expand_dims(K.arange(self.height), axis=1), [1, self.width]), axis=-1),
K.expand_dims(K.tile(K.expand_dims(K.arange(self.width), axis=0), [self.height, 1]), axis=-1),
], axis=-1)
half_block_size = self.block_size // 2
valid_seed_region = K.switch(
Expand All @@ -156,22 +173,21 @@ def _compute_valid_seed_region(self, height, width):
[
positions[:, :, 0] >= half_block_size,
positions[:, :, 1] >= half_block_size,
positions[:, :, 0] < height - half_block_size,
positions[:, :, 1] < width - half_block_size,
positions[:, :, 0] < self.height - half_block_size,
positions[:, :, 1] < self.width - half_block_size,
],
axis=-1,
),
axis=-1,
),
K.ones((height, width)),
K.zeros((height, width)),
self.ones,
self.zeros,
)
return K.expand_dims(K.expand_dims(valid_seed_region, axis=0), axis=-1)

def _compute_drop_mask(self, shape):
height, width = shape[1], shape[2]
mask = K.random_binomial(shape, p=self._get_gamma(height, width))
mask *= self._compute_valid_seed_region(height, width)
mask = K.random_binomial(shape, p=self._get_gamma())
mask *= self._compute_valid_seed_region()
mask = keras.layers.MaxPool2D(
pool_size=(self.block_size, self.block_size),
padding='same',
Expand Down
2 changes: 0 additions & 2 deletions lint.sh

This file was deleted.

4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
setuptools>=38.6.0
twine>=1.11.0
wheel>=0.31.0
nose
tensorflow
pycodestyle
coverage
33 changes: 23 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,43 @@
import os
import re
import codecs
from setuptools import setup, find_packages

current_path = os.path.abspath(os.path.dirname(__file__))

with codecs.open('README.md', 'r', 'utf8') as reader:
long_description = reader.read()

def read_file(*parts):
with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader:
return reader.read()

with codecs.open('requirements.txt', 'r', 'utf8') as reader:
install_requires = list(map(lambda x: x.strip(), reader.readlines()))

def get_requirements(*parts):
with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader:
return list(map(lambda x: x.strip(), reader.readlines()))


def find_version(*file_paths):
version_file = read_file(*file_paths)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError('Unable to find version string.')


setup(
name='keras-drop-block',
version='0.4.0',
version=find_version('keras_drop_block', '__init__.py'),
packages=find_packages(),
url='https://github.com/CyberZHG/keras-drop-block',
license='MIT',
author='CyberZHG',
author_email='CyberZHG@gmail.com',
author_email='CyberZHG@users.noreply.github.com',
description='DropBlock implemented in Keras',
long_description=long_description,
long_description=read_file('README.md'),
long_description_content_type='text/markdown',
install_requires=install_requires,
install_requires=get_requirements('requirements.txt'),
classifiers=(
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
),
Expand Down
3 changes: 2 additions & 1 deletion test.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
#!/usr/bin/env bash
nosetests --with-coverage --cover-html --cover-html-dir=htmlcov --cover-package="keras_drop_block" tests
pycodestyle --max-line-length=120 keras_drop_block tests && \
nosetests --with-coverage --cover-html --cover-html-dir=htmlcov --cover-package=keras_drop_block tests
17 changes: 7 additions & 10 deletions tests/test_drop_block_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import random
import tempfile
import unittest
import keras

import numpy as np

from keras_drop_block.backend import keras
from keras_drop_block import DropBlock1D


Expand Down Expand Up @@ -42,9 +44,7 @@ def test_training(self):

def test_mask_shape(self):
input_layer = keras.layers.Input(shape=(100, 3))
drop_block_layer = keras.layers.Lambda(
lambda x: DropBlock1D(block_size=3, keep_prob=0.7)(x, training=True),
)(input_layer)
drop_block_layer = DropBlock1D(block_size=3, keep_prob=0.7)(input_layer, training=True)
model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer)
model.compile(optimizer='adam', loss='mse', metrics={})
model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random())
Expand All @@ -65,9 +65,8 @@ def test_mask_shape(self):
self.assertTrue(0.65 < keep_prob < 0.8, keep_prob)

input_layer = keras.layers.Input(shape=(3, 100))
drop_block_layer = keras.layers.Lambda(
lambda x: DropBlock1D(block_size=3, keep_prob=0.7, data_format='channels_first')(x, training=True),
)(input_layer)
drop_block_layer = DropBlock1D(block_size=3, keep_prob=0.7,
data_format='channels_first')(input_layer, training=True)
model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer)
model.compile(optimizer='adam', loss='mse', metrics={})
model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random())
Expand All @@ -89,9 +88,7 @@ def test_mask_shape(self):

def test_sync_channels(self):
input_layer = keras.layers.Input(shape=(100, 3))
drop_block_layer = keras.layers.Lambda(
lambda x: DropBlock1D(block_size=3, keep_prob=0.7, sync_channels=True)(x, training=True),
)(input_layer)
drop_block_layer = DropBlock1D(block_size=3, keep_prob=0.7, sync_channels=True)(input_layer, training=True)
model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer)
model.compile(optimizer='adam', loss='mse', metrics={})
model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random())
Expand Down
17 changes: 7 additions & 10 deletions tests/test_drop_block_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import random
import tempfile
import unittest
import keras

import numpy as np

from keras_drop_block.backend import keras
from keras_drop_block import DropBlock2D


Expand Down Expand Up @@ -42,9 +44,7 @@ def test_training(self):

def test_mask_shape(self):
input_layer = keras.layers.Input(shape=(10, 10, 3))
drop_block_layer = keras.layers.Lambda(
lambda x: DropBlock2D(block_size=3, keep_prob=0.7)(x, training=True),
)(input_layer)
drop_block_layer = DropBlock2D(block_size=3, keep_prob=0.7)(input_layer, training=True)
model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer)
model.compile(optimizer='adam', loss='mse', metrics={})
model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random())
Expand All @@ -65,9 +65,8 @@ def test_mask_shape(self):
self.assertTrue(0.65 < keep_prob < 0.8, keep_prob)

input_layer = keras.layers.Input(shape=(3, 10, 10))
drop_block_layer = keras.layers.Lambda(
lambda x: DropBlock2D(block_size=3, keep_prob=0.7, data_format='channels_first')(x, training=True),
)(input_layer)
drop_block_layer = DropBlock2D(block_size=3, keep_prob=0.7,
data_format='channels_first')(input_layer, training=True)
model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer)
model.compile(optimizer='adam', loss='mse', metrics={})
model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random())
Expand All @@ -89,9 +88,7 @@ def test_mask_shape(self):

def test_sync_channels(self):
input_layer = keras.layers.Input(shape=(10, 10, 3))
drop_block_layer = keras.layers.Lambda(
lambda x: DropBlock2D(block_size=3, keep_prob=0.7, sync_channels=True)(x, training=True),
)(input_layer)
drop_block_layer = DropBlock2D(block_size=3, keep_prob=0.7, sync_channels=True)(input_layer, training=True)
model = keras.models.Model(inputs=input_layer, outputs=drop_block_layer)
model.compile(optimizer='adam', loss='mse', metrics={})
model_path = os.path.join(tempfile.gettempdir(), 'keras_drop_block_%f.h5' % random.random())
Expand Down

0 comments on commit 0edc5e3

Please sign in to comment.