Skip to content
Browse files

Add a generic database wrapper, supporting different database backend…

…s, to be used by migrate, scheduler, parser (eventually), and maybe others. This will consolidate the multiple database wrappers we have throughout the code and allow us to swap in SQLite for MySQL for unit testing purposes.

-add database/ directory for database libraries.  migrate.py will move here soon.
-add database_connection.py under server_common, a basic database wrapper supporting both MySQL and SQLite.  PostgreSQL should be an easy future addition (any library supporting Python DB-API should be trivial to add).  DatabaseConnection also supports graceful handling of dropped connections.
-add unittest for DatabaseConnection
-change migrate.py to use common DatabaseConnection. Scheduler will be changed to use it in a coming CL and in the future hopefully the TKO parser will be able to use it as well.
-change migrate_unittest.py to use SQLite.

Signed-off-by: Steve Howard <showard@google.com>
  • Loading branch information...
1 parent 4a5c052 commit f730a223a4174951fddbd4d253858490406f01f9 Steve Howard committed Oct 3, 2008
View
0 database/__init__.py
No changes.
View
8 database/common.py
@@ -0,0 +1,8 @@
+import os, sys
+dirname = os.path.dirname(sys.modules[__name__].__file__)
+autotest_dir = os.path.abspath(os.path.join(dirname, ".."))
+client_dir = os.path.join(autotest_dir, "client")
+sys.path.insert(0, client_dir)
+import setup_modules
+sys.path.pop(0)
+setup_modules.setup(base_path=autotest_dir, root_module_name="autotest_lib")
View
250 database/database_connection.py
@@ -0,0 +1,250 @@
+import traceback, time
+import common
+from autotest_lib.client.common_lib import global_config
+
+RECONNECT_FOREVER = object()
+
+_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
+_GLOBAL_CONFIG_NAMES = {
+ 'username' : 'user',
+ 'db_name' : 'database',
+}
+
+def _copy_exceptions(source, destination):
+ for exception_name in _DB_EXCEPTIONS:
+ setattr(destination, exception_name, getattr(source, exception_name))
+
+
+class _GenericBackend(object):
+ def __init__(self, database_module):
+ self._database_module = database_module
+ self._connection = None
+ self._cursor = None
+ self.rowcount = None
+ _copy_exceptions(database_module, self)
+
+
+ def connect(self, host=None, username=None, password=None, db_name=None):
+ """
+ This is assumed to enable autocommit.
+ """
+ raise NotImplementedError
+
+
+ def disconnect(self):
+ if self._connection:
+ self._connection.close()
+ self._connection = None
+ self._cursor = None
+
+
+ def execute(self, query, arguments=None):
+ self._cursor.execute(query, arguments)
+ self.rowcount = self._cursor.rowcount
+ return self._cursor.fetchall()
+
+
+ def get_exception_details(exception):
+ return ExceptionDetails.UNKNOWN
+
+
+class _MySqlBackend(_GenericBackend):
+ def __init__(self):
+ import MySQLdb
+ super(_MySqlBackend, self).__init__(MySQLdb)
+
+
+ @staticmethod
+ def convert_boolean(boolean, conversion_dict):
+ 'Convert booleans to integer strings'
+ return str(int(boolean))
+
+
+ def connect(self, host=None, username=None, password=None, db_name=None):
+ import MySQLdb.converters
+ convert_dict = MySQLdb.converters.conversions
+ convert_dict.setdefault(bool, self.convert_boolean)
+
+ self._connection = self._database_module.connect(
+ host=host, user=username, passwd=password, db=db_name,
+ conv=convert_dict)
+ self._connection.autocommit(True)
+ self._cursor = self._connection.cursor()
+
+
+ def get_exception_details(exception):
+ pass
+
+
+class _SqliteBackend(_GenericBackend):
+ def __init__(self):
+ from pysqlite2 import dbapi2
+ super(_SqliteBackend, self).__init__(dbapi2)
+
+
+ def connect(self, host=None, username=None, password=None, db_name=None):
+ self._connection = self._database_module.connect(db_name)
+ self._connection.isolation_level = None # enable autocommit
+ self._cursor = self._connection.cursor()
+
+
+ def execute(self, query, arguments=None):
+ # pysqlite2 uses paramstyle=qmark
+ # TODO: make this more sophisticated if necessary
+ query = query.replace('%s', '?')
+ return super(_SqliteBackend, self).execute(query, arguments)
+
+
+_BACKEND_MAP = {
+ 'mysql' : _MySqlBackend,
+ 'sqlite' : _SqliteBackend,
+}
+
+
+class DatabaseConnection(object):
+ """
+ Generic wrapper for a database connection. Supports both mysql and sqlite
+ backends.
+
+ Public attributes:
+ * reconnect_enabled: if True, when an OperationalError occurs the class will
+ try to reconnect to the database automatically.
+ * reconnect_delay_sec: seconds to wait before reconnecting
+ * max_reconnect_attempts: maximum number of time to try reconnecting before
+ giving up. Setting to RECONNECT_FOREVER removes the limit.
+ * rowcount - will hold cursor.rowcount after each call to execute().
+ * global_config_section - the section in which to find DB information. this
+ should be passed to the constructor, not set later, and may be None, in
+ which case information must be passed to connect().
+ """
+ _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
+ 'db_name')
+
+ def __init__(self, global_config_section=None):
+ self.global_config_section = global_config_section
+ self._backend = None
+ self.rowcount = None
+
+ # reconnect defaults
+ self.reconnect_enabled = True
+ self.reconnect_delay_sec = 20
+ self.max_reconnect_attempts = 10
+
+ self._read_options()
+
+
+ def _get_option(self, name, provided_value):
+ if provided_value is not None:
+ return provided_value
+ if self.global_config_section:
+ global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
+ return global_config.global_config.get_config_value(
+ self.global_config_section, global_config_name)
+ return getattr(self, name, None)
+
+
+ def _read_options(self, db_type=None, host=None, username=None,
+ password=None, db_name=None):
+ self.db_type = self._get_option('db_type', db_type)
+ self.host = self._get_option('host', host)
+ self.username = self._get_option('username', username)
+ self.password = self._get_option('password', password)
+ self.db_name = self._get_option('db_name', db_name)
+
+
+ def _get_backend(self, db_type):
+ if db_type not in _BACKEND_MAP:
+ raise ValueError('Invalid database type: %s, should be one of %s' %
+ (db_type, ', '.join(_BACKEND_MAP.keys())))
+ backend_class = _BACKEND_MAP[db_type]
+ return backend_class()
+
+
+ def _reached_max_attempts(self, num_attempts):
+ return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
+ num_attempts > self.max_reconnect_attempts)
+
+
+ def _is_reconnect_enabled(self, supplied_param):
+ if supplied_param is not None:
+ return supplied_param
+ return self.reconnect_enabled
+
+
+ def _connect_backend(self, try_reconnecting=None):
+ num_attempts = 0
+ while True:
+ try:
+ self._backend.connect(host=self.host, username=self.username,
+ password=self.password,
+ db_name=self.db_name)
+ return
+ except self._backend.OperationalError:
+ num_attempts += 1
+ if not self._is_reconnect_enabled(try_reconnecting):
+ raise
+ if self._reached_max_attempts(num_attempts):
+ raise
+ traceback.print_exc()
+ print ("Can't connect to database; reconnecting in %s sec" %
+ self.reconnect_delay_sec)
+ time.sleep(self.reconnect_delay_sec)
+ self.disconnect()
+
+
+ def connect(self, db_type=None, host=None, username=None, password=None,
+ db_name=None, try_reconnecting=None):
+ """
+ Parameters passed to this function will override defaults from global
+ config. try_reconnecting, if passed, will override
+ self.reconnect_enabled.
+ """
+ self.disconnect()
+ self._read_options(db_type, host, username, password, db_name)
+
+ self._backend = self._get_backend(self.db_type)
+ _copy_exceptions(self._backend, self)
+ self._connect_backend(try_reconnecting)
+
+
+ def disconnect(self):
+ if self._backend:
+ self._backend.disconnect()
+
+
+ def execute(self, query, parameters=None, try_reconnecting=None):
+ """
+ Execute a query and return cursor.fetchall(). try_reconnecting, if
+ passed, will override self.reconnect_enabled.
+ """
+ # _connect_backend() contains a retry loop, so don't loop here
+ try:
+ results = self._backend.execute(query, parameters)
+ except self._backend.OperationalError:
+ if not self._is_reconnect_enabled(try_reconnecting):
+ raise
+ traceback.print_exc()
+ print ("MYSQL connection died; reconnecting")
+ self.disconnect()
+ self._connect_backend(try_reconnecting)
+ results = self._backend.execute(query, parameters)
+
+ self.rowcount = self._backend.rowcount
+ return results
+
+
+ def get_database_info(self):
+ return dict((attribute, getattr(self, attribute))
+ for attribute in self._DATABASE_ATTRIBUTES)
+
+
+ @classmethod
+ def get_test_database(cls):
+ """
+ Factory method returning a DatabaseConnection for a temporary in-memory
+ database.
+ """
+ database = cls()
+ database.reconnect_enabled = False
+ database.connect(db_type='sqlite', db_name=':memory:')
+ return database
View
186 database/database_connection_unittest.py
@@ -0,0 +1,186 @@
+#!/usr/bin/python2.4
+
+import unittest, time
+import MySQLdb
+import pysqlite2.dbapi2
+import common
+from autotest_lib.client.common_lib import global_config
+from autotest_lib.client.common_lib.test_utils import mock
+from autotest_lib.database import database_connection
+
+_CONFIG_SECTION = 'TKO'
+_HOST = 'myhost'
+_USER = 'myuser'
+_PASS = 'mypass'
+_DB_NAME = 'mydb'
+_DB_TYPE = 'mydbtype'
+
+_CONNECT_KWARGS = dict(host=_HOST, username=_USER, password=_PASS,
+ db_name=_DB_NAME)
+_RECONNECT_DELAY = 10
+
+class FakeDatabaseError(Exception):
+ pass
+
+
+class DatabaseConnectionTest(unittest.TestCase):
+ def setUp(self):
+ self.god = mock.mock_god()
+ self.god.stub_function(time, 'sleep')
+
+
+ def tearDown(self):
+ global_config.global_config.reset_config_values()
+ self.god.unstub_all()
+
+
+ def _get_database_connection(self, config_section=_CONFIG_SECTION):
+ if config_section == _CONFIG_SECTION:
+ self._override_config()
+ db = database_connection.DatabaseConnection(config_section)
+
+ self._fake_backend = self.god.create_mock_class(
+ database_connection._GenericBackend, 'fake_backend')
+ for exception in database_connection._DB_EXCEPTIONS:
+ setattr(self._fake_backend, exception, FakeDatabaseError)
+ self._fake_backend.rowcount = 0
+
+ def get_fake_backend(db_type):
+ self._db_type = db_type
+ return self._fake_backend
+ self.god.stub_with(db, '_get_backend', get_fake_backend)
+
+ db.reconnect_delay_sec = _RECONNECT_DELAY
+ return db
+
+
+ def _override_config(self):
+ c = global_config.global_config
+ c.override_config_value(_CONFIG_SECTION, 'host', _HOST)
+ c.override_config_value(_CONFIG_SECTION, 'user', _USER)
+ c.override_config_value(_CONFIG_SECTION, 'password', _PASS)
+ c.override_config_value(_CONFIG_SECTION, 'database', _DB_NAME)
+ c.override_config_value(_CONFIG_SECTION, 'db_type', _DB_TYPE)
+
+
+ def test_connect(self):
+ db = self._get_database_connection(config_section=None)
+ self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
+
+ db.connect(db_type=_DB_TYPE, host=_HOST, username=_USER,
+ password=_PASS, db_name=_DB_NAME)
+
+ self.assertEquals(self._db_type, _DB_TYPE)
+ self.god.check_playback()
+
+
+ def test_global_config(self):
+ db = self._get_database_connection()
+ self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
+
+ db.connect()
+
+ self.assertEquals(self._db_type, _DB_TYPE)
+ self.god.check_playback()
+
+
+ def _expect_reconnect(self, fail=False):
+ self._fake_backend.disconnect.expect_call()
+ call = self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
+ if fail:
+ call.and_raises(FakeDatabaseError())
+
+
+ def _expect_fail_and_reconnect(self, num_reconnects, fail_last=False):
+ self._fake_backend.connect.expect_call(**_CONNECT_KWARGS).and_raises(
+ FakeDatabaseError())
+ for i in xrange(num_reconnects):
+ time.sleep.expect_call(_RECONNECT_DELAY)
+ if i < num_reconnects - 1:
+ self._expect_reconnect(fail=True)
+ else:
+ self._expect_reconnect(fail=fail_last)
+
+
+ def test_connect_retry(self):
+ db = self._get_database_connection()
+ self._expect_fail_and_reconnect(1)
+
+ db.connect()
+ self.god.check_playback()
+
+ self._fake_backend.disconnect.expect_call()
+ self._expect_fail_and_reconnect(0)
+ self.assertRaises(FakeDatabaseError, db.connect,
+ try_reconnecting=False)
+ self.god.check_playback()
+
+ db.reconnect_enabled = False
+ self._fake_backend.disconnect.expect_call()
+ self._expect_fail_and_reconnect(0)
+ self.assertRaises(FakeDatabaseError, db.connect)
+ self.god.check_playback()
+
+
+ def test_max_reconnect(self):
+ db = self._get_database_connection()
+ db.max_reconnect_attempts = 5
+ self._expect_fail_and_reconnect(5, fail_last=True)
+
+ self.assertRaises(FakeDatabaseError, db.connect)
+ self.god.check_playback()
+
+
+ def test_reconnect_forever(self):
+ db = self._get_database_connection()
+ db.max_reconnect_attempts = database_connection.RECONNECT_FOREVER
+ self._expect_fail_and_reconnect(30)
+
+ db.connect()
+ self.god.check_playback()
+
+
+ def _simple_connect(self, db):
+ self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
+ db.connect()
+ self.god.check_playback()
+
+
+ def test_disconnect(self):
+ db = self._get_database_connection()
+ self._simple_connect(db)
+ self._fake_backend.disconnect.expect_call()
+
+ db.disconnect()
+ self.god.check_playback()
+
+
+ def test_execute(self):
+ db = self._get_database_connection()
+ self._simple_connect(db)
+ params = object()
+ self._fake_backend.execute.expect_call('query', params)
+
+ db.execute('query', params)
+ self.god.check_playback()
+
+
+ def test_execute_retry(self):
+ db = self._get_database_connection()
+ self._simple_connect(db)
+ self._fake_backend.execute.expect_call('query', None).and_raises(
+ FakeDatabaseError())
+ self._expect_reconnect()
+ self._fake_backend.execute.expect_call('query', None)
+
+ db.execute('query')
+ self.god.check_playback()
+
+ self._fake_backend.execute.expect_call('query', None).and_raises(
+ FakeDatabaseError())
+ self.assertRaises(FakeDatabaseError, db.execute, 'query',
+ try_reconnecting=False)
+
+
+if __name__ == '__main__':
+ unittest.main()
View
4 frontend/migrations/001_initial_db.py
@@ -6,8 +6,8 @@
def migrate_up(manager):
- manager.execute("SHOW TABLES")
- tables = [row[0] for row in manager.cursor.fetchall()]
+ rows = manager.execute("SHOW TABLES")
+ tables = [row[0] for row in rows]
db_initialized = True
for table in required_tables:
if table not in tables:
View
105 migrate/migrate.py
@@ -5,6 +5,7 @@
from optparse import OptionParser
import common
from autotest_lib.client.common_lib import global_config
+from autotest_lib.database import database_connection
MIGRATE_TABLE = 'migrate_info'
@@ -37,50 +38,29 @@ class MigrationManager(object):
cursor = None
migrations_dir = None
- def __init__(self, database, migrations_dir=None, force=False):
- self.database = database
+ def __init__(self, database_connection, migrations_dir=None, force=False):
+ self._database = database_connection
self.force = force
+ self._set_migrations_dir(migrations_dir)
+
+
+ def _set_migrations_dir(self, migrations_dir=None):
+ config_section = self._database.global_config_section
if migrations_dir is None:
migrations_dir = os.path.abspath(
- _MIGRATIONS_DIRS.get(database, _DEFAULT_MIGRATIONS_DIR))
+ _MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR))
self.migrations_dir = migrations_dir
sys.path.append(migrations_dir)
- assert os.path.exists(migrations_dir)
-
- self.db_host = None
- self.db_name = None
- self.username = None
- self.password = None
-
-
- def read_db_info(self):
- # grab the config file and parse for info
- c = global_config.global_config
- self.db_host = c.get_config_value(self.database, "host")
- self.db_name = c.get_config_value(self.database, "database")
- self.username = c.get_config_value(self.database, "user")
- self.password = c.get_config_value(self.database, "password")
-
+ assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist"
- def connect(self, host, db_name, username, password):
- return MySQLdb.connect(host=host, db=db_name, user=username,
- passwd=password)
-
- def open_connection(self):
- self.connection = self.connect(self.db_host, self.db_name,
- self.username, self.password)
- self.connection.autocommit(True)
- self.cursor = self.connection.cursor()
-
-
- def close_connection(self):
- self.connection.close()
+ def _get_db_name(self):
+ return self._database.get_database_info()['db_name']
def execute(self, query, *parameters):
#print 'SQL:', query % parameters
- return self.cursor.execute(query, parameters)
+ return self._database.execute(query, parameters)
def execute_script(self, script):
@@ -95,11 +75,10 @@ def check_migrate_table_exists(self):
try:
self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
return True
- except MySQLdb.ProgrammingError, exc:
- error_code, _ = exc.args
- if error_code == MySQLdb.constants.ER.NO_SUCH_TABLE:
- return False
- raise
+ except self._database.DatabaseError, exc:
+ # we can't check for more specifics due to differences between DB
+ # backends (we can't even check for a subclass of DatabaseError)
+ return False
def create_migrate_table(self):
@@ -109,21 +88,20 @@ def create_migrate_table(self):
else:
self.execute("DELETE FROM %s" % MIGRATE_TABLE)
self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
- assert self.cursor.rowcount == 1
+ assert self._database.rowcount == 1
def set_db_version(self, version):
assert isinstance(version, int)
self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
version)
- assert self.cursor.rowcount == 1
+ assert self._database.rowcount == 1
def get_db_version(self):
if not self.check_migrate_table_exists():
return 0
- self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
- rows = self.cursor.fetchall()
+ rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
if len(rows) == 0:
return 0
assert len(rows) == 1 and len(rows[0]) == 1
@@ -190,30 +168,28 @@ def migrate_to_latest(self):
def initialize_test_db(self):
- self.read_db_info()
- test_db_name = 'test_' + self.db_name
+ db_name = self._get_db_name()
+ test_db_name = 'test_' + db_name
# first, connect to no DB so we can create a test DB
- self.db_name = ''
- self.open_connection()
+ self._database.connect(db_name='')
print 'Creating test DB', test_db_name
self.execute('CREATE DATABASE ' + test_db_name)
- self.close_connection()
+ self._database.disconnect()
# now connect to the test DB
- self.db_name = test_db_name
- self.open_connection()
+ self._database.connect(db_name=test_db_name)
def remove_test_db(self):
print 'Removing test DB'
- self.execute('DROP DATABASE ' + self.db_name)
+ self.execute('DROP DATABASE ' + self._get_db_name())
+ # reset connection back to real DB
+ self._database.disconnect()
+ self._database.connect()
def get_mysql_args(self):
- return ('-u %(user)s -p%(password)s -h %(host)s %(db)s' % {
- 'user' : self.username,
- 'password' : self.password,
- 'host' : self.db_host,
- 'db' : self.db_name})
+ return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' %
+ self._database.get_database_info())
def migrate_to_version_or_latest(self, version):
@@ -224,9 +200,7 @@ def migrate_to_version_or_latest(self, version):
def do_sync_db(self, version=None):
- self.read_db_info()
- self.open_connection()
- print 'Migration starting for database', self.db_name
+ print 'Migration starting for database', self._get_db_name()
self.migrate_to_version_or_latest(version)
print 'Migration complete'
@@ -237,7 +211,7 @@ def test_sync_db(self, version=None):
"""
self.initialize_test_db()
try:
- print 'Starting migration test on DB', self.db_name
+ print 'Starting migration test on DB', self._get_db_name()
self.migrate_to_version_or_latest(version)
# show schema to the user
os.system('mysqldump %s --no-data=true '
@@ -253,28 +227,24 @@ def simulate_sync_db(self, version=None):
Create a fresh DB, copy the existing DB to it, and then
try to synchronize it.
"""
- self.read_db_info()
- self.open_connection()
db_version = self.get_db_version()
- self.close_connection()
# don't do anything if we're already at the latest version
if db_version == self.get_latest_version():
print 'Skipping simulation, already at latest version'
return
# get existing data
- self.read_db_info()
print 'Dumping existing data'
dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
- os.close(dump_fd)
os.system('mysqldump %s >%s' %
(self.get_mysql_args(), dump_file))
# fill in test DB
self.initialize_test_db()
print 'Filling in test DB'
os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
+ os.close(dump_fd)
os.remove(dump_file)
try:
- print 'Starting migration test on DB', self.db_name
+ print 'Starting migration test on DB', self._get_db_name()
self.migrate_to_version_or_latest(version)
finally:
self.remove_test_db()
@@ -299,7 +269,10 @@ def main():
parser.add_option("-f", "--force", help="don't ask for confirmation",
action="store_true")
(options, args) = parser.parse_args()
- manager = MigrationManager(options.database, force=options.force)
+ database = database_connection.DatabaseConnection(options.database)
+ database.reconnect_enabled = False
+ database.connect()
+ manager = MigrationManager(database, force=options.force)
if len(args) > 0:
if len(args) > 1:
View
50 migrate/migrate_unittest.py
@@ -1,10 +1,11 @@
#!/usr/bin/python2.4
-import unittest
+import unittest, tempfile, os
import MySQLdb
import migrate
import common
from autotest_lib.client.common_lib import global_config
+from autotest_lib.database import database_connection
# Which section of the global config to pull info from. We won't actually use
# that DB, we'll use the corresponding test DB (test_<db name>).
@@ -54,18 +55,8 @@ def migrate_down(self, manager):
class TestableMigrationManager(migrate.MigrationManager):
- def __init__(self, database, migrations_dir=None):
- self.database = database
- self.migrations_dir = migrations_dir
- self.db_host = None
- self.db_name = None
- self.username = None
- self.password = None
-
-
- def read_db_info(self):
- migrate.MigrationManager.read_db_info(self)
- self.db_name = 'test_' + self.db_name
+ def _set_migrations_dir(self, migrations_dir=None):
+ pass
def get_migrations(self, minimum_version=None, maximum_version=None):
@@ -75,39 +66,16 @@ def get_migrations(self, minimum_version=None, maximum_version=None):
class MigrateManagerTest(unittest.TestCase):
- config = global_config.global_config
- host = config.get_config_value(CONFIG_DB, 'host')
- db_name = 'test_' + config.get_config_value(CONFIG_DB, 'database')
- user = config.get_config_value(CONFIG_DB, 'user')
- password = config.get_config_value(CONFIG_DB, 'password')
-
- def do_sql(self, sql):
- self.con = MySQLdb.connect(host=self.host, user=self.user,
- passwd=self.password)
- self.con.autocommit(True)
- self.cur = self.con.cursor()
- try:
- self.cur.execute(sql)
- finally:
- self.con.close()
-
-
- def remove_db(self):
- self.do_sql('DROP DATABASE ' + self.db_name)
-
-
def setUp(self):
- self.do_sql('CREATE DATABASE ' + self.db_name)
- try:
- self.manager = TestableMigrationManager(CONFIG_DB)
- except MySQLdb.OperationalError:
- self.remove_db()
- raise
+ self._database = (
+ database_connection.DatabaseConnection.get_test_database())
+ self._database.connect()
+ self.manager = TestableMigrationManager(self._database)
DummyMigration.clear_migrations_done()
def tearDown(self):
- self.remove_db()
+ self._database.disconnect()
def test_sync(self):
View
8 scheduler/monitor_db_unittest.py
@@ -6,6 +6,7 @@
from autotest_lib.client.common_lib import global_config, host_protections
from autotest_lib.client.common_lib.test_utils import mock
from autotest_lib.migrate import migrate
+from autotest_lib.database import database_connection
import monitor_db
@@ -102,10 +103,9 @@ def _open_test_db(self):
self._do_query('CREATE DATABASE ' + self._db_name)
self._disconnect_from_db()
- migration_dir = os.path.join(os.path.dirname(__file__),
- '..', 'frontend', 'migrations')
- manager = migrate.MigrationManager('AUTOTEST_WEB', migration_dir,
- force=True)
+ database = database_connection.DatabaseConnection('AUTOTEST_WEB')
+ database.connect(db_name=self._db_name)
+ manager = migrate.MigrationManager(database, force=True)
manager.do_sync_db()
self._connect_to_db(self._db_name)
View
4 tko/migrations/001_initial_db.py
@@ -4,8 +4,8 @@
'iteration_result')
def migrate_up(manager):
- manager.execute("SHOW TABLES")
- tables = [row[0] for row in manager.cursor.fetchall()]
+ rows = manager.execute("SHOW TABLES")
+ tables = [row[0] for row in rows]
db_initialized = True
for table in required_tables:
if table not in tables:

0 comments on commit f730a22

Please sign in to comment.
Something went wrong with that request. Please try again.