Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Pydantic root model as query parameter #11306

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 32 additions & 6 deletions fastapi/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter
from pydantic import RootModel, TypeAdapter
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler,
Expand Down Expand Up @@ -232,11 +232,7 @@ def get_definitions(
return field_mapping, definitions # type: ignore[return-value]

def is_scalar_field(field: ModelField) -> bool:
from fastapi import params

return field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check was missing from v1 side. Lifted the params.Body check out from this "scalar field check" into is_body_param check which is more appropriate for checking the params.Body case.

return field_annotation_is_scalar(field.field_info.annotation)

def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation)
Expand Down Expand Up @@ -279,6 +275,11 @@ def create_body_model(
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel

def root_model_inner_type(annotation: Any) -> Any:
if lenient_issubclass(annotation, RootModel):
return annotation.model_fields["root"].annotation
return None

else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
Expand Down Expand Up @@ -387,6 +388,12 @@ def get_model_definitions(
def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params

if (
lenient_issubclass(field.type_, BaseModel)
and "__root__" in field.type_.__fields__
):
return is_pv1_scalar_field(field.type_.__fields__["__root__"])

field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
Expand Down Expand Up @@ -513,6 +520,14 @@ def create_body_model(
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel

def root_model_inner_type(annotation: Any) -> Any:
if (
lenient_issubclass(annotation, BaseModel)
and "__root__" in annotation.__fields__
):
return annotation.__fields__["__root__"].annotation
return None


def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
Expand Down Expand Up @@ -549,11 +564,22 @@ def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
)


def field_annotation_is_root_model(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_root_model(arg) for arg in get_args(annotation))

return root_model_inner_type(annotation) is not None


def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))

if inner := root_model_inner_type(annotation):
return field_annotation_is_complex(inner)

return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)
Expand Down
11 changes: 9 additions & 2 deletions fastapi/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
copy_field_info,
create_body_model,
evaluate_forwardref,
field_annotation_is_root_model,
field_annotation_is_scalar,
get_annotation_from_field_info,
get_missing_field_error,
Expand Down Expand Up @@ -414,7 +415,11 @@ def analyze_param(
type_annotation
) or is_uploadfile_sequence_annotation(type_annotation):
field_info = params.File(annotation=use_annotation, default=default_value)
elif not field_annotation_is_scalar(annotation=type_annotation):
elif (
not field_annotation_is_scalar(type_annotation)
# Root models by default regarded as bodies
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward compatibility. Root models have been regarded as body params by default eventhough they were scalar.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to explicitly require annotating by Body/Query,etc

or field_annotation_is_root_model(type_annotation)
):
field_info = params.Body(annotation=use_annotation, default=default_value)
else:
field_info = params.Query(annotation=use_annotation, default=default_value)
Expand Down Expand Up @@ -459,7 +464,9 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
field=param_field
), "Path params must be of one of the supported types"
return False
elif is_scalar_field(field=param_field):
elif is_scalar_field(field=param_field) and not isinstance(
param_field.field_info, params.Body
):
return False
elif isinstance(
param_field.field_info, (params.Query, params.Header)
Expand Down
14 changes: 10 additions & 4 deletions fastapi/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def jsonable_encoder(
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
encoders.update(custom_encoder)
obj_dict = _model_dump(
serialized = _model_dump(
obj,
mode="json",
include=include,
Expand All @@ -230,10 +230,16 @@ def jsonable_encoder(
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]

if (
not PYDANTIC_V2
and isinstance(serialized, dict)
and "__root__" in serialized
):
serialized = serialized["__root__"]

return jsonable_encoder(
obj_dict,
serialized,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
# TODO: remove when deprecating Pydantic v1
Expand Down