Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@

OAUTH_REQUEST_TIMEOUT = 30 # seconds, avoid hanging tasks on token request
OAUTH_EXPIRY_BUFFER = 30
SUPPORTED_GRANT_TYPES = {"refresh_token", "client_credentials"}

T = TypeVar("T")


Expand Down Expand Up @@ -222,6 +224,18 @@ def _get_field(self, extra_dict, field_name):
return extra_dict[field_name] or None
return extra_dict.get(backcompat_key) or None

def _validate_grant_type(self, grant_type: str | None) -> str:
"""Validate OAuth grant_type."""
if not grant_type:
raise ValueError("Grant type must be provided for OAuth authentication.")

if grant_type not in SUPPORTED_GRANT_TYPES:
supported = ", ".join(sorted(SUPPORTED_GRANT_TYPES))

raise ValueError(f"Unsupported grant_type '{grant_type}'. Supported values: {supported}")

return grant_type

@property
def account_identifier(self) -> str:
"""Get snowflake account identifier."""
Expand Down Expand Up @@ -296,9 +310,8 @@ def _get_conn_params(self) -> dict[str, str | None]:
if azure_conn_id:
conn_config["token"] = self.get_azure_oauth_token(azure_conn_id)
else:
grant_type = conn_config.get("grant_type")
if not grant_type:
raise ValueError("Grant_type not provided")
grant_type = self._validate_grant_type(conn_config.get("grant_type"))

conn_config["token"] = self._get_valid_oauth_token(
conn_config=conn_config,
token_endpoint=conn_config.get("token_endpoint"),
Expand Down Expand Up @@ -494,14 +507,12 @@ def _get_valid_oauth_token(
if scope:
data["scope"] = scope

grant_type = self._validate_grant_type(grant_type)

if grant_type == "refresh_token":
data |= {
"refresh_token": conn_config["refresh_token"],
}
elif grant_type == "client_credentials":
pass # no setup necessary for client credentials grant.
else:
raise ValueError(f"Unknown grant_type: {grant_type}")

response = requests.post(
url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,24 @@ def test_get_snowpark_session(self, mock_session_builder):
}
)

@pytest.mark.parametrize(
("grant_type", "expected", "match"),
[
("refresh_token", "refresh_token", None),
("client_credentials", "client_credentials", None),
("invalid_grant", ValueError, r"Unsupported grant_type"),
(None, ValueError, r"Grant type must be provided"),
],
)
def test_validate_grant_type(self, grant_type, expected, match):
hook = SnowflakeHook(snowflake_conn_id="test")

if expected is ValueError:
with pytest.raises(ValueError, match=match):
hook._validate_grant_type(grant_type)
else:
assert hook._validate_grant_type(grant_type) == expected

@mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
@mock.patch("requests.post")
@mock.patch(
Expand Down