Skip to content
This repository has been archived by the owner on Jul 10, 2021. It is now read-only.

Commit

Permalink
Now considering n_stable based on the training set if there's no vali…
Browse files Browse the repository at this point in the history
…dation set specified. Updates and fixes to documentation and tests accrodingly.
  • Loading branch information
alexjc committed Nov 19, 2015
1 parent 1b46f2b commit 3a06089
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ If you want to use the latest official release, you can do so from PYPI directly

> pip install scikit-neuralnetwork

This will install a copy of `Lasagne` too as a dependency. We recommend you use a virtual environment for Python.
This will install a copy of ``Lasagne`` and other minor packages too as a dependency. We highly recommend you use a virtual environment for Python.

B) Pulling Repositories [Optional]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

You'll need to first install some dependencies manually
You'll need some dependencies, which you can install manually as follows::

> pip install numpy scipy theano lasagne

Expand Down
6 changes: 3 additions & 3 deletions sknn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def _reshape(self, X, y=None):
return X, y

def _train(self, X, y):
assert self.n_iter or self.valid_set,\
"Neither n_iter nor valid_set were specified; training would loop forever."
assert self.n_iter or self.n_stable,\
"Neither n_iter nor n_stable were specified; training would loop forever."

best_train_error, best_valid_error = float("inf"), float("inf")
best_params = []
Expand Down Expand Up @@ -159,7 +159,7 @@ def _train(self, X, y):
time.time() - start
))

if best_valid:
if best_valid or (self.valid_set is None and best_train):
best_params = self._backend._mlp_to_array()
n_stable = 0
else:
Expand Down
10 changes: 7 additions & 3 deletions sknn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,16 @@ class NeuralNetwork(object):
n_stable: int, optional
Number of interations after which training should return when the validation
error remains constant. This is a sign that the data has been fitted, or that
optimization may have stalled. Default is ``10``.
error remains (near) constant. This is usually a sign that the data has been
fitted, or that optimization may have stalled. If no validation set is specified,
then stability is judged based on the training error. Default is ``10``.
f_stable: float, optional
Threshold under which the validation error change is assumed to be stable, to
be used in combination with `n_stable`. Default is ``0.001`.
be used in combination with `n_stable`. This is calculated as a relative ratio
of improvement, so if the results are only 0.1% better training is considered
stable. The training set is used as fallback if there's no validation set. Default
is ``0.001`.
valid_set: tuple of array-like, optional
Validation set (X_v, y_v) to be used explicitly while training. Both
Expand Down
2 changes: 1 addition & 1 deletion sknn/tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_FitAutomaticValidation(self):

def test_TrainingInfinite(self):
a_in, a_out = numpy.zeros((8,16)), numpy.zeros((8,4))
self.nn = MLP(layers=[L("Linear")])
self.nn = MLP(layers=[L("Linear")], n_iter=None, n_stable=None)
assert_raises(AssertionError, self.nn._fit, a_in, a_out)


Expand Down

0 comments on commit 3a06089

Please sign in to comment.