Skip to content

Commit

Permalink
[Python] support secure mode for python (#584)
Browse files Browse the repository at this point in the history
* support secure mode for python

* add env variable for for secure mode

* lint code
  • Loading branch information
chaokunyang committed Jul 11, 2023
1 parent ffcc761 commit ff69814
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 14 deletions.
52 changes: 49 additions & 3 deletions python/pyfury/_fury.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import enum
import logging
import os
import sys
from dataclasses import dataclass
from typing import Dict, Tuple, TypeVar, Optional, Union, Iterable
Expand Down Expand Up @@ -573,6 +574,7 @@ class Fury:
"reference_resolver",
"class_resolver",
"serialization_context",
"secure_mode",
"buffer",
"pickler",
"unpickler",
Expand All @@ -586,8 +588,14 @@ class Fury:
serialization_context: "SerializationContext"
unpickler: Optional[pickle.Unpickler]

def __init__(self, language=Language.XLANG, reference_tracking: bool = True):
def __init__(
self,
language=Language.XLANG,
reference_tracking: bool = True,
secure_mode: bool = True,
):
self.language = language
self.secure_mode = _ENABLE_SECURITY_MODE_FORCIBLY or secure_mode
self.reference_tracking = reference_tracking
if self.reference_tracking:
self.reference_resolver = MapReferenceResolver()
Expand All @@ -597,7 +605,10 @@ def __init__(self, language=Language.XLANG, reference_tracking: bool = True):
self.class_resolver.initialize()
self.serialization_context = SerializationContext()
self.buffer = Buffer.allocate(32)
self.pickler = pickle.Pickler(self.buffer)
if not secure_mode:
self.pickler = pickle.Pickler(self.buffer)
else:
self.pickler = _PicklerStub(self.buffer)
self.unpickler = None
self._buffer_callback = None
self._buffers = None
Expand Down Expand Up @@ -791,7 +802,10 @@ def _deserialize(
):
if type(buffer) == bytes:
buffer = Buffer(buffer)
self.unpickler = pickle.Unpickler(buffer)
if self.secure_mode:
self.unpickler = _UnpicklerStub(buffer)
else:
self.unpickler = pickle.Unpickler(buffer)
if unsupported_objects is not None:
self._unsupported_objects = iter(unsupported_objects)
reader_index = buffer.reader_index
Expand Down Expand Up @@ -980,3 +994,35 @@ def reset_read(self):
def reset(self):
self.reset_write()
self.reset_read()


_ENABLE_SECURITY_MODE_FORCIBLY = os.getenv("ENABLE_SECURITY_MODE_FORCIBLY", "0") in {
"1",
"true",
}


class _PicklerStub:
def __init__(self, buf):
self.buf = buf

def dump(self, o):
raise ValueError(
f"Class {type(o)} is not registered, "
f"pickle is not allowed when secure mode enabled, Please register"
f"the class or pass unsupported_callback"
)

def clear_memo(self):
pass


class _UnpicklerStub:
def __init__(self, buf):
self.buf = buf

def load(self):
raise ValueError(
f"pickle is not allowed when secure mode enabled, Please register"
f"the class or pass unsupported_callback"
)
22 changes: 17 additions & 5 deletions python/pyfury/_serialization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import sys
from typing import TypeVar, Union, Iterable, get_type_hints

from pyfury._util import get_bit, set_bit, clear_bit
from pyfury._fury import Language, OpaqueObject
from pyfury._fury import Language, OpaqueObject, _PicklerStub, _UnpicklerStub, _ENABLE_SECURITY_MODE_FORCIBLY
from pyfury.error import ClassNotCompatibleError
from pyfury.lib import mmh3
from pyfury.type import is_primitive_type, FuryType, Int8Type, Int16Type, Int32Type, \
Expand Down Expand Up @@ -781,6 +781,7 @@ cdef class ClassInfo:
cdef class Fury:
cdef readonly object language
cdef readonly c_bool reference_tracking
cdef readonly c_bool secure_mode
cdef readonly MapReferenceResolver reference_resolver
cdef readonly ClassResolver class_resolver
cdef readonly SerializationContext serialization_context
Expand All @@ -794,16 +795,24 @@ cdef class Fury:
cdef object _peer_language
cdef list _native_objects

def __init__(self, language=Language.XLANG, reference_tracking: bool = True):
def __init__(
self,
language=Language.XLANG,
reference_tracking: bool = True,
secure_mode: bool = True,
):
self.language = language

self.secure_mode = _ENABLE_SECURITY_MODE_FORCIBLY or secure_mode
self.reference_tracking = reference_tracking
self.reference_resolver = MapReferenceResolver(reference_tracking)
self.class_resolver = ClassResolver(self)
self.class_resolver.initialize()
self.serialization_context = SerializationContext()
self.buffer = Buffer.allocate(32)
self.pickler = pickle.Pickler(self.buffer)
if not secure_mode:
self.pickler = pickle.Pickler(self.buffer)
else:
self.pickler = _PicklerStub(self.buffer)
self.unpickler = None
self._buffer_callback = None
self._buffers = None
Expand Down Expand Up @@ -998,7 +1007,10 @@ cdef class Fury:

cpdef inline _deserialize(
self, Buffer buffer, buffers=None, unsupported_objects=None):
self.unpickler = pickle.Unpickler(buffer)
if self.secure_mode:
self.unpickler = _UnpicklerStub(buffer)
else:
self.unpickler = pickle.Unpickler(buffer)
if unsupported_objects is not None:
self._unsupported_objects = iter(unsupported_objects)
cdef int32_t reader_index = buffer.reader_index
Expand Down
1 change: 0 additions & 1 deletion python/pyfury/lib/tests/test_mmh3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@

def test_mmh3():
assert mmh3.hash_buffer(bytearray([1, 2, 3]), seed=47)[0] == -7373655978913577904

10 changes: 5 additions & 5 deletions python/pyfury/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __init__(self, f1):
def test_reference_cleanup(language):
# FIXME this can't simulate the case where new objects are allocated on memory
# address of released tmp object.
fury_ = Fury(language=language, reference_tracking=True)
fury_ = Fury(language=language, reference_tracking=True, secure_mode=False)
# TODO support Language.XLANG, current unpickler will error for xlang,
o1 = RefTestClass1()
o2 = RefTestClass2(f1=o1)
Expand All @@ -207,7 +207,7 @@ def test_reference_cleanup(language):

@pytest.mark.parametrize("language", [Language.XLANG, Language.PYTHON])
def test_array_serializer(language):
fury_ = Fury(language=language, reference_tracking=True)
fury_ = Fury(language=language, reference_tracking=True, secure_mode=False)
for typecode in PyArraySerializer.typecode_dict.keys():
arr = array.array(typecode, list(range(10)))
assert ser_de(fury_, arr) == arr
Expand Down Expand Up @@ -328,7 +328,7 @@ def __init__(self, f1=None):


def test_register_py_serializer():
fury_ = Fury(language=Language.PYTHON, reference_tracking=True)
fury_ = Fury(language=Language.PYTHON, reference_tracking=True, secure_mode=False)

class Serializer(pyfury.Serializer):
def write(self, buffer, value):
Expand Down Expand Up @@ -380,7 +380,7 @@ def cross_language_read(self, buffer):


def test_pickle_fallback():
fury_ = Fury(language=Language.PYTHON, reference_tracking=True)
fury_ = Fury(language=Language.PYTHON, reference_tracking=True, secure_mode=False)
o1 = [1, True, np.dtype(np.int32)]
data1 = fury_.serialize(o1)
new_o1 = fury_.deserialize(data1)
Expand Down Expand Up @@ -485,7 +485,7 @@ def test_cache_serializer():


def test_pandas_range_index():
fury = Fury(language=Language.PYTHON, reference_tracking=True)
fury = Fury(language=Language.PYTHON, reference_tracking=True, secure_mode=False)
fury.register_serializer(pd.RangeIndex, pyfury.PandasRangeIndexSerializer(fury))
index = pd.RangeIndex(1, 100, 2, name="a")
new_index = ser_de(fury, index)
Expand Down

0 comments on commit ff69814

Please sign in to comment.