Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: hive metadata extractor not work on postgresql #394

Merged
merged 2 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 42 additions & 5 deletions databuilder/extractor/hive_table_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pyhocon import ConfigFactory, ConfigTree
from typing import Iterator, Union, Dict, Any
from sqlalchemy.engine.url import make_url

from databuilder import Scoped
from databuilder.extractor.table_metadata_constants import PARTITION_BADGE
Expand Down Expand Up @@ -56,6 +57,34 @@ class HiveTableMetadataExtractor(Extractor):
ORDER by tbl_id, is_partition_col desc;
"""

DEFAULT_POSTGRES_SQL_STATEMENT = """
SELECT source.* FROM
(SELECT t."TBL_ID" as tbl_id, d."NAME" as "schema", t."TBL_NAME" as name, t."TBL_TYPE",
tp."PARAM_VALUE" as description, p."PKEY_NAME" as col_name, p."INTEGER_IDX" as col_sort_order,
p."PKEY_TYPE" as col_type, p."PKEY_COMMENT" as col_description, 1 as "is_partition_col",
CASE WHEN t."TBL_TYPE" = 'VIRTUAL_VIEW' THEN 1
ELSE 0 END as "is_view"
FROM "TBLS" t
JOIN "DBS" d ON t."DB_ID" = d."DB_ID"
JOIN "PARTITION_KEYS" p ON t."TBL_ID" = p."TBL_ID"
LEFT JOIN "TABLE_PARAMS" tp ON (t."TBL_ID" = tp."TBL_ID" AND tp."PARAM_KEY"='comment')
{where_clause_suffix}
UNION
SELECT t."TBL_ID" as tbl_id, d."NAME" as "schema", t."TBL_NAME" as name, t."TBL_TYPE",
tp."PARAM_VALUE" as description, c."COLUMN_NAME" as col_name, c."INTEGER_IDX" as col_sort_order,
c."TYPE_NAME" as col_type, c."COMMENT" as col_description, 0 as "is_partition_col",
CASE WHEN t."TBL_TYPE" = 'VIRTUAL_VIEW' THEN 1
ELSE 0 END as "is_view"
FROM "TBLS" t
JOIN "DBS" d ON t."DB_ID" = d."DB_ID"
JOIN "SDS" s ON t."SD_ID" = s."SD_ID"
JOIN "COLUMNS_V2" c ON s."CD_ID" = c."CD_ID"
LEFT JOIN "TABLE_PARAMS" tp ON (t."TBL_ID" = tp."TBL_ID" AND tp."PARAM_KEY"='comment')
{where_clause_suffix}
) source
ORDER by tbl_id, is_partition_col desc;
"""

# CONFIG KEYS
WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix'
CLUSTER_KEY = 'cluster'
Expand All @@ -67,20 +96,28 @@ def init(self, conf: ConfigTree) -> None:
conf = conf.with_fallback(HiveTableMetadataExtractor.DEFAULT_CONFIG)
self._cluster = '{}'.format(conf.get_string(HiveTableMetadataExtractor.CLUSTER_KEY))

default_sql = HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT.format(
self._alchemy_extractor = SQLAlchemyExtractor()

sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())
default_sql = self._choose_default_sql_stm(sql_alch_conf).format(
where_clause_suffix=conf.get_string(HiveTableMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY))

self.sql_stmt = conf.get_string(HiveTableMetadataExtractor.EXTRACT_SQL, default=default_sql)

LOGGER.info('SQL for hive metastore: {}'.format(self.sql_stmt))

self._alchemy_extractor = SQLAlchemyExtractor()
sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))

sql_alch_conf = sql_alch_conf.with_fallback(ConfigFactory.from_dict(
{SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))
self._alchemy_extractor.init(sql_alch_conf)
self._extract_iter: Union[None, Iterator] = None

def _choose_default_sql_stm(self, conf: ConfigTree) -> str:
url = make_url(conf.get_string(SQLAlchemyExtractor.CONN_STRING))
if url.drivername.lower() in ['postgresql', 'postgres']:
return HiveTableMetadataExtractor.DEFAULT_POSTGRES_SQL_STATEMENT
else:
return HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT

def extract(self) -> Union[TableMetadata, None]:
if not self._extract_iter:
self._extract_iter = self._get_extract_iter()
Expand Down
20 changes: 15 additions & 5 deletions tests/unit/extractor/test_hive_table_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ def test_extraction_with_empty_query_result(self) -> None:
"""
Test Extraction with empty result from query
"""
with patch.object(SQLAlchemyExtractor, '_get_connection'):
with patch.object(SQLAlchemyExtractor, '_get_connection'), \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
extractor = HiveTableMetadataExtractor()
extractor.init(self.conf)

results = extractor.extract()
self.assertEqual(results, None)

def test_extraction_with_single_result(self) -> None:
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection:
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection, \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
connection = MagicMock()
mock_connection.return_value = connection
sql_execute = MagicMock()
Expand Down Expand Up @@ -101,7 +105,9 @@ def test_extraction_with_single_result(self) -> None:
self.assertIsNone(extractor.extract())

def test_extraction_with_multiple_result(self) -> None:
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection:
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection, \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
connection = MagicMock()
mock_connection.return_value = connection
sql_execute = MagicMock()
Expand Down Expand Up @@ -240,7 +246,9 @@ def test_sql_statement(self) -> None:
"""
Test Extraction with empty result from query
"""
with patch.object(SQLAlchemyExtractor, '_get_connection'):
with patch.object(SQLAlchemyExtractor, '_get_connection'), \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
extractor = HiveTableMetadataExtractor()
extractor.init(self.conf)
self.assertTrue(self.where_clause_suffix in extractor.sql_stmt)
Expand All @@ -250,7 +258,9 @@ def test_hive_sql_statement_with_custom_sql(self) -> None:
Test Extraction by providing a custom sql
:return:
"""
with patch.object(SQLAlchemyExtractor, '_get_connection'):
with patch.object(SQLAlchemyExtractor, '_get_connection'), \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
config_dict = {
HiveTableMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix,
'extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING):
Expand Down