Skip to content

Commit

Permalink
FIX scikit-learn#6420: Cloning decision tree estimators breaks criter…
Browse files Browse the repository at this point in the history
…ion objects (scikit-learn#7680)
  • Loading branch information
olologin authored and Sundrique committed Jun 14, 2017
1 parent d41f06c commit 2f7f34c
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 8 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -67,6 +67,10 @@ Bug fixes
<https://github.com/scikit-learn/scikit-learn/pull/6178>`_) by `Bertrand
Thirion`_

- Tree splitting criterion classes' cloning/pickling is now memory safe
(`#7680 <https://github.com/scikit-learn/scikit-learn/pull/7680>`_).
By `Ibraim Ganiev`_.

.. _changes_0_18_1:

Version 0.18.1
Expand Down
1 change: 1 addition & 0 deletions sklearn/tree/_criterion.pxd
Expand Up @@ -34,6 +34,7 @@ cdef class Criterion:
cdef SIZE_t end

cdef SIZE_t n_outputs # Number of outputs
cdef SIZE_t n_samples # Number of samples
cdef SIZE_t n_node_samples # Number of samples in the node (end-start)
cdef double weighted_n_samples # Weighted number of samples (in total)
cdef double weighted_n_node_samples # Weighted number of samples in the node
Expand Down
9 changes: 6 additions & 3 deletions sklearn/tree/_criterion.pyx
Expand Up @@ -235,6 +235,7 @@ cdef class ClassificationCriterion(Criterion):
self.end = 0

self.n_outputs = n_outputs
self.n_samples = 0
self.n_node_samples = 0
self.weighted_n_node_samples = 0.0
self.weighted_n_left = 0.0
Expand Down Expand Up @@ -273,11 +274,10 @@ cdef class ClassificationCriterion(Criterion):

def __dealloc__(self):
"""Destructor."""

free(self.n_classes)

def __reduce__(self):
return (ClassificationCriterion,
return (type(self),
(self.n_outputs,
sizet_ptr_to_ndarray(self.n_classes, self.n_outputs)),
self.__getstate__())
Expand Down Expand Up @@ -710,6 +710,7 @@ cdef class RegressionCriterion(Criterion):
self.end = 0

self.n_outputs = n_outputs
self.n_samples = n_samples
self.n_node_samples = 0
self.weighted_n_node_samples = 0.0
self.weighted_n_left = 0.0
Expand All @@ -734,7 +735,7 @@ cdef class RegressionCriterion(Criterion):
raise MemoryError()

def __reduce__(self):
return (RegressionCriterion, (self.n_outputs,), self.__getstate__())
return (type(self), (self.n_outputs, self.n_samples), self.__getstate__())

cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
Expand Down Expand Up @@ -881,6 +882,7 @@ cdef class MSE(RegressionCriterion):
MSE = var_left + var_right
"""

cdef double node_impurity(self) nogil:
"""Evaluate the impurity of the current node, i.e. the impurity of
samples[start:end]."""
Expand Down Expand Up @@ -1004,6 +1006,7 @@ cdef class MAE(RegressionCriterion):
self.end = 0

self.n_outputs = n_outputs
self.n_samples = n_samples
self.n_node_samples = 0
self.weighted_n_node_samples = 0.0
self.weighted_n_left = 0.0
Expand Down
3 changes: 1 addition & 2 deletions sklearn/tree/_tree.pyx
Expand Up @@ -547,8 +547,7 @@ cdef class Tree:
# (i.e. through `_resize` or `__setstate__`)
property n_classes:
def __get__(self):
# it's small; copy for memory safety
return sizet_ptr_to_ndarray(self.n_classes, self.n_outputs).copy()
return sizet_ptr_to_ndarray(self.n_classes, self.n_outputs)

property children_left:
def __get__(self):
Expand Down
4 changes: 2 additions & 2 deletions sklearn/tree/_utils.pyx
Expand Up @@ -62,10 +62,10 @@ cdef inline UINT32_t our_rand_r(UINT32_t* seed) nogil:


cdef inline np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size):
"""Encapsulate data into a 1D numpy array of intp's."""
"""Return copied data as 1D numpy array of intp's."""
cdef np.npy_intp shape[1]
shape[0] = <np.npy_intp> size
return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INTP, data)
return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INTP, data).copy()


cdef inline SIZE_t rand_int(SIZE_t low, SIZE_t high,
Expand Down
33 changes: 32 additions & 1 deletion sklearn/tree/tests/test_tree.py
@@ -1,6 +1,7 @@
"""
Testing for the tree module (sklearn.tree).
"""
import copy
import pickle
from functools import partial
from itertools import product
Expand Down Expand Up @@ -42,12 +43,14 @@

from sklearn import tree
from sklearn.tree._tree import TREE_LEAF
from sklearn.tree.tree import CRITERIA_CLF
from sklearn.tree.tree import CRITERIA_REG
from sklearn import datasets

from sklearn.utils import compute_sample_weight

CLF_CRITERIONS = ("gini", "entropy")
REG_CRITERIONS = ("mse", "mae")
REG_CRITERIONS = ("mse", "mae", "friedman_mse")

CLF_TREES = {
"DecisionTreeClassifier": DecisionTreeClassifier,
Expand Down Expand Up @@ -1597,6 +1600,7 @@ def test_no_sparse_y_support():
for name in ALL_TREES:
yield (check_no_sparse_y_support, name)


def test_mae():
# check MAE criterion produces correct results
# on small toy dataset
Expand All @@ -1609,3 +1613,30 @@ def test_mae():
dt_mae.fit([[3],[5],[3],[8],[5]],[6,7,3,4,3], [0.6,0.3,0.1,1.0,0.3])
assert_array_equal(dt_mae.tree_.impurity, [7.0/2.3, 3.0/0.7, 4.0/1.6])
assert_array_equal(dt_mae.tree_.value.flat, [4.0, 6.0, 4.0])


def test_criterion_copy():
# Let's check whether copy of our criterion has the same type
# and properties as original
n_outputs = 3
n_classes = np.arange(3, dtype=np.intp)
n_samples = 100

def _pickle_copy(obj):
return pickle.loads(pickle.dumps(obj))
for copy_func in [copy.copy, copy.deepcopy, _pickle_copy]:
for _, typename in CRITERIA_CLF.items():
criteria = typename(n_outputs, n_classes)
result = copy_func(criteria).__reduce__()
typename_, (n_outputs_, n_classes_), _ = result
assert_equal(typename, typename_)
assert_equal(n_outputs, n_outputs_)
assert_array_equal(n_classes, n_classes_)

for _, typename in CRITERIA_REG.items():
criteria = typename(n_outputs, n_samples)
result = copy_func(criteria).__reduce__()
typename_, (n_outputs_, n_samples_), _ = result
assert_equal(typename, typename_)
assert_equal(n_outputs, n_outputs_)
assert_equal(n_samples, n_samples_)

0 comments on commit 2f7f34c

Please sign in to comment.