diff --git a/sknn/backend/lasagne/mlp.py b/sknn/backend/lasagne/mlp.py index 4b73f6f..f7473d5 100644 --- a/sknn/backend/lasagne/mlp.py +++ b/sknn/backend/lasagne/mlp.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import (absolute_import, unicode_literals, print_function) +from __future__ import (absolute_import, division, unicode_literals, print_function) __all__ = ['MultiLayerPerceptronBackend'] @@ -261,7 +261,7 @@ def cast(array, indices): if shuffle: numpy.random.shuffle(indices) - for start_idx in range(0, total_size - batch_size + 1, batch_size): + for start_idx in range(0, total_size, batch_size): excerpt = indices[start_idx:start_idx + batch_size] Xb, yb, wb = cast(X, excerpt), cast(y, excerpt), cast(w, excerpt) yield Xb, yb, wb diff --git a/sknn/tests/test_training.py b/sknn/tests/test_training.py index 8adf7f3..d995b6e 100644 --- a/sknn/tests/test_training.py +++ b/sknn/tests/test_training.py @@ -53,6 +53,37 @@ def terminate(**_): assert_equals(self.counter, 1) +class TestBatchSize(unittest.TestCase): + + def setUp(self): + self.batch_count = 0 + self.nn = MLP( + layers=[L("Rectifier")], + learning_rate=0.001, n_iter=1, + callback={'on_batch_start': self.on_batch_start}) + + def on_batch_start(self, **args): + self.batch_count += 1 + + def test_BatchSizeLargerThanInput(self): + self.nn.batch_size = 32 + a_in, a_out = numpy.zeros((8,16)), numpy.ones((8,4)) + self.nn._fit(a_in, a_out) + assert_equals(1, self.batch_count) + + def test_BatchSizeSmallerThanInput(self): + self.nn.batch_size = 4 + a_in, a_out = numpy.ones((8,16)), numpy.zeros((8,4)) + self.nn._fit(a_in, a_out) + assert_equals(2, self.batch_count) + + def test_BatchSizeNonMultiple(self): + self.nn.batch_size = 4 + a_in, a_out = numpy.zeros((9,16)), numpy.ones((9,4)) + self.nn._fit(a_in, a_out) + assert_equals(3, self.batch_count) + + class TestCustomLogging(unittest.TestCase): def setUp(self):