In [91]:
import re
import json
import logging
import datetime
import pkg_resources
pkg_resources.require("sqlparse==0.3.0")
import sqlparse
import pandas as pd

In [158]:
class Formatter:
    
    def __init__(self, raw_query):
        self.raw_query = raw_query


    def format_query(self, raw_query):
        """
        Format a query using sqlparse.
        Args:
            param (str): A string of raw query.
        Returns:
            str: A string of formatted query.
        """
        raw_query = self.raw_query.replace("\"", '')
        formatted_query = sqlparse.format(raw_query, \
                                        reindent=True, \
                                        keyword_case='upper', \
                                        strip_comments=True)

        return formatted_query
    
    def parse_cte(self, query):
        """
        Parse the CTE's.
        Args:
            param (str): A string of query containing CTE's.
        Returns:
            dict: A dict of CTE's and main query, with keys being CTE aliases or "main".  
        """
        cte = re.compile(r"(WITH)*(.*AS\s*\(SELECT)")
        pos_list = []
        for pos in cte.finditer(query):
            pos_list.append(pos.start())
        
        if pos_list != []:
            cte_dict = {}
            for index, pos in enumerate(pos_list):
                if index == len(pos_list)-1:
                    cte_query = query[pos:]
                    cte_name = 'main'
                else:
                    cte_query = query[pos : pos_list[index+1]]
                    cte_name = re.findall(r"(WITH)*(.*)AS", cte_query)[0][1].strip(' ')
                cte_dict[cte_name] = cte_query
                
            cte = re.compile(r"(SELECT)")
            pos_list = []
            for pos in cte.finditer(cte_dict['main']):
                pos_list.append(pos.start())
            last_cte = cte_dict.get('main')[:pos_list[1]]
            
            cte_name = re.findall(r"(WITH)*(.*)AS", last_cte)[0][1].strip(' ')
            cte_dict[cte_name] = last_cte
            cte_dict['main'] = cte_dict.get('main')[pos_list[1]:]
            
        else:
            cte_dict = {}
            cte_dict['main'] = query

        return cte_dict   
    
    def _cleanup(self, cte_dict):

        for cte_name, cte in cte_dict.items():
            if cte_name != 'main':
                remove_head = re.search(r"\(", cte).start() + 1
                cte_dict[cte_name] = cte.replace(cte[:remove_head], '')

        return cte_dict  
    
    def _get_table_names(self, line_query):
        """
        Get all tables names mapping from a SQL query. 
        Args:
            param (list): The flist of a query that is split by \n.
        Returns:
            dict: A dictionary with table names mapping. If a table is being aliased, returns key being the alias and value being the table/subquery name/alias. Otherwise, key and value is the same.

        """
        table_name_mapping = dict()

        for line in line_query:
            
            table_line = re.findall(r"(FROM|JOIN).(\w+.*)", line)

            if table_line != []:
                table_name_line = table_line[0][1].split(' ')
                
                if len(table_name_line) == 1:
                    table_name_mapping[table_name_line[0].rstrip(')|,')] = table_name_line[0].rstrip(')|,')
                    
                elif len(table_name_line) > 1:
                    table_name_mapping[table_name_line[1].rstrip(')|,')] = table_name_line[0].rstrip(')|,')

        return table_name_mapping
    
    def _get_all_variables(self, query):
        """
        Get all variables including: table names, aliases, column names and aliases, and all other non-sql reserved words.
        Args:
            param (string): A string of any type of complete query; allows only complete query but can nest with CTE's and/or subqueries.
        Returns:
            list: A list of all variables within the query.
        """
        all_variables = []

        for e in query.split('\n'):
            if sum(list(map(lambda x: '*' in x, re.findall(r"\s[*]?", query)))):
                variable = []
            else:
                variable = re.findall(r"[a-z_\s.]+", e.lstrip(' '))

            all_variables.extend(variable)

        return all_variables
    
    def _get_queried_columns(self, table_names, meta_cols):
        """
        Get all columns by looking up referenced table names in the metacolumn file.
        Args:
            param1 (dict): A dictionary of mapping of table/subquery referenced in query and their aliases.
            param2 (dict): A dictionary of metadata columns from Glue.
        Returns:
            list: A list of {key, value} pairs, each pair reflecting the table name and all columns under the table, from Glue metastore.
        """
        queried_cols = []

        for _,table_name in table_names.items():

            if len(table_name.split('.')) == 2:
                queried_cols.append({table_name: set(meta_cols[meta_cols['db_table'] == table_name]['all_columns'])})

        return queried_cols
    
    def _map_db_columns(self, var_list, queried_cols, table_alias_mapping):
        """
        Map database columns.
        Args:
            var_list (list): The list of all variables (non-sql reserved words) in query.
            queried_cols (list): The list of all currently existing columns in Glue, under the table that was being queried. 
            table_alias_mapping (dict): The mapping of table and their (if) alias.
        Returns:
            list: A list of unique db.table.column that was being scanned by the query.
        """
        original_columns_list = []

        for var in set(var_list):

            var = var.strip(' ')

            if var in table_alias_mapping.keys():
                pass
            
            else:
                var_split = var.split('.')
            
                if len(var_split) == 1:
                    for db_table in queried_cols:
                        for k,v in db_table.items():
                            if var in v:
                                original_columns_list.append("{}.{}".format(k, var))

                elif len(var_split) == 2:

                    if var_split[0] in table_alias_mapping.keys():
                        db_table = table_alias_mapping[var_split[0]]

                        for db_table_col in queried_cols:
                            for k,v in db_table_col.items():
                                if k == db_table and var_split[1] in v:
                                    original_columns_list.append("{}.{}".format(k, var_split[1]))

        return list(set(original_columns_list))

In [159]:
raw_query = """WITH opportunity_to_name AS
(
    SELECT  -- make sure there is only one name per id
    id AS account_id,
    name AS account_name
    FROM
    sfdc.accounts
    WHERE
    dt = '{run_date}'
    -- dt = '2019-08-07'
    GROUP BY
    id,
    name
)

SELECT
*
FROM
opportunity_to_name
"""

In [160]:
formatter = Formatter(raw_query)

In [161]:
query = formatter.format_query(raw_query)

In [162]:
cte_dict = formatter.parse_cte(query)

In [163]:
cte_dict

{'main': 'SELECT *\nFROM opportunity_to_name',
 'opportunity_to_name': "WITH opportunity_to_name AS\n  (SELECT id AS account_id,\n          name AS account_name\n   FROM sfdc.accounts\n   WHERE dt = '{run_date}'\n   GROUP BY id,\n            name)\n"}

In [164]:
cte_dict = formatter._cleanup(cte_dict)

In [165]:
col_meta = pd.DataFrame({'db_table': 'sfdc.accounts', 
            'all_columns': ['account_health_c', 'account_health_flag_c', 'account_health_last_touch_c', 'account_notes_c', 'account_owner_c', 'account_owner_id_c', 'account_segment_c', 'account_source', 'account_start_date_c', 'account_tier_c', 'add_company_tags_single_c', 'annual_revenue', 'billing_city', 'billing_country', 'billing_postal_code', 'billing_state', 'billing_street', 'churned_date_c', 'created_by_id', 'created_date', 'crunchbase_funding_c', 'csm_c', 'customer_tier_c', 'domain_c', 'dscorgpkg_lead_source_c', 'dscorgpkg_naics_codes_c', 'dscorgpkg_sic_codes_c', 'finance_arr_c', 'github_issue_ticket_c', 'health_update_c', 'id', 'industry', 'industry_group_c', 'industry_sector_c', 'initial_deal_arr_c', 'initial_deal_date_c', 'is_deleted', 'last_activity_date', 'last_modified_date', 'lfbn_account_domain_c', 'lost_opportunities_c', 'lost_renewals_c', 'mapbox_username_c', 'naics_code_c', 'name', 'netsuite_conn_channel_tier_c', 'next_renewal_date_c', 'number_of_employees', 'number_of_mapbox_users_c', 'open_opportunities_c', 'open_renewals_c', 'owner_id', 'owner_role_c', 'parent_id', 'partner_status_c', 'partner_type_c', 'primary_contact_c', 'primary_use_case_c', 'rating', 'record_type_id', 'region_c', 'renewal_manager_c', 'sb_pf_company_c', 'sdr_c', 'segmentation_c', 'shipping_city', 'shipping_country', 'shipping_postal_code', 'shipping_state', 'shipping_street', 'sic', 'solution_engineer_c', 'sub_industry_c', 'sub_region_c', 'support_engineer_c', 'type', 'vertical_c', 'vertical_formula_c', 'won_opportunities_c', 'x18_digit_account_id_c', 'zendesk_result_c', 'zendesk_zendesk_organization_c', 'zendesk_zendesk_organization_id_c', 'zisf_zoominfo_industry_c', 'dt']})
col_meta

Unnamed: 0,db_table,all_columns
0,sfdc.accounts,account_health_c
1,sfdc.accounts,account_health_flag_c
2,sfdc.accounts,account_health_last_touch_c
3,sfdc.accounts,account_notes_c
4,sfdc.accounts,account_owner_c
...,...,...
80,sfdc.accounts,zendesk_result_c
81,sfdc.accounts,zendesk_zendesk_organization_c
82,sfdc.accounts,zendesk_zendesk_organization_id_c
83,sfdc.accounts,zisf_zoominfo_industry_c


In [171]:
original_columns_list

['sfdc.accounts.name',
 'sfdc.accounts.name',
 'sfdc.accounts.dt',
 'sfdc.accounts.id',
 'sfdc.accounts.id']