Skip to content

Commit

Permalink
Refactoring part 3: moving config classes to a new 'config' module.
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermechapiewski committed Jul 27, 2009
1 parent 3ccf85a commit aeba159
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 193 deletions.
67 changes: 67 additions & 0 deletions src/config.py
@@ -0,0 +1,67 @@
import codecs
import os

class Config(object):

DB_VERSION_TABLE = "__db_version__"

def __init__(self):
self._config = {}

def __repr__(self):
return str(self._config)

def get(self, config_key):
try:
return self._config[config_key]
except KeyError:
raise Exception("invalid configuration key ('%s')" % config_key)

def put(self, config_key, config_value):
if config_key in self._config:
raise Exception("the configuration key '%s' already exists and you cannot override any configuration" % config_key)
self._config[config_key] = config_value

def _parse_migrations_dir(self, dirs):
abs_dirs = []
for dir in dirs.split(':'):
abs_dirs.append(os.path.abspath(dir))
return abs_dirs

class FileConfig(Config):

def __init__(self, config_file="simple-db-migrate.conf"):
self._config = {}

# read configurations
try:
f = codecs.open(config_file, "rU", "utf-8")
exec(f.read())
except IOError:
raise Exception("%s: file not found" % config_file)
else:
f.close()

try:
self.put("db_host", HOST)
self.put("db_user", USERNAME)
self.put("db_password", PASSWORD)
self.put("db_name", DATABASE)
self.put("db_version_table", self.DB_VERSION_TABLE)
self.put("migrations_dir", self._parse_migrations_dir(MIGRATIONS_DIR))
except NameError, e:
raise Exception("config file error: " + str(e))

class InPlaceConfig(Config):

def __init__(self, db_host, db_user, db_password, db_name, migrations_dir, db_version_table=''):
if not db_version_table or db_version_table == '':
db_version_table = self.DB_VERSION_TABLE
self._config = {
"db_host": db_host,
"db_user": db_user,
"db_password": db_password,
"db_name": db_name,
"db_version_table": db_version_table,
"migrations_dir": self._parse_migrations_dir(migrations_dir)
}
66 changes: 0 additions & 66 deletions src/core.py
Expand Up @@ -3,72 +3,6 @@
import os
import shutil
import re

# TODO: move to a config module (the 3 config classes below)
class Config(object):

DB_VERSION_TABLE = "__db_version__"

def __init__(self):
self._config = {}

def __repr__(self):
return str(self._config)

def get(self, config_key):
try:
return self._config[config_key]
except KeyError:
raise Exception("invalid configuration key ('%s')" % config_key)

def put(self, config_key, config_value):
if config_key in self._config:
raise Exception("the configuration key '%s' already exists and you cannot override any configuration" % config_key)
self._config[config_key] = config_value

def _parse_migrations_dir(self, dirs):
abs_dirs = []
for dir in dirs.split(':'):
abs_dirs.append(os.path.abspath(dir))
return abs_dirs

class FileConfig(Config):

def __init__(self, config_file="simple-db-migrate.conf"):
self._config = {}

# read configurations
try:
f = codecs.open(config_file, "rU", "utf-8")
exec(f.read())
except IOError:
raise Exception("%s: file not found" % config_file)
else:
f.close()

try:
self.put("db_host", HOST)
self.put("db_user", USERNAME)
self.put("db_password", PASSWORD)
self.put("db_name", DATABASE)
self.put("db_version_table", self.DB_VERSION_TABLE)
self.put("migrations_dir", self._parse_migrations_dir(MIGRATIONS_DIR))
except NameError, e:
raise Exception("config file error: " + str(e))

class InPlaceConfig(Config):

def __init__(self, db_host, db_user, db_password, db_name, migrations_dir, db_version_table=''):
if not db_version_table or db_version_table == '':
db_version_table = self.DB_VERSION_TABLE
self._config = {
"db_host": db_host,
"db_user": db_user,
"db_password": db_password,
"db_name": db_name,
"db_version_table": db_version_table,
"migrations_dir": self._parse_migrations_dir(migrations_dir)
}

class Migration(object):

Expand Down
133 changes: 133 additions & 0 deletions tests/config_test.py
@@ -0,0 +1,133 @@
from test import *
from config import *
import codecs
import unittest

class ConfigTest(unittest.TestCase):

def test_it_should_parse_migrations_dir_with_one_relative_dir(self):
config = Config()
dirs = config._parse_migrations_dir('.')
assert len(dirs) == 1
assert dirs[0] == os.path.abspath('.')

def test_it_should_parse_migrations_dir_with_two_relative_dirs(self):
config = Config()
dirs = config._parse_migrations_dir('test:migrations:./a/relative/path:another/path')
assert len(dirs) == 4
assert dirs[0] == os.path.abspath('test')
assert dirs[1] == os.path.abspath('migrations')
assert dirs[2] == os.path.abspath('./a/relative/path')
assert dirs[3] == os.path.abspath('another/path')

def test_it_should_parse_migrations_dir_with_one_absolute_dir(self):
config = Config()
dirs = config._parse_migrations_dir(os.path.abspath('.'))
assert len(dirs) == 1
assert dirs[0] == os.path.abspath('.')

def test_it_should_parse_migrations_dir_with_two_absolute_dirs(self):
config = Config()
dirs = config._parse_migrations_dir('%s:%s:%s:%s' % (
os.path.abspath('test'), os.path.abspath('migrations'),
os.path.abspath('./a/relative/path'), os.path.abspath('another/path'))
)
assert len(dirs) == 4
assert dirs[0] == os.path.abspath('test')
assert dirs[1] == os.path.abspath('migrations')
assert dirs[2] == os.path.abspath('./a/relative/path')
assert dirs[3] == os.path.abspath('another/path')

class FileConfigTest(unittest.TestCase):

def setUp(self):
config_file = '''
HOST = os.getenv('DB_HOST') or 'localhost'
USERNAME = os.getenv('DB_USERNAME') or 'root'
PASSWORD = os.getenv('DB_PASSWORD') or ''
DATABASE = os.getenv('DB_DATABASE') or 'migration_example'
MIGRATIONS_DIR = os.getenv('MIGRATIONS_DIR') or 'example'
'''
f = open('sample.conf', 'w')
f.write(config_file)
f.close()

def tearDown(self):
os.remove('sample.conf')

def test_it_should_read_config_file(self):
config_path = os.path.abspath('sample.conf')
config = FileConfig(config_path)
self.assertEquals(config.get('db_host'), 'localhost')
self.assertEquals(config.get('db_user'), 'root')
self.assertEquals(config.get('db_password'), '')
self.assertEquals(config.get('db_name'), 'migration_example')
self.assertEquals(config.get('db_version_table'), Config.DB_VERSION_TABLE)
self.assertEquals(config.get('migrations_dir'), [os.path.abspath('example')])

def test_it_should_stop_execution_when_an_invalid_key_is_requested(self):
config_path = os.path.abspath('sample.conf')
config = FileConfig(config_path)
try:
config.get('invalid_config')
self.fail('it should not pass here')
except:
pass

def test_it_should_create_new_configs(self):
config_path = os.path.abspath('sample.conf')
config = FileConfig(config_path)

# ensure that the config does not exist
self.assertRaises(Exception, config.get, 'sample_config', 'TEST')

# create the config
config.put('sample_config', 'TEST')

# read the config
self.assertEquals(config.get('sample_config'), 'TEST')

def test_it_should_not_override_existing_configs(self):
config_path = os.path.abspath('sample.conf')
config = FileConfig(config_path)
config.put('sample_config', 'TEST')
self.assertRaises(Exception, config.put, 'sample_config', 'TEST')

class InPlaceConfigTest(unittest.TestCase):

def test_it_should_configure_default_parameters(self):
config = InPlaceConfig('localhost', 'user', 'passwd', 'db', 'dir')
self.assertEquals(config.get('db_host'), 'localhost')
self.assertEquals(config.get('db_user'), 'user')
self.assertEquals(config.get('db_password'), 'passwd')
self.assertEquals(config.get('db_name'), 'db')
self.assertEquals(config.get('db_version_table'), Config.DB_VERSION_TABLE)
self.assertEquals(config.get('migrations_dir'), [os.path.abspath('dir')])

def test_it_should_stop_execution_when_an_invalid_key_is_requested(self):
config = InPlaceConfig('localhost', 'user', 'passwd', 'db', 'dir')
try:
config.get('invalid_config')
self.fail('it should not pass here')
except:
pass

def test_it_should_create_new_configs(self):
config = InPlaceConfig('localhost', 'user', 'passwd', 'db', 'dir')

# ensure that the config does not exist
self.assertRaises(Exception, config.get, 'sample_config', 'TEST')

# create the config
config.put('sample_config', 'TEST')

# read the config
self.assertEquals(config.get('sample_config'), 'TEST')

def test_it_should_not_override_existing_configs(self):
config = InPlaceConfig('localhost', 'user', 'passwd', 'db', 'dir')
config.put('sample_config', 'TEST')
self.assertRaises(Exception, config.put, 'sample_config', 'TEST')

if __name__ == '__main__':
unittest.main()

0 comments on commit aeba159

Please sign in to comment.