Skip to content

Commit

Permalink
style(mypy): Enforcing typing for views.database (apache#9920)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
2 people authored and auxten committed Nov 20, 2020
1 parent dfc5da8 commit 6b89897
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 49 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true

[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.views.dashboard.*]
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.views.dashboard.*,superset.views.database.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
20 changes: 9 additions & 11 deletions superset/views/database/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from superset import event_logger
from superset.models.core import Database
from superset.typing import FlaskResponse
from superset.utils.core import error_msg_from_exception
from superset.views.base_api import BaseSupersetModelRestApi
from superset.views.database.decorators import check_datasource_access
Expand Down Expand Up @@ -49,7 +50,7 @@ def get_indexes_metadata(
return indexes


def get_col_type(col: Dict) -> str:
def get_col_type(col: Dict[Any, Any]) -> str:
try:
dtype = f"{col['type']}"
except Exception: # pylint: disable=broad-except
Expand Down Expand Up @@ -145,14 +146,14 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):

openapi_spec_tag = "Database"

@expose(
"/<int:pk>/table/<string:table_name>/<string:schema_name>/", methods=["GET"]
)
@expose("/<int:pk>/table/<table_name>/<schema_name>/", methods=["GET"])
@protect()
@check_datasource_access
@safe
@event_logger.log_this
def table_metadata(self, database: Database, table_name: str, schema_name: str):
def table_metadata(
self, database: Database, table_name: str, schema_name: str
) -> FlaskResponse:
""" Table schema info
---
get:
Expand Down Expand Up @@ -276,18 +277,15 @@ def table_metadata(self, database: Database, table_name: str, schema_name: str):
self.incr_stats("success", self.table_metadata.__name__)
return self.response(200, **table_info)

@expose("/<int:pk>/select_star/<string:table_name>/", methods=["GET"])
@expose(
"/<int:pk>/select_star/<string:table_name>/<string:schema_name>/",
methods=["GET"],
)
@expose("/<int:pk>/select_star/<table_name>/", methods=["GET"])
@expose("/<int:pk>/select_star/<table_name>/<schema_name>/", methods=["GET"])
@protect()
@check_datasource_access
@safe
@event_logger.log_this
def select_star(
self, database: Database, table_name: str, schema_name: Optional[str] = None
):
) -> FlaskResponse:
""" Table schema info
---
get:
Expand Down
12 changes: 9 additions & 3 deletions superset/views/database/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,30 @@
# under the License.
import functools
import logging
from typing import Optional
from typing import Any, Callable, Optional

from flask import g
from flask_babel import lazy_gettext as _

from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils.core import parse_js_uri_path_item
from superset.views.base_api import BaseSupersetModelRestApi

logger = logging.getLogger(__name__)


def check_datasource_access(f):
def check_datasource_access(f: Callable) -> Callable:
"""
A Decorator that checks if a user has datasource access
"""

def wraps(self, pk: int, table_name: str, schema_name: Optional[str] = None):
def wraps(
self: BaseSupersetModelRestApi,
pk: int,
table_name: str,
schema_name: Optional[str] = None,
) -> Any:
schema_name_parsed = parse_js_uri_path_item(schema_name, eval_undefined=True)
table_name_parsed = parse_js_uri_path_item(table_name)
if not table_name_parsed:
Expand Down
18 changes: 9 additions & 9 deletions superset/views/database/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Set

from sqlalchemy import or_
from sqlalchemy.orm import Query

from superset import security_manager
from superset.views.base import BaseFilter


class DatabaseFilter(BaseFilter):
# TODO(bogdan): consider caching.
def schema_access_databases(self): # noqa pylint: disable=no-self-use
found_databases = set()
for vm in security_manager.user_view_menu_names("schema_access"):
database_name, _ = security_manager.unpack_schema_perm(vm)
found_databases.add(database_name)
return found_databases
def schema_access_databases(self) -> Set[str]: # noqa pylint: disable=no-self-use
return {
security_manager.unpack_schema_perm(vm)[0]
for vm in security_manager.user_view_menu_names("schema_access")
}

def apply(
self, query, func
): # noqa pylint: disable=unused-argument,arguments-differ
def apply(self, query: Query, value: Any) -> Query:
if security_manager.all_database_access():
return query
database_perms = security_manager.user_view_menu_names("database_access")
Expand Down
20 changes: 11 additions & 9 deletions superset/views/database/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Contains the logic to create cohesive forms on the explore view"""
from typing import List

from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_appbuilder.forms import DynamicForm
from flask_babel import lazy_gettext as _
Expand All @@ -25,25 +27,25 @@

from superset import app, db, security_manager
from superset.forms import CommaSeparatedListField, filter_not_empty_values
from superset.models import core as models
from superset.models.core import Database

config = app.config


class CsvToDatabaseForm(DynamicForm):
# pylint: disable=E0211
def csv_allowed_dbs(): # type: ignore
csv_allowed_dbs = []
def csv_allowed_dbs() -> List[Database]: # type: ignore
csv_enabled_dbs = (
db.session.query(models.Database).filter_by(allow_csv_upload=True).all()
db.session.query(Database).filter_by(allow_csv_upload=True).all()
)
for csv_enabled_db in csv_enabled_dbs:
if CsvToDatabaseForm.at_least_one_schema_is_allowed(csv_enabled_db):
csv_allowed_dbs.append(csv_enabled_db)
return csv_allowed_dbs
return [
csv_enabled_db
for csv_enabled_db in csv_enabled_dbs
if CsvToDatabaseForm.at_least_one_schema_is_allowed(csv_enabled_db)
]

@staticmethod
def at_least_one_schema_is_allowed(database):
def at_least_one_schema_is_allowed(database: Database) -> bool:
"""
If the user has access to the database or all datasource
1. if schemas_allowed_for_csv_upload is empty
Expand Down
19 changes: 11 additions & 8 deletions superset/views/database/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from superset import app, security_manager
from superset.exceptions import SupersetException
from superset.models.core import Database
from superset.security.analytics_db_safety import check_sqlalchemy_uri
from superset.utils import core as utils
from superset.views.database.filters import DatabaseFilter
Expand Down Expand Up @@ -199,7 +200,7 @@ class DatabaseMixin:
"backend": _("Backend"),
}

def _pre_add_update(self, database):
def _pre_add_update(self, database: Database) -> None:
if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]:
check_sqlalchemy_uri(database.sqlalchemy_uri)
self.check_extra(database)
Expand All @@ -214,23 +215,23 @@ def _pre_add_update(self, database):
"schema_access", security_manager.get_schema_perm(database, schema)
)

def pre_add(self, database):
def pre_add(self, database: Database) -> None:
self._pre_add_update(database)

def pre_update(self, database):
def pre_update(self, database: Database) -> None:
self._pre_add_update(database)

def pre_delete(self, obj): # pylint: disable=no-self-use
if obj.tables:
def pre_delete(self, database: Database) -> None: # pylint: disable=no-self-use
if database.tables:
raise SupersetException(
Markup(
"Cannot delete a database that has tables attached. "
"Here's the list of associated tables: "
+ ", ".join("{}".format(o) for o in obj.tables)
+ ", ".join("{}".format(table) for table in database.tables)
)
)

def check_extra(self, database): # pylint: disable=no-self-use
def check_extra(self, database: Database) -> None: # pylint: disable=no-self-use
# this will check whether json.loads(extra) can succeed
try:
extra = database.get_extra()
Expand All @@ -252,7 +253,9 @@ def check_extra(self, database): # pylint: disable=no-self-use
)
)

def check_encrypted_extra(self, database): # pylint: disable=no-self-use
def check_encrypted_extra( # pylint: disable=no-self-use
self, database: Database
) -> None:
# this will check whether json.loads(secure_extra) can succeed
try:
database.get_encrypted_extra()
Expand Down
5 changes: 3 additions & 2 deletions superset/views/database/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
# specific language governing permissions and limitations
# under the License.

from typing import Type
from typing import Optional, Type

from flask_babel import lazy_gettext as _
from marshmallow import ValidationError
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import ArgumentError

from superset import security_manager
from superset.models.core import Database


def sqlalchemy_uri_validator(
Expand All @@ -43,7 +44,7 @@ def sqlalchemy_uri_validator(
)


def schema_allows_csv_upload(database, schema):
def schema_allows_csv_upload(database: Database, schema: Optional[str]) -> bool:
if not database.allow_csv_upload:
return False
schemas = database.get_schema_access_for_csv_upload()
Expand Down
20 changes: 14 additions & 6 deletions superset/views/database/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from flask import flash, g, redirect
from flask_appbuilder import SimpleFormView
from flask_appbuilder.forms import DynamicForm
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _
from wtforms.fields import StringField
Expand All @@ -31,8 +32,10 @@
from superset.constants import RouteMethod
from superset.exceptions import CertificateException
from superset.sql_parse import Table
from superset.typing import FlaskResponse
from superset.utils import core as utils
from superset.views.base import DeleteMixin, SupersetModelView, YamlExportMixin
from superset.views.database.forms import CsvToDatabaseForm

from .forms import CsvToDatabaseForm
from .mixins import DatabaseMixin
Expand All @@ -45,14 +48,19 @@
stats_logger = config["STATS_LOGGER"]


def sqlalchemy_uri_form_validator(_, field: StringField) -> None:
def sqlalchemy_uri_form_validator( # pylint: disable=unused-argument
form: DynamicForm, field: StringField
) -> None:
"""
Check if user has submitted a valid SQLAlchemy URI
"""

sqlalchemy_uri_validator(field.data, exception=ValidationError)


def certificate_form_validator(_, field: StringField) -> None:
def certificate_form_validator( # pylint: disable=unused-argument
form: DynamicForm, field: StringField
) -> None:
"""
Check if user has submitted a valid SSL certificate
"""
Expand All @@ -63,7 +71,7 @@ def certificate_form_validator(_, field: StringField) -> None:
raise ValidationError(ex.message)


def upload_stream_write(form_file_field: "FileStorage", path: str):
def upload_stream_write(form_file_field: "FileStorage", path: str) -> None:
chunk_size = app.config["UPLOAD_CHUNK_SIZE"]
with open(path, "bw") as file_description:
while True:
Expand All @@ -88,7 +96,7 @@ class DatabaseView(

yaml_dict_key = "databases"

def _delete(self, pk):
def _delete(self, pk: int) -> None:
DeleteMixin._delete(self, pk)


Expand All @@ -98,7 +106,7 @@ class CsvToDatabaseView(SimpleFormView):
form_title = _("CSV to Database configuration")
add_columns = ["database", "schema", "table_name"]

def form_get(self, form):
def form_get(self, form: CsvToDatabaseForm) -> None:
form.sep.data = ","
form.header.data = 0
form.mangle_dupe_cols.data = True
Expand All @@ -108,7 +116,7 @@ def form_get(self, form):
form.decimal.data = "."
form.if_exists.data = "fail"

def form_post(self, form):
def form_post(self, form: CsvToDatabaseForm) -> FlaskResponse:
database = form.con.data
csv_table = Table(table=form.name.data, schema=form.schema.data)

Expand Down

0 comments on commit 6b89897

Please sign in to comment.