diff --git a/src/migrant/engine.py b/src/migrant/engine.py index 34f81d1..3e5e4b5 100644 --- a/src/migrant/engine.py +++ b/src/migrant/engine.py @@ -5,6 +5,7 @@ ############################################################################### import logging import multiprocessing +import functools log = logging.getLogger(__name__) @@ -42,8 +43,10 @@ def _update(self, db, target_id=None): def update(self, target_id=None): target_id = self.pick_rev_id(target_id) conns = self.backend.generate_connections() + f = functools.partial(self._update, target_id=target_id) + with multiprocessing.Pool() as pool: - pool.imap_unordered(self._update, ((db, target_id) for db in self.initialized_dbs(conns))) + pool.map(f, self.initialized_dbs(conns)) def test(self, target_id=None): target_id = self.pick_rev_id(target_id) diff --git a/src/migrant/tests/test_cli.py b/src/migrant/tests/test_cli.py index f71a9cb..ef8dfdd 100644 --- a/src/migrant/tests/test_cli.py +++ b/src/migrant/tests/test_cli.py @@ -12,6 +12,7 @@ import logging import pytest from configparser import SafeConfigParser +import multiprocessing from migrant import cli, backend, exceptions @@ -47,20 +48,20 @@ class MockedDb: - def __init__(self, name): + def __init__(self, name, manager): self.name = name - self.migrations = [] - self.data = {} + self.migrations = manager.list() + self.data = manager.dict() class MockedBackend(backend.MigrantBackend): - def __init__(self, dbs): + def __init__(self, dbs, manager=None): self.dbs = dbs - self.new_scripts = [] + self.new_scripts = manager.list() if manager else [] self.inits = 0 def list_migrations(self, db): - return db.migrations + return list(db.migrations) def push_migration(self, db, migration): db.migrations.append(migration) @@ -103,10 +104,12 @@ def test_get_db_config(self): class UpgradeTest(unittest.TestCase): + @pytest.fixture(autouse=True) def backend_fixture(self, tmpdir, migrant_backend): - self.db0 = MockedDb("db0") - self.backend = MockedBackend([self.db0]) + m = multiprocessing.Manager() + self.db0 = MockedDb("db0", m) + self.backend = MockedBackend([self.db0], m) migrant_backend.set(self.backend) migrant_ini = tmpdir.join("migrant.ini") @@ -122,7 +125,7 @@ def backend_fixture(self, tmpdir, migrant_backend): logging.root.setLevel(logging.INFO) def test_status_dirty(self): - self.db0.migrations = ["aaaa_first"] + self.db0.migrations.extend(["aaaa_first"]) args = cli.parser.parse_args(["test", "status"]) cli.dispatch(args, self.cfg) @@ -131,7 +134,7 @@ def test_status_dirty(self): self.assertIn("Pending actions: 2", log) def test_status_clean(self): - self.db0.migrations = ["aaaa_first", "bbbb_second", "cccc_third"] + self.db0.migrations.extend(["aaaa_first", "bbbb_second", "cccc_third"]) args = cli.parser.parse_args(["test", "status"]) cli.dispatch(args, self.cfg) @@ -142,14 +145,14 @@ def test_status_clean(self): def test_no_scripts(self): args = cli.parser.parse_args(["virgin", "upgrade"]) cli.dispatch(args, self.cfg) - self.assertEqual(self.db0.migrations, ["INITIAL"]) + self.assertEqual(list(self.db0.migrations), ["INITIAL"]) def test_initial_upgrade(self): args = cli.parser.parse_args(["test", "upgrade"]) cli.dispatch(args, self.cfg) self.assertEqual( - self.db0.migrations, ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"] + list(self.db0.migrations), ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"] ) def test_subsequent_emtpy_upgrade(self): @@ -159,58 +162,57 @@ def test_subsequent_emtpy_upgrade(self): cli.dispatch(args, self.cfg) self.assertEqual( - self.db0.migrations, ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"] + list(self.db0.migrations), ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"] ) def test_upgrade_latest(self): - self.db0.migrations = ["aaaa_first"] + self.db0.migrations.extend(["aaaa_first"]) args = cli.parser.parse_args(["test", "upgrade"]) cli.dispatch(args, self.cfg) self.assertEqual( - self.db0.migrations, ["aaaa_first", "bbbb_second", "cccc_third"] + list(self.db0.migrations), ["aaaa_first", "bbbb_second", "cccc_third"] ) - self.assertEqual(self.db0.data, {"hello": "world", "value": "c"}) + self.assertEqual(dict(self.db0.data), {"hello": "world", "value": "c"}) def test_upgrade_particular(self): - self.db0.migrations = ["aaaa_first"] - + self.db0.migrations.extend(["aaaa_first"]) args = cli.parser.parse_args(["test", "upgrade", "--revision", "bbbb_second"]) cli.dispatch(args, self.cfg) - self.assertEqual(self.db0.migrations, ["aaaa_first", "bbbb_second"]) - self.assertEqual(self.db0.data, {"value": "b"}) + self.assertEqual(list(self.db0.migrations), ["aaaa_first", "bbbb_second"]) + self.assertEqual(dict(self.db0.data), {"value": "b"}) def test_downgrade(self): - self.db0.migrations = ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"] - self.db0.data = {"hello": "world", "value": "c"} + self.db0.migrations.extend(["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]) + self.db0.data.update({"hello": "world", "value": "c"}) args = cli.parser.parse_args(["test", "upgrade", "--revision", "aaaa_first"]) cli.dispatch(args, self.cfg) - self.assertEqual(self.db0.migrations, ["INITIAL", "aaaa_first"]) - self.assertEqual(self.db0.data, {"value": "a"}) + self.assertEqual(list(self.db0.migrations), ["INITIAL", "aaaa_first"]) + self.assertEqual(dict(self.db0.data), {"value": "a"}) def test_downgrade_to_initial(self): - self.db0.migrations = ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"] - self.db0.data = {"hello": "world", "value": "c"} + self.db0.migrations.extend(["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]) + self.db0.data.update({"hello": "world", "value": "c"}) args = cli.parser.parse_args(["test", "upgrade", "--revision", "INITIAL"]) cli.dispatch(args, self.cfg) - self.assertEqual(self.db0.migrations, ["INITIAL"]) - self.assertEqual(self.db0.data, {}) + self.assertEqual(list(self.db0.migrations), ["INITIAL"]) + self.assertEqual(dict(self.db0.data), {}) def test_dry_run_upgrade(self): - self.db0.migrations = ["aaaa_first"] - self.db0.data = {"value": "a"} + self.db0.migrations.extend(["aaaa_first"]) + self.db0.data.update({"value": "a"}) args = cli.parser.parse_args(["test", "upgrade", "--dry-run"]) cli.dispatch(args, self.cfg) - self.assertEqual(self.db0.migrations, ["aaaa_first"]) - self.assertEqual(self.db0.data, {"value": "a"}) + self.assertEqual(list(self.db0.migrations), ["aaaa_first"]) + self.assertEqual(dict(self.db0.data), {"value": "a"}) def test_test(self): self.db0.migrations = ["INITIAL", "aaaa_first", "bbbb_second", "cccc_third"]