From d061ea80de8cccbf5caa1d0305d3e92cee24b1c6 Mon Sep 17 00:00:00 2001 From: Mike Date: Mon, 28 Aug 2023 16:49:01 +1200 Subject: [PATCH 1/3] Add reflection --- src/hdx/database/__init__.py | 28 ++++++++++++++---- tests/fixtures/test.db | Bin 0 -> 3072 bytes tests/hdx/database/test_database.py | 43 +++++++++++++++++++++++----- 3 files changed, 58 insertions(+), 13 deletions(-) create mode 100644 tests/fixtures/test.db diff --git a/src/hdx/database/__init__.py b/src/hdx/database/__init__.py index d7a323b..5646d92 100644 --- a/src/hdx/database/__init__.py +++ b/src/hdx/database/__init__.py @@ -3,6 +3,7 @@ from typing import Any, Optional, Type, Union from sqlalchemy import create_engine +from sqlalchemy.ext.automap import automap_base from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from sqlalchemy.pool import NullPool from sshtunnel import SSHTunnelForwarder @@ -11,7 +12,7 @@ from .dburi import get_connection_uri from .no_timezone import Base as NoTZBase from .postgresql import wait_for_postgresql -from .with_timezone import Base +from .with_timezone import Base as TZBase logger = logging.getLogger(__name__) @@ -32,6 +33,7 @@ class Database: dialect (str): Database dialect. Defaults to "postgresql". driver (Optional[str]): Database driver. Defaults to None (psycopg if postgresql or None) db_has_tz (bool): True if db datetime columns have timezone. Defaults to False. + reflect (bool): Whether to reflect existing tables. Defaults to False. **kwargs: See below ssh_host (str): SSH host (the server to connect to) ssh_port (int): SSH port. Defaults to 22. @@ -52,6 +54,7 @@ def __init__( dialect: str = "postgresql", driver: Optional[str] = None, db_has_tz: bool = False, + reflect: bool = False, **kwargs: Any, ) -> None: if port is not None: @@ -87,10 +90,12 @@ def __init__( if dialect == "postgresql": wait_for_postgresql(db_uri) if db_has_tz: - table_base = Base + table_base = TZBase else: table_base = NoTZBase - self.session = self.get_session(db_uri, table_base=table_base) + self.session = self.get_session( + db_uri, table_base=table_base, reflect=reflect + ) def __enter__(self) -> Session: return self.session @@ -102,7 +107,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: @staticmethod def get_session( - db_uri: str, table_base: Type[DeclarativeBase] = NoTZBase + db_uri: str, + table_base: Type[DeclarativeBase] = NoTZBase, + reflect: bool = False, ) -> Session: """Gets SQLAlchemy session given url. Tables must inherit from Base in hdx.utilities.database unless base is defined. @@ -110,11 +117,20 @@ def get_session( Args: db_uri (str): Connection URI table_base (Type[DeclarativeBase]): Base database table class. Defaults to NoTZBase. + reflect (bool): Whether to reflect existing tables. Defaults to False. Returns: sqlalchemy.orm.Session: SQLAlchemy session """ engine = create_engine(db_uri, poolclass=NullPool, echo=False) Session = sessionmaker(bind=engine) - table_base.metadata.create_all(engine) - return Session() + if reflect: + Base = automap_base(declarative_base=table_base) + Base.prepare(autoload_with=engine) + classes = Base.classes + else: + table_base.metadata.create_all(engine) + classes = None + session = Session() + session.reflected_classes = classes + return session diff --git a/tests/fixtures/test.db b/tests/fixtures/test.db new file mode 100644 index 0000000000000000000000000000000000000000..81e2aee079b9184d81702904f394e3f62cbea054 GIT binary patch literal 3072 zcmeH`&riZI6vx|jQNa+7xOjP)7l=S8sLAl46L%oaA)_2PEz@<$lEDIu$W{M6{{;U7 z{{j;)c1+G(jrQgJ=(l}s+E3Ea(2i0LVX~O|DRhvEFhoo;$C*-*y#_)7E~b$pa!*5Wv_CjVHC{*zFy8_@vl99l_neQ<=63k?wbAUgN$=Y z1SaoDxHL&X68Jj=POw@&zNF9Lf``G98b+tlG8*j;(6cUWch4`Hl+vPjFD73|xHL&X M68Jv^R8lU`Z(lu6ssI20 literal 0 HcmV?d00001 diff --git a/tests/hdx/database/test_database.py b/tests/hdx/database/test_database.py index 3b5a0de..531011d 100755 --- a/tests/hdx/database/test_database.py +++ b/tests/hdx/database/test_database.py @@ -1,17 +1,19 @@ """Database Utility Tests""" -import copy -import os from collections import namedtuple +from copy import deepcopy from datetime import datetime, timezone +from os import remove from os.path import join +from shutil import copyfile import pytest from sqlalchemy import select from sshtunnel import SSHTunnelForwarder from .dbtestdate import DBTestDate -from hdx.database import Base, Database +from hdx.database import Database from hdx.database.no_timezone import Base as NoTZBase +from hdx.database.with_timezone import Base as TZBase class TestDatabase: @@ -19,6 +21,7 @@ class TestDatabase: stopped = False table_base = NoTZBase dbpath = join("tests", "test_database.db") + testdb = join("tests", "fixtures", "test.db") params_pg = { "database": "mydatabase", "host": "myserver", @@ -32,7 +35,16 @@ class TestDatabase: @pytest.fixture(scope="function") def nodatabase(self): try: - os.remove(TestDatabase.dbpath) + remove(TestDatabase.dbpath) + except OSError: + pass + return f"sqlite:///{TestDatabase.dbpath}" + + @pytest.fixture(scope="function") + def database_to_reflect(self): + try: + remove(TestDatabase.dbpath) + copyfile(TestDatabase.testdb, TestDatabase.dbpath) except OSError: pass return f"sqlite:///{TestDatabase.dbpath}" @@ -54,7 +66,7 @@ def stop(_): monkeypatch.setattr(SSHTunnelForwarder, "local_bind_host", "0.0.0.0") monkeypatch.setattr(SSHTunnelForwarder, "local_bind_port", 12345) - def get_session(_, db_url, table_base): + def get_session(_, db_url, table_base, reflect): class Session: bind = namedtuple("Bind", "engine") @@ -82,6 +94,23 @@ def test_get_session(self, nodatabase): dbtestdate = dbsession.execute(select(DBTestDate)).scalar_one() assert dbtestdate.test_date == now + def test_get_reflect_session(self, database_to_reflect): + with Database( + database=TestDatabase.dbpath, + port=None, + dialect="sqlite", + reflect=True, + ) as dbsession: + assert TestDatabase.table_base == NoTZBase + assert str(dbsession.bind.engine.url) == database_to_reflect + Table1 = dbsession.reflected_classes.table1 + row = dbsession.execute(select(Table1)).scalar_one() + assert row.id == "1" + assert row.col1 == "wfrefds" + assert row.date1 == datetime( + 1993, 9, 23, 14, 12, 56, 111000, tzinfo=timezone.utc + ) + def test_get_session_ssh(self, mock_psycopg, mock_SSHTunnelForwarder): with Database( ssh_host="mysshhost", **TestDatabase.params_pg @@ -90,7 +119,7 @@ def test_get_session_ssh(self, mock_psycopg, mock_SSHTunnelForwarder): str(dbsession.bind.engine.url) == "postgresql+psycopg://myuser:mypass@0.0.0.0:12345/mydatabase" ) - params = copy.deepcopy(TestDatabase.params_pg) + params = deepcopy(TestDatabase.params_pg) del params["password"] with Database( ssh_host="mysshhost", ssh_port=25, **params @@ -107,4 +136,4 @@ def test_get_session_ssh(self, mock_psycopg, mock_SSHTunnelForwarder): str(dbsession.bind.engine.url) == "postgresql+psycopg://myuser@0.0.0.0:12345/mydatabase" ) - assert TestDatabase.table_base == Base + assert TestDatabase.table_base == TZBase From 2931a2021812bf12cf07b9a02c60696d76c8333f Mon Sep 17 00:00:00 2001 From: Mike Date: Tue, 29 Aug 2023 10:37:28 +1200 Subject: [PATCH 2/3] Add reflection --- documentation/main.md | 12 ++++++++++-- src/hdx/database/__init__.py | 16 +++++++++++----- tests/hdx/database/test_database.py | 5 +++-- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/documentation/main.md b/documentation/main.md index 864c976..69cc793 100755 --- a/documentation/main.md +++ b/documentation/main.md @@ -58,14 +58,22 @@ tunnel: with Database(database="db", host="1.2.3.4", username="user", password="pass", dialect="dialect", driver="driver", ssh_host="5.6.7.8", ssh_port=2222, ssh_username="sshuser", - ssh_private_key="path_to_key", db_has_tz=True) as session: + ssh_private_key="path_to_key", db_has_tz=True, + reflect=False) as session: session.query(...) `db_has_tz` which defaults to `False` indicates whether database datetime columns have timezones. If `db_has_tz` is `True`, use `Base` from `hdx.database.with_timezone`, otherwise use `Base` from `hdx.database.no_timezone`. If `db_has_tz` is `False`, conversion occurs -between Python datetimes with timezones to timezoneless database columns. +between Python datetimes with timezones to timezoneless database columns. + +If `reflect` (which defaults to `False`) is `True`, classes will be reflected +from an existing database and the reflected classes are returned in a variable +`reflected_classes` in the returned Session object. Note that type annotation +maps don't work with reflection and hence `db_has_tz` will be ignored ie. +there will be no conversion between Python datetimes with timezones to +timezoneless database columns. ## Connection URI diff --git a/src/hdx/database/__init__.py b/src/hdx/database/__init__.py index 5646d92..ea2ee06 100644 --- a/src/hdx/database/__init__.py +++ b/src/hdx/database/__init__.py @@ -19,10 +19,12 @@ class Database: """Database helper class to handle ssh tunnels, waiting for PostgreSQL to - be up etc. Can be used in a with statement returning a Session object. - db_has_tz which defaults to False indicates whether database datetime - columns have timezones. If not, conversion occurs between Python datetimes - with timezones to timezoneless database columns. + be up etc. Can be used in a with statement returning a Session object that + if reflect is True will have a variable reflected_classes containing the + reflected classes. db_has_tz which defaults to False indicates whether + database datetime columns have timezones. If not, conversion occurs between + Python datetimes with timezones to timezoneless database columns (but not + when using reflection). Args: database (Optional[str]): Database name @@ -112,7 +114,11 @@ def get_session( reflect: bool = False, ) -> Session: """Gets SQLAlchemy session given url. Tables must inherit from Base in - hdx.utilities.database unless base is defined. + hdx.utilities.database unless base is defined. If reflect is True, + classes will be reflected from an existing database and the reflected + classes are returned in a variable reflected_classes in the returned + Session object. Note that type annotation maps don't work with + reflection. Args: db_uri (str): Connection URI diff --git a/tests/hdx/database/test_database.py b/tests/hdx/database/test_database.py index 531011d..13fb7b1 100755 --- a/tests/hdx/database/test_database.py +++ b/tests/hdx/database/test_database.py @@ -108,8 +108,9 @@ def test_get_reflect_session(self, database_to_reflect): assert row.id == "1" assert row.col1 == "wfrefds" assert row.date1 == datetime( - 1993, 9, 23, 14, 12, 56, 111000, tzinfo=timezone.utc - ) + 1993, 9, 23, 14, 12, 56, 111000 + ) # with reflection, type annotation maps do not work and hence + # we don't have a timezone here def test_get_session_ssh(self, mock_psycopg, mock_SSHTunnelForwarder): with Database( From 58587a9c109d635a19a2b8309a8dda2a8369660d Mon Sep 17 00:00:00 2001 From: Mike Date: Tue, 29 Aug 2023 10:38:53 +1200 Subject: [PATCH 3/3] Add reflection --- tests/hdx/database/test_database.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/hdx/database/test_database.py b/tests/hdx/database/test_database.py index 13fb7b1..c7023ef 100755 --- a/tests/hdx/database/test_database.py +++ b/tests/hdx/database/test_database.py @@ -107,10 +107,9 @@ def test_get_reflect_session(self, database_to_reflect): row = dbsession.execute(select(Table1)).scalar_one() assert row.id == "1" assert row.col1 == "wfrefds" - assert row.date1 == datetime( - 1993, 9, 23, 14, 12, 56, 111000 - ) # with reflection, type annotation maps do not work and hence + # with reflection, type annotation maps do not work and hence # we don't have a timezone here + assert row.date1 == datetime(1993, 9, 23, 14, 12, 56, 111000) def test_get_session_ssh(self, mock_psycopg, mock_SSHTunnelForwarder): with Database(