Skip to content

Commit

Permalink
#39: Fix to run both mysql dumps and regular sql in the same way.
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermechapiewski committed Mar 24, 2009
1 parent ad217e1 commit f37356a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
32 changes: 17 additions & 15 deletions src/simple_db_migrate/mysql.py
Expand Up @@ -17,10 +17,10 @@ def __init__(self, db_config_file="simple-db-migrate.conf", mysql_driver=MySQLdb
f.close()

self.__mysql_driver = mysql_driver
self.__mysql_host__ = HOST
self.__mysql_user__ = USERNAME
self.__mysql_passwd__ = PASSWORD
self.__mysql_db__ = DATABASE
self.__mysql_host = HOST
self.__mysql_user = USERNAME
self.__mysql_passwd = PASSWORD
self.__mysql_db = DATABASE
self.__version_table = "__db_version__"

if drop_db_first:
Expand All @@ -31,37 +31,39 @@ def __init__(self, db_config_file="simple-db-migrate.conf", mysql_driver=MySQLdb

def __mysql_connect(self, connect_using_db_name=True):
try:
conn = self.__mysql_driver.connect(host=self.__mysql_host, user=self.__mysql_user, passwd=self.__mysql_passwd)
if connect_using_db_name:
return self.__mysql_driver.connect(host=self.__mysql_host__, user=self.__mysql_user__, passwd=self.__mysql_passwd__, db=self.__mysql_db__)

return self.__mysql_driver.connect(host=self.__mysql_host__, user=self.__mysql_user__, passwd=self.__mysql_passwd__)
conn.select_db(self.__mysql_db)
return conn
except Exception, e:
self.__cli.error_and_exit("could not connect to database (%s)" % e)

def __execute(self, sql):
db = self.__mysql_connect()
db = self.__mysql_connect()
cursor = db.cursor()
#cursor._defer_warnings = True
try:
cursor.execute(sql)
sql_statements = sql.split(";")
sql_statements = [s.strip() for s in sql_statements if s.strip() != ""]
for statement in sql_statements:
cursor.execute(statement)
cursor.close()
db.commit()
db.close()
except Exception, e:
db.rollback()
db.close()
self.__cli.error_and_exit("error executing migration (%s)" % e)

def _drop_database(self):
db = self.__mysql_connect(False)
try:
db.query("drop database %s;" % self.__mysql_db__)
db.query("drop database %s;" % self.__mysql_db)
except Exception, e:
self.__cli.error_and_exit("can't drop database '%s'; database doesn't exist" % self.__mysql_db__)
self.__cli.error_and_exit("can't drop database '%s'; database doesn't exist" % self.__mysql_db)
db.close()

def _create_database_if_not_exists(self):
db = self.__mysql_connect(False)
db.query("create database if not exists %s;" % self.__mysql_db__)
db.query("create database if not exists %s;" % self.__mysql_db)
db.close()

def _create_version_table_if_not_exists(self):
Expand Down
20 changes: 11 additions & 9 deletions tests/mysql_test.py
Expand Up @@ -16,18 +16,15 @@ def tearDown(self):

def __mock_db_init(self, mysql_driver_mock, db_mock, cursor_mock):
mysql_driver_mock.expects(at_least_once()).method("connect").will(return_value(db_mock))
db_mock.expects(at_least_once()).method("select_db")

# create db if not exists
db_mock.expects(once()).method("query").query(eq("create database if not exists migration_test;"))
db_mock.expects(once()).method("close")

# create version table if not exists
create_version_table = "create table if not exists __db_version__ ( version varchar(20) NOT NULL default \"0\" );"
db_mock.expects(once()).method("cursor").will(return_value(cursor_mock))
cursor_mock.expects(once()).method("execute").execute(eq(create_version_table))
cursor_mock.expects(once()).method("close")
db_mock.expects(once()).method("commit")
db_mock.expects(once()).method("close")
self.__mock_db_execute(db_mock, cursor_mock, create_version_table)

# check if exists any version
db_mock.expects(once()).method("cursor").will(return_value(cursor_mock))
Expand All @@ -38,7 +35,12 @@ def __mock_db_init(self, mysql_driver_mock, db_mock, cursor_mock):
def __mock_db_execute(self, db_mock, cursor_mock, query):
# mock a call to __execute
db_mock.expects(once()).method("cursor").will(return_value(cursor_mock))
cursor_mock.expects(once()).method("execute").execute(eq(query))

sql_statements = query.split(";")
sql_statements = [s.strip() for s in sql_statements if s.strip() != ""]
for statement in sql_statements:
cursor_mock.expects(once()).method("execute").execute(eq(statement))

cursor_mock.expects(once()).method("close")
db_mock.expects(once()).method("commit")
db_mock.expects(once()).method("close")
Expand Down Expand Up @@ -87,7 +89,7 @@ def test_it_should_execute_migration_down_and_update_schema_version(self):

mysql = MySQL("test.conf", mysql_driver_mock)
mysql.change("create table spam();", "20090212112104", False)

def test_it_should_get_current_schema_version(self):
mysql_driver_mock = Mock()
db_mock = Mock()
Expand All @@ -102,7 +104,7 @@ def test_it_should_get_current_schema_version(self):

mysql = MySQL("test.conf", mysql_driver_mock)
self.assertEquals("0", mysql.get_current_schema_version())

def test_it_should_get_all_schema_versions(self):
mysql_driver_mock = Mock()
db_mock = Mock()
Expand All @@ -127,6 +129,6 @@ def test_it_should_get_all_schema_versions(self):
self.assertEquals(len(expected_versions), len(schema_versions))
for version in schema_versions:
self.assertTrue(version in expected_versions)

if __name__ == "__main__":
unittest.main()

0 comments on commit f37356a

Please sign in to comment.