Skip to content

Commit

Permalink
[MRG] add warning when importing old or new pickle. (scikit-learn#7248)
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller authored and TomDLT committed Oct 3, 2016
1 parent 31be86a commit c62b049
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 8 deletions.
19 changes: 19 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .utils.fixes import signature
from .utils.deprecation import deprecated
from .exceptions import ChangedBehaviorWarning as _ChangedBehaviorWarning
from . import __version__


@deprecated("ChangedBehaviorWarning has been moved into the sklearn.exceptions"
Expand Down Expand Up @@ -296,6 +297,24 @@ def __repr__(self):
return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
offset=len(class_name),),)

def __getstate__(self):
if type(self).__module__.startswith('sklearn.'):
return dict(self.__dict__.items(), _sklearn_version=__version__)
else:
return dict(self.__dict__.items())

def __setstate__(self, state):
if type(self).__module__.startswith('sklearn.'):
pickle_version = state.pop("_sklearn_version", "pre-0.18")
if pickle_version != __version__:
warnings.warn(
"Trying to unpickle estimator {0} from version {1} when "
"using version {2}. This might lead to breaking code or "
"invalid results. Use at your own risk.".format(
self.__class__.__name__, pickle_version, __version__),
UserWarning)
self.__dict__.update(state)


###############################################################################
class ClassifierMixin(object):
Expand Down
5 changes: 2 additions & 3 deletions sklearn/isotonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,7 @@ def predict(self, T):

def __getstate__(self):
"""Pickle-protocol - return state of the estimator. """
# copy __dict__
state = dict(self.__dict__)
state = super(IsotonicRegression, self).__getstate__()
# remove interpolation method
state.pop('f_', None)
return state
Expand All @@ -423,6 +422,6 @@ def __setstate__(self, state):
We need to rebuild the interpolation function.
"""
self.__dict__.update(state)
super(IsotonicRegression, self).__setstate__(state)
if hasattr(self, '_necessary_X_') and hasattr(self, '_necessary_y_'):
self._build_f(self._necessary_X_, self._necessary_y_)
56 changes: 51 additions & 5 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,29 @@
import numpy as np
import scipy.sparse as sp

import sklearn
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_false
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_not_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_no_warnings
from sklearn.utils.testing import assert_warns_message

from sklearn.base import BaseEstimator, clone, is_classifier
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
from sklearn import datasets
from sklearn.utils import deprecated

from sklearn.base import TransformerMixin
from sklearn.utils.mocking import MockDataFrame
import pickle


#############################################################################
Expand Down Expand Up @@ -235,8 +242,8 @@ def test_is_classifier():
assert_true(is_classifier(svc))
assert_true(is_classifier(GridSearchCV(svc, {'C': [0.1, 1]})))
assert_true(is_classifier(Pipeline([('svc', svc)])))
assert_true(is_classifier(Pipeline([('svc_cv',
GridSearchCV(svc, {'C': [0.1, 1]}))])))
assert_true(is_classifier(Pipeline(
[('svc_cv', GridSearchCV(svc, {'C': [0.1, 1]}))])))


def test_set_params():
Expand All @@ -253,9 +260,6 @@ def test_set_params():


def test_score_sample_weight():
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
from sklearn import datasets

rng = np.random.RandomState(0)

Expand Down Expand Up @@ -313,3 +317,45 @@ def transform(self, X, y=None):
# the test
assert_true((e.df == cloned_e.df).values.all())
assert_equal(e.scalar_param, cloned_e.scalar_param)


class TreeNoVersion(DecisionTreeClassifier):
def __getstate__(self):
return self.__dict__


def test_pickle_version_warning():
# check that warnings are raised when unpickling in a different version

# first, check no warning when in the same version:
iris = datasets.load_iris()
tree = DecisionTreeClassifier().fit(iris.data, iris.target)
tree_pickle = pickle.dumps(tree)
assert_true(b"version" in tree_pickle)
assert_no_warnings(pickle.loads, tree_pickle)

# check that warning is raised on different version
tree_pickle_other = tree_pickle.replace(sklearn.__version__.encode(),
b"something")
message = ("Trying to unpickle estimator DecisionTreeClassifier from "
"version {0} when using version {1}. This might lead to "
"breaking code or invalid results. "
"Use at your own risk.".format("something",
sklearn.__version__))
assert_warns_message(UserWarning, message, pickle.loads, tree_pickle_other)

# check that not including any version also works:
# TreeNoVersion has no getstate, like pre-0.18
tree = TreeNoVersion().fit(iris.data, iris.target)

tree_pickle_noversion = pickle.dumps(tree)
assert_false(b"version" in tree_pickle_noversion)
message = message.replace("something", "pre-0.18")
message = message.replace("DecisionTreeClassifier", "TreeNoVersion")
# check we got the warning about using pre-0.18 pickle
assert_warns_message(UserWarning, message, pickle.loads,
tree_pickle_noversion)

# check that no warning is raised for external estimators
TreeNoVersion.__module__ = "notsklearn"
assert_no_warnings(pickle.loads, tree_pickle_noversion)
2 changes: 2 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,8 @@ def check_estimators_pickle(name, Estimator):

# pickle and unpickle!
pickled_estimator = pickle.dumps(estimator)
if Estimator.__module__.startswith('sklearn.'):
assert_true(b"version" in pickled_estimator)
unpickled_estimator = pickle.loads(pickled_estimator)

for method in result:
Expand Down

0 comments on commit c62b049

Please sign in to comment.