Skip to content

Commit

Permalink
Merge 60e2303 into 090f890
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Jun 19, 2023
2 parents 090f890 + 60e2303 commit c8f917b
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 4 deletions.
3 changes: 3 additions & 0 deletions mashumaro/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@
"TO_DICT_ADD_BY_ALIAS_FLAG",
"TO_DICT_ADD_OMIT_NONE_FLAG",
"ADD_DIALECT_SUPPORT",
"ADD_SERIALIZATION_CONTEXT",
"SerializationStrategyValueType",
]


TO_DICT_ADD_BY_ALIAS_FLAG = "TO_DICT_ADD_BY_ALIAS_FLAG"
TO_DICT_ADD_OMIT_NONE_FLAG = "TO_DICT_ADD_OMIT_NONE_FLAG"
ADD_DIALECT_SUPPORT = "ADD_DIALECT_SUPPORT"
ADD_SERIALIZATION_CONTEXT = "ADD_SERIALIZATION_CONTEXT"


CodeGenerationOption = Literal[
"TO_DICT_ADD_BY_ALIAS_FLAG",
"TO_DICT_ADD_OMIT_NONE_FLAG",
"ADD_DIALECT_SUPPORT",
"ADD_SERIALIZATION_CONTEXT",
]


Expand Down
27 changes: 26 additions & 1 deletion mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mashumaro.config import (
ADD_DIALECT_SUPPORT,
ADD_SERIALIZATION_CONTEXT,
TO_DICT_ADD_BY_ALIAS_FLAG,
TO_DICT_ADD_OMIT_NONE_FLAG,
BaseConfig,
Expand Down Expand Up @@ -547,6 +548,7 @@ def get_pack_method_flags(
(TO_DICT_ADD_OMIT_NONE_FLAG, "omit_none"),
(TO_DICT_ADD_BY_ALIAS_FLAG, "by_alias"),
(ADD_DIALECT_SUPPORT, "dialect"),
(ADD_SERIALIZATION_CONTEXT, "context"),
):
if self.is_code_generation_option_enabled(option, cls):
if self.is_code_generation_option_enabled(option):
Expand Down Expand Up @@ -582,6 +584,7 @@ def get_pack_method_default_flag_values(
for value in self._get_encoder_kwargs(cls).values():
kw_param_names.append(value[0])
kw_param_values.append(value[1])

omit_none_feature = self.is_code_generation_option_enabled(
TO_DICT_ADD_OMIT_NONE_FLAG, cls
)
Expand All @@ -591,19 +594,29 @@ def get_pack_method_default_flag_values(
)
kw_param_names.append("omit_none")
kw_param_values.append("True" if omit_none else "False")

by_alias_feature = self.is_code_generation_option_enabled(
TO_DICT_ADD_BY_ALIAS_FLAG, cls
)
if by_alias_feature:
serialize_by_alias = self.get_config(cls).serialize_by_alias
kw_param_names.append("by_alias")
kw_param_values.append("True" if serialize_by_alias else "False")

dialects_feature = self.is_code_generation_option_enabled(
ADD_DIALECT_SUPPORT, cls
)
if dialects_feature:
kw_param_names.append("dialect")
kw_param_values.append("None")

context_feature = self.is_code_generation_option_enabled(
ADD_SERIALIZATION_CONTEXT, cls
)
if context_feature:
kw_param_names.append("context")
kw_param_values.append("None")

if pos_param_names:
pluggable_flags_str = ", ".join(
[f"{n}={v}" for n, v in zip(pos_param_names, pos_param_values)]
Expand Down Expand Up @@ -718,7 +731,15 @@ def _add_pack_method_lines(self, method_name: str) -> None:
else:
pre_serialize = self.get_declared_hook(__PRE_SERIALIZE__)
if pre_serialize:
self.add_line(f"self = self.{__PRE_SERIALIZE__}()")
if self.is_code_generation_option_enabled(
ADD_SERIALIZATION_CONTEXT
):
pre_serialize_args = "context=context"
else:
pre_serialize_args = ""
self.add_line(
f"self = self.{__PRE_SERIALIZE__}({pre_serialize_args})"
)
by_alias_feature = self.is_code_generation_option_enabled(
TO_DICT_ADD_BY_ALIAS_FLAG
)
Expand Down Expand Up @@ -792,6 +813,10 @@ def _add_pack_method_lines(self, method_name: str) -> None:
else:
return_statement = "return {}"
if post_serialize:
if self.is_code_generation_option_enabled(
ADD_SERIALIZATION_CONTEXT
):
kwargs = f"{kwargs}, context=context"
self.add_line(
return_statement.format(
f"self.{__POST_SERIALIZE__}({kwargs})"
Expand Down
11 changes: 9 additions & 2 deletions mashumaro/mixins/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,15 @@ def __pre_deserialize__(cls: Type[T], d: Dict[Any, Any]) -> Dict[Any, Any]:
def __post_deserialize__(cls: Type[T], obj: T) -> T:
...

def __pre_serialize__(self: T) -> T:
def __pre_serialize__(
self: T,
# context: Any = None, # added with ADD_SERIALIZATION_CONTEXT option
) -> T:
...

def __post_serialize__(self: T, d: Dict[Any, Any]) -> Dict[Any, Any]:
def __post_serialize__(
self: T,
d: Dict[Any, Any],
# context: Any = None, # added with ADD_SERIALIZATION_CONTEXT option
) -> Dict[Any, Any]:
...
59 changes: 58 additions & 1 deletion tests/test_hooks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,60 @@
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, no_type_check
from typing import Any, ClassVar, Dict, Optional, no_type_check

import pytest

from mashumaro import DataClassDictMixin
from mashumaro.config import ADD_SERIALIZATION_CONTEXT, BaseConfig
from mashumaro.exceptions import BadHookSignature


class BaseClassWithSerializationContext(DataClassDictMixin):
class Config(BaseConfig):
code_generation_options = [ADD_SERIALIZATION_CONTEXT]


@dataclass
class Foo(BaseClassWithSerializationContext):
baz: int

class Config(BaseConfig):
code_generation_options = []

def __pre_serialize__(self):
return self

def __post_serialize__(self, d: Dict):
return d


@dataclass
class Bar(BaseClassWithSerializationContext):
baz: int

def __pre_serialize__(self, context: Optional[Dict] = None):
return self

def __post_serialize__(self, d: Dict, context: Optional[Dict] = None):
if context and context.get("omit_baz"):
d.pop("baz")
return d


@dataclass
class FooBarBaz(BaseClassWithSerializationContext):
foo: Foo
bar: Bar
baz: int

def __pre_serialize__(self, context: Optional[Dict] = None):
return self

def __post_serialize__(self, d: Dict, context: Optional[Dict] = None):
if context and context.get("omit_baz"):
d.pop("baz")
return d


def test_bad_pre_deserialize_hook():
with pytest.raises(BadHookSignature):

Expand Down Expand Up @@ -143,3 +191,12 @@ class B(A, DataClassDictMixin):
post_deserialize_hook.assert_called_once()
pre_serialize_hook.assert_called_once()
post_serialize_hook.assert_called_once()


def test_passing_context_into_hook():
foo = FooBarBaz(foo=Foo(1), bar=Bar(baz=2), baz=3)
assert foo.to_dict() == {"foo": {"baz": 1}, "bar": {"baz": 2}, "baz": 3}
assert foo.to_dict(context={"omit_baz": True}) == {
"foo": {"baz": 1},
"bar": {},
}

0 comments on commit c8f917b

Please sign in to comment.