From faf3108e433b3131ed88071cd6dc1a6f5bf93440 Mon Sep 17 00:00:00 2001 From: Guilherme Chapiewski Date: Wed, 11 Mar 2009 18:34:38 -0300 Subject: [PATCH] #15: 2 of 4 - Remove from migrations history on migration DOWN. --- src/simple_db_migrate/main.py | 2 +- src/simple_db_migrate/mysql.py | 13 +++++++++---- tests/mysql_test.py | 25 ++++++++++++++++++++----- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/simple_db_migrate/main.py b/src/simple_db_migrate/main.py index 1441a42..aa28538 100644 --- a/src/simple_db_migrate/main.py +++ b/src/simple_db_migrate/main.py @@ -58,7 +58,7 @@ def __migrate(self): print "===== executing %s (%s) =====" % (sql_file, migration) sql = self.__db_migrate.get_sql_command(sql_file, migration_up) - self.__mysql.change(sql, file_version) + self.__mysql.change(sql, file_version, migration_up) #recording the last statement executed sql_statements_executed += sql diff --git a/src/simple_db_migrate/mysql.py b/src/simple_db_migrate/mysql.py index a311cb9..63aa1e0 100644 --- a/src/simple_db_migrate/mysql.py +++ b/src/simple_db_migrate/mysql.py @@ -59,13 +59,18 @@ def __create_version_table_if_not_exists(self): sql = "insert into __db_version__ values (0);" self.__execute(sql) - def __insert_new_db_version(self, version): - sql = "insert into __db_version__ (version) values (\"%s\");" % str(version) + def __change_db_version(self, version, up=True): + if up: + # moving up and storing history + sql = "insert into __db_version__ (version) values (\"%s\");" % str(version) + else: + # moving down and deleting from history + sql = "delete from __db_version__ where version > \"%s\";" % str(version) self.__execute(sql) - def change(self, sql, new_db_version): + def change(self, sql, new_db_version, up=True): self.__execute(sql) - self.__insert_new_db_version(new_db_version) + self.__change_db_version(new_db_version, up) def get_current_schema_version(self): db = self.__mysql_connect() diff --git a/tests/mysql_test.py b/tests/mysql_test.py index b31ae6f..60a06f1 100644 --- a/tests/mysql_test.py +++ b/tests/mysql_test.py @@ -47,21 +47,36 @@ def test_it_should_create_database_and_version_table_on_init_if_not_exists(self) mysql = MySQL("test.conf", mysql_driver_mock) - def test_it_should_execute_changes_and_update_schema_version(self): + def test_it_should_execute_migration_up_and_remove_from_schema_version(self): mysql_driver_mock = Mock() db_mock = Mock() cursor_mock = Mock() - + self.__create_init_expectations(mysql_driver_mock, db_mock, cursor_mock) - + db_mock.expects(once()).method("query").query(eq("create table spam();")) db_mock.expects(once()).method("close") db_mock.expects(once()).method("query").query(eq("insert into __db_version__ (version) values (\"20090212112104\");")) db_mock.expects(once()).method("close") - + mysql = MySQL("test.conf", mysql_driver_mock) mysql.change("create table spam();", "20090212112104") - + + def test_it_should_execute_migration_down_and_update_schema_version(self): + mysql_driver_mock = Mock() + db_mock = Mock() + cursor_mock = Mock() + + self.__create_init_expectations(mysql_driver_mock, db_mock, cursor_mock) + + db_mock.expects(once()).method("query").query(eq("create table spam();")) + db_mock.expects(once()).method("close") + db_mock.expects(once()).method("query").query(eq("delete from __db_version__ where version > \"20090212112104\";")) + db_mock.expects(once()).method("close") + + mysql = MySQL("test.conf", mysql_driver_mock) + mysql.change("create table spam();", "20090212112104", False) + def test_it_should_get_current_schema_version(self): mysql_driver_mock = Mock() db_mock = Mock()