Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(QueryContext): add QueryContextFactory to meet SRP #17495

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 14 additions & 5 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from typing import Any, Dict
from __future__ import annotations

from typing import Any, Dict, Optional, TYPE_CHECKING

from flask_babel import gettext as _
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
Expand All @@ -24,7 +26,7 @@

from superset import app
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_context_factory import QueryContextFactory
from superset.db_engine_specs.base import builtin_time_grains
from superset.utils import schema as utils
from superset.utils.core import (
Expand All @@ -35,6 +37,9 @@
TimeRangeEndpoint,
)

if TYPE_CHECKING:
from superset.common.query_context import QueryContext

config = app.config

#
Expand Down Expand Up @@ -1129,6 +1134,7 @@ class Meta: # pylint: disable=too-few-public-methods


class ChartDataQueryContextSchema(Schema):
query_context_factory: Optional[QueryContextFactory] = None
datasource = fields.Nested(ChartDataDatasourceSchema)
queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
force = fields.Boolean(
Expand All @@ -1139,13 +1145,16 @@ class ChartDataQueryContextSchema(Schema):
result_type = EnumField(ChartDataResultType, by_value=True)
result_format = EnumField(ChartDataResultFormat, by_value=True)

# pylint: disable=no-self-use,unused-argument
# pylint: disable=unused-argument
@post_load
def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext:
query_context = QueryContext(**data)
query_context = self.get_query_context_factory().create(**data)
return query_context

# pylint: enable=no-self-use,unused-argument
def get_query_context_factory(self) -> QueryContextFactory:
if self.query_context_factory is None:
self.query_context_factory = QueryContextFactory()
return self.query_context_factory


class AnnotationDataSchema(Schema):
Expand Down
45 changes: 15 additions & 30 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,21 @@
from pandas import DateOffset
from typing_extensions import TypedDict

from superset import app, db, is_feature_enabled
from superset import app, is_feature_enabled
from superset.annotation_layers.dao import AnnotationLayerDAO
from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
from superset.common.query_object import QueryObject
from superset.common.query_object_factory import QueryObjectFactory
from superset.common.utils import QueryCacheManager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.constants import CacheRegion
from superset.exceptions import QueryObjectValidationError, SupersetException
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.utils import csv
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.utils.core import (
DatasourceDict,
DTTM_ALIAS,
error_msg_from_exception,
get_column_names_from_columns,
Expand All @@ -57,6 +53,7 @@
from superset.views.utils import get_viz

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
from superset.stats_logger import BaseStatsLogger

config = app.config
Expand All @@ -70,10 +67,6 @@ class CachedTimeOffset(TypedDict):
cache_keys: List[Optional[str]]


def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(config, ConnectorRegistry(), db.session)


class QueryContext:
"""
The query context contains the query object and additional fields necessary
Expand All @@ -90,36 +83,28 @@ class QueryContext:
force: bool
custom_cache_timeout: Optional[int]

cache_values: Dict[str, Any]

# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
# pylint: disable=too-many-arguments
def __init__(
self,
datasource: DatasourceDict,
queries: List[Dict[str, Any]],
result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None,
*,
datasource: BaseDatasource,
queries: List[QueryObject],
result_type: ChartDataResultType,
result_format: ChartDataResultFormat,
force: bool = False,
custom_cache_timeout: Optional[int] = None,
cache_values: Dict[str, Any]
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
query_object_factory = create_query_object_factory()
self.queries = [
query_object_factory.create(self.result_type, **query_obj)
for query_obj in queries
]
self.datasource = datasource
self.result_type = result_type
self.result_format = result_format
self.queries = queries
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.cache_values = {
"datasource": datasource,
"queries": queries,
"result_type": self.result_type,
"result_format": self.result_format,
}
self.cache_values = cache_values

@staticmethod
def left_join_df(
Expand Down
83 changes: 83 additions & 0 deletions superset/common/query_context_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Any, Dict, List, Optional, TYPE_CHECKING

from superset import app, db
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object_factory import QueryObjectFactory
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import DatasourceDict

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource

config = app.config


def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(config, ConnectorRegistry(), db.session)


class QueryContextFactory: # pylint: disable=too-few-public-methods
_query_object_factory: QueryObjectFactory

def __init__(self) -> None:
self._query_object_factory = create_query_object_factory()

def create(
self,
*,
datasource: DatasourceDict,
queries: List[Dict[str, Any]],
result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None,
force: bool = False,
custom_cache_timeout: Optional[int] = None
) -> QueryContext:
datasource_model_instance = None
if datasource:
datasource_model_instance = self._convert_to_model(datasource)
result_type = result_type or ChartDataResultType.FULL
result_format = result_format or ChartDataResultFormat.JSON
queries_ = [
self._query_object_factory.create(result_type, **query_obj)
for query_obj in queries
]
cache_values = {
"datasource": datasource,
"queries": queries,
"result_type": result_type,
"result_format": result_format,
}
return QueryContext(
datasource=datasource_model_instance,
queries=queries_,
result_type=result_type,
result_format=result_format,
force=force,
custom_cache_timeout=custom_cache_timeout,
cache_values=cache_values,
)

# pylint: disable=no-self-use
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
22 changes: 17 additions & 5 deletions superset/models/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
import logging
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
Expand Down Expand Up @@ -41,6 +43,7 @@

if TYPE_CHECKING:
from superset.common.query_context import QueryContext
from superset.common.query_context_factory import QueryContextFactory
from superset.connectors.base.models import BaseDatasource

metadata = Model.metadata # pylint: disable=no-member
Expand All @@ -59,6 +62,8 @@ class Slice( # pylint: disable=too-many-public-methods
):
"""A slice is essentially a report or a view on data"""

query_context_factory: Optional[QueryContextFactory] = None

__tablename__ = "slices"
id = Column(Integer, primary_key=True)
slice_name = Column(String(250))
Expand Down Expand Up @@ -248,13 +253,12 @@ def form_data(self) -> Dict[str, Any]:
update_time_range(form_data)
return form_data

def get_query_context(self) -> Optional["QueryContext"]:
# pylint: disable=import-outside-toplevel
from superset.common.query_context import QueryContext

def get_query_context(self) -> Optional[QueryContext]:
if self.query_context:
try:
return QueryContext(**json.loads(self.query_context))
return self.get_query_context_factory().create(
**json.loads(self.query_context)
)
except json.decoder.JSONDecodeError as ex:
logger.error("Malformed json in slice's query context", exc_info=True)
logger.exception(ex)
Expand Down Expand Up @@ -313,6 +317,14 @@ def icons(self) -> str:
def url(self) -> str:
return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D"

def get_query_context_factory(self) -> QueryContextFactory:
if self.query_context_factory is None:
# pylint: disable=import-outside-toplevel
from superset.common.query_context_factory import QueryContextFactory

self.query_context_factory = QueryContextFactory()
return self.query_context_factory


def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) -> None:
src_class = target.cls_model
Expand Down
24 changes: 20 additions & 4 deletions superset/views/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from __future__ import annotations

from typing import Any, TYPE_CHECKING

import simplejson as json
from flask import request
Expand All @@ -27,31 +29,37 @@
TimeRangeAmbiguousError,
TimeRangeParseFailError,
)
from superset.common.query_context import QueryContext
from superset.legacy import update_time_range
from superset.models.slice import Slice
from superset.typing import FlaskResponse
from superset.utils import core as utils
from superset.utils.date_parser import get_since_until
from superset.views.base import api, BaseSupersetView, handle_api_exception

if TYPE_CHECKING:
from superset.common.query_context_factory import QueryContextFactory

get_time_range_schema = {"type": "string"}


class Api(BaseSupersetView):
query_context_factory = None

@event_logger.log_this
@api
@handle_api_exception
@has_access_api
@expose("/v1/query/", methods=["POST"])
def query(self) -> FlaskResponse: # pylint: disable=no-self-use
def query(self) -> FlaskResponse:
"""
Takes a query_obj constructed in the client and returns payload data response
for the given query_obj.

raises SupersetSecurityException: If the user cannot access the resource
"""
query_context = QueryContext(**json.loads(request.form["query_context"]))
query_context = self.get_query_context_factory().create(
**json.loads(request.form["query_context"])
)
query_context.raise_for_access()
result = query_context.get_payload()
payload_json = result["queries"]
Expand Down Expand Up @@ -99,3 +107,11 @@ def time_range(self, **kwargs: Any) -> FlaskResponse:
except (ValueError, TimeRangeParseFailError, TimeRangeAmbiguousError) as error:
error_msg = {"message": f"Unexpected time range: {error}"}
return self.json_response(error_msg, 400)

def get_query_context_factory(self) -> QueryContextFactory:
if self.query_context_factory is None:
# pylint: disable=import-outside-toplevel
from superset.common.query_context_factory import QueryContextFactory

self.query_context_factory = QueryContextFactory()
return self.query_context_factory