Skip to content

Commit

Permalink
refactor: Use closing for DB sessions & add docstrings to `meltano.…
Browse files Browse the repository at this point in the history
…core.db` (meltano#6611)

* refactor: Replace `StaleJobFailer` class with `fail_stale_jobs` function

* refactor: Use `closing` for DB sessions & add docstrings to `meltano.core.db`
  • Loading branch information
WillDaSilva committed Aug 12, 2022
1 parent 175d7a1 commit c416f28
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 77 deletions.
11 changes: 11 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ ignore =
WPS306
# Allow inconsistent returns - permitted because of too many false positives
WPS324
# Allow syntax-level (process when first parsed, rather than at every run) string concatenation
WPS326
# Allow assignment expressions (walrus operator :=)
WPS332
# Allow 'incorrect' order of methods in a class
Expand Down Expand Up @@ -240,6 +242,15 @@ per-file-ignores =
tests/**/__init__.py:
# Allow for using __init__.py to promote imports to the module namespace
F401
# Don't require docstrings in tests
DAR101
DAR201
DAR301
D100
D101
D102
D103
D104
src/meltano/__init__.py:
# Found `__init__.py` module with logic
WPS412
Expand Down
4 changes: 2 additions & 2 deletions src/meltano/cli/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import click

from meltano.core.db import DB, project_engine
from meltano.core.db import ensure_schema_exists, project_engine

from . import cli
from .params import pass_project
Expand All @@ -25,4 +25,4 @@ def schema():
def create(project, schema_name, roles):
"""Create system DB schema, if not exists."""
engine, _ = project_engine(project)
DB.ensure_schema_exists(engine, schema_name, grant_roles=roles)
ensure_schema_exists(engine, schema_name, grant_roles=roles)
7 changes: 3 additions & 4 deletions src/meltano/cli/select.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Extractor selection management CLI."""
from __future__ import annotations

from contextlib import closing

import click

from meltano.core.db import project_engine
Expand Down Expand Up @@ -111,11 +113,8 @@ async def show(project, extractor, show_all=False):
_, Session = project_engine(project) # noqa: N806
select_service = SelectService(project, extractor)

session = Session()
try:
with closing(Session()) as session:
list_all = await select_service.list_all(session)
finally:
session.close()

# legend
click.secho("Legend:")
Expand Down
8 changes: 3 additions & 5 deletions src/meltano/core/block/extract_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import asyncio
import logging
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, closing
from typing import AsyncIterator

import structlog
Expand Down Expand Up @@ -459,11 +459,9 @@ async def run_with_job(self) -> None:
+ "To ignore this check use the '--force' option."
)

try: # noqa: WPS501
async with job.run(self.context.session):
with closing(self.context.session) as session:
async with job.run(session):
await self.execute()
finally:
self.context.session.close()

async def terminate(self, graceful: bool = False) -> None:
"""Terminate an in flight ExtractLoad execution, potentially disruptive.
Expand Down
152 changes: 102 additions & 50 deletions src/meltano/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,34 @@
import time

from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import text

from meltano.core.project import Project

from .project_settings_service import ProjectSettingsService

# Keep a Project → Engine mapping to serve
# the same engine for the same Project
_engines = {}


def project_engine(project, default=False) -> tuple(Engine, sessionmaker):
"""Create and register a SQLAlchemy engine for a Meltano project instance."""
def project_engine(
project: Project,
default: bool = False,
) -> tuple[Engine, sessionmaker]:
"""Create and register a SQLAlchemy engine for a Meltano project instance.
Args:
project: The Meltano project that the engine will be connected to.
default: Whether the engine created should be stored as the default
engine for this project.
Returns:
The engine, and a session maker bound to the engine.
"""
existing_engine = _engines.get(project)
if existing_engine:
return existing_engine
Expand All @@ -30,7 +44,8 @@ def project_engine(project, default=False) -> tuple(Engine, sessionmaker):
logging.debug(f"Creating engine {project}@{engine_uri}")
engine = create_engine(engine_uri, pool_pre_ping=True)

check_db_connection(
# Connect to the database to ensure it is available.
connect(
engine,
max_retries=settings.get("database_max_retries"),
retry_timeout=settings.get("database_retry_timeout"),
Expand All @@ -47,64 +62,101 @@ def project_engine(project, default=False) -> tuple(Engine, sessionmaker):
return engine_session


def check_db_connection(engine, max_retries, retry_timeout): # noqa: WPS231
"""Check if the database is available the first time a project's engine is created."""
def connect(
engine: Engine,
max_retries: int,
retry_timeout: float,
) -> Connection:
"""Connect to the database.
Args:
engine: The DB engine with which the check will be performed.
max_retries: The maximum number of retries that will be attempted.
retry_timeout: The number of seconds to wait between retries.
Raises:
OperationalError: Error during DB connection - max retries exceeded.
Returns:
A connection to the database.
"""
attempt = 0
while True:
try: # noqa: WPS503
engine.connect()
try:
return engine.connect()
except OperationalError:
if attempt == max_retries:
if attempt >= max_retries:
logging.error(
"Could not connect to the Database. Max retries exceeded."
f"Could not connect to the database after {attempt} "
"attempts. Max retries exceeded."
)
raise
attempt += 1
logging.info(
f"DB connection failed. Will retry after {retry_timeout}s. Attempt {attempt}/{max_retries}"
f"DB connection failed. Will retry after {retry_timeout}s. "
f"Attempt {attempt}/{max_retries}"
)
time.sleep(retry_timeout)
else:
break


def init_hook(engine):
function_map = {"sqlite": init_sqlite_hook}
init_hooks = {
"sqlite": lambda x: x.execute("PRAGMA journal_mode=WAL"),
}


def init_hook(engine: Engine) -> None:
"""Run the initialization hook for the provided DB engine.
The initialization hooks are taken from the `meltano.core.db.init_hooks`
dictionary, which maps the dialect name of the engine to a unary function
which will be called with the provided DB engine.
Args:
engine: The engine for which the init hook will be run.
Raises:
Exception: The init hook raised an exception.
"""
try:
function_map[engine.dialect.name](engine)
hook = init_hooks[engine.dialect.name]
except KeyError:
pass
except Exception as e:
raise Exception(f"Can't initialize database: {str(e)}") from e


def init_sqlite_hook(engine):
# enable the WAL
engine.execute("PRAGMA journal_mode=WAL")


class DB:
@classmethod
def ensure_schema_exists(cls, engine, schema_name, grant_roles=()):
"""Ensure the given schema_name exists in the database."""
schema_identifier = schema_name
group_identifiers = ",".join(grant_roles)

create_schema = text(f"CREATE SCHEMA IF NOT EXISTS {schema_identifier}")
grant_select_schema = text(
f"ALTER DEFAULT PRIVILEGES IN SCHEMA {schema_identifier} GRANT SELECT ON TABLES TO {group_identifiers}"
)
grant_usage_schema = text(
f"GRANT USAGE ON SCHEMA {schema_identifier} TO {group_identifiers}"
)

with engine.connect() as conn, conn.begin():
conn.execute(create_schema)
if grant_roles:
conn.execute(grant_select_schema)
conn.execute(grant_usage_schema)

logging.info(f"Schema {schema_name} has been created successfully.")
for role in grant_roles:
logging.info(f"Usage has been granted for role: {role}.")
return

try:
hook(engine)
except Exception as ex:
raise Exception(f"Failed to initialize database: {ex!s}") from ex


def ensure_schema_exists(
engine: Engine,
schema_name: str,
grant_roles: tuple[str] = (),
) -> None:
"""Ensure the specified `schema_name` exists in the database.
Args:
engine: The DB engine to be used.
schema_name: The name of the schema.
grant_roles: Roles to grant to the specified schema.
"""
schema_identifier = schema_name
group_identifiers = ",".join(grant_roles)

create_schema = text(f"CREATE SCHEMA IF NOT EXISTS {schema_identifier}")
grant_select_schema = text(
f"ALTER DEFAULT PRIVILEGES IN SCHEMA {schema_identifier} GRANT SELECT ON TABLES TO {group_identifiers}"
)
grant_usage_schema = text(
f"GRANT USAGE ON SCHEMA {schema_identifier} TO {group_identifiers}"
)

with engine.connect() as conn, conn.begin():
conn.execute(create_schema)
if grant_roles:
conn.execute(grant_select_schema)
conn.execute(grant_usage_schema)

logging.info(f"Schema {schema_name} has been created successfully.")
for role in grant_roles:
logging.info(f"Usage has been granted for role: {role}.")
6 changes: 2 additions & 4 deletions src/meltano/core/migration_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import logging
from contextlib import closing

import click
import sqlalchemy
Expand Down Expand Up @@ -120,12 +121,9 @@ def seed(self, project: Project) -> None:
project: The project to seed the database for.
"""
_, session_maker = project_engine(project)
session = session_maker()
try: # noqa: WPS501, WPS229 Found too long try body length and finally without except
with closing(session_maker()) as session:
self._create_user_role(session)
session.commit()
finally:
session.close()

def _create_user_role(self, session: Session) -> None:
"""Actually perform the database seeding creating users/roles.
Expand Down
2 changes: 1 addition & 1 deletion src/meltano/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
interval: str | None = None,
start_date: datetime.datetime | None = None,
job: str | None = None,
env: dict | None = None,
env: dict[str, str] | None = None,
):
"""Initialize a Schedule.
Expand Down
24 changes: 16 additions & 8 deletions tests/fixtures/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import warnings
from contextlib import closing
from typing import Generator

import pytest
Expand All @@ -10,6 +11,8 @@
from sqlalchemy.exc import SAWarning
from sqlalchemy.orm import close_all_sessions, sessionmaker

from meltano.core.project import Project


@pytest.fixture(scope="session", autouse=True)
def engine_uri_env(engine_uri: str) -> Generator:
Expand Down Expand Up @@ -74,12 +77,17 @@ def connection(engine_sessionmaker): # noqa: WPS442


@pytest.fixture()
def session(project, engine_sessionmaker, connection): # noqa: WPS442
"""Create a new database session for a test."""
_, create_session = engine_sessionmaker
def session(project: Project, engine_sessionmaker, connection): # noqa: WPS442
"""Create a new database session for a test.
session = create_session(bind=connection) # noqa: WPS442
try:
yield session
finally:
session.close()
Args:
project: The `project` fixture.
engine_sessionmaker: The `engine_sessionmaker` fixture.
connection: The `connection` fixture.
Yields:
An ORM DB session for the given project, bound to the given connection.
"""
_, create_session = engine_sessionmaker
with closing(create_session(bind=connection)) as fixture_session:
yield fixture_session
6 changes: 3 additions & 3 deletions tests/meltano/core/test_db_reconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mock import Mock
from sqlalchemy.exc import OperationalError

from meltano.core.db import check_db_connection
from meltano.core.db import connect


class TestConnectionRetries:
Expand All @@ -16,7 +16,7 @@ def test_ping_failure(self):
"test_error", "test_error", "test_error"
)
with pytest.raises(OperationalError):
check_db_connection(engine=engine_mock, max_retries=3, retry_timeout=0.1)
connect(engine=engine_mock, max_retries=3, retry_timeout=0.1)

assert engine_mock.connect.call_count == 4

Expand All @@ -27,5 +27,5 @@ def test_ping_failure(self):
None,
]

check_db_connection(engine=engine_mock, max_retries=3, retry_timeout=0.1)
connect(engine=engine_mock, max_retries=3, retry_timeout=0.1)
assert engine_mock.connect.call_count == 2

0 comments on commit c416f28

Please sign in to comment.