In [None]:
#hide
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False  # workaround for buggy jedi

In [None]:
# default_exp sql_file_formatter

In [None]:
#export
import re
import os
from sql_formatter.core import format_sql, assert_and_print

# sql_file_formatter

> Functions to format a SQL file with multiple queries and SQL statements

In [None]:
#hide
from nbdev.showdoc import *

## Use-Case

Assume you have a file called sql_file.sql containing SQL statements and queries.

After reading it in python we could have something like this:

In [None]:
sql_file = """
use database my_database;
use schema my_schema;

create or replace view first_view as -- my first view
select a.car_id,
       b.car_name,
       a.price,
from sales as a left join (select car_id, car_name, from cars) as b 
on a.car_id = b.car_id
where car_id>1 and car_id<=100;

create or replace table first_table as -- my first table
select car_id,
       avg(price) as avg_price,
from first_view
group by car_id;
""".strip()

print(sql_file)

use database my_database;
use schema my_schema;

create or replace view first_view as -- my first view
select a.car_id,
       b.car_name,
       a.price,
from sales as a left join (select car_id, car_name, from cars) as b 
on a.car_id = b.car_id
where car_id>1 and car_id<=100;

create or replace table first_table as -- my first table
select car_id,
       avg(price) as avg_price,
from first_view
group by car_id;


Then we would like to format the SQL-queries in this file, while letting every other non-query-SQL statement untouched. For the example above we would like to have something like this:

In [None]:
expected_sql_file = """
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       a.price
FROM   sales AS a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) AS b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100;


CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) AS avg_price
FROM   first_view
GROUP BY car_id;
""".strip()

print(expected_sql_file)

use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       a.price
FROM   sales AS a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) AS b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100;


CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) AS avg_price
FROM   first_view
GROUP BY car_id;


To simplify the problem, we assume for the moment the user is kind enough to separate every SQL statement by a semicolon (;)

### Formatting philosophy of SQL files

* Every SQL-query is separated from above and below by two new lines
* Every SQL-query is formatted via `format_sql`

So our first step is to identify the lines corresponding to the SQL queries

## Formatting Functions

### SQL-queries identification

SQL queries always start either with *CREATE* or with *SELECT*.

Everything else is not a valid SQL statement and will therefore not be formatted

In [None]:
#export
def check_sql_query(s):
    "Checks whether `s` is a SQL query"
    return bool(re.match(pattern=r"^\n*(?:select|create)", string=s, flags=re.I))

In [None]:
assert check_sql_query("""
creaTe or replace table my_table as
select asdf
from table
where asdf = 1
""")

In [None]:
assert check_sql_query("""
SELECT qwer, asdf
""")

In [None]:
assert not check_sql_query("use database my_database;")

In [None]:
assert not check_sql_query("""

use schema my_schema;
""")

### Main function formattin SQL commands

In [None]:
#export
def format_sql_commands(s):
    "Format SQL commands in `s`"
    split_s = s.split(";")  # split by semicolon
    formatted_split_s = [
        "\n\n" + format_sql(sp, add_semicolon=False)
        if check_sql_query(sp)
        else sp
        for sp in split_s
    ]
    formatted_s = ";".join(formatted_split_s)
    return formatted_s

In [None]:
assert_and_print(
    format_sql_commands(sql_file),
    expected_sql_file
)

Correcting mistake: Comma at the end of SELECT statement
Correcting mistake: Comma at the end of SELECT statement
Correcting mistake: Comma at the end of SELECT statement
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       a.price
FROM   sales AS a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) AS b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100;


CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) AS avg_price
FROM   first_view
GROUP BY car_id;


### Function to format SQL files

In [None]:
#export
def format_sql_file(f):
    "Format SQL file `f`"
    with open(f, "r") as file:
        sql_commands = file.read()
    formatted_sql_commands = format_sql_commands(sql_commands)
    return formatted_sql_commands

In [None]:
with open("tmp", "w") as file:
    file.write(sql_file)
assert_and_print(
    format_sql_file("tmp"),
    expected_sql_file
)
os.remove("tmp")

Correcting mistake: Comma at the end of SELECT statement
Correcting mistake: Comma at the end of SELECT statement
Correcting mistake: Comma at the end of SELECT statement
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
       b.car_name,
       a.price
FROM   sales AS a
    LEFT JOIN (SELECT car_id,
                      car_name
               FROM   cars) AS b
        ON a.car_id = b.car_id
WHERE  car_id > 1
   and car_id <= 100;


CREATE OR REPLACE TABLE first_table AS -- my first table
SELECT car_id,
       avg(price) AS avg_price
FROM   first_view
GROUP BY car_id;


In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_sql_file_formatter.ipynb.
Converted index.ipynb.
