Skip to content

Commit

Permalink
Add Amazon Elastic Container Registry (ECR) Hook
Browse files Browse the repository at this point in the history
Co-authored-by: D. Ferruzzi <ferruzzi@amazon.com>
Co-authored-by: Niko <onikolas@amazon.com>
  • Loading branch information
3 people committed Dec 13, 2022
1 parent 7f9727f commit 7fb732d
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 0 deletions.
101 changes: 101 additions & 0 deletions airflow/providers/amazon/aws/hooks/ecr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import base64
import logging
from dataclasses import dataclass
from datetime import datetime

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.log.secrets_masker import mask_secret

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class EcrCredentials:
"""Helper (frozen dataclass) for storing temporary ECR credentials."""

username: str
password: str
proxy_endpoint: str
expires_at: datetime

def __post_init__(self):
mask_secret(self.password)
logger.debug("Credentials to Amazon ECR %r expires at %s.", self.proxy_endpoint, self.expires_at)

@property
def registry(self) -> str:
"""Return registry in appropriate `docker login` format."""
# https://github.com/docker/docker-py/issues/2256#issuecomment-824940506
return self.proxy_endpoint.replace("https://", "")


class EcrHook(AwsBaseHook):
"""
Interact with Amazon Elastic Container Registry (ECR)
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

def __init__(self, **kwargs):
kwargs["client_type"] = "ecr"
super().__init__(**kwargs)

def get_temporary_credentials(self, registry_ids: list[str] | str | None = None) -> list[EcrCredentials]:
"""Get temporary credentials for Amazon ECR.
Return list of :class:`~airflow.providers.amazon.aws.hooks.ecr.EcrCredentials`,
obtained credentials valid for 12 hours.
:param registry_ids: Either AWS Account ID or list of AWS Account IDs that are associated
with the registries from which credentials are obtained. If you do not specify a registry,
the default registry is assumed.
.. seealso::
- `boto3 ECR client get_authorization_token method <https://boto3.amazonaws.com/v1/documentation/\
api/latest/reference/services/ecr.html#ECR.Client.get_authorization_token>`_.
"""
registry_ids = registry_ids or None
if isinstance(registry_ids, str):
registry_ids = [registry_ids]

if registry_ids:
response = self.conn.get_authorization_token(registryIds=registry_ids)
else:
response = self.conn.get_authorization_token()

creds = []
for auth_data in response["authorizationData"]:
username, password = base64.b64decode(auth_data["authorizationToken"]).decode("utf-8").split(":")
creds.append(
EcrCredentials(
username=username,
password=password,
proxy_endpoint=auth_data["proxyEndpoint"],
expires_at=auth_data["expiresAt"],
)
)

return creds
7 changes: 7 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ integrations:
how-to-guide:
- /docs/apache-airflow-providers-amazon/operators/ec2.rst
tags: [aws]
- integration-name: Amazon Elastic Container Registry (ECR)
external-doc-url: https://aws.amazon.com/ecr/
logo: /integration-logos/aws/Amazon-Elastic-Container-Registry_light-bg@4x.png
tags: [aws]
- integration-name: Amazon ECS
external-doc-url: https://aws.amazon.com/ecs/
logo: /integration-logos/aws/Amazon-Elastic-Container-Service_light-bg@4x.png
Expand Down Expand Up @@ -402,6 +406,9 @@ hooks:
- integration-name: Amazon EC2
python-modules:
- airflow.providers.amazon.aws.hooks.ec2
- integration-name: Amazon Elastic Container Registry (ECR)
python-modules:
- airflow.providers.amazon.aws.hooks.ecr
- integration-name: Amazon ECS
python-modules:
- airflow.providers.amazon.aws.hooks.ecs
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ dat
Databricks
databricks
datacenter
dataclass
Datadog
datadog
Dataflow
Expand Down
103 changes: 103 additions & 0 deletions tests/providers/amazon/aws/hooks/test_ecr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from unittest import mock

import boto3
import pytest
from moto import mock_ecr
from moto.core import DEFAULT_ACCOUNT_ID

from airflow.providers.amazon.aws.hooks.ecr import EcrHook


@pytest.fixture
def patch_hook(monkeypatch):
"""Patch hook object by dummy boto3 ECR client."""
ecr_client = boto3.client("ecr")
monkeypatch.setattr(EcrHook, "conn", ecr_client)
yield


@mock_ecr
class TestEcrHook:
def test_service_type(self):
"""Test expected boto3 client type."""
assert EcrHook().client_type == "ecr"

@pytest.mark.parametrize(
"accounts_ids",
[
pytest.param("", id="empty-string"),
pytest.param(None, id="none"),
pytest.param([], id="empty-list"),
],
)
def test_get_temporary_credentials_default_account_id(self, patch_hook, accounts_ids):
"""Test different types of empty account/registry ids."""
result = EcrHook().get_temporary_credentials(registry_ids=accounts_ids)
assert len(result) == 1
assert result[0].username == "AWS"
assert result[0].registry.startswith(DEFAULT_ACCOUNT_ID)
assert result[0].password == f"{DEFAULT_ACCOUNT_ID}-auth-token"

@pytest.mark.parametrize(
"accounts_id, expected_registry",
[
pytest.param(DEFAULT_ACCOUNT_ID, DEFAULT_ACCOUNT_ID, id="moto-default-account"),
pytest.param("111100002222", "111100002222", id="custom-account-id"),
pytest.param(["333366669999"], "333366669999", id="custom-account-id-list"),
],
)
def test_get_temporary_credentials_single_account_id(self, patch_hook, accounts_id, expected_registry):
"""Test different types of single account/registry ids."""
result = EcrHook().get_temporary_credentials(registry_ids=accounts_id)
assert len(result) == 1
assert result[0].username == "AWS"
assert result[0].registry.startswith(expected_registry)
assert result[0].password == f"{expected_registry}-auth-token"

@pytest.mark.parametrize(
"accounts_ids",
[
pytest.param([DEFAULT_ACCOUNT_ID, "111100002222"], id="moto-default-and-custom-account-ids"),
pytest.param(["999888777666", "333366669999", "777"], id="custom-accounts-ids"),
],
)
def test_get_temporary_credentials_multiple_account_ids(self, patch_hook, accounts_ids):
"""Test multiple account ids in the single method call."""
expected_creds = len(accounts_ids)
result = EcrHook().get_temporary_credentials(registry_ids=accounts_ids)
assert len(result) == expected_creds
assert [cr.username for cr in result] == ["AWS"] * expected_creds
assert all(cr.registry.startswith(accounts_ids[ix]) for ix, cr in enumerate(result))

@pytest.mark.parametrize(
"accounts_ids",
[
pytest.param(None, id="none"),
pytest.param("111100002222", id="single-account-id"),
pytest.param(["999888777666", "333366669999", "777"], id="multiple-account-ids"),
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.ecr.mask_secret")
def test_get_temporary_credentials_mask_secrets(self, mock_masker, patch_hook, accounts_ids):
"""Test masking passwords."""
result = EcrHook().get_temporary_credentials(registry_ids=accounts_ids)
assert mock_masker.call_args_list == [mock.call(cr.password) for cr in result]

0 comments on commit 7fb732d

Please sign in to comment.