Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Avoid error by checking for existence before close. #454

Merged
merged 2 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 4 additions & 8 deletions databuilder/extractor/athena_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

from pyhocon import ConfigFactory, ConfigTree

from databuilder import Scoped
from databuilder.extractor import sql_alchemy_extractor
from databuilder.extractor.base_extractor import Extractor
from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor
from databuilder.models.table_metadata import ColumnMetadata, TableMetadata

TableKey = namedtuple('TableKey', ['schema', 'table_name'])
Expand Down Expand Up @@ -56,15 +55,12 @@ def init(self, conf: ConfigTree) -> None:

LOGGER.info('SQL for Athena metadata: %s', self.sql_stmt)

self._alchemy_extractor = SQLAlchemyExtractor()
sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))

self._alchemy_extractor.init(sql_alch_conf)
self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt)
self._extract_iter: Union[None, Iterator] = None

def close(self) -> None:
self._alchemy_extractor.close()
if getattr(self, '_alchemy_extractor', None) is not None:
self._alchemy_extractor.close()

def extract(self) -> Union[TableMetadata, None]:
if not self._extract_iter:
Expand Down
12 changes: 4 additions & 8 deletions databuilder/extractor/druid_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

from pyhocon import ConfigFactory, ConfigTree

from databuilder import Scoped
from databuilder.extractor import sql_alchemy_extractor
from databuilder.extractor.base_extractor import Extractor
from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor
from databuilder.models.table_metadata import ColumnMetadata, TableMetadata

TableKey = namedtuple('TableKey', ['schema', 'table_name'])
Expand Down Expand Up @@ -52,15 +51,12 @@ def init(self, conf: ConfigTree) -> None:
where_clause_suffix=conf.get_string(DruidMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY,
default=''))

self._alchemy_extractor = SQLAlchemyExtractor()
sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))

self._alchemy_extractor.init(sql_alch_conf)
self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt)
self._extract_iter: Union[None, Iterator] = None

def close(self) -> None:
self._alchemy_extractor.close()
if getattr(self, '_alchemy_extractor', None) is not None:
self._alchemy_extractor.close()

def extract(self) -> Union[TableMetadata, None]:
if not self._extract_iter:
Expand Down
13 changes: 4 additions & 9 deletions databuilder/extractor/mssql_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

from pyhocon import ConfigFactory, ConfigTree

from databuilder import Scoped
from databuilder.extractor import sql_alchemy_extractor
from databuilder.extractor.base_extractor import Extractor
from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor
from databuilder.models.table_metadata import ColumnMetadata, TableMetadata

TableKey = namedtuple('TableKey', ['schema_name', 'table_name'])
Expand Down Expand Up @@ -108,16 +107,12 @@ def init(self, conf: ConfigTree) -> None:

LOGGER.info('SQL for MS SQL Metadata: %s', self.sql_stmt)

self._alchemy_extractor = SQLAlchemyExtractor()
sql_alch_conf = Scoped \
.get_scoped_conf(conf, self._alchemy_extractor.get_scope()) \
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))

self._alchemy_extractor.init(sql_alch_conf)
self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt)
self._extract_iter: Union[None, Iterator] = None

def close(self) -> None:
self._alchemy_extractor.close()
if getattr(self, '_alchemy_extractor', None) is not None:
self._alchemy_extractor.close()

def extract(self) -> Union[TableMetadata, None]:
if not self._extract_iter:
Expand Down
12 changes: 4 additions & 8 deletions databuilder/extractor/presto_view_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

from pyhocon import ConfigFactory, ConfigTree

from databuilder import Scoped
from databuilder.extractor import sql_alchemy_extractor
from databuilder.extractor.base_extractor import Extractor
from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor
from databuilder.models.table_metadata import ColumnMetadata, TableMetadata

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,15 +54,12 @@ def init(self, conf: ConfigTree) -> None:

LOGGER.info('SQL for hive metastore: %s', self.sql_stmt)

self._alchemy_extractor = SQLAlchemyExtractor()
sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))

self._alchemy_extractor.init(sql_alch_conf)
self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt)
self._extract_iter: Union[None, Iterator] = None

def close(self) -> None:
self._alchemy_extractor.close()
if getattr(self, '_alchemy_extractor', None) is not None:
self._alchemy_extractor.close()

def extract(self) -> Union[TableMetadata, None]:
if not self._extract_iter:
Expand Down
12 changes: 4 additions & 8 deletions databuilder/extractor/snowflake_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from pyhocon import ConfigFactory, ConfigTree
from unidecode import unidecode

from databuilder import Scoped
from databuilder.extractor import sql_alchemy_extractor
from databuilder.extractor.base_extractor import Extractor
from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor
from databuilder.models.table_metadata import ColumnMetadata, TableMetadata

TableKey = namedtuple('TableKey', ['schema', 'table_name'])
Expand Down Expand Up @@ -99,15 +98,12 @@ def init(self, conf: ConfigTree) -> None:

LOGGER.info('SQL for snowflake metadata: %s', self.sql_stmt)

self._alchemy_extractor = SQLAlchemyExtractor()
sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope()) \
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))

self._alchemy_extractor.init(sql_alch_conf)
self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt)
self._extract_iter: Union[None, Iterator] = None

def close(self) -> None:
self._alchemy_extractor.close()
if getattr(self, '_alchemy_extractor', None) is not None:
self._alchemy_extractor.close()

def extract(self) -> Union[TableMetadata, None]:
if not self._extract_iter:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

from pyhocon import ConfigFactory, ConfigTree

from databuilder import Scoped
from databuilder.extractor import sql_alchemy_extractor
from databuilder.extractor.base_extractor import Extractor
from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor
from databuilder.models.table_last_updated import TableLastUpdated

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,15 +75,12 @@ def init(self, conf: ConfigTree) -> None:
LOGGER.info('SQL for snowflake table last updated timestamp: %s', self.sql_stmt)

# use an sql_alchemy_extractor to execute sql
self._alchemy_extractor = SQLAlchemyExtractor()
sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope()) \
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))

self._alchemy_extractor.init(sql_alch_conf)
self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt)
self._extract_iter: Union[None, Iterator] = None

def close(self) -> None:
self._alchemy_extractor.close()
if getattr(self, '_alchemy_extractor', None) is not None:
self._alchemy_extractor.close()

def extract(self) -> Union[TableLastUpdated, None]:
if not self._extract_iter:
Expand Down
22 changes: 20 additions & 2 deletions databuilder/extractor/sql_alchemy_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import importlib
from typing import Any

from pyhocon import ConfigTree
from pyhocon import ConfigFactory, ConfigTree
from sqlalchemy import create_engine

from databuilder import Scoped
from databuilder.extractor.base_extractor import Extractor


Expand Down Expand Up @@ -39,7 +40,8 @@ def init(self, conf: ConfigTree) -> None:
self._execute_query()

def close(self) -> None:
self.connection.close()
if self.connection is not None:
self.connection.close()

def _get_connection(self) -> Any:
"""
Expand Down Expand Up @@ -83,3 +85,19 @@ def extract(self) -> Any:

def get_scope(self) -> str:
return 'extractor.sqlalchemy'


def from_surrounding_config(conf: ConfigTree, sql_stmt: str) -> SQLAlchemyExtractor:
"""
A factory to create SQLAlchemyExtractors that are wrapped by another, specialized
extractor. This function pulls the config from the wrapping extractor's config, and
returns a newly configured SQLAlchemyExtractor.
:param conf: A config tree from which the sqlalchemy config still needs to be taken.
:param conf: The SQL statement to use for extraction. Expected to be set by the
wrapping extractor implementation, and not by the config.
"""
ae = SQLAlchemyExtractor()
c = Scoped.get_scoped_conf(conf, ae.get_scope()) \
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: sql_stmt}))
ae.init(c)
return ae