Skip to content

Commit

Permalink
feat: create backend routes and API for importing saved queries (#13893)
Browse files Browse the repository at this point in the history
* initial commit

* revisions

* started tests

* added unit tests

* revisions

* tests passing

* fixed api test

* Update superset/queries/saved_queries/commands/importers/v1/utils.py

Co-authored-by: Hugh A. Miles II <hughmil3s@gmail.com>

* Revert "Update superset/queries/saved_queries/commands/importers/v1/utils.py"

This reverts commit 18580aa.

Co-authored-by: Hugh A. Miles II <hughmil3s@gmail.com>
  • Loading branch information
AAfghahi and hughhhh authored Apr 8, 2021
1 parent 806fb73 commit b5e5b3a
Show file tree
Hide file tree
Showing 10 changed files with 468 additions and 5 deletions.
84 changes: 83 additions & 1 deletion superset/queries/saved_queries/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from datetime import datetime
from io import BytesIO
from typing import Any
from zipfile import ZipFile

from flask import g, Response, send_file
from flask import g, request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext

from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.databases.filters import DatabaseFilter
from superset.extensions import event_logger
from superset.models.sql_lab import SavedQuery
from superset.queries.saved_queries.commands.bulk_delete import (
BulkDeleteSavedQueryCommand,
Expand All @@ -36,6 +40,9 @@
SavedQueryNotFoundError,
)
from superset.queries.saved_queries.commands.export import ExportSavedQueriesCommand
from superset.queries.saved_queries.commands.importers.dispatcher import (
ImportSavedQueriesCommand,
)
from superset.queries.saved_queries.filters import (
SavedQueryAllTextFilter,
SavedQueryFavoriteFilter,
Expand All @@ -58,6 +65,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
RouteMethod.EXPORT,
RouteMethod.RELATED,
RouteMethod.DISTINCT,
RouteMethod.IMPORT,
"bulk_delete", # not using RouteMethod since locally defined
}
class_permission_name = "SavedQuery"
Expand Down Expand Up @@ -252,3 +260,77 @@ def export(self, **kwargs: Any) -> Response:
as_attachment=True,
attachment_filename=filename,
)

@expose("/import/", methods=["POST"])
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_",
log_to_statsd=False,
)
def import_(self) -> Response:
"""Import Saved Queries with associated databases
---
post:
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
formData:
description: upload file (ZIP)
type: string
format: binary
passwords:
description: JSON map of passwords for each file
type: string
overwrite:
description: overwrite existing saved queries?
type: bool
responses:
200:
description: Saved Query import result
content:
application/json:
schema:
type: object
properties:
message:
type: string
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
upload = request.files.get("formData")
if not upload:
return self.response_400()
with ZipFile(upload) as bundle:
contents = get_contents_from_bundle(bundle)

passwords = (
json.loads(request.form["passwords"])
if "passwords" in request.form
else None
)
overwrite = request.form.get("overwrite") == "true"

command = ImportSavedQueriesCommand(
contents, passwords=passwords, overwrite=overwrite
)
try:
command.run()
return self.response(200, message="OK")
except CommandInvalidError as exc:
logger.warning("Import Saved Query failed")
return self.response_422(message=exc.normalized_messages())
except Exception as exc: # pylint: disable=broad-except
logger.exception("Import Saved Query failed")
return self.response_500(message=str(exc))
15 changes: 14 additions & 1 deletion superset/queries/saved_queries/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
# under the License.
from flask_babel import lazy_gettext as _

from superset.commands.exceptions import CommandException, DeleteFailedError
from superset.commands.exceptions import (
CommandException,
CommandInvalidError,
DeleteFailedError,
ImportFailedError,
)


class SavedQueryBulkDeleteFailedError(DeleteFailedError):
Expand All @@ -25,3 +30,11 @@ class SavedQueryBulkDeleteFailedError(DeleteFailedError):

class SavedQueryNotFoundError(CommandException):
message = _("Saved query not found.")


class SavedQueryImportError(ImportFailedError):
message = _("Import saved query failed for an unknown reason.")


class SavedQueryInvalidError(CommandInvalidError):
message = _("Saved query parameters are invalid.")
67 changes: 67 additions & 0 deletions superset/queries/saved_queries/commands/importers/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import logging
from typing import Any, Dict

from marshmallow.exceptions import ValidationError

from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.queries.saved_queries.commands.importers import v1

logger = logging.getLogger(__name__)

command_versions = [
v1.ImportSavedQueriesCommand,
]


class ImportSavedQueriesCommand(BaseCommand):
"""
Import Saved Queries
This command dispatches the import to different versions of the command
until it finds one that matches.
"""

# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs

def run(self) -> None:
# iterate over all commands until we find a version that can
# handle the contents
for version in command_versions:
command = version(self.contents, *self.args, **self.kwargs)
try:
command.run()
return
except IncorrectVersionError:
logger.debug("File not handled by command, skipping")
except (CommandInvalidError, ValidationError) as exc:
# found right version, but file is invalid
logger.exception("Error running import command")
raise exc

raise CommandInvalidError("Could not find a valid command to import file")

def validate(self) -> None:
pass
71 changes: 71 additions & 0 deletions superset/queries/saved_queries/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Set

from marshmallow import Schema
from sqlalchemy.orm import Session

from superset.commands.importers.v1 import ImportModelsCommand
from superset.connectors.sqla.models import SqlaTable
from superset.databases.commands.importers.v1.utils import import_database
from superset.databases.schemas import ImportV1DatabaseSchema
from superset.queries.saved_queries.commands.exceptions import SavedQueryImportError
from superset.queries.saved_queries.commands.importers.v1.utils import (
import_saved_query,
)
from superset.queries.saved_queries.dao import SavedQueryDAO
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema


class ImportSavedQueriesCommand(ImportModelsCommand):
"""Import Saved Queries"""

dao = SavedQueryDAO
model_name = "saved_queries"
prefix = "queries/"
schemas: Dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"queries/": ImportV1SavedQuerySchema(),
}
import_error = SavedQueryImportError

@staticmethod
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
) -> None:
# discover databases associated with saved queries
database_uuids: Set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("queries/"):
database_uuids.add(config["database_uuid"])

# import related databases
database_ids: Dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id

# import saved queries with the correct parent ref
for file_name, config in configs.items():
if (
file_name.startswith("queries/")
and config["database_uuid"] in database_ids
):
config["db_id"] = database_ids[config["database_uuid"]]
import_saved_query(session, config, overwrite=overwrite)
38 changes: 38 additions & 0 deletions superset/queries/saved_queries/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict

from sqlalchemy.orm import Session

from superset.models.sql_lab import SavedQuery


def import_saved_query(
session: Session, config: Dict[str, Any], overwrite: bool = False
) -> SavedQuery:
existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite:
return existing
config["id"] = existing.id

saved_query = SavedQuery.import_from_dict(session, config, recursive=False)
if saved_query.id is None:
session.flush()

return saved_query
12 changes: 12 additions & 0 deletions superset/queries/saved_queries/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from marshmallow import fields, Schema
from marshmallow.validate import Length

openapi_spec_methods_override = {
"get": {"get": {"description": "Get a saved query",}},
Expand All @@ -32,3 +34,13 @@

get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
get_export_ids_schema = {"type": "array", "items": {"type": "integer"}}


class ImportV1SavedQuerySchema(Schema):
schema = fields.String(allow_none=True, validate=Length(0, 128))
label = fields.String(allow_none=True, validate=Length(0, 256))
description = fields.String(allow_none=True)
sql = fields.String(required=True)
uuid = fields.UUID(required=True)
version = fields.String(required=True)
database_uuid = fields.UUID(required=True)
2 changes: 1 addition & 1 deletion tests/charts/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_import_v1_chart(self):
db.session.commit()

def test_import_v1_chart_multiple(self):
"""Test that a dataset can be imported multiple times"""
"""Test that a chart can be imported multiple times"""
contents = {
"metadata.yaml": yaml.safe_dump(chart_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
Expand Down
15 changes: 14 additions & 1 deletion tests/fixtures/importexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,11 @@
"type": "Dashboard",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}

saved_queries_metadata_config: Dict[str, Any] = {
"version": "1.0.0",
"type": "SavedQuery",
"timestamp": "2021-03-30T20:37:54.791187+00:00",
}
database_config: Dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
Expand Down Expand Up @@ -499,3 +503,12 @@
},
"version": "1.0.0",
}
saved_queries_config = {
"schema": "public",
"label": "Test Saved Query",
"description": None,
"sql": "-- Note: Unless you save your query, these tabs will NOT persist if you clear\nyour cookies or change browsers.\n\n\nSELECT * from birth_names",
"uuid": "05b679b5-8eaf-452c-b874-a7a774cfa4e9",
"version": "1.0.0",
"database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
}
Loading

0 comments on commit b5e5b3a

Please sign in to comment.