Skip to content
This repository has been archived by the owner on Mar 24, 2021. It is now read-only.

Commit

Permalink
Merge pull request #768 from Parsely/feature/serde
Browse files Browse the repository at this point in the history
SerDe support
  • Loading branch information
Emmett J. Butler committed Feb 22, 2018
2 parents 4bec6cf + 5a951a7 commit 24bdb1a
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 15 deletions.
15 changes: 13 additions & 2 deletions pykafka/balancedconsumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def __init__(self,
post_rebalance_callback=None,
use_rdkafka=False,
compacted_topic=False,
membership_protocol=RangeProtocol):
membership_protocol=RangeProtocol,
deserializer=None):
"""Create a BalancedConsumer instance
:param topic: The topic this consumer should consume
Expand Down Expand Up @@ -204,6 +205,14 @@ def __init__(self,
:param membership_protocol: The group membership protocol to which this consumer
should adhere
:type membership_protocol: :class:`pykafka.membershipprotocol.GroupMembershipProtocol`
:param deserializer: A function defining how to deserialize messages returned
from Kafka. A function with the signature d(value, partition_key) that
returns a tuple of (deserialized_value, deserialized_partition_key). The
arguments passed to this function are the bytes representations of a
message's value and partition key, and the returned data should be these
fields transformed according to the client code's serialization logic.
See `pykafka.utils.__init__` for stock implemtations.
:type deserializer: function
"""
self._cluster = cluster
try:
Expand Down Expand Up @@ -238,6 +247,7 @@ def __init__(self,
self._worker_exception = None
self._is_compacted_topic = compacted_topic
self._membership_protocol = membership_protocol
self._deserializer = deserializer

if not rdkafka and use_rdkafka:
raise ImportError("use_rdkafka requires rdkafka to be installed")
Expand Down Expand Up @@ -436,7 +446,8 @@ def _get_internal_consumer(self, partitions=None, start=True):
auto_start=start,
compacted_topic=self._is_compacted_topic,
generation_id=self._generation_id,
consumer_id=self._consumer_id
consumer_id=self._consumer_id,
deserializer=self._deserializer
)

def _get_participants(self):
Expand Down
12 changes: 11 additions & 1 deletion pykafka/managedbalancedconsumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __init__(self,
use_rdkafka=False,
compacted_topic=True,
heartbeat_interval_ms=3000,
membership_protocol=RangeProtocol):
membership_protocol=RangeProtocol,
deserializer=None):
"""Create a ManagedBalancedConsumer instance
:param topic: The topic this consumer should consume
Expand Down Expand Up @@ -168,6 +169,14 @@ def __init__(self,
:param membership_protocol: The group membership protocol to which this consumer
should adhere
:type membership_protocol: :class:`pykafka.membershipprotocol.GroupMembershipProtocol`
:param deserializer: A function defining how to deserialize messages returned
from Kafka. A function with the signature d(value, partition_key) that
returns a tuple of (deserialized_value, deserialized_partition_key). The
arguments passed to this function are the bytes representations of a
message's value and partition key, and the returned data should be these
fields transformed according to the client code's serialization logic.
See `pykafka.utils.__init__` for stock implemtations.
:type deserializer: function
"""

self._cluster = cluster
Expand Down Expand Up @@ -199,6 +208,7 @@ def __init__(self,
self._membership_protocol = membership_protocol
self._membership_protocol.metadata.topic_names = [self._topic.name]
self._heartbeat_interval_ms = valid_int(heartbeat_interval_ms)
self._deserializer = deserializer
if use_rdkafka is True:
raise ImportError("use_rdkafka is not available for {}".format(
self.__class__.__name__))
Expand Down
28 changes: 21 additions & 7 deletions pykafka/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __init__(self,
max_request_size=1000012,
sync=False,
delivery_reports=False,
auto_start=True):
auto_start=True,
serializer=None):
"""Instantiate a new AsyncProducer
:param cluster: The cluster to which to connect
Expand Down Expand Up @@ -154,6 +155,14 @@ def __init__(self,
with kafka after __init__ is complete. If false, communication
can be started with `start()`.
:type auto_start: bool
:param serializer: A function defining how to serialize messages to be sent
to Kafka. A function with the signature d(value, partition_key) that
returns a tuple of (serialized_value, serialized_partition_key). The
arguments passed to this function are a message's value and partition key,
and the returned data should be these fields transformed according to the
client code's serialization logic. See `pykafka.utils.__init__` for stock
implemtations.
:type serializer: function
"""
self._cluster = cluster
self._protocol_version = msg_protocol_version(cluster._broker_version)
Expand Down Expand Up @@ -182,6 +191,7 @@ def __init__(self,
if delivery_reports or self._synchronous
else _DeliveryReportNone())
self._auto_start = auto_start
self._serializer = serializer
self._running = False
self._update_lock = self._cluster.handler.Lock()
if self._auto_start:
Expand Down Expand Up @@ -317,16 +327,20 @@ def produce(self, message, partition_key=None, timestamp=None):
:return: The :class:`pykafka.protocol.Message` instance that was
added to the internal message queue
"""
if partition_key is not None and type(partition_key) is not bytes:
raise TypeError("Producer.produce accepts a bytes object as partition_key, "
"but it got '%s'", type(partition_key))
if message is not None and type(message) is not bytes:
raise TypeError("Producer.produce accepts a bytes object as message, but it "
"got '%s'", type(message))
if self._serializer is None:
if partition_key is not None and type(partition_key) is not bytes:
raise TypeError("Producer.produce accepts a bytes object as partition_key, "
"but it got '%s'", type(partition_key))
if message is not None and type(message) is not bytes:
raise TypeError("Producer.produce accepts a bytes object as message, but it "
"got '%s'", type(message))
if timestamp is not None and self._protocol_version < 1:
raise RuntimeError("Producer.produce got a timestamp with protocol 0")
if not self._running:
raise ProducerStoppedException()
if self._serializer is not None:
message, partition_key = self._serializer(message, partition_key)

partitions = list(self._topic.partitions.values())
partition_id = self._partitioner(partitions, partition_key).id

Expand Down
3 changes: 2 additions & 1 deletion pykafka/rdkafka/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(self,
max_request_size=1000012,
sync=False,
delivery_reports=False,
auto_start=True):
auto_start=True,
serializer=None):
callargs = {k: v for k, v in vars().items()
if k not in ("self", "__class__")}
self._broker_version = cluster._broker_version
Expand Down
5 changes: 3 additions & 2 deletions pykafka/rdkafka/simple_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __init__(self,
reset_offset_on_start=False,
compacted_topic=False,
generation_id=-1,
consumer_id=b''):
consumer_id=b'',
deserializer=None):
callargs = {k: v for k, v in vars().items()
if k not in ("self", "__class__")}
self._rdk_consumer = None
Expand Down Expand Up @@ -286,4 +287,4 @@ def _mk_rdkafka_config_lists(self):
# librdkafka expects all config values as strings:
conf = [(key, str(conf[key])) for key in conf]
topic_conf = [(key, str(topic_conf[key])) for key in topic_conf]
return conf, topic_conf
return conf, topic_conf
15 changes: 14 additions & 1 deletion pykafka/simpleconsumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def __init__(self,
reset_offset_on_start=False,
compacted_topic=False,
generation_id=-1,
consumer_id=b''):
consumer_id=b'',
deserializer=None):
"""Create a SimpleConsumer.
Settings and default values are taken from the Scala
Expand Down Expand Up @@ -155,6 +156,14 @@ def __init__(self,
:param consumer_id: The identifying string to use for this consumer on group
requests
:type consumer_id: bytes
:param deserializer: A function defining how to deserialize messages returned
from Kafka. A function with the signature d(value, partition_key) that
returns a tuple of (deserialized_value, deserialized_partition_key). The
arguments passed to this function are the bytes representations of a
message's value and partition key, and the returned data should be these
fields transformed according to the client code's serialization logic.
See `pykafka.utils.__init__` for stock implemtations.
:type deserializer: function
"""
self._running = False
self._cluster = cluster
Expand Down Expand Up @@ -187,6 +196,7 @@ def __init__(self,
self._generation_id = valid_int(generation_id, allow_zero=True,
allow_negative=True)
self._consumer_id = consumer_id
self._deserializer = deserializer

# incremented for any message arrival from any partition
# the initial value is 0 (no messages waiting)
Expand Down Expand Up @@ -468,6 +478,9 @@ def consume(self, block=True, unblock_event=None):
if not self._slot_available.is_set():
self._slot_available.set()

if self._deserializer is not None:
ret.value, ret.partition_key = self._deserializer(ret.value,
ret.partition_key)
return ret

def _auto_commit(self):
Expand Down
27 changes: 27 additions & 0 deletions pykafka/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,33 @@ def pack_into(self, buff, offset):
raise NotImplementedError()


def serialize_utf8(value, partition_key):
"""A serializer accepting bytes or str arguments and returning utf-8 encoded bytes
Can be used as `pykafka.producer.Producer(serializer=serialize_utf8)`
"""
if value is not None and type(value) != bytes:
# allow UnicodeError to be raised here if the encoding fails
value = value.encode('utf-8')
if partition_key is not None and type(partition_key) != bytes:
partition_key = partition_key.encode('utf-8')
return value, partition_key


def deserialize_utf8(value, partition_key):
"""A deserializer accepting bytes arguments and returning utf-8 strings
Can be used as `pykafka.simpleconsumer.SimpleConsumer(deserializer=deserialize_utf8)`,
or similarly in other consumer classes
"""
# allow UnicodeError to be raised here if the decoding fails
if value is not None:
value = value.decode('utf-8')
if partition_key is not None:
partition_key = partition_key.decode('utf-8')
return value, partition_key


VERSIONS_CACHE = {}


Expand Down
14 changes: 13 additions & 1 deletion tests/pykafka/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import platform
import pytest
import random
import time
import types
import unittest2
Expand All @@ -27,6 +28,7 @@
from pykafka.test.utils import get_cluster, stop_cluster, retry
from pykafka.common import CompressionType
from pykafka.producer import OwnedBroker
from pykafka.utils import serialize_utf8, deserialize_utf8
from tests.pykafka import patch_subclass

kafka_version = os.environ.get('KAFKA_VERSION', '0.8.0')
Expand Down Expand Up @@ -54,11 +56,12 @@ def _get_producer(self, **kwargs):
topic = self.client.topics[self.topic_name]
return topic.get_producer(use_rdkafka=self.USE_RDKAFKA, **kwargs)

def _get_consumer(self):
def _get_consumer(self, **kwargs):
return self.client.topics[self.topic_name].get_simple_consumer(
consumer_timeout_ms=1000,
auto_offset_reset=OffsetType.LATEST,
reset_offset_on_start=True,
**kwargs
)

def test_produce(self):
Expand All @@ -74,6 +77,15 @@ def test_produce(self):
message = consumer.consume()
assert message.value == payload

def test_produce_utf8(self):
payload = u"{}".format(random.random())
consumer = self._get_consumer(deserializer=deserialize_utf8)
prod = self._get_producer(sync=True, min_queued_messages=1,
serializer=serialize_utf8)
prod.produce(payload)
message = consumer.consume()
assert message.value == payload

def test_sync_produce_raises(self):
"""Ensure response errors are raised in produce() if sync=True"""
with self._get_producer(sync=True, min_queued_messages=1) as prod:
Expand Down

0 comments on commit 24bdb1a

Please sign in to comment.