Skip to content

Commit

Permalink
Add Anti Zero-Drift functionality for Sparsity-Aware clustering (expe…
Browse files Browse the repository at this point in the history
…rimental)

 * Set the random seed in the sparsity preservation test to a specific value to
   make sure that some of the weights are null
  • Loading branch information
MatteoArm committed Sep 23, 2020
1 parent 006f7d1 commit 9936522
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def testValuesRemainClusteredAfterTraining(self):

@keras_parameterized.run_all_keras_modes
def testSparsityIsPreservedDuringTraining(self):
""" Set a specific random seed to ensure that we get some null weights to test sparsity preservation with. """
tf.random.set_seed(1)

"""Verifies that training a clustered model does not destroy the sparsity of the weights."""
original_model = keras.Sequential([
layers.Dense(5, input_shape=(5,)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def get_updater(for_weight_name):
def fn():
# Get the clustered weights
pulling_indices = self.pulling_indices_tf[for_weight_name]
clustered_weights = self.clustering_impl[for_weight_name].get_clustered_weight(pulling_indices)
clustered_weights = self.clustering_impl[for_weight_name].\
get_clustered_weight(pulling_indices)

if self.preserve_sparsity:
# Get the sparsity mask
Expand Down Expand Up @@ -293,10 +294,20 @@ def call(self, inputs):
# since they are integers and not differentiable. Gradients won't flow back
# through tf.argmin
# Go through all tensors and replace them with their clustered copies.
for weight_name, _ in self.clustered_vars:
# Get the clustered weights
for weight_name in self.ori_weights_vars_tf:
pulling_indices = self.pulling_indices_tf[weight_name]
clustered_weights = self.clustering_impl[weight_name].get_clustered_weight(pulling_indices)

# Update cluster associations
pulling_indices.assign(tf.dtypes.cast(
self.clustering_impl[weight_name].\
get_pulling_indices(self.ori_weights_vars_tf[weight_name]),
pulling_indices.dtype
))

# Get the clustered weights
clustered_weights = self.clustering_impl[weight_name].\
get_clustered_weight_forward(pulling_indices,\
self.ori_weights_vars_tf[weight_name])

if self.preserve_sparsity:
# Get the sparsity mask
Expand Down

0 comments on commit 9936522

Please sign in to comment.