Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement KIP-893 nullable entity fields #101

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions codegen/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@


def format_default(
type_: Primitive,
type_: Primitive | EntityType | CommonStructType,
default: str | int | float | bool,
optional: bool,
custom_type: CustomTypeDef | None,
Expand All @@ -99,6 +99,7 @@ def format_default(
| Primitive.uint32
| Primitive.uint64
), str(default):
assert not isinstance(type_, EntityType | CommonStructType)
if custom_type_open:
return "".join(
(
Expand Down Expand Up @@ -133,7 +134,7 @@ def format_default(
return "None"

raise NotImplementedError(
f"Failed parsing default for {type_.value=} field: {default=!r}"
f"Failed parsing default for {type_=} field: {default=!r}"
)


Expand Down Expand Up @@ -166,8 +167,6 @@ def format_dataclass_field(
if isinstance(field_type, PrimitiveArrayType):
field_kwargs["default"] = "()"
elif default is not None:
assert not isinstance(field_type, EntityType)
assert not isinstance(field_type, CommonStructType)
field_kwargs["default"] = format_default(
field_type, default, optional, custom_type
)
Expand Down Expand Up @@ -355,21 +354,27 @@ def generate_entity_array_field(
return f" {to_snake_case(field.name)}: tuple[{field.type}, ...]{field_call}\n"


def entity_annotation(field: EntityField | CommonStructField, optional: bool) -> str:
return f"{field.type} | None" if optional else str(field.type)


def generate_entity_field(
field: EntityField | CommonStructField,
version: int,
) -> str:
optional = (
field.nullableVersions.matches(version) if field.nullableVersions else False
)
field_call = format_dataclass_field(
field_type=field.type,
default=None,
optional=(
field.nullableVersions.matches(version) if field.nullableVersions else False
),
default=field.default,
optional=optional,
custom_type=None,
tag=field.get_tag(version),
ignorable=field.ignorable,
)
return f" {to_snake_case(field.name)}: {field.type}{field_call}\n"
annotation = entity_annotation(field, optional)
return f" {to_snake_case(field.name)}: {annotation}{field_call}\n"


def generate_common_struct_array_field(
Expand Down
1 change: 0 additions & 1 deletion codegen/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def main() -> None:
"FetchResponse", # Records
"FetchSnapshotResponse", # Records
"FetchRequest", # Should not output tagged field if its value equals to default (presumably)
"ConsumerGroupHeartbeatResponse", # Nullable `assignment` field
}:
module_code[module_path].append(
test_code_java.format(
Expand Down
2 changes: 2 additions & 0 deletions codegen/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ class CommonStructArrayField(_BaseField):

class CommonStructField(_BaseField):
type: CommonStructType
default: Literal["null"] | None = None


class EntityArrayField(_BaseField):
Expand All @@ -382,6 +383,7 @@ class EntityArrayField(_BaseField):
class EntityField(_BaseField):
type: EntityType
fields: tuple[Field, ...]
default: Literal["null"] | None = None


class _BaseSchema(BaseModel):
Expand Down
18 changes: 14 additions & 4 deletions src/kio/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Work-around for broken support for cache decorators in
# https://github.com/python/typeshed/issues/6347
# https://stackoverflow.com/a/73517689
from typing import TYPE_CHECKING
from typing import Protocol
from typing import TypeVar

__all__ = ("cache",)
__all__ = ("cache", "DataclassInstance")


# Work-around for broken support for cache decorators in
# https://github.com/python/typeshed/issues/6347
# https://stackoverflow.com/a/73517689
if TYPE_CHECKING:
_C = TypeVar("_C")

Expand All @@ -14,3 +16,11 @@ def cache(c: _C) -> _C:

else:
from functools import cache


if TYPE_CHECKING:
from _typeshed import DataclassInstance
else:

class DataclassInstance(Protocol):
...
2 changes: 1 addition & 1 deletion src/kio/schema/consumer_group_heartbeat/v0/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ class ConsumerGroupHeartbeatResponse(ApiMessage):
"""True if the member should compute the assignment for the group."""
heartbeat_interval: i32Timedelta = field(metadata={"kafka_type": "timedelta_i32"})
"""The heartbeat interval in milliseconds."""
assignment: Assignment
assignment: Assignment | None = field(default=None)
"""null if not provided; the assignment otherwise."""
21 changes: 21 additions & 0 deletions src/kio/serial/_introspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,27 @@ class FieldKind(enum.Enum):
def classify_field(field: Field[T]) -> tuple[FieldKind, type[T]]:
type_origin = get_origin(field.type)

if type_origin is UnionType:
try:
a, b = get_args(field.type)
except ValueError:
raise SchemaError(
f"Field {field.name} has unsupported union type: {field.type}"
) from None

if a is NoneType:
inner_type = b
elif b is NoneType:
inner_type = a
else:
raise SchemaError("Only union with None is supported")

return (
(FieldKind.entity, inner_type)
if is_dataclass(inner_type)
else (FieldKind.primitive, inner_type)
)

if type_origin is not tuple:
return (
(FieldKind.entity, field.type) # type: ignore[return-value]
Expand Down
42 changes: 39 additions & 3 deletions src/kio/serial/_parse.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import Field
from dataclasses import fields
from typing import IO
from typing import Literal
from typing import TypeVar
from typing import assert_never
from typing import overload

from kio._utils import cache
from kio.static.protocol import Entity
Expand All @@ -13,6 +15,8 @@
from ._introspect import get_field_tag
from ._introspect import get_schema_field_type
from ._introspect import is_optional
from ._shared import NullableEntityMarker
from .readers import read_int8


def get_reader(
Expand Down Expand Up @@ -114,7 +118,11 @@ def get_field_reader(
)
)
case FieldKind.entity:
return entity_reader(field_type) # type: ignore[type-var]
return ( # type: ignore[no-any-return]
entity_reader(field_type, nullable=True) # type: ignore[call-overload]
if is_optional(field)
else entity_reader(field_type, nullable=False) # type: ignore[call-overload]
)
case FieldKind.entity_tuple:
return array_reader( # type: ignore[return-value]
entity_reader(field_type) # type: ignore[type-var]
Expand All @@ -126,8 +134,27 @@ def get_field_reader(
E = TypeVar("E", bound=Entity)


@overload
def entity_reader(
entity_type: type[E],
nullable: Literal[False] = ...,
) -> readers.Reader[E]:
...


@overload
def entity_reader(
entity_type: type[E],
nullable: Literal[True],
) -> readers.Reader[E | None]:
...


@cache
def entity_reader(entity_type: type[E]) -> readers.Reader[E]:
def entity_reader(
entity_type: type[E],
nullable: bool = False,
) -> readers.Reader[E | None]:
field_readers = {}
tagged_field_readers = {}
is_request_header = entity_type.__name__ == "RequestHeader"
Expand Down Expand Up @@ -170,4 +197,13 @@ def read_entity(buffer: IO[bytes]) -> E:

return entity_type(**kwargs)

return read_entity
if not nullable:
return read_entity

# This is undocumented behavior, formalized in KIP-893.
# https://cwiki.apache.org/confluence/display/KAFKA/KIP-893%3A+The+Kafka+protocol+should+support+nullable+structs
def read_nullable_entity(buffer: IO[bytes]) -> E | None:
marker = NullableEntityMarker(read_int8(buffer))
return None if marker is NullableEntityMarker.null else read_entity(buffer)

return read_nullable_entity
37 changes: 34 additions & 3 deletions src/kio/serial/_serialize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import io
from dataclasses import Field
from dataclasses import fields
from typing import Literal
from typing import TypeVar
from typing import assert_never
from typing import overload

from kio._utils import cache
from kio.static.protocol import Entity
Expand All @@ -13,10 +15,12 @@
from ._introspect import get_field_tag
from ._introspect import get_schema_field_type
from ._introspect import is_optional
from ._shared import NullableEntityMarker
from .writers import Writable
from .writers import Writer
from .writers import compact_array_writer
from .writers import legacy_array_writer
from .writers import write_int8
from .writers import write_tagged_field
from .writers import write_unsigned_varint

Expand Down Expand Up @@ -128,7 +132,11 @@ def get_field_writer(
)
)
case FieldKind.entity:
return entity_writer(field_type) # type: ignore[type-var]
return ( # type: ignore[no-any-return]
entity_writer(field_type, nullable=True) # type: ignore[call-overload]
if optional
else entity_writer(field_type, nullable=False) # type: ignore[call-overload]
)
case FieldKind.entity_tuple:
return array_writer( # type: ignore[return-value]
entity_writer(field_type) # type: ignore[type-var]
Expand All @@ -140,8 +148,31 @@ def get_field_writer(
E = TypeVar("E", bound=Entity)


def _wrap_nullable(write_entity: Writer[E]) -> Writer[E | None]:
# This is undocumented behavior, formalized in KIP-893.
# https://cwiki.apache.org/confluence/display/KAFKA/KIP-893%3A+The+Kafka+protocol+should+support+nullable+structs
def write_nullable(buffer: Writable, entity: E | None) -> None:
if entity is None:
write_int8(buffer, NullableEntityMarker.null.value)
return
write_int8(buffer, NullableEntityMarker.not_null.value)
write_entity(buffer, entity)

return write_nullable


@overload
def entity_writer(entity_type: type[E], nullable: Literal[False] = ...) -> Writer[E]:
...


@overload
def entity_writer(entity_type: type[E], nullable: Literal[True]) -> Writer[E | None]:
...


@cache
def entity_writer(entity_type: type[E]) -> Writer[E]:
def entity_writer(entity_type: type[E], nullable: bool = False) -> Writer[E | None]:
field_writers = {}
tagged_field_writers = {}
is_request_header = entity_type.__name__ == "RequestHeader"
Expand Down Expand Up @@ -204,4 +235,4 @@ def write_entity(buffer: Writable, entity: E) -> None:
write_unsigned_varint(buffer, num_tagged_fields)
buffer.write(tag_buffer.getvalue())

return write_entity
return _wrap_nullable(write_entity) if nullable else write_entity # type: ignore[return-value]
8 changes: 8 additions & 0 deletions src/kio/serial/_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import enum

from kio.static.primitive import i8


class NullableEntityMarker(enum.Enum):
null = i8(-1)
not_null = i8(1)
11 changes: 2 additions & 9 deletions src/kio/static/protocol.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
from typing import TYPE_CHECKING
from typing import ClassVar
from typing import Protocol

from .primitive import i16

if TYPE_CHECKING:
from _typeshed import DataclassInstance
else:

class DataclassInstance(Protocol):
...
from kio._utils import DataclassInstance

from .primitive import i16

__all__ = ("ApiMessage", "Entity", "Payload")

Expand Down
32 changes: 31 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
from subprocess import PIPE
from subprocess import Popen
from subprocess import TimeoutExpired
from types import NoneType
from types import UnionType
from typing import Any
from typing import get_args
from typing import get_origin
from uuid import UUID

import pytest
import pytest_asyncio
from hypothesis import settings

from kio._utils import DataclassInstance
from kio.serial import entity_writer
from kio.static.protocol import Entity

Expand Down Expand Up @@ -90,11 +95,36 @@ async def stream_writer(
return async_buffers[1]


def is_nullable_entity_field(field: dataclasses.Field) -> bool:
if get_origin(field.type) is not UnionType:
return False
try:
a, b = get_args(field.type)
except ValueError:
return False
return (a is NoneType and dataclasses.is_dataclass(b)) or (
b is NoneType and dataclasses.is_dataclass(a)
)


def map_nullable_entity_fields(obj: DataclassInstance) -> dict[str, bool]:
"""Return map of KIP-893 nullable entity fields."""
return {
field.name: is_nullable_entity_field(field) for field in dataclasses.fields(obj)
}


class JavaTester:
class _Encoder(JSONEncoder):
def default(self, o: Any) -> Any:
if dataclasses.is_dataclass(o):
return self._replace_tzaware_nulls(dataclasses.asdict(o))
return self._replace_tzaware_nulls(
{
k: v
for k, v in dataclasses.asdict(o).items()
if (v is not None or not map_nullable_entity_fields(o)[k])
aiven-anton marked this conversation as resolved.
Show resolved Hide resolved
}
)
if isinstance(o, timedelta):
return round(o.total_seconds() * 1000)
if isinstance(o, datetime):
Expand Down
Loading