Skip to content
This repository has been archived by the owner on Jul 10, 2021. It is now read-only.

Commit

Permalink
Removed the mutator as it can be replaced entirely with the callbacks…
Browse files Browse the repository at this point in the history
…. Ported tests accordingly. Note, however, that repeatedly mutating data in batches may have unintended side-effects.
  • Loading branch information
alexjc committed Nov 20, 2015
1 parent 63aae4c commit 3d527d3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 20 deletions.
10 changes: 4 additions & 6 deletions sknn/backend/lasagne/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,10 @@ def cast(array):
for start_idx in range(0, total_size - batch_size + 1, batch_size):
excerpt = indices[start_idx:start_idx + batch_size]
Xb, yb = cast(X[excerpt]), cast(y[excerpt])
if self.mutator is not None:
for x, _ in zip(Xb, yb):
self.mutator(x)

yield Xb, yb

def _batch_impl(self, X, y, processor, output, shuffle):
def _batch_impl(self, X, y, processor, mode, output, shuffle):
progress, batches = 0, X.shape[0] / self.batch_size
loss, count = 0.0, 0
for Xb, yb in self._iterate_data(X, y, self.batch_size, shuffle):
Expand All @@ -255,10 +253,10 @@ def _batch_impl(self, X, y, processor, output, shuffle):
return loss / count

def _train_impl(self, X, y):
return self._batch_impl(X, y, self.trainer, output='.', shuffle=True)
return self._batch_impl(X, y, self.trainer, mode='train', output='.', shuffle=True)

def _valid_impl(self, X, y):
return self._batch_impl(X, y, self.validator, output=' ', shuffle=False)
return self._batch_impl(X, y, self.validator, mode='valid', output=' ', shuffle=False)

@property
def is_initialized(self):
Expand Down
7 changes: 0 additions & 7 deletions sknn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,6 @@ class NeuralNetwork(object):
only be applied to layers of type ``Linear`` or ``Gaussian`` and they must be used as
the output layer (PyLearn2 only).
mutator: callable, optional
A function that takes a single training sample ``(X, y)`` at each epoch and returns
a modified version. This is useful for dataset augmentation, e.g. mirroring input
images or jittering.
callback: callable or dict, optional
An observer mechanism that exposes information about the inner training loop. This is
either a single function that takes ``cbs(event, **variables)`` as a parameter, or a
Expand Down Expand Up @@ -405,7 +400,6 @@ def __init__(
valid_set=None,
valid_size=0.0,
loss_type=None,
mutator=None,
callback=None,
debug=False,
verbose=None,
Expand Down Expand Up @@ -457,7 +451,6 @@ def __init__(
self.valid_set = valid_set
self.valid_size = valid_size
self.loss_type = loss_type
self.mutator = mutator
self.debug = debug
self.verbose = verbose
self.callback = callback
Expand Down
8 changes: 4 additions & 4 deletions sknn/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ def setUp(self):
self.nn = MLPR(
layers=[L("Linear")],
n_iter=1,
batch_size=2,
mutator=self._mutate_fn)
batch_size=1,
callback={'on_batch_start': self._mutate_fn})

def _mutate_fn(self, sample):
def _mutate_fn(self, Xb, **_):
self.called += 1
sample[sample == 0.0] = self.value
Xb[Xb == 0.0] = self.value

def test_TestCalledOK(self):
a_in, a_out = numpy.zeros((8,16)), numpy.zeros((8,4))
Expand Down
6 changes: 3 additions & 3 deletions sknn/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def test_FitHybrid(self):
self.nn._fit(X, y)

def test_FitMutator(self):
def mutate(x):
def mutate(Xb, **_):
self.count += 1
return x - 0.5
self.nn.mutator = mutate
Xb -= 0.5
self.nn.callback = {'on_batch_start': mutate}

for t in SPARSE_TYPES:
sparse_matrix = getattr(scipy.sparse, t)
Expand Down

0 comments on commit 3d527d3

Please sign in to comment.