Skip to content

Commit

Permalink
#55: Fixed error when a string in a SQL contains ';' (semicolon).
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermechapiewski committed Aug 10, 2009
1 parent 4d30766 commit 5e36933
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 14 deletions.
4 changes: 2 additions & 2 deletions Makefile
Expand Up @@ -17,8 +17,8 @@ clean:
test:
@make clean
@echo "Starting tests..."
@nosetests -s --verbose --with-coverage --cover-erase --cover-package=cli,config,core,helpers,main,mysql tests/* > /dev/null
@#nosetests -s --verbose --with-coverage --cover-erase --cover-inclusive tests/* > /dev/null
@nosetests -s --verbose --with-coverage --cover-erase --cover-package=cli,config,core,helpers,main,mysql tests/*
@#nosetests -s --verbose --with-coverage --cover-erase --cover-inclusive tests/*
@make clean

install:
Expand Down
13 changes: 12 additions & 1 deletion src/helpers.py
Expand Up @@ -4,4 +4,15 @@ class Lists(object):

@classmethod
def subtract(self, list_a, list_b):
return [l for l in list_a if l not in list_b]
return [l for l in list_a if l not in list_b]

class Utils(object):

@classmethod
def how_many(self, string, match):
if not match or len(match) != 1:
raise Exception("match should be a char")
count = {}
for char in string:
count[char] = count.get(char, 0) + 1
return count.get(match, 0)
42 changes: 32 additions & 10 deletions src/mysql.py
@@ -1,7 +1,10 @@
import re
import sys

import MySQLdb

from helpers import Utils

class MySQL(object):

def __init__(self, config=None, mysql_driver=MySQLdb):
Expand Down Expand Up @@ -44,8 +47,36 @@ def __execute(self, sql):
except Exception, e:
raise Exception("error executing migration (%s)" % e)

def __change_db_version(self, version, up=True):
if up:
# moving up and storing history
sql = "insert into %s (version) values (\"%s\");" % (self.__version_table, str(version))
else:
# moving down and deleting from history
sql = "delete from %s where version >= \"%s\";" % (self.__version_table, str(version))
self.__execute(sql)

def _parse_sql_statements(self, migration_sql):
all_statements = migration_sql.split(';')
all_statements = []
last_statement = ''

for statement in migration_sql.split(';'):
if len(last_statement) > 0:
curr_statement = '%s;%s' % (last_statement, statement)
else:
curr_statement = statement

single_quotes = Utils.how_many(curr_statement, "'")
double_quotes = Utils.how_many(curr_statement, '"')
left_parenthesis = Utils.how_many(curr_statement, '(')
right_parenthesis = Utils.how_many(curr_statement, ')')

if single_quotes % 2 == 0 and double_quotes % 2 == 0 and left_parenthesis == right_parenthesis:
all_statements.append(curr_statement)
last_statement = ''
else:
last_statement = curr_statement

return [s.strip() for s in all_statements if s.strip() != ""]

def _drop_database(self):
Expand Down Expand Up @@ -78,15 +109,6 @@ def _create_version_table_if_not_exists(self):
sql = "insert into %s (version) values (\"0\");" % self.__version_table
self.__execute(sql)

def __change_db_version(self, version, up=True):
if up:
# moving up and storing history
sql = "insert into %s (version) values (\"%s\");" % (self.__version_table, str(version))
else:
# moving down and deleting from history
sql = "delete from %s where version >= \"%s\";" % (self.__version_table, str(version))
self.__execute(sql)

def change(self, sql, new_db_version, up=True):
self.__execute(sql)
self.__change_db_version(new_db_version, up)
Expand Down
16 changes: 15 additions & 1 deletion tests/helpers_test.py
Expand Up @@ -20,4 +20,18 @@ def test_it_should_subtract_lists2(self):

result = Lists.subtract(a, b)

self.assertEquals(len(result), 0)
self.assertEquals(len(result), 0)

class UtilsTest(unittest.TestCase):

def test_it_should_count_chars_in_a_string(self):
word = 'abbbcd;;;;;;;;;;;;;;'
assert Utils.how_many(word, 'a') == 1
assert Utils.how_many(word, 'b') == 3
assert Utils.how_many(word, ';') == 14
assert Utils.how_many(word, '%') == 0

def test_it_should_raise_exception_when_char_to_match_is_not_valid(self):
self.assertRaises(Exception, Utils.how_many, 'whatever', 'what')
self.assertRaises(Exception, Utils.how_many, 'whatever', None)
self.assertRaises(Exception, Utils.how_many, 'whatever', '')
30 changes: 30 additions & 0 deletions tests/mysql_test.py
Expand Up @@ -159,6 +159,36 @@ def test_it_should_parse_sql_statements(self):
assert statements[0] == 'create table eggs'
assert statements[1] == 'drop table spam'

def test_it_should_parse_sql_statements_with_html_inside(self):
mysql_driver_mock = Mock()
db_mock = Mock()
cursor_mock = Mock()
self.__mock_db_init(mysql_driver_mock, db_mock, cursor_mock)
mysql = MySQL(self.__config, mysql_driver_mock)

sql = u"""
create table eggs;
INSERT INTO widget_parameter_domain (widget_parameter_id, label, value)
VALUES ((SELECT MAX(widget_parameter_id)
FROM widget_parameter), "Carros", '<div class="box-zap-geral">
<div class="box-zap box-zap-autos">
<a class="logo" target="_blank" title="ZAP" href="http://www.zap.com.br/Parceiros/g1/RedirG1.aspx?CodParceriaLink=42&amp;URL=http://www.zap.com.br">');
drop table spam;
"""
statements = mysql._parse_sql_statements(sql)

expected_sql_with_html = """INSERT INTO widget_parameter_domain (widget_parameter_id, label, value)
VALUES ((SELECT MAX(widget_parameter_id)
FROM widget_parameter), "Carros", '<div class="box-zap-geral">
<div class="box-zap box-zap-autos">
<a class="logo" target="_blank" title="ZAP" href="http://www.zap.com.br/Parceiros/g1/RedirG1.aspx?CodParceriaLink=42&amp;URL=http://www.zap.com.br">')"""

assert len(statements) == 3, 'expected %s, got %s' % (3, len(statements))
assert statements[0] == 'create table eggs'
assert statements[1] == expected_sql_with_html, 'expected "%s", got "%s"' % (expected_sql_with_html, statements[1])
assert statements[2] == 'drop table spam'

if __name__ == "__main__":
unittest.main()

0 comments on commit 5e36933

Please sign in to comment.