Skip to content

Commit

Permalink
Add the MSK IAM transport to support AWS MSK cluster instances withou…
Browse files Browse the repository at this point in the history
…t additional custom ones

Signed-off-by: Mattia Bertorello <mattia.bertorello@booking.com>
  • Loading branch information
mattiabertorello committed Feb 29, 2024
1 parent 840dd35 commit 4ff38d0
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 0 deletions.
4 changes: 4 additions & 0 deletions client/python/openlineage/client/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from openlineage.client.transport.file import FileTransport
from openlineage.client.transport.http import HttpConfig, HttpTransport
from openlineage.client.transport.kafka import KafkaConfig, KafkaTransport
from openlineage.client.transport.msk_iam import MSKIAMTransport
from openlineage.client.transport.noop import NoopTransport
from openlineage.client.transport.transport import Config, Transport, TransportFactory

_factory = DefaultTransportFactory()
_factory.register_transport(HttpTransport.kind, HttpTransport)
_factory.register_transport(KafkaTransport.kind, KafkaTransport)
_factory.register_transport(MSKIAMTransport.kind, MSKIAMTransport)
_factory.register_transport(ConsoleTransport.kind, ConsoleTransport)
_factory.register_transport(NoopTransport.kind, NoopTransport)
_factory.register_transport(FileTransport.kind, FileTransport)
Expand All @@ -36,6 +38,8 @@ def register_transport(clazz: type[Transport]) -> type[Transport]:
"HttpTransport",
"KafkaConfig",
"KafkaTransport",
"MSKIAMTransport",
"MSKIAMConfig",
"ConsoleTransport",
"NoopTransport",
"Transport",
Expand Down
123 changes: 123 additions & 0 deletions client/python/openlineage/client/transport/msk_iam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2018-2024 contributors to the OpenLineage project
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import functools
import logging
import os
from typing import Any

import attr
from openlineage.client.transport import KafkaConfig, KafkaTransport

log = logging.getLogger(__name__)


def _detect_running_region() -> None | str:
"""Dynamically determine the region from a running Glue job (or anything on EC2 for
that matter).
https://stackoverflow.com/questions/37514810/how-to-get-the-region-of-the-current-user-from-boto
"""
import boto3 # type: ignore[import]

easy_checks = [
# check if set through ENV vars
os.environ.get("AWS_REGION"),
os.environ.get("AWS_DEFAULT_REGION"),
# else check if set in config or in boto already
boto3.DEFAULT_SESSION.region_name if boto3.DEFAULT_SESSION else None,
boto3.Session().region_name,
]
region: None | str = None
for region in easy_checks:
if region:
return region

# else query an external service
# https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html
import requests

try:
r = requests.get("http://169.254.169.254/latest/dynamic/instance-identity/document", timeout=1)
region = r.json().get("region")
except Exception: # noqa: S110 BLE001
pass

return region


@attr.s
class MSKIAMConfig(KafkaConfig):
# MSK producer config
# https://github.com/aws/aws-msk-iam-sasl-signer-python

region: str = attr.ib(default=None)
aws_profile: None | str = attr.ib(default=None)
role_arn: None | str = attr.ib(default=None)
aws_debug_creds: bool = attr.ib(default=False)


def _oauth_cb(config: MSKIAMConfig, *_: Any) -> tuple[str, float]:
from aws_msk_iam_sasl_signer import MSKAuthTokenProvider # type: ignore[import]

region = config.region
if config.aws_profile:
auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token_from_profile(
region, config.aws_profile
)
elif config.role_arn:
auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token_from_role_arn(
region, config.role_arn
)
# Implement the version to load a custom `botocore.credentials.CredentialProvider` at runtime
# and calling the method
# `MSKAuthTokenProvider.generate_auth_token_from_credentials_provider(region, credentials_provider)`
else:
auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token(
region, aws_debug_creds=config.aws_debug_creds
)
log.debug("Token expiry time: %s region %s", expiry_ms, region)
# Note that this library expects oauth_cb to return expiry time in seconds since epoch,
# while the token generator returns expiry in ms
return auth_token, expiry_ms / 1000


class MSKIAMTransport(KafkaTransport):
kind = "msk-iam"
config_class = MSKIAMConfig

def __init__(self, config: MSKIAMConfig) -> None:
self.msk_config = config
super().__init__(config)

def _setup_producer(self, config: dict) -> None: # type: ignore[type-arg]
try:
log.info("Setup the MSK transport with this configuration: %s", self.msk_config)

# https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md
if self.msk_config.region is None:
region = _detect_running_region()
if region:
self.msk_config.region = region
else:
except_message = (
"OpenLineage MSK IAM Transport must have a region defined. "
"Please use the `region` configuration key to set it."
)
log.exception(except_message)
raise ValueError(except_message)
config.update(
{
"security.protocol": "SASL_SSL",
"sasl.mechanism": "OAUTHBEARER",
"oauth_cb": functools.partial(_oauth_cb, self.msk_config),
}
)
super()._setup_producer(config)
except ModuleNotFoundError:
log.exception(
"OpenLineage client did not found aws-msk-iam-sasl-signer-python module. "
"Installing it is required for MSK IAM Transport to work. "
"You can also get it via `pip install openlineage-python[mskiam]`",
)
raise
4 changes: 4 additions & 0 deletions client/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ dependencies = [
optional-dependencies.kafka = [
"confluent-kafka>=2.1.1",
]
optional-dependencies.msk-iam = [
"aws-msk-iam-sasl-signer-python>=1.0.1",
"confluent-kafka>=2.1.1",
]
optional-dependencies.test = [
"covdefaults>=2.3",
"pytest>=7.3.1",
Expand Down
210 changes: 210 additions & 0 deletions client/python/tests/test_msk_iam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright 2018-2024 contributors to the OpenLineage project
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import datetime
import os
import uuid
from typing import TYPE_CHECKING
from unittest import mock

import pytest
from openlineage.client.run import Job, Run, RunEvent, RunState
from openlineage.client.transport.msk_iam import (
MSKIAMConfig,
MSKIAMTransport,
_detect_running_region,
_oauth_cb,
)

if TYPE_CHECKING:
from pytest_mock import MockerFixture


@pytest.fixture()
def event() -> RunEvent:
return RunEvent(
eventType=RunState.START,
eventTime=datetime.datetime.now().isoformat(),
run=Run(runId=str(uuid.uuid4())),
job=Job(namespace="kafka", name="test"),
producer="prod",
schemaURL="schema",
)


def test_msk_loads_full_config() -> None:
config = MSKIAMConfig.from_dict(
{
"type": "msk",
"config": {"bootstrap.servers": "xxx.c2.kafka.us-east-1.amazonaws.com:9098"},
"topic": "random-topic",
"flush": False,
"region": "us-east-1",
},
)

assert config.config["bootstrap.servers"] == "xxx.c2.kafka.us-east-1.amazonaws.com:9098"
assert config.topic == "random-topic"
assert config.region == "us-east-1"
assert config.flush is False


@mock.patch.dict(os.environ, {"AWS_DEFAULT_REGION": "eu-west-1"})
def test_msk_detect_running_default_region() -> None:
region = _detect_running_region()

assert region == "eu-west-1"


@mock.patch.dict(os.environ, {"AWS_REGION": "eu-central-1"})
def test_msk_detect_running_region() -> None:
region = _detect_running_region()

assert region == "eu-central-1"


def test_msk_detect_running_region_ec2(mocker: MockerFixture) -> None:
method_json = mocker.MagicMock()
method_json.json.return_value = {"region": "us-west-1"}
mocker.patch("requests.get", return_value=method_json)
region = _detect_running_region()

assert region == "us-west-1"


def test_msk_detect_running_region_empty(mocker: MockerFixture) -> None:
mocker.patch("requests.get", side_effect=Exception())
region = _detect_running_region()
assert region is None


def test_msk_load_config_fails_on_no_config() -> None:
with pytest.raises(TypeError):
MSKIAMConfig.from_dict(
{
"type": "kafka",
"config": {"bootstrap.servers": "localhost:9092"},
},
)


def test_msk_token_provider(mocker: MockerFixture) -> None:
expiry_time_ms = 1000
expected_expiry_time_ms = 1.0
mock_methods = {}
for method in [
"generate_auth_token",
"generate_auth_token_from_profile",
"generate_auth_token_from_role_arn",
]:
mock_methods[method] = mocker.patch(
f"aws_msk_iam_sasl_signer.MSKAuthTokenProvider.{method}",
return_value=("abc:" + method, expiry_time_ms),
)
config = MSKIAMConfig(
config={"bootstrap.servers": "localhost:9092"},
topic="random-topic",
region="us-east-1",
)

token, expire_time = _oauth_cb(config, None)
assert token == "abc:generate_auth_token" # noqa: S105
assert expire_time == expected_expiry_time_ms
mock_methods["generate_auth_token"].assert_called_once_with(
config.region, aws_debug_creds=config.aws_debug_creds
)

# Test generate_auth_token_from_profile
config = MSKIAMConfig(
config={"bootstrap.servers": "localhost:9092"},
topic="random-topic",
region="us-east-1",
aws_profile="default_profile1",
)

token, expire_time = _oauth_cb(config, None)
assert token == "abc:generate_auth_token_from_profile" # noqa: S105
assert expire_time == expected_expiry_time_ms
mock_methods["generate_auth_token_from_profile"].assert_called_once_with(
config.region, config.aws_profile
)

# Test generate_auth_token_from_role_arn
config = MSKIAMConfig(
config={"bootstrap.servers": "localhost:9092"},
topic="random-topic",
region="us-east-1",
role_arn="arn:aws:iam::1234:role/abc",
)

token, expire_time = _oauth_cb(config, None)
assert token == "abc:generate_auth_token_from_role_arn" # noqa: S105
assert expire_time == expected_expiry_time_ms
mock_methods["generate_auth_token_from_role_arn"].assert_called_once_with(config.region, config.role_arn)

# Test default region
mocker.patch(
"openlineage.client.transport.msk_iam._detect_running_region",
return_value="eu-central-1",
)
config = MSKIAMConfig(
config={"bootstrap.servers": "localhost:9092"},
topic="random-topic",
)
MSKIAMTransport(config)
transport = MSKIAMTransport(config)
assert transport.msk_config.region == "eu-central-1"

# Test no region
mocker.patch(
"openlineage.client.transport.msk_iam._detect_running_region",
return_value=None,
)
config = MSKIAMConfig(
config={"bootstrap.servers": "localhost:9092"},
topic="random-topic",
)
with pytest.raises(
ValueError,
match="OpenLineage MSK IAM Transport must have a region defined. "
"Please use the `region` configuration key to set it.",
):
MSKIAMTransport(config)


def test_setup_producer_configuration(
mocker: MockerFixture,
) -> None:
mocker.patch(
"openlineage.client.transport.kafka._check_if_airflow_sqlalchemy_context",
return_value=False,
)
config = MSKIAMConfig(
config={"bootstrap.servers": "localhost:9092"},
topic="random-topic",
region="us-east-1",
)

setup_producer_mocker = mocker.patch(
"openlineage.client.transport.kafka.KafkaTransport._setup_producer",
)
msk_token_mocker = mocker.patch(
"openlineage.client.transport.msk_iam._oauth_cb", return_value=("token", 1000)
)
MSKIAMTransport(config)

expected_kafka_config = {
"bootstrap.servers": "localhost:9092",
"sasl.mechanism": "OAUTHBEARER",
"security.protocol": "SASL_SSL",
}
total_producer_configuration_keys = 4
args, kwargs = setup_producer_mocker.call_args
actual_kafka_config = args[0]
assert len(actual_kafka_config) == total_producer_configuration_keys
assert "oauth_cb" in actual_kafka_config
actual_oauth_cb = actual_kafka_config["oauth_cb"]
del actual_kafka_config["oauth_cb"]
assert actual_kafka_config == expected_kafka_config
assert actual_oauth_cb(msk_token_mocker) == ("token", 1000)
2 changes: 2 additions & 0 deletions client/python/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package = wheel
wheel_build_env = .pkg
extras =
kafka
msk-iam
test
set_env =
COVERAGE_FILE = {toxworkdir}/.coverage.{envname}
Expand Down Expand Up @@ -48,6 +49,7 @@ description = generate a DEV environment
package = editable
extras =
kafka
msk-iam
test
commands =
python -m pip list --format=columns
Expand Down

0 comments on commit 4ff38d0

Please sign in to comment.