Skip to content

Commit

Permalink
Do not mock isinstance in Amazon Tests (#34800)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Oct 6, 2023
1 parent 7707f4a commit 7e896c5
Show file tree
Hide file tree
Showing 11 changed files with 326 additions and 536 deletions.
73 changes: 32 additions & 41 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,19 @@ def test_get_session_returns_a_boto3_session(self):
assert table.item_count == 0

@pytest.mark.parametrize(
"client_meta",
"hook_params",
[
AwsBaseHook(client_type="s3").get_client_type().meta,
AwsBaseHook(resource_type="dynamodb").get_resource_type().meta.client.meta,
pytest.param({"client_type": "s3"}, id="client-type"),
pytest.param({"resource_type": "dynamodb"}, id="resource-type"),
],
)
def test_user_agent_extra_update(self, client_meta):
def test_user_agent_extra_update(self, hook_params):
"""
We are only looking for the keys appended by the AwsBaseHook. A user_agent string
is a number of key/value pairs such as: `BOTO3/1.25.4 AIRFLOW/2.5.0.DEV0 AMPP/6.0.0`.
"""
client_meta = AwsBaseHook(aws_conn_id=None, client_type="s3").conn_client_meta

expected_user_agent_tag_keys = ["Airflow", "AmPP", "Caller", "DagRunKey"]

result_user_agent_tags = client_meta.config.user_agent.split(" ")
Expand Down Expand Up @@ -477,31 +479,25 @@ def mock_assume_role(**kwargs):
return sts_response

with mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get"
) as mock_get, mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.boto3"
) as mock_boto3, mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
mock_get.return_value.ok = True

mock_client = mock_boto3.session.Session.return_value.client
"airflow.providers.amazon.aws.hooks.base_aws.BaseSessionFactory._create_basic_session",
spec=boto3.session.Session,
) as mocked_basic_session:
mocked_basic_session.return_value.region_name = "us-east-2"
mock_client = mocked_basic_session.return_value.client
mock_client.return_value.assume_role.side_effect = mock_assume_role

hook = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="s3")
hook.get_client_type("s3")

calls_assume_role = [
mock.call.session.Session().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
mock.call.session.Session()
.client()
.assume_role(
RoleArn=role_arn,
RoleSessionName=slugified_role_session_name,
),
]
mock_boto3.assert_has_calls(calls_assume_role)
AwsBaseHook(aws_conn_id=aws_conn_id, client_type="s3").get_client_type()
mocked_basic_session.assert_has_calls(
[
mock.call().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
mock.call()
.client()
.assume_role(
RoleArn=role_arn,
RoleSessionName=slugified_role_session_name,
),
]
)

def test_get_credentials_from_gcp_credentials(self):
mock_connection = Connection(
Expand Down Expand Up @@ -684,25 +680,21 @@ def mock_assume_role_with_saml(**kwargs):
with mock.patch("builtins.__import__", side_effect=import_mock), mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get"
) as mock_get, mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.boto3"
) as mock_boto3, mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
mock_get.return_value.ok = True

mock_client = mock_boto3.session.Session.return_value.client
"airflow.providers.amazon.aws.hooks.base_aws.BaseSessionFactory._create_basic_session",
spec=boto3.session.Session,
) as mocked_basic_session:
mocked_basic_session.return_value.region_name = "us-east-2"
mock_client = mocked_basic_session.return_value.client
mock_client.return_value.assume_role_with_saml.side_effect = mock_assume_role_with_saml

hook = AwsBaseHook(aws_conn_id="aws_default", client_type="s3")
hook.get_client_type("s3")
AwsBaseHook(aws_conn_id="aws_default", client_type="s3").get_client_type()

mock_get.assert_called_once_with(idp_url, auth=mock_auth)
mock_xpath.assert_called_once_with(xpath)

calls_assume_role_with_saml = [
mock.call.session.Session().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
mock.call.session.Session()
mocked_basic_session.assert_has_calls = [
mock.call().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
mock.call()
.client()
.assume_role_with_saml(
DurationSeconds=duration_seconds,
Expand All @@ -711,7 +703,6 @@ def mock_assume_role_with_saml(**kwargs):
SAMLAssertion=encoded_saml_assertion,
),
]
mock_boto3.assert_has_calls(calls_assume_role_with_saml)

@mock_iam
def test_expand_role(self):
Expand Down
72 changes: 38 additions & 34 deletions tests/providers/amazon/aws/hooks/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from unittest import mock

import pytest

from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook

SUBMIT_JOB_SUCCESS_RETURN = {
Expand Down Expand Up @@ -46,6 +48,12 @@
}


@pytest.fixture
def mocked_hook_client():
with mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook.conn") as m:
yield m


class TestEmrContainerHook:
def setup_method(self):
self.emr_containers = EmrContainerHook(virtual_cluster_id="vc1234")
Expand All @@ -54,14 +62,8 @@ def test_init(self):
assert self.emr_containers.aws_conn_id == "aws_default"
assert self.emr_containers.virtual_cluster_id == "vc1234"

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
@mock.patch("boto3.session.Session")
def test_create_emr_on_eks_cluster(self, mock_session, mock_isinstance):
emr_client_mock = mock.MagicMock()
emr_client_mock.create_virtual_cluster.return_value = CREATE_EMR_ON_EKS_CLUSTER_RETURN
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
def test_create_emr_on_eks_cluster(self, mocked_hook_client):
mocked_hook_client.create_virtual_cluster.return_value = CREATE_EMR_ON_EKS_CLUSTER_RETURN

emr_on_eks_create_cluster_response = self.emr_containers.create_emr_on_eks_cluster(
virtual_cluster_name="test_virtual_cluster",
Expand All @@ -70,15 +72,19 @@ def test_create_emr_on_eks_cluster(self, mock_session, mock_isinstance):
)
assert emr_on_eks_create_cluster_response == "vc1234"

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
@mock.patch("boto3.session.Session")
def test_submit_job(self, mock_session, mock_isinstance):
mocked_hook_client.create_virtual_cluster.assert_called_once_with(
name="test_virtual_cluster",
containerProvider={
"id": "test_eks_cluster",
"type": "EKS",
"info": {"eksInfo": {"namespace": "test_eks_namespace"}},
},
tags={},
)

def test_submit_job(self, mocked_hook_client):
# Mock out the emr_client creator
emr_client_mock = mock.MagicMock()
emr_client_mock.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
mocked_hook_client.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN

emr_containers_job = self.emr_containers.submit_job(
name="test-job-run",
Expand All @@ -90,32 +96,30 @@ def test_submit_job(self, mock_session, mock_isinstance):
)
assert emr_containers_job == "job123456"

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
@mock.patch("boto3.session.Session")
def test_query_status_polling_when_terminal(self, mock_session, mock_isinstance):
emr_client_mock = mock.MagicMock()
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB1_RUN_DESCRIPTION
mocked_hook_client.start_job_run.assert_called_once_with(
name="test-job-run",
virtualClusterId="vc1234",
executionRoleArn="arn:aws:somerole",
releaseLabel="emr-6.3.0-latest",
jobDriver={},
configurationOverrides={},
tags={},
clientToken="uuidtoken",
)

def test_query_status_polling_when_terminal(self, mocked_hook_client):
mocked_hook_client.describe_job_run.return_value = JOB1_RUN_DESCRIPTION
query_status = self.emr_containers.poll_query_status(job_id="job123456")
# should only poll once since query is already in terminal state
emr_client_mock.describe_job_run.assert_called_once()
mocked_hook_client.describe_job_run.assert_called_once()
assert query_status == "COMPLETED"

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
@mock.patch("boto3.session.Session")
def test_query_status_polling_with_timeout(self, mock_session, mock_isinstance):
emr_client_mock = mock.MagicMock()
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION
def test_query_status_polling_with_timeout(self, mocked_hook_client):
mocked_hook_client.describe_job_run.return_value = JOB2_RUN_DESCRIPTION

query_status = self.emr_containers.poll_query_status(
job_id="job123456", max_polling_attempts=2, poll_interval=0
)
# should poll until max_tries is reached since query is in non-terminal state
assert emr_client_mock.describe_job_run.call_count == 2
assert mocked_hook_client.describe_job_run.call_count == 2
assert query_status == "RUNNING"
48 changes: 13 additions & 35 deletions tests/providers/amazon/aws/operators/test_cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.cloud_formation import (
CloudFormationCreateStackOperator,
Expand All @@ -31,19 +33,14 @@
DEFAULT_ARGS = {"owner": "airflow", "start_date": DEFAULT_DATE}


class TestCloudFormationCreateStackOperator:
def setup_method(self):
# Mock out the cloudformation_client (moto fails with an exception).
self.cloudformation_client_mock = MagicMock()

# Mock out the emr_client creator
cloudformation_session_mock = MagicMock()
cloudformation_session_mock.client.return_value = self.cloudformation_client_mock
self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock)
@pytest.fixture
def mocked_hook_client():
with mock.patch("airflow.providers.amazon.aws.hooks.cloud_formation.CloudFormationHook.conn") as m:
yield m

self.mock_context = MagicMock()

def test_create_stack(self):
class TestCloudFormationCreateStackOperator:
def test_create_stack(self, mocked_hook_client):
stack_name = "myStack"
timeout = 15
template_body = "My stack body"
Expand All @@ -55,30 +52,15 @@ def test_create_stack(self):
dag=DAG("test_dag_id", default_args=DEFAULT_ARGS),
)

with mock.patch("boto3.session.Session", self.boto3_session_mock), mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
operator.execute(self.mock_context)
operator.execute(MagicMock())

self.cloudformation_client_mock.create_stack.assert_any_call(
mocked_hook_client.create_stack.assert_any_call(
StackName=stack_name, TemplateBody=template_body, TimeoutInMinutes=timeout
)


class TestCloudFormationDeleteStackOperator:
def setup_method(self):
# Mock out the cloudformation_client (moto fails with an exception).
self.cloudformation_client_mock = MagicMock()

# Mock out the emr_client creator
cloudformation_session_mock = MagicMock()
cloudformation_session_mock.client.return_value = self.cloudformation_client_mock
self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock)

self.mock_context = MagicMock()

def test_delete_stack(self):
def test_delete_stack(self, mocked_hook_client):
stack_name = "myStackToBeDeleted"

operator = CloudFormationDeleteStackOperator(
Expand All @@ -87,10 +69,6 @@ def test_delete_stack(self):
dag=DAG("test_dag_id", default_args=DEFAULT_ARGS),
)

with mock.patch("boto3.session.Session", self.boto3_session_mock), mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
operator.execute(self.mock_context)
operator.execute(MagicMock())

self.cloudformation_client_mock.delete_stack.assert_any_call(StackName=stack_name)
mocked_hook_client.delete_stack.assert_any_call(StackName=stack_name)

0 comments on commit 7e896c5

Please sign in to comment.