Skip to content

Commit

Permalink
Added SASL/OAuthBearer (#630)
Browse files Browse the repository at this point in the history
* added OAUTHBEARER as an authorization method

* fixed flake errors

* made  AbstractTokenProvider to async friendly

* updated CHANGES

* changed oauth assert to ValueError

* updated CHANGES

* offered our indenpendent AbastractTokenProvider

* avoided calling extensions multiple times

* updated CHANGES

* Consumer memory leak in idle state (#629)

* hotfix/ memory leak in comsumer getmany

* hotfix/ Added changes in change log

* hotfix/ Fixed case where pending future set is empty

* hotfix/ renamed pending future variable

* updated CHANGES

Co-authored-by: Andy Luo <andy.luo@tibra.com>
Co-authored-by: Rajat Singh <iamsinghrajat@gmail.com>
  • Loading branch information
3 people committed Jun 17, 2020
1 parent 68be906 commit ea1aab8
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Changelog

628.bugfix
Fix memory leak in kafka consumer when consumer is in idle state not consuming any message
=======
618.feature
added `OAUTHBEARER` as a new `sasl_mechanism`.

0.6.0 (2020-05-15)
==================
Expand Down
53 changes: 52 additions & 1 deletion aiokafka/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,57 @@ def on_partitions_assigned(self, assigned):
pass


# This statement is compatible with both Python 2.7 & 3+
ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()})


class AbstractTokenProvider(ABC):
"""
A Token Provider must be used for the SASL OAuthBearer protocol.
The implementation should ensure token reuse so that multiple
calls at connect time do not create multiple tokens. The implementation
should also periodically refresh the token in order to guarantee
that each call returns an unexpired token. A timeout error should
be returned after a short period of inactivity so that the
broker can log debugging info and retry.
Token Providers MUST implement the token() method
"""

def __init__(self, **config):
pass

@abc.abstractmethod
async def token(self):
"""
An async callback returning a (str) ID/Access Token to be sent to
the Kafka client. In case where a synchoronous callback is needed,
implementations like following can be used:
.. highlight:: python
.. code-block:: python
from aiokafka.abc import AbstractTokenProvider
class CustomTokenProvider(AbstractTokenProvider):
async def token(self):
return asyncio.get_running_loop().run_in_executor(
None, self._token)
def _token(self):
# The actual synchoronous token callback.
"""
pass

def extensions(self):
"""
This is an OPTIONAL method that may be implemented.
Returns a map of key-value pairs that can
be sent with the SASL/OAUTHBEARER initial client request. If
not implemented, the values are ignored. This feature is only available
in Kafka >= 2.1.0.
"""
return {}


__all__ = [
"ConsumerRebalanceListener"
"ConsumerRebalanceListener",
"AbstractTokenProvider"
]
15 changes: 11 additions & 4 deletions aiokafka/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def __init__(self, *, loop=None, bootstrap_servers='localhost',
sasl_plain_username=None,
sasl_plain_password=None,
sasl_kerberos_service_name='kafka',
sasl_kerberos_domain_name=None):
sasl_kerberos_domain_name=None,
sasl_oauth_token_provider=None
):
if loop is None:
loop = get_running_loop()

Expand All @@ -106,10 +108,12 @@ def __init__(self, *, loop=None, bootstrap_servers='localhost',
"`ssl_context` is mandatory if security_protocol=='SSL'")
if security_protocol in ["SASL_SSL", "SASL_PLAINTEXT"]:
if sasl_mechanism not in (
"PLAIN", "GSSAPI", "SCRAM-SHA-256", "SCRAM-SHA-512"):
"PLAIN", "GSSAPI", "SCRAM-SHA-256", "SCRAM-SHA-512",
"OAUTHBEARER"):
raise ValueError(
"only `PLAIN`, `GSSAPI`, `SCRAM-SHA-256` and "
"`SCRAM-SHA-512` sasl_mechanism are supported "
"only `PLAIN`, `GSSAPI`, `SCRAM-SHA-256`, "
"`SCRAM-SHA-512` and `OAUTHBEARER`"
"sasl_mechanism are supported "
"at the moment")
if sasl_mechanism == "PLAIN" and \
(sasl_plain_username is None or sasl_plain_password is None):
Expand All @@ -133,6 +137,7 @@ def __init__(self, *, loop=None, bootstrap_servers='localhost',
self._sasl_plain_password = sasl_plain_password
self._sasl_kerberos_service_name = sasl_kerberos_service_name
self._sasl_kerberos_domain_name = sasl_kerberos_domain_name
self._sasl_oauth_token_provider = sasl_oauth_token_provider

self.cluster = ClusterMetadata(metadata_max_age_ms=metadata_max_age_ms)

Expand Down Expand Up @@ -202,6 +207,7 @@ async def bootstrap(self):
sasl_plain_password=self._sasl_plain_password,
sasl_kerberos_service_name=self._sasl_kerberos_service_name, # noqa: ignore=E501
sasl_kerberos_domain_name=self._sasl_kerberos_domain_name,
sasl_oauth_token_provider=self._sasl_oauth_token_provider,
version_hint=version_hint)
except (OSError, asyncio.TimeoutError) as err:
log.error('Unable connect to "%s:%s": %s', host, port, err)
Expand Down Expand Up @@ -438,6 +444,7 @@ async def _get_conn(
sasl_plain_password=self._sasl_plain_password,
sasl_kerberos_service_name=self._sasl_kerberos_service_name, # noqa: ignore=E501
sasl_kerberos_domain_name=self._sasl_kerberos_domain_name,
sasl_oauth_token_provider=self._sasl_oauth_token_provider,
version_hint=version_hint
)
except (OSError, asyncio.TimeoutError, KafkaError) as err:
Expand Down
68 changes: 67 additions & 1 deletion aiokafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import aiokafka.errors as Errors
from aiokafka.util import ensure_future, create_future, PY_36

from aiokafka.abc import AbstractTokenProvider

try:
import gssapi
except ImportError:
Expand Down Expand Up @@ -75,6 +77,7 @@ async def create_conn(
sasl_plain_password=None,
sasl_kerberos_service_name='kafka',
sasl_kerberos_domain_name=None,
sasl_oauth_token_provider=None,
version_hint=None
):
if loop is None:
Expand All @@ -90,6 +93,7 @@ async def create_conn(
sasl_plain_password=sasl_plain_password,
sasl_kerberos_service_name=sasl_kerberos_service_name,
sasl_kerberos_domain_name=sasl_kerberos_domain_name,
sasl_oauth_token_provider=sasl_oauth_token_provider,
version_hint=version_hint)
await conn.connect()
return conn
Expand Down Expand Up @@ -122,10 +126,23 @@ def __init__(self, host, port, *, loop, client_id='aiokafka',
sasl_plain_password=None, sasl_plain_username=None,
sasl_kerberos_service_name='kafka',
sasl_kerberos_domain_name=None,
sasl_oauth_token_provider=None,
version_hint=None):
if sasl_mechanism == "GSSAPI":
assert gssapi is not None, "gssapi library required"

if sasl_mechanism == "OAUTHBEARER":
if sasl_oauth_token_provider is None or \
not isinstance(
sasl_oauth_token_provider, AbstractTokenProvider):
raise ValueError("sasl_oauth_token_provider needs to be \
provided implementing aiokafka.abc.AbstractTokenProvider")
assert callable(
getattr(sasl_oauth_token_provider, "token", None)
), (
'sasl_oauth_token_provider must implement method #token()'
)

self._loop = loop
self._host = host
self._port = port
Expand All @@ -139,6 +156,7 @@ def __init__(self, host, port, *, loop, client_id='aiokafka',
self._sasl_plain_password = sasl_plain_password
self._sasl_kerberos_service_name = sasl_kerberos_service_name
self._sasl_kerberos_domain_name = sasl_kerberos_domain_name
self._sasl_oauth_token_provider = sasl_oauth_token_provider

# Version hint is the version determined by initial client bootstrap
self._version_hint = version_hint
Expand Down Expand Up @@ -262,7 +280,7 @@ async def _do_sasl_handshake(self):
raise exc

assert self._sasl_mechanism in (
'PLAIN', 'GSSAPI', 'SCRAM-SHA-256', 'SCRAM-SHA-512'
'PLAIN', 'GSSAPI', 'SCRAM-SHA-256', 'SCRAM-SHA-512', 'OAUTHBEARER'
)
if self._security_protocol == 'SASL_PLAINTEXT' and \
self._sasl_mechanism == 'PLAIN':
Expand All @@ -273,6 +291,8 @@ async def _do_sasl_handshake(self):
authenticator = self.authenticator_gssapi()
elif self._sasl_mechanism.startswith('SCRAM-SHA-'):
authenticator = self.authenticator_scram()
elif self._sasl_mechanism == 'OAUTHBEARER':
authenticator = self.authenticator_oauth()
else:
authenticator = self.authenticator_plain()

Expand Down Expand Up @@ -312,6 +332,10 @@ async def _do_sasl_handshake(self):
self.log.info(
'Authenticated as %s via GSSAPI',
self.sasl_principal)
elif self._sasl_mechanism == 'OAUTHBEARER':
self.log.info(
'Authenticated via OAUTHBEARER'
)
else:
self.log.info('Authenticated as %s via PLAIN',
self._sasl_plain_username)
Expand All @@ -334,6 +358,10 @@ def authenticator_scram(self):
sasl_plain_username=self._sasl_plain_username,
sasl_mechanism=self._sasl_mechanism)

def authenticator_oauth(self):
return OAuthAuthenticator(
sasl_oauth_token_provider=self._sasl_oauth_token_provider)

@property
def sasl_principal(self):
service = self._sasl_kerberos_service_name
Expand Down Expand Up @@ -692,3 +720,41 @@ def create_salted_password(self, salt, iterations):
@staticmethod
def _xor_bytes(left, right):
return bytes(lb ^ rb for lb, rb in zip(left, right))


class OAuthAuthenticator(BaseSaslAuthenticator):
def __init__(self, *, sasl_oauth_token_provider):
self._sasl_oauth_token_provider = sasl_oauth_token_provider
self._token_sent = False

async def step(self, payload):
if self._token_sent:
return
token = await self._sasl_oauth_token_provider.token()
token_extensions = self._token_extensions()
self._token_sent = True
return self._build_oauth_client_request(token, token_extensions)\
.encode("utf-8"), True

def _build_oauth_client_request(self, token, token_extensions):
return "n,,\x01auth=Bearer {}{}\x01\x01".format(
token, token_extensions
)

def _token_extensions(self):
"""
Return a string representation of the OPTIONAL key-value pairs
that can be sent with an OAUTHBEARER initial request.
"""
# Only run if the #extensions() method is implemented
# by the clients Token Provider class
# Builds up a string separated by \x01 via a dict of key value pairs
if callable(
getattr(self._sasl_oauth_token_provider, "extensions", None)):
extensions = self._sasl_oauth_token_provider.extensions()
if len(extensions) > 0:
msg = "\x01".join(
["{}={}".format(k, v) for k, v in extensions.items()])
return "\x01" + msg

return ""
12 changes: 9 additions & 3 deletions aiokafka/consumer/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,15 @@ class AIOKafkaConsumer(object):
sasl_mechanism (str): Authentication mechanism when security_protocol
is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are:
PLAIN, GSSAPI, SCRAM-SHA-256, SCRAM-SHA-512. Default: PLAIN
PLAIN, GSSAPI, SCRAM-SHA-256, SCRAM-SHA-512, OAUTHBEARER.
Default: PLAIN
sasl_plain_username (str): username for sasl PLAIN authentication.
Default: None
sasl_plain_password (str): password for sasl PLAIN authentication.
Default: None
sasl_oauth_token_provider (kafka.oauth.abstract.AbstractTokenProvider):
OAuthBearer token provider instance. (See kafka.oauth.abstract).
Default: None
Note:
Many configuration parameters are taken from Java Client:
Expand Down Expand Up @@ -247,7 +251,8 @@ def __init__(self, *topics, loop=None,
sasl_plain_password=None,
sasl_plain_username=None,
sasl_kerberos_service_name='kafka',
sasl_kerberos_domain_name=None):
sasl_kerberos_domain_name=None,
sasl_oauth_token_provider=None):
if loop is None:
loop = get_running_loop()

Expand All @@ -271,7 +276,8 @@ def __init__(self, *topics, loop=None,
sasl_plain_username=sasl_plain_username,
sasl_plain_password=sasl_plain_password,
sasl_kerberos_service_name=sasl_kerberos_service_name,
sasl_kerberos_domain_name=sasl_kerberos_domain_name)
sasl_kerberos_domain_name=sasl_kerberos_domain_name,
sasl_oauth_token_provider=sasl_oauth_token_provider)

self._group_id = group_id
self._heartbeat_interval_ms = heartbeat_interval_ms
Expand Down
12 changes: 9 additions & 3 deletions aiokafka/producer/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,15 @@ class AIOKafkaProducer(object):
New in version 0.5.0.
sasl_mechanism (str): Authentication mechanism when security_protocol
is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are:
PLAIN, GSSAPI, SCRAM-SHA-256, SCRAM-SHA-512. Default: PLAIN
PLAIN, GSSAPI, SCRAM-SHA-256, SCRAM-SHA-512, OAUTHBEARER.
Default: PLAIN
sasl_plain_username (str): username for sasl PLAIN authentication.
Default: None
sasl_plain_password (str): password for sasl PLAIN authentication.
Default: None
sasl_oauth_token_provider (kafka.oauth.abstract.AbstractTokenProvider):
OAuthBearer token provider instance. (See kafka.oauth.abstract).
Default: None
Note:
Many configuration parameters are taken from the Java client:
Expand Down Expand Up @@ -184,7 +188,8 @@ def __init__(self, *, loop=None, bootstrap_servers='localhost',
transaction_timeout_ms=60000, sasl_mechanism="PLAIN",
sasl_plain_password=None, sasl_plain_username=None,
sasl_kerberos_service_name='kafka',
sasl_kerberos_domain_name=None):
sasl_kerberos_domain_name=None,
sasl_oauth_token_provider=None):
if loop is None:
loop = get_running_loop()

Expand Down Expand Up @@ -246,7 +251,8 @@ def __init__(self, *, loop=None, bootstrap_servers='localhost',
sasl_plain_username=sasl_plain_username,
sasl_plain_password=sasl_plain_password,
sasl_kerberos_service_name=sasl_kerberos_service_name,
sasl_kerberos_domain_name=sasl_kerberos_domain_name)
sasl_kerberos_domain_name=sasl_kerberos_domain_name,
sasl_oauth_token_provider=sasl_oauth_token_provider)
self._metadata = self.client.cluster
self._message_accumulator = MessageAccumulator(
self._metadata, max_batch_size, compression_attrs,
Expand Down

0 comments on commit ea1aab8

Please sign in to comment.