Skip to content

Commit

Permalink
add create table column detail
Browse files Browse the repository at this point in the history
  • Loading branch information
caoshuai03 committed Nov 21, 2019
1 parent 0a655e7 commit 2f5a382
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 46 deletions.
13 changes: 13 additions & 0 deletions mysqltokenparser/constant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# coding: utf-8

# common
COLUMN_NAME = 'column_name'
COLUMN_DEFINITION = 'column_definition'
COLUMN_TYPE = 'column_type'
COLUMN_TYPE_DATA = 'column_type_data'
TABLE_NAME = 'table_name'
CREATE_DEFINITIONS = 'create_definitions'
INDEX_NAME = 'index_name'
INDEX_DEFINITION = 'index_definition'
PRIMARY_KEY = 'primary_key'
UNIQUE_KEY = 'unique_key'
COMMON_KEY = 'common_key'

# sql type
SQL_TYPE_DDL = 'ddl'
SQL_TYPE_DML = 'dml'
Expand Down
84 changes: 62 additions & 22 deletions mysqltokenparser/sqltypemixins/altertable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mysqltokenparser.utils import iterchild
from mysqltokenparser.MySqlParser import MySqlParser
from mysqltokenparser.constant import DDL_TYPE_ALTERTABLE
from mysqltokenparser.constant import *


class AlterTableMixin:
Expand All @@ -18,7 +18,7 @@ def enterAlterTable(self, ctx):
children = ctx.children
for child in children:
if isinstance(child, MySqlParser.TableNameContext):
data['tablename'] = self._get_last_name(child)
data[TABLE_NAME] = self._get_last_name(child)
if isinstance(child, MySqlParser.AlterByAddColumnContext):
alter_data.append({
"type": 'addcolumn',
Expand Down Expand Up @@ -53,18 +53,18 @@ def enterAlterTable(self, ctx):
@iterchild
def _enterAlterByDropColumn(self, child, ret):
if isinstance(child, MySqlParser.UidContext):
ret['columnname'] = self._get_last_name(child)
ret[COLUMN_NAME] = self._get_last_name(child)

@iterchild
def _enterAlterByAddIndex(self, child, ret):
columnnames = []
if isinstance(child, MySqlParser.UidContext):
ret['indexname'] = self._get_last_name(child)
ret[INDEX_NAME] = self._get_last_name(child)
if isinstance(child, MySqlParser.IndexColumnNamesContext):
columnnames = self._enterIndexColumnNames(child).get('columns', [])

ret['indexdefinition'] = {
'columnnames': columnnames
ret[INDEX_DEFINITION] = {
COLUMN_NAME: columnnames
}

@iterchild
Expand All @@ -76,26 +76,26 @@ def _enterIndexColumnNames(self, child, ret):
@iterchild
def _enterAlterByChangeColumn(self, child, ret):
if isinstance(child, MySqlParser.UidContext):
if ret.get('columnname'):
if ret.get(COLUMN_NAME):
ret['new_columnname'] = self._get_last_name(child)
else:
ret['columnname'] = self._get_last_name(child)
ret[COLUMN_NAME] = self._get_last_name(child)
if isinstance(child, MySqlParser.ColumnDefinitionContext):
ret['columndefinition'] = self._enterColumnDefinition(child)
ret[COLUMN_DEFINITION] = self._enterColumnDefinition(child)

@iterchild
def _enterAlterByModifyColumn(self, child, ret):
if isinstance(child, MySqlParser.UidContext):
ret['columnname'] = self._get_last_name(child)
ret[COLUMN_NAME] = self._get_last_name(child)
if isinstance(child, MySqlParser.ColumnDefinitionContext):
ret['columndefinition'] = self._enterColumnDefinition(child)
ret[COLUMN_DEFINITION] = self._enterColumnDefinition(child)

@iterchild
def _enterAlterByAddColumn(self, child, ret):
if isinstance(child, MySqlParser.UidContext):
ret['columnname'] = self._get_last_name(child)
ret[COLUMN_NAME] = self._get_last_name(child)
if isinstance(child, MySqlParser.ColumnDefinitionContext):
ret['columndefinition'] = self._enterColumnDefinition(child)
ret[COLUMN_DEFINITION] = self._enterColumnDefinition(child)

@iterchild
def _enterColumnDefinition(self, child, ret):
Expand All @@ -106,28 +106,68 @@ def _enterColumnDefinition(self, child, ret):
if isinstance(child, MySqlParser.SimpleDataTypeContext):
ret.update(self._enterSimpleDataType(child))

if isinstance(child, MySqlParser.NullColumnConstraintContext):
ret.update(self._enterNullColumnConstraint(child))
if isinstance(child, MySqlParser.AutoIncrementColumnConstraintContext):
ret.update(self._enterAutoIncrementColumnConstraint(child))
if isinstance(child, MySqlParser.CommentColumnConstraintContext):
ret.update(self._enterCommentColumnConstraint(child))
if isinstance(child, MySqlParser.PrimaryKeyColumnConstraintContext):
ret.update(self._enterPrimaryKeyColumnConstraint(child))

def _enterPrimaryKeyColumnConstraint(self, ctx):
return {
PRIMARY_KEY: True
}

def _enterCommentColumnConstraint(self, ctx):
return {
'comment': self._get_last_name(ctx.children[1])
}

def _enterAutoIncrementColumnConstraint(self, ctx):
return {
'auto_increment': True
}

@iterchild
def _enterNullColumnConstraint(self, child, ret):
if isinstance(child, MySqlParser.NullNotnullContext):
ret['null'] = False

@iterchild
def _enterSimpleDataType(self, child, ret):
ret.update({
'column_types': self._get_last_name(child),
'data': {}
COLUMN_TYPE: self._get_last_name(child),
COLUMN_TYPE_DATA: {}
})

@iterchild
def _enterDimensionDataType(self, child, ret):
ret.update({
'column_types': self._get_last_name(child),
'data': {}
})
if isinstance(child, antlr4.tree.Tree.TerminalNodeImpl):
ret[COLUMN_TYPE] = self._get_last_name(child)
if isinstance(child, MySqlParser.LengthOneDimensionContext):
ret[COLUMN_TYPE_DATA] = self._enterLengthOneDimension(child)
if isinstance(child, MySqlParser.LengthTwoDimensionContext):
ret[COLUMN_TYPE_DATA] = self._enterLengthTwoDimension(child)
if isinstance(child, MySqlParser.LengthTwoOptionalDimensionContext):
ret[COLUMN_TYPE_DATA] = self._enterLengthTwoDimension(child)

@iterchild
def _enterStringDataType(self, child, ret):
if isinstance(child, antlr4.tree.Tree.TerminalNodeImpl):
ret['column_type'] = self._get_last_name(child)
ret[COLUMN_TYPE] = self._get_last_name(child)
if isinstance(child, MySqlParser.LengthOneDimensionContext):
ret['data'] = self._enterLengthOneDimension(child)
ret[COLUMN_TYPE_DATA] = self._enterLengthOneDimension(child)

@iterchild
def _enterLengthTwoDimension(self, child, ret):
if isinstance(child, MySqlParser.DecimalLiteralContext):
decimal_literal = ret.setdefault('decimal_literal', [])
decimal_literal.append(self._get_last_name(child))

@iterchild
def _enterLengthOneDimension(self, child, ret):
if isinstance(child, MySqlParser.DecimalLiteralContext):
ret['decimalliteral'] = self._get_last_name(child)
decimal_literal = ret.setdefault('decimal_literal', [])
decimal_literal.append(self._get_last_name(child))
33 changes: 21 additions & 12 deletions mysqltokenparser/sqltypemixins/createtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def enterColumnCreateTable(self, ctx):
children = ctx.children
for child in children:
if isinstance(child, MySqlParser.TableNameContext):
data['table_name'] = self._get_last_name(child)
data[TABLE_NAME] = self._get_last_name(child)

if isinstance(child, MySqlParser.CreateDefinitionsContext):
data['create_definitions'] = self._enterCreateDefinitions(child)
data[CREATE_DEFINITIONS] = self._enterCreateDefinitions(child)

if isinstance(child, MySqlParser.TableOptionEngineContext):
data[TABLE_OPTION_ENGINE] = self._enterTableOptionEngine(
Expand All @@ -33,15 +33,24 @@ def enterColumnCreateTable(self, ctx):
data[TABLE_OPTION_CHARSET] = self._enterTableOptionCharset(
child).get(TABLE_OPTION_CHARSET)

if isinstance(child, MySqlParser.TableOptionCommentContext):
data[TABLE_OPTION_COMMENT] = self._enterTableOptionComment(child)

def _enterTableOptionComment(self, ctx):
try:
return self._get_last_name(ctx.children[2])
except Exception as e:
return ''

@iterchild
def _enterTableOptionCharset(self, child, ret):
if isinstance(child, MySqlParser.CharsetNameContext):
ret['charset'] = self._get_last_name(child)
ret[TABLE_OPTION_CHARSET] = self._get_last_name(child)

@iterchild
def _enterTableOptionEngine(self, child, ret):
if isinstance(child, MySqlParser.EngineNameContext):
ret['engine'] = self._get_last_name(child)
ret[TABLE_OPTION_ENGINE] = self._get_last_name(child)

@iterchild
def _enterCreateDefinitions(self, child, ret):
Expand All @@ -55,7 +64,7 @@ def _enterCreateDefinitions(self, child, ret):

if isinstance(child, MySqlParser.IndexDeclarationContext):
indexs = ret.setdefault('indexs', {})
indexs.setdefault('common_key', []).append(self._enterIndexDeclaration(child))
indexs.setdefault(COMMON_KEY, []).append(self._enterIndexDeclaration(child))

@iterchild
def _enterIndexDeclaration(self, child, ret):
Expand All @@ -66,21 +75,21 @@ def _enterIndexDeclaration(self, child, ret):
@iterchild
def _enterSimpleIndexDeclaration(self, child, ret):
if isinstance(child, MySqlParser.UidContext):
ret['indexname'] = self._get_last_name(child)
ret[INDEX_NAME] = self._get_last_name(child)
if isinstance(child, MySqlParser.IndexColumnNamesContext):
ret['columnnames'] = self._enterIndexColumnNames(child).get('columns', [])
ret['columns'] = self._enterIndexColumnNames(child).get('columns', [])

@iterchild
def _enterConstraintDeclaration(self, child, ret):
if isinstance(child, MySqlParser.PrimaryKeyTableConstraintContext):
ret['primary_key'] = self._enterPrimaryKeyTableConstraint(child)
ret[PRIMARY_KEY] = self._enterPrimaryKeyTableConstraint(child)
if isinstance(child, MySqlParser.UniqueKeyTableConstraintContext):
ret['unique_key'] = self._enterUniqueKeyTableConstraint(child)
ret[UNIQUE_KEY] = self._enterUniqueKeyTableConstraint(child)

@iterchild
def _enterUniqueKeyTableConstraint(self, child, ret):
if isinstance(child, MySqlParser.UidContext):
ret['indexname'] = self._get_last_name(child)
ret[INDEX_NAME] = self._get_last_name(child)
if isinstance(child, MySqlParser.IndexColumnNamesContext):
ret['columns'] = self._enterIndexColumnNames(child).get('columns', [])

Expand All @@ -92,6 +101,6 @@ def _enterPrimaryKeyTableConstraint(self, child, ret):
@iterchild
def _enterColumnDeclaration(self, child, ret):
if isinstance(child, MySqlParser.UidContext):
ret['columnname'] = self._get_last_name(child)
ret[COLUMN_NAME] = self._get_last_name(child)
if isinstance(child, MySqlParser.ColumnDefinitionContext):
ret['columndefinition'] = self._enterColumnDefinition(child)
ret[COLUMN_DEFINITION] = self._enterColumnDefinition(child)
1 change: 1 addition & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
))

from mysqltokenparser import mysqltokenparser
from mysqltokenparser import constant
26 changes: 14 additions & 12 deletions tests/test_createtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from helper import mysqltokenparser as mtp
from helper import constant as _c


@pytest.fixture
Expand All @@ -25,10 +26,11 @@ def test_createtable(response):

sql = u"""
CREATE TABLE tab_name (
id int NOT NULL AUTO_INCREMENT COMMENT '主键',
uid int NOT NULL COMMENT '唯一流水id',
id int NOT NULL AUTO_INCREMENT PRIMARY KEY COMMENT '主键',
uid int(2) NOT NULL COMMENT '唯一流水id',
okmysite DECIMAL(6,2) NOT NULL COMMENT 'cccc',
name varchar(20) NOT NULL DEFAULT '' COMMENT '名称',
amount int NOT NULL DEFAULT 0 COMMENT '数量',
amount DOUBLE(6,2) NOT NULL DEFAULT 0 COMMENT '数量',
create_date date NOT NULL DEFAULT '1000-01-01' COMMENT '创建日期',
create_time datetime DEFAULT '1000-01-01 00:00:00' COMMENT '创建时间',
update_time timestamp default current_timestamp on update current_timestamp COMMENT '更新时间(会自动更新,不需要刻意程序更新)',
Expand All @@ -44,20 +46,20 @@ def test_createtable(response):
assert isinstance(tokens, dict)

hope_tablename = 'tab_name'
assert hope_tablename == tokens['data']['data']['table_name']
assert hope_tablename == tokens['data']['data'][_c.TABLE_NAME]

hope_engine = 'InnoDB'
assert hope_engine == tokens['data']['data']['engine']
assert hope_engine == tokens['data']['data'][_c.TABLE_OPTION_ENGINE]

hope_charset = 'utf8'
assert hope_charset == tokens['data']['data']['charset']
assert hope_charset == tokens['data']['data'][_c.TABLE_OPTION_CHARSET]

hope_column_len = 7
assert hope_column_len == len(tokens['data']['data']['create_definitions']['columns'])
hope_column_len = 8
assert hope_column_len == len(tokens['data']['data'][_c.CREATE_DEFINITIONS]['columns'])

hope_common_index_len = 2
assert hope_common_index_len == len(tokens['data']['data']['create_definitions']['indexs']['common_key'])
assert hope_common_index_len == len(tokens['data']['data'][_c.CREATE_DEFINITIONS]['indexs'][_c.COMMON_KEY])

hope_columnname = ["id", "uid", "name", "amount", "create_date", "create_time", "update_time"]
for i in tokens['data']['data']['create_definitions']['columns']:
assert i['columnname'] in hope_columnname
hope_columnname = ["id", "uid", "name", "amount", "create_date", "create_time", "update_time", "okmysite"]
for i in tokens['data']['data'][_c.CREATE_DEFINITIONS]['columns']:
assert i[_c.COLUMN_NAME] in hope_columnname

0 comments on commit 2f5a382

Please sign in to comment.