Skip to content

Commit

Permalink
Fix reraise outside of try block in AthenaHook.get_output_location (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Dec 1, 2023
1 parent c26aa12 commit fd03dc2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
25 changes: 11 additions & 14 deletions airflow/providers/amazon/aws/hooks/athena.py
Expand Up @@ -292,20 +292,17 @@ def get_output_location(self, query_execution_id: str) -> str:
:param query_execution_id: Id of submitted athena query
"""
if query_execution_id:
response = self.get_query_info(query_execution_id=query_execution_id, use_cache=True)

if response:
try:
return response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
except KeyError:
self.log.error(
"Error retrieving OutputLocation. Query execution id: %s", query_execution_id
)
raise
else:
raise
raise ValueError("Invalid Query execution id. Query execution id: %s", query_execution_id)
if not query_execution_id:
raise ValueError(f"Invalid Query execution id. Query execution id: {query_execution_id}")

if not (response := self.get_query_info(query_execution_id=query_execution_id, use_cache=True)):
raise ValueError(f"Unable to get query information for execution id: {query_execution_id}")

try:
return response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
except KeyError:
self.log.error("Error retrieving OutputLocation. Query execution id: %s", query_execution_id)
raise

def stop_query(self, query_execution_id: str) -> dict:
"""Cancel the submitted query.
Expand Down
25 changes: 25 additions & 0 deletions tests/providers/amazon/aws/hooks/test_athena.py
Expand Up @@ -18,6 +18,8 @@

from unittest import mock

import pytest

from airflow.providers.amazon.aws.hooks.athena import AthenaHook

MOCK_DATA = {
Expand Down Expand Up @@ -197,6 +199,29 @@ def test_hook_get_output_location(self, mock_conn):
result = self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
assert result == "s3://test_bucket/test.csv"

@pytest.mark.parametrize(
"query_execution_id", [pytest.param("", id="empty-string"), pytest.param(None, id="none")]
)
def test_hook_get_output_location_empty_execution_id(self, query_execution_id):
with pytest.raises(ValueError, match="Invalid Query execution id"):
self.athena.get_output_location(query_execution_id=query_execution_id)

@pytest.mark.parametrize("response", [pytest.param({}, id="empty-dict"), pytest.param(None, id="none")])
def test_hook_get_output_location_no_response(self, response):
with mock.patch.object(AthenaHook, "get_query_info", return_value=response) as m:
with pytest.raises(ValueError, match="Unable to get query information"):
self.athena.get_output_location(query_execution_id="PLACEHOLDER")
m.assert_called_once_with(query_execution_id="PLACEHOLDER", use_cache=True)

def test_hook_get_output_location_invalid_response(self, caplog):
with mock.patch.object(AthenaHook, "get_query_info") as m:
m.return_value = {"foo": "bar"}
caplog.clear()
caplog.set_level("ERROR")
with pytest.raises(KeyError):
self.athena.get_output_location(query_execution_id="PLACEHOLDER")
assert "Error retrieving OutputLocation" in caplog.text

@mock.patch.object(AthenaHook, "get_conn")
def test_hook_get_query_info_caching(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
Expand Down

0 comments on commit fd03dc2

Please sign in to comment.