In [0]:
import time
import json
import inspect
from threading import Thread
from pyspark.sql import SparkSession
from databricks.sdk import WorkspaceClient
from databricks.connect import DatabricksSession

from pyspark.sql.types import Row
from pyspark.sql.functions import col, regexp, current_user, asc, lower, nvl, lit

spark = None
warehouse_id = None
w = WorkspaceClient()
cluster_id = '0627-133437-mtg2jcrp'


class ThreadWithReturnValue(Thread):
    def __init__(self, group=None, target=None, name=None, args=[], kwargs={}, Verbose=None):
        Thread.__init__(self, group, target, name, args, kwargs)
        self._return = None

    def run(self):
        if self._target is not None:
            try:
                self._return = self._target(*self._args, **self._kwargs)
            except Exception as ex:
                self._return = Exception(f'Exception in {self._target.__name__}: {ex}')

    def join(self, *args):
        Thread.join(self, *args)
        return self._return


class ThreadList(list):
    def append(self, thread) -> None:
        if not isinstance(thread, ThreadWithReturnValue):
            raise TypeError(f'thread must be of an instance of ThreadWithReturnValue not {type(thread)}')
        return super().append(thread)
    
    def start_all_threads(self) -> None:
        for thread in self:
            thread.start()

    def join_all_threads(self) -> dict:
        final_res = {}
        exception_lst = []
        for thread in self:
            res = thread.join()
            if res and isinstance(res, dict):
                final_res.update(res)
            elif res and isinstance(res, Exception):
                exception_lst.append(res)
        
        if exception_lst:
            for i, exception in enumerate(exception_lst, start=1):
                print(f'---- EXCEPTION #{i} START ----\n\n{str(exception).strip()}\n\n---- EXCEPTION #{i} END ----')
                if i < len(exception_lst):
                    print('\n\n')
            print()
            raise Exception('GOT THE ABOVE EXCEPTIONS PLEASE REVIEW')
        
        return final_res

    def start_and_join_all_threads(self):
        self.start_all_threads()
        return self.join_all_threads()


class Table():
    def __init__(self, full_tbl_name, full_copy_name: str=None, exclude_cols: list=None, col_mapping: dict=None, column_exprs: dict=None, data_type_exprs: dict=None) -> None:
        self.full_tbl_name = full_tbl_name
        self.catalog_name, self.schema_name, self.tbl_name = full_tbl_name.split('.')
        
        if full_copy_name is not None:
            self.full_copy_name = full_copy_name.replace('%CATALOG_NAME%', self.catalog_name).replace('%SCHEMA_NAME%', self.schema_name).replace('%TBL_NAME%', self.tbl_name)
            self.copy_catalog_name, self.copy_schema_name, self.copy_tbl_name = self.full_copy_name.split('.')
            assert self.full_copy_name != self.full_tbl_name, f'TABLE NAME: {self.full_tbl_name} CANNOT EQUAL COPY NAME: {self.full_copy_name}'
        else:
            self.full_copy_name, self.copy_catalog_name, self.copy_schema_name, self.copy_tbl_name = None, None, None, None

        self.exclude_cols = []
        if exclude_cols is not None:
            self.exclude_cols = [col.lower() for col in exclude_cols]

        self.col_mapping = {}
        if col_mapping is not None:
            self.col_mapping = {from_col.lower(): to_col.lower() for from_col, to_col in col_mapping.items()}

        self.column_exprs = {}
        if column_exprs is not None:
            self.column_exprs = {dt: expr for dt, expr in column_exprs.items()}

        self.data_type_exprs = {}
        if data_type_exprs is not None:
            self.data_type_exprs = {dt: expr for dt, expr in data_type_exprs.items()}

        self.col_info = self.get_tbl_col_info()


    def get_tbl_col_info(self) -> dict:
        col_info = []
        for idx, row in enumerate(execute_sql_query(f'DESC {self.full_tbl_name}', as_df=True).collect()):
            if row['col_name'] == '# Partition Information':
                break
            if row['col_name'].lower() in self.exclude_cols:
                continue

            tmp = {'name': row['col_name'].lower(), 'data_type': row['data_type'].lower(), 'idx': idx}
            if tmp['name'] in self.col_mapping:
                tmp['alias'] = self.col_mapping[tmp['name']]

            # If the column has a expr
            if self.column_exprs.get(tmp['name'], '') != '':
                tmp['col_expr'] = self.column_exprs[tmp['name']]
            
            # If the columns datatype has an expr
            if self.data_type_exprs.get(tmp['data_type'], '') != '':
                tmp['dt_expr'] = self.data_type_exprs[tmp['data_type']]

            col_info.append(tmp)
        return col_info
    

    def get_col_lst(self, aliased: bool=False) -> list:
        return [col.get('alias', col['name']) if aliased else col['name'] for col in self.col_info]


    # def get_col_csv(self, aliased: bool=False, use_dt_expr: bool=False):
    #     ret = []
    #     for col in self.col_info:
    #         col_name = col['name']
    #         col_dt = col['data_type']

    #         # If the columns datatype has an expr
    #         dt_expr = self.data_type_exprs[col_dt.lower()].format(__COLUMN_NAME__=col_name) if use_dt_expr and col_dt in self.data_type_exprs else ''

    #         str_ret = dt_expr if dt_expr else col_name

    #         # If the column has an alias add that if the alised flag is true
    #         # If there is a dt_expr and the alias is false we only want the column name
    #         str_ret += f" AS {col.get('alias', col_name) if aliased else col_name}" if aliased or dt_expr else ''
            
    #         ret.append(str_ret)
            
    #     return ','.join(ret)

    def get_col_expr_str(self, col, aliased: bool=False, use_col_expr: bool=False, use_dt_expr: bool=False):
        col_name = col['name']

        ret = f'{col_name}'
        if use_col_expr and 'col_expr' in col:
            ret = col['col_expr'].format(__COLUMN_NAME__=ret)

        if use_dt_expr and 'dt_expr' in col:
            ret = col['dt_expr'].format(__COLUMN_NAME__=ret)
        
        if aliased and 'alias' in col:
            ret += f" AS {col['alias']}"
        elif (use_col_expr and 'col_expr' in col) or (use_dt_expr and 'dt_expr' in col):
            ret += f" AS {col_name}"

        return ret


    def get_col_csv(self, aliased: bool=False, use_col_expr: bool=False, use_dt_expr: bool=False):
        return ','.join([self.get_col_expr_str(col, aliased=aliased, use_col_expr=use_col_expr, use_dt_expr=use_dt_expr) for col in self.col_info])
    

    def get_copy_ddl(self, copy_type: str, filter_clause: str=None, aliased: bool=False) -> str:
        copy_type = copy_type.upper().strip()
        assert copy_type in ('TABLE', 'VIEW'), f'INVALID COPY TYPE: {copy_type}, ONLY TABLE AND VIEW ARE SUPPORTED'
        ddl = f'CREATE OR REPLACE {copy_type} {self.full_copy_name} AS SELECT {self.get_col_csv(aliased=aliased, use_col_expr=True, use_dt_expr=True)} FROM {self.full_tbl_name}'
        if filter_clause:
            ddl += filter_clause
        return ddl
    

    def set_exclude_cols(self, exclude_cols: list) -> None:
        new_exclude_cols = [col.lower() for col in exclude_cols]
        if len(list(set(self.exclude_cols) - set(new_exclude_cols))):
            self.exclude_cols = new_exclude_cols
            self.col_info = self.get_tbl_col_info()


def validate_function_args(func, args:dict) -> None:
    func_params = inspect.signature(func).parameters
    for arg_name in func_params:
        arg = func_params[arg_name]
        arg_hint_type, arg_default = arg.annotation, arg.default
        
        # This must mean it is a required arg since it does not have a default
        # if arg.default == arg.empty:
        #     print(arg)

        # Make sure the type matches, unless it equals the default value
        if not isinstance(args[arg_name], arg_hint_type) and arg_default != args[arg_name]:
            raise TypeError(f'Invalid type for arg {arg_name}, expected {arg_hint_type}, got {type(args[arg_name])}')


def execute_sql_query(query: str, as_df: bool=False, warehouse_id=None):
    if warehouse_id is None:
        if 'warehouse_id' not in globals() or globals()['warehouse_id'] is None:
            raise Exception('WHEN IMPORTING execute_sql_query warehouse_id MUST BE PASSED AS A PARAM')
        else:
            warehouse_id = globals()['warehouse_id']

    res = w.statement_execution.execute_statement(query, warehouse_id, wait_timeout='0s')
    
    wait_states = ['PENDING', 'RUNNING']

    try:
        time.sleep(1)
        res = w.statement_execution.get_statement(res.statement_id)

        # Wait for query to complete and cancel it if a keyboard interrupt is detected
        while res.status.state.value in wait_states:
            res = w.statement_execution.get_statement(res.statement_id)
            time.sleep(5)

    except KeyboardInterrupt:
        w.statement_execution.cancel_execution(res.statement_id)
        raise
    
    if res.status.state.value == 'FAILED':
        raise Exception(res.status.error.message)
    elif res.status.state.value == 'CANCELED':
        raise Exception(f'QUERY HAS BEEN CANCELED, ID: {res.statement_id}')
    elif res.status.state.value not in wait_states and res.status.state.value != 'SUCCEEDED':
        raise Exception(f'UNKNOWN STATE {res.status.state.value} FOR QUERY ID: {res.statement_id}')

    if as_df:
        res_dict = res.as_dict()
        col_lst = sorted(res_dict['manifest']['schema']['columns'], key=lambda x: x['position'])

        # df = pd.DataFrame(
        #     res_dict.get('result', {}).get('data_array', [])
        #     , columns=[col['name'] for col in col_lst]
        # )

        # dt_map = {
        #     'INT': int
        #     , 'BOOLEAN': bool
        #     , 'STRING': str
        #     , 'DATE': 'datetime64'
        #     , 'TIMESTAMP': 'datetime64'
        # }
        # df = df.astype({col['name']: dt_map[col['type_name']] for col in col_lst}).convert_dtypes()

        # Using spark
        df = spark.createDataFrame(res_dict.get('result', {}).get('data_array', []), schema=','.join([f"`{col['name']}` STRING" for col in col_lst]))
        for col in col_lst:
            df = df.withColumn(col['name'], df[f"`{col['name']}`"].cast(col['type_name']))

        return df
        
    return res


def get_default_arg_map() -> dict:
    return {
        'db_tbl': {'required': True, 'type': str}
        , 'sf_tbl': {'required': True, 'type': str}
        , 'filter_clause': {'required': False, 'type': str}
        , 'warehouse_id': {'required': False, 'type': str, 'default': '0e9c378a506f69a9'}
        # Will set the default for this later, dont set it here
        , 'missmatch_tbl_name': {'required': False, 'type': str}
        , 'cfg_tbl_name': {'required': False, 'type': str, 'default': 'users.pamons.validation_config'}
        , 'sf_copy_obj_type': {'required': False, 'type': str, 'default': 'TABLE'}
        , 'db_copy_obj_type': {'required': False, 'type': str, 'default': 'VIEW'}
        , 'json_ordering_udf': {'required': False, 'type': str, 'default': 'users.pamons.sort_json_str_keys'}
        , 'exclude_cols': {'required': False, 'type': list}
        , 'db_exclude_cols': {'required': False, 'type': list}
        , 'sf_exclude_cols': {'required': False, 'type': list}
        , 'primary_keys': {'required': False, 'type': list}
        , 'skip_checks': {'required': False, 'type': list}
        # List of dicts
        , 'column_mapping': {
            'required': False, 'type': list
            , 'arg_map': {
                'sf': {'required': True, 'type': str} 
                , 'db': {'required': True, 'type': str} 
            }
        }
        , 'column_exprs': {
            'required': False, 'type': dict, 'lower_case_keys': True
            # Since the keys can differ because they are column names we cannot make an arg map
            , 'skip_arg_map_check': True
        }
        , 'data_type_exprs': {
            'required': False, 'type': dict
                , 'arg_map': {
                    'timestamp': {'required': False, 'type': str} 
                    , 'date': {'required': False, 'type': str}
                }
                , 'lower_case_keys': True
                , 'default': {'timestamp': 'date_trunc("SECOND", {__COLUMN_NAME__})'}
        }
        , 'fail_at_first_check': {'required': False, 'type': bool}
        , 'include_queries': {'required': False, 'type': bool}
        , 'use_threads': {'required': False, 'type': bool}
        , 'use_cfg_tbl': {'required': False, 'type': bool}
        # 0 meaning no table will be created
        , 'num_of_sample_rows': {'required': False, 'type': int, 'default': 0}

        # Debug only flags
        , 'create_sf_copy': {'required': False, 'type': bool, 'default': True}
        , 'drop_sf_copy': {'required': False, 'type': bool, 'default': True}
        , 'create_db_copy': {'required': False, 'type': bool, 'default': True}
        , 'drop_db_copy': {'required': False, 'type': bool, 'default': True}
    }


def validate_cfg(cfg: dict, arg_map=None, exclude_checks: list=[]) -> dict:

    if arg_map is None:
        arg_map = get_default_arg_map()

    type_defaults = {
        list: []
        , bool: False
        , str: ''
        , dict: {}
    }

    # Check to make sure we are not missing any required args
    missing_required_args = list(set([arg for arg in arg_map if arg_map[arg]['required'] and arg not in exclude_checks]) - set(cfg))
    assert not missing_required_args, f'Missing required arguments {missing_required_args}'

    # This map will store variable names that used the default value, useful for dicts
    used_default_map = {}
    for arg in arg_map:
        # These checks we should excluded
        if arg in exclude_checks:
            continue

        # Add default values to optional missing args
        if arg not in cfg:
            if 'default' in arg_map[arg]:
                cfg[arg] = arg_map[arg]['default']
            else:
                cfg[arg] = type_defaults[arg_map[arg]['type']]
            used_default_map[arg] = True
        
        if not isinstance(cfg[arg], arg_map[arg]['type']):
            raise TypeError(f"ARG {arg} HAS TYPE {type(cfg[arg])} BUT IS EXPECTED TO BE {arg_map[arg]['type']}")

        if isinstance(cfg[arg], dict) or (isinstance(cfg[arg], list) and 'arg_map' in arg_map[arg]):
            
            # Reuse code to validate dicts or lists of dicts, as_dict will convert it back to a dict after
            as_dict = False
            if isinstance(cfg[arg], dict):
                as_dict = True
                cfg[arg] = [cfg[arg]]

            for i, item in enumerate(cfg[arg]):
                # If the default was used and the null_on_default flg is in the arg_map for this var 
                # we can set args value as null and skip the checks validating the keys/values in the arg
                # This usually means the arg was not passed
                if used_default_map.get(arg, False) and arg_map[arg].get('null_on_default', False):
                    cfg[arg][i] = None
                else:

                    # Lower case the keys
                    if arg_map[arg].get('lower_case_keys', False):
                        item = {key.lower(): val for key, val in item.items()}
                    
                    if not arg_map[arg].get('skip_arg_map_check', False):
                        assert 'arg_map' in arg_map[arg], f'arg_map REQUIRED IN arg_map FOR CONFIG VARIABLE {arg}'
                        cfg[arg][i] = validate_cfg(item, arg_map[arg]['arg_map'])
            

            if as_dict:
                cfg[arg] = cfg[arg][0]

    extra_keys = list(set(list(cfg.keys())) - set(list(arg_map.keys())))
    assert not extra_keys, f'UNKNOWN ARGS {extra_keys}'

    return cfg


def update_cfg_from_cfg_tbl(cfg: dict, cfg_tbl_name: str) -> dict:
    # Stored the keys that were passed so we dont override them
    passed_keys = list(cfg.keys())
    cfg = validate_cfg(cfg)

    # execute_sql_query("""
    #     CREATE TABLE IF NOT EXISTS users.pamons.validation_config (
    #         db_tbl STRING NOT NULL,
    #         sf_tbl STRING NOT NULL,
    #         primary_keys ARRAY<STRING>,
    #         filter_clause STRING,
    #         exclude_cols ARRAY<STRING>,
    #         db_exclude_cols ARRAY<STRING>,
    #         sf_exclude_cols ARRAY<STRING>,
    #         column_mapping ARRAY<STRUCT<db: STRING, sf: STRING>>,
    #         warehouse_id STRING,
    #         cfg_json STRING,
    #         user_id STRING
    #     ) USING delta
    # """)

    cfg_df = spark.sql(f"SELECT * FROM {cfg_tbl_name} WHERE LOWER(DB_TBL) = LOWER('{cfg['db_tbl'].strip()}') AND LOWER(USER_ID) = LOWER(CURRENT_USER())")
    if cfg_df.count():
        print(f"Found entry for tbl {cfg['db_tbl']} in cfg table {cfg_tbl_name}")

        df_collected = cfg_df.collect()[0]
        for arg_name in cfg_df.columns:
            if arg_name in ('cfg_json', 'user_id'):
                continue

            arg_val = df_collected[arg_name]
            # Arg should almost always be in cfg because of validate_cfg adding the defaults
            # Skip any args where the values match or the arg is not using the default value aka passed via user
            if arg_name not in cfg or (arg_val != cfg[arg_name] and arg_name not in passed_keys):
                # Handle array of structs
                if isinstance(arg_val, list) and len(arg_val) > 0 and isinstance(arg_val[0], Row):
                    cfg[arg_name] = [row.asDict() for row in arg_val]
                else:
                    cfg[arg_name] = arg_val

                print(f'Got {arg_name} from cfg tbl with value: {cfg[arg_name]}')
        
        # Any additional_args we want to load from the json string
        additional_args = ['column_exprs']

        if additional_args:
            cfg_json = json.loads(cfg_df.select(col('cfg_json')).collect()[0][0])
            for arg_name in additional_args:
                # This should only ever happen if we have never seen an arg before
                if arg_name not in cfg_json:
                    continue

                arg_val = cfg_json[arg_name]

                if arg_name not in cfg or (arg_val != cfg[arg_name] and arg_name not in passed_keys):
                    cfg[arg_name] = arg_val
                    print(f'Got {arg_name} from cfg tbl with value: {cfg[arg_name]}')
        
    else:
        print(f"Could not find a config record for {cfg['db_tbl']} in {cfg_tbl_name}")

    return cfg


def update_cfg_tbl(cfg: dict, cfg_tbl_name: str) -> None:
    # print(f'Updating {cfg_tbl_name}')
    # Add the cgf in case we need to extract new fields later
    cfg.update({'cfg_json': json.dumps(cfg)})

    spark.createDataFrame(
        [cfg]
        , 'db_tbl string, sf_tbl string, primary_keys array<string>, filter_clause string, exclude_cols array<string>, db_exclude_cols array<string>, sf_exclude_cols array<string>, column_mapping array<struct<db:string,sf:string>>, warehouse_id string, cfg_json string'
    ).withColumn('user_id', current_user()).createOrReplaceTempView('validation_config_tmp')

    spark.sql(f"""
        MERGE INTO {cfg_tbl_name} a
        USING validation_config_tmp b
        ON a.db_tbl = b.db_tbl and a.user_id = b.user_id
        WHEN MATCHED THEN UPDATE SET *
        WHEN NOT MATCHED THEN INSERT *
    """)


#### 1 - Schema check ####
def col_dt_check(cfg: dict, db_tbl: Table, sf_tbl: Table) -> dict:
    print('Starting check #1, compare the column datatypes')
    result_dict = {}

    # Check using regex because the decimal can have various scales
    sf_db_dt_mappings = [
        {'db': 'bigint', 'sf': 'decimal\(\d+,0\)'}
        ,{'db': 'int', 'sf': 'decimal\(\d+,0\)'}
        ,{'db': 'double', 'sf': 'decimal\(\d+,d+\)'}
        ,{'db': 'timestamp_ntz', 'sf': 'timestamp'}
        ,{'db': 'decimal(38,6)', 'sf': '^decimal\(38,(\d+)\)$'}
        ,{'db': 'decimal(38,9)', 'sf': '^decimal\(38,(\d+)\)$'}
    ]
    sf_db_dt_mapping_df = spark.createDataFrame(sf_db_dt_mappings)

    # db_full_tbl_col_info = [['db'] + col for col in get_tbl_col_info(db_tbl_name) if col[0] in col_lst]
    # sf_full_tbl_col_info = [['sf'] + col for col in get_tbl_col_info(sf_tbl_name) if col[0] in col_lst]
    db_full_tbl_col_info = [['db', col['name'], col['data_type']] for col in db_tbl.col_info]
    sf_full_tbl_col_info = [['sf', col['name'], col['data_type']] for col in sf_tbl.col_info]
    schema_df = spark.createDataFrame(db_full_tbl_col_info + sf_full_tbl_col_info, 'src string, col_name string, data_type string')

    schema_res = (
        schema_df.filter("src = 'db'").alias('db')
        .join(schema_df.filter("src = 'sf'").alias('sf'), on=[col('db.col_name') == col('sf.col_name')], how='inner')
        .join(sf_db_dt_mapping_df.alias('mapping'), on=[(col('db.data_type') == col('db'))], how='left')
        .select(
            col('db.col_name').alias('column_name'), col('db.data_type').alias('db_dt'), col('sf.data_type').alias('sf_dt')
            , ((nvl(col('db.data_type') == col('sf.data_type'), lit(False))) | (nvl(regexp(col('sf.data_type'), col('sf')), lit(False)))).alias('dt_match_flg')
        )
    )
    
    # Old way of doing it, much slower, leaving this here so the query can be populated in the result, the result should be the same
    # schema_check_query = f"""
    #     SELECT db.column_name, db.data_type db_dt, sf.data_type sf_dt, LOWER(db.data_type) = LOWER(sf.data_type) dt_match_flg
    #     FROM {db_catalog_name}.information_schema.columns db
    #     JOIN {sf_catalog_name}.information_schema.columns sf
    #     ON LOWER(db.column_name) = LOWER(sf.column_name)
    #         AND LOWER(db.table_schema) = '{db_schema_name}' AND LOWER(db.table_name) = '{db_tbl_name}' AND LOWER(db.column_name) in {str(col_lst).replace('[', '(').replace(']', ')')}
    #         AND LOWER(sf.table_schema) = '{sf_schema_name}' AND LOWER(sf.table_name) = '{sf_tbl_name}' AND LOWER(sf.column_name) in {str(col_lst).replace('[', '(').replace(']', ')')}
    # """
    # schema_res = execute_sql_query(schema_check_query, as_df=True)

    if schema_res.select('dt_match_flg').distinct().count() != 1:
        if cfg['fail_at_first_check']:
            raise Exception(f'Schemas do not match')
        result_dict['schema_check'] = {'passed': False}
    else:
        result_dict['schema_check'] = {'passed': True}
    result_dict['schema_check'].update({
        'db_col_count': len(db_full_tbl_col_info)
        , 'sf_col_count': len(sf_full_tbl_col_info)
        , 'missmatched_cols': [(row['column_name'], row['db_dt'], row['sf_dt']) for row in schema_res.filter("not dt_match_flg").collect()]
    })

    # if cfg['include_queries']:
    #     result_dict['schema_check']['query'] = schema_check_query

    return result_dict


#### 2 - Check counts ####
def count_check(cfg: dict, db_tbl: Table, sf_tbl: Table) -> dict:
    print('Starting check #2, compare the counts of the two tables')
    result_dict = {}

    count_check_query = f"""
        SELECT 'DB_COUNT' src, COUNT(*) row_cnt FROM {db_tbl.full_copy_name}
        UNION ALL
        SELECT 'SF_COUNT', COUNT(*) FROM {sf_tbl.full_copy_name}
    """
    count_res = execute_sql_query(count_check_query, as_df=True)

    count_dict = {row['src']: row['row_cnt'] for row in count_res.collect()}
    if count_dict['DB_COUNT'] != count_dict['SF_COUNT']:
        if cfg['fail_at_first_check']:
            raise Exception(f'Counts do not match. DB: {count_dict["DB_COUNT"]}, SF: {count_dict["SF_COUNT"]}')
        result_dict['count_check'] = {'passed': False}
    else:
        result_dict['count_check'] = {'passed': True}
    result_dict['count_check'].update(count_dict)

    if cfg['include_queries']:
        result_dict['count_check']['query'] = count_check_query
    
    return result_dict


#### 3 - Hash check ####
# if 6 not in cfg['skip_checks'] and cfg['primary_keys'] and cfg['num_of_sample_rows']
# def hash_check(cfg: dict, db_tbl: Table, sf_tbl: Table) -> dict:
#     print('Starting check #3, take a hash of the entire row and compare the diffrence of the two tables')
#     result_dict = {}

#     cols_csv = db_tbl.get_col_csv()

#     hash_check_query = f"""
#         SELECT COUNT(*)
#         FROM (
#             SELECT sha2(concat_ws('||', {cols_csv}), 256) HASHED FROM {db_tbl.full_copy_name}
#             MINUS
#             SELECT sha2(concat_ws('||', {cols_csv}), 256) FROM {sf_tbl.full_copy_name}
#         )
#     """
#     hash_res = execute_sql_query(hash_check_query)
#     total_row_count = int(hash_res.result.data_array[0][0])

#     if total_row_count != 0:
#         if cfg['fail_at_first_check']:
#             raise Exception(f'Hash check failed, {total_row_count} rows are different')
#         result_dict['hash_check'] = {'passed': False}
#     else:
#         result_dict['hash_check'] = {'passed': True}
#     result_dict['hash_check']['row_diff'] = total_row_count
    
#     if cfg['include_queries']:
#         result_dict['hash_check']['query'] = hash_check_query
    
#     return result_dict

def hash_check(cfg: dict, db_tbl: Table, sf_tbl: Table) -> dict:
    print('Starting check #3, take a hash of the entire row and compare the diffrence of the two tables')
    result_dict = {}

    cols_csv = db_tbl.get_col_csv()

    hash_minus_query = f"""
        SELECT sha2(concat_ws('||', {cols_csv}), 256) HASHED FROM {db_tbl.full_copy_name}
        MINUS
        SELECT sha2(concat_ws('||', {cols_csv}), 256) FROM {sf_tbl.full_copy_name}
    """
    
    hash_check_query = f"SELECT COUNT(*) FROM ({hash_minus_query})"

    hash_res = execute_sql_query(hash_check_query)
    total_row_count = int(hash_res.result.data_array[0][0])

    if total_row_count != 0:
        if cfg['fail_at_first_check']:
            raise Exception(f'Hash check failed, {total_row_count} rows are different')
        result_dict['hash_check'] = {'passed': False}
    else:
        result_dict['hash_check'] = {'passed': True}
    result_dict['hash_check']['row_diff'] = total_row_count
    
    if 6 not in cfg['skip_checks'] and not cfg['primary_keys'] and cfg['num_of_sample_rows'] and not result_dict['hash_check']['passed']:
        print(f"Finding {cfg['num_of_sample_rows']} rows from DB where the hashes did not find a match in SF, saving to {cfg['missmatch_tbl_name']}")
        
        sample_rows_query = f"""CREATE OR REPLACE TABLE {cfg['missmatch_tbl_name']} AS
            SELECT * FROM {db_tbl.full_copy_name}
            WHERE sha2(concat_ws('||', {cols_csv}), 256) IN ({hash_minus_query})
            LIMIT {cfg['num_of_sample_rows']}
        """
        
        execute_sql_query(sample_rows_query)


    if cfg['include_queries']:
        result_dict['hash_check']['query'] = hash_check_query
    
    return result_dict


#### 4 - Check for duplicate primary keys ####
def duplicate_pk_check(cfg: dict, db_tbl: Table, sf_tbl: Table) -> dict:
    print('Starting check #4, check for duplicate primary keys')
    result_dict = {}

    duplicate_pk_check_query = f"""
        SELECT lkp_src src, COUNT(DISTINCT pk) as duplicate_key_cnt
        FROM (
            SELECT 'DB_COUNT' as src, concat_ws('||', {','.join(cfg['primary_keys'])}) pk 
            FROM {db_tbl.full_copy_name} GROUP BY 2 HAVING COUNT(*) > 1
            UNION ALL
            SELECT 'SF_COUNT' as src, concat_ws('||', {','.join(cfg['primary_keys'])}) pk 
            FROM {sf_tbl.full_copy_name} GROUP BY 2 HAVING COUNT(*) > 1
        )
        -- Add this lookup so we always have a value, even if its 0
        RIGHT JOIN (SELECT 'DB_COUNT' AS lkp_src UNION ALL SELECT 'SF_COUNT')
        ON src = lkp_src
        GROUP BY 1
    """
    duplicate_pk_res = execute_sql_query(duplicate_pk_check_query, as_df=True)

    duplicate_pk_dict = {row['src']: row['duplicate_key_cnt'] for row in duplicate_pk_res.collect()}
    if duplicate_pk_dict['DB_COUNT'] != 0 or duplicate_pk_dict['SF_COUNT'] != 0:
        if cfg['fail_at_first_check']:
            # raise Exception(f'Duplicate primary keys found, {duplicate_pk_res.manifest.total_row_count} rows are different')
            raise Exception(f'Duplicate primary keys found, DB duplicates: {duplicate_pk_dict["DB_COUNT"]}, SF duplicates: {duplicate_pk_dict["SF_COUNT"]}')
        result_dict['duplicate_pk_check'] = {'passed': False}
    else:
        result_dict['duplicate_pk_check'] = {'passed': True}
    result_dict['duplicate_pk_check'].update(duplicate_pk_dict)

    if cfg['include_queries']:
        result_dict['duplicate_pk_check']['query'] = duplicate_pk_check_query
    return result_dict


#### 5 - Check for rows only in one table based on PKs ####
def exclusive_pk_check(cfg: dict, db_tbl: Table, sf_tbl: Table) -> dict:
    print('Starting check #5, check for PKs that are exclusive to one table')
    result_dict = {}

    exclusive_pk_check_query = f"""
        SELECT 'SF' as src, count(*) as exclusive_rows
        FROM (
            SELECT * FROM {sf_tbl.full_copy_name} sf WHERE NOT EXISTS (
                SELECT 1
                FROM {db_tbl.full_copy_name} db
                WHERE {' AND '.join([f'sf.{key} <=> db.{key}' for key in cfg['primary_keys']])}
            )
        )
        UNION ALL
        SELECT 'DB' as src, count(*) as exclusive_rows
        FROM (
            SELECT * FROM {db_tbl.full_copy_name} db WHERE NOT EXISTS (
                SELECT 1
                FROM {sf_tbl.full_copy_name} sf
                WHERE {' AND '.join([f'sf.{key} <=> db.{key}' for key in cfg['primary_keys']])}
            )
        )
    """
    exclusive_rows_res = execute_sql_query(exclusive_pk_check_query, as_df=True)

    exclusive_rows_dict = {row['src']: row['exclusive_rows'] for row in exclusive_rows_res.collect()}
    if exclusive_rows_dict['SF'] != 0 or exclusive_rows_dict['DB'] != 0:
        if cfg['fail_at_first_check']:
            raise Exception(f'Rows exclusive to one table based on PKs, SF: {exclusive_rows_dict["SF"]}, DB: {exclusive_rows_dict["DB"]}')
        result_dict['exclusive_pk_check'] = {'passed': False}
    else:
        result_dict['exclusive_pk_check'] = {'passed': True}
    result_dict['exclusive_pk_check'].update(exclusive_rows_dict)

    if cfg['include_queries']:
        result_dict['exclusive_pk_check']['query'] = exclusive_pk_check_query
    
    return result_dict


#### 6
def col_row_level_missmatch_check(cfg: dict, db_tbl: Table, sf_tbl: Table) -> dict:
    print('Starting check #6, check where the PKs match but the non PK columns do not')
    result_dict = {}

    col_lst = db_tbl.get_col_lst()

    # Base query that checks for matches in all columns/rows
    col_pk_row_mismatch_base_query = f"""
        SELECT {','.join(
            [
                f'db.{col} {col}_pk' for col in cfg['primary_keys']
            ] + [
                f'db.{col} {col}_db, sf.{col} {col}_sf, NOT db.{col} <=> sf.{col} {col}_missmatch_flg' for col in col_lst if col not in cfg['primary_keys']
            ]
        )}
        FROM {db_tbl.full_copy_name} db
        JOIN {sf_tbl.full_copy_name} sf
        ON {' AND '.join([f'db.{key} <=> sf.{key}' for key in cfg['primary_keys']])}
            AND ({' OR '.join([f'NOT db.{col} <=> sf.{col}' for col in col_lst if col not in cfg['primary_keys']])})
    """

    # Create a view so that we can cache the results and reuse for the sample rows query
    col_pk_row_mismatch_base_tmp_view_name = f"{cfg['user_schema']}.col_pk_row_mismatch_base__recon_tmp"
    execute_sql_query(f"""CREATE OR REPLACE TABLE {col_pk_row_mismatch_base_tmp_view_name} as {col_pk_row_mismatch_base_query}""")

    # Query to get the number of missmatches in a column
    col_pk_row_mismatch_counts_query = f"""
        SELECT col, count(missmatch_flg) missmatch_cnt
        FROM {col_pk_row_mismatch_base_tmp_view_name}
        UNPIVOT(missmatch_flg for col in ({','.join([f'{col}_missmatch_flg as {col}' for col in col_lst if col not in cfg['primary_keys']])}))
        WHERE missmatch_flg
        GROUP BY 1
    """

    col_pk_row_mismatch_counts_res = execute_sql_query(col_pk_row_mismatch_counts_query, as_df=True)

    col_missmatch_counts_dict = {row['col']: row['missmatch_cnt'] for row in col_pk_row_mismatch_counts_res.collect()}
    if col_pk_row_mismatch_counts_res.count() > 0:
        if cfg['fail_at_first_check']:
            print(json.dumps(col_missmatch_counts_dict, indent=4))
            raise Exception(f'The above {col_pk_row_mismatch_counts_res.count()} columns have missmatches in at least 1 row')
        result_dict['col_pk_row_mismatch'] = {'passed': False}
    else:
        result_dict['col_pk_row_mismatch'] = {'passed': True}
    result_dict['col_pk_row_mismatch']['cols_with_missmatched_rows'] = col_missmatch_counts_dict

    if cfg['include_queries']:
        result_dict['col_pk_row_mismatch']['query'] = col_pk_row_mismatch_counts_query

    # Save sample rows
    if cfg['num_of_sample_rows'] and col_missmatch_counts_dict:
        print(f"Finding {cfg['num_of_sample_rows']} sample rows for each column with missmatches, saving to {cfg['missmatch_tbl_name']}")

        # Query to generate the delta table
        missmatched_rows_query = f"""
            CREATE OR REPLACE TABLE {cfg['missmatch_tbl_name']} AS
        """

        # Old query need to see what is more efficient
        missmatched_rows_query += ' UNION ALL '.join([f"""(
                            SELECT DISTINCT {','.join(
                                        [
                                            f'{key}_pk' for key in cfg['primary_keys']
                                        ] + [
                                            f"'{col}' as col_name", f"{col}_db::STRING db_val", f"{col}_sf::STRING sf_val"
                                        ]
                                    )}
                            FROM {col_pk_row_mismatch_base_tmp_view_name}
                            WHERE {col}_missmatch_flg
                            LIMIT {cfg['num_of_sample_rows']}
                        )""" for col in col_missmatch_counts_dict])

        execute_sql_query(missmatched_rows_query)
    
    execute_sql_query(f"DROP TABLE {col_pk_row_mismatch_base_tmp_view_name}")

    return result_dict


def set_warehouse_id(tmp_wh_id):
    global warehouse_id
    warehouse_id = tmp_wh_id

def set_spark():
    global spark
    try:
        spark = SparkSession.builder.getOrCreate()
    except:
        print('Failed to get spark session, creating one using databricks-connect')
        
        # Wait for the cluster to start
        print('Waiting for cluster to start...')
        if w.clusters.get(cluster_id).state.value == 'TERMINATED':
            w.clusters.start(cluster_id)
        w.clusters.wait_get_cluster_running(cluster_id)     
        print('Getting spark session...')   
        spark = DatabricksSession.builder.clusterId(cluster_id).getOrCreate()


def set_globals(warehouse_id=None):
    if ('warehouse_id' not in globals() or globals()['warehouse_id'] is None or globals()['warehouse_id'] != warehouse_id) and warehouse_id is not None:
        set_warehouse_id(warehouse_id)
    if 'spark' not in globals() or globals()['spark'] is None:
        set_spark()


def reconcile(cfg: dict) -> dict:
    # Need to do this twice because we need spark
    set_globals()
    print('Starting reconcile...')

    cfg_copy = validate_cfg({key: val for key, val in cfg.items() if key in ('use_cfg_tbl', 'cfg_tbl_name', 'db_tbl')}, exclude_checks=['sf_tbl'])
    if cfg_copy['use_cfg_tbl'] and cfg_copy['cfg_tbl_name']:
        cfg = update_cfg_from_cfg_tbl(cfg, cfg_copy['cfg_tbl_name'])

    cfg = validate_cfg(cfg)
    set_globals(warehouse_id=cfg['warehouse_id'])

    current_user = execute_sql_query("select replace(current_user(), '@dropbox.com', '')", as_df=True).collect()[0][0]
    user_schema = f'users.{current_user}'

    # Test to make sure the schema exists
    execute_sql_query(f'show tables in {user_schema}')

    cfg['user_schema'] = user_schema

    db_tbl = Table(
        cfg['db_tbl'].lower().strip()
        , full_copy_name=f'{user_schema}.%TBL_NAME%__reconcile_tmp_databricks'
        , exclude_cols=cfg['exclude_cols'] + cfg['db_exclude_cols']
        , col_mapping={mapping['db']: mapping['sf'] for mapping in cfg['column_mapping']}
        , column_exprs=cfg['column_exprs']
        , data_type_exprs=cfg['data_type_exprs']
    )

    sf_tbl = Table(
        cfg['sf_tbl'].lower().strip()
        , full_copy_name=f'{user_schema}.%TBL_NAME%__reconcile_tmp_snowflake'
        , exclude_cols=cfg['exclude_cols'] + cfg['sf_exclude_cols']
        , col_mapping={mapping['sf']: mapping['db'] for mapping in cfg['column_mapping']}
        , column_exprs=cfg['column_exprs']
        , data_type_exprs=cfg['data_type_exprs']
    )

    cfg['missmatch_tbl_name'] = f'{user_schema}.{db_tbl.tbl_name}__missmatches'

    db_col_lst = db_tbl.get_col_lst()
    sf_col_lst_alias = sf_tbl.get_col_lst(aliased=True)

    cfg['primary_keys'] = [key.lower() for key in cfg['primary_keys']]

    # Check to make sure primary keys are not excluded and exist in the tbl
    db_missing_pk_keys = list(set(cfg['primary_keys']) - set(db_col_lst))
    sf_missing_pk_keys = list(set(cfg['primary_keys']) - set(sf_col_lst_alias))
    if len(db_missing_pk_keys) != 0 or len(sf_missing_pk_keys) != 0:
        raise Exception(f'Primary keys excluded or does not exists in db or sf table:\nDB missing keys: {db_missing_pk_keys}\nSF missing keys: {sf_missing_pk_keys}')


    # print((set(db_cols) - set(sf_cols)) | (set(sf_cols) - set(db_cols)))
    if set(db_col_lst) != set(sf_col_lst_alias):
        raise Exception(f'THE COLUMNS DO NOT MATCH, DB EXCLUSIVE COLS: {set(db_col_lst) - set(sf_col_lst_alias)}, SF EXCLUSIVE COLS: {set(sf_col_lst_alias) - set(db_col_lst)}')

    # Create a view in delta in case there is a filter clause
    filter_clause = ''
    if cfg['filter_clause']:
        filter_clause = f" WHERE {cfg['filter_clause']}"


    # Copy the sf table to delta so we can query it faster
    if cfg['create_sf_copy']:
        print(f"Creating {cfg['sf_copy_obj_type'].lower()}: {sf_tbl.full_copy_name} for SF table: {sf_tbl.full_tbl_name}")
        execute_sql_query(sf_tbl.get_copy_ddl(cfg['sf_copy_obj_type'], filter_clause=filter_clause, aliased=True))

    # Create a view for the DB table so that we can have a filter
    if cfg['create_db_copy']:
        print(f"Creating {cfg['db_copy_obj_type'].lower()}: {db_tbl.full_copy_name} for DB table: {db_tbl.full_tbl_name}")
        execute_sql_query(db_tbl.get_copy_ddl(cfg['db_copy_obj_type'], filter_clause=filter_clause))


    # Start the validation
    thread_lst = ThreadList()
    result_dict = {}
    if 1 not in cfg['skip_checks']:
        if cfg['use_threads']:
            thread_lst.append(ThreadWithReturnValue(target=col_dt_check, args=(cfg, db_tbl, sf_tbl)))
        else:
            result_dict.update(col_dt_check(cfg, db_tbl, sf_tbl))

    if 2 not in cfg['skip_checks']:
        if cfg['use_threads']:
            thread_lst.append(ThreadWithReturnValue(target=count_check, args=(cfg, db_tbl, sf_tbl)))
        else:
            result_dict.update(count_check(cfg, db_tbl, sf_tbl))

    if 3 not in cfg['skip_checks']:
        if cfg['use_threads']:
            thread_lst.append(ThreadWithReturnValue(target=hash_check, args=(cfg, db_tbl, sf_tbl)))
        else:
            result_dict.update(hash_check(cfg, db_tbl, sf_tbl))


    # Primary key checks
    if cfg['primary_keys']:
        result_dict['primary_keys'] = cfg['primary_keys']

        if 4 not in cfg['skip_checks']:   
            if cfg['use_threads']:
                thread_lst.append(ThreadWithReturnValue(target=duplicate_pk_check, args=(cfg, db_tbl, sf_tbl)))
            else:
                result_dict.update(duplicate_pk_check(cfg, db_tbl, sf_tbl))

        if 5 not in cfg['skip_checks']:
            if cfg['use_threads']:
                thread_lst.append(ThreadWithReturnValue(target=exclusive_pk_check, args=(cfg, db_tbl, sf_tbl)))
            else:
                result_dict.update(exclusive_pk_check(cfg, db_tbl, sf_tbl))

        if 6 not in cfg['skip_checks']:
            if cfg['use_threads']:
                thread_lst.append(ThreadWithReturnValue(target=col_row_level_missmatch_check, args=(cfg, db_tbl, sf_tbl)))
            else:
                result_dict.update(col_row_level_missmatch_check(cfg, db_tbl, sf_tbl))

    if thread_lst:
        thread_lst.start_all_threads()
        result_dict.update(thread_lst.join_all_threads())

    if cfg['drop_sf_copy']:
        print(f"Dropping SF tmp {cfg['sf_copy_obj_type'].lower()}: {sf_tbl.full_copy_name}")
        execute_sql_query(f"DROP {cfg['sf_copy_obj_type']} {sf_tbl.full_copy_name}")

    if cfg['drop_db_copy']:
        print(f"Dropping DB tmp {cfg['db_copy_obj_type'].lower()}: {db_tbl.full_copy_name}")
        execute_sql_query(f"DROP {cfg['db_copy_obj_type']} {db_tbl.full_copy_name}")

    if cfg['cfg_tbl_name'] is not None:
        update_cfg_tbl(cfg, cfg['cfg_tbl_name'])

    return result_dict


def get_tbl_col_counts(full_tbl_name: str, warehouse_id: str, results_tbl: str='users.pamons.tbl_column_counts_tbl') -> None:
    set_globals(warehouse_id)
    
    res = []
    db_tbl = Table(full_tbl_name.lower().strip())

    tbl_row_count = execute_sql_query(f'SELECT COUNT(*) FROM {db_tbl.full_tbl_name}', as_df=True, warehouse_id=warehouse_id).collect()[0][0]

    # Need to do this before the exclude cols
    all_cols_count = f"COUNT(DISTINCT CONCAT_WS('||', {db_tbl.get_col_csv()})) as all_cols_distinct_cnt"

    # You can auto exclude bool cols from being a key, they can only have 3 values: true, false, null
    db_tbl.set_exclude_cols([col['name'] for col in db_tbl.col_info if col['data_type'] == 'boolean'])

    # Search for columns that have unique values
    individual_col_key_query = f"""
        SELECT col_name, distinct_val_count, {tbl_row_count} = distinct_val_count AS is_pk
        FROM (
            SELECT {','.join(
                [all_cols_count]
                +
                [f'COUNT(DISTINCT {col_name}) as {col_name}_distinct_cnt' for col_name in db_tbl.get_col_lst()]
            )}
            FROM {db_tbl.full_tbl_name}
        )
        UNPIVOT(DISTINCT_VAL_COUNT FOR COL_NAME IN ({','.join(
            ['all_cols_distinct_cnt AS all_cols']
            +
            [f'{col_name}_distinct_cnt AS {col_name}' for col_name in db_tbl.get_col_lst()]
        )}))
        ORDER BY distinct_val_count desc
    """
    individual_col_key_df = execute_sql_query(individual_col_key_query, as_df=True, warehouse_id=warehouse_id)

    for row in individual_col_key_df.collect():
        tmp_res = {'tbl_name': db_tbl.full_tbl_name, 'col_name': row['col_name'].lower()}
        tmp_res.update({col: row[col] for col in individual_col_key_df.columns})
        res.append(tmp_res)

    res_df = spark.createDataFrame(res
        , ', '.join(['tbl_name string'] + [' '.join(col) for col in individual_col_key_df.dtypes])
    )
    
    (
        res_df.write.mode('overwrite')
        .option('replaceWhere', f'tbl_name = "{db_tbl.full_tbl_name}"')
        .option('overwriteSchema', True)
        .saveAsTable(results_tbl)
    )

    primary_key_count = res_df.filter(col('is_pk')).count()
    if primary_key_count == 0:
        print(f'Could not find PK for tbl {db_tbl.full_tbl_name}')
    else:
        print(f'Found {primary_key_count} PKs for tbl {db_tbl.full_tbl_name}')

    print(f'Writing results to {results_tbl}')


#### find_primary_keys helper functions ####
def divide_list_into_chunks(l, n): 
    for i in range(0, len(l), n):  
        yield l[i:i + n] 


def test_col_chunk(tbl_name: str, col_chunk: list, exclude_cols: list, warehouse_id: str):
    except_query = f"except({', '.join(exclude_cols + col_chunk)})"
    query = f"""select count(*) from (select * {except_query} from {tbl_name} group by all having count(*) > 1)"""
    return int(execute_sql_query(query, warehouse_id=warehouse_id).result.data_array[0][0])


def find_primary_keys(tbl_name, warehouse_id: str, update_tbl_counts: bool=True, tbl_column_counts_tbl_name: str='users.pamons.tbl_column_counts_tbl'):
    set_globals(warehouse_id)

    if update_tbl_counts:
        get_tbl_col_counts(tbl_name, warehouse_id=warehouse_id, results_tbl=tbl_column_counts_tbl_name)
    
    tbl_res_exist = spark.table(tbl_column_counts_tbl_name) \
        .filter((col('tbl_name') == tbl_name))
    
    if not update_tbl_counts and tbl_res_exist.count() == 0:
        raise Exception(f'update_tbl_counts IS FALSE BUT THERE ARE NO RESULTS IN {tbl_column_counts_tbl_name} FOR TBL {tbl_name}')
    
    # If all the cols combined do not create a PK then it is impossible to create one
    all_cols_pk = tbl_res_exist.filter((col('is_pk')) & (lower(col('col_name')) == 'all_cols'))
    if all_cols_pk.count() == 0:
        print('No PKs possible, all columns combined do no create a PK')
        return []

    res = [row['col_name'] for row in tbl_res_exist.filter((col('is_pk')) & (lower(col('col_name')) != 'all_cols')).collect()]
    if not res:
        non_pk_cols_df = tbl_res_exist.filter(~col('is_pk')).orderBy(asc('distinct_val_count'))

        col_lst = [row['col_name'] for row in non_pk_cols_df.collect()]
        exclude_cols = []
        exclude_cols_old = []

        chunk_size = len(col_lst)//6
        chunk_size = 1 if chunk_size < 1 else chunk_size 
        print(f'Chunk size: {chunk_size}')
        while exclude_cols == exclude_cols_old and len(list(set(col_lst) - set(exclude_cols))):
            exclude_cols_old = exclude_cols.copy()

            for col_chunk in list(divide_list_into_chunks(col_lst, chunk_size)):
                res_count = test_col_chunk(tbl_name, col_chunk, exclude_cols, warehouse_id)
                if res_count == 0:
                    print('Excluding cols:', ', '.join(col_chunk), 'no duplicates')
                    exclude_cols.extend(col_chunk)
                else:
                    for col_name in col_chunk:
                        res_count = test_col_chunk(tbl_name, [col_name], exclude_cols, warehouse_id)
                        if res_count == 0:
                            print(f'Excluding col: {col_name} no duplicates')
                            exclude_cols.append(col_name)

        res = [[col for col in col_lst if col not in exclude_cols]]
    
    return res

In [0]:
tables = [

{

   'db_tbl': 'bi_dev.pamons.dim_core_team_upsell_events'
       , 'sf_tbl': 'connection__bi_snowflake.dimensions.dim_core_team_upsell_events'
    , 'num_of_sample_rows': 5
    , 'warehouse_id': '65a50cfd6840d9db'
}

# {
#     'db_tbl': 'bi_dev.pamons.int_core_video_file_action_events'
#     , 'sf_tbl': 'connection__bi_snowflake.intermediate.int_core_video_file_action_events'
#     , 'num_of_sample_rows': 5
#     , 'warehouse_id': '65a50cfd6840d9db'
#     , 'filter_clause': ' event_dt <= "2024-08-11"'
# }

# {
#     'db_tbl': 'bi_dev.pamons.int_core_ss_team_churn_volume'
#     , 'sf_tbl': 'connection__bi_snowflake.intermediate.int_core_ss_team_churn_volume'
#     , 'num_of_sample_rows': 5
#     , 'warehouse_id': '65a50cfd6840d9db'
#     ,'column_exprs': {'total_arr_ly': 'round({__COLUMN_NAME__}, 2)', 'total_arr': 'round({__COLUMN_NAME__}, 2)'}

# },

# ,{
#     'db_tbl': 'bi_dev.pamons.stg_core_team_attribute'
#     , 'sf_tbl': 'connection__bi_snowflake.staging.stg_core_team_attribute'
#     , 'num_of_sample_rows': 5
#     , 'warehouse_id': '65a50cfd6840d9db'
#     , 'exclude_cols': ['process_ts']
#     , 'filter_clause': " report_dt = '2024-07-29'"
# }
    
# ,{
#     'db_tbl': 'bi_dev.pamons.stg_core_teams_total_arr'
#     , 'sf_tbl': 'connection__bi_snowflake.staging.stg_core_teams_total_arr'
#     , 'num_of_sample_rows': 5
#     , 'warehouse_id': '65a50cfd6840d9db'
#     , 'db_exclude_cols': ['day']
# }
    
# ,{
#     'db_tbl': 'bi_dev.pamons.stg_core_user_attribute'
#     , 'sf_tbl': 'connection__bi_snowflake.staging.stg_core_user_attribute'
#     , 'num_of_sample_rows': 5
#     , 'warehouse_id': '65a50cfd6840d9db'
#     , 'filter_clause': " report_dt = '2024-07-26'"

# }

]


In [0]:
from validation_tools import reconcile
from validation_tools.html_helpers import count_checks, col_schema_check, col_row_missmatch_check
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, BooleanType, MapType, ArrayType, DateType
from datetime import datetime, timedelta
import traceback

results = []


start_time = datetime.now()
max_duration = timedelta(hours=10)


for table in tables:
    current_timestamp = datetime.now()
    data ={}
    res = {}
    try:
        res = reconcile(table)
    
    except Exception as e:
        print("RECONCILIATION ERRORED OUT")
        print(traceback.format_exc())
        data['comments'] = "ERROR: " + str(e) + str(traceback.format_exc())
    try:
        try:
            html_attrs = [
            f"<h3>Matched using keys: {', '.join(res['primary_keys'])}</h3>"
            , count_checks(res)
            , col_schema_check(res)
            , col_row_missmatch_check(res)
            ]
            displayHTML('</br>'.join(html_attrs))
        except Exception as e:
            try:
                html_attrs = [
                count_checks(res)
                , col_schema_check(res)
                ]
                displayHTML('</br>'.join(html_attrs))
            except Exception as e:
                print("****Could not print HTML version ****" + str(e))
        data['count_check_passed'] = res.get('count_check', {}).get('passed')
        data['schema_check_passed'] = res.get('schema_check', {}).get('passed')
        data['col_pk_row_mismatch_passed'] = res.get('col_pk_row_mismatch', {}).get('passed')
        data['hash_check_passed'] = res.get('hash_check', {}).get('passed')

        if not data['hash_check_passed']:
            data['comments'] = f"Hash check failed, check table users.amuthumani.{table['db_tbl'].split('.')[-1]}__hash_mismatches for samples"

        if 'mismatched_rows_dict' in res.get('col_pk_row_mismatch', {}):   
            data['mismatched_rows'] = res['col_pk_row_mismatch']['mismatched_rows_dict']

        if not res.get('col_pk_row_mismatch', {}).get('passed') and (table.get('primary_keys', None) is None):
            data['comments'] = f"Column PK/Row mismatch check failed. Check table users.amuthumani.{table['db_tbl'].split('.')[-1]} for mismatched rows."

        if (res.get('hash_check', {}).get('passed') and 
            res.get('schema_check', {}).get('passed') and 
            res.get('count_check', {}).get('passed') and 
            res.get('col_pk_row_mismatch', {}).get('passed')):
            data['passed'] = True
        elif (res.get('hash_check', {}).get('passed') and 
            res.get('schema_check', {}).get('passed') and 
            res.get('count_check', {}).get('passed') and 
            (table.get('primary_keys', []) == [])):
            data['passed'] = True
        else:
            data['passed'] = False

        # Add schema_check details
        schema_check = res.get('schema_check', {})
        data['schema_check_db_col_count'] = schema_check.get('db_col_count')
        data['schema_check_sf_col_count'] = schema_check.get('sf_col_count')
        data['schema_check_missmatched_cols'] = str(schema_check.get('missmatched_cols', ''))

        count_check = res.get('count_check', {})
        data['count_check_db_count'] = count_check.get('DB_COUNT')
        data['count_check_sf_count'] = count_check.get('SF_COUNT')

        hash_check = res.get('hash_check', {})
        data['hash_check_row_diff'] = hash_check.get('row_diff')

        data['primary_keys'] = res.get('primary_keys')

        # Add duplicate_pk_check details
        duplicate_pk_check = res.get('duplicate_pk_check', {})
        data['duplicate_pk_check_db_count'] = duplicate_pk_check.get('DB_COUNT')
        data['duplicate_pk_check_sf_count'] = duplicate_pk_check.get('SF_COUNT')

    except Exception as e:
        print("ERRORED OUT")
        print(traceback.format_exc())
        data['comments'] = "ERROR: " + str(e) + str(traceback.format_exc())

    data['db_tbl'] = table['db_tbl']
    data['sf_tbl'] = table['sf_tbl']
    data['report_ts'] = current_timestamp

    results.append(data)
    # # Check if 4 hours have passed
    # if datetime.now() - start_time > max_duration:
    #     print("4 hours have passed. Ending the loop.")
    #     break


Starting reconcile...
Creating table: users.amuthumani.dim_core_team_upsell_events__reconcile_tmp_snowflake for SF table: connection__bi_snowflake.dimensions.dim_core_team_upsell_events
Creating view: users.amuthumani.dim_core_team_upsell_events__reconcile_tmp_databricks for DB table: bi_dev.pamons.dim_core_team_upsell_events
Starting check #1, compare the column datatypes
Starting check #2, compare the counts of the two tables
Starting check #3, take a hash of the entire row and compare the diffrence of the two tables
Finding 5 rows from DB where the hashes did not find a match in SF, saving to users.amuthumani.dim_core_team_upsell_events__missmatches
Dropping SF tmp table: users.amuthumani.dim_core_team_upsell_events__reconcile_tmp_snowflake
Dropping DB tmp view: users.amuthumani.dim_core_team_upsell_events__reconcile_tmp_databricks


0,1,2,3
Test,Databricks,Snowflake,Pass/Fail
Column Count,5,5,Pass
Record Count,3009,3688,Fail
Hash Missmatch Count,25,,Fail

0
Schemas match!


In [0]:
print(results)

[{'count_check_passed': False, 'schema_check_passed': True, 'col_pk_row_mismatch_passed': None, 'hash_check_passed': False, 'comments': 'Hash check failed, check table users.amuthumani.dim_core_team_upsell_events__hash_mismatches for samples', 'passed': False, 'schema_check_db_col_count': 5, 'schema_check_sf_col_count': 5, 'schema_check_missmatched_cols': '[]', 'count_check_db_count': 3009, 'count_check_sf_count': 3688, 'hash_check_row_diff': 25, 'primary_keys': None, 'duplicate_pk_check_db_count': None, 'duplicate_pk_check_sf_count': None, 'db_tbl': 'bi_dev.pamons.dim_core_team_upsell_events', 'sf_tbl': 'connection__bi_snowflake.dimensions.dim_core_team_upsell_events', 'report_ts': datetime.datetime(2024, 8, 15, 19, 41, 4, 362810)}]


In [0]:
spark = SparkSession.builder.getOrCreate()

In [0]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, BooleanType, MapType, ArrayType, DateType, LongType

# Starting check #1, compare the column datatypes
# Starting check #2, compare the counts of the two tables
# Starting check #3, take a hash of the entire row and compare the diffrence of the two tables
# Starting check #4, check for duplicate primary keys
# Starting check #5, check for PKs that are exclusive to one table
# Starting check #6, check where the PKs match but the non PK columns do not
schema = StructType([
    
    StructField("db_tbl", StringType(), True),
    StructField("sf_tbl", StringType(), True),
    StructField("report_ts", TimestampType(), True),
    StructField("passed", BooleanType(), True),
    StructField("primary_keys", ArrayType(StringType()), True),
    StructField("duplicate_pk_check_db_count", LongType(), True),
    StructField("duplicate_pk_check_sf_count", LongType(), True),
    StructField("count_check_passed", BooleanType(), True),
    StructField("count_check_db_count", LongType(), True),
    StructField("count_check_sf_count", LongType(), True),
    StructField("schema_check_passed", BooleanType(), True),
    StructField("schema_check_db_col_count", LongType(), True),
    StructField("schema_check_sf_col_count", LongType(), True),
    StructField("schema_check_missmatched_cols", StringType(), True),
    
    StructField("hash_check_passed", BooleanType(), True),
    StructField("hash_check_row_diff", LongType(), True),
    StructField("col_pk_row_mismatch_passed", BooleanType(), True),
    StructField("mismatched_rows", MapType(
        StringType(), ArrayType(StructType([
            StructField("report_dt_pk", DateType(), True),
            StructField("col_name", StringType(), True),
            StructField("db_val", StringType(), True),
            StructField("sf_val", StringType(), True)
        ]))
    ), True),
    StructField("comments", StringType(), True)
])
# Convert to DataFrame
df = spark.createDataFrame(results, schema)

# Write to a table
output_table_name = "bi_dev.amuthumani.comparison_results_delta"
# df.write.mode("overwrite").saveAsTable(output_table_name)
df.write.format("delta").mode("append").saveAsTable(output_table_name)

In [0]:
def columns_to_join(tables, i):
    # Define the table name
    db_tbl = tables[i]['db_tbl']
    mismtach_table_name = f"users.amuthumani.{db_tbl.split('.')[-1]}__missmatches"

    # Load the table schema
    mismatch_df = spark.table(mismtach_table_name)

    # Get the schema (list of column names and types)
    schema = mismatch_df.schema

    # Define the suffixes to search for
    suffixes = ('_dt', '_id')

    # Find columns ending with the specified suffixes
    matching_columns = [field.name for field in schema if field.name.endswith(suffixes)]


    # Print all column names comma seperated
    print("All column names:")
    print(", d.".join(schema.names))
    print()
    print("join columns", matching_columns)

    return mismtach_table_name, matching_columns, schema.names

In [0]:

def generate_sql(mismatch_table_name, cols, all_column_names, table):
  # Base tables
  sf_table = table['sf_tbl']

  db_table = table['db_tbl']

  # Generate join conditions using cols list
  join_conditions = " and ".join([f"m.{col} = s.{col.upper()}" for col in cols])
  db_join_conditions = " and ".join([f"m.{col} = d.{col.upper()}" for col in cols])

  # SQL query
  sql_text = f"""
  SELECT * FROM (
    SELECT 'SF' AS DB, s.*
    FROM {mismatch_table_name} m 
    LEFT OUTER JOIN {sf_table} s ON {join_conditions}
    UNION ALL
    SELECT 'DB' AS DB, d.* 
    FROM {mismatch_table_name} m 
    LEFT OUTER JOIN {db_table} d ON {db_join_conditions}
  ) 
  ORDER BY all
  LIMIT 100;
  """
  # Exclude columns 'column1' and 'column2'
  excluded_columns = table.get('db_exclude_cols', []) + table.get('sf_exclude_cols', []) + ['hashed']

  # Get the list of all columns
  all_columns = [col for col in all_column_names if col not in excluded_columns]

  sql_text = sql_text.replace(', s.*', f', s.{", s.".join(all_columns)}')
  sql_text = sql_text.replace(', d.*', f', d.{", d.".join(all_columns)}')

  print(sql_text)
  return sql_text


In [0]:
for i, table in enumerate(tables):
    mismtach_table_name, matching_columns, all_column_names = columns_to_join(tables, i)
    sql_text = generate_sql(mismtach_table_name, matching_columns[0:4], all_column_names, table)
    spark.sql(sql_text).show()

All column names:
team_id, d.day, d.transaction_id, d.upsell_subtype, d.is_idaho_upsell, d.HASHED

join columns ['team_id', 'transaction_id']

  SELECT * FROM (
    SELECT 'SF' AS DB, s.team_id, s.day, s.transaction_id, s.upsell_subtype, s.is_idaho_upsell, s.HASHED
    FROM users.amuthumani.dim_core_team_upsell_events__missmatches m 
    LEFT OUTER JOIN connection__bi_snowflake.dimensions.dim_core_team_upsell_events s ON m.team_id = s.TEAM_ID and m.transaction_id = s.TRANSACTION_ID
    UNION ALL
    SELECT 'DB' AS DB, d.team_id, d.day, d.transaction_id, d.upsell_subtype, d.is_idaho_upsell, d.HASHED 
    FROM users.amuthumani.dim_core_team_upsell_events__missmatches m 
    LEFT OUTER JOIN bi_dev.pamons.dim_core_team_upsell_events d ON m.team_id = d.TEAM_ID and m.transaction_id = d.TRANSACTION_ID
  ) 
  ORDER BY all
  LIMIT 100;
  


[0;31m---------------------------------------------------------------------------[0m
[0;31mAnalysisException[0m                         Traceback (most recent call last)
File [0;32m<command-3922396375445145>, line 4[0m
[1;32m      2[0m mismtach_table_name, matching_columns, all_column_names [38;5;241m=[39m columns_to_join(tables, i)
[1;32m      3[0m sql_text [38;5;241m=[39m generate_sql(mismtach_table_name, matching_columns[[38;5;241m0[39m:[38;5;241m4[39m], all_column_names, table)
[0;32m----> 4[0m spark[38;5;241m.[39msql(sql_text)[38;5;241m.[39mshow()

File [0;32m/databricks/spark/python/pyspark/sql/connect/session.py:733[0m, in [0;36mSparkSession.sql[0;34m(self, sqlQuery, args, **kwargs)[0m
[1;32m    730[0m         _views[38;5;241m.[39mappend(SubqueryAlias(df[38;5;241m.[39m_plan, name))
[1;32m    732[0m cmd [38;5;241m=[39m SQL(sqlQuery, _args, _named_args, _views)
[0;32m--> 733[0m data, properties [38;5;241m=[39m [38;5;28mself[39m[38;5;24

In [0]:
%sql
SELECT * FROM (
    SELECT 'SF' AS DB, s.team_id, s.day, s.transaction_id, s.upsell_subtype, s.is_idaho_upsell
    FROM users.amuthumani.dim_core_team_upsell_events__missmatches m 
    LEFT OUTER JOIN connection__bi_snowflake.dimensions.dim_core_team_upsell_events s ON m.team_id = s.TEAM_ID and m.day = s.day
    UNION ALL
    SELECT 'DB' AS DB, d.team_id, d.day, d.transaction_id, d.upsell_subtype, d.is_idaho_upsell
    FROM users.amuthumani.dim_core_team_upsell_events__missmatches m 
    LEFT OUTER JOIN bi_dev.pamons.dim_core_team_upsell_events d ON m.team_id = d.TEAM_ID and m.day = d.day
  ) 
  ORDER BY all
  LIMIT 100;

DB,team_id,day,transaction_id,upsell_subtype,is_idaho_upsell
DB,4134787.0,2023-07-19,751045027,PLAN PRICE UPGRADE,True
DB,6752627.0,2023-09-23,776920715,PLAN PRICE UPGRADE,False
DB,7930371.0,2023-10-13,784091301,PLAN PRICE UPGRADE,False
DB,8173651.0,2023-10-18,785583250,PLAN PRICE UPGRADE,False
DB,8250435.0,2023-10-18,785857965,PLAN PRICE UPGRADE,False
SF,4134787.0,2023-07-19,751044294,PLAN PRICE UPGRADE,True
SF,6752627.0,2023-09-23,776923951,PLAN PRICE UPGRADE,False
SF,7930371.0,2023-10-13,784092050,PLAN PRICE UPGRADE,False
SF,8173651.0,2023-10-18,785583280,PLAN PRICE UPGRADE,False
SF,8250435.0,2023-10-18,785857940,PLAN PRICE UPGRADE,False
