-
Notifications
You must be signed in to change notification settings - Fork 263
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the MSK IAM transport to support AWS MSK cluster instances withou…
…t additional custom ones Signed-off-by: Mattia Bertorello <mattia.bertorello@booking.com>
- Loading branch information
1 parent
840dd35
commit 4ff38d0
Showing
5 changed files
with
343 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright 2018-2024 contributors to the OpenLineage project | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from __future__ import annotations | ||
|
||
import functools | ||
import logging | ||
import os | ||
from typing import Any | ||
|
||
import attr | ||
from openlineage.client.transport import KafkaConfig, KafkaTransport | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def _detect_running_region() -> None | str: | ||
"""Dynamically determine the region from a running Glue job (or anything on EC2 for | ||
that matter). | ||
https://stackoverflow.com/questions/37514810/how-to-get-the-region-of-the-current-user-from-boto | ||
""" | ||
import boto3 # type: ignore[import] | ||
|
||
easy_checks = [ | ||
# check if set through ENV vars | ||
os.environ.get("AWS_REGION"), | ||
os.environ.get("AWS_DEFAULT_REGION"), | ||
# else check if set in config or in boto already | ||
boto3.DEFAULT_SESSION.region_name if boto3.DEFAULT_SESSION else None, | ||
boto3.Session().region_name, | ||
] | ||
region: None | str = None | ||
for region in easy_checks: | ||
if region: | ||
return region | ||
|
||
# else query an external service | ||
# https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html | ||
import requests | ||
|
||
try: | ||
r = requests.get("http://169.254.169.254/latest/dynamic/instance-identity/document", timeout=1) | ||
region = r.json().get("region") | ||
except Exception: # noqa: S110 BLE001 | ||
pass | ||
|
||
return region | ||
|
||
|
||
@attr.s | ||
class MSKIAMConfig(KafkaConfig): | ||
# MSK producer config | ||
# https://github.com/aws/aws-msk-iam-sasl-signer-python | ||
|
||
region: str = attr.ib(default=None) | ||
aws_profile: None | str = attr.ib(default=None) | ||
role_arn: None | str = attr.ib(default=None) | ||
aws_debug_creds: bool = attr.ib(default=False) | ||
|
||
|
||
def _oauth_cb(config: MSKIAMConfig, *_: Any) -> tuple[str, float]: | ||
from aws_msk_iam_sasl_signer import MSKAuthTokenProvider # type: ignore[import] | ||
|
||
region = config.region | ||
if config.aws_profile: | ||
auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token_from_profile( | ||
region, config.aws_profile | ||
) | ||
elif config.role_arn: | ||
auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token_from_role_arn( | ||
region, config.role_arn | ||
) | ||
# Implement the version to load a custom `botocore.credentials.CredentialProvider` at runtime | ||
# and calling the method | ||
# `MSKAuthTokenProvider.generate_auth_token_from_credentials_provider(region, credentials_provider)` | ||
else: | ||
auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token( | ||
region, aws_debug_creds=config.aws_debug_creds | ||
) | ||
log.debug("Token expiry time: %s region %s", expiry_ms, region) | ||
# Note that this library expects oauth_cb to return expiry time in seconds since epoch, | ||
# while the token generator returns expiry in ms | ||
return auth_token, expiry_ms / 1000 | ||
|
||
|
||
class MSKIAMTransport(KafkaTransport): | ||
kind = "msk-iam" | ||
config_class = MSKIAMConfig | ||
|
||
def __init__(self, config: MSKIAMConfig) -> None: | ||
self.msk_config = config | ||
super().__init__(config) | ||
|
||
def _setup_producer(self, config: dict) -> None: # type: ignore[type-arg] | ||
try: | ||
log.info("Setup the MSK transport with this configuration: %s", self.msk_config) | ||
|
||
# https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md | ||
if self.msk_config.region is None: | ||
region = _detect_running_region() | ||
if region: | ||
self.msk_config.region = region | ||
else: | ||
except_message = ( | ||
"OpenLineage MSK IAM Transport must have a region defined. " | ||
"Please use the `region` configuration key to set it." | ||
) | ||
log.exception(except_message) | ||
raise ValueError(except_message) | ||
config.update( | ||
{ | ||
"security.protocol": "SASL_SSL", | ||
"sasl.mechanism": "OAUTHBEARER", | ||
"oauth_cb": functools.partial(_oauth_cb, self.msk_config), | ||
} | ||
) | ||
super()._setup_producer(config) | ||
except ModuleNotFoundError: | ||
log.exception( | ||
"OpenLineage client did not found aws-msk-iam-sasl-signer-python module. " | ||
"Installing it is required for MSK IAM Transport to work. " | ||
"You can also get it via `pip install openlineage-python[mskiam]`", | ||
) | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
# Copyright 2018-2024 contributors to the OpenLineage project | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from __future__ import annotations | ||
|
||
import datetime | ||
import os | ||
import uuid | ||
from typing import TYPE_CHECKING | ||
from unittest import mock | ||
|
||
import pytest | ||
from openlineage.client.run import Job, Run, RunEvent, RunState | ||
from openlineage.client.transport.msk_iam import ( | ||
MSKIAMConfig, | ||
MSKIAMTransport, | ||
_detect_running_region, | ||
_oauth_cb, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pytest_mock import MockerFixture | ||
|
||
|
||
@pytest.fixture() | ||
def event() -> RunEvent: | ||
return RunEvent( | ||
eventType=RunState.START, | ||
eventTime=datetime.datetime.now().isoformat(), | ||
run=Run(runId=str(uuid.uuid4())), | ||
job=Job(namespace="kafka", name="test"), | ||
producer="prod", | ||
schemaURL="schema", | ||
) | ||
|
||
|
||
def test_msk_loads_full_config() -> None: | ||
config = MSKIAMConfig.from_dict( | ||
{ | ||
"type": "msk", | ||
"config": {"bootstrap.servers": "xxx.c2.kafka.us-east-1.amazonaws.com:9098"}, | ||
"topic": "random-topic", | ||
"flush": False, | ||
"region": "us-east-1", | ||
}, | ||
) | ||
|
||
assert config.config["bootstrap.servers"] == "xxx.c2.kafka.us-east-1.amazonaws.com:9098" | ||
assert config.topic == "random-topic" | ||
assert config.region == "us-east-1" | ||
assert config.flush is False | ||
|
||
|
||
@mock.patch.dict(os.environ, {"AWS_DEFAULT_REGION": "eu-west-1"}) | ||
def test_msk_detect_running_default_region() -> None: | ||
region = _detect_running_region() | ||
|
||
assert region == "eu-west-1" | ||
|
||
|
||
@mock.patch.dict(os.environ, {"AWS_REGION": "eu-central-1"}) | ||
def test_msk_detect_running_region() -> None: | ||
region = _detect_running_region() | ||
|
||
assert region == "eu-central-1" | ||
|
||
|
||
def test_msk_detect_running_region_ec2(mocker: MockerFixture) -> None: | ||
method_json = mocker.MagicMock() | ||
method_json.json.return_value = {"region": "us-west-1"} | ||
mocker.patch("requests.get", return_value=method_json) | ||
region = _detect_running_region() | ||
|
||
assert region == "us-west-1" | ||
|
||
|
||
def test_msk_detect_running_region_empty(mocker: MockerFixture) -> None: | ||
mocker.patch("requests.get", side_effect=Exception()) | ||
region = _detect_running_region() | ||
assert region is None | ||
|
||
|
||
def test_msk_load_config_fails_on_no_config() -> None: | ||
with pytest.raises(TypeError): | ||
MSKIAMConfig.from_dict( | ||
{ | ||
"type": "kafka", | ||
"config": {"bootstrap.servers": "localhost:9092"}, | ||
}, | ||
) | ||
|
||
|
||
def test_msk_token_provider(mocker: MockerFixture) -> None: | ||
expiry_time_ms = 1000 | ||
expected_expiry_time_ms = 1.0 | ||
mock_methods = {} | ||
for method in [ | ||
"generate_auth_token", | ||
"generate_auth_token_from_profile", | ||
"generate_auth_token_from_role_arn", | ||
]: | ||
mock_methods[method] = mocker.patch( | ||
f"aws_msk_iam_sasl_signer.MSKAuthTokenProvider.{method}", | ||
return_value=("abc:" + method, expiry_time_ms), | ||
) | ||
config = MSKIAMConfig( | ||
config={"bootstrap.servers": "localhost:9092"}, | ||
topic="random-topic", | ||
region="us-east-1", | ||
) | ||
|
||
token, expire_time = _oauth_cb(config, None) | ||
assert token == "abc:generate_auth_token" # noqa: S105 | ||
assert expire_time == expected_expiry_time_ms | ||
mock_methods["generate_auth_token"].assert_called_once_with( | ||
config.region, aws_debug_creds=config.aws_debug_creds | ||
) | ||
|
||
# Test generate_auth_token_from_profile | ||
config = MSKIAMConfig( | ||
config={"bootstrap.servers": "localhost:9092"}, | ||
topic="random-topic", | ||
region="us-east-1", | ||
aws_profile="default_profile1", | ||
) | ||
|
||
token, expire_time = _oauth_cb(config, None) | ||
assert token == "abc:generate_auth_token_from_profile" # noqa: S105 | ||
assert expire_time == expected_expiry_time_ms | ||
mock_methods["generate_auth_token_from_profile"].assert_called_once_with( | ||
config.region, config.aws_profile | ||
) | ||
|
||
# Test generate_auth_token_from_role_arn | ||
config = MSKIAMConfig( | ||
config={"bootstrap.servers": "localhost:9092"}, | ||
topic="random-topic", | ||
region="us-east-1", | ||
role_arn="arn:aws:iam::1234:role/abc", | ||
) | ||
|
||
token, expire_time = _oauth_cb(config, None) | ||
assert token == "abc:generate_auth_token_from_role_arn" # noqa: S105 | ||
assert expire_time == expected_expiry_time_ms | ||
mock_methods["generate_auth_token_from_role_arn"].assert_called_once_with(config.region, config.role_arn) | ||
|
||
# Test default region | ||
mocker.patch( | ||
"openlineage.client.transport.msk_iam._detect_running_region", | ||
return_value="eu-central-1", | ||
) | ||
config = MSKIAMConfig( | ||
config={"bootstrap.servers": "localhost:9092"}, | ||
topic="random-topic", | ||
) | ||
MSKIAMTransport(config) | ||
transport = MSKIAMTransport(config) | ||
assert transport.msk_config.region == "eu-central-1" | ||
|
||
# Test no region | ||
mocker.patch( | ||
"openlineage.client.transport.msk_iam._detect_running_region", | ||
return_value=None, | ||
) | ||
config = MSKIAMConfig( | ||
config={"bootstrap.servers": "localhost:9092"}, | ||
topic="random-topic", | ||
) | ||
with pytest.raises( | ||
ValueError, | ||
match="OpenLineage MSK IAM Transport must have a region defined. " | ||
"Please use the `region` configuration key to set it.", | ||
): | ||
MSKIAMTransport(config) | ||
|
||
|
||
def test_setup_producer_configuration( | ||
mocker: MockerFixture, | ||
) -> None: | ||
mocker.patch( | ||
"openlineage.client.transport.kafka._check_if_airflow_sqlalchemy_context", | ||
return_value=False, | ||
) | ||
config = MSKIAMConfig( | ||
config={"bootstrap.servers": "localhost:9092"}, | ||
topic="random-topic", | ||
region="us-east-1", | ||
) | ||
|
||
setup_producer_mocker = mocker.patch( | ||
"openlineage.client.transport.kafka.KafkaTransport._setup_producer", | ||
) | ||
msk_token_mocker = mocker.patch( | ||
"openlineage.client.transport.msk_iam._oauth_cb", return_value=("token", 1000) | ||
) | ||
MSKIAMTransport(config) | ||
|
||
expected_kafka_config = { | ||
"bootstrap.servers": "localhost:9092", | ||
"sasl.mechanism": "OAUTHBEARER", | ||
"security.protocol": "SASL_SSL", | ||
} | ||
total_producer_configuration_keys = 4 | ||
args, kwargs = setup_producer_mocker.call_args | ||
actual_kafka_config = args[0] | ||
assert len(actual_kafka_config) == total_producer_configuration_keys | ||
assert "oauth_cb" in actual_kafka_config | ||
actual_oauth_cb = actual_kafka_config["oauth_cb"] | ||
del actual_kafka_config["oauth_cb"] | ||
assert actual_kafka_config == expected_kafka_config | ||
assert actual_oauth_cb(msk_token_mocker) == ("token", 1000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters