Skip to content

Commit

Permalink
refactor: Deprecate ensure_user_is_set in favor of override_user (#20502
Browse files Browse the repository at this point in the history
)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
john-bodley and John Bodley committed Jul 5, 2022
1 parent ad308fb commit 94b3d2f
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 174 deletions.
170 changes: 87 additions & 83 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
security_manager,
)
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.utils.core import override_user
from superset.views.utils import get_datasource_info, get_viz

if TYPE_CHECKING:
Expand All @@ -44,16 +45,6 @@
] # TODO: new config key


def ensure_user_is_set(user_id: Optional[int]) -> None:
user_is_not_set = not (hasattr(g, "user") and g.user is not None)
if user_is_not_set and user_id is not None:
# pylint: disable=assigning-non-slot
g.user = security_manager.get_user_by_id(user_id)
elif user_is_not_set:
# pylint: disable=assigning-non-slot
g.user = security_manager.get_anonymous_user()


def set_form_data(form_data: Dict[str, Any]) -> None:
# pylint: disable=assigning-non-slot
g.form_data = form_data
Expand All @@ -76,30 +67,35 @@ def load_chart_data_into_cache(
# pylint: disable=import-outside-toplevel
from superset.charts.data.commands.get_data_command import ChartDataCommand

try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
query_context = _create_query_context_from_form(form_data)
command = ChartDataCommand(query_context)
result = command.run(cache=True)
cache_key = result["cache_key"]
result_url = f"/api/v1/chart/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
async_query_manager.STATUS_DONE,
result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while loading chart data, error: %s", ex)
raise ex
except Exception as ex:
# TODO: QueryContext should support SIP-40 style errors
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
errors = [{"message": error}]
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise ex
user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
or security_manager.get_anonymous_user()
)

with override_user(user, force=False):
try:
set_form_data(form_data)
query_context = _create_query_context_from_form(form_data)
command = ChartDataCommand(query_context)
result = command.run(cache=True)
cache_key = result["cache_key"]
result_url = f"/api/v1/chart/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
async_query_manager.STATUS_DONE,
result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while loading chart data, error: %s", ex)
raise ex
except Exception as ex:
# TODO: QueryContext should support SIP-40 style errors
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
errors = [{"message": error}]
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise ex


@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
Expand All @@ -110,53 +106,61 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
force: bool = False,
) -> None:
cache_key_prefix = "ejr-" # ejr: explore_json request
try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)

# Perform a deep copy here so that below we can cache the original
# value of the form_data object. This is necessary since the viz
# objects modify the form_data object. If the modified version were
# to be cached here, it will lead to a cache miss when clients
# attempt to retrieve the value of the completed async query.
original_form_data = copy.deepcopy(form_data)

viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
form_data=form_data,
force=force,
)
# run query & cache results
payload = viz_obj.get_payload()
if viz_obj.has_error(payload):
raise SupersetVizException(errors=payload["errors"])

# Cache the original form_data value for async retrieval
cache_value = {
"form_data": original_form_data,
"response_type": response_type,
}
cache_key = generate_cache_key(cache_value, cache_key_prefix)
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
result_url = f"/superset/explore_json/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
async_query_manager.STATUS_DONE,
result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while loading explore json, error: %s", ex)
raise ex
except Exception as ex:
if isinstance(ex, SupersetVizException):
errors = ex.errors # pylint: disable=no-member
else:
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
errors = [error]

async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise ex
user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
or security_manager.get_anonymous_user()
)

with override_user(user, force=False):
try:
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)

# Perform a deep copy here so that below we can cache the original
# value of the form_data object. This is necessary since the viz
# objects modify the form_data object. If the modified version were
# to be cached here, it will lead to a cache miss when clients
# attempt to retrieve the value of the completed async query.
original_form_data = copy.deepcopy(form_data)

viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
form_data=form_data,
force=force,
)
# run query & cache results
payload = viz_obj.get_payload()
if viz_obj.has_error(payload):
raise SupersetVizException(errors=payload["errors"])

# Cache the original form_data value for async retrieval
cache_value = {
"form_data": original_form_data,
"response_type": response_type,
}
cache_key = generate_cache_key(cache_value, cache_key_prefix)
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
result_url = f"/superset/explore_json/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
async_query_manager.STATUS_DONE,
result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning(
"A timeout occurred while loading explore json, error: %s", ex
)
raise ex
except Exception as ex:
if isinstance(ex, SupersetVizException):
errors = ex.errors # pylint: disable=no-member
else:
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
errors = [error]

async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise ex
18 changes: 11 additions & 7 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,23 +1453,27 @@ def get_user_id() -> Optional[int]:


@contextmanager
def override_user(user: Optional[User]) -> Iterator[Any]:
def override_user(user: Optional[User], force: bool = True) -> Iterator[Any]:
"""
Temporarily override the current user (if defined) per `flask.g`.
Temporarily override the current user per `flask.g` with the specified user.
Sometimes, often in the context of async Celery tasks, it is useful to switch the
current user (which may be undefined) to different one, execute some SQLAlchemy
tasks and then revert back to the original one.
tasks et al. and then revert back to the original one.
:param user: The override user
:param force: Whether to override the current user if set
"""

# pylint: disable=assigning-non-slot
if hasattr(g, "user"):
current = g.user
g.user = user
yield
g.user = current
if force or g.user is None:
current = g.user
g.user = user
yield
g.user = current
else:
yield
else:
g.user = user
yield
Expand Down
24 changes: 12 additions & 12 deletions tests/integration_tests/access_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,34 +562,34 @@ def test_get_username(
assert get_username() == username


@pytest.mark.parametrize(
"username",
[
None,
"alpha",
"gamma",
],
)
@pytest.mark.parametrize("username", [None, "alpha", "gamma"])
@pytest.mark.parametrize("force", [False, True])
def test_override_user(
app_context: AppContext,
mocker: MockFixture,
username: str,
force: bool,
) -> None:
mock_g = mocker.patch("superset.utils.core.g", spec={})
admin = security_manager.find_user(username="admin")
user = security_manager.find_user(username)

with override_user(user, force):
assert mock_g.user == user

assert not hasattr(mock_g, "user")

with override_user(user):
mock_g.user = None

with override_user(user, force):
assert mock_g.user == user

assert not hasattr(mock_g, "user")
assert mock_g.user is None

mock_g.user = admin

with override_user(user):
assert mock_g.user == user
with override_user(user, force):
assert mock_g.user == user if force else admin

assert mock_g.user == admin

Expand Down
Loading

0 comments on commit 94b3d2f

Please sign in to comment.