Skip to content

Commit

Permalink
[ENH] Support class label encoding in fit and predict (#18)
Browse files Browse the repository at this point in the history
* Add label encoder

The label encoder that convert original labels into integers (0, 1, 2, ...)

* check dtype and add comments

* Bug fix

* bug fix

* Disable partial mode

Label encoder does not deal with partial mode yet.

* Label encoder with scikit-learn

* bug fix

* Add utility vars in __init__()

* Bug fix

There is a typo. :)

* black formatting

* Update cascade.py

* modify save and load

* fix format

* Add testing case for label encoder

* fix format

* fix format

* black formatting

* add CHANGELOG.rst
  • Loading branch information
NiMaZi committed Feb 6, 2021
1 parent 5dc0906 commit ad030f4
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 50 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Version 0.1.*
.. |Fix| replace:: :raw-html:`<span class="badge badge-danger">Fix</span>` :raw-latex:`{\small\sc [Fix]}`
.. |API| replace:: :raw-html:`<span class="badge badge-warning">API Change</span>` :raw-latex:`{\small\sc [API Change]}`

- |Feature| support class label encoding (`#18 <https://github.com/LAMDA-NJU/Deep-Forest/pull/18>`__) @NiMaZi
- |Feature| support sample weight in :meth:`fit` (`#7 <https://github.com/LAMDA-NJU/Deep-Forest/pull/7>`__) @tczhao
- |Feature| configurable predictor parameter (`#9 <https://github.com/LAMDA-NJU/Deep-Forest/issues/10>`__) @tczhao
- |Enhancement| add base class ``BaseEstimator`` and ``ClassifierMixin`` (`#8 <https://github.com/LAMDA-NJU/Deep-Forest/pull/8>`__) @pjgao
153 changes: 103 additions & 50 deletions deepforest/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numbers
import numpy as np
from abc import ABCMeta, abstractmethod
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.multiclass import type_of_target
from sklearn.base import BaseEstimator, ClassifierMixin

from . import _utils
Expand Down Expand Up @@ -155,7 +157,6 @@ def _build_predictor(
partial_mode : :obj:`bool`, default=False
Whether to train the deep forest in partial mode. For large
datasets, it is recommended to use the partial mode.
- If ``True``, the partial mode is activated and all fitted
estimators will be dumped in a local buffer;
- If ``False``, all fitted estimators are directly stored in the
Expand All @@ -172,7 +173,6 @@ def _build_predictor(
instance used by :mod:`np.random`.
verbose : :obj:`int`, default=1
Controls the verbosity when fitting and predicting.
- If ``<= 0``, silent mode, which means no logging information will
be displayed;
- If ``1``, logging information on the cascade layer level will be
Expand All @@ -181,14 +181,51 @@ def _build_predictor(
"""


def deepforest_model_doc(header):
"""Decorator on obtaining documentation for deep forest models."""
__fit_doc = """
.. note::
Deep forest supports two kinds of modes for training:
- **Full memory mode**, in which the training / testing data and
all fitted estimators are directly stored in the memory.
- **Partial mode**, in which after fitting each estimator using
the training data, it will be dumped in the buffer. During the
evaluating stage, the dumped estimators are reloaded into the
memory sequentially to evaluate the testing data.
By setting the ``partial_mode`` to ``True``, the partial mode is
activated, and a local buffer will be created at the current
directory. The partial mode is able to reduce the running memory
cost when training the deep forest.
Parameters
----------
X : :obj:`numpy.ndarray` of shape (n_samples, n_features)
The training data. Internally, it will be converted to
``np.uint8``.
y : :obj:`numpy.ndarray` of shape (n_samples,)
The class labels of input samples.
sample_weight : :obj:`numpy.ndarray` of shape (n_samples,), default=None
Sample weights. If ``None``, then samples are equally weighted.
"""


def deepforest_model_doc(header, item):
"""
Decorator on obtaining documentation for deep forest models.
Parameters
----------
header: string
Introduction to the decorated class or method.
item : string
Type of the docstring item.
"""

def get_doc(item):
"""Return the selected item."""
__doc = {"model": __model_doc, "fit": __fit_doc}
return __doc[item]

def adddoc(cls):
doc = [header + "\n\n"]
doc.extend([__model_doc])
doc.extend(get_doc(item))
cls.__doc__ = "".join(doc)

return cls

return adddoc
Expand Down Expand Up @@ -422,7 +459,6 @@ def _repr_performance(self, pivot):
def predict(self, X):
"""
Predict class labels or regression values for X.
For classification, the predicted class for each sample in X is
returned. For regression, the predicted value based on X is returned.
"""
Expand All @@ -433,35 +469,7 @@ def n_aug_features_(self):

# flake8: noqa: E501
def fit(self, X, y, sample_weight=None):
"""
Build a deep forest using the training data.
.. note::

Deep forest supports two kinds of modes for training:
- **Full memory mode**, in which the training / testing data and
all fitted estimators are directly stored in the memory.
- **Partial mode**, in which after fitting each estimator using
the training data, it will be dumped in the buffer. During the
evaluating stage, the dumped estimators are reloaded into the
memory sequentially to evaluate the testing data.
By setting the ``partial_mode`` to ``True``, the partial mode is
activated, and a local buffer will be created at the current
directory. The partial mode is able to reduce the running memory
cost when training the deep forest.
Parameters
----------
X : :obj:`numpy.ndarray` of shape (n_samples, n_features)
The training data. Internally, it will be converted to
``np.uint8``.
y : :obj:`numpy.ndarray` of shape (n_samples,)
The class labels of input samples.
sample_weight : :obj:`numpy.ndarray` of shape (n_samples,), default=None
Sample weights. If ``None``, then samples are equally weighted.
"""
self._check_input(X, y)
self._validate_params()
n_counter = 0 # a counter controlling the early stopping
Expand Down Expand Up @@ -700,15 +708,11 @@ def fit(self, X, y, sample_weight=None):
def save(self, dirname="model"):
"""
Save the model to the specified directory.
Parameters
----------
dirname : :obj:`str`, default="model"
The name of the output directory.
.. warning::
Other methods on model serialization such as :mod:`pickle` or
:mod:`joblib` are not recommended, especially when ``partial_mode``
is set to ``True``.
Expand All @@ -726,8 +730,15 @@ def save(self, dirname="model"):
d["buffer"] = self.buffer_
d["verbose"] = self.verbose
d["use_predictor"] = self.use_predictor

if self.use_predictor:
d["predictor_name"] = self.predictor_name

# Save label encoder if labels are encoded.
if hasattr(self, "labels_are_encoded"):
d["labels_are_encoded"] = self.labels_are_encoded
d["label_encoder"] = self.label_encoder_

_io.model_saveobj(dirname, "param", d)
_io.model_saveobj(dirname, "binner", self.binners_)
_io.model_saveobj(dirname, "layer", self.layers_, self.partial_mode)
Expand All @@ -740,15 +751,11 @@ def save(self, dirname="model"):
def load(self, dirname):
"""
Load the model from the specified directory.
Parameters
----------
dirname : :obj:`str`
The name of the input directory.
.. note::
The dumped model after calling :meth:`load_model` is not exactly
the same as the model before saving, because many objects
irrelevant to model inference will not be saved.
Expand All @@ -764,6 +771,11 @@ def load(self, dirname):
self.verbose = d["verbose"]
self.use_predictor = d["use_predictor"]

# Load label encoder if labels are encoded.
if "labels_are_encoded" in d:
self.labels_are_encoded = d["labels_are_encoded"]
self.label_encoder_ = d["label_encoder"]

# Load internal containers
self.binners_ = _io.model_loadobj(dirname, "binner")
self.layers_ = _io.model_loadobj(dirname, "layer", d)
Expand All @@ -789,7 +801,7 @@ def clean(self):


@deepforest_model_doc(
"""Implementation of the deep forest for classification."""
"""Implementation of the deep forest for classification.""", "model"
)
class CascadeForestClassifier(BaseCascadeForest, ClassifierMixin):
def __init__(
Expand Down Expand Up @@ -832,20 +844,63 @@ def __init__(
verbose=verbose,
)

# Used to deal with classification labels
self.labels_are_encoded = False
self.type_of_target_ = None
self.label_encoder_ = None

def _encode_class_labels(self, y):
"""
Fit the internal label encoder and return encoded labels.
"""
self.type_of_target_ = type_of_target(y)
if self.type_of_target_ in ("binary", "multiclass"):
self.labels_are_encoded = True
self.label_encoder_ = LabelEncoder()
encoded_y = self.label_encoder_.fit_transform(y)
else:
msg = (
"CascadeForestClassifier is used for binary and multiclass"
" classification, wheras the training labels seem not to"
" be any one of them."
)
raise ValueError(msg)

return encoded_y

def _decode_class_labels(self, y):
"""
Transform the predicted labels back to original encoding.
"""
if self.labels_are_encoded:
decoded_y = self.label_encoder_.inverse_transform(y)
else:
decoded_y = y

return decoded_y

def _repr_performance(self, pivot):
msg = "Val Acc = {:.3f} %"
return msg.format(pivot * 100)

@deepforest_model_doc(
"""Build a deep forest using the training data.""", "fit"
)
def fit(self, X, y, sample_weight=None):

# Check the input for classification
y = self._encode_class_labels(y)

super().fit(X, y, sample_weight)

def predict_proba(self, X):
"""
Predict class probabilities for X.
Parameters
----------
X : :obj:`numpy.ndarray` of shape (n_samples, n_features)
The input samples. Internally, its dtype will be converted to
``np.uint8``.
Returns
-------
proba : :obj:`numpy.ndarray` of shape (n_samples, n_classes)
Expand Down Expand Up @@ -917,18 +972,16 @@ def predict_proba(self, X):
def predict(self, X):
"""
Predict class for X.
Parameters
----------
X : :obj:`numpy.ndarray` of shape (n_samples, n_features)
The input samples. Internally, its dtype will be converted to
``np.uint8``.
Returns
-------
y : :obj:`numpy.ndarray` of shape (n_samples,)
The predicted classes.
"""
proba = self.predict_proba(X)

return np.argmax(proba, axis=1)
y = self._decode_class_labels(np.argmax(proba, axis=1))
return y
32 changes: 32 additions & 0 deletions tests/test_model_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
from numpy.testing import assert_array_equal

from sklearn.datasets import load_digits
from deepforest import CascadeForestClassifier


def test_model_input_label_encoder():
"""Test if the model behaves the same with and without label encoding."""

# Load data
X, y = load_digits(return_X_y=True)
y_as_str = np.char.add("label_", y.astype(str))

# Train model on integer labels. Labels should look like: 1, 2, 3, ...
model = CascadeForestClassifier(random_state=1)
model.fit(X, y)
y_pred_int_labels = model.predict(X)

# Train model on string labels. Labels should look like: "label_1", "label_2", "label_3", ...
model = CascadeForestClassifier(random_state=1)
model.fit(X, y_as_str)
y_pred_str_labels = model.predict(X)

# Check if the underlying data are the same
y_pred_int_labels_as_str = np.char.add(
"label_", y_pred_int_labels.astype(str)
)
assert_array_equal(y_pred_str_labels, y_pred_int_labels_as_str)

# Clean up buffer
model.clean()

0 comments on commit ad030f4

Please sign in to comment.