Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify and fix DecimalAnnotation #738

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 12 additions & 47 deletions beanie/odm/custom_types/decimal.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,12 @@
# check python version
import sys

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated

from decimal import Decimal as NativeDecimal
from typing import Any, Callable

from bson import Decimal128
from pydantic import GetJsonSchemaHandler
from pydantic.fields import FieldInfo
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema


class DecimalCustomAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: Callable[[Any], core_schema.CoreSchema], # type: ignore
) -> core_schema.CoreSchema: # type: ignore
def validate(value, _: FieldInfo) -> NativeDecimal:
if isinstance(value, Decimal128):
return value.to_decimal()
return value

python_schema = core_schema.general_plain_validator_function(validate) # type: ignore

return core_schema.json_or_python_schema(
json_schema=core_schema.float_schema(),
python_schema=python_schema,
)

@classmethod
def __get_pydantic_json_schema__(
cls,
_core_schema: core_schema.CoreSchema, # type: ignore
handler: GetJsonSchemaHandler,
) -> JsonSchemaValue:
return handler(core_schema.float_schema())


DecimalAnnotation = Annotated[NativeDecimal, DecimalCustomAnnotation]
import decimal

import bson
import pydantic
from typing_extensions import Annotated

DecimalAnnotation = Annotated[
decimal.Decimal,
pydantic.BeforeValidator(
lambda v: v.to_decimal() if isinstance(v, bson.Decimal128) else v
),
]
16 changes: 16 additions & 0 deletions tests/odm/custom_types/test_decimal_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from decimal import Decimal

from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
from tests.odm.models import DocumentWithDecimalField


def test_decimal_deserialize():
m = DocumentWithDecimalField(amt=Decimal("1.4"))
if IS_PYDANTIC_V2:
m_json = m.model_dump_json()
m_from_json = DocumentWithDecimalField.model_validate_json(m_json)
else:
m_json = m.json()
m_from_json = DocumentWithDecimalField.parse_raw(m_json)
assert isinstance(m_from_json.amt, Decimal)
assert m_from_json.amt == Decimal("1.4")
Loading