<a href="https://colab.research.google.com/github/Shalala06/SQL-Formatting-Python/blob/seb/SQL-Formatting-Python.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [34]:
import re
def format_sql(script):
    # Defining key words
    keywords = [
        'SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'HAVING', 'JOIN', 'INNER JOIN',
        'LEFT JOIN', 'RIGHT JOIN', 'OUTER JOIN', 'INSERT INTO', 'UPDATE', 'SET', 'DELETE',
        'ALTER', 'DROP TABLE', 'PARTITION BY', 'OVER', 'ROW_NUMBER', 'IS', 'IS NOT', 'NULL',
        'CASE', 'WHEN', 'THEN', 'END', 'ON', 'AND', 'OR', 'AS', 'CREATE TABLE', 'DESC', 'ASC'
    ]
    primary_keywords = [
        'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'HAVING',
        'UPDATE', 'SET', 'DELETE', 'ALTER'
    ]
    join_keywords = ['LEFT JOIN','INNER JOIN', 'RIGHT JOIN', 'OUTER JOIN']

    special_keywords = ['AND', 'ON', ')', 'OR']

    # Normalize the script into one line
    script = ' '.join(script.split())
    # Capitalize keywords
    def capitalize_keywords(match):
        return match.group(0).upper()
    for keyword in keywords:
        pattern = re.compile(r'\b' + re.escape(keyword) + r'\b', re.IGNORECASE)
        script = pattern.sub(capitalize_keywords, script)
    # Handle DROP TABLE statements
    script = re.sub(r'\bDROP TABLE\b', '\n  DROP TABLE', script, flags=re.IGNORECASE)
    # Handle CREATE TABLE statements
    script = re.sub(r'\bCREATE TABLE\b', '\n\nCREATE TABLE', script, flags=re.IGNORECASE)
    # Drop down and indent primary keywords
    for keyword in primary_keywords:
        script = re.sub(r'\b' + re.escape(keyword) + r'\b(?!\s*\n\s*)', f'\n\n{keyword}\n ', script, flags=re.IGNORECASE)
    # Indent SELECT keyword
    script = re.sub(r'\b' + re.escape("SELECT") + r'\b(?!\s*\n\s*)', f'SELECT\n ', script, flags=re.IGNORECASE)

    join_keywords_with_space = ['LEFT JOIN ','INNER JOIN ', 'RIGHT JOIN ', 'OUTER JOIN ']
    # Indent JOIN keywords
    for join_keyword in join_keywords:
        if join_keyword in join_keywords_with_space:
            script = re.sub(r'\b' + re.escape(join_keyword) + r'\b', '\n\n  ' + join_keyword, script, flags=re.IGNORECASE)
        else:
            script = re.sub(r'\b' + re.escape(join_keyword) + r'\b', '\n\n  ' + join_keyword + ' ', script, flags=re.IGNORECASE)

    # Indent Special keywords
    for keyword in special_keywords:
        script = re.sub(r'\b(?!\s*\n\s*)' + re.escape(keyword) + r'\b(?!\s*\n\s*)', '\n    ' + keyword, script, flags=re.IGNORECASE)

    script = re.sub(r'\)', '\n)', script)

    # New line and comma for multiple columns in clauses
    script = re.sub(r'\s*,\s*', '\n  , ', script)

    def row_number_pos(script):
        row_number_positions = [m.start() for m in re.finditer(r'ROW_NUMBER\b', script, re.IGNORECASE)]
        row_number_elements = []

        for index, row_number_start in enumerate(row_number_positions):
            row_number = ""
            parentheses_count = 0
            for i, char in enumerate(script[row_number_start:]):
                row_number += char
                if i == 0 or i == 1:
                    continue
                if row_number[i-2] == 'A' and row_number[i-1] == 'S':
                    row_number_end = i + row_number_start
                    row_number_elements.append([row_number_start, row_number_end, script[row_number_start: row_number_end]])
                    break

        return row_number_elements, script
    row_number_keywords = [
        'ORDER BY', 'PARTITION BY'
    ]

    def format_row_number(script):
        rn_list = []
        adjusted_rn_list = []

        row_number_elements, script = row_number_pos(script)

        for element in row_number_elements:
          rn_list.append(element[2])
          row_number = ' '.join(element[2].split())

          rn_adjusted = re.sub(r' \)', ')', row_number)
          rn_adjusted = re.sub(r'\s*,\s*', '\n        , ', rn_adjusted)

          for keyword in row_number_keywords:
            rn_adjusted = re.sub(r'\b' + re.escape(keyword) + r'\b(?!\s*\n\s*)', f'\n      {keyword}\n       ', rn_adjusted, flags=re.IGNORECASE)
          adjusted_rn_list.append(rn_adjusted)

        for i in range(len(row_number_elements)):
          script = script.replace(rn_list[i], adjusted_rn_list[i])

        return script

    # Identifies each subquery in the input, it's indentation level and the start and end positions
    def subquery_position(script):
        script = format_row_number(script)
        sq_level_list = []
        sq_start = []
        sq_end = []
        sq = []
        sq_element_list = []

        # Identifies positions of each SELECT in script
        select_positions = [m.start() for m in re.finditer(r'\bSELECT\b', script, re.IGNORECASE)]
        # Loops through positions and count parentheses before each SELECT - Takes the difference between open and closed brackets to calculate
        # subquery/indentation level
        for index, subquery_start in enumerate(select_positions):
            sq_elements = []
            before_select = script[:subquery_start]
            open_count = before_select.count('(')
            close_count = before_select.count(')')
            subquery_level = open_count - close_count
            sq_level_list.append(subquery_level)

            # Loops through the script from each subquery start onwards - i is used as a count to identify subquery end which is indicated
            # by balancing parentheses to match 1 - subqueries are then be extracted from script using start and end positions
            subquery = ""
            open_count_sub = 0
            close_count_sub = 0
            for i, char in enumerate(script[subquery_start:]):
                subquery += char
                if char == '(':
                    open_count_sub += 1
                elif char == ')':
                    close_count_sub += 1
                if close_count_sub - open_count_sub == 1:
                    break
            subquery_end = i + 1 + subquery_start  # end position is one past the last character
            subquery = script[subquery_start:subquery_end]
            sq_elements = [subquery_level, subquery_start, subquery_end, subquery]
            sq_element_list.append(sq_elements)

        return sq_element_list, script

    # Uses the output of find subqueries to indent based on the subquery level
    def indent_subqueries(script):
        sq_element_lists = []
        adjusted_sq = []
        subquery_list = []
        subquery_level_sorted =[]
        subquery_indent_list = []
        end_subqueries = []
        end_subquery_indent_list = []
        end_subquery_list = []
        current_level_list = []
        split_level_list = []
        counter = 0

        sq_level = 0
        sq_start = 1
        sq_end = 2
        sq = 3

        sq_elements, script = subquery_position(script)

        select_positions = [m.start() for m in re.finditer(r'\bSELECT\b', script, re.IGNORECASE)]

        # Extracting subquery blocks
        for elements in sq_elements:
          if elements[sq_level] == 1:
            if current_level_list:
              split_level_list.append(current_level_list)
              current_level_list = []
            current_level_list.append(elements)
          else:
            current_level_list.append(elements)

        if current_level_list:
          split_level_list.append(current_level_list)

        del split_level_list[0]

        # Identifying outer sq end position
        for sq_elements in split_level_list:
          string = ""
          for index, line in enumerate(script[sq_elements[0][sq_end]:].split('\n')):
            if line == "":
              break
            string += '\n' + line
            for idx, char in enumerate(string):
              pos = sq_elements[0][sq_end] + idx

          # Extracting end subqueries within subquery blocks
          end_subqueries.append(script[sq_elements[0][sq_end]-1:pos-1])

          for i in range(len(sq_elements) - 1):
            script_end_to_end = script[sq_elements[i+1][sq_end]-1:sq_elements[i][sq_end]-1]
            if "SELECT" in script_end_to_end:
              select_count = script_end_to_end.count("SELECT")
              nested_sq_element = split_level_list[0][-select_count]
              sq_end_to_end = script[sq_elements[i+1][sq_end]:nested_sq_element[sq_start]] + script[nested_sq_element[sq_end]: sq_elements[i][sq_end]]
              return sq_end_to_end
            else:
              sq_end_to_end = script_end_to_end

            end_subqueries.append(sq_end_to_end)

          # Adjusting subqueries - Separates subqueries into constituent parts if it contains nested subqueries -
          # this breaks the subquery into two parts (prior nested subquery & post nested subquery)
          # Loops through the subqueries and skips the index[0] as this is the entire script
          # If current subquery is in previous subquery then the position of the previous SELECT up to 2 characters before the SELECT
          # of the current subquery will be taken - the previous subquery is updated in sq list to its prior nested subquery part

          for idx in range(len(sq_elements)):

              if sq_elements[idx][sq] in sq_elements[idx - 1][sq]:

                  prev_subquery_start = sq_elements[idx - 1][sq_start]
                  adjusted_sub = script[prev_subquery_start:sq_elements[idx][sq_start] - 2]
                  sq_elements[idx - 1][sq] = adjusted_sub

          # Indenting adjusted sq list
          for element in sq_elements:
              line_indented_list = ""
              for line in element[sq].split('\n'):
                  line_indented = 's' * (element[sq_level] * 4) + line  # Indent with 2 spaces per level
                  line_indented_list += '\n' + line_indented
              subquery_indent_list.append(line_indented_list)
              subquery_list.append(element[sq])

              end_line_indented_list = []
              for line in end_subqueries[counter].split('\n'):
                if line[4:6] in ["ON","AN", "OR", ") "]:                        # Non subquery indentation
                  end_line_indented = 's' * ((element[sq_level] * 4)-2) + line
                  end_line_indented_list.append('\n' + end_line_indented)
                else:
                  end_line_indented = 's' * (element[sq_level] * 4) + line  # Indent with 2 spaces per level
                  end_line_indented_list.append('\n' + end_line_indented)

              if element[sq_level] > 1:
                end_line_indented_list[-1] = end_line_indented_list[-1].replace(f'\n', '')

              end_line_indented_list = "".join(end_line_indented_list)
              end_subquery_indent_list.append(end_line_indented_list)
              end_subquery_list.append(end_subqueries[counter])
              counter += 1
        for i in range(len(subquery_list)):
          if end_subquery_list[i] in ["",")"] :
            continue
          script = script.replace(subquery_list[i], subquery_indent_list[i])
          script = script.replace(end_subquery_list[i], end_subquery_indent_list[i])

        return script

    return indent_subqueries(script)

sql_script = """
SELECT
    customers.customer_id,
    customers.customer_name,
    orders.order_id,
    orders.order_date,
    order_details.quantity,
    products.product_name
FROM
    customers
    INNER JOIN (
        SELECT
            orders.order_id,
            orders.attribute_id
        FROM
            orders
            INNER JOIN (
                SELECT
                    attribute_id
                FROM
                    product_attribute
                INNER JOIN (
                    SELECT
                        attribute
                        , row_number() over (partition by person_id order by order) as rn
                    FROM
                        attribute
                    INNER JOIN (
                        SELECT
                            id
                        FROM
                            product
                    ) AS product4
                    ON product_attribute.attribute_id = product4.attribute_id AND order.order_id IS NOT NULL
                ) AS product3
                ON product_attribute.attribute_id = product3.attribute_id
            WHERE order.product IS NULL) AS product2
            ON orders.attribute_id = product2.attribute_id
        WHERE
            orders.order_id IS NOT NULL
        AND EXISTS (SELECT
                    1
                FROM
                    orders
                      INNER JOIN (select hello from bye) AS opinion ON orders.order_id = o.order_id
                WHERE
                    orders.order_id = o.order_id
                  )
        AND person_id IS NOT NULL
    ) AS orders
    ON customers.order_id = orders.order_id
    INNER JOIN (
        SELECT
            o.order_id,
            o.customer_id
        FROM
            orders o
            INNER JOIN (
                SELECT
                    pa.attribute_id
                    , row_number() over (partition by person_id order by order ) as rn
                FROM
                    product_attribute pa

                    INNER JOIN (
                        SELECT
                            pa.house_number
                            , milk_id
                        FROM
                            Local.cornershop
                            INNER JOIN order_details
                            ON orders.order_id = order_details.order_id
                        ) AS local_shop ON local_shop.milk_id = pa.meowmeow
            ) AS filtered_products ON o.attribute_id = filtered_products.attribute_id
        WHERE
            o.order_id IS NOT NULL AND apples IS NULL
    ) AS orders
    ON customers.customer_id = orders.customer_id
    INNER JOIN products
    ON order_details.product_id = products.product_id
WHERE
  cats<rats>dogs

ORDER BY
    products.product_id;
"""
formatted_sql = format_sql(sql_script)
print(formatted_sql)

 AS product2 
    ON orders.attribute_id = product2.attribute_id 

WHERE
  orders.order_id IS NOT NULL 
    AND EXISTS ( 
    AND person_id IS NOT NULL 
)
