From 058f6bc151eb5b422bf648031be03747fb27ea77 Mon Sep 17 00:00:00 2001 From: Olivier Gagnon Date: Thu, 13 Aug 2015 10:19:17 -0400 Subject: [PATCH] Added a method to reset the neural network to the best state when using n_stable argument --- sknn/backend/pylearn2/ae.py | 9 +++++++++ sknn/nn.py | 10 +++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sknn/backend/pylearn2/ae.py b/sknn/backend/pylearn2/ae.py index cc33130..16ac09e 100644 --- a/sknn/backend/pylearn2/ae.py +++ b/sknn/backend/pylearn2/ae.py @@ -85,3 +85,12 @@ def _create_ae_datasets(self, ds, layers): trds = transformer_dataset.TransformerDataset(raw=ds, transformer=stack) trainsets.append(trds) return trainsets + + def _mlp_get_weights(self, l): + return l.get_weights() + + def _mlp_to_array(self): + return [i.get_value() for i in self.dca.get_params()] + + def _array_to_mlp(self, array, nn): + self.dca.set_params(array) diff --git a/sknn/nn.py b/sknn/nn.py index 0376b6d..71d1160 100644 --- a/sknn/nn.py +++ b/sknn/nn.py @@ -408,7 +408,8 @@ def __init__( self.debug = debug self.verbose = verbose self.weights = None - + self.best_valid_network = None + self._backend = None self._create_logger() self._setup() @@ -449,6 +450,7 @@ def _train_layer(self, trainer, layer, dataset): layer.monitor.channels = {str(k): v for k, v in layer.monitor.channels.items()} best_valid_error = float("inf") + for i in itertools.count(1): start = time.time() trainer.train(dataset=dataset) @@ -473,6 +475,9 @@ def _train_layer(self, trainer, layer, dataset): time.time() - start )) + if best_valid: + self.best_valid_network = self._backend._mlp_to_array() + if not trainer.continue_learning(layer): log.debug("") log.info("Early termination condition fired at %i iterations.", i) @@ -481,3 +486,6 @@ def _train_layer(self, trainer, layer, dataset): log.debug("") log.info("Terminating after specified %i total iterations.", i) break + + def reset_best_network(self): + self._backend._array_to_mlp(self.best_valid_network, self._backend.mlp) \ No newline at end of file