diff --git a/src/simple_db_migrate/mysql.py b/src/simple_db_migrate/mysql.py index 7dc8069..87fa02d 100644 --- a/src/simple_db_migrate/mysql.py +++ b/src/simple_db_migrate/mysql.py @@ -21,6 +21,7 @@ def __init__(self, db_config_file="simple-db-migrate.conf", mysql_driver=MySQLdb self.__mysql_user__ = USERNAME self.__mysql_passwd__ = PASSWORD self.__mysql_db__ = DATABASE + self.__version_table = "__db_version__" if drop_db_first: self._drop_database() @@ -31,11 +32,13 @@ def __init__(self, db_config_file="simple-db-migrate.conf", mysql_driver=MySQLdb def __mysql_connect(self, connect_using_db_name=True): try: if connect_using_db_name: - return self.__mysql_driver.connect(host=self.__mysql_host__, user=self.__mysql_user__, passwd=self.__mysql_passwd__, db=self.__mysql_db__) + conn = self.__mysql_driver.connect(host=self.__mysql_host__, user=self.__mysql_user__, passwd=self.__mysql_passwd__, db=self.__mysql_db__) + conn.autocommit(True) + return conn return self.__mysql_driver.connect(host=self.__mysql_host__, user=self.__mysql_user__, passwd=self.__mysql_passwd__) - except Exception: - self.__cli.error_and_exit("could not connect to database") + except Exception, e: + self.__cli.error_and_exit("could not connect to database (%s)" % e) def __execute(self, sql): db = self.__mysql_connect() @@ -57,28 +60,28 @@ def _create_database_if_not_exists(self): def _create_version_table_if_not_exists(self): # create version table - sql = "create table if not exists __db_version__ ( version varchar(20) NOT NULL default \"0\" );" + sql = "create table if not exists %s ( version varchar(20) NOT NULL default \"0\" );" % self.__version_table self.__execute(sql) # check if there is a register there db = self.__mysql_connect() cursor = db.cursor() - cursor.execute("select count(*) from __db_version__;") + cursor.execute("select count(*) from %s;" % self.__version_table) count = cursor.fetchone()[0] db.close() # if there is not a version register, insert one if count == 0: - sql = "insert into __db_version__ values (0);" + sql = "insert into %s (version) values (\"0\");" % self.__version_table self.__execute(sql) def __change_db_version(self, version, up=True): if up: # moving up and storing history - sql = "insert into __db_version__ (version) values (\"%s\");" % str(version) + sql = "insert into %s (version) values (\"%s\");" % (self.__version_table, str(version)) else: # moving down and deleting from history - sql = "delete from __db_version__ where version >= \"%s\";" % str(version) + sql = "delete from %s where version >= \"%s\";" % (self.__version_table, str(version)) self.__execute(sql) def change(self, sql, new_db_version, up=True): @@ -88,7 +91,7 @@ def change(self, sql, new_db_version, up=True): def get_current_schema_version(self): db = self.__mysql_connect() cursor = db.cursor() - cursor.execute("select version from __db_version__ order by version desc limit 0,1;") + cursor.execute("select version from %s order by version desc limit 0,1;" % self.__version_table) version = cursor.fetchone()[0] db.close() return version @@ -97,7 +100,7 @@ def get_all_schema_versions(self): versions = [] db = self.__mysql_connect() cursor = db.cursor() - cursor.execute("select version from __db_version__ order by version;") + cursor.execute("select version from %s order by version;" % self.__version_table) all_versions = cursor.fetchall() for version in all_versions: versions.append(version[0])