Skip to content

Commit

Permalink
[mypy] Enforcing typing for superset.models (#9883)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
john-bodley and john-bodley committed May 23, 2020
1 parent 6d4e236 commit e789a35
Show file tree
Hide file tree
Showing 14 changed files with 207 additions and 130 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true

[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
8 changes: 6 additions & 2 deletions superset/connectors/sqla/models.py
Expand Up @@ -18,7 +18,7 @@
import logging
import re
from collections import OrderedDict
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union

import pandas as pd
Expand Down Expand Up @@ -103,7 +103,11 @@ def query(self, query_obj: Dict[str, Any]) -> QueryResult:
logger.exception(ex)
error_message = utils.error_msg_from_exception(ex)
return QueryResult(
status=status, df=df, duration=0, query="", error_message=error_message
status=status,
df=df,
duration=timedelta(0),
query="",
error_message=error_message,
)

def get_query_str(self, query_obj):
Expand Down
3 changes: 2 additions & 1 deletion superset/legacy.py
Expand Up @@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Code related with dealing with legacy / change management"""
from typing import Any, Dict


def update_time_range(form_data):
def update_time_range(form_data: Dict[str, Any]) -> None:
"""Move since and until to time_range."""
if "since" in form_data or "until" in form_data:
form_data["time_range"] = "{} : {}".format(
Expand Down
6 changes: 4 additions & 2 deletions superset/models/annotations.py
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""a collection of Annotation-related models"""
from typing import Any, Dict

from flask_appbuilder import Model
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text
from sqlalchemy.orm import relationship
Expand All @@ -31,7 +33,7 @@ class AnnotationLayer(Model, AuditMixinNullable):
name = Column(String(250))
descr = Column(Text)

def __repr__(self):
def __repr__(self) -> str:
return self.name


Expand All @@ -52,7 +54,7 @@ class Annotation(Model, AuditMixinNullable):
__table_args__ = (Index("ti_dag_state", layer_id, start_dttm, end_dttm),)

@property
def data(self):
def data(self) -> Dict[str, Any]:
return {
"layer_id": self.layer_id,
"start_dttm": self.start_dttm,
Expand Down
21 changes: 13 additions & 8 deletions superset/models/core.py
Expand Up @@ -152,7 +152,7 @@ class Database(
]
export_children = ["tables"]

def __repr__(self):
def __repr__(self) -> str:
return self.name

@property
Expand Down Expand Up @@ -234,7 +234,9 @@ def default_schemas(self) -> List[str]:
return self.get_extra().get("default_schemas", [])

@classmethod
def get_password_masked_url_from_uri(cls, uri: str): # pylint: disable=invalid-name
def get_password_masked_url_from_uri( # pylint: disable=invalid-name
cls, uri: str
) -> URL:
sqlalchemy_url = make_url(uri)
return cls.get_password_masked_url(sqlalchemy_url)

Expand Down Expand Up @@ -279,7 +281,7 @@ def get_effective_user(
effective_username = g.user.username
return effective_username

@utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra"))
@utils.memoized(watch=["impersonate_user", "sqlalchemy_uri_decrypted", "extra"])
def get_sqla_engine(
self,
schema: Optional[str] = None,
Expand Down Expand Up @@ -339,7 +341,7 @@ def get_sqla_engine(
def get_reserved_words(self) -> Set[str]:
return self.get_dialect().preparer.reserved_words

def get_quoter(self):
def get_quoter(self) -> Callable:
return self.get_dialect().identifier_preparer.quote

def get_df( # pylint: disable=too-many-locals
Expand Down Expand Up @@ -405,7 +407,7 @@ def select_star( # pylint: disable=too-many-arguments
indent: bool = True,
latest_partition: bool = False,
cols: Optional[List[Dict[str, Any]]] = None,
):
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
eng = self.get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
return self.db_engine_spec.select_star(
Expand Down Expand Up @@ -436,7 +438,10 @@ def inspector(self) -> Inspector:
attribute_in_key="id",
)
def get_all_table_names_in_database(
self, cache: bool = False, cache_timeout: Optional[bool] = None, force=False
self,
cache: bool = False,
cache_timeout: Optional[bool] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
Expand Down Expand Up @@ -547,7 +552,7 @@ def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:

@classmethod
def get_db_engine_spec_for_backend(
cls, backend
cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)

Expand All @@ -565,7 +570,7 @@ def grains(self) -> Tuple[TimeGrain, ...]:
def get_extra(self) -> Dict[str, Any]:
return self.db_engine_spec.get_extra_params(self)

def get_encrypted_extra(self):
def get_encrypted_extra(self) -> Dict[str, Any]:
encrypted_extra = {}
if self.encrypted_extra:
try:
Expand Down
26 changes: 15 additions & 11 deletions superset/models/dashboard.py
Expand Up @@ -36,7 +36,9 @@
Text,
UniqueConstraint,
)
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
from sqlalchemy.orm.mapper import Mapper

from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
from superset.models.helpers import AuditMixinNullable, ImportMixin
Expand All @@ -59,7 +61,7 @@
logger = logging.getLogger(__name__)


def copy_dashboard(mapper, connection, target):
def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard") -> None:
# pylint: disable=unused-argument
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
if dashboard_id is None:
Expand Down Expand Up @@ -140,7 +142,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
"slug",
]

def __repr__(self):
def __repr__(self) -> str:
return self.dashboard_title or str(self.id)

@property
Expand Down Expand Up @@ -202,13 +204,13 @@ def thumbnail_url(self) -> str:
return f"/api/v1/dashboard/{self.id}/thumbnail/{self.digest}/"

@property
def changed_by_name(self):
def changed_by_name(self) -> str:
if not self.changed_by:
return ""
return str(self.changed_by)

@property
def changed_by_url(self):
def changed_by_url(self) -> str:
if not self.changed_by:
return ""
return f"/superset/profile/{self.changed_by.username}"
Expand All @@ -229,8 +231,8 @@ def data(self) -> Dict[str, Any]:
"position_json": positions,
}

@property
def params(self) -> str:
@property # type: ignore
def params(self) -> str: # type: ignore
return self.json_metadata

@params.setter
Expand All @@ -257,7 +259,9 @@ def import_obj( # pylint: disable=too-many-locals,too-many-branches,too-many-st
Audit metadata isn't copied over.
"""

def alter_positions(dashboard, old_to_new_slc_id_dict):
def alter_positions(
dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
) -> None:
""" Updates slice_ids in the position json.
Sample position_json data:
Expand Down Expand Up @@ -291,9 +295,9 @@ def alter_positions(dashboard, old_to_new_slc_id_dict):
if (
isinstance(value, dict)
and value.get("meta")
and value.get("meta").get("chartId")
and value.get("meta", {}).get("chartId")
):
old_slice_id = value.get("meta").get("chartId")
old_slice_id = value["meta"]["chartId"]

if old_slice_id in old_to_new_slc_id_dict:
value["meta"]["chartId"] = old_to_new_slc_id_dict[old_slice_id]
Expand Down Expand Up @@ -470,8 +474,8 @@ def export_dashboards( # pylint: disable=too-many-locals


def event_after_dashboard_changed( # pylint: disable=unused-argument
mapper, connection, target
):
mapper: Mapper, connection: Connection, target: Dashboard
) -> None:
cache_dashboard_thumbnail.delay(target.id, force=True)


Expand Down

0 comments on commit e789a35

Please sign in to comment.