Skip to content

Commit

Permalink
Simplify and fix DecimalAnnotation
Browse files Browse the repository at this point in the history
  • Loading branch information
gsakkis committed Oct 11, 2023
1 parent 9372046 commit d63a101
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 39 deletions.
49 changes: 10 additions & 39 deletions beanie/odm/custom_types/decimal.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,18 @@
# check python version
import decimal
import sys

import bson
import pydantic

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated

from decimal import Decimal as NativeDecimal
from typing import Any, Callable

from bson import Decimal128
from pydantic import GetJsonSchemaHandler
from pydantic.fields import FieldInfo
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema


class DecimalCustomAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: Callable[[Any], core_schema.CoreSchema], # type: ignore
) -> core_schema.CoreSchema: # type: ignore
def validate(value, _: FieldInfo) -> NativeDecimal:
if isinstance(value, Decimal128):
return value.to_decimal()
return value

python_schema = core_schema.general_plain_validator_function(validate) # type: ignore

return core_schema.json_or_python_schema(
json_schema=core_schema.float_schema(),
python_schema=python_schema,
)

@classmethod
def __get_pydantic_json_schema__(
cls,
_core_schema: core_schema.CoreSchema, # type: ignore
handler: GetJsonSchemaHandler,
) -> JsonSchemaValue:
return handler(core_schema.float_schema())


DecimalAnnotation = Annotated[NativeDecimal, DecimalCustomAnnotation]
DecimalAnnotation = Annotated[
decimal.Decimal,
pydantic.BeforeValidator(
lambda v: v.to_decimal() if isinstance(v, bson.Decimal128) else v
),
]
16 changes: 16 additions & 0 deletions tests/odm/custom_types/test_decimal_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from decimal import Decimal

from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
from tests.odm.models import DocumentWithDecimalField


def test_decimal_deserialize():
m = DocumentWithDecimalField(amt=Decimal("1.4"))
if IS_PYDANTIC_V2:
m_json = m.model_dump_json()
m_from_json = DocumentWithDecimalField.model_validate_json(m_json)
else:
m_json = m.json()
m_from_json = DocumentWithDecimalField.parse_raw(m_json)
assert isinstance(m_from_json.amt, Decimal)
assert m_from_json.amt == Decimal("1.4")

0 comments on commit d63a101

Please sign in to comment.