Skip to content

Commit

Permalink
improving the way to read the migrations, avoiding problems of encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
wandenberg committed Oct 12, 2010
1 parent 4ca42cc commit e9beac9
Showing 1 changed file with 60 additions and 9 deletions.
69 changes: 60 additions & 9 deletions src/core/__init__.py
Expand Up @@ -3,19 +3,23 @@
import os
import shutil
import re
import imp
import tempfile
import sys

class Migration(object):

MIGRATION_FILES_EXTENSION = ".migration"
MIGRATION_FILES_MASK = r"[0-9]{14}\w+%s$" % MIGRATION_FILES_EXTENSION
TEMPLATE = 'SQL_UP = u"""\n\n"""\n\nSQL_DOWN = u"""\n\n"""'
TEMPLATE = '#-*- coding:%s -*-\nSQL_UP = u"""\n\n"""\n\nSQL_DOWN = u"""\n\n"""\n'

def __init__(self, file=None, id=0, file_name="", version="", sql_up="", sql_down=""):
def __init__(self, file=None, id=0, file_name="", version="", sql_up="", sql_down="", script_encoding="utf-8"):
self.id = id
self.file_name = file_name
self.version = version
self.sql_up = sql_up
self.sql_down = sql_down
self.script_encoding = script_encoding
if file:
file_name = os.path.split(file)[1]
if not Migration.is_file_name_valid(file_name):
Expand All @@ -30,9 +34,48 @@ def __init__(self, file=None, id=0, file_name="", version="", sql_up="", sql_dow
self.sql_up, self.sql_down = self._get_commands()

def _get_commands(self):
f = codecs.open(self.abspath, "rU", "utf-8")
exec(f.read())
f.close()
SQL_UP = ''
SQL_DOWN = ''
mod = None
temp_abspath = None

try:
mod = imp.load_source(self.file_name, self.abspath)
SQL_UP = self._check_sql_unicode(mod.SQL_UP)
SQL_DOWN = self._check_sql_unicode(mod.SQL_DOWN)
except Exception:
try:
f = open(self.abspath, "rU")
content = f.read()
f.close()

temp_abspath = "%s/%s" %(tempfile.gettempdir().rstrip('/'), self.file_name)
f = open(temp_abspath, "w")
f.write('#-*- coding:%s -*-\n%s' % (self.script_encoding, content))
f.close()

mod = imp.load_source(self.file_name, temp_abspath)

SQL_UP = self._check_sql_unicode(mod.SQL_UP)
SQL_DOWN = self._check_sql_unicode(mod.SQL_DOWN)

except Exception:
f = codecs.open(self.abspath, "rU", self.script_encoding)
exec(f.read())
f.close()
finally:
#erase temp and compiled files
if temp_abspath and os.path.isfile(temp_abspath):
os.remove(temp_abspath)

if mod and sys.modules.has_key(self.file_name):
sys.modules.pop(self.file_name)

if temp_abspath and os.path.isfile(temp_abspath + "c"):
os.remove(temp_abspath + "c")

if os.path.isfile(self.abspath + "c"):
os.remove(self.abspath + "c")

try:
(SQL_UP, SQL_DOWN)
Expand All @@ -47,6 +90,13 @@ def _get_commands(self):

return SQL_UP, SQL_DOWN

def _check_sql_unicode(self, sql):
try:
sql = unicode(sql.decode(self.script_encoding))
except UnicodeEncodeError:
sql = unicode(sql)
return sql

def compare_to(self, another_migration):
if self.version < another_migration.version:
return -1
Expand All @@ -64,7 +114,7 @@ def is_file_name_valid(file_name):
return match != None

@staticmethod
def create(migration_name, migration_dir='.'):
def create(migration_name, migration_dir='.', script_encoding='utf-8'):
timestamp = strftime("%Y%m%d%H%M%S")
file_name = "%s_%s%s" % (timestamp, migration_name, Migration.MIGRATION_FILES_EXTENSION)

Expand All @@ -74,8 +124,8 @@ def create(migration_name, migration_dir='.'):
new_file_name = "%s/%s" % (migration_dir, file_name)

try:
f = codecs.open(new_file_name, "w", "utf-8")
f.write(Migration.TEMPLATE)
f = codecs.open(new_file_name, "w", script_encoding)
f.write(Migration.TEMPLATE % (script_encoding))
f.close()
except IOError:
raise Exception("could not create file ('%s')" % new_file_name)
Expand All @@ -86,6 +136,7 @@ class SimpleDBMigrate(object):

def __init__(self, config=None):
self._migrations_dir = config.get("migrations_dir")
self._script_encoding=config.get("db_script_encoding", "utf-8")

def get_all_migrations(self):
migrations = []
Expand All @@ -101,7 +152,7 @@ def get_all_migrations(self):

for dir_file in dir_list:
if dir_file.endswith(Migration.MIGRATION_FILES_EXTENSION) and Migration.is_file_name_valid(dir_file):
migration = Migration('%s/%s' % (path, dir_file))
migration = Migration('%s/%s' % (path, dir_file), script_encoding=self._script_encoding)
migrations.append(migration)

if len(migrations) == 0:
Expand Down

0 comments on commit e9beac9

Please sign in to comment.