/
pickler.py
236 lines (191 loc) · 8.16 KB
/
pickler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Pickler for values, functions, and classes.
Pickles created by the pickling library contain non-ASCII characters, so
we base64-encode the results so that we can put them in a JSON objects.
The pickler is used to embed FlatMap callable objects into the workflow JSON
description.
The pickler module should be used to pickle functions and modules; for values,
the coders.*PickleCoder classes should be used instead.
"""
import base64
import logging
import sys
import traceback
import types
import zlib
import dill
def is_nested_class(cls):
"""Returns true if argument is a class object that appears to be nested."""
return (isinstance(cls, type)
and cls.__module__ != '__builtin__'
and cls.__name__ not in sys.modules[cls.__module__].__dict__)
def find_containing_class(nested_class):
"""Finds containing class of a nestec class passed as argument."""
def find_containing_class_inner(outer):
for k, v in outer.__dict__.items():
if v is nested_class:
return outer, k
elif isinstance(v, (type, types.ClassType)) and hasattr(v, '__dict__'):
res = find_containing_class_inner(v)
if res: return res
return find_containing_class_inner(sys.modules[nested_class.__module__])
def _nested_type_wrapper(fun):
"""A wrapper for the standard pickler handler for class objects.
Args:
fun: Original pickler handler for type objects.
Returns:
A wrapper for type objects that handles nested classes.
The wrapper detects if an object being pickled is a nested class object.
For nested class object only it will save the containing class object so
the nested structure is recreated during unpickle.
"""
def wrapper(pickler, obj):
# When the nested class is defined in the __main__ module we do not have to
# do anything special because the pickler itself will save the constituent
# parts of the type (i.e., name, base classes, dictionary) and then
# recreate it during unpickling.
if is_nested_class(obj) and obj.__module__ != '__main__':
containing_class_and_name = find_containing_class(obj)
if containing_class_and_name is not None:
return pickler.save_reduce(
getattr, containing_class_and_name, obj=obj)
try:
return fun(pickler, obj)
except dill.dill.PicklingError:
# pylint: disable=protected-access
return pickler.save_reduce(
dill.dill._create_type,
(type(obj), obj.__name__, obj.__bases__,
dill.dill._dict_from_dictproxy(obj.__dict__)),
obj=obj)
# pylint: enable=protected-access
return wrapper
# Monkey patch the standard pickler dispatch table entry for type objects.
# Dill, for certain types, defers to the standard pickler (including type
# objects). We wrap the standard handler using type_wrapper() because
# for nested class we want to pickle the actual enclosing class object so we
# can recreate it during unpickling.
# TODO(silviuc): Make sure we submit the fix upstream to GitHub dill project.
dill.dill.Pickler.dispatch[type] = _nested_type_wrapper(
dill.dill.Pickler.dispatch[type])
# Dill pickles generators objects without complaint, but unpickling produces
# TypeError: object.__new__(generator) is not safe, use generator.__new__()
# on some versions of Python.
def reject_generators(unused_pickler, unused_obj):
raise TypeError("can't (safely) pickle generator objects")
dill.dill.Pickler.dispatch[types.GeneratorType] = reject_generators
# This if guards against dill not being full initialized when generating docs.
if 'save_module' in dir(dill.dill):
# Always pickle non-main modules by name.
old_save_module = dill.dill.save_module
@dill.dill.register(dill.dill.ModuleType)
def save_module(pickler, obj):
if dill.dill.is_dill(pickler) and obj is pickler._main:
return old_save_module(pickler, obj)
else:
dill.dill.log.info('M2: %s' % obj)
# pylint: disable=protected-access
pickler.save_reduce(dill.dill._import_module, (obj.__name__,), obj=obj)
# pylint: enable=protected-access
dill.dill.log.info('# M2')
# Pickle module dictionaries (commonly found in lambda's globals)
# by referencing their module.
old_save_module_dict = dill.dill.save_module_dict
known_module_dicts = {}
@dill.dill.register(dict)
def new_save_module_dict(pickler, obj):
obj_id = id(obj)
if not known_module_dicts or '__file__' in obj or '__package__' in obj:
if obj_id not in known_module_dicts:
for m in sys.modules.values():
try:
if m and m.__name__ != '__main__':
d = m.__dict__
known_module_dicts[id(d)] = m, d
except AttributeError:
# Skip modules that do not have the __name__ attribute.
pass
if obj_id in known_module_dicts and dill.dill.is_dill(pickler):
m = known_module_dicts[obj_id][0]
try:
# pylint: disable=protected-access
dill.dill._import_module(m.__name__)
return pickler.save_reduce(
getattr, (known_module_dicts[obj_id][0], '__dict__'), obj=obj)
except (ImportError, AttributeError):
return old_save_module_dict(pickler, obj)
else:
return old_save_module_dict(pickler, obj)
dill.dill.save_module_dict = new_save_module_dict
def _nest_dill_logging():
"""Prefix all dill logging with its depth in the callstack.
Useful for debugging pickling of deeply nested structures.
"""
old_log_info = dill.dill.log.info
def new_log_info(msg, *args, **kwargs):
old_log_info(
('1 2 3 4 5 6 7 8 9 0 ' * 10)[:len(traceback.extract_stack())] + msg,
*args, **kwargs)
dill.dill.log.info = new_log_info
# Turn off verbose logging from the dill pickler.
logging.getLogger('dill').setLevel(logging.WARN)
# TODO(ccy): Currently, there are still instances of pickler.dumps() and
# pickler.loads() being used for data, which results in an unnecessary base64
# encoding. This should be cleaned up.
def dumps(o, enable_trace=True):
try:
s = dill.dumps(o)
except Exception: # pylint: disable=broad-except
if enable_trace:
dill.dill._trace(True) # pylint: disable=protected-access
s = dill.dumps(o)
else:
raise
finally:
dill.dill._trace(False) # pylint: disable=protected-access
# Compress as compactly as possible to decrease peak memory usage (of multiple
# in-memory copies) and free up some possibly large and no-longer-needed
# memory.
c = zlib.compress(s, 9)
del s
return base64.b64encode(c)
def loads(encoded, enable_trace=True):
c = base64.b64decode(encoded)
s = zlib.decompress(c)
del c # Free up some possibly large and no-longer-needed memory.
try:
return dill.loads(s)
except Exception: # pylint: disable=broad-except
if enable_trace:
dill.dill._trace(True) # pylint: disable=protected-access
return dill.loads(s)
else:
raise
finally:
dill.dill._trace(False) # pylint: disable=protected-access
def dump_session(file_path):
"""Pickle the current python session to be used in the worker.
Note: Due to the inconsistency in the first dump of dill dump_session we
create and load the dump twice to have consistent results in the worker and
the running session. Check: https://github.com/uqfoundation/dill/issues/195
"""
dill.dump_session(file_path)
dill.load_session(file_path)
return dill.dump_session(file_path)
def load_session(file_path):
return dill.load_session(file_path)