From 783646cf10121249ff9b3fbb58c0e993813c0081 Mon Sep 17 00:00:00 2001 From: "Alex J. Champandard" Date: Fri, 1 Jan 2016 12:40:56 +0100 Subject: [PATCH 1/2] Monkey-patching sklearn every use rather than upfront so it works more consistently with pickle and alongside other libraries. May need to copy/paste functionality due to the way the original is implemented :-| --- sknn/mlp.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sknn/mlp.py b/sknn/mlp.py index 518a484..75edd28 100644 --- a/sknn/mlp.py +++ b/sknn/mlp.py @@ -9,6 +9,7 @@ import time import logging import itertools +import contextlib log = logging.getLogger('sknn') @@ -318,17 +319,20 @@ def _setup(self): super(Classifier, self)._setup() self.label_binarizers = [] + @contextlib.contextmanager + def _patch_sklearn(self): # WARNING: Unfortunately, sklearn's LabelBinarizer handles binary data # as a special case and encodes it very differently to multiclass cases. # In our case, we want to have 2D outputs when there are 2 classes, or # the predicted probabilities (e.g. Softmax) will be incorrect. # The LabelBinarizer is also implemented in a way that this cannot be - # customized without a providing a complete rewrite, so here we patch - # the `type_of_target` function for this to work correctly, + # customized without a providing a near-complete rewrite, so here we patch + # the `type_of_target` function for this to work correctly. import sklearn.preprocessing.label as spl - assert 'type_of_target' in dir(spl),\ - "Could not setup sklearn.preprocessing.label.LabelBinarizer functionality." + backup = spl.type_of_target spl.type_of_target = lambda _: "multiclass" + yield + spl.type_of_target = backup def fit(self, X, y, w=None): """Fit the neural network to symbolic labels as a classification problem. @@ -368,7 +372,8 @@ def fit(self, X, y, w=None): # Deal deal with single- and multi-output classification problems. LB = sklearn.preprocessing.LabelBinarizer self.label_binarizers = [LB() for _ in range(y.shape[1])] - ys = [lb.fit_transform(y[:,i]) for i, lb in enumerate(self.label_binarizers)] + with self._patch_sklearn(): + ys = [lb.fit_transform(y[:,i]) for i, lb in enumerate(self.label_binarizers)] yp = numpy.concatenate(ys, axis=1) # Also transform the validation set if it was explicitly specified. @@ -376,7 +381,8 @@ def fit(self, X, y, w=None): X_v, y_v = self.valid_set if y_v.ndim == 1: y_v = y_v.reshape((y_v.shape[0], 1)) - ys = [lb.transform(y_v[:,i]) for i, lb in enumerate(self.label_binarizers)] + with self._patch_sklearn(): + ys = [lb.transform(y_v[:,i]) for i, lb in enumerate(self.label_binarizers)] y_vp = numpy.concatenate(ys, axis=1) self.valid_set = (X_v, y_vp) From 5e908c34d9a52738f85ee9430aaf19e1686f7f28 Mon Sep 17 00:00:00 2001 From: "Alex J. Champandard" Date: Fri, 1 Jan 2016 12:49:47 +0100 Subject: [PATCH 2/2] Fix test name, coverage was below 100%! --- sknn/tests/test_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sknn/tests/test_conv.py b/sknn/tests/test_conv.py index 3ff8d38..8dce149 100644 --- a/sknn/tests/test_conv.py +++ b/sknn/tests/test_conv.py @@ -235,7 +235,7 @@ def _run(self, activation): def test_RectifierConv(self): self._run("Rectifier") - def test_RectifierConv(self): + def test_ExponentialLinear(self): self._run("ExpLin") def test_SigmoidConv(self):