In [None]:
#hide
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False  # workaround for buggy jedi

In [None]:
# default_exp validation

In [None]:
#export
import math
from sql_formatter.utils import *

# validation

> Functions for basic SQL validation

In [None]:
#hide
from nbdev.showdoc import *

## Motivation

Sometimes, users may make some syntax errors that in turn makes the formatter fail. We want therefore to catch this errors before formatting and give the user a hint to where the error lies

### Missing semicolon to separate queries

One mistake that would make the formatter fail is if the SQL queries are not properly delimited by semicolon. We therefore make a basic validation and look for the keyword CREATE appearing twice in a query. As this cannot happen within one query, the validation should fail and point the user out that she / he may have forgotten a semicolon

In [None]:
#export
def validate_semicolon(s):
    """Validate query `s` by looking for forgotten semicolon. 
    The implication could be the keyword CREATE appearing twice"""
    positions = identify_create_table_view(s)
    validation = {
        "exit_code": 0,
        "total_lines": count_lines(s)
    }
    if len(positions) > 1:
        validation["exit_code"] = 1
        validation["val_lines"] = positions
    return validation

In [None]:
assert_and_print(
    validate_semicolon(
"""
create or replace table my_table as
select asdf, qwer from table1

create view my_view as select asdf from my_table
"""
    ), {"exit_code": 1, "val_lines": [2, 5], "total_lines": 5}
)

{'exit_code': 1, 'total_lines': 5, 'val_lines': [2, 5]}


This should not throw an error because it is not CREATE TABLE / VIEW twice but CREATE TASK + CREATE TABLE

In [None]:
assert_and_print(
    validate_semicolon(
"""
create or replace task my_task as
create or replace table my_table as
select asdf, qwer from table1;

"""
    ), {"exit_code": 0, "total_lines": 5}
)

{'exit_code': 0, 'total_lines': 5}


In [None]:
assert_and_print(
    validate_semicolon(
"""
create or replace table my_table as
select asdf, qwer from table1

"""
    ), {"exit_code": 0, "total_lines": 4}
)

{'exit_code': 0, 'total_lines': 4}


### Unbalanced parenthesis

In [None]:
#export
def validate_balanced_parenthesis(s):
    """Validate query `s` by looking for
    unbalanced parenthesis
    
    exit_code:
    * 0 = balanced parenthesis
    * 1 = unbalanced parenthesis, too many (
    * 2 = unbalanced parenthesis, too many )
    """
    positions = []  # container for position of unbalanced parenthesis
    # counter for comments
    k = 0  # 0 = no comment range
    comment_open1 = False # comment indicator for /* */ comments
    comment_open2 = False  # comment indicator for -- comments
    quote_open1 = False  # quote '
    quote_open2 = False # quote "
    for i, c in enumerate(s):
        if c == "(" and k == 0:
            positions.append(i)
        elif c == ")" and k == 0:
            if len(positions) == 0:
                return {
                    "exit_code": 1,
                    "val_lines": find_line_number(s, [i]),
                    "total_lines": count_lines(s)
                }
            else:
                positions.pop()
        elif (
            s[i:i+2] == "/*" and 
            not comment_open1 and 
            not comment_open2 and
            not quote_open1 and 
            not quote_open2 
        ):  # if there is an opening comment /*
            k += 1
            comment_open1 = True
        elif (
            s[i:i+2] == "*/" and
            comment_open1 and
            not comment_open2 and
            not quote_open1 and
            not quote_open2
        ):  # if there is a closing comment */
            k -= 1
            comment_open1 = False
        elif (
            s[i:i+2] == "--" and 
            not comment_open1 and 
            not comment_open2 and
            not quote_open1 and 
            not quote_open2 
        ):  # if there is an opening comment --
            k += 1
            comment_open2 = True
        elif (
            (c == "\n" or s[i:i+3] == "[c]") and
            not comment_open1 and
            comment_open2 and
            not quote_open1 and
            not quote_open2
        ):  # if the -- comment ends
            k -= 1
            comment_open2 = False
        elif (
            c == "'" and
            not comment_open1 and 
            not comment_open2 and
            not quote_open1 and 
            not quote_open2            
        ):  # if opening quote '
            k += 1
            quote_open1 = True
        elif (
            c == "'" and
            not comment_open1 and 
            not comment_open2 and
            quote_open1 and 
            not quote_open2            
        ):  # if opening quote '
            k -= 1
            quote_open1 = False
        elif (
            c == '"' and
            not comment_open1 and 
            not comment_open2 and
            not quote_open1 and 
            quote_open2            
        ):  # if opening quote '
            k += 1
            quote_open2 = True
        elif (
            c == '"' and
            not comment_open1 and 
            not comment_open2 and
            not quote_open1 and 
            quote_open2            
        ):  # if opening quote '
            k -= 1
            quote_open2 = False
    if len(positions) == 0:
        return {
            "exit_code": 0,
            "total_lines": count_lines(s)
        }
    else:
        return {
            "exit_code": 1,
            "val_lines": find_line_number(s, positions),                
            "total_lines": count_lines(s)
        }

In [None]:
    validate_balanced_parenthesis("() () ( () )")

{'exit_code': 0, 'total_lines': 0}

In [None]:
assert_and_print(
    validate_balanced_parenthesis("() () ( () )"),
    {"exit_code": 0, "total_lines": 0}
)

{'exit_code': 0, 'total_lines': 0}


In [None]:
assert_and_print(
    validate_balanced_parenthesis(
"""
(
(
)
(
"""
    ),
    {"exit_code": 1, "val_lines": [2, 5], "total_lines": 5}
)

{'exit_code': 1, 'val_lines': [2, 5], 'total_lines': 5}


In [None]:
assert_and_print(
    validate_balanced_parenthesis(
"""
( )
-- ) ( )( ) ()
)
"""
    ),
    {"exit_code": 1, "val_lines": [4], "total_lines": 4}
)

{'exit_code': 1, 'val_lines': [4], 'total_lines': 4}


In [None]:
assert_and_print(
    validate_balanced_parenthesis("( )( )"),
    {"exit_code": 0, "total_lines": 0}
)

{'exit_code': 0, 'total_lines': 0}


### Unbalanced `case` ... `end`

Sometimes we may forget to write the end of a case statement

In [None]:
#export
def validate_case_when(s):
    "Validate query `s` looking for unbalanced case ... end"
    case_pos = identify_in_sql("case ", s)  # positions of case when
    end_pos = identify_in_sql("end", s)  # positions of end keywords
    if len(case_pos) == len(end_pos):
        # build pairs
        case_end = [(case_pos[i], end_pos[i]) for i in range(len(case_pos))]
    else:
        # if not same lenght then right padding
        case_pos_len = len(case_pos)
        end_pos_len = len(end_pos)
        max_case_end = max(case_pos_len, end_pos_len)  # maximal positions
        case_pos = case_pos + [math.inf] * (max_case_end - case_pos_len)
        end_pos = end_pos + [-1] * (max_case_end - end_pos_len)
        case_end = [(case_pos[i], end_pos[i]) for i in range(max_case_end)]
    val_positions = []
    for case, end in case_end:
        # if case is missing, then case = infinity > end
        # if end is missing, then end = -1 < case
        if case > end:
            val_positions.append((case, end))
    validation = {
        "exit_code": 0,
        "total_lines": count_lines(s)
    }    
    if len(val_positions) > 0:
        # get line numbers
        val_lines = [
            find_line_number(s, [start])[0] if start != math.inf 
            else find_line_number(s, [end])[0]
            for start, end in val_positions
        ]
        validation["exit_code"] = 1
        validation["val_lines"] = val_lines
    return validation

In [None]:
assert_and_print(
    validate_case_when(
"""
select asdf,
case when bla bla as asdf, -- some case when in comments
qwer
from table1
"""
    ),
    {"exit_code": 1, "val_lines": [3], "total_lines": 5}
)

{'exit_code': 1, 'total_lines': 5, 'val_lines': [3]}


In [None]:
assert_and_print(
    validate_case_when(
"""
select asdf,
case when bla bla end as asdf, -- some case when in comments
qwer
from table1
"""
    ),
    {"exit_code": 0, "total_lines": 5}
)

{'exit_code': 0, 'total_lines': 5}


In [None]:
assert_and_print(
    validate_case_when(
"""
select asdf,
case when bla bla end as asdf, -- some case when in comments
qwer,
case when something else as qwer
from table1
"""
    ),
    {"exit_code": 1, "val_lines": [5], "total_lines": 6}
)

{'exit_code': 1, 'total_lines': 6, 'val_lines': [5]}


In [None]:
assert_and_print(
    validate_case_when(
"""
select asdf,
when bla bla end as asdf, -- some case when in comments
qwer
from table1
"""
    ),
    {"exit_code": 1, "val_lines": [3], "total_lines": 5}
)

{'exit_code': 1, 'total_lines': 5, 'val_lines': [3]}


In [None]:
assert_and_print(
    validate_case_when(
"""
create or replace transient table my_table as /* some table */
select asdf,
qwer,
case when asdf >=1 
and -- some comment
asdf<=10 and substr(qwer, 1, 2) = 'abc' 
and -- some comment
substr(qwer, 3, 2) = 'qwerty' then 1 /* another comment */
    else 0 end as case_field,
substr(case when asdf=1 then 'a' else 'b' end, 1, 2) as end_file,
asdf2,
asdf3
from table1
"""
    ),
    {"exit_code": 0, "total_lines": 14}
)

Assertion failed

Observed:

{'exit_code': 1, 'total_lines': 14, 'val_lines': [11]}


Expected:

{'exit_code': 0, 'total_lines': 14}


KeyError: 0

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_format_file.ipynb.
Converted 02_utils.ipynb.
Converted 03_validation.ipynb.
Converted 04_release.ipynb.
Converted index.ipynb.
