Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: quote column name if db requires #15465

Merged
merged 2 commits into from Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 17 additions & 1 deletion superset/connectors/sqla/models.py
Expand Up @@ -49,7 +49,14 @@
)
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql import (
column,
ColumnElement,
literal_column,
quoted_name,
table,
text,
)
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom, TextClause
from sqlalchemy.sql.selectable import Alias, TableClause
Expand Down Expand Up @@ -897,16 +904,25 @@ def make_sqla_column_compatible(
self, sqla_col: Column, label: Optional[str] = None
) -> Column:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
also adds quotes to the column if engine is configured for quotes.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.db_engine_spec

# add quotes to column
if db_engine_spec.force_column_alias_quotes:
sqla_col = column(
quoted_name(sqla_col.name, True), sqla_col.type, sqla_col.is_literal
)

# add quotes to tables
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)

sqla_col.key = label_expected
return sqla_col

Expand Down
57 changes: 56 additions & 1 deletion tests/core_tests.py
Expand Up @@ -25,6 +25,9 @@
import logging
from typing import Dict, List
from urllib.parse import quote

from sqlalchemy.sql import column, quoted_name, literal_column
from sqlalchemy import select
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices

import pytest
Expand All @@ -38,7 +41,7 @@
import sqlalchemy as sqla
from sqlalchemy.exc import SQLAlchemyError
from superset.models.cache import CacheKey
from superset.utils.core import get_example_database
from superset.utils.core import get_example_database, get_or_create_db
from tests.conftest import with_feature_flags
from tests.fixtures.energy_dashboard import load_energy_table_with_slice
from tests.test_app import app
Expand Down Expand Up @@ -890,6 +893,58 @@ def test_comments_in_sqlatable_query(self):
rendered_query = str(table.get_from_clause())
self.assertEqual(clean_query, rendered_query)

def test_make_column_compatible(self):
"""
DB Eng Specs: Make column compatible
"""

# with force_column_alias_quotes enabled
snowflake_database = get_or_create_db("snowflake", "snowflake://")

table = SqlaTable(
table_name="test_columns_with_alias_quotes", database=snowflake_database,
)

col = table.make_sqla_column_compatible(column("foo"))
s = select([col])
self.assertEqual(str(s), 'SELECT "foo" AS "foo"')

# with literal_column
table = SqlaTable(
table_name="test_columns_with_alias_quotes_on_literal_column",
database=snowflake_database,
)

col = table.make_sqla_column_compatible(literal_column("foo"))
s = select([col])
self.assertEqual(str(s), 'SELECT foo AS "foo"')

# with force_column_alias_quotes NOT enabled
postgres_database = get_or_create_db("postgresql", "postgresql://")

table = SqlaTable(
table_name="test_columns_with_no_quotes", database=postgres_database,
)

col = table.make_sqla_column_compatible(column("foo"))
s = select([col])
self.assertEqual(str(s), "SELECT foo AS foo")

# with literal_column
table = SqlaTable(
table_name="test_columns_with_no_quotes_on_literal_column",
database=postgres_database,
)

col = table.make_sqla_column_compatible(literal_column("foo"))
s = select([col])
self.assertEqual(str(s), "SELECT foo AS foo")

# cleanup
db.session.delete(snowflake_database)
db.session.delete(postgres_database)
db.session.commit()

def test_slice_payload_no_datasource(self):
self.login(username="admin")
data = self.get_json_resp("/superset/explore_json/", raise_on_error=False)
Expand Down
16 changes: 16 additions & 0 deletions tests/db_engine_specs/snowflake_tests.py
Expand Up @@ -16,12 +16,28 @@
# under the License.
import json

from sqlalchemy import column

from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.models.core import Database
from tests.db_engine_specs.base_tests import TestDbEngineSpec


class TestSnowflakeDbEngineSpec(TestDbEngineSpec):
def test_snowflake_sqla_column_label(self):
"""
DB Eng Specs (snowflake): Test column label
"""
test_cases = {
"Col": "Col",
"SUM(x)": "SUM(x)",
"SUM[x]": "SUM[x]",
"12345_col": "12345_col",
}
for original, expected in test_cases.items():
actual = SnowflakeEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)

def test_convert_dttm(self):
dttm = self.get_dttm()

Expand Down