In [ ]:
import copy
import pyspark.sql.types as T
import re
from pyspark.sql import functions as f


class EdgraphDWHBuilder:
    def __init__(self, original_metadata, stage3_db_name,stage_3_path, partitioning, spark, oea, logger, error_logger,lakeTableOverwrite,entity_freq_processor):
        self.original_metadata = original_metadata
        self.processed_metadata = copy.deepcopy(original_metadata)
        self.stage_3_path = stage_3_path
        self.stage3_db_name = stage3_db_name
        self.partitioning = partitioning
        self.spark = spark
        self.oea = oea
        self.logger = logger
        self.base_tables = []
        self.error_logger = error_logger
        self.pipelineExecutionId = 'TEST_' + error_logger.generate_random_alphanumeric(10)
        self.lakeTableOverwrite = lakeTableOverwrite
        self.entity_freq_processor = entity_freq_processor

    def return_primary_key(self, table_name, columns):
        # TODO: Revise Logic
        if table_name == 'DimSchool':
            primary_key = 'SchoolHKey'
        elif table_name == 'DimStudentRace':
            primary_key = 'StudentRaceHKey'
        elif table_name == 'DimDate':
            primary_key = 'DateKey'
        elif table_name == 'DimSchoolYear':
            primary_key = 'SchoolYearShort'
        elif table_name == 'DimLearningStandard':
            primary_key = ['LearningStandardHKey', 'ParentLearningStandardHKey']
        elif table_name == 'DimObjectiveAssessment':
            primary_key = ['ObjectiveAssessmentHKey', 'ParentObjectiveAssessmentHKey']
        elif table_name == 'UserAuthorization':
            primary_key = 'UserKey'
        elif table_name == 'Parameter':
            primary_key = 'ParameterId'
        elif table_name == 'Goals':
            primary_key = 'GoalId'
        elif table_name == 'DescriptorConstant':
            primary_key = 'DescriptorConstantId'
        elif table_name == 'ExecutionDuration':
            primary_key = 'ExecutionDurationId'
        elif table_name == 'ExplicitStudentDataAuthorization':
            primary_key = 'UserKey'
        elif table_name == 'ExecutionAudit':
            primary_key = 'ExecutionAuditId'
        else:
            table_name_mask = (table_name[:4].lower() == 'fact') or (table_name == 'DimAssessmentSection') or (table_name == 'DataAuthorization') or (table_name == 'DimObjectiveAssessmentPerformanceLevel') or (table_name == 'DimAssessmentPerformanceLevel') or (table_name == 'DimAssessmentAcademicSubject') or (table_name == 'DimAssessmentAssessedGradeLevel')
            if table_name_mask:
                fact_pk = []
                for column in columns:
                    if column.lower().endswith('skey'):
                        primary_key = column
                        fact_pk.append(primary_key)
                    else:
                        pass
                return fact_pk

            else:
                for column in columns:
                    if column.lower().endswith('hkey'):
                        primary_key = column
                        break
                    else:
                        primary_key = 'PLACEHOLDER'
            
        return primary_key
    
    def insert_default_record(self, 
                            df, 
                            hKeyDefaults=["1900", "Not Specified", "NotSpecified"]):
        default_values = {}
        binary_cols = []
        hKey_cols = []
        for column in df.columns:
            # FIXME 08-02-2024 (added separate default record condition for schoolyear and "date")
            if column.lower() == "schoolyear":
                default_values[column] = '1900'
            elif column.lower().endswith("date"):
                default_values[column] = "1900-01-01"
            elif column.lower().endswith("hkey"):
                hKey_cols.append(column)
            elif column.lower().endswith("skey"):
                default_values[column] = -1
            else:
                column_dtype = df.schema[column].dataType
                if isinstance(column_dtype, T.StringType):
                    default_values[column] = "Not Specified"
                elif isinstance(column_dtype, T.TimestampType):
                    default_values[column] = "1900-01-01"
                elif isinstance(column_dtype, (T.IntegerType, T.LongType, T.ShortType, T.ByteType)):
                    default_values[column] = -1
                elif isinstance(column_dtype, T.DoubleType):
                    default_values[column] = -1
                elif isinstance(column_dtype, T.BooleanType):
                    default_values[column] = False
                elif isinstance(column_dtype, T.BinaryType):
                    binary_cols.append(column)
                else:
                    default_values[column] = 'Not Specified'

        default_record_df = self.spark.createDataFrame([default_values])
        
        for hKey_col in hKey_cols:
        # TODO: Revise Logic for HKey
            default_record_df = default_record_df.withColumn(hKey_col, F.sha2(F.concat(*[F.lit(val) for val in hKeyDefaults]), 256))

        for binary_col in binary_cols:
        # TODO: Revise Logic for Binary
            default_record_df = default_record_df.withColumn(binary_col, F.sha2(F.concat(*[F.lit(val) for val in hKeyDefaults]), 256))        
        
        default_df = df.unionByName(default_record_df)
        return default_df

    def return_schema_tables_in_order(self, schema_metadata):
        sorted_tables = sorted(schema_metadata['tables'].items(), key=lambda x: x[1]['table_order'])
        sorted_tables_dict = dict(sorted_tables)

        for table_name, table_value in sorted_tables_dict.items():
            if table_value['table_order'] < 0:
                self.base_tables.append(table_name)
            else:
                break
        sorted_schema_metadata = {
            #'parameters': schema_metadata['parameters'],
            'tables': sorted_tables_dict
        }
        return sorted_schema_metadata

    def return_table_queries_in_order(self, table_metadata):
        sorted_queries = sorted(table_metadata['queries'], key=lambda x: x['order'])
        sorted_query_strings = [query['query'] for query in sorted_queries]
        sorted_query_params = [query.get('query_params', False) for query in sorted_queries]
        return sorted_query_strings, sorted_query_params

    def reorder_metadata_schemas(self):
        sorted_schemas = {}
        for schema_name in self.processed_metadata['metadata']['build-assets']['schemas'].keys():
            sorted_schema_metadata = self.return_schema_tables_in_order(
                self.processed_metadata['metadata']['build-assets']['schemas'][schema_name]
            )
            sorted_schemas[schema_name] = sorted_schema_metadata
        self.processed_metadata['metadata']['build-assets']['schemas'] = sorted_schemas

    def return_schema_queries_in_order(self, schema_name):
        schema_metadata = self.processed_metadata['metadata']['build-assets']['schemas'][schema_name]
        sorted_queries = {}
        sorted_params = {}
        for table_name, table_metadata in schema_metadata['tables'].items():
            sorted_table_queries, sorted_table_queries_params = self.return_table_queries_in_order(table_metadata)
            sorted_queries[table_name] = sorted_table_queries
            sorted_params[table_name] = sorted_table_queries_params
        return sorted_queries, sorted_params

    def parameterize_table_queries(self, schema_queries, schema_queries_params, table_name, query_params=None, **kwargs):
        if query_params is None:
            query_params = {}

        queries = schema_queries[table_name]
        params = schema_queries_params[table_name]
        parameterized_queries = []
        for index, query in enumerate(queries):
            
            # query = query.replace("{stage3_db_name}.dbo_", 'dbo_vw_')
            # query = query.replace("{stage3_db_name}.config_", 'config_vw_')
            
            # query = query.replace("{stage3_db_name}.auth_", 'auth_vw_')
            # query = query.replace("{stage3_db_name}.", 'dbo_vw_')
            # #query = query.replace("{stage3_db_name}.EducationOrganization", 'dbo_vw_EducationOrganization')

            # query = query.replace("{base_table_db_name}.dbo_", 'dbo_vw_')
            # query = query.replace("{stage3_db_name}.", 'dbo_vw_')
            # query = query.replace("{base_table_db_name}.config_", 'config_vw_')
            # query = query.replace("{base_table_db_name}.auth_", 'auth_vw_')
            #query = query.replace("{base_table_db_name}.EducationOrganization", 'dbo_vw_EducationOrganization')
            for key, value in kwargs.items():
                if key != 'query_params':
                    query = query.replace(f"{{{key}}}", value)
            if (params[index]) and (query_params):
                for key, value in query_params.items():
                    query = query.replace(f"{{{key}}}", value)

            query = query.replace(f"{stage3_db_name}.dbo_", 'dbo_vw_')
            query = query.replace(f"{stage3_db_name}.config_", 'config_vw_')
            
            query = query.replace(f"{stage3_db_name}.auth_", 'auth_vw_')
            query = query.replace(f"{stage3_db_name}.", 'dbo_vw_')
            #query = query.replace("{stage3_db_name}.EducationOrganization", 'dbo_vw_EducationOrganization')

            query = query.replace(f"{base_table_db_name}.dbo_", 'dbo_vw_')
            query = query.replace(f"{stage3_db_name}.", 'dbo_vw_')
            query = query.replace(f"{base_table_db_name}.config_", 'config_vw_')
            query = query.replace(f"{base_table_db_name}.auth_", 'auth_vw_')
            parameterized_queries.append(query)
        return parameterized_queries

    def execute_table_queries(self, schema_name, table_name, queries,stage3_db_name, surrogate_key = True, insertion_type = 'append', explain = False):
        # TODO: 2024-02-23: Error Logging for etl and staging queries 
        staging_list = self.generate_staging_list(len(queries), schema_name)
        for step_prefix, query in zip(staging_list,queries):
            self.execute_query(table_name, step_prefix,query,stage3_db_name, surrogate_key, insertion_type, explain)

    def execute_query(self, table_name, step_prefix, query,stage3_db_name, surrogate_key = True, insertion_type = 'append', explain = False): 
        # TODO: 2024-02-23: Error Logging for etl and staging queries
        start_time = datetime.now()
        logger.info(f'Creating Temporary View - {step_prefix}_{table_name}')  
        #logger.info(query)
        query = query.replace("from", "FROM")
        query = query.replace("From", "FROM")
        query = query.replace(",FROM", ' FROM') 
        query = query.replace(", FROM", ' FROM')
        query = query.replace(",\nFROM", ' FROM')
        query = query.replace(",\n\nFROM", ' FROM')
        query = query.replace(",\n\n\nFROM", ' FROM')
        query = query.replace(",\n\n\n\nFROM", ' FROM')
        df = self.spark.sql(query)
        
        # if not(step_prefix == 'fact_vw' or table_name[:4].lower() == 'fact' or table_name.lower() == 'dimdate' or step_prefix == 'config_vw' or step_prefix == 'config' or step_prefix == 'auth' or step_prefix == 'auth_vw'):
        #         df = self.insert_default_record(df = df, 
        #                                     hKeyDefaults = ["1900", "Not Specified", "Not Specified"])
        if table_name !='DimDate':
            df = self.create_record_hash_column(df,essential_columns)

        columns = df.columns 
        primary_key = self.return_primary_key(table_name, columns)

        
            

        if table_name in self.base_tables:
            insertion_type = 'overwrite'
            primary_key = 'lakeId'
        #Testing Change

        table_path = f"{self.stage_3_path}/{step_prefix[:-3]}/{table_name}"
        table_url = oea.to_url(table_path)

        if DeltaTable.isDeltaTable(spark, table_url):
        #Upsert Logic for HKey columns - 04-01-2024
            for column in columns:
                if column.lower().endswith('hkey'):
                    insertion_type = 'upsert_EdGraph'

            if insertion_type == 'upsert_EdGraph':
                print(insertion_type)
                skey_column = table_name.replace('Dim','')
                skey_column = skey_column.replace('Fact','')
                print(skey_column)
                df_sink = spark.read.format('delta').load(table_url)
                max_skey = df_sink.agg({f"{skey_column}Skey": "max"}).collect()[0][0]
                print(max_skey)
                df_sink = df_sink.drop(*[skey_column+'Skey','RECORD_HASH'])
                df = df.drop(*[skey_column+'Skey','RECORD_HASH'])
                df_subtract = df.subtract(df_sink) #diff dataframe
                print(primary_key)
                print(type(primary_key))
                print(df_subtract.orderBy("ClassroomPositionHkey").show())
                print(df_sink.orderBy("ClassroomPositionHkey").show())
                print(df.orderBy("ClassroomPositionHkey").show())
                if type(primary_key) == 'list':
                    df_modified_rows = df_subtract.join(df_sink, how='leftsemi', on=primary_key) #will print the modified lines
                    df_new_rows      = df_subtract.join(df_sink, how='leftanti', on=primary_key) #will print the new lines
                else:
                    df_modified_rows = df_subtract.join(df_sink, how='leftsemi', on=primary_key) #will print the modified lines
                    print('modified_lines')
                    df_modified_rows = df_modified_rows.withColumn(f"{skey_column}Skey", lit(None).cast(IntegerType()))

                    print(df_modified_rows.show())
                    df_new_rows      = df_subtract.join(df_sink, how='leftanti', on=primary_key) #will print the new lines
                    print('new rows')
                    df_new_rows = df_new_rows.withColumn(f"{skey_column}Skey", lit(None).cast(IntegerType()))
                    print(df_new_rows.show())

                df = df_new_rows



        if ((insertion_type == 'append') or (insertion_type=='upsert_EdGraph')) and (step_prefix == 'dbo_vw') and (surrogate_key == True):
            logger.info(f'[APPEND MODE ACTIVE FOR QUERY CONSTRUCT]: {step_prefix} ::: {table_name} ::: {primary_key}')
            if type(primary_key) == list:
                pk_statement = self.oea.return_pk_statement(primary_key)
                skey = list()
                for pk_component in primary_key:
                    sk_component = pk_component[:-4] + 'SKey'
                    skey.append(sk_component)
            else:
                skey = primary_key[:-4] + 'SKey'
                if 'hkey' in primary_key.lower():
                    surrogate_key = True
                pk_statement = self.oea.return_pk_statement([primary_key])
            
            #2024-01-30 Change
            table_name_mask = (table_name[:4].lower() == 'fact') or (table_name == 'DimAssessmentSection') or (table_name == 'DataAuthorization') or (table_name == 'DimObjectiveAssessmentPerformanceLevel') or (table_name == 'DimAssessmentPerformanceLevel') or (table_name == 'DimAssessmentAcademicSubject') or (table_name == 'DimAssessmentAssessedGradeLevel')
            if table_name_mask:
                surrogate_key = False

            if surrogate_key:
                # FIXME 2024-03-15 Fix for SKey under review Monotonically v/s PK
                if type(primary_key) == list:
                    for index, pk_component in enumerate(primary_key):
                        sk_component = skey[index]
                        #df = df.withColumn('row_id_label', (F.monotonically_increasing_id()))
                        df = df.withColumn('row_id_label', (F.col(pk_component)))
                        windowSpec = W.orderBy("row_id_label")
                        df = df.withColumn("row_id_label", F.row_number().over(windowSpec))
                        
                        df = df.withColumn(sk_component, F.when((F.col(pk_component).isNull()) | (F.col(sk_component) == -1), -1).otherwise(F.col('row_id_label')))
                        df = df.drop('row_id_label')
                else:
                    #df = df.withColumn('row_id_label', (F.monotonically_increasing_id()))
                    df = df.withColumn('row_id_label', (F.col(primary_key)))
                    windowSpec = W.orderBy("row_id_label")
                    df = df.withColumn("row_id_label", F.row_number().over(windowSpec))
                    
                    
                    
                    df = df.withColumn(skey, F.when((F.col(primary_key).isNull()) | (F.col(skey) == -1), -1).otherwise(F.col('row_id_label')))
                    df = df.withColumn(skey, F.when(F.col(skey) == -1, F.col(skey)).otherwise(F.col(skey) + max_skey))
                    df = df.drop('row_id_label')

        if explain and (step_prefix == 'dbo_vw' or step_prefix == 'config_vw' or step_prefix == 'auth_vw'):
            df.explain()
        df.createOrReplaceTempView(f"{step_prefix}_{table_name}")

        if insertion_type == 'upsert_EdGraph':
            df_modified_rows.createOrReplaceTempView(f"{step_prefix}_{table_name}_modified")

        end_time = datetime.now()
        numInputRows = df.count()

        log_data = error_logger.create_log_dict(uniqueId = self.error_logger.generate_random_alphanumeric(10), # Generate a random 10-character alphanumeric value
                                                pipelineExecutionId = self.pipelineExecutionId,#'TEST_1234',#executionId,
                                                sparkSessionId = self.spark.sparkContext.applicationId,
                                                stageName = "edgraph-dwh: Migrate From S2R To Edgraph DWH Creating Spark Views",
                                                schemaFormat = 'edgraph-dwh',
                                                entityType = 'Spark View',
                                                entityName = table_name,
                                                numInputRows = numInputRows,
                                                totalNumOutputRows = 0,
                                                numTargetRowsInserted = 0,
                                                numTargetRowsUpdated = 0,
                                                numRecordsSkipped = 0,
                                                numRecordsDeleted = 0,
                                                start_time = start_time,
                                                end_time = end_time,
                                                insertionType = 'Creating Spark Views')
        error_logger.consolidate_logs(log_data,'entity')


        #spark.catalog.cacheTable(f"{step_prefix}_{table_name}")
    def threaded_dump_to_stage3_delta(self,input_tuple):
        table_name, step_prefix, schema_queries_in_order, essential_columns, surrogate_key = input_tuple
        
        try:
            status = 'Successful'
            if step_prefix in ['dbo_vw']:
                self.dump_to_stage3_delta_lake(step_prefix, table_name,essential_columns, surrogate_key = surrogate_key)
                self.add_to_lake_db_stage3(step_prefix, table_name, overwrite = self.lakeTableOverwrite)
                self.entity_freq_processor.update_lookup_df(table_name,status)
            if step_prefix in ['config_vw', 'auth_vw']:
                # # TODO: To Be Reviewed for skey
                self.dump_to_stage3_delta_lake(step_prefix, table_name,essential_columns, surrogate_key = False)
                self.add_to_lake_db_stage3(step_prefix, table_name, overwrite = self.lakeTableOverwrite)
                self.entity_freq_processor.update_lookup_df(table_name,status)
        except Exception as e:
                logger.error(f"An error occurred: {e}")
                status = "Unsuccessful"
                self.entity_freq_processor.update_lookup_df(table_name,status)
        

    def upsert_with_logging(self,
                            table_name, 
                            df, 
                            df_modified,
                            destination_path, 
                            primary_key, 
                            partitioning, 
                            partitioning_cols,
                            surrogate_key,
                            insertion_type = 'upsert'):
        
        logger.info(f'--- PRIMARY KEY: {primary_key}')
        start_time = datetime.now()

        # FIXME: 2024-03-14: upsert is a misnomer
        if insertion_type == 'upsert_EdGraph':
            numInputRows, numOutputRows, numTargetRowsInserted, numTargetRowsUpdated = self.oea.upsert_EdGraph(df = df,
                                                                    df_modified = df_modified, 
                                                                    destination_path = destination_path,#f"{sink_general_path}",
                                                                    primary_key = primary_key,
                                                                    partitioning = partitioning,
                                                                    partitioning_cols = partitioning_cols,
                                                                    de_duplicate = False)
        if insertion_type == 'upsert':
            #FIX ME - 2024-01-31  - 2024-03-07 Review again 
            numInputRows, numOutputRows, numTargetRowsInserted, numTargetRowsUpdated = self.oea.overwrite(df = df, 
                                                                    destination_path = destination_path,#f"{sink_general_path}",
                                                                    primary_key = primary_key,
                                                                    partitioning = partitioning,
                                                                    partitioning_cols = partitioning_cols,
                                                                    surrogate_key = False,
                                                                    de_duplicate = False)
        elif insertion_type == 'overwrite':
            numInputRows, numOutputRows, numTargetRowsInserted, numTargetRowsUpdated = self.oea.overwrite(df = df, 
                                                                    destination_path = destination_path,#f"{sink_general_path}",
                                                                    primary_key = primary_key,
                                                                    partitioning = partitioning,
                                                                    partitioning_cols = partitioning_cols,
                                                                    surrogate_key = False,
                                                                    de_duplicate = True)
                      
        end_time = datetime.now()
        # NOTE: 2024-02-23: @Harsh - FYI

        log_data = error_logger.create_log_dict(uniqueId = self.error_logger.generate_random_alphanumeric(10), # Generate a random 10-character alphanumeric value
                                                pipelineExecutionId = self.pipelineExecutionId,#'TEST_1234',#executionId,
                                                sparkSessionId = self.spark.sparkContext.applicationId,
                                                stageName = "edgraph-dwh: Migrate From S2R To Edgraph DWH Wrting to Delta",
                                                schemaFormat = 'edgraph-dwh',
                                                entityType = destination_path.split('/')[-2],
                                                entityName = table_name,
                                                numInputRows = numInputRows,
                                                totalNumOutputRows = numOutputRows,
                                                numTargetRowsInserted = numTargetRowsInserted,
                                                numTargetRowsUpdated = numTargetRowsUpdated,
                                                numRecordsSkipped = 0,
                                                numRecordsDeleted = 0,
                                                start_time = start_time,
                                                end_time = end_time,
                                                insertionType = insertion_type)
        error_logger.consolidate_logs(log_data,'entity')
    
    def create_record_hash_column(self, df,essential_columns):
        # TODO: Under DEV
        df_cols = [col for col in df.columns if col not in essential_columns]
        non_struct_array_cols = [col for col in df_cols if not isinstance(df.schema[col].dataType, (StructType, ArrayType))]
        non_struct_array_cols = non_struct_array_cols # FIXME: UNDER DEV
        
        logger.info(f'Creating hash column - RECORD_HASH - of the following columns: {non_struct_array_cols}')
        record_hash_expr = [f.col(key_component).cast('string') for key_component in non_struct_array_cols]
        df = df.withColumn("RECORD_HASH",F.sha2(F.concat(*[F.concat(F.coalesce(column, F.lit('')), F.lit('_')) for column in record_hash_expr]), 256))
        return df
        

    def get_df_latest_records_by_join(self, df, destination_path, func_enabled = False):
        if not func_enabled:
            return df
        else:
            logger.info('[EDFIOEACHILD REFINEMENT RECORD HASHING] JOIN BASED COMPARSIONS / DELTA COMPARISONS BEFORE UPSERT IS ENABLED')
            df = df.withColumnRenamed('RECORD_VERSION', 'RECORD_VERSION_LEFT')
            
            df_destination = self.load(destination_path)            
            df.createOrReplaceTempView('temp_vw_df_source_table')
            df_destination.createOrReplaceTempView('temp_vw_df_destination_table')

            query = f"""SELECT temp_vw_df_source_table.*, 
                               temp_vw_df_destination_table.RECORD_VERSION
                        FROM temp_vw_df_source_table 
                        LEFT JOIN temp_vw_df_destination_table 
                            ON temp_vw_df_source_table.NATURAL_KEY_HASH = temp_vw_df_destination_table.NATURAL_KEY_HASH
                        WHERE (temp_vw_df_source_table.RECORD_HASH != temp_vw_df_destination_table.RECORD_HASH)
                        OR (temp_vw_df_destination_table.RECORD_HASH IS NULL)
                    """
            df_joined = spark.sql(query)
            df_joined = df_joined.withColumn('RECORD_VERSION', F.col('RECORD_VERSION') + 1)
            df_joined = df_joined.drop('RECORD_VERSION_LEFT')

            logger.info(f"[EDFIOEACHILD REFINEMENT RECORD HASHING] --- NUM ROWS (SOURCE DELTA LAKE) - {df.count()}")
            logger.info(f"[EDFIOEACHILD REFINEMENT RECORD HASHING] --- NUM ROWS (DESTINATION DELTA LAKE) - {df_destination.count()}")
            logger.info(f'[EDFIOEACHILD REFINEMENT RECORD HASHING] --- NUM ROWS (ACTUALLY MODIFIED) - {df_joined.count()}')
            return df_joined

    def dump_to_stage3_delta_lake(self, step_prefix, table_name,essential_columns, surrogate_key = True):
        table_path = f"{self.stage_3_path}/{step_prefix[:-3]}/{table_name}"
        table_url = self.oea.to_url(table_path)

        df = self.spark.sql(f'select * from {step_prefix}_{table_name}')
        

        

        #TODO: ADD primary key in the MetadataProcessor
        columns = df.columns
        primary_key = self.return_primary_key(table_name, columns)
        # if table_name in ['FactSectionAttendance', 'FactSchoolAttendance']:
        #     primary_key.append('EventDate')
        # if table_name in ['FactSchoolEnrollment']:
        #     primary_key.append('EntryDate')
        insertion_type = 'None'

        table_path = f"{self.stage_3_path}/{step_prefix[:-3]}/{table_name}"
        table_url = oea.to_url(table_path)
        if DeltaTable.isDeltaTable(spark, table_url):
            
            for column in columns:
                if column.lower().endswith('hkey'):
                    insertion_type = 'upsert_EdGraph'
        
        if insertion_type =='upsert_EdGraph':
            pass

        elif table_name in self.base_tables:
            insertion_type = 'overwrite'
            primary_key = 'lakeId'
        else:
            insertion_type = 'upsert'
            #FIX TRY 2024-03-15
            if not(step_prefix == 'fact_vw' or table_name[:4].lower() == 'fact' or table_name.lower() == 'dimdate' or step_prefix == 'config_vw' or step_prefix == 'config' or step_prefix == 'auth' or step_prefix == 'auth_vw'):
                # df = self.insert_default_record(df = df, 
                #                            hKeyDefaults = ["1900", "Not Specified", "Not Specified"])
                pass

        for column in columns:
            if column.lower().endswith('hkey'):
                insertion_type = 'upsert_EdGraph'
                
        if insertion_type == 'upsert_EdGraph':
            df_modified = spark.sql(f'select * from {step_prefix}_{table_name}_modified')
        else:
            df_modified = None


        # FIXME: 2024-01-30  (TEMP FIX) TO PREVENT SKEY OVERWRITES
        table_name_mask = (table_name[:4].lower() == 'fact') or (table_name == 'DimAssessmentSection') or (table_name == 'DataAuthorization') or (table_name == 'DimObjectiveAssessmentPerformanceLevel') or (table_name == 'DimAssessmentPerformanceLevel') or (table_name == 'DimAssessmentAcademicSubject') or (table_name == 'DimAssessmentAssessedGradeLevel')
        if table_name_mask:
            surrogate_key = False
        self.upsert_with_logging(table_name = table_name, 
                                 df = df,
                                 df_modified = df_modified,
                                 destination_path = table_path, 
                                 primary_key = primary_key, 
                                 partitioning = self.partitioning, 
                                 partitioning_cols = [],
                                 surrogate_key = surrogate_key,
                                 insertion_type = insertion_type)
    
    def add_to_lake_db_stage3(self, step_prefix, table_name, overwrite = True):
        spark.sql(f'CREATE DATABASE IF NOT EXISTS {self.stage3_db_name}')
        table_prefix = step_prefix[:-3]
        table_path = f"{self.stage_3_path}/{table_prefix}/{table_name}"
        table_url = self.oea.to_url(table_path)

        table_prefix = f'{table_prefix}_'
        if table_name in self.base_tables:
            table_prefix = ''

        logger.info(f'Adding the table - {table_prefix}{table_name} to Lake DB')
        if overwrite:
            spark.sql(F"DROP TABLE IF EXISTS {self.stage3_db_name}.{table_prefix}{table_name}")
        spark.sql(f"CREATE TABLE IF NOT EXISTS {self.stage3_db_name}.{table_prefix}{table_name} using DELTA location '{table_url}'")

    def process_metadata(self):
        self.reorder_metadata_schemas()

    def generate_staging_list(self, length, schema_name):
        staging_list = ["etl_vw"]
        if length > 1:
            staging_list.append("staging_vw")
        for i in range(3, length):
            staging_list.append(f'staging_vw_{chr(ord("a") + i - 2)}_')
        staging_list.append(schema_name)

        final_prefix = f'{schema_name}_vw'
        staging_list[-1] = final_prefix
        # FIXME: Automation has bug
        if length == 2:
            staging_list = ['etl_vw', final_prefix]
        if length == 1:
            staging_list = [final_prefix]
        return staging_list

In [ ]:
class SemanticViewsBuilder(EdgraphDWHBuilder):
    def __init__(self, original_metadata, stage3_db_name, stage_3_path, partitioning, spark, oea, logger, error_logger,lakeTableOverwrite):
        super().__init__(original_metadata, stage3_db_name, stage_3_path, partitioning, spark, oea, logger, error_logger,lakeTableOverwrite)
    
    def set_server_creds(self, server_name, database_name, user_name, password, driver='ODBC Driver 18 for SQL Server'):
        self.server_name = server_name
        self.database_name = database_name
        self.user_name = user_name
        self.password = password
        self.driver = driver
        self.connection = self.connect_to_database()

    def connect_to_database(self):
        connection_string = f"DRIVER={{{self.driver}}};SERVER={self.server_name};DATABASE={self.database_name};UID={self.user_name};PWD={self.password};"
        return pyodbc.connect(connection_string)
    
    def execute_table_queries(self, schema_name, table_name, queries):
        staging_list = self.generate_staging_list(len(queries), schema_name)
        for step_prefix, query in zip(staging_list,queries):
            self.execute_query(query)

    def execute_query(self, query, isResult = False):
        query = query.replace("from", "FROM")
        query = query.replace("From", "FROM")
        query = query.replace(",FROM", ' FROM') 
        query = query.replace(", FROM", ' FROM')
        query = query.replace(",\nFROM", ' FROM')
        query = query.replace(",\n\nFROM", ' FROM')
        query = query.replace(",\n\n\nFROM", ' FROM')
        query = query.replace(",\n\n\n\nFROM", ' FROM')

        cursor = self.connection.cursor()
        cursor.execute(query)
        if isResult:
            result = int(cursor.fetchone()[0])
        else:
            result = None
        self.connection.commit()
        cursor.close()
        return result

    def create_schema_if_not_exists(self, schema):
        # NOTE: Under Dev & Review
        query = f"SELECT count(*) FROM sys.schemas WHERE name = N'{schema}';"        
        isSchema = self.execute_query(query, isResult = True)
        if isSchema == 0:
            logger.info(f'CREATING SCHEMA {schema}')
            query = f"CREATE SCHEMA [{schema}] AUTHORIZATION [dbo]"
            self.execute_query(query)
    
    def drop_view(self, schema, view_name):
        query = f"DROP VIEW IF EXISTS [{schema}].[{view_name}]"
        self.execute_query(query)
    
    def close_connection(self):
        self.connection.close()

In [ ]:
from datetime import date, timedelta
from dateutil.relativedelta import relativedelta
import pandas as pd
from pyspark.sql import SparkSession

class SparkTableGenerator:
    def __init__(self, spark_session, base_table_db_name, current_school_year, current_execution_datetime):
        self.spark = spark_session
        self.base_table_db_name = base_table_db_name
        self.current_school_year = int(current_school_year)
        self.current_execution_datetime = current_execution_datetime

    def generate_fiscal_month(self):
        # result_df = self.spark.sql(f"SELECT * FROM {self.base_table_db_name}.config_parameter")
        result_df = self.spark.sql(f"SELECT * FROM config_vw_parameter")
        if result_df is not None and not result_df.filter(result_df['ParameterValue'].isNotNull()).isEmpty():
            fiscal_month_df = result_df.filter(result_df['ParameterName'].contains("FiscalMonth"))

            if not fiscal_month_df.isEmpty():
                fiscal_month_value = fiscal_month_df.select('ParameterValue').first()[0]
                fiscal_month = int(fiscal_month_value)
            else:
                fiscal_month = 7
        else:
            fiscal_month = 7

        print("Fiscal Month:", fiscal_month)
        return fiscal_month

    def generate_first_day_of_week(self):
        # result_df = self.spark.sql(f"SELECT * FROM {self.base_table_db_name}.config_parameter")
        result_df = self.spark.sql(f"SELECT * FROM config_vw_parameter")
        if result_df is not None and not result_df.filter(result_df['ParameterValue'].isNotNull()).isEmpty():
            first_day_df = result_df.filter(result_df['ParameterName'] == 'FirstDayOfWeek')

            if not first_day_df.isEmpty():
                first_day_value = first_day_df.select('ParameterValue').first()[0]
                first_day_of_week = int(first_day_value)
            else:
                first_day_of_week = 1
        else:
            first_day_of_week = 1

        print("First Day of Week:", first_day_of_week)
        return first_day_of_week

    def generate_dim_date(self, fiscal_month, num_years = 4):
        current_execution_datetime = self.current_execution_datetime
        start_year = self.current_school_year - num_years
        dim_date_list = []

        while start_year <= self.current_school_year:
            tmp_fiscal_date = date(start_year, fiscal_month, 1)
            end_date = tmp_fiscal_date - timedelta(days=1)
            start_date = tmp_fiscal_date - relativedelta(years=1)

            while start_date <= end_date:
                # Your existing date generation logic here
                date_key = start_date.year * 10000 + start_date.month * 100 + start_date.day
                day = start_date.day

                suffix = 'th'
                if 4 <= day % 100 <= 20 or day % 10 == 0:
                    suffix = 'th'
                elif day % 10 == 1:
                    suffix = 'st'
                elif day % 10 == 2:
                    suffix = 'nd'
                elif day % 10 == 3:
                    suffix = 'rd'

                dim_date_dict = {
                    'DateKey': date_key,
                    'Date': start_date,
                    'DayOfMonth': day,
                    'DayOfMonthWithSuffix': f"{day}{suffix}",
                    'DayOfWeek': start_date.weekday() + 1,
                    'DayOfYear': start_date.timetuple().tm_yday,
                    'WeekDayName': start_date.strftime('%A'),
                    'WeekdayNameShort': start_date.strftime('%a').upper(),
                    'WeekOfMonth': (start_date.day - 1) // 7 + 1,
                    'WeekOfYear': start_date.isocalendar()[1],
                    'Month': start_date.month,
                    'MonthName': start_date.strftime('%B'),
                    'MonthNameShort': start_date.strftime('%b').upper(),
                    'SchoolYear': f"{start_year - 1}-{start_year}",
                    'SchoolYearShort': str(start_year),
                    'CalendarYear': start_date.year,
                    'DW_CreatedDateTime': str(current_execution_datetime),
                    'DW_ModifiedDateTime': str(current_execution_datetime),
                }

                dim_date_list.append(dim_date_dict)

                start_date += timedelta(days=1)

            start_year += 1

        dim_date_df = self.spark.createDataFrame(dim_date_list)
        dim_date_df.withColumn("SchoolYearShort", dim_date_df["SchoolYearShort"].cast(IntegerType()))
        dim_date_df.withColumn("DatSkey", lit(None).cast(IntegerType()))
        return dim_date_df#self.spark.createDataFrame(dim_date_df)