Skip to content

Commit

Permalink
Support Pydantic root model as query parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusSintonen committed Mar 17, 2024
1 parent ffb4f77 commit d22d6e0
Show file tree
Hide file tree
Showing 3 changed files with 427 additions and 6 deletions.
30 changes: 28 additions & 2 deletions fastapi/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter
from pydantic import RootModel, TypeAdapter
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler,
Expand Down Expand Up @@ -277,6 +277,11 @@ def create_body_model(
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel

def root_model_inner_type(annotation: Any) -> Any:
if lenient_issubclass(annotation, RootModel):
return annotation.model_fields["root"].annotation
return None

else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
Expand Down Expand Up @@ -385,6 +390,12 @@ def get_model_definitions(
def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params

if (
lenient_issubclass(field.type_, BaseModel)
and "__root__" in field.type_.__fields__
):
return is_pv1_scalar_field(field.type_.__fields__["__root__"])

field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
Expand Down Expand Up @@ -478,7 +489,11 @@ def get_definitions(
)

def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
from fastapi import params

return is_pv1_scalar_field(field) and not isinstance(
field.field_info, params.Body
)

def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
Expand Down Expand Up @@ -511,6 +526,14 @@ def create_body_model(
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel

def root_model_inner_type(annotation: Any) -> Any:
if (
lenient_issubclass(annotation, BaseModel)
and "__root__" in annotation.__fields__
):
return annotation.__fields__["__root__"].annotation
return None


def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
Expand Down Expand Up @@ -552,6 +575,9 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))

if inner := root_model_inner_type(annotation):
return field_annotation_is_complex(inner)

return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)
Expand Down
10 changes: 6 additions & 4 deletions fastapi/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def jsonable_encoder(
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
encoders.update(custom_encoder)
obj_dict = _model_dump(
serialized = _model_dump(
obj,
mode="json",
include=include,
Expand All @@ -230,10 +230,12 @@ def jsonable_encoder(
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
if (
isinstance(serialized, dict) and "__root__" in serialized
): # TODO: remove when deprecating Pydantic v1
serialized = serialized["__root__"]
return jsonable_encoder(
obj_dict,
serialized,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
# TODO: remove when deprecating Pydantic v1
Expand Down

0 comments on commit d22d6e0

Please sign in to comment.