Skip to content

Commit

Permalink
Fix providers tests in main branch with eager upgrades (#18040)
Browse files Browse the repository at this point in the history
The SQS and DataCatalog were failing tests in main branch because
some recent release of dependencies broke them:

1) SQS moto 2.2.6 broke SQS tests - the queue url in the 2.2.6+
   version has to start with http:// or https://

2) DataCatalog part of Google Provider incorrectly imported
   types and broke tests (used beta instad of datacatalog path)
  • Loading branch information
potiuk committed Sep 6, 2021
1 parent 1be3ef6 commit bfad233
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 24 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/datacatalog.py
Expand Up @@ -19,7 +19,7 @@

from google.api_core.retry import Retry
from google.cloud import datacatalog
from google.cloud.datacatalog_v1beta1 import (
from google.cloud.datacatalog import (
CreateTagRequest,
DataCatalogClient,
Entry,
Expand Down
9 changes: 6 additions & 3 deletions tests/providers/amazon/aws/operators/test_sqs.py
Expand Up @@ -29,6 +29,9 @@

DEFAULT_DATE = timezone.datetime(2019, 1, 1)

QUEUE_NAME = 'test-queue'
QUEUE_URL = f'https://{QUEUE_NAME}'


class TestSQSPublishOperator(unittest.TestCase):
def setUp(self):
Expand All @@ -38,7 +41,7 @@ def setUp(self):
self.operator = SQSPublishOperator(
task_id='test_task',
dag=self.dag,
sqs_queue='test',
sqs_queue=QUEUE_URL,
message_content='hello',
aws_conn_id='aws_default',
)
Expand All @@ -48,13 +51,13 @@ def setUp(self):

@mock_sqs
def test_execute_success(self):
self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)

result = self.operator.execute(self.mock_context)
assert 'MD5OfMessageBody' in result
assert 'MessageId' in result

message = self.sqs_hook.get_conn().receive_message(QueueUrl='test')
message = self.sqs_hook.get_conn().receive_message(QueueUrl=QUEUE_URL)

assert len(message['Messages']) == 1
assert message['Messages'][0]['MessageId'] == result['MessageId']
Expand Down
39 changes: 21 additions & 18 deletions tests/providers/amazon/aws/sensors/test_sqs.py
Expand Up @@ -32,23 +32,26 @@

DEFAULT_DATE = timezone.datetime(2017, 1, 1)

QUEUE_NAME = 'test-queue'
QUEUE_URL = f'https://{QUEUE_NAME}'


class TestSQSSensor(unittest.TestCase):
def setUp(self):
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}

self.dag = DAG('test_dag_id', default_args=args)
self.sensor = SQSSensor(
task_id='test_task', dag=self.dag, sqs_queue='test', aws_conn_id='aws_default'
task_id='test_task', dag=self.dag, sqs_queue=QUEUE_URL, aws_conn_id='aws_default'
)

self.mock_context = mock.MagicMock()
self.sqs_hook = SQSHook()

@mock_sqs
def test_poke_success(self):
self.sqs_hook.create_queue('test')
self.sqs_hook.send_message(queue_url='test', message_body='hello')
self.sqs_hook.create_queue(QUEUE_NAME)
self.sqs_hook.send_message(queue_url=QUEUE_URL, message_body='hello')

result = self.sensor.poke(self.mock_context)
assert result
Expand All @@ -60,7 +63,7 @@ def test_poke_success(self):
@mock_sqs
def test_poke_no_message_failed(self):

self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)
result = self.sensor.poke(self.mock_context)
assert not result

Expand Down Expand Up @@ -112,40 +115,40 @@ def test_poke_receive_raise_exception(self, mock_conn):
@mock.patch.object(SQSHook, 'get_conn')
def test_poke_visibility_timeout(self, mock_conn):
# Check without visibility_timeout parameter
self.sqs_hook.create_queue('test')
self.sqs_hook.send_message(queue_url='test', message_body='hello')
self.sqs_hook.create_queue(QUEUE_NAME)
self.sqs_hook.send_message(queue_url=QUEUE_URL, message_body='hello')

self.sensor.poke(self.mock_context)

calls_receive_message = [
mock.call().receive_message(QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1)
mock.call().receive_message(QueueUrl=QUEUE_URL, MaxNumberOfMessages=5, WaitTimeSeconds=1)
]
mock_conn.assert_has_calls(calls_receive_message)
# Check with visibility_timeout parameter
self.sensor = SQSSensor(
task_id='test_task2',
dag=self.dag,
sqs_queue='test',
sqs_queue=QUEUE_URL,
aws_conn_id='aws_default',
visibility_timeout=42,
)
self.sensor.poke(self.mock_context)

calls_receive_message = [
mock.call().receive_message(
QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1, VisibilityTimeout=42
QueueUrl=QUEUE_URL, MaxNumberOfMessages=5, WaitTimeSeconds=1, VisibilityTimeout=42
)
]
mock_conn.assert_has_calls(calls_receive_message)

@mock_sqs
def test_poke_message_invalid_filtering(self):
self.sqs_hook.create_queue('test')
self.sqs_hook.send_message(queue_url='test', message_body='hello')
self.sqs_hook.create_queue(QUEUE_NAME)
self.sqs_hook.send_message(queue_url=QUEUE_URL, message_body='hello')
sensor = SQSSensor(
task_id='test_task2',
dag=self.dag,
sqs_queue='test',
sqs_queue=QUEUE_URL,
aws_conn_id='aws_default',
message_filtering='invalid_option',
)
Expand All @@ -155,7 +158,7 @@ def test_poke_message_invalid_filtering(self):

@mock.patch.object(SQSHook, "get_conn")
def test_poke_message_filtering_literal_values(self, mock_conn):
self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)
matching = [{"id": 11, "body": "a matching message"}]
non_matching = [{"id": 12, "body": "a non-matching message"}]
all = matching + non_matching
Expand Down Expand Up @@ -188,13 +191,13 @@ def mock_delete_message_batch(**kwargs):
# Test that only filtered messages are deleted
delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
calls_delete_message_batch = [
mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
mock.call().delete_message_batch(QueueUrl=QUEUE_URL, Entries=delete_entries)
]
mock_conn.assert_has_calls(calls_delete_message_batch)

@mock.patch.object(SQSHook, "get_conn")
def test_poke_message_filtering_jsonpath(self, mock_conn):
self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)
matching = [
{"id": 11, "key": {"matches": [1, 2]}},
{"id": 12, "key": {"matches": [3, 4, 5]}},
Expand Down Expand Up @@ -234,13 +237,13 @@ def mock_delete_message_batch(**kwargs):
# Test that only filtered messages are deleted
delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
calls_delete_message_batch = [
mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
mock.call().delete_message_batch(QueueUrl=QUEUE_URL, Entries=delete_entries)
]
mock_conn.assert_has_calls(calls_delete_message_batch)

@mock.patch.object(SQSHook, "get_conn")
def test_poke_message_filtering_jsonpath_values(self, mock_conn):
self.sqs_hook.create_queue('test')
self.sqs_hook.create_queue(QUEUE_NAME)
matching = [
{"id": 11, "key": {"matches": [1, 2]}},
{"id": 12, "key": {"matches": [1, 4, 5]}},
Expand Down Expand Up @@ -282,6 +285,6 @@ def mock_delete_message_batch(**kwargs):
# Test that only filtered messages are deleted
delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
calls_delete_message_batch = [
mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
mock.call().delete_message_batch(QueueUrl='https://test-queue', Entries=delete_entries)
]
mock_conn.assert_has_calls(calls_delete_message_batch)
3 changes: 1 addition & 2 deletions tests/providers/google/cloud/hooks/test_datacatalog.py
Expand Up @@ -22,8 +22,7 @@

import pytest
from google.api_core.retry import Retry
from google.cloud.datacatalog_v1beta1 import CreateTagRequest, CreateTagTemplateRequest
from google.cloud.datacatalog_v1beta1.types import Entry, Tag, TagTemplate
from google.cloud.datacatalog import CreateTagRequest, CreateTagTemplateRequest, Entry, Tag, TagTemplate

from airflow import AirflowException
from airflow.providers.google.cloud.hooks.datacatalog import CloudDataCatalogHook
Expand Down

0 comments on commit bfad233

Please sign in to comment.