In [ ]:
import pyspark

from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.utils import AnalysisException
from pyspark.sql.types import StringType, StructType, StructField, IntegerType

from pyspark.sql.functions import col, substring, regexp_extract, split, lit, struct, to_date, from_unixtime, date_format
from pyspark.sql.functions import create_map, lit, when, array, coalesce, concat_ws
from pyspark.sql.functions import collect_list, create_map, lit, struct, array, concat
from pyspark.sql.functions import expr
from pyspark.sql.types import DecimalType

from pyspark.sql import functions as F
from pyspark.sql.functions import hash
import pyspark.sql.functions as f
from pyspark.sql.functions import udf
from concurrent.futures import ThreadPoolExecutor

import json
import os
import pandas as pd
import re
import uuid

import copy
from itertools import chain
from datetime import datetime

global_df_test = None
df_staffs = None
schema_name = None

In [ ]:
%run OEA/modules/Ed-Fi/v0.7/src/utilities/edfi_v0_7_edfi_py

In [ ]:
class SAPEdFiOEAChild(EdFiOEAChild):
    """ 
    NOTE: This class inherits features from the base class OEA and therefore,
    should be created / executed after running the notebook OEA_py
    """
    def __init__(self, workspace='dev', logging_level=logging.INFO, storage_account=None, keyvault=None, timezone=None, sap_pipeline = None, sap_pipelineType = None):
        # Call the base class constructor to initialize inherited attributes
        super().__init__(workspace, logging_level, storage_account, keyvault, timezone)
        spark = SparkSession.builder.config("spark.kryoserializer.buffer.max", "3000m").getOrCreate()
        spark.conf.set("spark.microsoft.delta.optimizeWrite.enabled", "true") # more info here: https://learn.microsoft.com/en-us/azure/synapse-analytics/spark/optimize-write-for-apache-spark
        self.sap_pipeline = sap_pipeline
        self.sap_pipelineType = sap_pipelineType
    
    def get_latest_changes(self, source_path, sink_path, filtering_date = 'rundate',primary_key = ['id'],debugMode = False):
        return super().get_latest_changes(source_path = source_path,
                                          sink_path = sink_path,
                                          filtering_date = filtering_date,
                                          primary_key = primary_key,
                                          debugMode = debugMode)

    def process(self, source_path,foreach_batch_function, batch_type, natural_key = None,landingDateTimeFormat = 'yyyyMMddHHmmss',options={}):
        # FIXME: 2024-02-08 (Under Dev for high granularity and de-dup)
        """ This simplifies the process of using structured streaming when processing transformations.
            Provide a source_path and a function that receives a dataframe to work with (which will be a dataframe with data from the given source_path).
            Use it like this...
            def refine_contoso_dataset(df_source):
                metadata = oea.get_metadata_from_url('https://raw.githubusercontent.com/microsoft/OpenEduAnalytics/gene/v0.7dev/modules/module_catalog/Student_and_School_Data_Systems/metadata.csv')
                df_pseudo, df_lookup = oea.pseudonymize(df, metadata['studentattendance'])
                oea.upsert(df_pseudo, 'stage2/Refined/contoso_sis/v0.1/studentattendance/general')
                oea.upsert(df_lookup, 'stage2/Refined/contoso_sis/v0.1/studentattendance/sensitive')
            oea.process('stage2/Ingested/contoso_sis/v0.1/studentattendance', refine_contoso_dataset)             
        """
        if not self.path_exists(source_path):
            raise ValueError(f'The given path does not exist: {source_path} (which resolves to: {self.to_url(source_path)})') 

        if natural_key is not None:
            natural_key_expr = [f.col(key_component).cast('string') for key_component in natural_key]
        
        def wrapped_function(df, batch_id):
            current_timestamp = datetime.now()
            df = df.withColumn('LastModifiedDate', F.lit(current_timestamp))
            df = df.withColumn("rundate", F.to_timestamp(F.col("rundate").cast('string'), landingDateTimeFormat))
            df = df.withColumn('sap_pipeline', F.lit(self.sap_pipeline))
            df = df.withColumn('sap_pipelineType', F.lit(self.sap_pipelineType))
            # df = df.orderBy(F.col("rundate").desc()).dropDuplicates(["id"])
            # FIXME: 2024-03-13 TEMP FIX For amount precision
            if 'AMOUNT' in df.columns and 'PU_FC_SUB2' in df.columns:
                df = df.withColumn("AMOUNT", df["AMOUNT"].cast(DecimalType(precision=20, scale=2)).cast('string'))
            if natural_key is not None:
                df = df.withColumn("NATURAL_KEY_HASH",F.sha2(F.concat(*[F.concat(F.coalesce(column, F.lit('')), F.lit('_')) for column in natural_key_expr]), 256))
            if batch_type != 'delete':
                df = df.withColumn("rowIsActive", F.lit(True))
            
            df.persist() # cache the df so it doesn't get read in multiple times when we write to multiple destinations. See: https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#foreachbatch
            foreach_batch_function(df, batch_id)
            df.unpersist()

        spark.sql("set spark.sql.streaming.schemaInference=true")
        #source_path = source_path.replace(':', '\:')
        print(f"source_path is: {source_path}")
        streaming_df = spark.readStream.load(self.to_url(source_path), **options)
        streaming_df = streaming_df.withColumn('stage1_source_url', F.input_file_name())

        # for more info on append vs complete vs update modes for structured streaming: https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#basic-concepts
        query = streaming_df.writeStream.format('delta').outputMode('append').trigger(once=True).option('checkpointLocation', self.to_url(source_path) + '/_checkpoints').foreachBatch(wrapped_function).start()
        query.awaitTermination()   # block until query is terminated, with stop() or with error; A StreamingQueryException will be thrown if an exception occurs.
        number_of_new_inbound_rows = query.lastProgress["numInputRows"]
        logger.info(f'[SAPEDFIOEACHILD INGESTION STRUCTURED STREAMING PROCESS]: Number of new inbound rows processed: {number_of_new_inbound_rows}')
        logger.debug(query.lastProgress)
        return number_of_new_inbound_rows
    
    def add_to_lake_db(self, source_entity_path, overwrite = False, extension = None, table_type = 'non-empty-non-descriptor', pipeline_type = None):
        """ Adds the given entity as a table (if the table doesn't already exist) to the proper lake db based on the path.
            This method will also create the lake db if it doesn't already exist.
            eg: add_to_lake_db('stage2/Ingested/contoso_sis/v0.1/students')

            Note that a spark db that points to source data in the delta format can't be queried via SQL serverless pool. More info here: https://docs.microsoft.com/en-us/azure/synapse-analytics/sql/resources-self-help-sql-on-demand#delta-lake
        """
        source_dict = self.parse_path(source_entity_path)
        db_name = source_dict['ldb_name']
        if ('/emptySchemas/' not in source_entity_path and table_type == 'non-empty-non-descriptor') or (table_type == 'ingested-intermediate'):
            if extension is not None:
                if not(extension.startswith('_')):
                    extension = '_' + extension
                source_dict['entity'] = source_dict['entity'] + str(extension)
            
            spark.sql(f'CREATE DATABASE IF NOT EXISTS {db_name}')
            if overwrite:
                spark.sql(f"drop table if exists {db_name}.{source_dict['entity']}")

            spark.sql(f"create table if not exists {db_name}.{source_dict['entity']} using DELTA location '{self.to_url(source_dict['entity_path'])}'")
        elif table_type == 'non-empty-descriptor':
            target_db_names = [db_name]            
            for target_db_name in target_db_names:
                spark.sql(f'CREATE DATABASE IF NOT EXISTS {target_db_name}')
                if overwrite:
                    spark.sql(f"drop table if exists {target_db_name}.{source_dict['entity']}")

                spark.sql(f"create table if not exists {target_db_name}.{source_dict['entity']} using DELTA location '{self.to_url(source_dict['entity_path'])}'")
        
    def merge_deletes_into_delta_lake(self, df, destination_path, func_enabled = False):
        df = df.cache()
        entity_name = destination_path.split('/')[-1]
        if not func_enabled:
            return 0
        else:
            logger.info('[SAPEDFIOEACHILD REFINEMENT SUBMISSION DELETES] MERGING DELETES BEFORE JOIN BASED COMPARSIONS / DELTA COMPARISONS / UPSERT IS ENABLED')                 
            df_destination = self.load(destination_path)
            if 'SUBMISSION_RECORD_IS_ACTIVE' in df_destination.columns:   
                df_destination = df_destination.drop('SUBMISSION_RECORD_IS_ACTIVE')
            df.createOrReplaceTempView('temp_vw_df_source_table')
            df_destination.createOrReplaceTempView('temp_vw_df_destination_table')

            query = f'select max(rundate) maxdatetime from temp_vw_df_source_table'
            maxdatetime = spark.sql(query).first()['maxdatetime']
            query = f"""SELECT temp_vw_df_destination_table.*,
                        CASE 
                            WHEN temp_vw_df_source_table.NATURAL_KEY_HASH IS NULL THEN False 
                            ELSE True 
                        END AS SUBMISSION_RECORD_IS_ACTIVE
                        FROM temp_vw_df_destination_table  
                        LEFT JOIN temp_vw_df_source_table
                            ON temp_vw_df_destination_table.NATURAL_KEY_HASH = temp_vw_df_source_table.NATURAL_KEY_HASH
                    """
            df_joined = spark.sql(query)
            df_joined = df_joined.filter(df_joined["SUBMISSION_RECORD_IS_ACTIVE"] == False)
            df_joined.createOrReplaceTempView('temp_vw_df_inactive_records_staging')
            
            query = f"""WITH maxPipelineExecutionIdCTE AS
                (
                    SELECT pipelineExecutionId
                    FROM ldb_{self.workspace}_sap_etl_logs.etlsubmissionslogs
                    WHERE entity_name = '{entity_name}'
                      AND operation_type != 'delete'
                    ORDER BY start_time desc, end_time desc
                    LIMIT 1
                )
                SELECT DISTINCT edfi_id,
                                edfi_location,
                                NATURAL_KEY_HASH 
                FROM ldb_{self.workspace}_sap_etl_logs.etlsubmissionslogs
                INNER JOIN maxPipelineExecutionIdCTE 
                    ON etlsubmissionslogs.pipelineExecutionId = maxPipelineExecutionIdCTE.pipelineExecutionId
                WHERE entity_name = '{entity_name}'
                  AND response_status_code LIKE '2%' 
                  AND response_status_code NOT LIKE '204'
                """
            df_last_submission_logs = spark.sql(query)
            df_last_submission_logs.createOrReplaceTempView('temp_vw_df_last_submission_logs')
            
            query = """SELECT temp_vw_df_last_submission_logs.NATURAL_KEY_HASH
                        FROM temp_vw_df_last_submission_logs
                        INNER JOIN  temp_vw_df_inactive_records_staging
                            ON  temp_vw_df_last_submission_logs.NATURAL_KEY_HASH = temp_vw_df_inactive_records_staging.NATURAL_KEY_HASH
                    """
            df_final = spark.sql(query)
            df_final = df_final.withColumn('sap_pipeline', F.lit(self.sap_pipeline))
            df_final = df_final.withColumn('sap_pipelineType', F.lit(self.sap_pipelineType))
            df_final = df_final.withColumn('rundate', F.lit(maxdatetime))
            #df_final = df_final.withColumn('LastModifiedDateTime', F.lit(datetime.now()))
            df_final = df_final.withColumn('SUBMISSION_RECORD_IS_ACTIVE', F.lit(False))

            primary_key = ['NATURAL_KEY_HASH', 'sap_pipeline','sap_pipelineType']
            update_cols = {"sink.rundate": "updates.rundate", 
                           #"sink.LastModifiedDateTime": "updates.LastModifiedDateTime", 
                           "sink.SUBMISSION_RECORD_IS_ACTIVE": "updates.SUBMISSION_RECORD_IS_ACTIVE"}
            destination_url = self.to_url(destination_path)
            pk_statement = self.return_pk_statement(primary_key)
            logger.info(f"[SAPEDFIOEACHILD REFINEMENT SUBMISSION DELETES] NUMBER OF RECORDS TO BE MARKED AS INACTIVE FOR SUBMISSION - {df_final.count()}")
            if DeltaTable.isDeltaTable(spark, destination_url):
                logger.info('[SAPEDFIOEACHILD REFINEMENT SUBMISSION DELETES] DELETE MERGES')
                delta_table_sink = DeltaTable.forPath(spark, destination_url)
                delta_table_sink.alias('sink').merge(df_final.alias('updates'), pk_statement).whenMatchedUpdate(set = update_cols).execute()# .whenNotMatchedInsert(values = insert_cols).execute()
            else:
                logger.error(f'Invalid stage 3 delta location for the item - {entity_name}')
            return 1


In [ ]:
class SAPOpenAPIUtilChild(OpenAPIUtil):
    def __init__(self, swagger_url):
        super().__init__(swagger_url)
        self.pluralization_mappings = dict()

    def depluralize(self, noun):
        if noun == 'people':
            return 'person'
        if noun == 'surveys':
            return 'survey'
        if re.search('[sxz]es$', noun):
            return re.sub('es$', '', noun)
        if re.search('ies$', noun):
            return re.sub('ies$', 'y', noun)
        if re.search('s$', noun):
            return re.sub('s$', '', noun)
        return noun
    
    def create_definitions(self):
        self.swagger_json['definitions'] = self.swagger_json.get('definitions', self.swagger_json['components']['schemas'])
        for entity in self.swagger_json['definitions']:
            properties = self.swagger_json['definitions'][entity]['properties']
            table_name = entity.split('_')[-1]
            table_schema = {}

            for prop in properties:
                if 'description' in properties[prop].keys():
                    properties[prop].pop('description')
                field_info = properties[prop]
                if 'required' in self.swagger_json['definitions'][entity].keys():
                    field_info['required'] = True if prop in self.swagger_json['definitions'][entity]['required'] else False
                else:
                    field_info['required'] = False
                field_info['table_name'] = entity.split('_')[-1]
                field_info['column_name'] = prop
                if 'x-Ed-Fi-pseudonymization' in field_info:
                    field_info['pseudonymization'] = field_info['x-Ed-Fi-pseudonymization']
                    field_info.pop('x-Ed-Fi-pseudonymization')
                for header in [x for x in self.metadata_headers if x not in field_info] : field_info[header] = None
                table_schema[prop] = field_info

            self.definitions[table_name] = table_schema
        self.tables = [x for x in self.definitions.keys()]
    
    def create_spark_schemas_from_definitions(self):
        for entity in self.dependency_order:
            table_schema = self.definitions[entity]
            spark_schema = []
            if(entity == 'localEducationAgencyReference'):
                print(entity)
            for col_name in table_schema:
                col_metadata = {}
                if('pseudonymization' in table_schema[col_name]): col_metadata['pseudonymization'] = table_schema[col_name]['pseudonymization']
                if('x-Ed-Fi-isIdentity' in table_schema[col_name]): col_metadata['x-Ed-Fi-isIdentity'] = table_schema[col_name]['x-Ed-Fi-isIdentity']
                
                col_metadata['required'] = table_schema[col_name]['required']
                referenced_table = self.get_reference(table_schema[col_name])
                if table_schema[col_name]['type'] == 'array':
                    datatype = ArrayType(self.schemas[self.pluralize(referenced_table)])
                    if('x-Ed-Fi-explode' in table_schema[col_name]):
                        col_metadata['x-Ed-Fi-explode'] = table_schema[col_name]['x-Ed-Fi-explode']
                elif table_schema[col_name]['$ref'] != None:
                    datatype = self.schemas[self.pluralize(referenced_table)]
                    if('x-Ed-Fi-fields-to-pluck' in table_schema[col_name]):
                        col_metadata['x-Ed-Fi-fields-to-pluck'] = table_schema[col_name]['x-Ed-Fi-fields-to-pluck']
                else:
                    datatype = self.get_data_type(table_schema[col_name]['type'], table_schema[col_name]['format'])
                col_spark_schema = StructField(col_name, datatype, not(table_schema[col_name]['required']))
                col_spark_schema.metadata = col_metadata
                spark_schema.append(col_spark_schema)
            self.schemas[self.pluralize(entity)] = StructType(spark_schema)
            self.pluralization_mappings[self.pluralize(entity)] = entity

In [ ]:
class SAPUtilities:
    def __init__(self, spark, oea, sap_essential_columns = None):
        self.spark = spark
        self.oea = oea
        self.sap_essential_columns = sap_essential_columns if sap_essential_columns is not None else ['DistrictId', 'SchoolYear', 'sap_pipeline', 'sap_pipelineType', 'lakeId', 'validationRecordId', 'LastModifiedDate', 'rundate', 'stage1_source_url', 'RECORD', 'NATURAL_KEY_HASH', 'RECORD_HASH', 'RECORD_VERSION']
    
    ## Descriptor Utilities
    def loadDescriptors(self, path):
        df = self.spark.read.format('delta').load(self.oea.to_url(path)).select("namespace", "codeValue")
        return df 
        
    def returnNamespaces(self, descriptorsDFRef, descriptor):
        namespaces = descriptorsDFRef[descriptor].select("namespace").distinct().rdd.flatMap(lambda x: x).collect()
        namespaces_num = len(namespaces)
        if namespaces_num == 1:
            return descriptor, namespaces[0]
        else:
            return 0,namespaces
    
    ## Column Mapping
    def map_child_elements(self, child_list, mapping_dict):
        mapped_child_list = [mapping_dict[item] for item in child_list]
        return mapped_child_list

    def refine_entities_in_order(self, sap_to_edfi_complex, unmapped_child_list):
        master_list = list(sap_to_edfi_complex.values())
        child_list = self.map_child_elements(unmapped_child_list, sap_to_edfi_complex)
        
        # Create a dictionary to store the indices of elements in the Master List
        master_indices = {item: index for index, item in enumerate(master_list)}

        # Sort the Child List based on the indices in the Master List
        sorted_child_list = sorted(child_list, key=lambda item: master_indices[item])

        # Map the sorted_child_list back to the original values
        original_sorted_child_list = [item for item in unmapped_child_list if sap_to_edfi_complex[item] in sorted_child_list]
        return original_sorted_child_list

    def extract_refined_cols_mapping(self, file_path):
        # FIXME: To Be Revised
        df = self.spark.read.json(file_path)
        #data_dict = df.toPandas().to_dict(orient='records')[0]
        data_dict = df.head(1)[0].asDict()
        return data_dict

    def map_columns(self, df, column_mapping):
        for column_name, new_column_name in column_mapping.items():
            if column_name in df.columns:
                df = df.withColumnRenamed(column_name, new_column_name)
        return df

    def map_to_hard_values(self, df, edfi_item):
        # FIXME: TO BE REVIEWED (DISTRICTID = ???)
        # df = df.withColumn("DistrictId", lit(101912))
        if edfi_item == 'staffEducationOrganizationAssignmentAssociations':
            #df = df.withColumn("staffServiceDescriptor", lit("SS013000"))
            pass
        elif edfi_item == 'staffEducationOrganizationEmploymentAssociations':
            if 'employmentStatusDescriptor' not in df.columns:
                # FIXME: HARDCODED SINCE COLUMN IS ABSENT
                df = df.withColumn("employmentStatusDescriptor", lit('Other'))
        else:
            pass
        return df
    
    ## DATA Cleaning
    def infer_descriptor_columns(self, final_columns):
        descriptor_columns = [col for col in final_columns if col.endswith("Descriptor")]
        return descriptor_columns

    def transform_dataframe(self, df, descriptor_col, df_namespace_mapping):
        try:
            joined_df = df.join(df_namespace_mapping, col(descriptor_col) == col('codeValue'), 'left')

            transformed_df = joined_df.withColumn(descriptor_col, concat(col('namespace'), lit('#'),col(descriptor_col)))        
            final_df = transformed_df.select(*df.columns)

            # Replace null values in the descriptor_col
            final_df = final_df.withColumn(descriptor_col, \
                when(col(descriptor_col).isNull(), lit('uriPlaceholder#NA')).otherwise(col(descriptor_col)))
            
            return final_df
        
        except AnalysisException as e:
            logger.info(f"An error occurred during transformation: - {e}")
            return df 

    def create_struct_id_sets(self, df, beginDate, endDate, descriptor, struct_name, **additional_columns):
        # FIXME: Temporary Fix for Staffs
        # Create a struct column using the input columns
        if beginDate is None:
            beginDate_expr = F.lit(None).alias('beginDate')
        elif beginDate == 'staffService_beginDate':
            beginDate_expr = F.col(beginDate).alias('staffServiceBeginDate')
        else:
            beginDate_expr = F.col(beginDate).alias('beginDate')
        
        if endDate is None:
            endDate_expr = F.lit(None).alias('endDate')
        elif endDate == 'staffService_endDate':
            endDate_expr = F.col(endDate).alias('staffServiceEndDate')

        else:
            endDate_expr = F.col(endDate).alias('endDate')
        
        # Create expressions for additional columns
        additional_exprs = [F.col(col_name).alias(col_name) for col_name in list(additional_columns.values())]
        if descriptor is not None:
            struct_column = F.struct(
                beginDate_expr,
                endDate_expr,
                F.col(descriptor).alias(descriptor),
                *additional_exprs
            )
        else:
            struct_column = F.struct(
                beginDate_expr,
                endDate_expr,
                *additional_exprs
            )
        # Collect the structs into an array for each row
        df = df.withColumn(struct_name, F.array(struct_column).alias(struct_name))
        return df

    def convert_to_TX_ext_struct(self, df, columns):
        # Create a list of map expressions for the specified columns
        tx_map_expr = []
        for col_name in columns:
            tx_map_expr.append(col(col_name).alias(col_name))
        tx_struct = struct(*tx_map_expr)
        tx_struct = F.struct(tx_struct.alias("TX"))
        # Add the TX struct to the _ext column
        df_with_tx_struct = df.withColumn("_ext", tx_struct)
        return df_with_tx_struct


    def convert_id_to_struct(self, df, struct_name, id_column):
        df = df.withColumn(struct_name, struct(col(id_column).alias(id_column)))
        #df = df.drop(id_column)
        return df

    def convert_column_to_array(self,
                                df, 
                                source_columns = ["raceDescriptor"],
                                target_key = "raceDescriptor",
                                target_column = 'races'):
        # Filter out null values from each race column and create struct array
        array_expr = []
        for column in source_columns:
            array_expr.append(F.when(F.col(column).isNotNull(), F.struct(F.col(column).alias(target_key))))
        df = df.withColumn(target_column, F.array(*array_expr))
        return df

    def convert_date_columns_to_standard_format(self, df):
        date_columns = [column for column in df.columns if "date" in column.lower()]
        for column in date_columns:
            df = df.withColumn(column, to_date(col(column), "yyyyMMdd").cast(StringType()))
        return df

    def format_digit_vals(self, df, col_name):
        condition = when(col(col_name).cast("string").rlike("^[0-9]{1}$"), concat_ws("", lit("00"), col(col_name)))
        condition = condition.when(col(col_name).cast("string").rlike("^[0-9]{2}$"), concat_ws("", lit("0"), col(col_name))).otherwise(col(col_name))
        formatted_df = df.withColumn(col_name, condition)
        return formatted_df

    def replace_null_with_default(self, df, column, default_value):
        return df.na.fill(value=default_value,subset=[column])

    def drop_completely_null_columns(self, df):
        non_null_counts = df.select([col(c).isNotNull().alias(c) for c in df.columns])
        non_null_columns = [col for col in non_null_counts.columns if non_null_counts.select(col).first()[col]]
        
        return df.select(*non_null_columns)

    def filter_columns(self, df, column_list):
        existing_columns = [col for col in column_list if col in df.columns]
        return df.select(existing_columns)

    ## Other Utilities
    def get_sink_general_path(self, entity_parent_path, edfi_version, edfi_item, partitioning,SAP_SUB, TEST_MODE = False):
        if edfi_item.endswith('Exts'):
            item_domain = f'tx/{edfi_item}'
        else:
            item_domain = f'ed-fi/{edfi_item}'
        destination_path = entity_parent_path.replace('Ingested', 'Refined').replace('SAP', f'SAP/{SAP_SUB}') + '/general/' + item_domain #.replace('SAP', 'Ed-Fi').replace('1.0', edfi_version)
        if partitioning:
            pattern = re.compile(r'DistrictId=.*?/|SchoolYear=.*?/')
            destination_path = re.sub(pattern, '', destination_path)
        
        if TEST_MODE:
            destination_path = destination_path.replace('/Refined/', '/TEST/Refined/')
        return destination_path

    def get_sink_sensitive_path(self, entity_parent_path, edfi_version, edfi_item, partitioning,SAP_SUB, TEST_MODE = False):
        if edfi_item.endswith('Exts'):
            item_domain = f'tx/{edfi_item}'
        else:
            item_domain = f'ed-fi/{edfi_item}'
        destination_path = entity_parent_path.replace('Ingested', 'Refined').replace('SAP', f'SAP/{SAP_SUB}') + '/sensitive/' + f'{item_domain}_lookup' #.replace('SAP', 'Ed-Fi').replace('1.0', edfi_version)
        if partitioning:
            pattern = re.compile(r'DistrictId=.*?/|SchoolYear=.*?/')
            destination_path = re.sub(pattern, '', destination_path)
        
        if TEST_MODE:
            destination_path = destination_path.replace('/Refined/', '/TEST/Refined/')
        return destination_path
    def get_sink_general_sensitive_paths(self, 
                                        source_path, 
                                        edfi_version, 
                                        edfi_item, 
                                        partitioning = False,
                                        SAP_SUB = 'FINAL',
                                        TEST_MODE = False):
        
        path_dict = self.oea.parse_path(source_path)  
        sink_general_path = self.get_sink_general_path(path_dict['entity_parent_path'], 
                                                edfi_version, 
                                                edfi_item, 
                                                partitioning,
                                                SAP_SUB,
                                                TEST_MODE)
        sink_sink_path = self.get_sink_sensitive_path(path_dict['entity_parent_path'], 
                                                edfi_version, 
                                                edfi_item, 
                                                partitioning,
                                                SAP_SUB,
                                                TEST_MODE)
        return sink_general_path, sink_sink_path
            
    def extract_district_id(self, input_string):
        # Extract DistrictId using regular expression
        district_id_match = re.search(r'DistrictId=(\d+)', input_string)
        if district_id_match:
            district_id = district_id_match.group(1)
        else:
            district_id = None

        # Extract Component using regular expression
        component_match = re.search(r'(\w+) into', input_string)
        if component_match:
            component = component_match.group(1)
        else:
            component = None
        return district_id
    
    def has_column(self, df, col):
        try:
            df[col]
            return True
        except AnalysisException:
            return False
    
    def process_date_column(self, 
                            df, 
                            date_col_name,
                            date_format = "yyyyMMdd", 
                            exclude_cols = []):
        # NOTE: WIP - Subject to Revisions
        if date_col_name in df.columns:
            df = df.withColumn(date_col_name, to_date(col(date_col_name), date_format))
            df = df.withColumn(date_col_name, col(date_col_name).cast(DateType()))
        else:
            # TODO: Requires Review
            df = df.withColumn(date_col_name, F.lit(None).cast(DateType()))
        return df

    def preprocess_race_columns(self, df, raceDescriptorsRef):
        race_columns = [f'race{i}_Descriptor' for i in range(1, 6)]
        for i in range(1,6):
            col_name = f"race{i}_Descriptor"
            df = df.withColumn(col_name, F.when(F.col(col_name) == '1', F.lit(f"0{i}")).otherwise(F.lit("uriPlaceholder#NA")))

            df = self.transform_dataframe(df, 
                                          col_name, 
                                          raceDescriptorsRef #sap_process_client.descriptorsDFRef['raceDescriptors']
                                          )
        return df

    def combine_race_columns(self, df):
        race_columns = [f'race{i}_Descriptor' for i in range(1, 6)]
        race_array = F.array([F.struct(F.col(col).alias('raceDescriptor')) for col in race_columns])
        
        df = df.withColumn('races', race_array)
        df = df.drop(*race_columns)
        
        races_schema = StructType([StructField('raceDescriptor', StringType(), True)])
        array_type = ArrayType(races_schema, True)
        struct_field = StructField('races', array_type, True)
        df = df.withColumn('races', F.col('races').cast(array_type))
        return df
    
    def create_record_hash_column(self, df):
        # TODO: Under DEV
        df_cols = [col for col in df.columns if col not in self.sap_essential_columns]
        non_struct_array_cols = [col for col in df_cols if not isinstance(df.schema[col].dataType, (StructType, ArrayType))]
        non_struct_array_cols = ['NATURAL_KEY_HASH'] + non_struct_array_cols # FIXME: UNDER DEV
        
        logger.info(f'Creating hash column - RECORD_HASH - of the following columns: {non_struct_array_cols}')
        record_hash_expr = [f.col(key_component).cast('string') for key_component in non_struct_array_cols]
        df = df.withColumn("RECORD_HASH",F.sha2(F.concat(*[F.concat(F.coalesce(column, F.lit('')), F.lit('_')) for column in record_hash_expr]), 256))
        return df

In [ ]:
class SAPProcessClient:
    def __init__(self, 
                 spark, 
                 oea,
                 sap_utilities,
                 sap_to_edfi_complex,
                 final_columns,
                 _ext_TX_cols, 
                 descriptorsDFRef,
                 descriptors):
        self.spark = spark
        self.oea = oea
        self.sap_utilities = sap_utilities
        self.sap_to_edfi_complex = sap_to_edfi_complex
        self.final_columns = final_columns 
        self._ext_TX_cols = _ext_TX_cols 
        self.descriptorsDFRef = descriptorsDFRef
        self.descriptors = descriptors
    def processBudgetExts(self, df):
        df = sap_utilities.process_date_column(df = df,
                                               date_col_name = 'beginDate',
                                               date_format = "yyyyMMdd",
                                               exclude_cols = [])
        df = sap_utilities.process_date_column(df = df,
                                               date_col_name = 'endDate',
                                               date_format = "yyyyMMdd",
                                               exclude_cols = [])

        df = df.withColumn('educationOrganizationId', col('educationOrganizationId').cast(IntegerType()))
        df = self.sap_utilities.convert_id_to_struct(df, 'educationOrganizationReference', 'educationOrganizationId')
        df = self.sap_utilities.filter_columns(df, self.final_columns['budgetExts'])
        df = self.sap_utilities.create_record_hash_column(df)
        return df

    def processActualExts(self, df):
        df = sap_utilities.process_date_column(df = df,
                                               date_col_name = 'beginDate',
                                               date_format = "yyyyMMdd",
                                               exclude_cols = [])
        df = sap_utilities.process_date_column(df = df,
                                               date_col_name = 'endDate',
                                               date_format = "yyyyMMdd",
                                               exclude_cols = [])
                                               
        df = df.withColumn('educationOrganizationId', col('educationOrganizationId').cast(IntegerType()))
        df = self.sap_utilities.convert_id_to_struct(df, 'educationOrganizationReference', 'educationOrganizationId')
        df = self.sap_utilities.filter_columns(df, self.final_columns['actualExts'])
        df = self.sap_utilities.create_record_hash_column(df)
        return df

    def processStaffs(self, df):
        self.spark.conf.set("spark.sql.legacy.timeParserPolicy","LEGACY")
        if 'birthDate' in df.columns:
            df = df.withColumn("birthDate",
                                to_date(col("birthDate"),"yyyyMMdd"))
        else:
            df = df.withColumn('birthDate', 
                               to_date(F.lit(None),"yyyyMMdd"))
        
        # pkTeacherRequirementDescriptor
        # TODO: To Be Reviewed
        if 'pkTeacherRequirementDescriptor' not in df.columns:
            df = df.withColumn('pkTeacherRequirementDescriptor', F.lit('uriPlaceholder#NA'))
        if 'staffTypeDescriptor' not in df.columns:
            df = df.withColumn('staffTypeDescriptor', F.lit('uriPlaceholder#NA'))
        if 'generationCodeDescriptor' not in df.columns:
            df = df.withColumn('generationCodeDescriptor', F.lit('uriPlaceholder#NA'))
        if 'totalYearsProfExperience' not in df.columns:
            df = df.withColumn('totalYearsProfExperience', F.lit(None).cast('int'))
        if 'yearsExperienceInDistrict' not in df.columns:
            df = df.withColumn('yearsExperienceInDistrict', F.lit(None).cast('int'))
        
        # TYPESET
        df = df.withColumn('staffTypeSet_beginDate', to_date(col('staffTypeSet_beginDate'), "yyyyMMdd").cast(StringType()))
        df = df.withColumn('staffTypeSet_endDate', to_date(col('staffTypeSet_endDate'), "yyyyMMdd").cast(StringType()))
        
        df = df.withColumn('staffTypeSet_beginDate', col('staffTypeSet_beginDate').cast(DateType()))
        df = df.withColumn('staffTypeSet_endDate', col('staffTypeSet_endDate').cast(DateType()))

        # PARAPROFESSIONAL
        df = df.withColumn('paraprofessional_beginDate', to_date(col('paraprofessional_beginDate'), "yyyyMMdd").cast(StringType()))
        df = df.withColumn('paraprofessional_beginDate', to_date(col('paraprofessional_beginDate'), "yyyyMMdd").cast(StringType()))
        
        df = df.withColumn('paraprofessional_beginDate', col('paraprofessional_beginDate').cast(DateType()))
        df = df.withColumn('paraprofessional_endDate', col('paraprofessional_endDate').cast(DateType()))

        # BOOLEAN TYPES
        df = df.withColumn('staffDoNotReportTSDS', col('staffDoNotReportTSDS').cast(BooleanType()))
        df = df.withColumn('hispanicLatinoEthnicity', col('hispanicLatinoEthnicity').cast(BooleanType()))
        df = df.withColumn('paraprofessionalCertification', col('paraprofessionalCertification').cast(BooleanType()))

        # RACES
        df = self.sap_utilities.preprocess_race_columns(df, self.descriptorsDFRef['raceDescriptors'])
        df = self.sap_utilities.combine_race_columns(df)
        
        if 'raceDescriptor' in df.columns:
            df = self.sap_utilities.convert_column_to_array(df, 
                                            source_columns = ["raceDescriptor"],
                                            target_key = "raceDescriptor",
                                            target_column = 'races')
        df = self.sap_utilities.create_struct_id_sets(df = df, 
                                    beginDate = 'staffTypeSet_beginDate', 
                                    endDate = 'staffTypeSet_endDate',
                                    descriptor = 'staffTypeDescriptor',
                                    struct_name = 'typeSets')
        
        df = self.sap_utilities.create_struct_id_sets(df = df, 
                                    beginDate = 'paraprofessional_beginDate', 
                                    endDate = 'paraprofessional_endDate',
                                    descriptor = None,
                                    struct_name = 'paraprofessionalCertificationSet',
                                    paraprofessionalCertification = 'paraprofessionalCertification')
        #TODO: Potentially Debug
        _ext_TX_cols_staffs = self._ext_TX_cols['staffs']
        if 'paraprofessionalCertificationSet' not in _ext_TX_cols_staffs:
            _ext_TX_cols_staffs.append('paraprofessionalCertificationSet')
        if 'staffDoNotReportTSDS' not in _ext_TX_cols_staffs:
            _ext_TX_cols_staffs.append('staffDoNotReportTSDS')
        df = self.sap_utilities.convert_to_TX_ext_struct(df, 
                                          columns = _ext_TX_cols_staffs)
        df = self.sap_utilities.filter_columns(df, self.final_columns['staffs'])
        df = self.sap_utilities.create_record_hash_column(df)
        return df

    def processContractedInstructionalStaffFTEExts(self, df):
        df = df.withColumnRenamed("localEducationAgencyId", "educationOrganizationId")

        df = df.withColumn('educationOrganizationId', col('educationOrganizationId').cast(IntegerType()))
        df = self.sap_utilities.convert_id_to_struct(df, 'educationOrganizationReference', 'educationOrganizationId')
        
        df = df.withColumn('schoolId', col('schoolId').cast(IntegerType()))
        df = self.sap_utilities.convert_id_to_struct(df, 'schoolReference', 'schoolId')
        df = self.sap_utilities.filter_columns(df, self.final_columns['contractedInstructionalStaffFTEExts'])
        df = self.sap_utilities.create_record_hash_column(df)
        return df

    def processStaffEducationOrganizationEmploymentAssociations(self, df):
        df = df.withColumn('percentDayEmployed', col('percentDayEmployed').cast('integer'))
        df = df.withColumn('numberDaysEmployed', col('numberDaysEmployed').cast('integer'))
        # df = self.sap_utilities.replace_null_with_default(df, 'employmentStatusDescriptor', 'Other')

        df = df.withColumnRenamed("localEducationAgencyId", "educationOrganizationId")
        df = df.withColumn('educationOrganizationId', col('educationOrganizationId').cast(IntegerType()))

        df = self.sap_utilities.convert_id_to_struct(df, 'educationOrganizationReference', 'educationOrganizationId')
        
        #df = df.withColumn("auxiliaryRoleIdSet_beginDate", col("hireDate"))
        
        df = df.withColumn('auxiliaryRoleIdSet_beginDate', to_date(col('auxiliaryRoleIdSet_beginDate'), "yyyyMMdd"))
        df = df.withColumn('hireDate', to_date(col('hireDate'), "yyyyMMdd"))
        df = df.withColumn('endDate', to_date(col('endDate'), "yyyyMMdd"))
        df = df.withColumn('auxiliaryRoleIdSet_endDate', to_date(col('auxiliaryRoleIdSet_endDate'), "yyyyMMdd"))
        
        df = self.sap_utilities.create_struct_id_sets(df = df, 
                                                      beginDate = 'auxiliaryRoleIdSet_beginDate', 
                                                      endDate = 'auxiliaryRoleIdSet_endDate',
                                                      descriptor = 'auxiliaryRoleIdDescriptor',
                                                      struct_name = 'auxiliaryRoleIdSets')
                            
        df = self.sap_utilities.convert_id_to_struct(df, 'staffReference', 'staffUniqueId')
        df = self.sap_utilities.convert_to_TX_ext_struct(df, 
                                        columns = self._ext_TX_cols['staffEducationOrganizationEmploymentAssociations']
                                        )
        df = self.sap_utilities.filter_columns(df, self.final_columns['staffEducationOrganizationEmploymentAssociations'])
        df = self.sap_utilities.create_record_hash_column(df)
        return df


    def processStaffEducationOrganizationAssignmentAssociations(self, df):
        df = df.withColumn('staffService_beginDate', to_date(col('staffService_beginDate'), "yyyyMMdd"))
        df = df.withColumn('staffService_endDate', to_date(col('staffService_endDate'), "yyyyMMdd"))
        
        df = df.withColumn('staffService_beginDate', col('staffService_beginDate').cast(DateType()))
        df = df.withColumn('staffService_endDate', col('staffService_endDate').cast(DateType()))

        df = df.withColumnRenamed("localEducationAgencyId", "educationOrganizationId")
        df = df.withColumn("educationOrganizationId", col("educationOrganizationId").cast(IntegerType()))

        df = self.sap_utilities.convert_id_to_struct(df, 'educationOrganizationReference', 'educationOrganizationId')
        # df = df.withColumn("beginDate", to_date(lit("2023-08-01")))

        df = df.withColumn('beginDate', to_date(col('beginDate'), "yyyyMMdd"))
        df = df.withColumn('endDate', to_date(col('endDate'), "yyyyMMdd"))
        
        df = df.withColumn('beginDate', col('beginDate').cast(DateType()))
        df = df.withColumn('endDate', col('endDate').cast(DateType()))

        df = df.withColumn('monthlyMinutes', col('monthlyMinutes').cast(IntegerType()))

        df = self.sap_utilities.format_digit_vals(df, 'staffClassificationDescriptor')
        df = self.sap_utilities.create_struct_id_sets(df = df, 
                                    beginDate = 'staffService_beginDate', 
                                    endDate = 'staffService_endDate',
                                    descriptor = 'staffServiceDescriptor',
                                    struct_name = 'staffServiceSets',
                                    monthlyMinutes = "monthlyMinutes",
                                    populationServedDescriptor = "populationServedDescriptor")
        df = self.sap_utilities.convert_id_to_struct(df, 'schoolReference', 'schoolId')
        
        df = self.sap_utilities.convert_id_to_struct(df, 'staffReference', 'staffUniqueId')
        df = self.sap_utilities.convert_to_TX_ext_struct(df, 
                                        columns = self._ext_TX_cols['staffEducationOrganizationAssignmentAssociations']
                                        )
        df = self.sap_utilities.filter_columns(df, self.final_columns['staffEducationOrganizationAssignmentAssociations'])
        df = self.sap_utilities.create_record_hash_column(df)
        return df


    def processPayrollExts(self, df):
        df = df.withColumnRenamed("localEducationAgencyId", "educationOrganizationId")
        df = df.withColumn('educationOrganizationId', col('educationOrganizationId').cast(IntegerType()))
        
        df = self.sap_utilities.convert_id_to_struct(df, 'educationOrganizationReference', 'educationOrganizationId')

        df = df.withColumn('beginDate', to_date(col('beginDate'), "yyyyMMdd"))
        df = df.withColumn('endDate', to_date(col('endDate'), "yyyyMMdd"))
        
        df = df.withColumn('beginDate', col('beginDate').cast(DateType()))
        df = df.withColumn('endDate', col('endDate').cast(DateType()))

        
        df = df.withColumn("fiscalYear", col("fiscalYear").cast(IntegerType()))
        df = df.withColumn("organization", col("organization").cast(IntegerType()))
        df = self.sap_utilities.convert_id_to_struct(df, 'staffReference', 'staffUniqueId')
        df = self.sap_utilities.filter_columns(df, self.final_columns['payrollExts'])
        df = self.sap_utilities.create_record_hash_column(df)
        return df


In [ ]:
class SAPToEdFiRefine(EdFiRefine):
    def __init__(self, 
                 workspace, 
                 oea,
                 spark, 
                 sap_oea_utils,
                 sap_process_client,
                 logger,
                 schema_gen, 
                 moduleName, 
                 authUrl, 
                 swaggerUrl, 
                 dataManagementUrl, 
                 changeQueriesUrl, 
                 dependenciesUrl, 
                 apiVersion, 
                 schoolYear,
                 districtId, 
                 pipelineExecutionId,
                 error_logger,
                 test_mode,
                 natural_upsert_mode = False,
                 sap_essential_columns = None):
        super().__init__(workspace = workspace, 
                       oea = oea, 
                       spark = spark,
                       schema_gen = schema_gen, 
                       moduleName = moduleName, 
                       authUrl = authUrl, 
                       swaggerUrl = swaggerUrl, 
                       dataManagementUrl = dataManagementUrl, 
                       changeQueriesUrl = changeQueriesUrl, 
                       dependenciesUrl = dependenciesUrl, 
                       apiVersion = apiVersion, 
                       schoolYear = schoolYear, 
                       districtId = districtId, 
                       pipelineExecutionId = pipelineExecutionId,
                       error_logger = error_logger,
                       test_mode = test_mode)
        self.sap_oea_utils = sap_oea_utils
        self.logger = logger
        self._ext_TX_cols = sap_process_client._ext_TX_cols #SAPProcessClient()._ext_TX_cols
        self.schemas = sap_oea_utils.create_spark_schemas() 
        self.primitive_datatypes = ['timestamp', 'date', 'decimal', 'boolean', 'integer', 'string', 'long'] 
        # self.error_logger = ErrorLogging(spark = spark, 
        #                                 oea = oea, 
        #                                 logger = logger)
        self.natural_upsert_mode = natural_upsert_mode
        self.sap_essential_columns = sap_essential_columns if sap_essential_columns is not None else ['DistrictId', 'SchoolYear', 'sap_pipeline', 'sap_pipelineType', 'lakeId', 'validationRecordId', 'LastModifiedDate', 'rundate', 'stage1_source_url', 'RECORD', 'NATURAL_KEY_HASH', 'RECORD_HASH', 'RECORD_VERSION']

    def set_params(self, params = {}):
        for key, value in params.items():
            if key == 'sap_pipeline':
                self.sap_pipeline = value
    
    def modify_target_schema(self, target_schema, incl_columns, excl_columns):
        for column in incl_columns:
            if column not in excl_columns:
                if column.lower().endswith('date'):
                    target_schema = target_schema.add(StructField(column, TimestampType()))
                else:
                    target_schema = target_schema.add(StructField(column, StringType()))
        return target_schema              
    
    def modify_descriptor_value(self, df, col_name):
        if col_name in df.columns:
            # TODO: @Abhinav, I do not see where you made the changes to use the descriptorId instead of Namespace/CodeValue
            df = df.withColumn(f"{col_name}LakeId", f.concat_ws('_', f.col('DistrictId'), f.col('SchoolYear'), f.regexp_replace(col_name, '#', '_')))
            #df = df.drop(col_name)
        else:
            df = df.withColumn(f"{col_name}LakeId", f.lit(None).cast("String"))

        return df

    
    def extract_pk_from_definitions_exp(self, table_name, depluralize = True):
        # NOTE: This function is under review
        '''
        EXAMPLE:
        pk, _ = extract_pk_from_definitions('payrollExts', depluralize = True)
        fk = self.extract_reference_key(schemas['payrollExts']['educationOrganizationReference'])
        '''
        target_schema = copy.deepcopy(self.schemas[table_name])
        identity_cols = []
        cleanup_cols = []
        for col_struct in target_schema:
            col_name = col_struct.name
            if not(col_struct.nullable):
                cleanup_cols.append(col_name)
            if not(col_struct.nullable) and not(re.search('Reference$', col_name)):
                identity_cols.append(col_name)
            elif not(col_struct.nullable) and (re.search('Reference$', col_name)):
                target_col = target_schema[col_name]
                identity_cols = identity_cols + [f'{col_name}.{x}' for x in self.extract_reference_key(target_col)]
                
        return sorted(identity_cols), cleanup_cols
    
    def extract_pk_from_definitions(self, table_name, depluralize = True):
        '''
        EXAMPLE:
        pk, _ = extract_pk_from_definitions('payrollExts', depluralize = True)
        fk = self.extract_reference_key(schemas['payrollExts']['educationOrganizationReference'])
        '''
        target_schema = copy.deepcopy(self.schemas[table_name])
        if depluralize:
            table_name = self.sap_oea_utils.pluralization_mappings[table_name] #self.sap_oea_utils.depluralize(table_name)
        identity_cols = []
        cleanup_cols = []
        for col_name in self.sap_oea_utils.definitions[table_name]:
            target_col = target_schema[col_name]
            col_schema = self.sap_oea_utils.definitions[table_name][col_name]
            key = self.sap_oea_utils.pluralize(table_name)
            if 'x-Ed-Fi-isIdentity' in col_schema and col_schema['x-Ed-Fi-isIdentity'] != ["*"]:
                identity_cols = identity_cols + [col_name]
                cleanup_cols = identity_cols
            if (re.search('Reference$', col_name) is not None) and (col_schema['required']):
                cleanup_cols = identity_cols + [col_name]
                identity_cols = identity_cols + [f'{col_name}.{x}' for x in self.extract_reference_key(target_col)]
                
        return sorted(identity_cols), cleanup_cols

    # Use this function to extract reference Key columns and create a FK
    def extract_reference_key(self, target_col):
        reference_keys = []
        field_names = target_col.dataType.fields
        for candidate_field in field_names:
            candidate_field_name = candidate_field.name
            if candidate_field.nullable == False:
                identifier_field_name = f"{candidate_field_name}"
                reference_keys = reference_keys + [identifier_field_name]
        
        return sorted(reference_keys)

    def upsert_with_logging(self, 
                            df, 
                            sap_pipeline,
                            destination_path, 
                            primary_key, 
                            partitioning, 
                            partitioning_cols,
                            table_name,
                            ext_entity,
                            parent = True):
        # NOTE: partitioning_cols is a legacy param
        start_time = self.thread_local.start_times.get(table_name, datetime.now())
        if self.natural_upsert_mode:
            primary_key = ['NATURAL_KEY_HASH', 'DistrictId', 'SchoolYear']
        else:
            primary_key = ['RECORD', 'DistrictId', 'SchoolYear']
        if parent:
            numInputRows, numOutputRows, numTargetRowsInserted, numTargetRowsUpdated = self.oea.upsert(df = df, 
                                                                    destination_path = destination_path,
                                                                    primary_key = primary_key,#['RECORD', 'DistrictId', 'SchoolYear'],
                                                                    partitioning = True,
                                                                    partitioning_cols = [self.districtId_col_name, self.schoolYear_col_name],
                                                                    surrogate_key = False)   
        else:
            numInputRows, numOutputRows, numTargetRowsInserted, numTargetRowsUpdated = self.oea.delete_then_insert(df = df, 
                                                                    destination_path = destination_path,
                                                                    primary_key = primary_key,#['RECORD', 'DistrictId', 'SchoolYear'],
                                                                    partitioning = True,
                                                                    partitioning_cols = [self.districtId_col_name, self.schoolYear_col_name],
                                                                    surrogate_key = False)                                    
        end_time = datetime.now()
        # FIXME: 2024-02-15: flagging empty Tab;es
        if '/emptySchemas/' in destination_path:
            emptySchemaMetadata = True
        else:
            emptySchemaMetadata = False
        log_data = error_logger.create_log_dict(uniqueId = error_logger.generate_random_alphanumeric(10), # Generate a random 10-character alphanumeric value
                                                pipelineExecutionId = self.pipelineExecutionId,#'TEST_1234',#executionId,
                                                sparkSessionId = spark.sparkContext.applicationId,
                                                sap_pipeline = sap_pipeline,
                                                sap_pipelineType = sap_pipelineType,
                                                stageName = "Refinement",
                                                schemaFormat = 'ed-fi',
                                                entityType = ext_entity.lower(), # TODO: To Be Reviewed
                                                entityName = table_name,
                                                numInputRows = numInputRows,
                                                totalNumOutputRows = numOutputRows,
                                                numTargetRowsInserted = numTargetRowsInserted,
                                                numTargetRowsUpdated = numTargetRowsUpdated,
                                                numRecordsSkipped = 0,
                                                numRecordsDeleted = 0,
                                                start_time = start_time,
                                                end_time = end_time,
                                                insertionType = 'append' if ingestionHistoryMode else 'upsert',
                                                emptySchemaMetadata = emptySchemaMetadata)
        error_logger.consolidate_logs(log_data,'entity')
    
    def process_ext_column(self,
                           df,
                           schema_name,
                           table_name,
                           ext_column_name,
                           ext_entity,
                           target_schema,
                           sink_general_path):    
        if '_ext' in df.columns:
            self.logger.info(f"Writing EXT Tables - {table_name}")
            # FIXME: Revise Logic
            sink_general_path = sink_general_path.replace('/ed-fi/', f'/{ext_entity.lower()}/')
            if table_name.startswith('staffs'):
                df = df.select(self.sap_essential_columns + ['staffUniqueId'] + ['_ext']) # df.select(['RECORD','lakeId', 'DistrictId', 'SchoolYear', 'LastModifiedDate','staffUniqueId', '_ext', 'rundate', 'sap_pipeline', 'sap_pipelineType', 'validationRecordId', 'stage1_source_url','NATURAL_KEY_HASH'])
            else:
                df = df.select(self.sap_essential_columns + ['_ext']) # df.select(['RECORD','lakeId', 'DistrictId', 'SchoolYear', 'LastModifiedDate','_ext', 'rundate', 'sap_pipeline', 'sap_pipelineType', 'validationRecordId', 'stage1_source_url','NATURAL_KEY_HASH'])
            
            target_schema = self.get_ext_entities_schemas(table_name = table_name,
                                                    ext_column_name = ext_column_name,
                                                    default_value = ext_entity)
            try:
                ext_inner_cols =  self._ext_TX_cols[table_name]
            except:
                ext_inner_cols = target_schema.fieldNames()
            
            df = self.flatten_ext_column(df = df, 
                                         table_name = table_name, 
                                         ext_col = ext_column_name, 
                                         inner_key = ext_entity,
                                         ext_inner_cols = ext_inner_cols)
            
            df = self.transform_sub_module(df, 
                                        target_schema, 
                                        sink_general_path, 
                                        schema_name,
                                        table_name,
                                        True,
                                        ext_entity)

            self.logger.info(f"Writing EXT Table - {table_name}")
            self.upsert_with_logging(df = df, 
                                     sap_pipeline = self.sap_pipeline,
                                     destination_path = f"{sink_general_path}", 
                                     primary_key = 'lakeId', 
                                     partitioning = True,
                                     partitioning_cols = ['DistrictId', 'SchoolYear'], 
                                     table_name = table_name,
                                     ext_entity = ext_entity,
                                     parent = True)
            
            self.oea.add_to_lake_db(sink_general_path, overwrite = True, extension = ext_entity)

    def flatten_reference_col(self, df, table_name,target_col, reference_key):
        if target_col.name not in ['credentialReference']:
            # reference_key =  self.extract_reference_key(self.schemas[table_name][target_col.name])
            col_prefix = target_col.name.replace('Reference', '')
            df = df.withColumn(f"{col_prefix}LakeId", f.when(f.col(target_col.name).isNotNull(), f.concat(f.col('DistrictId'), f.lit('_'), f.col('SchoolYear'), f.lit('_'), *[F.concat(F.col(f'{target_col.name}.{x}'), F.lit('_')) for x in reference_key[:-1]], F.col(f'{target_col.name}.{reference_key[-1]}'))))

        return df

    def modify_references_and_descriptors(self, df, table_name,target_col):
        for ref_col in [x for x in df.columns if re.search('Reference$', x) is not None]:
            reference_key = self.extract_reference_key(target_col.dataType.elementType[ref_col])
            df = self.flatten_reference_col(df, table_name,target_col.dataType.elementType[ref_col], reference_key)
        for desc_col in [x for x in df.columns if re.search('Descriptor$', x) is not None]:
            df = self.modify_descriptor_value(df, desc_col)
            
            # TODO: Test Run
            # if desc_col in df.columns:
            #    df = df.drop(desc_col)
        return df

    def explode_arrays(self, df, sink_general_path,target_col, schema_name, table_name, ext_entity):
        cols = self.sap_essential_columns #['RECORD', 'lakeId', 'DistrictId', 'SchoolYear', 'LastModifiedDate', 'rundate', 'sap_pipeline', 'sap_pipelineType', 'validationRecordId', 'stage1_source_url','NATURAL_KEY_HASH']
        child_name = f"{table_name}_{target_col.name}"
        self.store_start_time(child_name)

        child_df = df.select(cols + [target_col.name])
        child_df = child_df.withColumn("exploded", f.explode(target_col.name))#.drop(target_col.name).select(cols + ['exploded.*'])

        identity_cols = [x.name for x in target_col.dataType.elementType.fields if 'x-Ed-Fi-isIdentity' in x.metadata].sort()
        if(identity_cols is not None and len(identity_cols) > 0):
            child_df = child_df.withColumn(f"{target_col.name}LakeId", f.concat(f.col('DistrictId'), f.lit('_'), f.col('SchoolYear'), f.lit('_'), *[f.concat(f.col(x), f.lit('_')) for x in identity_cols]))
        
        child_df = child_df.select(cols + ['exploded.*']).drop(target_col.name) #FIXME: TO BE REVIEWED
        child_df = self.modify_references_and_descriptors(child_df, table_name,target_col)

        if ext_entity == 'ed-fi':
            extension = None
        else:
            extension = ext_entity

        for array_sub_col in [x for x in target_col.dataType.elementType.fields if x.dataType.typeName() == 'array' ]:
            grand_child_name = f"{table_name}_{target_col.name}_{array_sub_col.name}"
            self.store_start_time(grand_child_name)
            
            grand_child_df = child_df.withColumn('exploded', f.explode(array_sub_col.name)).select(child_df.columns + ['exploded.*']).drop(array_sub_col.name)
            grand_child_df = self.modify_references_and_descriptors(grand_child_df, table_name,array_sub_col)

            self.logger.info(f"Writing Grand Child Table - {table_name}_{target_col.name}_{array_sub_col.name}") 
            self.upsert_with_logging(df = grand_child_df, 
                                     sap_pipeline = self.sap_pipeline, #TODO: Generalize
                                     destination_path = f"{sink_general_path}_{target_col.name}_{array_sub_col.name}", 
                                     primary_key = 'lakeId', 
                                     partitioning = True,
                                     partitioning_cols = ['DistrictId', 'SchoolYear'], 
                                     table_name = f"{table_name}_{target_col.name}_{array_sub_col.name}",
                                     ext_entity = ext_entity,
                                     parent = False)
            self.oea.add_to_lake_db(f"{sink_general_path}_{target_col.name}_{array_sub_col.name}", 
                                    overwrite = True,
                                    extension = extension)
            
        
        self.logger.info(f"Writing Child Table - {table_name}_{target_col.name}")
        self.upsert_with_logging(df = child_df, 
                                     sap_pipeline = self.sap_pipeline, #TODO: Generalize
                                     destination_path = f"{sink_general_path}_{target_col.name}", 
                                     primary_key = 'lakeId', 
                                     partitioning = True,
                                     partitioning_cols = ['DistrictId', 'SchoolYear'], 
                                     table_name = f"{table_name}_{target_col.name}",
                                     ext_entity = ext_entity,
                                     parent = False)
        self.oea.add_to_lake_db(f"{sink_general_path}_{target_col.name}", 
                                    overwrite = True,
                                    extension = extension)
    
        df = df.drop(target_col.name)
        return df

    def transform(self,
                df, 
                schema_name, 
                table_name, 
                primary_key,
                ext_entity,
                sink_general_path,
                parent_schema_name, 
                parent_table_name):
        # TODO: Pseudominization Pending
        self.store_start_time(table_name)
        if re.search('Descriptors$', table_name) is None:
            target_schema = copy.deepcopy(self.schemas[table_name])
            # Add primary key
            if self.has_column(df, primary_key):
                df = df.withColumn('lakeId', f.concat_ws('_', f.col('DistrictId'), f.col('SchoolYear'), f.col(primary_key)).cast("String"))
                df = df.withColumn('validationRecordId', f.concat_ws('_', f.col('DistrictId'), f.col('SchoolYear'), f.col(primary_key)).cast("String"))
            else:
                df = df.withColumn('lakeId', f.lit(None).cast("String"))
                df = df.withColumn('validationRecordId', f.lit(None).cast("String"))
        else:
            target_schema = self.get_descriptor_schema(table_name)
            # Add primary key
            if self.has_column(df, primary_key):
                df = df.withColumn('lakeId', f.concat_ws('_', f.col('DistrictId'), f.col('SchoolYear'), f.col(primary_key)).cast("String"))
                df = df.withColumn('validationRecordId', f.concat_ws('_', f.col('DistrictId'), f.col('SchoolYear'), f.col(primary_key)).cast("String"))
            else:
                df = df.withColumn('lakeId', f.lit(None).cast("String"))
                df = df.withColumn('validationRecordId', f.lit(None).cast("String"))

        # FIXME: Automate best on self.sap_essential_columns
        target_schema = self.modify_target_schema(target_schema = target_schema, 
                                                  incl_columns = self.sap_essential_columns, 
                                                  excl_columns = ['lakeId'])
        
        # target_schema = target_schema.add(StructField('DistrictId', StringType()))\
        #                             .add(StructField('SchoolYear', StringType()))\
        #                             .add(StructField('LastModifiedDate', TimestampType()))\
        #                             .add(StructField('RECORD', StringType()))\
        #                             .add(StructField('rundate', TimestampType())) \
        #                             .add(StructField('sap_pipeline', StringType())) \
        #                             .add(StructField('validationRecordId', StringType())) \
        #                             .add(StructField('stage1_source_url', StringType())) \
        #                             .add(StructField('NATURAL_KEY_HASH', StringType()))
        # FIXME: Temporary Fix
        if table_name.lower().endswith('exts'):
            ext_entity_flag = ext_entity
        else:
            ext_entity_flag = None
        
        df = self.transform_sub_module(df = df, 
                                       target_schema = target_schema, 
                                       sink_general_path = sink_general_path, 
                                       schema_name = schema_name,
                                       table_name = table_name, 
                                       extension = False,
                                       ext_entity = 'ed-fi' if ext_entity_flag is None else ext_entity_flag) #NOTE: This represents None for all Exts ending tables

        self.process_ext_column(df = df,
                        schema_name = schema_name,
                        table_name = table_name,
                        ext_column_name = '_ext',
                        ext_entity = ext_entity, #NOTE: Should always represent 'TX'
                        target_schema = target_schema,
                        sink_general_path = sink_general_path)


        self.logger.info(f"Writing Main Table - {table_name}")    
        if '_ext' in df.columns:
            df = df.withColumn('_ext', f.lit(None).cast(target_schema['_ext'].dataType))
        self.upsert_with_logging(df = df, 
                                 sap_pipeline = self.sap_pipeline, 
                                 destination_path = f"{sink_general_path}", 
                                 primary_key = 'lakeId', 
                                 partitioning = True,
                                 partitioning_cols = ['DistrictId', 'SchoolYear'], 
                                 table_name = table_name,
                                 ext_entity = 'ed-fi' if ext_entity_flag is None else ext_entity_flag,
                                 parent = True) 
        if table_name.lower().endswith('descriptors'):
            # NOTE: Reverting to non-None logic of else
            self.oea.add_to_lake_db(sink_general_path, overwrite = True, extension = ext_entity_flag, table_type = 'non-empty-descriptor', pipeline_type = None)
        else:
            self.oea.add_to_lake_db(sink_general_path, overwrite = True, extension = ext_entity_flag, table_type = 'non-empty-non-descriptor', pipeline_type = None)        
        return df

    def transform_sub_module(self, 
                             df, 
                             target_schema, 
                             sink_general_path, 
                             schema_name, 
                             table_name, 
                             extension = False, 
                             ext_entity = 'ed-fi'):
        if not(table_name.lower().endswith('descriptors')):
            primary_key, clean_up_cols = self.extract_pk_from_definitions(table_name, 
                                                                      depluralize = True)
        else:
            primary_key = ['namespace', 'codeValue']
            clean_up_cols = []
        flatten_cols = []
        arr_cols = []
        descriptor_cols = []
        for col_name in target_schema.fieldNames():
            target_col = target_schema[col_name]
            if target_col.dataType.typeName() in self.primitive_datatypes:
                # If it is a Descriptor
                if re.search('Descriptor$', col_name) is not None:
                    df = self.modify_descriptor_value(df, col_name)
                    descriptor_cols.append(col_name)
                else:
                    if col_name in df.columns:
                        # Casting columns to primitive data types
                        df = df.withColumn(col_name, f.col(col_name).cast(target_col.dataType))
                    else:
                        # If Column not present in dataframe, add column with None values.
                        df = df.withColumn(col_name, f.lit(None).cast(target_col.dataType))
            # If Complex datatype, i.e. Object, Array
            else:
                if col_name not in df.columns:
                    df = df.withColumn(col_name, f.lit(None).cast(target_col.dataType))
                elif (col_name == '_ext'):# or (extension == True):
                    df = df.withColumn(f"{col_name}_json", f.to_json(f.col(col_name))).drop(f"{col_name}_json")
                else:
                    # Generate JSON column as a Complex Type
                    if (col_name.lower() == 'paraprofessionalcertificationset') and (target_col.dataType.typeName() != 'array'):
                        # FIXME: Temporary Fix to deal with paraprofessionalcertificationset
                        df = df.withColumn(col_name, f.array(f.col(col_name)))
                        target_col.dataType = f.ArrayType(target_col.dataType)
                    
                    df = df.withColumn(f"{col_name}_json", f.to_json(f.col(col_name))) \
                        .withColumn(col_name, f.from_json(f.col(f"{col_name}_json"), target_col.dataType)) \
                        .drop(f"{col_name}_json")
                
                # Modify the links with surrogate keys
                if re.search('Reference$', col_name) is not None:
                    flatten_cols.append(target_col)
        
                if target_col.dataType.typeName() == 'array':
                    arr_cols.append(target_col)

        for target_col in flatten_cols:
            reference_key =  self.extract_reference_key(self.schemas[table_name][target_col.name])
            df = self.flatten_reference_col(df, table_name,target_col, reference_key)
        
        if not(extension):
            if "LakeId" in df.columns:
                df = df.drop('LakeId')

            df = df.withColumn("LakeId", f.concat(f.col('DistrictId'), f.lit('_'), f.col('SchoolYear'), f.lit('_'), *[F.concat(F.col(f'{x}'), F.lit('_')) for x in primary_key[:-1]], F.col(f'{primary_key[-1]}')))
            
        for col_name in clean_up_cols:
            if col_name.lower().endswith('reference') or col_name.lower().endswith('descriptor'):
                if col_name in df.columns:
                    df = df.drop(col_name)
            else:
                self.logger.info(f"{col_name} may be required for analytics - not being dropped")
        
        for col_name in descriptor_cols:
            if col_name in df.columns:
                df = df.drop(col_name)

        for target_col in arr_cols:
            df = self.explode_arrays(df, sink_general_path,target_col, schema_name, table_name, ext_entity)
        
        return df

    def get_ext_entities_schemas(self,
                                table_name = 'staffs',
                                ext_column_name = '_ext',
                                default_value = 'TPDM'):
        target_schema = copy.deepcopy(self.schemas[table_name])
        for col_name in target_schema.fieldNames():
            target_col = target_schema[col_name]
            if target_col.name == ext_column_name:
                if target_col.dataType[0].name == default_value:
                    return target_col.dataType[0].dataType         
                    
    def flatten_ext_column(self,
                            df, 
                        table_name, 
                        ext_col, 
                        inner_key,
                        ext_inner_cols
                        ):
        #TODO: Modify the complex sub type field name logic
        cols = self.sap_essential_columns # ['RECORD','lakeId', 'DistrictId', 'SchoolYear', 'LastModifiedDate', 'rundate', 'sap_pipeline', 'sap_pipelineType', 'validationRecordId', 'stage1_source_url', 'NATURAL_KEY_HASH']
        if table_name == 'staffs':
            cols = cols + ['staffUniqueId']
        
        flattened_cols = ext_inner_cols#["educatorPreparationPrograms"] #_ext_TX_cols[table_name]
        dict_col = F.col(ext_col)[inner_key]
        complex_dtype_text = str(df.select('_ext').dtypes[0][1])

        exprs = [dict_col.getItem(key).alias(key) for key in flattened_cols if str(key) in complex_dtype_text]
        flattened_df = df.select(exprs + cols)

        return flattened_df

In [ ]:
class SAPEdFiClient(EdFiClient):
    def __init__(self, 
                 workspace, 
                 kvName, 
                 moduleName, 
                 authUrl, 
                 dataManagementUrl, 
                 changeQueriesUrl, 
                 dependenciesUrl, 
                 apiVersion, 
                 batchLimit, 
                 minChangeVer="", 
                 maxChangeVer="",
                 landingDateTimeFormat = "yyyyMMddHHmmss", 
                 schoolYear=None, 
                 districtId=None, 
                 kvSecret_clientId = None, 
                 kvSecret_clientSecret = None,
                 retry_strategy = None, 
                 threadMode = False, 
                 devMode = False,
                 oea = None, 
                 final_columns = {}, 
                 lookup_table_name = '',
                 lookup_table_base_path = '', 
                 lookup_db_name = ''):
        # Call the constructor of the parent class
        super().__init__(workspace = workspace, 
                         kvName = kvName, 
                         moduleName = moduleName, 
                         authUrl = authUrl, 
                         dataManagementUrl = dataManagementUrl, 
                         changeQueriesUrl = changeQueriesUrl, 
                         dependenciesUrl = dependenciesUrl, 
                         apiVersion = apiVersion, 
                         batchLimit = batchLimit, 
                         minChangeVer = minChangeVer, 
                         maxChangeVer = maxChangeVer, 
                         landingDateTimeFormat = landingDateTimeFormat,
                         schoolYear = schoolYear, 
                         districtId = districtId, 
                         kvSecret_clientId = kvSecret_clientId, 
                         kvSecret_clientSecret = kvSecret_clientSecret,
                         retry_strategy = retry_strategy, 
                         threadMode = threadMode, 
                         devMode = devMode)

        # Additional arguments specific to SAPEdFiClient
        self.success_logs = []
        self.error_logs = []
        self.edfi_id_records = []
        
        self.edfi_id_record_schema = StructType([
                                    StructField("edfi_location", StringType(), True),
                                    StructField("edfi_id", StringType(), True),
                                    StructField("edfi_id_modified", StringType(), True),
                                    StructField("entity_name", StringType(), True),
                                    StructField("resource", StringType(), True),
                                    StructField("staffUniqueId", StringType(), True),
                                    StructField("recordAPIRefreshDateTime", StringType(), True),
                                    StructField("isDeleted", BooleanType(), True)
                                    ])
        self.log_schema = StructType([
                        StructField("pipeline_execution_id", StringType(), True),
                        StructField("pipelineExecutionId", StringType(), True),
                        StructField("run_id", StringType(), True),
                        StructField("operation_type", StringType(), True),
                        StructField("request_payload", StringType(), True),
                        StructField("request_url", StringType(), True),
                        StructField("resource", StringType(), True),
                        StructField("entity_name", StringType(), True),
                        StructField("response_headers", StringType(), True),
                        StructField("response_status_code", StringType(), True),
                        StructField("response_text", StringType(), True),
                        StructField("rundate_time", StringType(), True),
                        StructField("start_time", StringType(), True),
                        StructField("end_time", StringType(), True),
                        StructField("record_id", StringType(), True),
                        StructField("lakeId", StringType(), True),
                        StructField("edfi_location", StringType(), True),
                        StructField("edfi_id", StringType(), True),
                        StructField("edfi_id_modified", StringType(), True),                                    
                        StructField("sap_pipeline", StringType(), True),
                        StructField("sap_pipelineType", StringType(), True),
                        StructField("stage1_source_url", StringType(), True),
                        StructField("NATURAL_KEY_HASH", StringType(), True)                    
                        ])
        self.final_columns = final_columns
        self.oea = oea
        self.clientId = oea._get_secret(kvSecret_clientId) # oea._get_secret("oea-edfi-api-client-id") if clientId is None else clientId
        self.clientSecret = oea._get_secret(kvSecret_clientSecret) # oea._get_secret("oea-edfi-api-client-secret") if clientSecret is None else clientSecretself.max_rundates = list()
        self.lookup_table_name = lookup_table_name
        self.lookup_table_path = f"{lookup_table_base_path}/{lookup_table_name}"
        self.lookup_db_name = lookup_db_name
        self.schoolYear = schoolYear 
        self.districtId = districtId
        self.metadata_logging_keys = ['RECORD','lakeId','stage1_source_url', 'sap_pipeline', 'sap_pipelineType', 'NATURAL_KEY_HASH']
        self.max_rundates = list()
         
    def cleanup_discriptor_col(self,
                          descriptor_prefix, 
                          descriptor_key, 
                          descriptor_value, 
                          descriptor_type):
        descriptor_col_name = f'{descriptor_prefix}{descriptor_type}'
        descriptor_key = descriptor_key[0].upper() + descriptor_key[1:] 
        if descriptor_col_name == descriptor_key:
            descriptor_value = descriptor_value.replace(f'/{descriptor_type}#', f'/{descriptor_key}#')
        return descriptor_value

    def cleanup_dict(self, original_item):
        metadata_logging_values = dict()
        item = copy.copy(original_item)
        keys_to_remove = []

        if 'staffUniqueId' in item:
            staffUniqueId = item['staffUniqueId']
        elif 'staffReference' in item:
            staffUniqueId = item['staffReference']['staffUniqueId']
        else:
            staffUniqueId = 'INVALID_COLUMN'
        
        # FIXME: 2024-02-01 TEMP FIX
        if ('fiscalYear' in item) and ('actualFundDescriptor' in item) or ('budgetFundDescriptor' in item):
            item['fiscalYear'] = item['fiscalYear'][-1]

        if 'races' in item:
            final_races = []
            for index, raceDescriptor in enumerate(item['races']):
                for raceKey, raceValue in raceDescriptor.items():
                    if raceValue != 'uriPlaceholder#NA':
                        final_races.append(raceDescriptor)
                item['races'] = final_races
        
        for key in item:
            if key in self.metadata_logging_keys:
                metadata_logging_values[key] = item[key]
                keys_to_remove.append(key)

            if key != '_ext':
                if (item[key] == 'uriPlaceholder#NA') or (item[key] is None):
                    keys_to_remove.append(key)
            # FIXME: Temporary
            payrollExt_expr = (key.lower().startswith('payroll')) and (key.lower().endswith('descriptor'))
            budgetExt_expr = (key.lower().startswith('budget')) and (key.lower().endswith('descriptor'))
            actualExt_expr = (key.lower().startswith('actual')) and (key.lower().endswith('descriptor'))
            
            
            if payrollExt_expr: 
                item[key] = self.cleanup_discriptor_col('Payroll', key, item[key], 'FundDescriptor')
                item[key] = self.cleanup_discriptor_col('Payroll', key, item[key], 'FunctionDescriptor')
                item[key] = self.cleanup_discriptor_col('Payroll', key, item[key], 'ObjectDescriptor')
                item[key] = self.cleanup_discriptor_col('Payroll', key, item[key], 'ProgramIntentDescriptor')

            elif budgetExt_expr:
                item[key] = self.cleanup_discriptor_col('Budget', key, item[key], 'FundDescriptor')
                item[key] = self.cleanup_discriptor_col('Budget', key, item[key], 'FunctionDescriptor')
                item[key] = self.cleanup_discriptor_col('Budget', key, item[key], 'ObjectDescriptor')
                item[key] = self.cleanup_discriptor_col('Budget', key, item[key], 'ProgramIntentDescriptor')
            
            elif actualExt_expr:
                item[key] = self.cleanup_discriptor_col('Actual', key, item[key], 'FundDescriptor')
                item[key] = self.cleanup_discriptor_col('Actual', key, item[key], 'FunctionDescriptor')
                item[key] = self.cleanup_discriptor_col('Actual', key, item[key], 'ObjectDescriptor')
                item[key] = self.cleanup_discriptor_col('Actual', key, item[key], 'ProgramIntentDescriptor')
            
            elif key == 'ciStaffProgramIntentDescriptor':
                item[key] = item[key].replace('/ProgramIntentDescriptor#', '/CIStaffProgramIntentDescriptor#')
            
        for key in keys_to_remove:
            item.pop(key)

        if ('_ext' in item) and ('TX' in item['_ext']):
            inner_keys_to_remove = []
            for ext_item in item['_ext']['TX']:
                if (item['_ext']['TX'][ext_item] is None) or (item['_ext']['TX'][ext_item] == 'uriPlaceholder#NA'):
                    inner_keys_to_remove.append(ext_item)
                if (ext_item.endswith('Sets')):
                    for inner_set_item_dict in item['_ext']['TX'][ext_item]:
                        for inner_key, inner_value in inner_set_item_dict.items(): 
                            if inner_value == 'uriPlaceholder#NA':
                                inner_keys_to_remove.append(ext_item)
                                break
                if (ext_item == 'staffDoNotReportTSDS'):
                    item['_ext']['TX'][ext_item] = bool(item['_ext']['TX'][ext_item])
                if (ext_item == 'paraprofessionalCertificationSet'):
                    inner_set_item_dict = item['_ext']['TX'][ext_item][0]
                    #for inner_set_item_dict in item['_ext']['TX'][ext_item]:
                    if 'beginDate' not in inner_set_item_dict.keys():
                        #print('beginDate ABSENT')
                        inner_keys_to_remove.append(ext_item) 
                    item['_ext']['TX'][ext_item] = inner_set_item_dict

            # Remove the keys from the dictionary
            for key in inner_keys_to_remove:
                item['_ext']['TX'].pop(key)
        
        if ('_ext' in item) and ('TX' in item['_ext']):
            if item['_ext']['TX'] == {}:
                item.pop('_ext')
        return item, metadata_logging_values, staffUniqueId 

    def cast_column_to_bool(self, df, column_name):
        # Define a user-defined function (UDF) to convert string to bool
        from pyspark.sql.functions import udf
        from pyspark.sql.types import BooleanType
        
        def str_to_bool(s):
            return s == "1"
        str_to_bool_udf = udf(str_to_bool, BooleanType())
        df = df.withColumn(column_name, str_to_bool_udf(col(column_name)))
        return df

    def get_latest_submission_records(self, df, lookup_table_name, filtering_date = 'rundate',resource_name = '',sap_pipeline = '', sap_pipelineType = '',operationType = 'submission',debugMode = False):
        # FIXME 2024-01-24: Under Refactoring
        maxdatetime = None
        maxRecordVersion = None
        try:
            lookup_df = spark.sql(f"""
                                    SELECT lastSubmissionMaxRunDate as maxdatetime,
                                           lastSubmissionMaxRecordVersion as maxRecordVersion
                                    FROM {self.lookup_db_name}.{lookup_table_name} 
                                    WHERE resource_name = '{resource_name}'
                                    AND sap_pipeline = '{sap_pipeline}'
                                    AND sap_pipelineType = '{sap_pipelineType}'
                                """)
            lookup_df_first = lookup_df.first()
            if lookup_df_first is not None:
                maxdatetime = lookup_df_first['maxdatetime']
                maxRecordVersion = lookup_df_first['maxRecordVersion']
            else:
                # FIXME: Temporary Fix
                maxdatetime = '2020-11-25'
                maxRecordVersion = 0
        except AnalysisException as e:
            pass

        if maxdatetime and not(debugMode):
            if operationType == 'submission':
                df = df.where(f"{filtering_date} > '{maxdatetime}'").where(f'RECORD_VERSION > {maxRecordVersion}')
            elif operationType == 'delete':
                df = df.where(f"{filtering_date} > '{maxdatetime}'").where(f'RECORD_VERSION >= {maxRecordVersion}')
            else:
                df = df.where(f"{filtering_date} > '{maxdatetime}'")
        return df
    
    def store_maxRunDates(self, df, resource_name):
        # FIXME 2024-01-24: Under Refactoring
        df.createOrReplaceTempView('SUBMISSION_SESSION_TEMP_VIEW')
        result = spark.sql("""SELECT max(rundate) as maxRunDate,
                                     max(RECORD_VERSION) as maxRecordVersion, 
                                     sap_pipeline,
                                     sap_pipelineType, 
                                     DistrictId, 
                                     SchoolYear 
                              FROM SUBMISSION_SESSION_TEMP_VIEW
                              GROUP BY DistrictId, 
                                       SchoolYear, 
                                       sap_pipeline,
                                       sap_pipelineType""")
        if result.first() is not None:
            maxRunDate = result.first()['maxRunDate']
            maxRecordVersion = result.first()['maxRecordVersion']
            sap_pipeline = result.first()['sap_pipeline']
            sap_pipelineType = result.first()['sap_pipelineType']
            submission_districtId = result.first()['DistrictId']
            submission_schoolYear = result.first()['SchoolYear']

            temp_dict = dict()
            temp_dict['resource_name'] = resource_name
            temp_dict['lastSubmissionMaxRunDate'] = maxRunDate
            temp_dict['lastSubmissionMaxRecordVersion'] = maxRecordVersion
            temp_dict['sap_pipeline'] = sap_pipeline
            temp_dict['sap_pipelineType'] = sap_pipelineType
            temp_dict['DistrictId'] = submission_districtId
            temp_dict['SchoolYear'] = submission_schoolYear
            temp_dict['recordAPIRefreshDateTime'] = datetime.now()
            
            with self.lock:
                self.max_rundates.append(temp_dict)

    def return_submission_type(self,
                               df):
        # FIXME Under dev and review
        sap_pipeline = df.first()['sap_pipeline']
        sap_pipelineType = df.first()['sap_pipelineType']
        return sap_pipeline, sap_pipelineType
    
    def load_by_SY_DI(self, path):
        # FIXME Under dev and review
        df = spark.read.format('delta').load(oea.to_url(path)).filter(f"DistrictId == {self.districtId}").filter(f"SchoolYear == {self.schoolYear}")
        return df
    
    def process_submission_entities(self, df, resource_name):
        # FIXME Under dev and review
        if resource_name == 'budgetExts':
            df = df.withColumn("budgetAmount", round(F.col("budgetAmount")).cast("int"))
        if resource_name == 'actualExts':
            df = df.withColumn("actualAmount", round(F.col("actualAmount")).cast("int"))
        if resource_name == 'payrollExts':
            df = df.withColumn("payrollAmount", round(F.col("payrollAmount")).cast("int"))
        if resource_name == 'staffEducationOrganizationAssignmentAssociations':
            pass
        if resource_name == 'staffEducationOrganizationEmploymentAssociations':
            pass
            #df = df.withColumn('employmentStatusDescriptor', lit('uri://ed-fi.org/employmentStatusDescriptor#Other')) 
        if resource_name == 'staffs':
            df = self.cast_column_to_bool(df, 'hispanicLatinoEthnicity')
        return df
    
    def filter_out_invalid_vals(self, df, resource_name):
        # FIXME: TEMP FIX to filter out unecessary variables
        if resource_name == 'actualExts':
            df.createOrReplaceTempView("temp_vw_actualExts")
            df = spark.sql("""
                            SELECT *
                            FROM temp_vw_actualExts
                            WHERE NOT (
                                actualFundDescriptor LIKE '%uriPlaceholder#NA%' 
                                OR actualObjectDescriptor LIKE '%uriPlaceholder#NA%'
                                OR actualProgramIntentDescriptor LIKE '%uriPlaceholder#NA%'
                                OR actualFunctionDescriptor LIKE '%uriPlaceholder#NA%'
                            )
                        """)
        return df
        
    def loadDataFromStage3IntoJSON(self,
                                   resource_name, 
                                   file_path):
        # FIXME Under dev and review
        try:
            df = self.load_by_SY_DI(file_path)
            sap_pipeline, sap_pipelineType = self.return_submission_type(df)
            df = self.get_latest_submission_records(df = df, 
                                                    lookup_table_name = self.lookup_table_name,#'submissions_lookup_table', 
                                                    filtering_date = 'rundate',
                                                    resource_name = resource_name,
                                                    sap_pipeline = sap_pipeline,
                                                    sap_pipelineType = sap_pipelineType,
                                                    operationType = 'submission',
                                                    debugMode = False)
            df = df.filter(df['SUBMISSION_RECORD_IS_ACTIVE'] == True)
            self.store_maxRunDates(df, resource_name)
            df = self.process_submission_entities(df, resource_name)
            
            post_cols = self.final_columns[resource_name]
            df = self.filter_columns(df, post_cols)
            df = self.filter_out_invalid_vals(df, resource_name)
            df = df.cache()
            
            json_df = df.toJSON().map(lambda x: json.loads(x)).collect()
            json_str = df.toJSON().collect()

            return json_df, json_str,df
        except Exception as error:
            logger.exception(f'An Error Occured - {error}')
            return None, None, None

    def getDataForEdFiPosts(self,
                        resource_names = None,
                    file_path = None,
                    resource_json_dict = dict()
                    ):
        for resource_name in resource_names:
            resource_path = f'{file_path}/{resource_name}'            
            logger.info(f"Stage 3 Path - {resource_path}")
            json_list,json_str,temp_df = self.loadDataFromStage3IntoJSON(resource_name, 
                                                                    resource_path)
            if json_list is not None:
                resource_json_dict[resource_name] = json_list
            else:
                print(f'Error in loading in resource = {resource_name}')
        return resource_json_dict
    
    def return_lookup_table_as_spark_df(self):
        # FIXME Under dev and review
        df = spark.createDataFrame(self.max_rundates)
        return df
    
    def dump_lookup_table(self):
        # FIXME Under dev and review
        df = self.return_lookup_table_as_spark_df()
        self.oea.upsert(df = df, 
                   destination_path = self.lookup_table_path,
                   primary_key = ['resource_name', 'sap_pipeline', 'sap_pipelineType', 'DistrictId', 'SchoolYear'],
                   partitioning = True,
                   partitioning_cols = [],
                   surrogate_key = False) 
    
    def add_lookup_table_to_lake_db(self, overwrite = True):
        # FIXME Under dev and review
        spark.sql(f'CREATE DATABASE IF NOT EXISTS {self.lookup_db_name}')
        if overwrite:
            spark.sql(f"drop table if exists {self.lookup_db_name}.{self.lookup_table_name}")
        spark.sql(f"create table if not exists {self.lookup_db_name}.{self.lookup_table_name} using DELTA location '{self.oea.to_url(self.lookup_table_path)}'")

    def generate_edfi_id_record(self, **kwargs):
        record = {'edfi_location': str(kwargs.get('edfi_location', '')),
                  'edfi_id': str(kwargs.get('edfi_id', '')),
                  'edfi_id_modified': f"{self.districtId}_{self.schoolYear}_{str(kwargs.get('edfi_id', ''))}",
                  'entity_name' : str(kwargs.get('entity_name', '')),
                  'resource': str(kwargs.get('resource', '')),
                  'staffUniqueId': str(kwargs.get('staffUniqueId', '')),
                  'recordAPIRefreshDateTime': str(kwargs.get('recordAPIRefreshDateTime', '')),
                  'isDeleted': kwargs.get('isDeleted', False),
                  }
        return record

    def generate_log_record(self, **kwargs):
        # TODO: Add RUN ID which is <> pipeline id
        # FIXME: pipelineExecutionId is the alias of pipeline_execution_id
        log_record = {
            'pipeline_execution_id': str(kwargs.get('pipeline_execution_id', '')),
            'pipelineExecutionId': str(kwargs.get('pipeline_execution_id', '')),
            'run_id': str(kwargs.get('run_id', '')),
            'operation_type': str(kwargs.get('operation_type', '')),
            'resource': str(kwargs.get('resource', '')),
            'entity_name': str(kwargs.get('entity_name', '')),
            'request_url': str(kwargs.get('request_url', '')),
            'request_payload': str(kwargs.get('json_object', '')),
            'response_status_code': str(kwargs.get('response_status_code', '')),
            'response_text': str(kwargs.get('response_text', '')),
            'response_headers': str(kwargs.get('response_headers', '')),
            'rundate_time': str(kwargs.get('rundate_time', '')),
            'start_time': str(kwargs.get('start_time', '')),
            'end_time': str(kwargs.get('end_time', '')),
            'record_id': str(kwargs.get('record_id', '')),
            'lakeId': str(kwargs.get('lakeId', '')),
            'edfi_location': str(kwargs.get('edfi_location', '')),
            'edfi_id': str(kwargs.get('edfi_id', '')),
            'edfi_id_modified': f"{self.districtId}_{self.schoolYear}_{str(kwargs.get('edfi_id', ''))}",
            'sap_pipeline': str(kwargs.get('sap_pipeline', '')),
            'sap_pipelineType': str(kwargs.get('sap_pipelineType', '')),
            'stage1_source_url': str(kwargs.get('stage1_source_url', '')),
            'NATURAL_KEY_HASH': str(kwargs.get('NATURAL_KEY_HASH', ''))
            }
        return log_record
    
    def send_delete_request(self,
                        pipeline_execution_id,
                        run_id,
                        resource,
                        chunk_num, 
                        chunk, 
                        url, 
                        success_logging = True,
                        error_logging = False):
        self.init_thread_local_vars()
        rundate_time=datetime.now()
        try:
            for json_object in chunk:
                edfi_id = json_object.get('edfi_id')
                edfi_location = json_object.get('edfi_location')
                delete_url = f"{self.dataManagementUrl}{resource}/{edfi_id}"
                
                start_time = datetime.now()
                requests_session = self.getSession()
                response = requests.delete(delete_url, headers={"Authorization": f"Bearer {self.getAccessToken()}"})
                end_time = datetime.now()

            
                logged_record = self.generate_log_record(pipeline_execution_id = pipeline_execution_id,
                                                            run_id = run_id,
                                                            operation_type = 'delete',
                                                            resource = resource,
                                                            entity_name = resource.split('/')[-1],
                                                            request_url = url,
                                                            json_object = json_object,
                                                            response_status_code = response.status_code,
                                                            response_text = response.text,
                                                            response_headers = response.headers,
                                                            rundate_time = rundate_time,
                                                            start_time = start_time,
                                                            end_time = end_time,
                                                            stage1_source_url = 'INVALID_COL_FOR_DELETE', # metadata_logging_values['stage1_source_url'],
                                                            record_id = 'INVALID_COL_FOR_DELETE', # metadata_logging_values['RECORD'],
                                                            lakeId = 'INVALID_COL_FOR_DELETE', # metadata_logging_values['lakeId'],
                                                            edfi_location = response.headers.get('location', edfi_location), # FIXME: Under Review
                                                            edfi_id = response.headers.get('location', edfi_id).split('/')[-1], # FIXME: Under Review
                                                            sap_pipeline = 'INVALID_COL_FOR_DELETE', 
                                                            sap_pipelineType = 'INVALID_COL_FOR_DELETE', 
                                                            NATURAL_KEY_HASH = 'INVALID_COL_FOR_DELETE' # str(metadata_logging_values['NATURAL_KEY_HASH'])
                                                            )

                if response.status_code < 400:
                    if success_logging:
                        logger.info(f'SUCCESS - {response.status_code}')
                        with self.lock:
                            self.success_logs.append(logged_record)
                        edfi_id_record = self.generate_edfi_id_record(edfi_location = edfi_location,# response.headers.get('location'),
                                                                      edfi_id = edfi_id,# response.headers.get('location').split('/')[-1],
                                                                      resource = resource,
                                                                      entity_name = resource.split('/')[-1],
                                                                      staffUniqueId = 'INVALID_PLACEHOLDER',
                                                                      recordAPIRefreshDateTime = datetime.now(),
                                                                      isDeleted = True)
                        with self.lock:
                            self.edfi_id_records.append(edfi_id_record)
                else:
                    if error_logging:
                        logger.info(f"ERROR Code - {response.status_code}")
                        with self.lock:
                            self.error_logs.append(logged_record)
        except Exception as e:
            logger.info(f"ERROR: {e}")

    def send_api_request(self,
                        pipeline_execution_id,
                        run_id,
                        resource,
                        chunk_num, 
                        chunk, 
                        url, 
                        success_logging = True,
                        error_logging = False):
        self.init_thread_local_vars()
        rundate_time=datetime.now()
        try:
            for json_object in chunk:
                start_time = datetime.now()
                json_object, metadata_logging_values, staffUniqueId = self.cleanup_dict(json_object)
                headers = {
                        'Authorization': f"Bearer {self.getAccessToken()}",
                        'Content-Type': 'application/json'
                        }
                requests_session = self.getSession()
                response = requests_session.post(url, json=json_object, headers=headers)
                end_time = datetime.now()

                logged_record = self.generate_log_record(pipeline_execution_id = pipeline_execution_id,
                                                            run_id = run_id,
                                                            resource = resource,
                                                            operation_type = 'insert-or-update',
                                                            entity_name = resource.split('/')[-1],
                                                            request_url = url,
                                                            json_object = json_object,
                                                            response_status_code = response.status_code,
                                                            response_text = response.text,
                                                            response_headers = response.headers,
                                                            rundate_time = rundate_time,
                                                            start_time = start_time,
                                                            end_time = end_time,
                                                            stage1_source_url = metadata_logging_values['stage1_source_url'],
                                                            record_id = metadata_logging_values['RECORD'],
                                                            lakeId = metadata_logging_values['lakeId'],
                                                            edfi_location = response.headers.get('location', 'INVALID_VALUE_PLACEHOLDER'),
                                                            edfi_id = response.headers.get('location', 'INVALID_VALUE_PLACEHOLDER').split('/')[-1],
                                                            sap_pipeline = metadata_logging_values['sap_pipeline'],
                                                            sap_pipelineType = metadata_logging_values['sap_pipelineType'],
                                                            NATURAL_KEY_HASH = str(metadata_logging_values['NATURAL_KEY_HASH'])
                                                            )

                if response.status_code < 400:
                    if success_logging:
                        logger.info(f'SUCCESS - {response.status_code}')
                        with self.lock:
                            self.success_logs.append(logged_record)

                        edfi_id_record = self.generate_edfi_id_record(edfi_location = response.headers.get('location'),
                                                                      edfi_id = response.headers.get('location').split('/')[-1],
                                                                      resource = resource,
                                                                      entity_name = resource.split('/')[-1],
                                                                      staffUniqueId = staffUniqueId,
                                                                      recordAPIRefreshDateTime = datetime.now(),
                                                                      isDeleted = False)
                        with self.lock:
                            self.edfi_id_records.append(edfi_id_record)
                else:
                    if error_logging:
                        #logger.info(f"There was an error submitting data for {resource}")
                        logger.error(f"[POST TO ED-FI] {response.status_code}")
                        with self.lock:
                            self.error_logs.append(logged_record)
        except Exception as e:
            logger.info(f"ERROR: {e}")

    # Function to upsert records using multi-threading
    def upsert_records(self,
                    pipeline_execution_id,
                    run_id,
                    resource, 
                    resource_name, 
                    records, 
                    chunk_size = 500, 
                    num_threads = 10, 
                    function_name = 'post',
                    success_logging = True, 
                    error_logging = False):
        logger.info(f"Initiating {resource}")
        logger.info(f"Processing {len(records)} records")
        self.init_thread_local_vars()

        url = f"{self.dataManagementUrl}{resource}"
        chunks = [records[i:i + chunk_size] for i in range(0, len(records), chunk_size)]
        threads = []
        del records

        with ThreadPoolExecutor(max_workers=num_threads) as tpe:
            if function_name == 'post':
                for chunk_num, chunk in enumerate(chunks):
                    tpe.submit(self.send_api_request, pipeline_execution_id, run_id,resource, chunk_num, chunk, url,success_logging, error_logging)
            elif function_name == 'delete':
                for chunk_num, chunk in enumerate(chunks):
                    tpe.submit(self.send_delete_request, pipeline_execution_id, run_id,resource, chunk_num, chunk, url,success_logging, error_logging)

        # Print a completion message after all records have been upserted
        logger.info(f"All {resource} calls completed.")
        del chunks
    
    def deleteEntityById(self, resource, id):
        try:
            url = f"{self.dataManagementUrl}{resource}/{id}"
            response = requests.delete(url, headers={"Authorization": f"Bearer {self.getAccessToken()}"})
            if response.status_code == 404:
                logger.info("RESOURCE NOT FOUND")
            if response.status_code < 400:
                logger.info(f"RESOURCE DELETED - {id}")
            return response
        except Exception as error:
            logger.error(error)
    
    def filter_columns(self, df, column_list):
        # TODO: USE the one present in the class named SAP Utilities instead
        existing_columns = [col for col in column_list if col in df.columns]
        return df.select(existing_columns)