Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Support masking
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Nov 26, 2018
1 parent 7fe0485 commit e704840
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 8 deletions.
34 changes: 27 additions & 7 deletions keras_targeted_dropout/targeted_dropout.py
Expand Up @@ -34,13 +34,33 @@ def compute_mask(self, inputs, mask=None):
def compute_output_shape(self, input_shape):
return input_shape

def _compute_target_mask(self, inputs):
def _compute_target_mask(self, inputs, mask=None):
input_shape = K.shape(inputs)
input_type = K.dtype(inputs)
mask_threshold = K.constant(1e8, dtype=input_type)
channel_num = input_shape[-1]
channel_dim = K.prod(input_shape[:-1])
norm = K.abs(inputs)
channeled_norm = K.transpose(K.reshape(norm, (channel_dim, input_shape[-1])))
idx = K.cast(self.target_rate * K.cast(channel_dim, K.floatx()), 'int32')
threshold = -K.tf.nn.top_k(-channeled_norm, k=idx).values[:, -1]
masked_inputs = inputs
if mask is not None:
masked_inputs = K.switch(
K.cast(mask, K.floatx()) > 0.5,
masked_inputs,
K.ones_like(masked_inputs, dtype=input_type) * mask_threshold
)
norm = K.abs(masked_inputs)
channeled_norm = K.transpose(K.reshape(norm, (channel_dim, channel_num)))
weight_num = K.sum(
K.reshape(K.cast(masked_inputs < mask_threshold, K.floatx()), (channel_dim, channel_num)),
axis=0,
)
indices = K.stack(
[
K.arange(channel_num, dtype='int32'),
K.cast(self.target_rate * weight_num, dtype='int32') - 1,
],
axis=-1,
)
threshold = -K.tf.gather_nd(K.tf.nn.top_k(-channeled_norm, k=K.max(indices[:, 1]) + 1).values, indices)
threshold = K.reshape(K.tile(threshold, [channel_dim]), input_shape)
target_mask = K.switch(
norm <= threshold,
Expand All @@ -49,8 +69,8 @@ def _compute_target_mask(self, inputs):
)
return target_mask

def call(self, inputs, training=None):
target_mask = self._compute_target_mask(inputs)
def call(self, inputs, mask=None, training=None):
target_mask = self._compute_target_mask(inputs, mask=mask)

def dropped_mask():
drop_mask = K.switch(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -2,7 +2,7 @@

setup(
name='keras-targeted-dropout',
version='0.2.0',
version='0.3.0',
packages=find_packages(),
url='https://github.com/CyberZHG/keras-targeted-dropout',
license='MIT',
Expand Down
36 changes: 36 additions & 0 deletions tests/test_targeted_dropout.py
Expand Up @@ -75,3 +75,39 @@ def test_drop_rate(self):
zero_num = np.sum((outputs == 0.0).astype(keras.backend.floatx()))
actual_rate = zero_num / 100000.0
self.assertTrue(0.15 < actual_rate < 0.17)

def test_masked(self):
model = keras.models.Sequential()
model.add(keras.layers.Masking(
batch_size=None,
input_shape=(None, None),
))
model.add(TargetedDropout(
drop_rate=0.4,
target_rate=0.6,
))
model.compile(optimizer='adam', loss='mse')
model_path = os.path.join(tempfile.gettempdir(), 'keras_targeted_dropout_%f.h5' % random.random())
model.save(model_path)
model = keras.models.load_model(
model_path,
custom_objects={'TargetedDropout': TargetedDropout},
)
model.summary()

inputs = np.array([
[[1, 5, 2, 3]],
[[0, 0, 0, 0]],
[[5, 2, 5, 1]],
[[2, 9, 4, 3]],
[[4, 7, 5, 6]],
])
outputs = model.predict(inputs)
expected = np.array([
[[0., 0., 0., 0.]],
[[0., 0., 0., 0.]],
[[5., 0., 5., 0.]],
[[0., 9., 0., 0.]],
[[4., 7., 5., 6.]],
])
self.assertTrue(np.allclose(expected, outputs), (expected, outputs))

0 comments on commit e704840

Please sign in to comment.