-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Amazon Elastic Container Registry (ECR) Hook
Co-authored-by: D. Ferruzzi <ferruzzi@amazon.com> Co-authored-by: Niko <onikolas@amazon.com>
- Loading branch information
1 parent
7f9727f
commit 7fb732d
Showing
5 changed files
with
212 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import base64 | ||
import logging | ||
from dataclasses import dataclass | ||
from datetime import datetime | ||
|
||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook | ||
from airflow.utils.log.secrets_masker import mask_secret | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class EcrCredentials: | ||
"""Helper (frozen dataclass) for storing temporary ECR credentials.""" | ||
|
||
username: str | ||
password: str | ||
proxy_endpoint: str | ||
expires_at: datetime | ||
|
||
def __post_init__(self): | ||
mask_secret(self.password) | ||
logger.debug("Credentials to Amazon ECR %r expires at %s.", self.proxy_endpoint, self.expires_at) | ||
|
||
@property | ||
def registry(self) -> str: | ||
"""Return registry in appropriate `docker login` format.""" | ||
# https://github.com/docker/docker-py/issues/2256#issuecomment-824940506 | ||
return self.proxy_endpoint.replace("https://", "") | ||
|
||
|
||
class EcrHook(AwsBaseHook): | ||
""" | ||
Interact with Amazon Elastic Container Registry (ECR) | ||
Additional arguments (such as ``aws_conn_id``) may be specified and | ||
are passed down to the underlying AwsBaseHook. | ||
.. seealso:: | ||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
kwargs["client_type"] = "ecr" | ||
super().__init__(**kwargs) | ||
|
||
def get_temporary_credentials(self, registry_ids: list[str] | str | None = None) -> list[EcrCredentials]: | ||
"""Get temporary credentials for Amazon ECR. | ||
Return list of :class:`~airflow.providers.amazon.aws.hooks.ecr.EcrCredentials`, | ||
obtained credentials valid for 12 hours. | ||
:param registry_ids: Either AWS Account ID or list of AWS Account IDs that are associated | ||
with the registries from which credentials are obtained. If you do not specify a registry, | ||
the default registry is assumed. | ||
.. seealso:: | ||
- `boto3 ECR client get_authorization_token method <https://boto3.amazonaws.com/v1/documentation/\ | ||
api/latest/reference/services/ecr.html#ECR.Client.get_authorization_token>`_. | ||
""" | ||
registry_ids = registry_ids or None | ||
if isinstance(registry_ids, str): | ||
registry_ids = [registry_ids] | ||
|
||
if registry_ids: | ||
response = self.conn.get_authorization_token(registryIds=registry_ids) | ||
else: | ||
response = self.conn.get_authorization_token() | ||
|
||
creds = [] | ||
for auth_data in response["authorizationData"]: | ||
username, password = base64.b64decode(auth_data["authorizationToken"]).decode("utf-8").split(":") | ||
creds.append( | ||
EcrCredentials( | ||
username=username, | ||
password=password, | ||
proxy_endpoint=auth_data["proxyEndpoint"], | ||
expires_at=auth_data["expiresAt"], | ||
) | ||
) | ||
|
||
return creds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+10.8 KB
docs/integration-logos/aws/Amazon-Elastic-Container-Registry_light-bg@4x.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -335,6 +335,7 @@ dat | |
Databricks | ||
databricks | ||
datacenter | ||
dataclass | ||
Datadog | ||
datadog | ||
Dataflow | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from unittest import mock | ||
|
||
import boto3 | ||
import pytest | ||
from moto import mock_ecr | ||
from moto.core import DEFAULT_ACCOUNT_ID | ||
|
||
from airflow.providers.amazon.aws.hooks.ecr import EcrHook | ||
|
||
|
||
@pytest.fixture | ||
def patch_hook(monkeypatch): | ||
"""Patch hook object by dummy boto3 ECR client.""" | ||
ecr_client = boto3.client("ecr") | ||
monkeypatch.setattr(EcrHook, "conn", ecr_client) | ||
yield | ||
|
||
|
||
@mock_ecr | ||
class TestEcrHook: | ||
def test_service_type(self): | ||
"""Test expected boto3 client type.""" | ||
assert EcrHook().client_type == "ecr" | ||
|
||
@pytest.mark.parametrize( | ||
"accounts_ids", | ||
[ | ||
pytest.param("", id="empty-string"), | ||
pytest.param(None, id="none"), | ||
pytest.param([], id="empty-list"), | ||
], | ||
) | ||
def test_get_temporary_credentials_default_account_id(self, patch_hook, accounts_ids): | ||
"""Test different types of empty account/registry ids.""" | ||
result = EcrHook().get_temporary_credentials(registry_ids=accounts_ids) | ||
assert len(result) == 1 | ||
assert result[0].username == "AWS" | ||
assert result[0].registry.startswith(DEFAULT_ACCOUNT_ID) | ||
assert result[0].password == f"{DEFAULT_ACCOUNT_ID}-auth-token" | ||
|
||
@pytest.mark.parametrize( | ||
"accounts_id, expected_registry", | ||
[ | ||
pytest.param(DEFAULT_ACCOUNT_ID, DEFAULT_ACCOUNT_ID, id="moto-default-account"), | ||
pytest.param("111100002222", "111100002222", id="custom-account-id"), | ||
pytest.param(["333366669999"], "333366669999", id="custom-account-id-list"), | ||
], | ||
) | ||
def test_get_temporary_credentials_single_account_id(self, patch_hook, accounts_id, expected_registry): | ||
"""Test different types of single account/registry ids.""" | ||
result = EcrHook().get_temporary_credentials(registry_ids=accounts_id) | ||
assert len(result) == 1 | ||
assert result[0].username == "AWS" | ||
assert result[0].registry.startswith(expected_registry) | ||
assert result[0].password == f"{expected_registry}-auth-token" | ||
|
||
@pytest.mark.parametrize( | ||
"accounts_ids", | ||
[ | ||
pytest.param([DEFAULT_ACCOUNT_ID, "111100002222"], id="moto-default-and-custom-account-ids"), | ||
pytest.param(["999888777666", "333366669999", "777"], id="custom-accounts-ids"), | ||
], | ||
) | ||
def test_get_temporary_credentials_multiple_account_ids(self, patch_hook, accounts_ids): | ||
"""Test multiple account ids in the single method call.""" | ||
expected_creds = len(accounts_ids) | ||
result = EcrHook().get_temporary_credentials(registry_ids=accounts_ids) | ||
assert len(result) == expected_creds | ||
assert [cr.username for cr in result] == ["AWS"] * expected_creds | ||
assert all(cr.registry.startswith(accounts_ids[ix]) for ix, cr in enumerate(result)) | ||
|
||
@pytest.mark.parametrize( | ||
"accounts_ids", | ||
[ | ||
pytest.param(None, id="none"), | ||
pytest.param("111100002222", id="single-account-id"), | ||
pytest.param(["999888777666", "333366669999", "777"], id="multiple-account-ids"), | ||
], | ||
) | ||
@mock.patch("airflow.providers.amazon.aws.hooks.ecr.mask_secret") | ||
def test_get_temporary_credentials_mask_secrets(self, mock_masker, patch_hook, accounts_ids): | ||
"""Test masking passwords.""" | ||
result = EcrHook().get_temporary_credentials(registry_ids=accounts_ids) | ||
assert mock_masker.call_args_list == [mock.call(cr.password) for cr in result] |