In [None]:
#hide
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False  # workaround for buggy jedi

In [None]:
# default_exp format_file

In [None]:
#export
import re
import os
from fastcore.script import call_parse, Param
from sql_formatter.core import format_sql, assert_and_print

# format_file

> Functions to format a SQL file with multiple queries and SQL statements

In [None]:
#hide
from nbdev.showdoc import *

## Use-Case

Assume you have a file called sql_file.sql containing SQL statements and queries.

After reading it in python we could have something like this:

In [None]:
sql_file = """
--- Views for some nice data mart ---
use database my_database;
use schema my_schema;

create or replace view first_view as -- my first view
select a.car_id,
       b.car_name,
       a.price,
from sales as a left join (select car_id, car_name, from cars) as b 
on a.car_id = b.car_id
where car_id>1 and car_id<=100;

-- Table no. 1 --
create or replace table first_table as -- my first table
select car_id,
       avg(price) as avg_price,
from first_view
group by car_id;

--- End of file ---
""".strip()

Then we would like to format the SQL-queries in this file, while letting every other non-query-SQL statement untouched. For the example above we would like to have something like this:

In [None]:
expected_sql_file = """
--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       a.price
FROM   sales as a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) as b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) as avg_price
FROM   first_view
GROUP BY car_id;

--- End of file ---
""".lstrip()

To simplify the problem, we assume for the moment the user is kind enough to separate every SQL statement by a semicolon (;)

### Formatting philosophy of SQL files

* Every SQL-query is separated from above and below by two new lines
* Every SQL-query is formatted via `format_sql`

So our first step is to identify the lines corresponding to the SQL queries

## Formatting Functions

### SQL-queries identification

SQL queries always start either with *CREATE* or with *SELECT*.

Everything else is not a valid SQL statement and will therefore not be formatted

In [None]:
#export
def check_sql_query(s):
    "Checks whether `s` is a SQL query"
    return bool(re.match(pattern=r".*(?:select|create).*", string=s, flags=re.I | re.DOTALL))

In [None]:
assert check_sql_query("""
--- Table 1---
creaTe or replace table my_table as
select asdf
from table
where asdf = 1
""".strip())

In [None]:
assert check_sql_query("""
SELECT qwer, asdf
""")

In [None]:
assert not check_sql_query("use database my_database;")

In [None]:
assert not check_sql_query("""

use schema my_schema;
""")

### Main function formattin SQL commands

In [None]:
#export
def format_sql_commands(s):
    "Format SQL commands in `s`"
    s = s.strip()  # strip file contents
    split_s = s.split(";")  # split by semicolon
    # format only SQL queries, let everything else unchanged
    formatted_split_s = [
        "\n\n\n" + format_sql(sp, add_semicolon=False).strip()
        if check_sql_query(sp)
        else sp
        for sp in split_s
    ]
    # join by semicolon
    formatted_s = ";".join(formatted_split_s)
    # add newline at the end of file
    formatted_s = formatted_s + "\n"
    return formatted_s

In [None]:
assert_and_print(
    format_sql_commands(sql_file),
    expected_sql_file
)

--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       a.price
FROM   sales as a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) as b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) as avg_price
FROM   first_view
GROUP BY car_id;

--- End of file ---



### Function to format 1 SQL file

In [None]:
#export
def format_sql_file(f):
    """Format file `f` with SQL commands and overwrite the file.
    Return 0 for no change and 1 for formatting adjustments"""
    # open the file
    with open(f, "r") as file:
        sql_commands = file.read()
    # format SQL statements
    formatted_file = format_sql_commands(sql_commands)
    exit_code = 0 if sql_commands == formatted_file else 1
    # overwrite file
    with open(f, "w") as f:
        f.write(formatted_file)
    return exit_code

In [None]:
with open("tmp", "w") as file:
    file.write(sql_file)
format_sql_file("tmp")
with open("tmp", "r") as file:
    formatted_file = file.read()
assert_and_print(
    formatted_file,
    expected_sql_file
)
os.remove("tmp")

--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       a.price
FROM   sales as a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) as b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) as avg_price
FROM   first_view
GROUP BY car_id;

--- End of file ---



### Function to format many SQL files

In [None]:
#export
@call_parse
def format_sql_files(
    files: Param(help="Path to SQL files", type=str, nargs="+")
):
    "Format SQL `files`"
    exit_codes = []
    for file in files:
        exit_codes.append(format_sql_file(file))
    if sum(exit_codes) == 0:
        print("Nothing to format, everything is fine!")
    else:
        print("All done!")

In [None]:
with open("tmp", "w") as file:
    file.write(sql_file)
with open("tmp2", "w") as file:
    file.write(sql_file)
format_sql_files(["tmp", "tmp2"])
with open("tmp", "r") as file:
    formatted_file = file.read()
assert_and_print(
    formatted_file,
    expected_sql_file
)
with open("tmp2", "r") as file:
    formatted_file = file.read()
assert_and_print(
    formatted_file,
    expected_sql_file
)
os.remove("tmp")
os.remove("tmp2")

All done!
--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       a.price
FROM   sales as a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) as b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) as avg_price
FROM   first_view
GROUP BY car_id;

--- End of file ---

--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       a.price
FROM   sales as a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) as b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_format_file.ipynb.
Converted index.ipynb.
