Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mockafka/aiokafka/aiokafka_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def message_to_record(message: Message, offset: int) -> ConsumerRecord[bytes, by
key = key_str.encode() if key_str is not None else None
value = value_str.encode() if value_str is not None else None

headers = message.headers()

return ConsumerRecord(
topic=topic,
partition=partition,
Expand All @@ -51,7 +53,7 @@ def message_to_record(message: Message, offset: int) -> ConsumerRecord[bytes, by
checksum=None, # Deprecated, we won't support it
serialized_key_size=len(key) if key else 0,
serialized_value_size=len(value) if value else 0,
headers=tuple((message.headers() or {}).items()),
headers=tuple(headers) if headers else (),
)


Expand Down
4 changes: 3 additions & 1 deletion mockafka/aiokafka/aiokafka_producer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Optional

from mockafka.kafka_store import KafkaStore
from mockafka.message import Message

Expand Down Expand Up @@ -46,7 +48,7 @@ async def send(
key=None,
partition=0,
timestamp_ms=None,
headers=None,
headers: Optional[list[tuple[str, Optional[bytes]]]] = None,
) -> None:
await self._produce(
topic=topic,
Expand Down
4 changes: 2 additions & 2 deletions mockafka/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class Message:
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._headers: Optional[dict] = kwargs.get("headers", None)
self._headers: Optional[list[tuple[str, Optional[bytes]]]] = kwargs.get("headers", None)
self._key: Optional[str] = kwargs.get("key", None)
self._value: Optional[str] = kwargs.get("value", None)
self._topic: Optional[str] = kwargs.get("topic", None)
Expand All @@ -37,7 +37,7 @@ def latency(self, *args, **kwargs):
def leader_epoch(self, *args, **kwargs):
return self._leader_epoch

def headers(self, *args, **kwargs):
def headers(self) -> Optional[list[tuple[str, Optional[bytes]]]]:
return self._headers

def key(self, *args, **kwargs):
Expand Down
37 changes: 34 additions & 3 deletions mockafka/producer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Optional, Union

from mockafka.cluster_metadata import ClusterMetadata
from mockafka.kafka_store import KafkaStore
from mockafka.message import Message
Expand All @@ -11,10 +13,39 @@ class FakeProducer(object):
def __init__(self, config: dict | None = None):
self.kafka = KafkaStore()

def produce(self, topic, value=None, *args, **kwargs):
def produce(
self,
topic,
value=None,
key=None,
partition=None,
callback=None,
on_delivery=None,
timestamp=None,
headers: Union[
# While Kafka itself supports only list[tuple[...]], confluent_kafka
# allows passing in a dict here.
dict[str, Optional[bytes]],
list[tuple[str, Optional[bytes]]],
None,
] = None,
**kwargs,
) -> None:
if isinstance(headers, dict):
headers = list(headers.items())
# create a message and call produce kafka
message = Message(value=value, topic=topic, *args, **kwargs)
self.kafka.produce(message=message, topic=topic, partition=kwargs["partition"])
message = Message(
topic=topic,
value=value,
key=key,
partition=partition,
callback=callback,
on_delivery=on_delivery,
timestamp=timestamp,
headers=headers,
**kwargs,
)
self.kafka.produce(message=message, topic=topic, partition=partition)

def list_topics(self, topic=None, *args, **kwargs):
return ClusterMetadata(topic)
Expand Down
15 changes: 13 additions & 2 deletions tests/test_aiokafka/test_aiokafka_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ async def test_produce_fail_for_none_partition(self):
async def test_produce_once(self) -> None:
await self._create_mock_topic()
await self.producer.send(
headers={},
headers=[
("header-name1", b"header-value"),
("header-name2", None),
("header-name1", b"duplicate!"),
],
key=self.key,
value=self.value,
topic=self.topic,
Expand All @@ -76,7 +80,14 @@ async def test_produce_once(self) -> None:
self.assertEqual(message.key(), self.key)
self.assertEqual(message.value(payload=None), self.value)
self.assertEqual(message.topic(), self.topic)
self.assertEqual(message.headers(), {})
self.assertEqual(
message.headers(),
[
("header-name1", b"header-value"),
("header-name2", None),
("header-name1", b"duplicate!"),
],
)
self.assertEqual(message.error(), None)
self.assertEqual(message.latency(), None)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_async_mockafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,25 @@ async def test_produce_and_consume_with_decorator(message=None):

assert message.key == b"test_key"
assert message.value == b"test_value"


@pytest.mark.asyncio
@asetup_kafka(topics=[{"topic": "test_topic1", "partition": 2}], clean=True)
async def test_produce_and_consume_with_headers():
producer = FakeAIOKafkaProducer()
consumer = FakeAIOKafkaConsumer()

await producer.start()
await consumer.start()
consumer.subscribe({"test_topic1"})

await producer.send(
topic="test_topic1",
headers=[('header_name', b"test"), ('header_name2', b"test")],
)
await producer.stop()

record = await consumer.getone()
assert record.headers == (('header_name', b"test"), ('header_name2', b"test"))

await consumer.stop()
13 changes: 11 additions & 2 deletions tests/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def test_produce_fail_for_none_partition(self):

def test_produce_once(self) -> None:
self.producer.produce(
headers={},
headers={
"header-name1": b"header-value",
"header-name2": None,
},
key=self.key,
value=self.value,
topic=self.topic,
Expand All @@ -72,7 +75,13 @@ def test_produce_once(self) -> None:
self.assertEqual(message.key(), self.key)
self.assertEqual(message.value(payload=None), self.value)
self.assertEqual(message.topic(), self.topic)
self.assertEqual(message.headers(), {})
self.assertEqual(
message.headers(),
[
("header-name1", b"header-value"),
("header-name2", None),
],
)
self.assertEqual(message.error(), None)
self.assertEqual(message.latency(), None)

Expand Down