Skip to content
This repository has been archived by the owner on Jul 10, 2021. It is now read-only.

Commit

Permalink
Merge 3d527d3 into a56243c
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjc committed Nov 20, 2015
2 parents a56243c + 3d527d3 commit 1870226
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 31 deletions.
12 changes: 6 additions & 6 deletions sknn/backend/lasagne/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,29 +234,29 @@ def cast(array):
for start_idx in range(0, total_size - batch_size + 1, batch_size):
excerpt = indices[start_idx:start_idx + batch_size]
Xb, yb = cast(X[excerpt]), cast(y[excerpt])
if self.mutator is not None:
for x, _ in zip(Xb, yb):
self.mutator(x)

yield Xb, yb

def _batch_impl(self, X, y, processor, output, shuffle):
def _batch_impl(self, X, y, processor, mode, output, shuffle):
progress, batches = 0, X.shape[0] / self.batch_size
loss, count = 0.0, 0
for Xb, yb in self._iterate_data(X, y, self.batch_size, shuffle):
self._do_callback('on_batch_start', locals())
loss += processor(Xb, yb)
count += 1
while count / batches > progress / 60:
sys.stdout.write(output)
sys.stdout.flush()
progress += 1
self._do_callback('on_batch_finish', locals())
sys.stdout.write('\r')
return loss / count

def _train_impl(self, X, y):
return self._batch_impl(X, y, self.trainer, output='.', shuffle=True)
return self._batch_impl(X, y, self.trainer, mode='train', output='.', shuffle=True)

def _valid_impl(self, X, y):
return self._batch_impl(X, y, self.validator, output=' ', shuffle=False)
return self._batch_impl(X, y, self.validator, mode='valid', output=' ', shuffle=False)

@property
def is_initialized(self):
Expand Down
42 changes: 30 additions & 12 deletions sknn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,53 +118,69 @@ def _reshape(self, X, y=None):
X = X.reshape((X.shape[0], numpy.product(X.shape[1:])))
return X, y

def _do_callback(self, event, variables):
if self.callback is None:
return

del variables['self']
if isinstance(self.callback, dict):
function = self.callback.get(event, None)
return function(**variables) if function else None
else:
return self.callback(event, **variables)

def _train(self, X, y):
assert self.n_iter or self.n_stable,\
"Neither n_iter nor n_stable were specified; training would loop forever."

best_train_error, best_valid_error = float("inf"), float("inf")
best_params = []
n_stable = 0
self._do_callback('on_train_start', locals())

for i in itertools.count(1):
start = time.time()
start_time = time.time()
self._do_callback('on_epoch_start', locals())

best_train = False
is_best_train = False
avg_train_error = self._backend._train_impl(X, y)
if avg_train_error is not None:
if math.isnan(avg_train_error):
raise RuntimeError("Training diverged and returned NaN.")

best_train_error = min(best_train_error, avg_train_error)
best_train = bool(avg_train_error < best_train_error * (1.0 + self.f_stable))
is_best_train = bool(avg_train_error < best_train_error * (1.0 + self.f_stable))

best_valid = False
is_best_valid = False
avg_valid_error = None
if self.valid_set is not None:
avg_valid_error = self._backend._valid_impl(*self.valid_set)
if avg_valid_error is not None:
best_valid_error = min(best_valid_error, avg_valid_error)
best_valid = bool(avg_valid_error < best_valid_error * (1.0 + self.f_stable))
is_best_valid = bool(avg_valid_error < best_valid_error * (1.0 + self.f_stable))

finish_time = time.time()
log.debug("\r{:>5} {}{}{} {}{}{} {:>5.1f}s".format(
i,
ansi.BLUE if best_train else "",
ansi.BLUE if is_best_train else "",
"{0:>10.3e}".format(float(avg_train_error)) if (avg_train_error is not None) else " N/A ",
ansi.ENDC if best_train else "",
ansi.ENDC if is_best_train else "",

ansi.GREEN if best_valid else "",
ansi.GREEN if is_best_valid else "",
"{:>10.3e}".format(float(avg_valid_error)) if (avg_valid_error is not None) else " N/A ",
ansi.ENDC if best_valid else "",
ansi.ENDC if is_best_valid else "",

time.time() - start
finish_time - start_time
))

if best_valid or (self.valid_set is None and best_train):
if is_best_valid or (self.valid_set is None and is_best_train):
best_params = self._backend._mlp_to_array()
n_stable = 0
else:
n_stable += 1

self._do_callback('on_epoch_finish', locals())

if self.valid_set is not None and n_stable >= self.n_stable:
log.debug("")
log.info("Early termination condition fired at %i iterations.", i)
Expand All @@ -173,7 +189,8 @@ def _train(self, X, y):
log.debug("")
log.info("Terminating after specified %i total iterations.", i)
break


self._do_callback('on_train_finish', locals())
self._backend._array_to_mlp(best_params, self._backend.mlp)

def _fit(self, X, y):
Expand Down Expand Up @@ -362,6 +379,7 @@ def partial_fit(self, X, y, classes=None):
self.label_binarizers = [LB() for _ in range(y.shape[1])]
for lb, cls in zip(self.label_binarizers, classes):
lb.fit(cls)

return self.fit(X, y)

def predict_proba(self, X):
Expand Down
24 changes: 18 additions & 6 deletions sknn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,22 @@ class NeuralNetwork(object):
only be applied to layers of type ``Linear`` or ``Gaussian`` and they must be used as
the output layer (PyLearn2 only).
mutator: callable, optional
A function that takes a single training sample ``(X, y)`` at each epoch and returns
a modified version. This is useful for dataset augmentation, e.g. mirroring input
images or jittering.
callback: callable or dict, optional
An observer mechanism that exposes information about the inner training loop. This is
either a single function that takes ``cbs(event, **variables)`` as a parameter, or a
dictionary of functions indexed by on `event` string that conforms to ``cb(**variables)``.
There are multiple events sent from the inner training loop:
* ``on_train_start`` — Called when the main training function is entered.
* ``on_epoch_start`` — Called the first thing when a new iteration starts.
* ``on_batch_start`` — Called before an individual batch is processed.
* ``on_batch_finish`` — Called after that individual batch is processed.
* ``on_epoch_finish`` — Called the first last when the iteration is done.
* ``on_train_finish`` — Called just before the training function exits.
For each function, the ``variables`` dictionary passed contains all local variables within
the training implementation.
debug: bool, optional
Should the underlying training algorithms perform validation on the data
Expand Down Expand Up @@ -388,7 +400,7 @@ def __init__(
valid_set=None,
valid_size=0.0,
loss_type=None,
mutator=None,
callback=None,
debug=False,
verbose=None,
**params):
Expand Down Expand Up @@ -439,9 +451,9 @@ def __init__(
self.valid_set = valid_set
self.valid_size = valid_size
self.loss_type = loss_type
self.mutator = mutator
self.debug = debug
self.verbose = verbose
self.callback = callback

self._backend = None
self._create_logger()
Expand Down
65 changes: 65 additions & 0 deletions sknn/tests/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
from nose.tools import (assert_in, assert_raises, assert_equals)

import collections
import numpy
from sknn.mlp import MultiLayerPerceptron as MLP, Layer as L

import sknn.mlp


class TestSingleCallback(unittest.TestCase):

def setUp(self):
self.data = collections.defaultdict(list)

def _callback(self, event, **variables):
self.data[event].append(variables)

def test_TrainingCallbacks(self):
a_in, a_out = numpy.zeros((8,16)), numpy.zeros((8,4))
nn = MLP(layers=[L("Linear")], n_iter=4, callback=self._callback)
nn._fit(a_in, a_out)
assert_equals(len(self.data['on_train_start']), 1)
assert_equals(len(self.data['on_train_finish']), 1)

def test_EpochCallbacks(self):
a_in, a_out = numpy.zeros((8,16)), numpy.zeros((8,4))
nn = MLP(layers=[L("Linear")], n_iter=4, callback=self._callback)
nn._fit(a_in, a_out)
assert_equals(len(self.data['on_epoch_start']), 4)
assert_equals(len(self.data['on_epoch_finish']), 4)

def test_BatchCallbacks(self):
a_in, a_out = numpy.zeros((8,16)), numpy.zeros((8,4))
nn = MLP(layers=[L("Linear")], n_iter=1, batch_size=4, callback=self._callback)
nn._fit(a_in, a_out)
assert_equals(len(self.data['on_batch_start']), 2)
assert_equals(len(self.data['on_batch_finish']), 2)


class TestSpecificCallback(unittest.TestCase):

def setUp(self):
self.data = []

def _callback(self, **variables):
self.data.append(variables)

def test_TrainingCallback(self):
a_in, a_out = numpy.zeros((8,16)), numpy.zeros((8,4))
nn = MLP(layers=[L("Linear")], n_iter=4, callback={'on_train_start': self._callback})
nn._fit(a_in, a_out)
assert_equals(len(self.data), 1)

def test_EpochCallback(self):
a_in, a_out = numpy.zeros((8,16)), numpy.zeros((8,4))
nn = MLP(layers=[L("Linear")], n_iter=4, callback={'on_epoch_start': self._callback})
nn._fit(a_in, a_out)
assert_equals(len(self.data), 4)

def test_BatchCallbacks(self):
a_in, a_out = numpy.zeros((8,16)), numpy.zeros((8,4))
nn = MLP(layers=[L("Linear")], n_iter=1, batch_size=4, callback={'on_batch_start': self._callback})
nn._fit(a_in, a_out)
assert_equals(len(self.data), 2)
8 changes: 4 additions & 4 deletions sknn/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ def setUp(self):
self.nn = MLPR(
layers=[L("Linear")],
n_iter=1,
batch_size=2,
mutator=self._mutate_fn)
batch_size=1,
callback={'on_batch_start': self._mutate_fn})

def _mutate_fn(self, sample):
def _mutate_fn(self, Xb, **_):
self.called += 1
sample[sample == 0.0] = self.value
Xb[Xb == 0.0] = self.value

def test_TestCalledOK(self):
a_in, a_out = numpy.zeros((8,16)), numpy.zeros((8,4))
Expand Down
6 changes: 3 additions & 3 deletions sknn/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def test_FitHybrid(self):
self.nn._fit(X, y)

def test_FitMutator(self):
def mutate(x):
def mutate(Xb, **_):
self.count += 1
return x - 0.5
self.nn.mutator = mutate
Xb -= 0.5
self.nn.callback = {'on_batch_start': mutate}

for t in SPARSE_TYPES:
sparse_matrix = getattr(scipy.sparse, t)
Expand Down

0 comments on commit 1870226

Please sign in to comment.