Skip to content

Commit

Permalink
Updated version number to 0.0.1 for release. Some final changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
BenWhetton committed Aug 23, 2017
1 parent c55f0ce commit e45e931
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 39 deletions.
4 changes: 4 additions & 0 deletions MANIFEST.in
@@ -0,0 +1,4 @@
include LICENSE
include README.md
include tox.ini
graft tests
25 changes: 16 additions & 9 deletions README.md
Expand Up @@ -3,8 +3,8 @@
## Introduction
Keras-surgeon provides simple methods for modifying trained
[Keras][] models. The following functionality is currently implemented:
* Delete neurons/channels from layers
* Delete layers
* delete neurons/channels from layers
* delete layers
* insert layers
* replace layers

Expand All @@ -17,21 +17,28 @@ inspired the name of this package.
The `operations` module contains simple methods to perform network surgery on a
single layer within a model.\
Example usage:

```python
from kerassurgeon.operations import delete_layer, insert_layer, delete_channels
# delete layer_1 from a model
model = delete_layer(model, layer_1)
# insert new_layer_1 before layer_2 in a model
model = insert_layer(model, layer_2, new_layer_3)
# delete channels 0, 4 and 67 from layer_2 in model
model = delete_channels(model, layer_2, [0,4,67])
```

The `Surgeon` class enables many modifications to be performed in a single operation.\
Example usage:
```python
# model is a Keras model
# layer_1 and layer_2 are Keras layers from model
# channels is a list of integers; indices of the channels to be deleted
# delete channels 2, 6 and 8 from layer_1 and insert new_layer_1 before
# layer_2 in a model
from kerassurgeon import Surgeon
surgeon = Surgeon(model)
surgeon.add_job('delete_channels', model, layer_1, channels=channels)
surgeon.add_job('insert_layer', model, layer_2, new_layer=layer_3)
surgeon.add_job('delete_channels', model, layer_1, channels=[2, 6, 8])
surgeon.add_job('insert_layer', model, layer_2, new_layer=new_layer_1)
new_model = surgeon.operate()
```
The `identify` module contains method to identify which channels to prune.
The `identify` module contains methods to identify which channels to prune.


## Documentation
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
@@ -0,0 +1,2 @@
[metadata]
description-file = README.md
6 changes: 2 additions & 4 deletions setup.py
Expand Up @@ -5,17 +5,15 @@

setup(
name='kerassurgeon',
version="0.0.0a1",
url='',
version="0.0.1",
url='https://github.com/BenWhetton/keras-surgeon',
license='MIT',
description='A library for performing network surgery on trained Keras models.'
'Useful for deep neural network pruning.',
author='Ben Whetton',
author_email='Ben.Whetton@gmail.com',
install_requires=['keras'],

extras_require={'pd': ['pandas'], },

tests_require=['pytest'],
packages=find_packages('src'),
package_dir={'': 'src'}
Expand Down
2 changes: 2 additions & 0 deletions src/kerassurgeon/__init__.py
@@ -1 +1,3 @@
from kerassurgeon.surgeon import Surgeon

__version__ = '0.0.1'
21 changes: 0 additions & 21 deletions src/kerassurgeon/operations.py
@@ -1,27 +1,6 @@
from kerassurgeon.surgeon import Surgeon


def rebuild_sequential(layers):
"""Rebuild a sequential model from a list of layers.
Arguments:
layers: List of Keras layers
Returns:
A Keras Sequential model
"""
from keras.models import Sequential

weights = []
for layer in layers:
weights.append(layer.get_weights())

new_model = Sequential(layers=layers)
for i, layer in enumerate(new_model.layers):
layer.set_weights(weights[i])
return new_model


def delete_layer(model, layer, *, node_indices=None, copy=True):
"""Delete instances of a layer from a Keras model.
Expand Down
5 changes: 0 additions & 5 deletions tests/test_surgeon.py
Expand Up @@ -56,11 +56,6 @@ def model_2():
return Model(model.inputs, model.outputs)


def test_rebuild_sequential(model_1):
model_2 = operations.rebuild_sequential(model_1.layers)
assert compare_models_seq(model_1, model_2)


def test_rebuild_submodel(model_2):
output_nodes = [model_2.output_layers[i].inbound_nodes[node_index]
for i, node_index in
Expand Down

0 comments on commit e45e931

Please sign in to comment.