Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure Presto database engine spec correctly handles Trino #20729

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
cast,
Dict,
List,
Optional,
Pattern,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from urllib import parse

import pandas as pd
Expand All @@ -35,13 +46,16 @@
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import Row as ResultRow
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import DatabaseError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select

from superset import cache_manager, is_feature_enabled
from superset.common.db_query_status import QueryStatus
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec, ColumnTypeMapping
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError

from superset.errors import SupersetErrorType
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
Expand Down Expand Up @@ -224,6 +238,15 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
),
}

@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
# pylint: disable=import-outside-toplevel,import-error
from pyhive.exc import DatabaseError

return {
DatabaseError: SupersetDBAPIDatabaseError,
}

@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
version = extra.get("version")
Expand Down Expand Up @@ -913,21 +936,23 @@ def extra_table_metadata(
indexes = database.get_indexes(table_name, schema_name)
if indexes:
cols = indexes[0].get("column_names", [])
full_table_name = table_name
if schema_name and "." not in table_name:
full_table_name = "{}.{}".format(schema_name, table_name)
pql = cls._partition_query(full_table_name, database)
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)

if not latest_parts:
latest_parts = tuple([None] * len(col_names))
metadata["partitions"] = {
"cols": cols,
"latest": dict(zip(col_names, latest_parts)),
"partitionQuery": pql,
}
if cols:
full_table_name = table_name
Copy link
Member Author

@john-bodley john-bodley Jul 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same logic as previous just indented under the if cols: statement.

Unlike Presto where get_indexes returns [] for a non-partition table, Trino returns [{'name': 'partition', 'column_names': [], 'unique': False}]. Rather than overriding the engine specific normalize_indexes method I though it would be more prudent to make this method more robust given there was already an expectation that there may be no columns associated with the index, i.e., a non-partitioned table.

if schema_name and "." not in table_name:
full_table_name = "{}.{}".format(schema_name, table_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f-string?

Copy link
Member Author

@john-bodley john-bodley Aug 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ktmud this code is unchanged, i.e., it's now nested under the if cols: section and thus I would prefer not make changes to said code, at least in this PR.

pql = cls._partition_query(full_table_name, database)
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)

if not latest_parts:
latest_parts = tuple([None] * len(col_names))
metadata["partitions"] = {
"cols": cols,
"latest": dict(zip(col_names, latest_parts)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should probably change the signature of latest_partition() itself to return this dict---would worth another PR.

For this code, it seems it can be simplified as:

Suggested change
"latest": dict(zip(col_names, latest_parts)),
"latest": {
col: latest_parts[i] if latest_parts else None
for i, col in enumerate(col_names)
}

"partitionQuery": pql,
}

# flake8 is not matching `Optional[str]` to `Any` for some reason...
metadata["view"] = cast(
Expand All @@ -947,20 +972,16 @@ def get_create_view(
:param schema: Schema name
:param table: Table (view) name
"""
# pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError

engine = cls.get_engine(database, schema)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)

except DatabaseError: # not a VIEW
return cls.fetch_data(cursor, 1)[0][0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never have been a pyhive.exc exception to begin with.

except SupersetDBAPIDatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)
return rows[0][0]

@classmethod
def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
Expand Down
12 changes: 11 additions & 1 deletion superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING

import simplejson as json
from flask import current_app
Expand All @@ -25,6 +25,7 @@

from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.models.sql_lab import Query
from superset.utils import core as utils
Expand All @@ -45,6 +46,15 @@ class TrinoEngineSpec(PrestoEngineSpec):
engine_aliases = {"trinonative"} # Required for backwards compatibility.
engine_name = "Trino"

@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
# pylint: disable=import-outside-toplevel,import-error
from trino.exceptions import DatabaseError

return {
DatabaseError: SupersetDBAPIDatabaseError,
}

@classmethod
def update_impersonation_config(
cls,
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

import pandas as pd
from sqlalchemy import types
from sqlalchemy.exc import DatabaseError
from sqlalchemy.sql import select

from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery
from superset.utils.core import DatasourceName, GenericDataType
Expand Down Expand Up @@ -895,9 +897,7 @@ def test_get_create_view_exception(self):
PrestoEngineSpec.get_create_view(database, schema=schema, table=table)

def test_get_create_view_database_error(self):
from pyhive.exc import DatabaseError

mock_execute = mock.MagicMock(side_effect=DatabaseError())
mock_execute = mock.MagicMock(side_effect=SupersetDBAPIDatabaseError())
database = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
Expand Down