Skip to content

Commit

Permalink
Merge pull request tensorflow#539 from wwwind:clustering_registry_imp…
Browse files Browse the repository at this point in the history
…rovement

PiperOrigin-RevId: 333336338
  • Loading branch information
tensorflower-gardener committed Sep 23, 2020
2 parents 87c06eb + 33a8030 commit 270435d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 70 deletions.
Expand Up @@ -65,6 +65,7 @@ def setUp(self):
self.keras_unsupported_layer = layers.ConvLSTM2D(2, (5, 5)) # Unsupported
self.custom_clusterable_layer = CustomClusterableLayer(10)
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))

clustering_registry.ClusteringLookupRegistry.register_new_implementation(
{
Expand All @@ -81,10 +82,10 @@ def setUp(self):
cluster_config.CentroidInitialization.DENSITY_BASED
}

def _build_clustered_layer_model(self, layer):
def _build_clustered_layer_model(self, layer, input_shape=(10, 1)):
wrapped_layer = cluster.cluster_weights(layer, **self.params)
self.model.add(wrapped_layer)
self.model.build(input_shape=(10, 1))
self.model.build(input_shape=input_shape)

return wrapped_layer

Expand Down Expand Up @@ -124,13 +125,32 @@ def testClusterKerasNonClusterableLayer(self):
wrapped_layer)
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())

@keras_parameterized.run_all_keras_modes
def testDepthwiseConv2DLayerNonClusterable(self):
"""
Verifies that we don't cluster a DepthwiseConv2D layer,
because clustering of this type of layer gives
big unrecoverable accuracy loss.
"""
wrapped_layer = self._build_clustered_layer_model(
self.keras_depthwiseconv2d_layer,
input_shape=(1, 10, 10, 10)
)

self._validate_clustered_layer(self.keras_depthwiseconv2d_layer,
wrapped_layer)
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())

def testClusterKerasUnsupportedLayer(self):
"""
Verifies that attempting to cluster an unsupported layer raises an
exception.
"""
keras_unsupported_layer = self.keras_unsupported_layer
# We need to build weights before check.
keras_unsupported_layer.build(input_shape = (10, 10))
with self.assertRaises(ValueError):
cluster.cluster_weights(self.keras_unsupported_layer, **self.params)
cluster.cluster_weights(keras_unsupported_layer, **self.params)

@keras_parameterized.run_all_keras_modes
def testClusterCustomClusterableLayer(self):
Expand All @@ -149,8 +169,14 @@ def testClusterCustomNonClusterableLayer(self):
Verifies that attempting to cluster a custom non-clusterable layer raises
an exception.
"""
custom_non_clusterable_layer = self.custom_non_clusterable_layer
# Once layer is empty with no weights allocated, clustering is supported.
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
**self.params)
# We need to build weights before check that clustering is not supported.
custom_non_clusterable_layer.build(input_shape=(10, 10))
with self.assertRaises(ValueError):
cluster_wrapper.ClusterWeights(self.custom_non_clusterable_layer,
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
**self.params)

@keras_parameterized.run_all_keras_modes
Expand Down Expand Up @@ -206,11 +232,14 @@ def testClusterModelUnsupportedKerasLayerRaisesError(self):
Verifies that attempting to cluster a model that contains an unsupported
layer raises an exception.
"""
keras_unsupported_layer = self.keras_unsupported_layer
# We need to build weights before check.
keras_unsupported_layer.build(input_shape = (10, 10))
with self.assertRaises(ValueError):
cluster.cluster_weights(
keras.Sequential([
self.keras_clusterable_layer, self.keras_non_clusterable_layer,
self.custom_clusterable_layer, self.keras_unsupported_layer
self.custom_clusterable_layer, keras_unsupported_layer
]), **self.params)

def testClusterModelCustomNonClusterableLayerRaisesError(self):
Expand All @@ -219,10 +248,13 @@ def testClusterModelCustomNonClusterableLayerRaisesError(self):
non-clusterable layer raises an exception.
"""
with self.assertRaises(ValueError):
custom_non_clusterable_layer = self.custom_non_clusterable_layer
# We need to build weights before check.
custom_non_clusterable_layer.build(input_shape = (1, 2))
cluster.cluster_weights(
keras.Sequential([
self.keras_clusterable_layer, self.keras_non_clusterable_layer,
self.custom_clusterable_layer, self.custom_non_clusterable_layer
self.custom_clusterable_layer, custom_non_clusterable_layer
]), **self.params)

@keras_parameterized.run_all_keras_modes
Expand Down
Expand Up @@ -32,9 +32,16 @@
CentroidInitialization = cluster_config.CentroidInitialization


class NonClusterableLayer(layers.Dense):
"""A custom layer that is not clusterable."""

class NonClusterableLayer(layers.Layer):
""""A custom layer with weights that is not clusterable."""
def __init__(self, units=10):
super(NonClusterableLayer, self).__init__()
self.add_weight(shape=(1, units),
initializer='uniform',
name='kernel')

def call(self, inputs):
return tf.matmul(inputs, self.weights)

class AlreadyClusterableLayer(layers.Dense, clusterable_layer.ClusterableLayer):
"""A custom layer that is clusterable."""
Expand Down
Expand Up @@ -17,7 +17,6 @@
import abc
import six
import tensorflow as tf

from tensorflow.keras import layers

from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
Expand Down Expand Up @@ -260,69 +259,21 @@ class ClusteringRegistry(object):
# the variables within the layers which hold the kernel weights. This
# allows the wrapper to access and modify the weights.
_LAYERS_WEIGHTS_MAP = {
layers.ELU: [],
layers.LeakyReLU: [],
layers.ReLU: [],
layers.Softmax: [],
layers.ThresholdedReLU: [],
layers.Conv1D: ['kernel'],
layers.Conv2D: ['kernel'],
layers.Conv2DTranspose: ['kernel'],
layers.Conv3D: ['kernel'],
layers.Conv3DTranspose: ['kernel'],
layers.Cropping1D: [],
layers.Cropping2D: [],
layers.Cropping3D: [],
# non-clusterable due to big unrecoverable accuracy loss
layers.DepthwiseConv2D: [],
layers.SeparableConv1D: ['pointwise_kernel'],
layers.SeparableConv2D: ['pointwise_kernel'],
layers.UpSampling1D: [],
layers.UpSampling2D: [],
layers.UpSampling3D: [],
layers.ZeroPadding1D: [],
layers.ZeroPadding2D: [],
layers.ZeroPadding3D: [],
layers.Activation: [],
layers.ActivityRegularization: [],
layers.Dense: ['kernel'],
layers.Dropout: [],
layers.Flatten: [],
layers.Lambda: [],
layers.Masking: [],
layers.Permute: [],
layers.RepeatVector: [],
layers.Reshape: [],
layers.SpatialDropout1D: [],
layers.SpatialDropout2D: [],
layers.SpatialDropout3D: [],
layers.Embedding: ['embeddings'],
layers.LocallyConnected1D: ['kernel'],
layers.LocallyConnected2D: ['kernel'],
layers.Add: [],
layers.Average: [],
layers.Concatenate: [],
layers.Dot: [],
layers.Maximum: [],
layers.Minimum: [],
layers.Multiply: [],
layers.Subtract: [],
layers.AlphaDropout: [],
layers.GaussianDropout: [],
layers.GaussianNoise: [],
layers.BatchNormalization: [],
layers.LayerNormalization: [],
layers.AveragePooling1D: [],
layers.AveragePooling2D: [],
layers.AveragePooling3D: [],
layers.GlobalAveragePooling1D: [],
layers.GlobalAveragePooling2D: [],
layers.GlobalAveragePooling3D: [],
layers.GlobalMaxPooling1D: [],
layers.GlobalMaxPooling2D: [],
layers.GlobalMaxPooling3D: [],
layers.MaxPooling1D: [],
layers.MaxPooling2D: [],
layers.MaxPooling3D: [],
}

_RNN_CELLS_WEIGHTS_MAP = {
Expand Down Expand Up @@ -369,6 +320,11 @@ def supports(cls, layer):
True/False whether the layer type is supported.
"""
# Automatically enable layers with zero trainable weights.
# Example: Reshape, AveragePooling2D, Maximum/Minimum, etc.
if len(layer.trainable_weights) == 0:
return True

if layer.__class__ in cls._LAYERS_WEIGHTS_MAP:
return True

Expand All @@ -393,6 +349,10 @@ def _is_rnn_layer(cls, layer):

@classmethod
def _weight_names(cls, layer):
# For layers with zero trainable weights, like Reshape, Pooling.
if len(layer.trainable_weights) == 0:
return []

return cls._LAYERS_WEIGHTS_MAP[layer.__class__]

@classmethod
Expand Down
Expand Up @@ -157,7 +157,14 @@ def testConvolutionalWeightsCA(self,

class CustomLayer(layers.Layer):
"""A custom non-clusterable layer class."""
def __init__(self, units=10):
super(CustomLayer, self).__init__()
self.add_weight(shape=(1, units),
initializer='uniform',
name='kernel')

def call(self, inputs):
return tf.matmul(inputs, self.weights)

class ClusteringLookupRegistryTest(test.TestCase, parameterized.TestCase):
"""Unit tests for the ClusteringLookupRegistry class."""
Expand Down Expand Up @@ -256,6 +263,11 @@ class CustomLayerFromClusterableLayer(layers.Dense):
"""A custom layer class derived from a built-in clusterable layer."""
pass

class CustomLayerFromClusterableLayerNoWeights(layers.Reshape):
"""A custom layer class derived from a built-in clusterable layer,
that does not have any weights."""
pass

class MinimalRNNCell(keras.layers.Layer):
"""A minimal RNN cell implementation."""

Expand Down Expand Up @@ -315,7 +327,10 @@ def testDoesNotSupportKerasUnsupportedLayer(self):
Verifies that ClusterRegistry does not support an unknown built-in layer.
"""
# ConvLSTM2D is a built-in keras layer but not supported.
self.assertFalse(ClusterRegistry.supports(layers.ConvLSTM2D(2, (5, 5))))
l = layers.ConvLSTM2D(2, (5, 5))
# We need to build weights
l.build(input_shape = (10, 10))
self.assertFalse(ClusterRegistry.supports(l))

def testSupportsKerasRNNLayers(self):
"""
Expand All @@ -330,8 +345,10 @@ def testDoesNotSupportKerasRNNLayerUnknownCell(self):
Verifies that ClusterRegistry does not support a custom non-clusterable RNN
cell.
"""
self.assertFalse(ClusterRegistry.supports(
keras.layers.RNN(ClusterRegistryTest.MinimalRNNCell(32))))
l = keras.layers.RNN(ClusterRegistryTest.MinimalRNNCell(32))
# We need to build it to have weights
l.build((10,1))
self.assertFalse(ClusterRegistry.supports(l))

def testSupportsKerasRNNLayerClusterableCell(self):
"""
Expand All @@ -350,19 +367,31 @@ def testDoesNotSupportCustomLayer(self):
def testDoesNotSupportCustomLayerInheritedFromClusterableLayer(self):
"""
Verifies that ClusterRegistry does not support a custom layer derived from
a clusterable layer.
a clusterable layer if there are trainable weights.
"""
custom_layer = ClusterRegistryTest.CustomLayerFromClusterableLayer(10)
custom_layer.build(input_shape=(10, 10))
self.assertFalse(ClusterRegistry.supports(custom_layer))

def testSupportsCustomLayerInheritedFromClusterableLayerNoWeights(self):
"""
Verifies that ClusterRegistry supports a custom layer derived from
a clusterable layer that does not have trainable weights.
"""
self.assertFalse(
ClusterRegistry.supports(
ClusterRegistryTest.CustomLayerFromClusterableLayer(10)))
custom_layer = ClusterRegistryTest.\
CustomLayerFromClusterableLayerNoWeights((7, 1))
custom_layer.build(input_shape=(3, 4))
self.assertTrue(ClusterRegistry.supports(custom_layer))

def testMakeClusterableRaisesErrorForKerasUnsupportedLayer(self):
"""
Verifies that an unsupported built-in layer cannot be made clusterable by
calling make_clusterable().
"""
l = layers.ConvLSTM2D(2, (5, 5))
l.build(input_shape = (10, 10))
with self.assertRaises(ValueError):
ClusterRegistry.make_clusterable(layers.ConvLSTM2D(2, (5, 5)))
ClusterRegistry.make_clusterable(l)

def testMakeClusterableRaisesErrorForCustomLayer(self):
"""
Expand All @@ -378,9 +407,10 @@ def testMakeClusterableRaisesErrorForCustomLayerInheritedFromClusterableLayer(
Verifies that a non-clusterable layer derived from a clusterable layer
cannot be made clusterable by calling make_clusterable().
"""
l = ClusterRegistryTest.CustomLayerFromClusterableLayer(10)
l.build(input_shape = (10, 10))
with self.assertRaises(ValueError):
ClusterRegistry.make_clusterable(
ClusterRegistryTest.CustomLayerFromClusterableLayer(10))
ClusterRegistry.make_clusterable(l)

def testMakeClusterableWorksOnKerasClusterableLayer(self):
"""
Expand Down Expand Up @@ -478,9 +508,12 @@ def testMakeClusterableRaisesErrorOnRNNLayersUnsupportedCell(self):
Verifies that make_clusterable() raises an exception when invoked with a
built-in RNN layer that contains a non-clusterable custom RNN cell.
"""
l = ClusterRegistryTest.MinimalRNNCell(5)
# we need to build weights
l.build(input_shape = (10, 1))
with self.assertRaises(ValueError):
ClusterRegistry.make_clusterable(layers.RNN(
[layers.LSTMCell(10), ClusterRegistryTest.MinimalRNNCell(5)]))
[layers.LSTMCell(10), l]))


if __name__ == '__main__':
Expand Down

0 comments on commit 270435d

Please sign in to comment.