Skip to content

Commit

Permalink
Merge branch 'serialization-context'
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Jun 22, 2023
2 parents 764c163 + 0166050 commit 5f46479
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 4 deletions.
88 changes: 88 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Table of contents
* [Add `omit_none` keyword argument](#add-omit_none-keyword-argument)
* [Add `by_alias` keyword argument](#add-by_alias-keyword-argument)
* [Add `dialect` keyword argument](#add-dialect-keyword-argument)
* [Add `context` keyword argument](#add-context-keyword-argument)
* [Generic dataclasses](#generic-dataclasses)
* [Generic dataclass inheritance](#generic-dataclass-inheritance)
* [Generic dataclass in a field type](#generic-dataclass-in-a-field-type)
Expand Down Expand Up @@ -1082,6 +1083,7 @@ described below.
| [`TO_DICT_ADD_OMIT_NONE_FLAG`](#add-omit_none-keyword-argument) | Adds `omit_none` keyword-only argument to `to_*` methods. |
| [`TO_DICT_ADD_BY_ALIAS_FLAG`](#add-by_alias-keyword-argument) | Adds `by_alias` keyword-only argument to `to_*` methods. |
| [`ADD_DIALECT_SUPPORT`](#add-dialect-keyword-argument) | Adds `dialect` keyword-only argument to `from_*` and `to_*` methods. |
| [`ADD_SERIALIZATION_CONTEXT`](#add-context-keyword-argument) | Adds `context` keyword-only argument to `to_*` methods. |

#### `serialization_strategy` config option

Expand Down Expand Up @@ -2050,6 +2052,86 @@ class Entity(DataClassDictMixin):
code_generation_options = [ADD_DIALECT_SUPPORT]
```

#### Add `context` keyword argument

Sometimes it's needed to pass a "context" object to the serialization hooks
that will take it into account. For example, you could want to have an option
to remove sensitive data from the serialization result if you need to.
You can add `context` parameter to `to_*` methods that will be passed to
[`__pre_serialize__`](#before-serialization) and
[`__post_serialize__`](#after-serialization) hooks. The type of this context
as well as its mutability is up to you.

```python
from dataclasses import dataclass
from typing import Dict, Optional
from uuid import UUID
from mashumaro import DataClassDictMixin
from mashumaro.config import BaseConfig, ADD_SERIALIZATION_CONTEXT

class BaseModel(DataClassDictMixin):
class Config(BaseConfig):
code_generation_options = [ADD_SERIALIZATION_CONTEXT]

@dataclass
class Account(BaseModel):
id: UUID
username: str
name: str

def __pre_serialize__(self, context: Optional[Dict] = None):
return self

def __post_serialize__(self, d: Dict, context: Optional[Dict] = None):
if context and context.get("remove_sensitive_data"):
d["username"] = "***"
d["name"] = "***"
return d

@dataclass
class Session(BaseModel):
id: UUID
key: str
account: Account

def __pre_serialize__(self, context: Optional[Dict] = None):
return self

def __post_serialize__(self, d: Dict, context: Optional[Dict] = None):
if context and context.get("remove_sensitive_data"):
d["key"] = "***"
return d


foo = Session(
id=UUID('03321c9f-6a97-421e-9869-918ff2867a71'),
key="VQ6Q9bX4c8s",
account=Account(
id=UUID('4ef2baa7-edef-4d6a-b496-71e6d72c58fb'),
username="john_doe",
name="John"
)
)
assert foo.to_dict() == {
'id': '03321c9f-6a97-421e-9869-918ff2867a71',
'key': 'VQ6Q9bX4c8s',
'account': {
'id': '4ef2baa7-edef-4d6a-b496-71e6d72c58fb',
'username': 'john_doe',
'name': 'John'
}
}
assert foo.to_dict(context={"remove_sensitive_data": True}) == {
'id': '03321c9f-6a97-421e-9869-918ff2867a71',
'key': '***',
'account': {
'id': '4ef2baa7-edef-4d6a-b496-71e6d72c58fb',
'username': '***',
'name': '***'
}
}
```

### Generic dataclasses

Along with [user-defined generic types](#user-defined-generic-types)
Expand Down Expand Up @@ -2240,6 +2322,9 @@ obj.to_json()
print(obj.counter) # 2
```

Note that you can add an additional `context` argument using the
[corresponding](#add-context-keyword-argument) code generation option.

#### After serialization

For doing something with a dictionary that was created as a result of
Expand All @@ -2260,6 +2345,9 @@ print(obj.to_dict()) # {"user": "name"}
print(obj.to_json()) # '{"user": "name"}'
```

Note that you can add an additional `context` argument using the
[corresponding](#add-context-keyword-argument) code generation option.

JSON Schema
--------------------------------------------------------------------------------

Expand Down
3 changes: 3 additions & 0 deletions mashumaro/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@
"TO_DICT_ADD_BY_ALIAS_FLAG",
"TO_DICT_ADD_OMIT_NONE_FLAG",
"ADD_DIALECT_SUPPORT",
"ADD_SERIALIZATION_CONTEXT",
"SerializationStrategyValueType",
]


TO_DICT_ADD_BY_ALIAS_FLAG = "TO_DICT_ADD_BY_ALIAS_FLAG"
TO_DICT_ADD_OMIT_NONE_FLAG = "TO_DICT_ADD_OMIT_NONE_FLAG"
ADD_DIALECT_SUPPORT = "ADD_DIALECT_SUPPORT"
ADD_SERIALIZATION_CONTEXT = "ADD_SERIALIZATION_CONTEXT"


CodeGenerationOption = Literal[
"TO_DICT_ADD_BY_ALIAS_FLAG",
"TO_DICT_ADD_OMIT_NONE_FLAG",
"ADD_DIALECT_SUPPORT",
"ADD_SERIALIZATION_CONTEXT",
]


Expand Down
27 changes: 26 additions & 1 deletion mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from mashumaro.config import (
ADD_DIALECT_SUPPORT,
ADD_SERIALIZATION_CONTEXT,
TO_DICT_ADD_BY_ALIAS_FLAG,
TO_DICT_ADD_OMIT_NONE_FLAG,
BaseConfig,
Expand Down Expand Up @@ -627,6 +628,7 @@ def get_pack_method_flags(
(TO_DICT_ADD_OMIT_NONE_FLAG, "omit_none"),
(TO_DICT_ADD_BY_ALIAS_FLAG, "by_alias"),
(ADD_DIALECT_SUPPORT, "dialect"),
(ADD_SERIALIZATION_CONTEXT, "context"),
):
if self.is_code_generation_option_enabled(option, cls):
if self.is_code_generation_option_enabled(option):
Expand Down Expand Up @@ -662,6 +664,7 @@ def get_pack_method_default_flag_values(
for value in self._get_encoder_kwargs(cls).values():
kw_param_names.append(value[0])
kw_param_values.append(value[1])

omit_none_feature = self.is_code_generation_option_enabled(
TO_DICT_ADD_OMIT_NONE_FLAG, cls
)
Expand All @@ -671,19 +674,29 @@ def get_pack_method_default_flag_values(
)
kw_param_names.append("omit_none")
kw_param_values.append("True" if omit_none else "False")

by_alias_feature = self.is_code_generation_option_enabled(
TO_DICT_ADD_BY_ALIAS_FLAG, cls
)
if by_alias_feature:
serialize_by_alias = self.get_config(cls).serialize_by_alias
kw_param_names.append("by_alias")
kw_param_values.append("True" if serialize_by_alias else "False")

dialects_feature = self.is_code_generation_option_enabled(
ADD_DIALECT_SUPPORT, cls
)
if dialects_feature:
kw_param_names.append("dialect")
kw_param_values.append("None")

context_feature = self.is_code_generation_option_enabled(
ADD_SERIALIZATION_CONTEXT, cls
)
if context_feature:
kw_param_names.append("context")
kw_param_values.append("None")

if pos_param_names:
pluggable_flags_str = ", ".join(
[f"{n}={v}" for n, v in zip(pos_param_names, pos_param_values)]
Expand Down Expand Up @@ -804,7 +817,15 @@ def _add_pack_method_lines(self, method_name: str) -> None:
else:
pre_serialize = self.get_declared_hook(__PRE_SERIALIZE__)
if pre_serialize:
self.add_line(f"self = self.{__PRE_SERIALIZE__}()")
if self.is_code_generation_option_enabled(
ADD_SERIALIZATION_CONTEXT
):
pre_serialize_args = "context=context"
else:
pre_serialize_args = ""
self.add_line(
f"self = self.{__PRE_SERIALIZE__}({pre_serialize_args})"
)
by_alias_feature = self.is_code_generation_option_enabled(
TO_DICT_ADD_BY_ALIAS_FLAG
)
Expand Down Expand Up @@ -878,6 +899,10 @@ def _add_pack_method_lines(self, method_name: str) -> None:
else:
return_statement = "return {}"
if post_serialize:
if self.is_code_generation_option_enabled(
ADD_SERIALIZATION_CONTEXT
):
kwargs = f"{kwargs}, context=context"
self.add_line(
return_statement.format(
f"self.{__POST_SERIALIZE__}({kwargs})"
Expand Down
11 changes: 9 additions & 2 deletions mashumaro/mixins/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,15 @@ def __pre_deserialize__(cls: Type[T], d: Dict[Any, Any]) -> Dict[Any, Any]:
def __post_deserialize__(cls: Type[T], obj: T) -> T:
...

def __pre_serialize__(self: T) -> T:
def __pre_serialize__(
self: T,
# context: Any = None, # added with ADD_SERIALIZATION_CONTEXT option
) -> T:
...

def __post_serialize__(self: T, d: Dict[Any, Any]) -> Dict[Any, Any]:
def __post_serialize__(
self: T,
d: Dict[Any, Any],
# context: Any = None, # added with ADD_SERIALIZATION_CONTEXT option
) -> Dict[Any, Any]:
...
59 changes: 58 additions & 1 deletion tests/test_hooks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,60 @@
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, no_type_check
from typing import Any, ClassVar, Dict, Optional, no_type_check

import pytest

from mashumaro import DataClassDictMixin
from mashumaro.config import ADD_SERIALIZATION_CONTEXT, BaseConfig
from mashumaro.exceptions import BadHookSignature


class BaseClassWithSerializationContext(DataClassDictMixin):
class Config(BaseConfig):
code_generation_options = [ADD_SERIALIZATION_CONTEXT]


@dataclass
class Foo(BaseClassWithSerializationContext):
baz: int

class Config(BaseConfig):
code_generation_options = []

def __pre_serialize__(self):
return self

def __post_serialize__(self, d: Dict):
return d


@dataclass
class Bar(BaseClassWithSerializationContext):
baz: int

def __pre_serialize__(self, context: Optional[Dict] = None):
return self

def __post_serialize__(self, d: Dict, context: Optional[Dict] = None):
if context and context.get("omit_baz"):
d.pop("baz")
return d


@dataclass
class FooBarBaz(BaseClassWithSerializationContext):
foo: Foo
bar: Bar
baz: int

def __pre_serialize__(self, context: Optional[Dict] = None):
return self

def __post_serialize__(self, d: Dict, context: Optional[Dict] = None):
if context and context.get("omit_baz"):
d.pop("baz")
return d


def test_bad_pre_deserialize_hook():
with pytest.raises(BadHookSignature):

Expand Down Expand Up @@ -143,3 +191,12 @@ class B(A, DataClassDictMixin):
post_deserialize_hook.assert_called_once()
pre_serialize_hook.assert_called_once()
post_serialize_hook.assert_called_once()


def test_passing_context_into_hook():
foo = FooBarBaz(foo=Foo(1), bar=Bar(baz=2), baz=3)
assert foo.to_dict() == {"foo": {"baz": 1}, "bar": {"baz": 2}, "baz": 3}
assert foo.to_dict(context={"omit_baz": True}) == {
"foo": {"baz": 1},
"bar": {},
}

0 comments on commit 5f46479

Please sign in to comment.