diff --git a/codegen/generate_schema.py b/codegen/generate_schema.py index ad77f87b..bee20ad0 100644 --- a/codegen/generate_schema.py +++ b/codegen/generate_schema.py @@ -72,7 +72,7 @@ def format_default( - type_: Primitive, + type_: Primitive | EntityType | CommonStructType, default: str | int | float | bool, optional: bool, custom_type: CustomTypeDef | None, @@ -99,6 +99,7 @@ def format_default( | Primitive.uint32 | Primitive.uint64 ), str(default): + assert not isinstance(type_, EntityType | CommonStructType) if custom_type_open: return "".join( ( @@ -133,7 +134,7 @@ def format_default( return "None" raise NotImplementedError( - f"Failed parsing default for {type_.value=} field: {default=!r}" + f"Failed parsing default for {type_=} field: {default=!r}" ) @@ -166,8 +167,6 @@ def format_dataclass_field( if isinstance(field_type, PrimitiveArrayType): field_kwargs["default"] = "()" elif default is not None: - assert not isinstance(field_type, EntityType) - assert not isinstance(field_type, CommonStructType) field_kwargs["default"] = format_default( field_type, default, optional, custom_type ) @@ -355,21 +354,27 @@ def generate_entity_array_field( return f" {to_snake_case(field.name)}: tuple[{field.type}, ...]{field_call}\n" +def entity_annotation(field: EntityField | CommonStructField, optional: bool) -> str: + return f"{field.type} | None" if optional else str(field.type) + + def generate_entity_field( field: EntityField | CommonStructField, version: int, ) -> str: + optional = ( + field.nullableVersions.matches(version) if field.nullableVersions else False + ) field_call = format_dataclass_field( field_type=field.type, - default=None, - optional=( - field.nullableVersions.matches(version) if field.nullableVersions else False - ), + default=field.default, + optional=optional, custom_type=None, tag=field.get_tag(version), ignorable=field.ignorable, ) - return f" {to_snake_case(field.name)}: {field.type}{field_call}\n" + annotation = entity_annotation(field, optional) + return f" {to_snake_case(field.name)}: {annotation}{field_call}\n" def generate_common_struct_array_field( diff --git a/codegen/generate_tests.py b/codegen/generate_tests.py index 9029867a..9c577214 100644 --- a/codegen/generate_tests.py +++ b/codegen/generate_tests.py @@ -121,7 +121,6 @@ def main() -> None: "FetchResponse", # Records "FetchSnapshotResponse", # Records "FetchRequest", # Should not output tagged field if its value equals to default (presumably) - "ConsumerGroupHeartbeatResponse", # Nullable `assignment` field }: module_code[module_path].append( test_code_java.format( diff --git a/codegen/parser.py b/codegen/parser.py index 194cf4b6..a6d9b67e 100644 --- a/codegen/parser.py +++ b/codegen/parser.py @@ -372,6 +372,7 @@ class CommonStructArrayField(_BaseField): class CommonStructField(_BaseField): type: CommonStructType + default: Literal["null"] | None = None class EntityArrayField(_BaseField): @@ -382,6 +383,7 @@ class EntityArrayField(_BaseField): class EntityField(_BaseField): type: EntityType fields: tuple[Field, ...] + default: Literal["null"] | None = None class _BaseSchema(BaseModel): diff --git a/src/kio/_utils.py b/src/kio/_utils.py index ac2c09be..c33a499c 100644 --- a/src/kio/_utils.py +++ b/src/kio/_utils.py @@ -1,11 +1,13 @@ -# Work-around for broken support for cache decorators in -# https://github.com/python/typeshed/issues/6347 -# https://stackoverflow.com/a/73517689 from typing import TYPE_CHECKING +from typing import Protocol from typing import TypeVar -__all__ = ("cache",) +__all__ = ("cache", "DataclassInstance") + +# Work-around for broken support for cache decorators in +# https://github.com/python/typeshed/issues/6347 +# https://stackoverflow.com/a/73517689 if TYPE_CHECKING: _C = TypeVar("_C") @@ -14,3 +16,11 @@ def cache(c: _C) -> _C: else: from functools import cache + + +if TYPE_CHECKING: + from _typeshed import DataclassInstance +else: + + class DataclassInstance(Protocol): + ... diff --git a/src/kio/schema/consumer_group_heartbeat/v0/response.py b/src/kio/schema/consumer_group_heartbeat/v0/response.py index 52dfac64..c1c2ddcf 100644 --- a/src/kio/schema/consumer_group_heartbeat/v0/response.py +++ b/src/kio/schema/consumer_group_heartbeat/v0/response.py @@ -68,5 +68,5 @@ class ConsumerGroupHeartbeatResponse(ApiMessage): """True if the member should compute the assignment for the group.""" heartbeat_interval: i32Timedelta = field(metadata={"kafka_type": "timedelta_i32"}) """The heartbeat interval in milliseconds.""" - assignment: Assignment + assignment: Assignment | None = field(default=None) """null if not provided; the assignment otherwise.""" diff --git a/src/kio/serial/_introspect.py b/src/kio/serial/_introspect.py index 8729f5b3..8fe55fed 100644 --- a/src/kio/serial/_introspect.py +++ b/src/kio/serial/_introspect.py @@ -55,6 +55,27 @@ class FieldKind(enum.Enum): def classify_field(field: Field[T]) -> tuple[FieldKind, type[T]]: type_origin = get_origin(field.type) + if type_origin is UnionType: + try: + a, b = get_args(field.type) + except ValueError: + raise SchemaError( + f"Field {field.name} has unsupported union type: {field.type}" + ) from None + + if a is NoneType: + inner_type = b + elif b is NoneType: + inner_type = a + else: + raise SchemaError("Only union with None is supported") + + return ( + (FieldKind.entity, inner_type) + if is_dataclass(inner_type) + else (FieldKind.primitive, inner_type) + ) + if type_origin is not tuple: return ( (FieldKind.entity, field.type) # type: ignore[return-value] diff --git a/src/kio/serial/_parse.py b/src/kio/serial/_parse.py index 5342dc2b..ecc8494c 100644 --- a/src/kio/serial/_parse.py +++ b/src/kio/serial/_parse.py @@ -1,8 +1,10 @@ from dataclasses import Field from dataclasses import fields from typing import IO +from typing import Literal from typing import TypeVar from typing import assert_never +from typing import overload from kio._utils import cache from kio.static.protocol import Entity @@ -13,6 +15,8 @@ from ._introspect import get_field_tag from ._introspect import get_schema_field_type from ._introspect import is_optional +from ._shared import NullableEntityMarker +from .readers import read_int8 def get_reader( @@ -114,7 +118,11 @@ def get_field_reader( ) ) case FieldKind.entity: - return entity_reader(field_type) # type: ignore[type-var] + return ( # type: ignore[no-any-return] + entity_reader(field_type, nullable=True) # type: ignore[call-overload] + if is_optional(field) + else entity_reader(field_type, nullable=False) # type: ignore[call-overload] + ) case FieldKind.entity_tuple: return array_reader( # type: ignore[return-value] entity_reader(field_type) # type: ignore[type-var] @@ -126,8 +134,27 @@ def get_field_reader( E = TypeVar("E", bound=Entity) +@overload +def entity_reader( + entity_type: type[E], + nullable: Literal[False] = ..., +) -> readers.Reader[E]: + ... + + +@overload +def entity_reader( + entity_type: type[E], + nullable: Literal[True], +) -> readers.Reader[E | None]: + ... + + @cache -def entity_reader(entity_type: type[E]) -> readers.Reader[E]: +def entity_reader( + entity_type: type[E], + nullable: bool = False, +) -> readers.Reader[E | None]: field_readers = {} tagged_field_readers = {} is_request_header = entity_type.__name__ == "RequestHeader" @@ -170,4 +197,13 @@ def read_entity(buffer: IO[bytes]) -> E: return entity_type(**kwargs) - return read_entity + if not nullable: + return read_entity + + # This is undocumented behavior, formalized in KIP-893. + # https://cwiki.apache.org/confluence/display/KAFKA/KIP-893%3A+The+Kafka+protocol+should+support+nullable+structs + def read_nullable_entity(buffer: IO[bytes]) -> E | None: + marker = NullableEntityMarker(read_int8(buffer)) + return None if marker is NullableEntityMarker.null else read_entity(buffer) + + return read_nullable_entity diff --git a/src/kio/serial/_serialize.py b/src/kio/serial/_serialize.py index f58e5520..239e7d70 100644 --- a/src/kio/serial/_serialize.py +++ b/src/kio/serial/_serialize.py @@ -1,8 +1,10 @@ import io from dataclasses import Field from dataclasses import fields +from typing import Literal from typing import TypeVar from typing import assert_never +from typing import overload from kio._utils import cache from kio.static.protocol import Entity @@ -13,10 +15,12 @@ from ._introspect import get_field_tag from ._introspect import get_schema_field_type from ._introspect import is_optional +from ._shared import NullableEntityMarker from .writers import Writable from .writers import Writer from .writers import compact_array_writer from .writers import legacy_array_writer +from .writers import write_int8 from .writers import write_tagged_field from .writers import write_unsigned_varint @@ -128,7 +132,11 @@ def get_field_writer( ) ) case FieldKind.entity: - return entity_writer(field_type) # type: ignore[type-var] + return ( # type: ignore[no-any-return] + entity_writer(field_type, nullable=True) # type: ignore[call-overload] + if optional + else entity_writer(field_type, nullable=False) # type: ignore[call-overload] + ) case FieldKind.entity_tuple: return array_writer( # type: ignore[return-value] entity_writer(field_type) # type: ignore[type-var] @@ -140,8 +148,31 @@ def get_field_writer( E = TypeVar("E", bound=Entity) +def _wrap_nullable(write_entity: Writer[E]) -> Writer[E | None]: + # This is undocumented behavior, formalized in KIP-893. + # https://cwiki.apache.org/confluence/display/KAFKA/KIP-893%3A+The+Kafka+protocol+should+support+nullable+structs + def write_nullable(buffer: Writable, entity: E | None) -> None: + if entity is None: + write_int8(buffer, NullableEntityMarker.null.value) + return + write_int8(buffer, NullableEntityMarker.not_null.value) + write_entity(buffer, entity) + + return write_nullable + + +@overload +def entity_writer(entity_type: type[E], nullable: Literal[False] = ...) -> Writer[E]: + ... + + +@overload +def entity_writer(entity_type: type[E], nullable: Literal[True]) -> Writer[E | None]: + ... + + @cache -def entity_writer(entity_type: type[E]) -> Writer[E]: +def entity_writer(entity_type: type[E], nullable: bool = False) -> Writer[E | None]: field_writers = {} tagged_field_writers = {} is_request_header = entity_type.__name__ == "RequestHeader" @@ -204,4 +235,4 @@ def write_entity(buffer: Writable, entity: E) -> None: write_unsigned_varint(buffer, num_tagged_fields) buffer.write(tag_buffer.getvalue()) - return write_entity + return _wrap_nullable(write_entity) if nullable else write_entity # type: ignore[return-value] diff --git a/src/kio/serial/_shared.py b/src/kio/serial/_shared.py new file mode 100644 index 00000000..8c11ce92 --- /dev/null +++ b/src/kio/serial/_shared.py @@ -0,0 +1,8 @@ +import enum + +from kio.static.primitive import i8 + + +class NullableEntityMarker(enum.Enum): + null = i8(-1) + not_null = i8(1) diff --git a/src/kio/static/protocol.py b/src/kio/static/protocol.py index 9260c253..5037385e 100644 --- a/src/kio/static/protocol.py +++ b/src/kio/static/protocol.py @@ -1,16 +1,9 @@ -from typing import TYPE_CHECKING from typing import ClassVar from typing import Protocol -from .primitive import i16 - -if TYPE_CHECKING: - from _typeshed import DataclassInstance -else: - - class DataclassInstance(Protocol): - ... +from kio._utils import DataclassInstance +from .primitive import i16 __all__ = ("ApiMessage", "Entity", "Payload") diff --git a/tests/conftest.py b/tests/conftest.py index b236d72c..27bc34a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,13 +16,18 @@ from subprocess import PIPE from subprocess import Popen from subprocess import TimeoutExpired +from types import NoneType +from types import UnionType from typing import Any +from typing import get_args +from typing import get_origin from uuid import UUID import pytest import pytest_asyncio from hypothesis import settings +from kio._utils import DataclassInstance from kio.serial import entity_writer from kio.static.protocol import Entity @@ -90,11 +95,36 @@ async def stream_writer( return async_buffers[1] +def is_nullable_entity_field(field: dataclasses.Field) -> bool: + if get_origin(field.type) is not UnionType: + return False + try: + a, b = get_args(field.type) + except ValueError: + return False + return (a is NoneType and dataclasses.is_dataclass(b)) or ( + b is NoneType and dataclasses.is_dataclass(a) + ) + + +def map_nullable_entity_fields(obj: DataclassInstance) -> dict[str, bool]: + """Return map of KIP-893 nullable entity fields.""" + return { + field.name: is_nullable_entity_field(field) for field in dataclasses.fields(obj) + } + + class JavaTester: class _Encoder(JSONEncoder): def default(self, o: Any) -> Any: if dataclasses.is_dataclass(o): - return self._replace_tzaware_nulls(dataclasses.asdict(o)) + return self._replace_tzaware_nulls( + { + k: v + for k, v in dataclasses.asdict(o).items() + if (v is not None or not map_nullable_entity_fields(o)[k]) + } + ) if isinstance(o, timedelta): return round(o.total_seconds() * 1000) if isinstance(o, datetime): diff --git a/tests/generated/test_consumer_group_heartbeat_v0_response.py b/tests/generated/test_consumer_group_heartbeat_v0_response.py index f12ecc61..1a730d4e 100644 --- a/tests/generated/test_consumer_group_heartbeat_v0_response.py +++ b/tests/generated/test_consumer_group_heartbeat_v0_response.py @@ -14,6 +14,7 @@ from kio.schema.consumer_group_heartbeat.v0.response import TopicPartitions from kio.serial import entity_reader from kio.serial import entity_writer +from tests.conftest import JavaTester from tests.conftest import setup_buffer read_topic_partitions: Final = entity_reader(TopicPartitions) @@ -63,3 +64,11 @@ def test_consumer_group_heartbeat_response_roundtrip( buffer.seek(0) result = read_consumer_group_heartbeat_response(buffer) assert instance == result + + +@pytest.mark.java +@given(instance=from_type(ConsumerGroupHeartbeatResponse)) +def test_consumer_group_heartbeat_response_java( + instance: ConsumerGroupHeartbeatResponse, java_tester: JavaTester +) -> None: + java_tester.test(instance) diff --git a/tests/serial/test_introspect.py b/tests/serial/test_introspect.py index 220fb341..ae0bc741 100644 --- a/tests/serial/test_introspect.py +++ b/tests/serial/test_introspect.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from dataclasses import field from dataclasses import fields +from uuid import UUID import pytest @@ -29,13 +30,17 @@ class A: pep_604_union_with_none: int | None pep_604_union_without_none: int | str - simple_primitive: int - complex_primitive: int | str + primitive: int primitive_tuple: tuple[int, ...] primitive_tuple_optional: tuple[int | None, ...] entity: Nested entity_tuple: tuple[Nested, ...] + nullable_entity: Nested | None + backwards_nullable_entity: None | Nested unsupported_tuple: tuple[int, str] + unsupported_union: int | str | bool + + uuid_or_none: UUID | None model_fields = {field.name: field for field in fields(A)} @@ -73,6 +78,15 @@ def test_returns_false_for_verbose_union_without_none(self) -> None: def test_returns_false_for_pep_604_union_without_none(self) -> None: assert is_optional(model_fields["pep_604_union_without_none"]) is False + def test_returns_false_for_non_nullable_entity(self) -> None: + assert is_optional(model_fields["entity"]) is False + + def test_returns_true_for_nullable_entity(self) -> None: + assert is_optional(model_fields["nullable_entity"]) is True + + def test_returns_true_for_nullable_uuid(self) -> None: + assert is_optional(model_fields["uuid_or_none"]) is True + def test_raises_schema_error_for_invalid_tuple_type(self) -> None: with pytest.raises(SchemaError, match=r"has invalid tuple type"): is_optional(model_fields["unsupported_tuple"]) @@ -83,15 +97,16 @@ def test_raises_schema_error_for_invalid_tuple_type(self) -> None: with pytest.raises(SchemaError, match=r"has invalid tuple type"): classify_field(model_fields["unsupported_tuple"]) - @pytest.mark.parametrize( - "field", - ( - model_fields["simple_primitive"], - model_fields["complex_primitive"], - ), - ) - def test_can_classify_primitive_field(self, field: Field) -> None: - assert classify_field(field) == (FieldKind.primitive, field.type) + def test_raises_schema_error_for_invalid_union_type(self) -> None: + with pytest.raises(SchemaError, match=r"has unsupported union type"): + classify_field(model_fields["unsupported_union"]) + + def test_raises_schema_error_for_non_none_union(self) -> None: + with pytest.raises(SchemaError, match=r"Only union with None is supported"): + classify_field(model_fields["verbose_union_without_none"]) + + def test_can_classify_primitive_field(self) -> None: + assert classify_field(model_fields["primitive"]) == (FieldKind.primitive, int) def test_can_classify_primitive_tuple_field(self) -> None: assert classify_field(model_fields["primitive_tuple"]) == ( @@ -107,3 +122,20 @@ def test_can_classify_entity_tuple_field(self) -> None: def test_can_classify_simple_nested_entity(self) -> None: assert classify_field(model_fields["entity"]) == (FieldKind.entity, Nested) + + # See KIP-893. + @pytest.mark.parametrize( + "field", + ( + model_fields["nullable_entity"], + model_fields["backwards_nullable_entity"], + ), + ) + def test_can_classify_nullable_nested_entity(self, field: Field) -> None: + assert classify_field(field) == (FieldKind.entity, Nested) + + def test_can_classify_uuid_or_none(self) -> None: + assert classify_field(model_fields["uuid_or_none"]) == ( + FieldKind.primitive, + UUID, + ) diff --git a/tests/serial/test_parse.py b/tests/serial/test_parse.py index 54b033e8..95eb87d6 100644 --- a/tests/serial/test_parse.py +++ b/tests/serial/test_parse.py @@ -17,10 +17,12 @@ from kio.serial import entity_reader from kio.serial import readers from kio.serial._parse import get_reader +from kio.serial._shared import NullableEntityMarker from kio.serial.writers import write_boolean from kio.serial.writers import write_compact_array_length from kio.serial.writers import write_compact_string from kio.serial.writers import write_empty_tagged_fields +from kio.serial.writers import write_int8 from kio.serial.writers import write_int16 from kio.serial.writers import write_int32 from kio.serial.writers import write_legacy_string @@ -363,3 +365,42 @@ def test_raises_value_error_for_tagged_field_on_legacy_model() -> None: match=r"^Found tagged fields on a non-flexible model$", ): entity_reader(LegacyWithTag) + + +@dataclass(frozen=True, slots=True, kw_only=True) +class NestedNullable: + __version__: ClassVar[i16] = i16(0) + __flexible__: ClassVar[bool] = True + __api_key__: ClassVar[i16] = i16(-1) + child: Child | None = field(default=None) + name: str = field(metadata={"kafka_type": "string"}) + + +def test_can_read_populated_nested_nullable_entity(buffer: io.BytesIO) -> None: + write_int8(buffer, NullableEntityMarker.not_null.value) + write_compact_string(buffer, "child name") + write_empty_tagged_fields(buffer) # child fields + write_compact_string(buffer, "parent name") + write_empty_tagged_fields(buffer) # parent fields + buffer.seek(0) + + instance = entity_reader(NestedNullable)(buffer) + + assert instance == NestedNullable( + child=Child(name="child name"), + name="parent name", + ) + + +def test_can_read_empty_nested_nullable_entity(buffer: io.BytesIO) -> None: + write_int8(buffer, NullableEntityMarker.null.value) + write_compact_string(buffer, "parent name") + write_empty_tagged_fields(buffer) # parent fields + buffer.seek(0) + + instance = entity_reader(NestedNullable)(buffer) + + assert instance == NestedNullable( + child=None, + name="parent name", + ) diff --git a/tests/serial/test_serialize.py b/tests/serial/test_serialize.py index a949bb4d..61d3e4b9 100644 --- a/tests/serial/test_serialize.py +++ b/tests/serial/test_serialize.py @@ -16,10 +16,12 @@ from kio.serial import entity_writer from kio.serial import writers from kio.serial._serialize import get_writer +from kio.serial._shared import NullableEntityMarker from kio.serial.readers import read_boolean from kio.serial.readers import read_compact_array_length from kio.serial.readers import read_compact_string from kio.serial.readers import read_compact_string_nullable +from kio.serial.readers import read_int8 from kio.serial.readers import read_int16 from kio.serial.readers import read_int32 from kio.serial.readers import read_unsigned_varint @@ -256,3 +258,50 @@ def test_serialize_complex_entity(buffer: io.BytesIO) -> None: # main entity tagged fields assert read_unsigned_varint(buffer) == 0 + + +@dataclass(frozen=True, slots=True, kw_only=True) +class Child: + __version__: ClassVar[i16] = i16(0) + __flexible__: ClassVar[bool] = True + __api_key__: ClassVar[i16] = i16(-1) + name: str = field(metadata={"kafka_type": "string"}) + + +@dataclass(frozen=True, slots=True, kw_only=True) +class NestedNullable: + __version__: ClassVar[i16] = i16(0) + __flexible__: ClassVar[bool] = True + __api_key__: ClassVar[i16] = i16(-1) + child: Child | None = field(default=None) + name: str = field(metadata={"kafka_type": "string"}) + + +def test_can_write_populated_nested_nullable_entity(buffer: io.BytesIO) -> None: + write_nested_nullable = entity_writer(NestedNullable) + instance = NestedNullable( + child=Child(name="child name"), + name="parent name", + ) + write_nested_nullable(buffer, instance) + buffer.seek(0) + + assert read_int8(buffer) == NullableEntityMarker.not_null.value + assert read_compact_string(buffer) == "child name" + assert read_unsigned_varint(buffer) == 0 # tagged fields + assert read_compact_string(buffer) == "parent name" + assert read_unsigned_varint(buffer) == 0 # tagged fields + + +def test_can_write_empty_nested_nullable_entity(buffer: io.BytesIO) -> None: + write_nested_nullable = entity_writer(NestedNullable) + instance = NestedNullable( + child=None, + name="parent name", + ) + write_nested_nullable(buffer, instance) + buffer.seek(0) + + assert read_int8(buffer) == NullableEntityMarker.null.value + assert read_compact_string(buffer) == "parent name" + assert read_unsigned_varint(buffer) == 0 # tagged fields diff --git a/tests/test_integration.py b/tests/test_integration.py index b446a198..14dc97c4 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -84,7 +84,7 @@ def write_request_header( else: raise NotImplementedError(f"Unknown request header schema: {header_schema}") - entity_writer(header_schema)(buffer, header) # type: ignore[type-var] + entity_writer(header_schema)(buffer, header) # type: ignore[arg-type] async def send(