Skip to content

Commit

Permalink
fix: show all dbs in available endpoint (#15534)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Jul 2, 2021
1 parent 80b8df0 commit 8f92618
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 16 deletions.
29 changes: 20 additions & 9 deletions superset/db_engine_specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def is_engine_spec(attr: Any) -> bool:
)


def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
def load_engine_specs() -> List[Type[BaseEngineSpec]]:
engine_specs: List[Type[BaseEngineSpec]] = []

# load standard engines
Expand All @@ -75,6 +75,12 @@ def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
continue
engine_specs.append(engine_spec)

return engine_specs


def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
engine_specs = load_engine_specs()

# build map from name/alias -> spec
engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
for engine_spec in engine_specs:
Expand Down Expand Up @@ -121,11 +127,16 @@ def get_available_engine_specs() -> Dict[Type[BaseEngineSpec], Set[str]]:
except Exception: # pylint: disable=broad-except
logger.warning("Unable to load SQLAlchemy dialect: %s", dialect)
else:
drivers[dialect.name].add(getattr(dialect, "driver", dialect.name))

engine_specs = get_engine_specs()
return {
engine_specs[backend]: drivers
for backend, drivers in drivers.items()
if backend in engine_specs
}
backend = dialect.name
if isinstance(backend, bytes):
backend = backend.decode()
driver = getattr(dialect, "driver", dialect.name)
if isinstance(driver, bytes):
driver = driver.decode()
drivers[backend].add(driver)

available_engines = {}
for engine_spec in load_engine_specs():
available_engines[engine_spec] = drivers[engine_spec.engine]

return available_engines
1 change: 1 addition & 0 deletions superset/db_engine_specs/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class DrillEngineSpec(BaseEngineSpec):

engine = "drill"
engine_name = "Apache Drill"
default_driver = "sadrill"

_time_grain_expressions = {
None: "{col}",
Expand Down
5 changes: 5 additions & 0 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,3 +545,8 @@ def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
or parsed_query.is_set()
or parsed_query.is_show()
)


class SparkEngineSpec(HiveEngineSpec):

engine_name = "Apache Spark SQL"
8 changes: 7 additions & 1 deletion superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

class MssqlEngineSpec(BaseEngineSpec):
engine = "mssql"
engine_name = "Microsoft SQL"
engine_name = "Microsoft SQL Server"
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 128

Expand Down Expand Up @@ -126,3 +126,9 @@ def extract_error_message(cls, ex: Exception) -> str:
"have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1"
)
return f"{cls.engine} error: {cls._extract_error_message(ex)}"


class AzureSynapseSpec(MssqlEngineSpec):
engine = "mssql"
engine_name = "Azure Synapse"
default_driver = "pyodbc"
2 changes: 1 addition & 1 deletion tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def test_test_connection_failed(self):
expected_response = {
"errors": [
{
"message": "Could not load database driver: MssqlEngineSpec",
"message": "Could not load database driver: AzureSynapseSpec",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
Expand Down
10 changes: 5 additions & 5 deletions tests/integration_tests/db_engine_specs/mssql_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_extract_errors(self):
message='The hostname "locahost" cannot be resolved.',
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"engine_name": "Microsoft SQL Server",
"issue_codes": [
{
"code": 1007,
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_extract_errors(self):
message='Port 12345 on hostname "localhost" refused the connection.',
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"engine_name": "Microsoft SQL Server",
"issue_codes": [
{"code": 1008, "message": "Issue 1008 - The port is closed."}
],
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_extract_errors(self):
),
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"engine_name": "Microsoft SQL Server",
"issue_codes": [
{
"code": 1009,
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_extract_errors(self):
),
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"engine_name": "Microsoft SQL Server",
"issue_codes": [
{
"code": 1009,
Expand Down Expand Up @@ -292,7 +292,7 @@ def test_extract_errors(self):
error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"engine_name": "Microsoft SQL Server",
"issue_codes": [
{
"code": 1014,
Expand Down

0 comments on commit 8f92618

Please sign in to comment.