Skip to content

Commit

Permalink
added sparse filtering model freezing
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed Nov 14, 2017
1 parent b5d2039 commit 85b478d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
25 changes: 23 additions & 2 deletions spykes/ml/tensorflow/sparse_filtering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import print_function

import six
import collections

Expand Down Expand Up @@ -60,6 +58,7 @@ def __init__(self, model, layers=None):
'a list of strings. Got: "{}"'.format(layers))

self._submodels = None
self._submodel_map = None

@property
def submodels(self):
Expand All @@ -72,6 +71,17 @@ def submodels(self):
def num_layers(self):
return len(self.layer_names)

def get_submodel(self, model):
if self._submodel_map is None:
raise RuntimeError('This model must be compiled before you can '
'get a particular submodel')

if model not in self._submodel_map:
raise ValueError('Submodel not found: "{}". Must be one of {}'
.format(model, list(self._submodel_map.keys())))

return self._submodel_map[model]

def _clean_maybe_iterable_param(self, it, param):
'''Converts a potential iterable or single value to a list of values.
Expand Down Expand Up @@ -119,13 +129,24 @@ def compile(self, optimizer, freeze=False, **kwargs):
input_layer = self.model.input
for layer_name, o in zip(self.layer_names, optimizer):
output_layer = self.model.get_layer(layer_name).output

submodel = ks.models.Model(
inputs=input_layer,
outputs=output_layer,
)

# Freezes all but the selected layer.
if freeze:
for layer in submodel.layers:
layer.trainable = layer.name == layer_name

submodel.compile(loss=sparse_filtering_loss, optimizer=o, **kwargs)
submodel._make_train_function() # Forces submodel compilation.
self._submodels.append(submodel)

# Maps the layer names ot the submodels.
self._submodel_map = dict(zip(self.layer_names, self._submodels))

def fit(self, x, epochs=1, **kwargs):
'''Fits the model to the provided data.
Expand Down
20 changes: 17 additions & 3 deletions tests/ml/tensorflow/test_sparse_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from nose.tools import (
assert_raises,
assert_true,
assert_false,
assert_equal,
)

Expand Down Expand Up @@ -38,16 +39,29 @@ def test_sparse_filtering():
sf_model = SparseFiltering(model=model, layers=1)
sf_model = SparseFiltering(model=model, layers='a')
assert_equal(sf_model.layer_names, ['a'])

# Checks model compilation.
sf_model.compile('sgd')
assert_raises(RuntimeError, sf_model.compile, 'sgd')

sf_model = SparseFiltering(model=model, layers=['a', 'b'])
assert_equal(sf_model.layer_names, ['a', 'b'])

# Checks that the submodels attribute is not available yet.
with assert_raises(RuntimeError):
print(sf_model.submodels)

# Checks model compilation.
sf_model.compile('sgd')
assert_raises(RuntimeError, sf_model.compile, 'sgd')
# Checks getting a submodel.
with assert_raises(RuntimeError):
sf_model.get_submodel('a')

# Checks model freezing.
sf_model.compile('sgd', freeze=True)
assert_equal(len(sf_model.submodels), 2)

# Checks getting an invalid submodel.
with assert_raises(ValueError):
sf_model.get_submodel('c')

# Checks model fitting.
h = sf_model.fit(x=train_images, epochs=1)
Expand Down

0 comments on commit 85b478d

Please sign in to comment.