Skip to content

Commit

Permalink
Bugfix: BaseEstimator __getstate__ in Python 3.11
Browse files Browse the repository at this point in the history
Since Python 3.11, objects have a __getstate__ method by default:

python/cpython#70766

Therefore, the exception in BaseEstimator.__getstate__ will no longer be
raised, thus not falling back on using the object's __dict__:

https://github.com/scikit-learn/scikit-learn/blob/dc580a8ef5ee2a8aea80498388690e2213118efd/sklearn/base.py#L274-L280

If the instance dict of the object is empty, the return value will,
however, be None. Therefore, the line below calling state.items()
results in an error.

In this bugfix, it is checked if the state is None and if it is, the
object's __dict__ is used (which should always be empty).

Not addressed in this PR is how to deal with slots (see also discussion
in scikit-learn#10079). When there are __slots__, __getstate__ will actually return
a tuple, as documented here:

https://docs.python.org/3/library/pickle.html#object.__getstate__

The user would thus still get an indiscriptive error message.
  • Loading branch information
BenjaminBossan committed Dec 13, 2022
1 parent c0eb3d3 commit d5d91ab
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
4 changes: 4 additions & 0 deletions sklearn/base.py
Expand Up @@ -273,7 +273,11 @@ def __repr__(self, N_CHAR_MAX=700):
def __getstate__(self):
try:
state = super().__getstate__()
if state is None:
state = self.__dict__.copy()
except AttributeError:
# TODO: Remove once Python < 3.11 is dropped, as there will never be
# an AttributeError
state = self.__dict__.copy()

if type(self).__module__.startswith("sklearn."):
Expand Down
24 changes: 24 additions & 0 deletions sklearn/tests/test_base.py
Expand Up @@ -675,3 +675,27 @@ def test_clone_keeps_output_config():
ss_clone = clone(ss)
config_clone = _get_output_config("transform", ss_clone)
assert config == config_clone


def test_parent_object_empty_instance_dict():
# Since Python 3.11, Python objects have a __getstate__ method by default
# that returns None if the instance dict is empty
class Empty:
pass

class Estimator(Empty, BaseEstimator):
pass

state = Estimator().__getstate__()
expected = {"_sklearn_version": sklearn.__version__}
assert state == expected


def test_base_estimator_empty_instance_dict():
# Since Python 3.11, Python objects have a __getstate__ method by default
# that returns None if the instance dict is empty

# this should not raise
state = BaseEstimator().__getstate__()
expected = {"_sklearn_version": sklearn.__version__}
assert state == expected

0 comments on commit d5d91ab

Please sign in to comment.