Skip to content

Commit

Permalink
Added node_indices argument to get_apoz(). Added input checking. Impr…
Browse files Browse the repository at this point in the history
…oved docs.
  • Loading branch information
BenWhetton committed Aug 21, 2017
1 parent 9eef78f commit 0afa607
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions src/kerasprune/identify.py
Expand Up @@ -7,21 +7,25 @@
from kerasprune import utils


def get_apoz(model, layer, x_val):
def get_apoz(model, layer, x_val, node_indices=None):
"""Identify neurons with high Average Percentage of Zeros (APoZ).
The APoZ a.k.a. (A)verage (P)ercentage (o)f activations equal to (Z)ero,
is a metric for the usefulness of a channel defined in this paper:
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient
Deep Architectures" - [Hu et al. (2016)][]
`high_apoz()` enables the pruning methodology described in this paper to be
replicated
replicated.
If node_indices are not specified and the layer is shared within the model
the APoZ will be calculated over all instances of the shared layer.
Args:
model: A Keras model.
layer: The layer whose channels will be evaluated for pruning.
x_val: The input of the validation set. This will be used to calculate
the activations of the layer of interest.
node_indices(list[int]): (optional) A list of node indices.
Returns:
List of the APoZ values for each channel in the layer.
Expand All @@ -30,8 +34,24 @@ def get_apoz(model, layer, x_val):
if isinstance(layer, str):
layer = model.get_layer(name=layer)

# Check that layer is in the model
if layer not in model.layers:
raise ValueError('layer is not a valid Layer in model.')

layer_node_indices = utils.find_nodes_in_model(model, layer)
# If no nodes are specified, all of the layer's inbound nodes which are
# in model are selected.
if not node_indices:
node_indices = layer_node_indices
# Check for duplicate node indices
elif len(node_indices) != len(set(node_indices)):
raise ValueError('`node_indices` contains duplicate values.')
# Check that all of the selected nodes are in the layer
elif not set(node_indices).issubset(layer_node_indices):
raise ValueError('One or more nodes specified by `layer` and '
'`node_indices` are not in `model`.')

data_format = getattr(layer, 'data_format', 'channels_last')
node_indices = utils.find_nodes_in_model(model, layer)
# Perform the forward pass and get the activations of the layer.
activations = []
for node_index in node_indices:
Expand Down

0 comments on commit 0afa607

Please sign in to comment.