From d63a101e28409f4ac7fbd418779b70c702e8b388 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Thu, 12 Oct 2023 02:51:20 +0300 Subject: [PATCH] Simplify and fix DecimalAnnotation --- beanie/odm/custom_types/decimal.py | 49 ++++--------------- .../custom_types/test_decimal_annotation.py | 16 ++++++ 2 files changed, 26 insertions(+), 39 deletions(-) create mode 100644 tests/odm/custom_types/test_decimal_annotation.py diff --git a/beanie/odm/custom_types/decimal.py b/beanie/odm/custom_types/decimal.py index c0027580..ea03680f 100644 --- a/beanie/odm/custom_types/decimal.py +++ b/beanie/odm/custom_types/decimal.py @@ -1,47 +1,18 @@ -# check python version +import decimal import sys +import bson +import pydantic + if sys.version_info >= (3, 9): from typing import Annotated else: from typing_extensions import Annotated -from decimal import Decimal as NativeDecimal -from typing import Any, Callable - -from bson import Decimal128 -from pydantic import GetJsonSchemaHandler -from pydantic.fields import FieldInfo -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import core_schema - - -class DecimalCustomAnnotation: - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: Callable[[Any], core_schema.CoreSchema], # type: ignore - ) -> core_schema.CoreSchema: # type: ignore - def validate(value, _: FieldInfo) -> NativeDecimal: - if isinstance(value, Decimal128): - return value.to_decimal() - return value - - python_schema = core_schema.general_plain_validator_function(validate) # type: ignore - - return core_schema.json_or_python_schema( - json_schema=core_schema.float_schema(), - python_schema=python_schema, - ) - - @classmethod - def __get_pydantic_json_schema__( - cls, - _core_schema: core_schema.CoreSchema, # type: ignore - handler: GetJsonSchemaHandler, - ) -> JsonSchemaValue: - return handler(core_schema.float_schema()) - -DecimalAnnotation = Annotated[NativeDecimal, DecimalCustomAnnotation] +DecimalAnnotation = Annotated[ + decimal.Decimal, + pydantic.BeforeValidator( + lambda v: v.to_decimal() if isinstance(v, bson.Decimal128) else v + ), +] diff --git a/tests/odm/custom_types/test_decimal_annotation.py b/tests/odm/custom_types/test_decimal_annotation.py new file mode 100644 index 00000000..bb3970d5 --- /dev/null +++ b/tests/odm/custom_types/test_decimal_annotation.py @@ -0,0 +1,16 @@ +from decimal import Decimal + +from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 +from tests.odm.models import DocumentWithDecimalField + + +def test_decimal_deserialize(): + m = DocumentWithDecimalField(amt=Decimal("1.4")) + if IS_PYDANTIC_V2: + m_json = m.model_dump_json() + m_from_json = DocumentWithDecimalField.model_validate_json(m_json) + else: + m_json = m.json() + m_from_json = DocumentWithDecimalField.parse_raw(m_json) + assert isinstance(m_from_json.amt, Decimal) + assert m_from_json.amt == Decimal("1.4")