Skip to content

Commit

Permalink
make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
jweede authored and nerdling committed Mar 13, 2020
1 parent e4840ef commit 644bd71
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
5 changes: 4 additions & 1 deletion src/migrant/engine.py
Expand Up @@ -5,6 +5,7 @@
###############################################################################
import logging
import multiprocessing
import functools

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,8 +43,10 @@ def _update(self, db, target_id=None):
def update(self, target_id=None):
target_id = self.pick_rev_id(target_id)
conns = self.backend.generate_connections()
f = functools.partial(self._update, target_id=target_id)

with multiprocessing.Pool() as pool:
pool.imap_unordered(self._update, ((db, target_id) for db in self.initialized_dbs(conns)))
pool.map(f, self.initialized_dbs(conns))

def test(self, target_id=None):
target_id = self.pick_rev_id(target_id)
Expand Down
66 changes: 34 additions & 32 deletions src/migrant/tests/test_cli.py
Expand Up @@ -12,6 +12,7 @@
import logging
import pytest
from configparser import SafeConfigParser
import multiprocessing

from migrant import cli, backend, exceptions

Expand Down Expand Up @@ -47,20 +48,20 @@


class MockedDb:
def __init__(self, name):
def __init__(self, name, manager):
self.name = name
self.migrations = []
self.data = {}
self.migrations = manager.list()
self.data = manager.dict()


class MockedBackend(backend.MigrantBackend):
def __init__(self, dbs):
def __init__(self, dbs, manager=None):
self.dbs = dbs
self.new_scripts = []
self.new_scripts = manager.list() if manager else []
self.inits = 0

def list_migrations(self, db):
return db.migrations
return list(db.migrations)

def push_migration(self, db, migration):
db.migrations.append(migration)
Expand Down Expand Up @@ -103,10 +104,12 @@ def test_get_db_config(self):


class UpgradeTest(unittest.TestCase):

@pytest.fixture(autouse=True)
def backend_fixture(self, tmpdir, migrant_backend):
self.db0 = MockedDb("db0")
self.backend = MockedBackend([self.db0])
m = multiprocessing.Manager()
self.db0 = MockedDb("db0", m)
self.backend = MockedBackend([self.db0], m)
migrant_backend.set(self.backend)

migrant_ini = tmpdir.join("migrant.ini")
Expand All @@ -122,7 +125,7 @@ def backend_fixture(self, tmpdir, migrant_backend):
logging.root.setLevel(logging.INFO)

def test_status_dirty(self):
self.db0.migrations = ["aaaa_first"]
self.db0.migrations.extend(["aaaa_first"])

args = cli.parser.parse_args(["test", "status"])
cli.dispatch(args, self.cfg)
Expand All @@ -131,7 +134,7 @@ def test_status_dirty(self):
self.assertIn("Pending actions: 2", log)

def test_status_clean(self):
self.db0.migrations = ["aaaa_first", "bbbb_second", "cccc_third"]
self.db0.migrations.extend(["aaaa_first", "bbbb_second", "cccc_third"])

args = cli.parser.parse_args(["test", "status"])
cli.dispatch(args, self.cfg)
Expand All @@ -142,14 +145,14 @@ def test_status_clean(self):
def test_no_scripts(self):
args = cli.parser.parse_args(["virgin", "upgrade"])
cli.dispatch(args, self.cfg)
self.assertEqual(self.db0.migrations, ["INITIAL"])
self.assertEqual(list(self.db0.migrations), ["INITIAL"])

def test_initial_upgrade(self):
args = cli.parser.parse_args(["test", "upgrade"])
cli.dispatch(args, self.cfg)

self.assertEqual(
self.db0.migrations, ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]
list(self.db0.migrations), ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]
)

def test_subsequent_emtpy_upgrade(self):
Expand All @@ -159,58 +162,57 @@ def test_subsequent_emtpy_upgrade(self):
cli.dispatch(args, self.cfg)

self.assertEqual(
self.db0.migrations, ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]
list(self.db0.migrations), ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]
)

def test_upgrade_latest(self):
self.db0.migrations = ["aaaa_first"]
self.db0.migrations.extend(["aaaa_first"])

args = cli.parser.parse_args(["test", "upgrade"])
cli.dispatch(args, self.cfg)

self.assertEqual(
self.db0.migrations, ["aaaa_first", "bbbb_second", "cccc_third"]
list(self.db0.migrations), ["aaaa_first", "bbbb_second", "cccc_third"]
)
self.assertEqual(self.db0.data, {"hello": "world", "value": "c"})
self.assertEqual(dict(self.db0.data), {"hello": "world", "value": "c"})

def test_upgrade_particular(self):
self.db0.migrations = ["aaaa_first"]

self.db0.migrations.extend(["aaaa_first"])
args = cli.parser.parse_args(["test", "upgrade", "--revision", "bbbb_second"])
cli.dispatch(args, self.cfg)

self.assertEqual(self.db0.migrations, ["aaaa_first", "bbbb_second"])
self.assertEqual(self.db0.data, {"value": "b"})
self.assertEqual(list(self.db0.migrations), ["aaaa_first", "bbbb_second"])
self.assertEqual(dict(self.db0.data), {"value": "b"})

def test_downgrade(self):
self.db0.migrations = ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]
self.db0.data = {"hello": "world", "value": "c"}
self.db0.migrations.extend(["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"])
self.db0.data.update({"hello": "world", "value": "c"})

args = cli.parser.parse_args(["test", "upgrade", "--revision", "aaaa_first"])
cli.dispatch(args, self.cfg)

self.assertEqual(self.db0.migrations, ["INITIAL", "aaaa_first"])
self.assertEqual(self.db0.data, {"value": "a"})
self.assertEqual(list(self.db0.migrations), ["INITIAL", "aaaa_first"])
self.assertEqual(dict(self.db0.data), {"value": "a"})

def test_downgrade_to_initial(self):
self.db0.migrations = ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]
self.db0.data = {"hello": "world", "value": "c"}
self.db0.migrations.extend(["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"])
self.db0.data.update({"hello": "world", "value": "c"})

args = cli.parser.parse_args(["test", "upgrade", "--revision", "INITIAL"])
cli.dispatch(args, self.cfg)

self.assertEqual(self.db0.migrations, ["INITIAL"])
self.assertEqual(self.db0.data, {})
self.assertEqual(list(self.db0.migrations), ["INITIAL"])
self.assertEqual(dict(self.db0.data), {})

def test_dry_run_upgrade(self):
self.db0.migrations = ["aaaa_first"]
self.db0.data = {"value": "a"}
self.db0.migrations.extend(["aaaa_first"])
self.db0.data.update({"value": "a"})

args = cli.parser.parse_args(["test", "upgrade", "--dry-run"])
cli.dispatch(args, self.cfg)

self.assertEqual(self.db0.migrations, ["aaaa_first"])
self.assertEqual(self.db0.data, {"value": "a"})
self.assertEqual(list(self.db0.migrations), ["aaaa_first"])
self.assertEqual(dict(self.db0.data), {"value": "a"})

def test_test(self):
self.db0.migrations = ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]
Expand Down

0 comments on commit 644bd71

Please sign in to comment.