Skip to content

Commit

Permalink
Add KV lock for refreshing tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Mar 26, 2024
1 parent 9ed7d79 commit 0d27d8c
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 13 deletions.
6 changes: 6 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,9 @@ def __init__(self, error: str):
extra={"error": error},
)
)


class CreateAuthLockFailedException(Exception):
"""
Exception to signalize failure to acquire lock when refreshing token.
"""
1 change: 1 addition & 0 deletions superset/key_value/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class KeyValueResource(StrEnum):
DASHBOARD_PERMALINK = "dashboard_permalink"
EXPLORE_PERMALINK = "explore_permalink"
METASTORE_CACHE = "superset_metastore_cache"
OAUTH2 = "oauth2"


class SharedKey(StrEnum):
Expand Down
107 changes: 95 additions & 12 deletions superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,36 @@

from __future__ import annotations

import logging
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import TYPE_CHECKING

import backoff

from superset import db
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import CreateAuthLockFailedException
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import KeyValueResource, PickleKeyValueCodec

if TYPE_CHECKING:
from superset.models.core import DatabaseUserOAuth2Tokens

Check warning on line 36 in superset/utils/oauth2.py

View check run for this annotation

Codecov / codecov/patch

superset/utils/oauth2.py#L36

Added line #L36 was not covered by tests


LOCK_EXPIRATION = timedelta(seconds=30)
logger = logging.getLogger(__name__)


@backoff.on_exception(
backoff.expo,
CreateAuthLockFailedException,
factor=10,
base=2,
max_tries=5,
)
def get_oauth2_access_token(
database_id: int,
user_id: int,
Expand All @@ -49,21 +73,80 @@ def get_oauth2_access_token(
return token.access_token

if token.refresh_token:
# refresh access token
return refresh_oauth2_token(database_id, user_id, db_engine_spec, token)

# since the access token is expired and there's no refresh token, delete the entry
db.session.delete(token)

return None


def integers_to_uuid(a: int, b: int) -> uuid.UUID: # pylint: disable=invalid-name
"""
Generate UUID based on a namespace UUID and the string representation of integer pair.
"""
pair_str = f"{a}-{b}"
return uuid.uuid5(uuid.NAMESPACE_DNS, pair_str)


@contextmanager
def AuthLock( # pylint: disable=invalid-name
user_id: int,
database_id: int,
) -> Iterator[None]:
"""
KV global lock for refreshing tokens.
"""
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.create import CreateKeyValueCommand
from superset.commands.key_value.delete import DeleteKeyValueCommand
from superset.commands.key_value.delete_expired import DeleteExpiredKeyValueCommand

key = integers_to_uuid(user_id, database_id)
logger.debug(
"Acquiring lock to refresh OAuth2 token for user ID %d and database ID %d",
user_id,
database_id,
)
try:
DeleteExpiredKeyValueCommand(resource=KeyValueResource.OAUTH2).run()
CreateKeyValueCommand(
resource=KeyValueResource.OAUTH2,
codec=PickleKeyValueCodec(),
key=key,
value=True,
expires_on=datetime.now() + LOCK_EXPIRATION,
).run()
yield
except KeyValueCreateFailedError as ex:
raise CreateAuthLockFailedException("Error acquiring lock") from ex
finally:
DeleteKeyValueCommand(resource=KeyValueResource.OAUTH2, key=key).run()
logger.debug(
"Removed lock to refresh OAuth2 token for user ID %d and database ID %d",
user_id,
database_id,
)


def refresh_oauth2_token(
database_id: int,
user_id: int,
db_engine_spec: type[BaseEngineSpec],
token: DatabaseUserOAuth2Tokens,
) -> str | None:
with AuthLock(user_id, database_id):
token_response = db_engine_spec.get_oauth2_fresh_token(token.refresh_token)

# store new access token; note that the refresh token might be revoked, in which
# case there would be no access token in the response
if "access_token" in token_response:
token.access_token = token_response["access_token"]
token.access_token_expiration = datetime.now() + timedelta(
seconds=token_response["expires_in"]
)
db.session.add(token)
if "access_token" not in token_response:
return None

Check warning on line 144 in superset/utils/oauth2.py

View check run for this annotation

Codecov / codecov/patch

superset/utils/oauth2.py#L144

Added line #L144 was not covered by tests

return token.access_token
token.access_token = token_response["access_token"]
token.access_token_expiration = datetime.now() + timedelta(
seconds=token_response["expires_in"]
)
db.session.add(token)

# since the access token is expired and there's no refresh token, delete the entry
db.session.delete(token)

return None
return token.access_token
75 changes: 74 additions & 1 deletion tests/unit_tests/utils/oauth2_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
# pylint: disable=invalid-name, disallowed-name

from datetime import datetime
from uuid import UUID

import pytest
from freezegun import freeze_time
from pytest_mock import MockerFixture

from superset.utils.oauth2 import get_oauth2_access_token
from superset.exceptions import CreateAuthLockFailedException
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import KeyValueResource
from superset.utils.oauth2 import AuthLock, get_oauth2_access_token, integers_to_uuid


def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None:
Expand Down Expand Up @@ -93,3 +98,71 @@ def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None:

# check that token was deleted
db.session.delete.assert_called_with(token)


def test_integers_to_uuid() -> None:
"""
Test `integers_to_uuid`.
"""
assert integers_to_uuid(1, 1) == UUID("4a426d86-eae0-53be-8f9a-8113ffc5a445")
assert integers_to_uuid(2, 1) == UUID("0a81e791-1685-5239-bc04-4cdd6aacc18d")
assert integers_to_uuid(1, 2) == UUID("83b19a49-b4f2-5ac5-9b52-5a63907dd160")


def test_AuthLock_happy_path(mocker: MockerFixture) -> None:
"""
Test successfully acquiring the global auth lock.
"""
CreateKeyValueCommand = mocker.patch(
"superset.commands.key_value.create.CreateKeyValueCommand"
)
DeleteKeyValueCommand = mocker.patch(
"superset.commands.key_value.delete.DeleteKeyValueCommand"
)
DeleteExpiredKeyValueCommand = mocker.patch(
"superset.commands.key_value.delete_expired.DeleteExpiredKeyValueCommand"
)
PickleKeyValueCodec = mocker.patch("superset.utils.oauth2.PickleKeyValueCodec")

with freeze_time("2024-01-01"):
with AuthLock(1, 1):
DeleteExpiredKeyValueCommand.assert_called_with(
resource=KeyValueResource.OAUTH2,
)
CreateKeyValueCommand.assert_called_with(
resource=KeyValueResource.OAUTH2,
codec=PickleKeyValueCodec(),
key=integers_to_uuid(1, 1),
value=True,
expires_on=datetime(2024, 1, 1, 0, 0, 30),
)
DeleteKeyValueCommand.assert_not_called()

DeleteKeyValueCommand.assert_called_with(
resource=KeyValueResource.OAUTH2,
key=integers_to_uuid(1, 1),
)


def test_AuthLock_no_lock(mocker: MockerFixture) -> None:
"""
Test unsuccessfully acquiring the global auth lock.
"""
mocker.patch(
"superset.commands.key_value.create.CreateKeyValueCommand",
side_effect=KeyValueCreateFailedError(),
)
DeleteKeyValueCommand = mocker.patch(
"superset.commands.key_value.delete.DeleteKeyValueCommand"
)

with pytest.raises(CreateAuthLockFailedException) as excinfo:
with AuthLock(1, 1):
pass
assert str(excinfo.value) == "Error acquiring lock"

# confirm that key was deleted
DeleteKeyValueCommand.assert_called_with(
resource=KeyValueResource.OAUTH2,
key=integers_to_uuid(1, 1),
)

0 comments on commit 0d27d8c

Please sign in to comment.