Skip to content

Commit

Permalink
Separate database names from database connections
Browse files Browse the repository at this point in the history
Database connections are typically not serializable and we cannot pass
them to subprocesses. So, perform connection explicitly with new
Backend.begin() method only inside worker process.
  • Loading branch information
kedder committed Apr 24, 2020
1 parent 38557db commit 91c0726
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 78 deletions.
70 changes: 53 additions & 17 deletions src/migrant/backend.py
Expand Up @@ -3,7 +3,7 @@
# Copyright 2014 by Shoobx, Inc.
#
###############################################################################

from typing import List, Iterable, Generic, TypeVar
import logging
import pkg_resources

Expand All @@ -12,24 +12,22 @@
log = logging.getLogger(__name__)


class MigrantBackend:
"""Base interface for backend implementations"""
# Database name type
DBN = TypeVar("DBN")

def list_migrations(self, db):
raise NotImplementedError # pragma: no cover
# Database connection type
DBC = TypeVar("DBC")

def push_migration(self, db, migration):
raise NotImplementedError # pragma: no cover

def pop_migration(self, db, migration):
raise NotImplementedError # pragma: no cover
class MigrantBackend(Generic[DBN, DBC]):
"""Base interface for backend implementations"""

def generate_connections(self):
def generate_connections(self) -> Iterable[DBN]:
"""Generate connections to process
"""
raise NotImplementedError # pragma: no cover

def generate_test_connections(self):
def generate_test_connections(self) -> Iterable[DBN]:
"""Generate connections for migration tests
The same interface as `generate_connections`, however this should not
Expand All @@ -40,33 +38,71 @@ def generate_test_connections(self):
"""
return [] # pargma: no cover

def on_new_script(self, rev_name):
def begin(self, db: DBN) -> DBC:
"""Begin the migration
This gives the opportunity to perform actual connection to the database
before any migration is run.
"""
raise NotImplementedError # pragma: no cover

def commit(self, db: DBC) -> None:
"""Called on successful completion of a migration
This is an opportunity to commit work and close the connection.
"""
pass

def abort(self, db: DBC) -> None:
"""Called when migration have failed"""
pass

def cleanup(self, db: DBC) -> None:
"""Complete the migration
Called when all work on database is done.
"""
pass

def list_migrations(self, db: DBC) -> List[str]:
raise NotImplementedError # pragma: no cover

def push_migration(self, db: DBC, migration: str) -> None:
raise NotImplementedError # pragma: no cover

def pop_migration(self, db: DBC, migration: str) -> None:
raise NotImplementedError # pragma: no cover

def on_new_script(self, rev_name: str) -> None:
"""Called when new script is created
"""
pass # pragma: no cover

def on_repo_init(self):
def on_repo_init(self) -> None:
"""Called when new script repository is initialized
"""
pass # pragma: no cover


class NoopBackend(MigrantBackend):
class NoopBackend(MigrantBackend[str, str]):
def __init__(self, cfg):
self.cfg = cfg

def list_migrations(self, db):
def list_migrations(self, db: str) -> List[str]:
return ["INITIAL"]

def push_migration(self, db, migration):
def push_migration(self, db: str, migration: str) -> None:
log.info("NOOP: pushing migration %s" % migration)

def pop_migration(self, db, migration):
def pop_migration(self, db: str, migration: str) -> None:
log.info("NOOP: popping migration %s" % migration)

def generate_connections(self):
yield "NOOP"

def begin(self, db: str) -> str:
return db


def create_backend(cfg):
name = cfg["backend"]
Expand Down
116 changes: 56 additions & 60 deletions src/migrant/engine.py
Expand Up @@ -3,7 +3,7 @@
# Copyright 2014 by Shoobx, Inc.
#
###############################################################################
from typing import NewType, Dict, Iterable, List, Tuple, Generator
from typing import TypeVar, Dict, List, Tuple, Generic
import logging
import multiprocessing
import functools
Expand All @@ -15,18 +15,16 @@

log = logging.getLogger(__name__)

DB = NewType("DB", object)
Actions = List[Tuple[str, str]]

DBN = TypeVar("DBN")
DBC = TypeVar("DBC")

# Semaphore to limit consumption of generated database connections
bottleneck: multiprocessing.Semaphore


class MigrantEngine:
class MigrantEngine(Generic[DBN, DBC]):
def __init__(
self,
backend: MigrantBackend,
backend: MigrantBackend[DBN, DBC],
repository: Repository,
config: Dict[str, str],
dry_run: bool = False,
Expand All @@ -46,80 +44,72 @@ def status(self, target_id: str = None) -> int:
conns = self.backend.generate_connections()

total_actions = 0
for db in self.initialized_dbs(conns):
actions = self.calc_actions(db, target_id)
for db in conns:
cdb = self.initialized_db(db)
actions = self.calc_actions(cdb, target_id)
total_actions += len(actions)

return total_actions

def _update(self, db: DB, target_id: str) -> None:
log.info(f"Starting migration for {db}")
actions = self.calc_actions(db, target_id)
self.execute_actions(db, actions)
log.info(f"Migration completed for {db}")
def _update(self, db: DBN, target_id: str) -> None:
cdb = self.initialized_db(db)
log.info(f"{_pname()}: Starting migration for {cdb}")
actions = self.calc_actions(cdb, target_id)
try:
self.execute_actions(cdb, actions)
self.backend.commit(cdb)
except:
self.backend.abort(cdb)
raise
finally:
self.backend.cleanup(cdb)
log.info(f"{_pname()}: Migration completed for {cdb}")

def update(self, target_id: str = None) -> None:
global bottleneck
target_id = self.pick_rev_id(target_id)
conns = self.backend.generate_connections()

f = functools.partial(self._update, target_id=target_id)

bottleneck = multiprocessing.Semaphore(self.processes)

with multiprocessing.Pool(self.processes) as pool:
# call all jobs by materialize result generator
results = []
for conn in self.initialized_dbs(conns):
# We want to limit consumption of generated database connection
# (conns) so that the generator is not consumed instantly, but
# only when free pool workers are available for processing. We
# acquire the lock here and release it when job is finished in
# self._finish.
bottleneck.acquire()
res = pool.apply_async(
f, (conn,), callback=self._finished, error_callback=self._finished
)
results.append(res)

# Wait for all results to complete
for res in results:
res.get()

def _finished(self, res) -> None:
global bottleneck
bottleneck.release()
if self.processes == 1:
for conn in conns:
f(conn)
else:
with multiprocessing.Pool(self.processes) as pool:
for _ in pool.imap_unordered(f, conns):
pass

def test(self, target_id: str = None) -> None:
target_id = self.pick_rev_id(target_id)
conns = self.backend.generate_test_connections()

for db in self.initialized_dbs(conns):
actions = self.calc_actions(db, target_id)
for db in conns:
cdb = self.initialized_db(db)
actions = self.calc_actions(cdb, target_id)

# Perform 2 passes of up/down to make sure database is still
# upgradeable after being downgraded.
for testpass in range(1, 3):
log.info(f"PASS {testpass}. Testing upgrade for {db}")
self.execute_actions(db, actions, strict=True)
log.info(f"PASS {testpass}. Testing upgrade for {cdb}")
self.execute_actions(cdb, actions, strict=True)

log.info(f"PASS {testpass}. Testing downgrade for {db}")
log.info(f"PASS {testpass}. Testing downgrade for {cdb}")
reverted_actions = self.revert_actions(actions)
self.execute_actions(db, reverted_actions, strict=True)
self.execute_actions(cdb, reverted_actions, strict=True)

log.info("Testing completed for %s" % db)
log.info("Testing completed for %s" % cdb)

def initialized_dbs(self, conns: Iterable[DB]) -> Generator[DB, None, None]:
for db in conns:
log.info("Preparing migrations for %s" % db)
migrations = self.backend.list_migrations(db)
if not migrations:
latest_revid = self.pick_rev_id(None)
self.initialize_db(db, latest_revid)
continue
def initialized_db(self, db: DBN) -> DBC:
log.info(f"{_pname()}: Preparing migrations for {db}")
cdb = self.backend.begin(db)
migrations = self.backend.list_migrations(cdb)
if not migrations:
latest_revid = self.pick_rev_id(None)
self.initialize_db(cdb, latest_revid)

yield db
return cdb

def initialize_db(self, db: DB, initial_revid: str) -> None:
def initialize_db(self, db: DBC, initial_revid: str):
"""Iniitialize database that was never migrated before
Assume it is fully up-to-date.
Expand All @@ -133,7 +123,9 @@ def initialize_db(self, db: DB, initial_revid: str) -> None:
sid = script.name
self.backend.push_migration(db, sid)

log.info(f"Initialized migrations for {db}. Assuming database is at {sid}")
log.info(
f"{_pname()}: Initialized migrations for {db}. Assuming database is at {sid}"
)

def pick_rev_id(self, rev_id: str = None) -> str:
if rev_id is None:
Expand All @@ -145,7 +137,7 @@ def pick_rev_id(self, rev_id: str = None) -> str:

return rev_id

def calc_actions(self, db: DB, target_revid: str) -> Actions:
def calc_actions(self, db: DBC, target_revid: str) -> Actions:
"""Caclulate actions, required to update to revision `target_revid`
"""
target_revid = canonical_rev_id(target_revid)
Expand Down Expand Up @@ -184,10 +176,10 @@ def revert_actions(self, actions: Actions) -> Actions:
reverts = [("+" if a == "-" else "-", script) for a, script in actions]
return list(reversed(reverts))

def list_backend_migrations(self, db: DB) -> List[str]:
def list_backend_migrations(self, db: DBC) -> List[str]:
return [canonical_rev_id(revid) for revid in self.backend.list_migrations(db)]

def execute_actions(self, db: DB, actions: Actions, strict: bool = False) -> None:
def execute_actions(self, db: DBC, actions: Actions, strict: bool = False) -> None:
for action, revid in actions:
script = self.repository.load_script(revid)
assert action in ("+", "-")
Expand Down Expand Up @@ -218,6 +210,10 @@ def execute_actions(self, db: DB, actions: Actions, strict: bool = False) -> Non
end(db, script.name)


def _pname() -> str:
return multiprocessing.current_process().name


def canonical_rev_id(migration_name: str) -> str:
if "_" in migration_name:
return migration_name.split("_", 1)[0]
Expand Down
3 changes: 3 additions & 0 deletions src/migrant/tests/test_cli.py
Expand Up @@ -60,6 +60,9 @@ def __init__(self, dbs, manager=None):
self.new_scripts = manager.list() if manager else []
self.inits = 0

def begin(self, db):
return db

def list_migrations(self, db):
return list(db.migrations)

Expand Down
6 changes: 5 additions & 1 deletion src/migrant/tests/test_engine.py
Expand Up @@ -157,6 +157,7 @@ def _make_engine(migrations, scripts, log=None):
log = log if log is not None else []
backend = mock.Mock()
backend.list_migrations.return_value = migrations
backend.begin = lambda db: db

backend.generate_test_connections.return_value = ["db1", "db2"]

Expand All @@ -169,14 +170,17 @@ def _make_engine(migrations, scripts, log=None):
return engine


class MultiDbBackend(MigrantBackend):
class MultiDbBackend(MigrantBackend[str, str]):
def __init__(self, dbs: List[str], logfname: str) -> None:
self._applied = {}
self.dbs = dbs
self.logfname = logfname
for db in dbs:
self._applied[db] = ["INITIAL"]

def begin(self, db: str) -> str:
return db

def list_migrations(self, db: str) -> List[str]:
return self._applied.get(db, [])

Expand Down

0 comments on commit 91c0726

Please sign in to comment.