Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions providers/slack/src/airflow/providers/slack/hooks/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict

from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web.async_client import AsyncWebClient
from typing_extensions import NotRequired

from airflow.providers.common.compat.connection import get_async_connection
Expand All @@ -44,8 +41,10 @@
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.http_retry import RetryHandler
from slack_sdk.web.async_client import AsyncSlackResponse
from slack_sdk.web.async_client import AsyncSlackResponse, AsyncWebClient
from slack_sdk.web.slack_response import SlackResponse

from airflow.providers.common.compat.sdk import Connection
Expand Down Expand Up @@ -152,11 +151,15 @@ def __init__(
@cached_property
def client(self) -> WebClient:
"""Get the underlying slack_sdk.WebClient (cached)."""
from slack_sdk import WebClient

conn = self.get_connection(self.slack_conn_id)
return WebClient(**self._get_conn_params(conn=conn))

async def get_async_client(self) -> AsyncWebClient:
"""Get the underlying `slack_sdk.web.async_client.AsyncWebClient`."""
from slack_sdk.web.async_client import AsyncWebClient

conn = await get_async_connection(self.slack_conn_id)
return AsyncWebClient(**self._get_conn_params(conn))

Expand Down Expand Up @@ -372,6 +375,8 @@ def _call_conversations_list(self, cursor: str | None = None):
:raises SlackApiError: Propagated when errors other than 429 occur.
:return: Slack SDK response for the page requested.
"""
from slack_sdk.errors import SlackApiError

max_retries = 5
for attempt in range(max_retries):
try:
Expand All @@ -397,6 +402,8 @@ def test_connection(self):
.. seealso::
https://api.slack.com/methods/auth.test
"""
from slack_sdk.errors import SlackApiError

try:
response = self.call("auth.test")
response.validate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
from functools import cached_property, wraps
from typing import TYPE_CHECKING, Any

from slack_sdk import WebhookClient
from slack_sdk.webhook.async_client import AsyncWebhookClient

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException, BaseHook
from airflow.providers.slack.utils import ConnectionExtraConfig

if TYPE_CHECKING:
from slack_sdk import WebhookClient
from slack_sdk.http_retry import RetryHandler
from slack_sdk.webhook.async_client import AsyncWebhookClient

LEGACY_INTEGRATION_PARAMS = ("channel", "username", "icon_emoji", "icon_url")

Expand Down Expand Up @@ -154,6 +153,8 @@ def client(self) -> WebhookClient:

async def get_async_client(self) -> AsyncWebhookClient:
"""Get the underlying `slack_sdk.webhook.async_client.AsyncWebhookClient`."""
from slack_sdk.webhook.async_client import AsyncWebhookClient

return AsyncWebhookClient(**await self._async_get_conn_params())

def get_conn(self) -> WebhookClient:
Expand Down
12 changes: 6 additions & 6 deletions providers/slack/tests/unit/slack/hooks/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def test_empty_password(self):
{},
),
],
)
@mock.patch("airflow.providers.slack.hooks.slack.WebClient")
) #
@mock.patch("slack_sdk.WebhookClient")
def test_client_configuration(
self, mock_webclient_cls, hook_config, conn_extra, expected: dict[str, Any]
):
Expand Down Expand Up @@ -407,7 +407,7 @@ def test_get_channel_id(self, mocked_client):
def test_call_conversations_list_retries_then_succeeds(self, monkeypatch):
ok_resp = self.fake_slack_response(data={"channels": []})
monkeypatch.setattr(
"airflow.providers.slack.hooks.slack.WebClient",
"slack_sdk.WebhookClient",
lambda **_: mock.MagicMock(
conversations_list=mock.Mock(side_effect=[self.make_429(), self.make_429(), ok_resp])
),
Expand All @@ -420,7 +420,7 @@ def test_call_conversations_list_retries_then_succeeds(self, monkeypatch):

def test_call_conversations_list_exceeds_max(self, monkeypatch):
monkeypatch.setattr(
"airflow.providers.slack.hooks.slack.WebClient",
"slack_sdk.WebhookClient",
lambda **_: mock.MagicMock(conversations_list=mock.Mock(side_effect=[self.make_429()] * 5)),
)
with pytest.raises(AirflowException, match="Max retries"):
Expand Down Expand Up @@ -592,7 +592,7 @@ def mock_get_conn(self):
yield m

@pytest.mark.asyncio
@mock.patch("airflow.providers.slack.hooks.slack.AsyncWebClient")
@mock.patch("slack_sdk.webhook.async_client.AsyncWebhookClient")
async def test_get_async_client(self, mock_client, mock_get_conn):
"""Test get_async_client creates AsyncWebClient with correct params."""
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)
Expand All @@ -601,7 +601,7 @@ async def test_get_async_client(self, mock_client, mock_get_conn):
mock_client.assert_called_once_with(token=MOCK_SLACK_API_TOKEN, logger=mock.ANY)

@pytest.mark.asyncio
@mock.patch("airflow.providers.slack.hooks.slack.AsyncWebClient.api_call", new_callable=mock.AsyncMock)
@mock.patch("slack_sdk.webhook.async_client.AsyncWebhookClient.api_call", new_callable=mock.AsyncMock)
async def test_async_call(self, mock_api_call, mock_get_conn):
"""Test async_call is called correctly."""
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)
Expand Down
12 changes: 6 additions & 6 deletions providers/slack/tests/unit/slack/hooks/test_slack_webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def test_no_password_in_connection_field(self, conn_id):
),
],
)
@mock.patch("airflow.providers.slack.hooks.slack_webhook.WebhookClient")
@mock.patch("slack_sdk.WebhookClient")
def test_client_configuration(
self, mock_webhook_client_cls, hook_config, conn_extra, expected: dict[str, Any]
):
Expand Down Expand Up @@ -393,7 +393,7 @@ def test_client_configuration(
{"text": "Fallback Text", "blocks": ["Dummy Block"], "unfurl_media": True, "unfurl_links": True},
],
)
@mock.patch("airflow.providers.slack.hooks.slack_webhook.WebhookClient")
@mock.patch("slack_sdk.WebhookClient")
def test_hook_send_dict(self, mock_webhook_client_cls, send_body, headers):
"""Test `SlackWebhookHook.send_dict` method."""
mock_webhook_client = mock_webhook_client_cls.return_value
Expand All @@ -411,7 +411,7 @@ def test_hook_send_dict(self, mock_webhook_client_cls, send_body, headers):
mock_webhook_client_send_dict.assert_called_once_with(send_body, headers=headers)

@pytest.mark.parametrize("send_body", [("text", "Test Text"), 42, "null", "42"])
@mock.patch("airflow.providers.slack.hooks.slack_webhook.WebhookClient")
@mock.patch("slack_sdk.WebhookClient")
def test_hook_send_dict_invalid_type(self, mock_webhook_client_cls, send_body):
"""Test invalid body type for `SlackWebhookHook.send_dict` method."""
mock_webhook_client = mock_webhook_client_cls.return_value
Expand All @@ -424,7 +424,7 @@ def test_hook_send_dict_invalid_type(self, mock_webhook_client_cls, send_body):
assert mock_webhook_client_send_dict.assert_not_called

@pytest.mark.parametrize("json_string", ["{'text': 'Single quotes'}", '{"text": "Missing }"'])
@mock.patch("airflow.providers.slack.hooks.slack_webhook.WebhookClient")
@mock.patch("slack_sdk.WebhookClient")
def test_hook_send_dict_invalid_json_string(self, mock_webhook_client_cls, json_string):
"""Test invalid JSON-string passed to `SlackWebhookHook.send_dict` method."""
mock_webhook_client = mock_webhook_client_cls.return_value
Expand All @@ -446,7 +446,7 @@ def test_hook_send_dict_invalid_json_string(self, mock_webhook_client_cls, json_
"icon_url",
],
)
@mock.patch("airflow.providers.slack.hooks.slack_webhook.WebhookClient")
@mock.patch("slack_sdk.WebhookClient")
def test_hook_send_dict_legacy_slack_integration(self, mock_webhook_client_cls, legacy_attr):
"""Test `SlackWebhookHook.send_dict` warn users about Legacy Slack Integrations."""
mock_webhook_client = mock_webhook_client_cls.return_value
Expand Down Expand Up @@ -571,7 +571,7 @@ async def test_async_client(self, mock_async_get_conn_params):
{"text": "Fallback Text", "blocks": ["Dummy Block"], "unfurl_media": True, "unfurl_links": True},
],
)
@mock.patch("airflow.providers.slack.hooks.slack_webhook.AsyncWebhookClient")
@mock.patch("slack_sdk.webhook.async_client.AsyncWebhookClient")
@mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook._async_get_conn_params")
async def test_async_send_dict(
self, mock_async_get_conn_params, mock_async_webhook_client_cls, send_body, headers
Expand Down
Loading