diff --git a/sknn/mlp.py b/sknn/mlp.py index 518a484..75edd28 100644 --- a/sknn/mlp.py +++ b/sknn/mlp.py @@ -9,6 +9,7 @@ import time import logging import itertools +import contextlib log = logging.getLogger('sknn') @@ -318,17 +319,20 @@ def _setup(self): super(Classifier, self)._setup() self.label_binarizers = [] + @contextlib.contextmanager + def _patch_sklearn(self): # WARNING: Unfortunately, sklearn's LabelBinarizer handles binary data # as a special case and encodes it very differently to multiclass cases. # In our case, we want to have 2D outputs when there are 2 classes, or # the predicted probabilities (e.g. Softmax) will be incorrect. # The LabelBinarizer is also implemented in a way that this cannot be - # customized without a providing a complete rewrite, so here we patch - # the `type_of_target` function for this to work correctly, + # customized without a providing a near-complete rewrite, so here we patch + # the `type_of_target` function for this to work correctly. import sklearn.preprocessing.label as spl - assert 'type_of_target' in dir(spl),\ - "Could not setup sklearn.preprocessing.label.LabelBinarizer functionality." + backup = spl.type_of_target spl.type_of_target = lambda _: "multiclass" + yield + spl.type_of_target = backup def fit(self, X, y, w=None): """Fit the neural network to symbolic labels as a classification problem. @@ -368,7 +372,8 @@ def fit(self, X, y, w=None): # Deal deal with single- and multi-output classification problems. LB = sklearn.preprocessing.LabelBinarizer self.label_binarizers = [LB() for _ in range(y.shape[1])] - ys = [lb.fit_transform(y[:,i]) for i, lb in enumerate(self.label_binarizers)] + with self._patch_sklearn(): + ys = [lb.fit_transform(y[:,i]) for i, lb in enumerate(self.label_binarizers)] yp = numpy.concatenate(ys, axis=1) # Also transform the validation set if it was explicitly specified. @@ -376,7 +381,8 @@ def fit(self, X, y, w=None): X_v, y_v = self.valid_set if y_v.ndim == 1: y_v = y_v.reshape((y_v.shape[0], 1)) - ys = [lb.transform(y_v[:,i]) for i, lb in enumerate(self.label_binarizers)] + with self._patch_sklearn(): + ys = [lb.transform(y_v[:,i]) for i, lb in enumerate(self.label_binarizers)] y_vp = numpy.concatenate(ys, axis=1) self.valid_set = (X_v, y_vp) diff --git a/sknn/tests/test_conv.py b/sknn/tests/test_conv.py index 3ff8d38..8dce149 100644 --- a/sknn/tests/test_conv.py +++ b/sknn/tests/test_conv.py @@ -235,7 +235,7 @@ def _run(self, activation): def test_RectifierConv(self): self._run("Rectifier") - def test_RectifierConv(self): + def test_ExponentialLinear(self): self._run("ExpLin") def test_SigmoidConv(self):