diff --git a/sknn/backend/lasagne/mlp.py b/sknn/backend/lasagne/mlp.py index 1e1cfbe..3cbee3d 100644 --- a/sknn/backend/lasagne/mlp.py +++ b/sknn/backend/lasagne/mlp.py @@ -283,9 +283,9 @@ def _array_to_mlp(self, array, nn): ws = tuple(layer.W.shape.eval()) assert ws == weights.shape, "Layer weights shape mismatch: %r != %r" %\ (ws, weights.shape) - layer.W.set_value(weights) + layer.W.set_value(weights.astype(theano.config.floatX)) bs = tuple(layer.b.shape.eval()) assert bs == biases.shape, "Layer biases shape mismatch: %r != %r" %\ (bs, biases.shape) - layer.b.set_value(biases) + layer.b.set_value(biases.astype(theano.config.floatX)) diff --git a/sknn/tests/test_data.py b/sknn/tests/test_data.py index 26de705..8f0a6cf 100644 --- a/sknn/tests/test_data.py +++ b/sknn/tests/test_data.py @@ -60,8 +60,8 @@ def test_SetLayerParamsList(self): nn.set_parameters([(weights, biases)]) p = nn.get_parameters() - assert_true((p[0].weights == weights).all()) - assert_true((p[0].biases == biases).all()) + assert_true((p[0].weights.astype('float32') == weights.astype('float32')).all()) + assert_true((p[0].biases.astype('float32') == biases.astype('float32')).all()) def test_LayerParamsSkipOneWithNone(self): nn = MLPR(layers=[L("Sigmoid", units=32), L("Linear", name='abcd')]) @@ -73,8 +73,8 @@ def test_LayerParamsSkipOneWithNone(self): nn.set_parameters([None, (weights, biases)]) p = nn.get_parameters() - assert_true((p[1].weights == weights).all()) - assert_true((p[1].biases == biases).all()) + assert_true((p[1].weights.astype('float32') == weights.astype('float32')).all()) + assert_true((p[1].biases.astype('float32') == biases.astype('float32')).all()) def test_SetLayerParamsDict(self): nn = MLPR(layers=[L("Sigmoid", units=32), L("Linear", name='abcd')]) @@ -86,5 +86,5 @@ def test_SetLayerParamsDict(self): nn.set_parameters({'abcd': (weights, biases)}) p = nn.get_parameters() - assert_true((p[1].weights == weights).all()) - assert_true((p[1].biases == biases).all()) + assert_true((p[1].weights.astype('float32') == weights.astype('float32')).all()) + assert_true((p[1].biases.astype('float32') == biases.astype('float32')).all())