diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index cb40b9540818..2056109bbff7 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -58,7 +58,9 @@ def create( result_type = result_type or ChartDataResultType.FULL result_format = result_format or ChartDataResultFormat.JSON queries_ = [ - self._query_object_factory.create(result_type, **query_obj) + self._query_object_factory.create( + result_type, datasource=datasource, **query_obj + ) for query_obj in queries ] cache_values = { diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 40d37041b916..a8585fd47e05 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -23,9 +23,11 @@ from pprint import pformat from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING +from flask import g from flask_babel import gettext as _ from pandas import DataFrame +from superset import feature_flag_manager from superset.common.chart_data import ChartDataResultType from superset.exceptions import ( InvalidPostProcessingError, @@ -396,6 +398,24 @@ def cache_key(self, **extra: Any) -> str: if annotation_layers: cache_dict["annotation_layers"] = annotation_layers + # Add an impersonation key to cache if impersonation is enabled on the db + if ( + feature_flag_manager.is_feature_enabled("CACHE_IMPERSONATION") + and self.datasource + and hasattr(self.datasource, "database") + and self.datasource.database.impersonate_user + ): + + if key := self.datasource.database.db_engine_spec.get_impersonation_key( + getattr(g, "user", None) + ): + + logger.debug( + "Adding impersonation key to QueryObject cache dict: %s", key + ) + + cache_dict["impersonation_key"] = key + return md5_sha_from_dict(cache_dict, default=json_int_dttm_ser, ignore_nan=True) def exec_post_processing(self, df: DataFrame) -> DataFrame: diff --git a/superset/config.py b/superset/config.py index 8a5ec248fb8a..17c6a55412db 100644 --- a/superset/config.py +++ b/superset/config.py @@ -429,6 +429,9 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the # query, and might break queries and/or allow users to bypass RLS. Use with care! "RLS_IN_SQLLAB": False, + # Enable caching per impersonation key (e.g username) in a datasource where user + # impersonation is enabled + "CACHE_IMPERSONATION": False, } # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 6a2ddc5e5c3f..b4f4ec25c451 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -41,6 +41,7 @@ from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from flask import current_app +from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ from marshmallow import fields, Schema from marshmallow.validate import Range @@ -1537,6 +1538,17 @@ def cancel_query( # pylint: disable=unused-argument def parse_sql(cls, sql: str) -> List[str]: return [str(s).strip(" ;") for s in sqlparse.parse(sql)] + @classmethod + def get_impersonation_key(cls, user: Optional[User]) -> Any: + """ + Construct an impersonation key, by default it's the given username. + + :param user: logged in user + + :returns: username if given user is not null + """ + return user.username if user else None + # schema for adding a database by providing parameters instead of the # full SQLAlchemy URI