In [None]:
#hide
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False  # workaround for buggy jedi

In [None]:
# default_exp format_file

In [None]:
#export
import re
import os
import tempfile
from glob import glob
from fastcore.script import call_parse, Param, store_true
from sql_formatter.core import *
from sql_formatter.utils import *
from sql_formatter.validation import *

# format_file

> Functions to format a SQL file with multiple queries and SQL statements

In [None]:
#hide
from nbdev.showdoc import *

## Use-Case

Assume you have a file called sql_file.sql containing SQL statements and queries.

After reading it in python we could have something like this:

In [None]:
sql_file = """
--- Views for some nice data mart ---
use database my_database;
use schema my_schema;

create or replace view first_view as -- my first view
select a.car_id,
       b.car_name, sum(a.price) over (partition by b.car_name order by a.car_id) as sum_price, a.price,
from sales as a left join (select car_id, car_name, from cars) as b 
on a.car_id = b.car_id
where car_id>1 and car_id<=100 order by b.car_name;

-- Table no. 1 --
create or replace table first_table as -- my first table
select car_id,
       avg(price) as avg_price,
from first_view
group by car_id order by car_id;

--- End of file ---
""".strip()

Then we would like to format the SQL-queries in this file, while letting every other non-query-SQL statement untouched. For the example above we would like to have something like this:

In [None]:
expected_sql_file = """
--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       sum(a.price) OVER (PARTITION BY b.car_name
                          ORDER BY a.car_id) as sum_price,
       a.price
FROM   sales as a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) as b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100
ORDER BY b.car_name;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) as avg_price
FROM   first_view
GROUP BY car_id
ORDER BY car_id;

--- End of file ---
""".lstrip()

### Formatting philosophy of SQL files

* Every SQL-query is separated from above by two new lines
* Every SQL-query is formatted via `format_sql`

### Main function to format the SQL commands in a file

This function applies also basic validation and aborts formatting if the statements `CREATE .. TABLE / VIEW` appear at least twice in the same query after splitting by semicolon, warning the user that she / he may have forgotten a semicolon

In [None]:
#export
def format_sql_commands(s):
    "Format SQL commands in `s`"
    s = s.strip()  # strip file contents
    split_s = split_by_semicolon(s)  # split by query
    # validate semicolon
    validations_semicolon = [validate_semicolon(sp) for sp in split_s]
    val_summary_semicolon = sum([val["exit_code"] for val in validations_semicolon])
    # validate balanced parenthesis
    validations_balanced = [validate_balanced_parenthesis(sp) for sp in split_s]
    val_summary_balanced = sum([val["exit_code"] for val in validations_balanced])
    # validate balanced case when ... end
    val_case_end_balanced = [validate_case_when(sp) for sp in split_s]
    val_summary_case = sum([val["exit_code"] for val in val_case_end_balanced])
    if sum([val_summary_semicolon, val_summary_balanced, val_summary_case]) == 0:
        split_comment_after_semicolon = re.compile("((?:\n|create|select))")
        check_comment_after_semicolon = re.compile(r"[\r\t\f\v ]*(?:\/\*|--)")
        split_s_out = []  # initialize container
        for sp in split_s:  # split by semicolon
            # take care of comment after semicolon
            # split by first newline and format only the second item
            if check_comment_after_semicolon.match(sp):
                split_s2 = split_comment_after_semicolon.split(sp, maxsplit=1)
            else:
                split_s2 = [sp]
            formatted_split_s2 = [
                "\n\n\n" + format_sql(sp).strip()
                if check_sql_query(sp) and not check_skip_marker(sp)
                else sp
                for sp in split_s2
            ]
            split_s_out.append("".join(formatted_split_s2))
        # join by semicolon
        formatted_s = ";".join(split_s_out)
        # remove starting and ending newlines
        formatted_s = formatted_s.strip()
        # remove more than 3 newlines
        formatted_s = re.sub(r"\n{4,}", "\n\n\n", formatted_s)
        # add newline at the end of file
        formatted_s = formatted_s + "\n"
        return formatted_s
    else:
        error_dict = {}
        if val_summary_semicolon > 0:
            file_lines = [
                tuple([line + sum([sd["total_lines"] for sd in validations_semicolon[0:i]]) for line in d["val_lines"]])
                for i, d in enumerate(validations_semicolon)
                if d["exit_code"] == 1
            ]
            error_dict["semicolon"] = {
                "error_code": 2,
                "lines": file_lines
            }
        if val_summary_balanced > 0:
            file_lines = [
                [line + sum([sd["total_lines"] for sd in validations_balanced[0:i]]) for line in d["val_lines"]]
                for i, d in enumerate(validations_balanced)
                if d["exit_code"] == 1
            ]            
            error_dict["unbalanced_parenthesis"] = {
                "error_code": 3,
                "lines": file_lines
            }
        if val_summary_case > 0:
            file_lines = [
                [line + sum([sd["total_lines"] for sd in val_case_end_balanced[0:i]]) for line in d["val_lines"]]
                for i, d in enumerate(val_case_end_balanced)
                if d["exit_code"] == 1
            ]            
            error_dict["unbalanced_case"] = {
                "error_code": 4,
                "lines": file_lines
            }            
        return error_dict

Basic file formatting

In [None]:
assert_and_print(
    format_sql_commands(sql_file),
    expected_sql_file
)

--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       sum(a.price) OVER (PARTITION BY b.car_name
                          ORDER BY a.car_id) as sum_price,
       a.price
FROM   sales as a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) as b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100
ORDER BY b.car_name;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) as avg_price
FROM   first_view
GROUP BY car_id
ORDER BY car_id;

--- End of file ---



Using the `/*skip-formatter*/` marker to not format some query

In [None]:
assert_and_print(
    format_sql_commands("""
use database my_database;

/*skip-formatter*/
create Or replace View my_view aS
select asdf, qwer
from table1;

create or replace table my_table As
Select asdf, qwer
From table2
group by asdf;
"""),
    """use database my_database;

/*skip-formatter*/
create Or replace View my_view aS
select asdf, qwer
from table1;


CREATE OR REPLACE TABLE my_table AS
SELECT asdf,
       qwer
FROM   table2
GROUP BY asdf;
""")

use database my_database;

/*skip-formatter*/
create Or replace View my_view aS
select asdf, qwer
from table1;


CREATE OR REPLACE TABLE my_table AS
SELECT asdf,
       qwer
FROM   table2
GROUP BY asdf;



In [None]:
assert_and_print(
    format_sql_commands("""
create or replace table my_table As
Select asdf, qwer
From table2
group by asdf;
"""),
    """
CREATE OR REPLACE TABLE my_table AS
SELECT asdf,
       qwer
FROM   table2
GROUP BY asdf;
""".lstrip())

CREATE OR REPLACE TABLE my_table AS
SELECT asdf,
       qwer
FROM   table2
GROUP BY asdf;



In [None]:
assert_and_print(
    format_sql_commands(
"""
create table my_table As
select asdf, Qwer, /* ; */
qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1; /* Some comment */

create view my_view As
Select asdf
From my_table; /* Another comment */
"""
    ),
"""
CREATE TABLE my_table AS
SELECT asdf,
       qwer, /* ; */
       qwer2, -- ;
       replace(';', '', qwer3) as qwer4
FROM   table1; /* Some comment */


CREATE VIEW my_view AS
SELECT asdf
FROM   my_table; /* Another comment */
""".lstrip()
)

CREATE TABLE my_table AS
SELECT asdf,
       qwer, /* ; */
       qwer2, -- ;
       replace(';', '', qwer3) as qwer4
FROM   table1; /* Some comment */


CREATE VIEW my_view AS
SELECT asdf
FROM   my_table; /* Another comment */



In [None]:
assert_and_print(
    format_sql_commands(
"""
create or replace transient table my_table As
select asdf, Qwer, /* ; */
qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1;

create view my_view As
Select asdf
From my_table;
""".lstrip()
    ),
"""
CREATE OR REPLACE TRANSIENT TABLE my_table AS
SELECT asdf,
       qwer, /* ; */
       qwer2, -- ;
       replace(';', '', qwer3) as qwer4
FROM   table1;


CREATE VIEW my_view AS
SELECT asdf
FROM   my_table;
""".lstrip()
)

CREATE OR REPLACE TRANSIENT TABLE my_table AS
SELECT asdf,
       qwer, /* ; */
       qwer2, -- ;
       replace(';', '', qwer3) as qwer4
FROM   table1;


CREATE VIEW my_view AS
SELECT asdf
FROM   my_table;



If the validation fails, then the function returns a dictionary instead of the formatted queries with information about the error

Semicolon validation error

In [None]:
assert_and_print(
    format_sql_commands(
"""
create or replace transient table my_table As
select asdf, Qwer, /* ; */
qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1

create view my_view As
Select asdf
From my_table;
""".lstrip()
    ), 
    {"semicolon": {"error_code": 2, "lines": [(1, 7)]}}
)

{'semicolon': {'error_code': 2, 'lines': [(1, 7)]}}


Unbalanced parenthesis error

In [None]:
assert_and_print(
    format_sql_commands(
"""
create or replace transient table my_table As
select asdf, Qwer, /* ; */
(qwer2, -- ; ()
( /* ) */
replace(';', '', qwer3) as Qwer4
from table1;

create view my_view As
(Select asdf
From my_table;
""".lstrip()
    ), 
    {"unbalanced_parenthesis": {"error_code": 3, "lines": [[3, 4], [9]]}}
)

{'unbalanced_parenthesis': {'error_code': 3, 'lines': [[3, 4], [9]]}}


Unbalanced parenthesis + semicolon error

In [None]:
assert_and_print(
    format_sql_commands(
"""
create or replace transient table my_table As
select asdf, Qwer, /* ; */
(qwer2, -- ; ()
( /* ) */
replace(';', '', qwer3) as Qwer4
from table1

create view my_view As
(Select asdf
From my_table;
""".lstrip()
    ), 
    {
        "semicolon": {"error_code": 2, "lines": [(1, 8)]},        
        "unbalanced_parenthesis": {"error_code": 3, "lines": [[3, 4, 9]]},
    }
)

{'semicolon': {'error_code': 2, 'lines': [(1, 8)]}, 'unbalanced_parenthesis': {'error_code': 3, 'lines': [[3, 4, 9]]}}


Unbalanced case when ... end

In [None]:
assert_and_print(
    format_sql_commands(
"""
create or replace transient table my_table As
select asdf, Qwer, /* ; */
case when asdf = 1 then 1 as qwer,
replace(';', '', qwer3) as Qwer4
from table1;

create view my_view As
""".lstrip()
    ), 
    {"unbalanced_case": {"error_code": 4, "lines": [[3]]}}
)

{'unbalanced_case': {'error_code': 4, 'lines': [[3]]}}


### Function to format 1 SQL file

In [None]:
#export
def format_sql_file(f):
    """Format file `f` with SQL commands and overwrite the file.
    
    Return exit_code:
    * 0 = Everything already formatted
    * 1 = Formatting applied
    * 2 = Problem detected, formatting aborted
    """
    # open the file
    with open(f, "r") as file:
        sql_commands = file.read()
    # format SQL statements
    formatted_file = format_sql_commands(sql_commands)
    if isinstance(formatted_file, dict):
        print(f"Something went wrong in file: {f}")
        if "semicolon" in formatted_file.keys():
            print(
                (
                "[WARNING] Identified CREATE keyword more than twice within the same query " +
                f"at lines {formatted_file['semicolon']['lines']}\n"
                "You may have forgotten a semicolon (;) to delimit the queries"
                )
            )
        if "unbalanced_parenthesis" in formatted_file.keys():
            print(
                (
                "[WARNING] Identified unbalanced parenthesis " +
                f"at lines {formatted_file['unbalanced_parenthesis']['lines']}\n"
                "You should check your parenthesis"
                )
            )
        if "unbalanced_case" in formatted_file.keys():
            print(
                (
                "[WARNING] Identified unbalanced case when ... end " +
                f"at lines {formatted_file['unbalanced_case']['lines']}\n"
                "You should check for missing case or end keywords"
                )
            )            
        print(f"Aborting formatting for file {f}")
        exit_code = 2

        print(f"Aborting formatting for file {f}")
        exit_code = 2

    else:
        exit_code = 0 if sql_commands == formatted_file else 1
        # overwrite file
        with open(f, "w") as f:
            f.write(formatted_file)
    return exit_code

In [None]:
with tempfile.NamedTemporaryFile(mode="r+") as file:
    file.write(sql_file)
    file.seek(0)
    format_sql_file(file.name)
    formatted_file = file.read()
assert_and_print(
    formatted_file,
    expected_sql_file
)

--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       sum(a.price) OVER (PARTITION BY b.car_name
                          ORDER BY a.car_id) as sum_price,
       a.price
FROM   sales as a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) as b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100
ORDER BY b.car_name;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) as avg_price
FROM   first_view
GROUP BY car_id
ORDER BY car_id;

--- End of file ---



In [None]:
sql_forgotten_semicolon = """
create or replace transient table my_table As
select asdf, Qwer, /* ; */
qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1

create view my_view As
Select asdf
From my_table;
""".lstrip()
with tempfile.NamedTemporaryFile(mode="r+") as file:
    file.write(sql_forgotten_semicolon)
    file.seek(0)
    exit_code = format_sql_file(file.name)
    formatted_file = file.read()
assert_and_print( # no formatting
    formatted_file,
    sql_forgotten_semicolon
)
assert exit_code == 2

Something went wrong in file: /tmp/tmp6lglp_mv
You may have forgotten a semicolon (;) to delimit the queries
Aborting formatting for file /tmp/tmp6lglp_mv
Aborting formatting for file /tmp/tmp6lglp_mv
create or replace transient table my_table As
select asdf, Qwer, /* ; */
qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1

create view my_view As
Select asdf
From my_table;



In [None]:
sql_forgotten_semicolon = """
create or replace transient table my_table As
select asdf, Qwer, /* ; */
(qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1;

create view my_view As
(Select asdf
From my_table;
""".lstrip()
with tempfile.NamedTemporaryFile(mode="r+") as file:
    file.write(sql_forgotten_semicolon)
    file.seek(0)
    exit_code = format_sql_file(file.name)
    formatted_file = file.read()
assert_and_print( # no formatting
    formatted_file,
    sql_forgotten_semicolon
)
assert exit_code == 2

Something went wrong in file: /tmp/tmp5qkr0ekj
You should check your parenthesis
Aborting formatting for file /tmp/tmp5qkr0ekj
Aborting formatting for file /tmp/tmp5qkr0ekj
create or replace transient table my_table As
select asdf, Qwer, /* ; */
(qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1;

create view my_view As
(Select asdf
From my_table;



In [None]:
sql_forgotten_semicolon = """
create or replace transient table my_table As
select asdf, Qwer, /* ; */
(qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1

create view my_view As
(Select asdf
From my_table;
""".lstrip()
with tempfile.NamedTemporaryFile(mode="r+") as file:
    file.write(sql_forgotten_semicolon)
    file.seek(0)
    exit_code = format_sql_file(file.name)
    formatted_file = file.read()
assert_and_print( # no formatting
    formatted_file,
    sql_forgotten_semicolon
)
assert exit_code == 2

Something went wrong in file: /tmp/tmpw65h7x33
You may have forgotten a semicolon (;) to delimit the queries
You should check your parenthesis
Aborting formatting for file /tmp/tmpw65h7x33
Aborting formatting for file /tmp/tmpw65h7x33
create or replace transient table my_table As
select asdf, Qwer, /* ; */
(qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1

create view my_view As
(Select asdf
From my_table;



### Function to format many SQL files

With built-in CLI via `fastcore`

In [None]:
#export
@call_parse
def format_sql_files(
    files: Param(help='(Relative) path to SQL files. You can also use wildcard using ".*sql"', type=str, nargs="+"),
    recursive: Param(help="Should files also be searched in subfolders?", type=store_true)=False
):
    "Format SQL `files`"
    exit_codes = []
    # if wildcard "*" is input then use it
    if len(files) == 1 and re.search("\*", files[0]):
        if recursive:  # if recursive search
            files = glob(os.path.join("**", files[0]), recursive=True)
        else:
            files = glob(files[0])
    for file in files:
        exit_codes.append(format_sql_file(file))
    if sum(exit_codes) == 0:
        print("Nothing to format, everything is fine!")
    else:
        print("All specified files were formatted!")

In [None]:
with tempfile.TemporaryDirectory() as tmp_dir:
    with open(os.path.join(tmp_dir, "tmp"), "w") as f:
        f.write(sql_file)
    with open(os.path.join(tmp_dir, "tmp2"), "w") as f:
        f.write(sql_file)
    format_sql_files([os.path.join(tmp_dir, "tmp"), os.path.join(tmp_dir, "tmp2")])
    with open(os.path.join(tmp_dir, "tmp"), "r") as f:
        formatted_file = f.read()
    assert_and_print(
        formatted_file,
        expected_sql_file
    )
    with open(os.path.join(tmp_dir, "tmp2"), "r") as f:
        formatted_file = f.read()
    assert_and_print(
        formatted_file,
        expected_sql_file
    )

All specified files were formatted!
--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       sum(a.price) OVER (PARTITION BY b.car_name
                          ORDER BY a.car_id) as sum_price,
       a.price
FROM   sales as a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) as b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100
ORDER BY b.car_name;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) as avg_price
FROM   first_view
GROUP BY car_id
ORDER BY car_id;

--- End of file ---

--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       sum(a.price) OVER (PARTITION BY b.car_name
                          ORDER BY 

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_format_file.ipynb.
Converted 02_utils.ipynb.
Converted 03_validation.ipynb.
Converted 04_release.ipynb.
Converted index.ipynb.
