Skip to content

Commit

Permalink
Merge pull request #3 from OCHA-DAP/reflection
Browse files Browse the repository at this point in the history
Add reflection support
  • Loading branch information
mcarans committed Aug 28, 2023
2 parents 8accaaa + 58587a9 commit 6b90d52
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 20 deletions.
12 changes: 10 additions & 2 deletions documentation/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,22 @@ tunnel:
with Database(database="db", host="1.2.3.4", username="user",
password="pass", dialect="dialect", driver="driver",
ssh_host="5.6.7.8", ssh_port=2222, ssh_username="sshuser",
ssh_private_key="path_to_key", db_has_tz=True) as session:
ssh_private_key="path_to_key", db_has_tz=True,
reflect=False) as session:
session.query(...)

`db_has_tz` which defaults to `False` indicates whether database datetime
columns have timezones. If `db_has_tz` is `True`, use `Base` from
`hdx.database.with_timezone`, otherwise use `Base` from
`hdx.database.no_timezone`. If `db_has_tz` is `False`, conversion occurs
between Python datetimes with timezones to timezoneless database columns.
between Python datetimes with timezones to timezoneless database columns.

If `reflect` (which defaults to `False`) is `True`, classes will be reflected
from an existing database and the reflected classes are returned in a variable
`reflected_classes` in the returned Session object. Note that type annotation
maps don't work with reflection and hence `db_has_tz` will be ignored ie.
there will be no conversion between Python datetimes with timezones to
timezoneless database columns.

## Connection URI

Expand Down
44 changes: 33 additions & 11 deletions src/hdx/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional, Type, Union

from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from sqlalchemy.pool import NullPool
from sshtunnel import SSHTunnelForwarder
Expand All @@ -11,17 +12,19 @@
from .dburi import get_connection_uri
from .no_timezone import Base as NoTZBase
from .postgresql import wait_for_postgresql
from .with_timezone import Base
from .with_timezone import Base as TZBase

logger = logging.getLogger(__name__)


class Database:
"""Database helper class to handle ssh tunnels, waiting for PostgreSQL to
be up etc. Can be used in a with statement returning a Session object.
db_has_tz which defaults to False indicates whether database datetime
columns have timezones. If not, conversion occurs between Python datetimes
with timezones to timezoneless database columns.
be up etc. Can be used in a with statement returning a Session object that
if reflect is True will have a variable reflected_classes containing the
reflected classes. db_has_tz which defaults to False indicates whether
database datetime columns have timezones. If not, conversion occurs between
Python datetimes with timezones to timezoneless database columns (but not
when using reflection).
Args:
database (Optional[str]): Database name
Expand All @@ -32,6 +35,7 @@ class Database:
dialect (str): Database dialect. Defaults to "postgresql".
driver (Optional[str]): Database driver. Defaults to None (psycopg if postgresql or None)
db_has_tz (bool): True if db datetime columns have timezone. Defaults to False.
reflect (bool): Whether to reflect existing tables. Defaults to False.
**kwargs: See below
ssh_host (str): SSH host (the server to connect to)
ssh_port (int): SSH port. Defaults to 22.
Expand All @@ -52,6 +56,7 @@ def __init__(
dialect: str = "postgresql",
driver: Optional[str] = None,
db_has_tz: bool = False,
reflect: bool = False,
**kwargs: Any,
) -> None:
if port is not None:
Expand Down Expand Up @@ -87,10 +92,12 @@ def __init__(
if dialect == "postgresql":
wait_for_postgresql(db_uri)
if db_has_tz:
table_base = Base
table_base = TZBase
else:
table_base = NoTZBase
self.session = self.get_session(db_uri, table_base=table_base)
self.session = self.get_session(
db_uri, table_base=table_base, reflect=reflect
)

def __enter__(self) -> Session:
return self.session
Expand All @@ -102,19 +109,34 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:

@staticmethod
def get_session(
db_uri: str, table_base: Type[DeclarativeBase] = NoTZBase
db_uri: str,
table_base: Type[DeclarativeBase] = NoTZBase,
reflect: bool = False,
) -> Session:
"""Gets SQLAlchemy session given url. Tables must inherit from Base in
hdx.utilities.database unless base is defined.
hdx.utilities.database unless base is defined. If reflect is True,
classes will be reflected from an existing database and the reflected
classes are returned in a variable reflected_classes in the returned
Session object. Note that type annotation maps don't work with
reflection.
Args:
db_uri (str): Connection URI
table_base (Type[DeclarativeBase]): Base database table class. Defaults to NoTZBase.
reflect (bool): Whether to reflect existing tables. Defaults to False.
Returns:
sqlalchemy.orm.Session: SQLAlchemy session
"""
engine = create_engine(db_uri, poolclass=NullPool, echo=False)
Session = sessionmaker(bind=engine)
table_base.metadata.create_all(engine)
return Session()
if reflect:
Base = automap_base(declarative_base=table_base)
Base.prepare(autoload_with=engine)
classes = Base.classes
else:
table_base.metadata.create_all(engine)
classes = None
session = Session()
session.reflected_classes = classes
return session
Binary file added tests/fixtures/test.db
Binary file not shown.
43 changes: 36 additions & 7 deletions tests/hdx/database/test_database.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
"""Database Utility Tests"""
import copy
import os
from collections import namedtuple
from copy import deepcopy
from datetime import datetime, timezone
from os import remove
from os.path import join
from shutil import copyfile

import pytest
from sqlalchemy import select
from sshtunnel import SSHTunnelForwarder

from .dbtestdate import DBTestDate
from hdx.database import Base, Database
from hdx.database import Database
from hdx.database.no_timezone import Base as NoTZBase
from hdx.database.with_timezone import Base as TZBase


class TestDatabase:
started = False
stopped = False
table_base = NoTZBase
dbpath = join("tests", "test_database.db")
testdb = join("tests", "fixtures", "test.db")
params_pg = {
"database": "mydatabase",
"host": "myserver",
Expand All @@ -32,7 +35,16 @@ class TestDatabase:
@pytest.fixture(scope="function")
def nodatabase(self):
try:
os.remove(TestDatabase.dbpath)
remove(TestDatabase.dbpath)
except OSError:
pass
return f"sqlite:///{TestDatabase.dbpath}"

@pytest.fixture(scope="function")
def database_to_reflect(self):
try:
remove(TestDatabase.dbpath)
copyfile(TestDatabase.testdb, TestDatabase.dbpath)
except OSError:
pass
return f"sqlite:///{TestDatabase.dbpath}"
Expand All @@ -54,7 +66,7 @@ def stop(_):
monkeypatch.setattr(SSHTunnelForwarder, "local_bind_host", "0.0.0.0")
monkeypatch.setattr(SSHTunnelForwarder, "local_bind_port", 12345)

def get_session(_, db_url, table_base):
def get_session(_, db_url, table_base, reflect):
class Session:
bind = namedtuple("Bind", "engine")

Expand Down Expand Up @@ -82,6 +94,23 @@ def test_get_session(self, nodatabase):
dbtestdate = dbsession.execute(select(DBTestDate)).scalar_one()
assert dbtestdate.test_date == now

def test_get_reflect_session(self, database_to_reflect):
with Database(
database=TestDatabase.dbpath,
port=None,
dialect="sqlite",
reflect=True,
) as dbsession:
assert TestDatabase.table_base == NoTZBase
assert str(dbsession.bind.engine.url) == database_to_reflect
Table1 = dbsession.reflected_classes.table1
row = dbsession.execute(select(Table1)).scalar_one()
assert row.id == "1"
assert row.col1 == "wfrefds"
# with reflection, type annotation maps do not work and hence
# we don't have a timezone here
assert row.date1 == datetime(1993, 9, 23, 14, 12, 56, 111000)

def test_get_session_ssh(self, mock_psycopg, mock_SSHTunnelForwarder):
with Database(
ssh_host="mysshhost", **TestDatabase.params_pg
Expand All @@ -90,7 +119,7 @@ def test_get_session_ssh(self, mock_psycopg, mock_SSHTunnelForwarder):
str(dbsession.bind.engine.url)
== "postgresql+psycopg://myuser:mypass@0.0.0.0:12345/mydatabase"
)
params = copy.deepcopy(TestDatabase.params_pg)
params = deepcopy(TestDatabase.params_pg)
del params["password"]
with Database(
ssh_host="mysshhost", ssh_port=25, **params
Expand All @@ -107,4 +136,4 @@ def test_get_session_ssh(self, mock_psycopg, mock_SSHTunnelForwarder):
str(dbsession.bind.engine.url)
== "postgresql+psycopg://myuser@0.0.0.0:12345/mydatabase"
)
assert TestDatabase.table_base == Base
assert TestDatabase.table_base == TZBase

0 comments on commit 6b90d52

Please sign in to comment.