diff --git a/sknn/mlp.py b/sknn/mlp.py index 1d5c0e2..6df3f75 100644 --- a/sknn/mlp.py +++ b/sknn/mlp.py @@ -313,8 +313,14 @@ def __getstate__(self): del d[k] return d + def _mlp_get_weights(self, l): + if isinstance(l, mlp.ConvElemwise) or l.requires_reformat: + W, = l.transformer.get_params() + return W.get_value() + return l.get_weights() + def _mlp_to_array(self): - return [(l.get_weights(), l.get_biases()) for l in self.mlp.layers] + return [(self._mlp_get_weights(l), l.get_biases()) for l in self.mlp.layers] def __setstate__(self, d): self.__dict__.update(d) @@ -324,7 +330,7 @@ def __setstate__(self, d): def _array_to_mlp(self, array, nn): for layer, (weights, biases) in zip(nn.layers, array): - assert layer.get_weights().shape == weights.shape + assert self._mlp_get_weights(layer).shape == weights.shape layer.set_weights(weights) assert layer.get_biases().shape == biases.shape diff --git a/sknn/tests/test_conv.py b/sknn/tests/test_conv.py index b4a658c..777979d 100644 --- a/sknn/tests/test_conv.py +++ b/sknn/tests/test_conv.py @@ -1,6 +1,8 @@ import unittest -from nose.tools import (assert_is_not_none, assert_raises, assert_equal) +from nose.tools import (assert_is_not_none, assert_true, assert_raises, assert_equal) +import io +import pickle import numpy from sknn.mlp import Regressor as MLPR @@ -9,13 +11,15 @@ class TestConvolution(unittest.TestCase): - def _run(self, nn, a_in=None): + def _run(self, nn, a_in=None, fit=True): if a_in is None: a_in = numpy.zeros((8,32,16,1)) a_out = numpy.zeros((8,4)) - nn.fit(a_in, a_out) + if fit is True: + nn.fit(a_in, a_out) a_test = nn.predict(a_in) assert_equal(type(a_out), type(a_in)) + return a_test def test_MissingLastDim(self): self._run(MLPR( @@ -205,3 +209,46 @@ def _run(self, nn, a_in=None): nn.fit(a_in, a_out) a_test = nn.predict(a_in) assert_equal(type(a_out), type(a_in)) + + +class TestSerialization(unittest.TestCase): + + def setUp(self): + self.nn = MLPR( + layers=[ + C("Rectifier", channels=6, kernel_shape=(3,3)), + C("Sigmoid", channels=4, kernel_shape=(5,5)), + C("Tanh", channels=8, kernel_shape=(3,3)), + L("Linear")], + n_iter=1) + + def test_SerializeFail(self): + buf = io.BytesIO() + assert_raises(AssertionError, pickle.dump, self.nn, buf) + + def test_SerializeCorrect(self): + a_in, a_out = numpy.zeros((8,32,16,1)), numpy.zeros((8,4)) + self.nn.fit(a_in, a_out) + + buf = io.BytesIO() + pickle.dump(self.nn, buf) + + buf.seek(0) + nn = pickle.load(buf) + + assert_is_not_none(nn.mlp) + assert_equal(nn.layers, self.nn.layers) + + +class TestSerializedNetwork(TestConvolution): + + def _run(self, original, a_in=None): + a_test = super(TestSerializedNetwork, self)._run(original, a_in) + + buf = io.BytesIO() + pickle.dump(original, buf) + buf.seek(0) + nn = pickle.load(buf) + + a_copy = super(TestSerializedNetwork, self)._run(nn, a_in, fit=False) + assert_true((a_test == a_copy).all())