Skip to content
This repository was archived by the owner on Jul 10, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions sknn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import logging
import itertools
import contextlib

log = logging.getLogger('sknn')

Expand Down Expand Up @@ -318,17 +319,20 @@ def _setup(self):
super(Classifier, self)._setup()
self.label_binarizers = []

@contextlib.contextmanager
def _patch_sklearn(self):
# WARNING: Unfortunately, sklearn's LabelBinarizer handles binary data
# as a special case and encodes it very differently to multiclass cases.
# In our case, we want to have 2D outputs when there are 2 classes, or
# the predicted probabilities (e.g. Softmax) will be incorrect.
# The LabelBinarizer is also implemented in a way that this cannot be
# customized without a providing a complete rewrite, so here we patch
# the `type_of_target` function for this to work correctly,
# customized without a providing a near-complete rewrite, so here we patch
# the `type_of_target` function for this to work correctly.
import sklearn.preprocessing.label as spl
assert 'type_of_target' in dir(spl),\
"Could not setup sklearn.preprocessing.label.LabelBinarizer functionality."
backup = spl.type_of_target
spl.type_of_target = lambda _: "multiclass"
yield
spl.type_of_target = backup

def fit(self, X, y, w=None):
"""Fit the neural network to symbolic labels as a classification problem.
Expand Down Expand Up @@ -368,15 +372,17 @@ def fit(self, X, y, w=None):
# Deal deal with single- and multi-output classification problems.
LB = sklearn.preprocessing.LabelBinarizer
self.label_binarizers = [LB() for _ in range(y.shape[1])]
ys = [lb.fit_transform(y[:,i]) for i, lb in enumerate(self.label_binarizers)]
with self._patch_sklearn():
ys = [lb.fit_transform(y[:,i]) for i, lb in enumerate(self.label_binarizers)]
yp = numpy.concatenate(ys, axis=1)

# Also transform the validation set if it was explicitly specified.
if self.valid_set is not None:
X_v, y_v = self.valid_set
if y_v.ndim == 1:
y_v = y_v.reshape((y_v.shape[0], 1))
ys = [lb.transform(y_v[:,i]) for i, lb in enumerate(self.label_binarizers)]
with self._patch_sklearn():
ys = [lb.transform(y_v[:,i]) for i, lb in enumerate(self.label_binarizers)]
y_vp = numpy.concatenate(ys, axis=1)
self.valid_set = (X_v, y_vp)

Expand Down
2 changes: 1 addition & 1 deletion sknn/tests/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _run(self, activation):
def test_RectifierConv(self):
self._run("Rectifier")

def test_RectifierConv(self):
def test_ExponentialLinear(self):
self._run("ExpLin")

def test_SigmoidConv(self):
Expand Down