diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 2a202441a76c..59199cf8cd2e 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -248,7 +248,12 @@ def _get_conn_params(self) -> dict[str, str | None]: "Please remove one." ) elif private_key_file: - private_key_pem = Path(private_key_file).read_bytes() + private_key_file_path = Path(private_key_file) + if not private_key_file_path.is_file() or private_key_file_path.stat().st_size == 0: + raise ValueError("The private_key_file path points to an empty or invalid file.") + if private_key_file_path.stat().st_size > 4096: + raise ValueError("The private_key_file size is too big. Please keep it less than 4 KB.") + private_key_pem = Path(private_key_file_path).read_bytes() elif private_key_content: private_key_pem = private_key_content.encode() diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py index e1105fbb5948..6a738952d90f 100644 --- a/tests/providers/snowflake/hooks/test_snowflake.py +++ b/tests/providers/snowflake/hooks/test_snowflake.py @@ -393,6 +393,24 @@ def test_get_conn_params_should_support_private_auth_with_unencrypted_key( ), pytest.raises(TypeError, match="Password was given but private key is not encrypted."): SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() + def test_get_conn_params_should_fail_on_invalid_key(self): + connection_kwargs = { + **BASE_CONNECTION_KWARGS, + "password": None, + "extra": { + "database": "db", + "account": "airflow", + "warehouse": "af_wh", + "region": "af_region", + "role": "af_role", + "private_key_file": "/dev/urandom", + }, + } + with mock.patch.dict( + "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() + ), pytest.raises(ValueError, match="The private_key_file path points to an empty or invalid file."): + SnowflakeHook(snowflake_conn_id="test_conn").get_conn() + def test_should_add_partner_info(self): with mock.patch.dict( "os.environ",