Skip to content

Commit

Permalink
Fix HDFSHook HAClient is invalid (#30164)
Browse files Browse the repository at this point in the history
  • Loading branch information
trickysky committed Mar 27, 2023
1 parent 19d3c3d commit e141699
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
9 changes: 6 additions & 3 deletions airflow/providers/apache/hdfs/hooks/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ class HDFSHook(BaseHook):
hook_name = "HDFS"

def __init__(
self, hdfs_conn_id: str = "hdfs_default", proxy_user: str | None = None, autoconfig: bool = False
self,
hdfs_conn_id: str | set[str] = "hdfs_default",
proxy_user: str | None = None,
autoconfig: bool = False,
):
super().__init__()
if not snakebite_loaded:
Expand All @@ -60,7 +63,7 @@ def __init__(
"snakebite is not compatible with Python 3 "
"(as of August 2015). Please help by submitting a PR!"
)
self.hdfs_conn_id = hdfs_conn_id
self.hdfs_conn_id = {hdfs_conn_id} if isinstance(hdfs_conn_id, str) else hdfs_conn_id
self.proxy_user = proxy_user
self.autoconfig = autoconfig

Expand All @@ -73,7 +76,7 @@ def get_conn(self) -> Any:
use_sasl = conf.get("core", "security") == "kerberos"

try:
connections = self.get_connections(self.hdfs_conn_id)
connections = [self.get_connection(i) for i in self.hdfs_conn_id]

if not effective_user:
effective_user = connections[0].login
Expand Down
15 changes: 9 additions & 6 deletions tests/providers/apache/hdfs/hooks/test_hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,13 @@ def test_get_autoconfig_client_no_conn(self, mock_client):
HDFSHook(hdfs_conn_id="hdfs_missing", autoconfig=True).get_conn()
mock_client.assert_called_once_with(effective_user=None, use_sasl=False)

@mock.patch("airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook.get_connections")
def test_get_ha_client(self, mock_get_connections):
conn_1 = Connection(conn_id="hdfs_default", conn_type="hdfs", host="localhost", port=8020)
conn_2 = Connection(conn_id="hdfs_default", conn_type="hdfs", host="localhost2", port=8020)
mock_get_connections.return_value = [conn_1, conn_2]
client = HDFSHook().get_conn()
@mock.patch.dict(
"os.environ",
{
"AIRFLOW_CONN_HDFS1": "hdfs://host1:8020",
"AIRFLOW_CONN_HDFS2": "hdfs://host2:8020",
},
)
def test_get_ha_client(self):
client = HDFSHook(hdfs_conn_id={"hdfs1", "hdfs2"}).get_conn()
assert isinstance(client, snakebite.client.HAClient)

0 comments on commit e141699

Please sign in to comment.