Skip to content

Commit

Permalink
Add multisample dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Tsimfer authored and Sergey Tsimfer committed Sep 2, 2019
1 parent 8504fb1 commit 0fb2423
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions batchflow/models/tf/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,54 @@ def __call__(self, inputs):


class Dropout(Layer):
""" Wrapper for dropout layer. """
def __init__(self, dropout_rate, **kwargs):
""" Wrapper for dropout layer.
Parameters
----------
dropout_prob : float
Fraction of the input units to drop.
multisample: bool, number, sequence
If evaluates to True, then batch is split into multiple parts,
dropout applied to each of them separately and then parts are concatenated back.
If True, then batch is split evenly into two parts.
If integer, then batch is split evenly into that number of parts; must be a divisor of batch size.
If float, then batch is split into parts of `multisample` and `1 - multisample` sizes.
If sequence of ints, then batch is split into parts of given sizes. Must sum up to the batch size.
If sequence of floats, then each float means proportion of sizes in batch and must sum up to 1.
"""
def __init__(self, dropout_rate, multisample=False, **kwargs):
self.dropout_rate = dropout_rate
self.multisample = multisample
self.kwargs = kwargs

def __call__(self, inputs, training):
return K.Dropout(rate=self.dropout_rate, **self.kwargs)(inputs, training)
d_layer = K.Dropout(rate=self.dropout_rate, **self.kwargs)

if self.multisample != False:

This comment has been minimized.

Copy link
@roman-kh

roman-kh Sep 2, 2019

Member

is not

if self.multisample == True:
self.multisample = 2
elif isinstance(self.multisample, float):
self.multisample = [self.multisample, 1 - self.multisample]

if isinstance(self.multisample, int):
sizes = self.multisample
elif isinstance(self.multisample, (tuple, list)):
if all([isinstance(item, int) for item in self.multisample]):
sizes = self.multisample
elif all([isinstance(item, float) for item in self.multisample]):
batch_size = tf.cast(tf.shape(inputs)[0], dtype=tf.float32)
sizes = tf.convert_to_tensor([batch_size*item for item in self.multisample])
sizes = tf.cast(tf.math.round(sizes), dtype=tf.int32)
elif isinstance(self.multisample, tf.Tensor):
sizes = self.multisample
print(sizes)
splitted = tf.split(inputs, sizes, axis=0, name='mdropout_split')
dropped = [d_layer(branch, training) for branch in splitted]
output = tf.concat(dropped, axis=0, name='mdropout_concat')
else:
output = d_layer(inputs, training)
return output



Expand Down

0 comments on commit 0fb2423

Please sign in to comment.