Skip to content

Commit

Permalink
[mypy] Enforcing typing for some modules (#9416)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
john-bodley and john-bodley committed Apr 4, 2020
1 parent 1cdfb82 commit 5e55e09
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 27 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true

[mypy-superset.charts.*,superset.db_engine_specs.*]
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
5 changes: 4 additions & 1 deletion superset/commands/base.py
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
from typing import Optional

from flask_appbuilder.models.sqla import Model


class BaseCommand(ABC):
Expand All @@ -23,7 +26,7 @@ class BaseCommand(ABC):
"""

@abstractmethod
def run(self):
def run(self) -> Optional[Model]:
"""
Run executes the command. Can raise command exceptions
:raises: CommandException
Expand Down
16 changes: 8 additions & 8 deletions superset/commands/exceptions.py
Expand Up @@ -25,25 +25,25 @@
class CommandException(SupersetException):
""" Common base class for Command exceptions. """

def __repr__(self):
def __repr__(self) -> str:
if self._exception:
return self._exception
return self
return repr(self._exception)
return repr(self)


class CommandInvalidError(CommandException):
""" Common base class for Command Invalid errors. """

status = 422

def __init__(self, message="") -> None:
def __init__(self, message: str = "") -> None:
self._invalid_exceptions: List[ValidationError] = []
super().__init__(self.message)

def add(self, exception: ValidationError):
def add(self, exception: ValidationError) -> None:
self._invalid_exceptions.append(exception)

def add_list(self, exceptions: List[ValidationError]):
def add_list(self, exceptions: List[ValidationError]) -> None:
self._invalid_exceptions.extend(exceptions)

def normalized_messages(self) -> Dict[Any, Any]:
Expand Down Expand Up @@ -76,12 +76,12 @@ class ForbiddenError(CommandException):
class OwnersNotFoundValidationError(ValidationError):
status = 422

def __init__(self):
def __init__(self) -> None:
super().__init__(_("Owners are invalid"), field_names=["owners"])


class DatasourceNotFoundValidationError(ValidationError):
status = 404

def __init__(self):
def __init__(self) -> None:
super().__init__(_("Datasource does not exist"), field_names=["datasource_id"])
4 changes: 2 additions & 2 deletions superset/common/query_context.py
Expand Up @@ -157,7 +157,7 @@ def cache_timeout(self) -> int:
return self.datasource.database.cache_timeout
return config["CACHE_DEFAULT_TIMEOUT"]

def cache_key(self, query_obj: QueryObject, **kwargs) -> Optional[str]:
def cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
cache_key = (
query_obj.cache_key(
Expand All @@ -173,7 +173,7 @@ def cache_key(self, query_obj: QueryObject, **kwargs) -> Optional[str]:
return cache_key

def get_df_payload( # pylint: disable=too-many-locals,too-many-statements
self, query_obj: QueryObject, **kwargs
self, query_obj: QueryObject, **kwargs: Any
) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
cache_key = self.cache_key(query_obj, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_object.py
Expand Up @@ -122,7 +122,7 @@ def to_dict(self) -> Dict[str, Any]:
}
return query_object_dict

def cache_key(self, **extra) -> str:
def cache_key(self, **extra: Any) -> str:
"""
The cache key is made out of the key/values from to_dict(), plus any
other key/values in `extra`
Expand Down
9 changes: 5 additions & 4 deletions superset/common/tags.py
Expand Up @@ -14,14 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from sqlalchemy import Metadata
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import and_, func, functions, join, literal, select

from superset.models.tags import ObjectTypes, TagTypes


def add_types(engine, metadata):
def add_types(engine: Engine, metadata: Metadata) -> None:
"""
Tag every object according to its type:
Expand Down Expand Up @@ -163,7 +164,7 @@ def add_types(engine, metadata):
engine.execute(query)


def add_owners(engine, metadata):
def add_owners(engine: Engine, metadata: Metadata) -> None:
"""
Tag every object according to its owner:
Expand Down Expand Up @@ -319,7 +320,7 @@ def add_owners(engine, metadata):
engine.execute(query)


def add_favorites(engine, metadata):
def add_favorites(engine: Engine, metadata: Metadata) -> None:
"""
Tag every object that was favorited:
Expand Down
2 changes: 1 addition & 1 deletion superset/dao/base.py
Expand Up @@ -112,7 +112,7 @@ def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model:
return model

@classmethod
def delete(cls, model: Model, commit=True):
def delete(cls, model: Model, commit: bool = True) -> Model:
"""
Generic delete a model
:raises: DAOCreateFailedError
Expand Down
11 changes: 9 additions & 2 deletions superset/db_engines/hive.py
Expand Up @@ -14,12 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from pyhive.hive import Cursor # pylint: disable=unused-import
from TCLIService.ttypes import TFetchOrientation # pylint: disable=unused-import

# pylint: disable=protected-access
# TODO: contribute back to pyhive.
def fetch_logs(
self, max_rows=1024, orientation=None
): # pylint: disable=unused-argument
self: "Cursor",
max_rows: int = 1024, # pylint: disable=unused-argument
orientation: Optional["TFetchOrientation"] = None,
) -> str: # pylint: disable=unused-argument
"""Mocked. Retrieve the logs produced by the execution of the query.
Can be called multiple times to fetch the logs produced after
the previous call.
Expand Down
12 changes: 6 additions & 6 deletions superset/stats_logger.py
Expand Up @@ -24,26 +24,26 @@
class BaseStatsLogger:
"""Base class for logging realtime events"""

def __init__(self, prefix="superset"):
def __init__(self, prefix: str = "superset") -> None:
self.prefix = prefix

def key(self, key):
def key(self, key: str) -> str:
if self.prefix:
return self.prefix + key
return key

def incr(self, key):
def incr(self, key: str) -> None:
"""Increment a counter"""
raise NotImplementedError()

def decr(self, key):
def decr(self, key: str) -> None:
"""Decrement a counter"""
raise NotImplementedError()

def timing(self, key, value):
def timing(self, key, value: float) -> None:
raise NotImplementedError()

def gauge(self, key):
def gauge(self, key: str) -> None:
"""Setup a gauge"""
raise NotImplementedError()

Expand Down
3 changes: 2 additions & 1 deletion superset/utils/core.py
Expand Up @@ -1224,9 +1224,10 @@ class DatasourceName(NamedTuple):
schema: str


def get_stacktrace():
def get_stacktrace() -> Optional[str]:
if current_app.config["SHOW_STACKTRACE"]:
return traceback.format_exc()
return None


def split(
Expand Down

0 comments on commit 5e55e09

Please sign in to comment.