Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: create backend routes and API for importing saved queries #13893

Merged
merged 9 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion superset/queries/saved_queries/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from datetime import datetime
from io import BytesIO
from typing import Any
from zipfile import ZipFile

from flask import g, Response, send_file
from flask import g, request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext

from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.databases.filters import DatabaseFilter
from superset.extensions import event_logger
from superset.models.sql_lab import SavedQuery
from superset.queries.saved_queries.commands.bulk_delete import (
BulkDeleteSavedQueryCommand,
Expand All @@ -36,6 +40,9 @@
SavedQueryNotFoundError,
)
from superset.queries.saved_queries.commands.export import ExportSavedQueriesCommand
from superset.queries.saved_queries.commands.importers.dispatcher import (
ImportSavedQueriesCommand,
)
from superset.queries.saved_queries.filters import (
SavedQueryAllTextFilter,
SavedQueryFavoriteFilter,
Expand All @@ -58,6 +65,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
RouteMethod.EXPORT,
RouteMethod.RELATED,
RouteMethod.DISTINCT,
RouteMethod.IMPORT,
"bulk_delete", # not using RouteMethod since locally defined
}
class_permission_name = "SavedQuery"
Expand Down Expand Up @@ -252,3 +260,77 @@ def export(self, **kwargs: Any) -> Response:
as_attachment=True,
attachment_filename=filename,
)

@expose("/import/", methods=["POST"])
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_",
log_to_statsd=False,
)
def import_(self) -> Response:
"""Import Saved Queries with associated databases
---
post:
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
formData:
description: upload file (ZIP)
type: string
format: binary
passwords:
description: JSON map of passwords for each file
type: string
overwrite:
description: overwrite existing saved queries?
type: bool
responses:
200:
description: Saved Query import result
content:
application/json:
schema:
type: object
properties:
message:
type: string
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
upload = request.files.get("formData")
if not upload:
return self.response_400()
with ZipFile(upload) as bundle:
contents = get_contents_from_bundle(bundle)

passwords = (
json.loads(request.form["passwords"])
if "passwords" in request.form
else None
)
overwrite = request.form.get("overwrite") == "true"

command = ImportSavedQueriesCommand(
contents, passwords=passwords, overwrite=overwrite
)
try:
command.run()
return self.response(200, message="OK")
except CommandInvalidError as exc:
logger.warning("Import Saved Query failed")
return self.response_422(message=exc.normalized_messages())
except Exception as exc: # pylint: disable=broad-except
logger.exception("Import Saved Query failed")
return self.response_500(message=str(exc))
15 changes: 14 additions & 1 deletion superset/queries/saved_queries/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
# under the License.
from flask_babel import lazy_gettext as _

from superset.commands.exceptions import CommandException, DeleteFailedError
from superset.commands.exceptions import (
CommandException,
CommandInvalidError,
DeleteFailedError,
ImportFailedError,
)


class SavedQueryBulkDeleteFailedError(DeleteFailedError):
Expand All @@ -25,3 +30,11 @@ class SavedQueryBulkDeleteFailedError(DeleteFailedError):

class SavedQueryNotFoundError(CommandException):
message = _("Saved query not found.")


class SavedQueryImportError(ImportFailedError):
message = _("Import saved query failed for an unknown reason.")


class SavedQueryInvalidError(CommandInvalidError):
message = _("Saved query parameters are invalid.")
67 changes: 67 additions & 0 deletions superset/queries/saved_queries/commands/importers/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import logging
from typing import Any, Dict

from marshmallow.exceptions import ValidationError

from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.queries.saved_queries.commands.importers import v1

logger = logging.getLogger(__name__)

command_versions = [
v1.ImportSavedQueriesCommand,
]


class ImportSavedQueriesCommand(BaseCommand):
"""
Import Saved Queries

This command dispatches the import to different versions of the command
until it finds one that matches.
"""

# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs

def run(self) -> None:
# iterate over all commands until we find a version that can
# handle the contents
for version in command_versions:
command = version(self.contents, *self.args, **self.kwargs)
try:
command.run()
return
except IncorrectVersionError:
logger.debug("File not handled by command, skipping")
except (CommandInvalidError, ValidationError) as exc:
# found right version, but file is invalid
logger.exception("Error running import command")
raise exc

raise CommandInvalidError("Could not find a valid command to import file")

def validate(self) -> None:
pass
72 changes: 72 additions & 0 deletions superset/queries/saved_queries/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Set

from marshmallow import Schema
from sqlalchemy.orm import Session

from superset.commands.importers.v1 import ImportModelsCommand
from superset.connectors.sqla.models import SqlaTable
from superset.databases.commands.importers.v1.utils import import_database
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.queries.saved_queries.commands.exceptions import SavedQueryImportError
from superset.queries.saved_queries.commands.importers.v1.utils import (
import_saved_query,
)
from superset.queries.saved_queries.dao import SavedQueryDAO
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema


class ImportSavedQueriesCommand(ImportModelsCommand):
"""Import Saved Queries"""

dao = SavedQueryDAO
model_name = "saved_queries"
prefix = "queries/"
schemas: Dict[str, Schema] = {
"datasets/": ImportV1DatasetSchema(),
"queries/": ImportV1SavedQuerySchema(),
}
import_error = SavedQueryImportError

@staticmethod
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
) -> None:
# discover databases associated with saved queries
database_uuids: Set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("queries/"):
database_uuids.add(config["database_uuid"])

# import related databases
database_ids: Dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id

# import saved queries with the correct parent ref
for file_name, config in configs.items():
if (
file_name.startswith("queries/")
and config["database_uuid"] in database_ids
):
config["db_id"] = database_ids[config["database_uuid"]]
import_saved_query(session, config, overwrite=overwrite)
38 changes: 38 additions & 0 deletions superset/queries/saved_queries/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict

from sqlalchemy.orm import Session

from superset.models.sql_lab import SavedQuery


def import_saved_query(
session: Session, config: Dict[str, Any], overwrite: bool = False
) -> SavedQuery:
existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite:
return existing
config["id"] = existing.id

saved_query = SavedQuery.import_from_dict(session, config, recursive=False)
if saved_query.id is None:
session.flush()

return saved_query
12 changes: 12 additions & 0 deletions superset/queries/saved_queries/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from marshmallow import fields, Schema
from marshmallow.validate import Length

openapi_spec_methods_override = {
"get": {"get": {"description": "Get a saved query",}},
Expand All @@ -32,3 +34,13 @@

get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
get_export_ids_schema = {"type": "array", "items": {"type": "integer"}}


class ImportV1SavedQuerySchema(Schema):
schema = fields.String(allow_none=True, validate=Length(0, 128))
label = fields.String(allow_none=True, validate=Length(0, 256))
description = fields.String(allow_none=True)
sql = fields.String(required=True)
uuid = fields.UUID(required=True)
version = fields.String(required=True)
database_uuid = fields.UUID(required=True)
2 changes: 1 addition & 1 deletion tests/charts/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_import_v1_chart(self):
db.session.commit()

def test_import_v1_chart_multiple(self):
"""Test that a dataset can be imported multiple times"""
"""Test that a chart can be imported multiple times"""
contents = {
"metadata.yaml": yaml.safe_dump(chart_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
Expand Down
15 changes: 14 additions & 1 deletion tests/fixtures/importexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,11 @@
"type": "Dashboard",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}

saved_queries_metadata_config: Dict[str, Any] = {
"version": "1.0.0",
"type": "SavedQuery",
"timestamp": "2021-03-30T20:37:54.791187+00:00",
}
database_config: Dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
Expand Down Expand Up @@ -499,3 +503,12 @@
},
"version": "1.0.0",
}
saved_queries_config = {
"schema": "public",
"label": "Test Saved Query",
"description": None,
"sql": "-- Note: Unless you save your query, these tabs will NOT persist if you clear\nyour cookies or change browsers.\n\n\nSELECT * from birth_names",
"uuid": "05b679b5-8eaf-452c-b874-a7a774cfa4e9",
"version": "1.0.0",
"database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
}
Loading