diff --git a/src/db-migrate b/src/db-migrate index 4ea6cba..a96b38c 100755 --- a/src/db-migrate +++ b/src/db-migrate @@ -6,4 +6,4 @@ from simple_db_migrate.main import Main (options, args) = CLI().parse() # If CLI was correctly parsed, execute db-migrate. -Main(options, args).execute() +Main(options, args).execute() \ No newline at end of file diff --git a/src/simple_db_migrate/logging.py b/src/simple_db_migrate/logging.py index 6190a18..49a7ff9 100644 --- a/src/simple_db_migrate/logging.py +++ b/src/simple_db_migrate/logging.py @@ -6,6 +6,5 @@ def __print(self, level, msg): print "[%s] %s\n" % (level, msg) def error_and_exit(self, msg): - if not ENVIRONMENT == "TEST": - self.__print("ERROR", msg) - sys.exit(1) + self.__print("ERROR", msg) + sys.exit(1) \ No newline at end of file diff --git a/src/simple_db_migrate/main.py b/src/simple_db_migrate/main.py index 47be25d..1441a42 100644 --- a/src/simple_db_migrate/main.py +++ b/src/simple_db_migrate/main.py @@ -1,9 +1,13 @@ from core import SimpleDBMigrate +from mysql import MySQL +from logging import Log class Main(object): def __init__(self, options=None, args=None): self.__options = options self.__args = args + self.__mysql = MySQL(self.__options.db_config_file) + self.__db_migrate = SimpleDBMigrate(self.__options.migrations_dir) def execute(self): print "\nStarting DB migration..." @@ -14,23 +18,20 @@ def execute(self): print "\nDone.\n" def __create_migration(self): - db_migrate = SimpleDBMigrate(self.__options.migrations_dir) - new_file = db_migrate.create_migration(self.__options.create_migration) + new_file = self.__db_migrate.create_migration(self.__options.create_migration) print "- Created file '%s'" % (new_file) #TODO: too big -- needs refactor def __migrate(self): - db_migrate = SimpleDBMigrate(self.__options.migrations_dir) - destination_version = self.__options.schema_version if destination_version == None: - destination_version = db_migrate.latest_schema_version_available() + destination_version = self.__db_migrate.latest_schema_version_available() - if not db_migrate.check_if_version_exists(destination_version): + if not self.__db_migrate.check_if_version_exists(destination_version): Log().error_and_exit("version not found (%s)" % destination_version) - current_version = mysql.get_current_schema_version() + current_version = self.__mysql.get_current_schema_version() if str(current_version) == str(destination_version): Log().error_and_exit("current and destination versions are the same (%s)" % current_version) @@ -46,18 +47,18 @@ def __migrate(self): print "\nStarting migration %s!\n" % migration # getting only the migration sql files to be executed - migration_files_to_be_executed = db_migrate.get_migration_files_between_versions(current_version, destination_version) + migration_files_to_be_executed = self.__db_migrate.get_migration_files_between_versions(current_version, destination_version) sql_statements_executed = "" for sql_file in migration_files_to_be_executed: - file_version = db_migrate.get_migration_version(sql_file) + file_version = self.__db_migrate.get_migration_version(sql_file) if not migration_up: file_version = destination_version print "===== executing %s (%s) =====" % (sql_file, migration) - sql = db_migrate.get_sql_command(sql_file, migration_up) - mysql.change(sql, file_version) + sql = self.__db_migrate.get_sql_command(sql_file, migration_up) + self.__mysql.change(sql, file_version) #recording the last statement executed sql_statements_executed += sql diff --git a/src/simple_db_migrate/mysql.py b/src/simple_db_migrate/mysql.py index eb4ed70..a311cb9 100644 --- a/src/simple_db_migrate/mysql.py +++ b/src/simple_db_migrate/mysql.py @@ -59,18 +59,18 @@ def __create_version_table_if_not_exists(self): sql = "insert into __db_version__ values (0);" self.__execute(sql) - def __set_new_db_version(self, version): - sql = "update __db_version__ set version = \"%s\";" % str(version) + def __insert_new_db_version(self, version): + sql = "insert into __db_version__ (version) values (\"%s\");" % str(version) self.__execute(sql) def change(self, sql, new_db_version): self.__execute(sql) - self.__set_new_db_version(new_db_version) + self.__insert_new_db_version(new_db_version) def get_current_schema_version(self): db = self.__mysql_connect() cursor = db.cursor() - cursor.execute("select version from __db_version__;") + cursor.execute("select version from __db_version__ order by version desc limit 0,1;") version = cursor.fetchone()[0] db.close() return version diff --git a/tests/mysql_test.py b/tests/mysql_test.py index ecf6907..b31ae6f 100644 --- a/tests/mysql_test.py +++ b/tests/mysql_test.py @@ -56,7 +56,7 @@ def test_it_should_execute_changes_and_update_schema_version(self): db_mock.expects(once()).method("query").query(eq("create table spam();")) db_mock.expects(once()).method("close") - db_mock.expects(once()).method("query").query(eq("update __db_version__ set version = \"20090212112104\";")) + db_mock.expects(once()).method("query").query(eq("insert into __db_version__ (version) values (\"20090212112104\");")) db_mock.expects(once()).method("close") mysql = MySQL("test.conf", mysql_driver_mock) @@ -69,7 +69,7 @@ def test_it_should_get_current_schema_version(self): self.__create_init_expectations(mysql_driver_mock, db_mock, cursor_mock) - cursor_mock.expects(once()).method("execute").execute(eq("select version from __db_version__;")) + cursor_mock.expects(once()).method("execute").execute(eq("select version from __db_version__ order by version desc limit 0,1;")) cursor_mock.expects(once()).method("fetchone").will(return_value("0")) db_mock.expects(once()).method("close") diff --git a/tests/test.py b/tests/test.py index 49021ef..9e824d1 100644 --- a/tests/test.py +++ b/tests/test.py @@ -2,8 +2,6 @@ import sys import unittest -ENVIRONMENT = "TEST" - sys.path.insert(0, os.path.abspath("./src/simple_db_migrate")) sys.path.insert(0, os.path.abspath("./tests")) sys.path.insert(0, os.path.abspath("../src/simple_db_migrate")) @@ -28,7 +26,7 @@ alltests.run(result) if result.wasSuccessful(): - print "\nAll %d tests passed :)\n" % result.testsRun + print "\n*** All %d tests passed :) ***\n" % result.testsRun else: print "\nError in tests (%d runned, %d errors, %d failures)\n" % (result.testsRun, len(result.errors), len(result.failures))