In [2]:
import requests
import pandas as pd
import json
import datetime
from datetime import datetime, timedelta
import os
import re
from delta.tables import DeltaTable
import urllib
from urllib.parse import urlparse
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql.functions import concat, lit, monotonically_increasing_id, expr,input_file_name, trim, lower
from pyspark.sql.functions import coalesce, col, when, expr, format_number, avg, count, sum
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, DecimalType, ArrayType
from pyspark.sql.utils import AnalysisException
import warnings
import traceback

class OEAI:
    """
    The Open Education AI (OEAI) class provides a suite of utilities for handling and manipulating 
    data within a lakehouse architecture. This class simplifies complex data processing tasks, 
    including handling of nested JSON structures, DataFrame schema alignment, and secure access 
    to sensitive information.

    The class offers methods for:
    - Flattening nested JSON data into a simpler tabular format.
    - Matching and aligning DataFrame column types across different DataFrames.
    - Securely retrieving secrets from Azure Key Vault.
    - Adding missing columns to DataFrames and creating structures with null values.
    - Dynamically modifying DataFrame schemas based on specified mappings.
    - Working with date ranges and generating date chunks for batch processing.
    - Handling nested and complex data types within DataFrames.

    This class is designed to be flexible and robust, suitable for various data processing 
    and transformation needs in educational data analysis and other domains.

    Usage:
        The OEAI class is instantiated and its methods are called with the necessary 
        parameters, typically involving Spark DataFrames and other PySpark constructs.

    Example:
        ```
        oeai = OEAI()
        flattened_data = oeai.flatten_nested_json(json_data)
        updated_df = oeai.match_column_types(df1, df2)
        secret_value = oeai.get_secret(spark, "mySecretName", "myKeyvaultLinkedService", "myKeyvault")
        ```
    """
    def __init__(self, timezone="UTC", platform="Synapse"):
        # Initialize Spark Session
        self.spark = SparkSession.builder.appName("oeaiSpark").getOrCreate()
        
        # Set the timezone
        self.spark.conf.set("spark.sql.session.timeZone", timezone)

        # Suppress future warnings
        warnings.filterwarnings("ignore", category=FutureWarning)

        # Selects relevant code blocks. Defaults to "Synapse"
        self.platform = platform
    
    def load_audit_log(self, spark, audit_log_file):
        """
        Loads the audit log from a specified JSON file if it exists, otherwise returns an empty list.

        Args:
            spark (SparkSession): Active Spark session for file system operations.
            audit_log_file (str): Path to the audit log file.

        Returns:
            list: A list of dictionaries representing the audit log entries.
        """     
        schema = StructType([
            StructField("school_id", StringType(), True),
            StructField("endpoint", StringType(), True),
            StructField("query", StringType(), True),
            StructField("start_time", StringType(), True),
            StructField("end_time", StringType(), True),
            StructField("duration", StringType(), True),
            StructField("records_returned", StringType(), True)
        ])
        
        try:
            if self.platform == "Fabric":
                if mssparkutils.fs.exists(audit_log_file):
                    # Use this schema when reading the audit log
                    audit_log_df = spark.read.schema(schema).json(audit_log_file)
                    # Converting to a list of dictionaries while avoiding collecting large data sets
                    return [row.asDict() for row in audit_log_df.collect()]
                else:
                    return []
            else:
                fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(spark._jsc.hadoopConfiguration())
                path = spark._jvm.org.apache.hadoop.fs.Path(audit_log_file)
            
                if fs.exists(path):
                    # Use this schema when reading the audit log
                    audit_log_df = spark.read.schema(schema).json(audit_log_file)
                    # Converting to a list of dictionaries while avoiding collecting large data sets
                    return [row.asDict() for row in audit_log_df.collect()]
                else:
                    return []
        except Exception as e:
            # Log the exception
            print(f"An error occurred while loading the audit log: {e}")
            return []

    def save_audit_log(self, spark, audit_log, audit_log_file):
        """
        Saves the audit log to a JSON file.

        Args:
            spark (SparkSession): Active Spark session for file operations.
            audit_log (list): List of audit log entries, each as a dictionary.
            audit_log_file (str): Path to save the audit log file.

        Note:
            This function overwrites the existing file at audit_log_file path.
        """
        # Define the schema for the audit log DataFrame
        schema = StructType([
            StructField("school_id", StringType(), True),
            StructField("endpoint", StringType(), True),
            StructField("query", StringType(), True),
            StructField("start_time", StringType(), True),
            StructField("end_time", StringType(), True),
            StructField("duration", StringType(), True),
            StructField("records_returned", StringType(), True)
        ])

        try:
            # Convert the list of dictionaries to a DataFrame
            audit_log_df = spark.createDataFrame([Row(**record) for record in audit_log], schema)

            # Write the DataFrame to a JSON file, overwriting any existing file
            audit_log_df.write.mode("overwrite").json(audit_log_file)
        except Exception as e:
            # Handle potential exceptions (consider using a logging framework)
            print(f"An error occurred while saving the audit log: {e}")


    def save_empty_json(self, spark, file_path):
        """
        Creates and saves an empty JSON file with a predefined schema.

        Args:
            spark (SparkSession): Active Spark session for file operations.
            file_path (str): The file path where the empty JSON file will be saved.

        Note:
            This function creates a JSON file with a single empty record.
        """
        # Define an empty schema with at least one field
        schema = StructType([
            StructField("message", StringType(), True)
        ])

        try:
            # Create an empty DataFrame with the defined schema
            empty_df = spark.createDataFrame([], schema)

            # Write the empty DataFrame to a JSON file, overwriting any existing file
            empty_df.write.mode("overwrite").json(file_path)
        except Exception as e:
            # Handle potential exceptions (consider using a logging framework)
            print(f"An error occurred while saving the empty JSON file: {e}")


    def parse_date(self, date_dict):
        """
        Parses a date from a dictionary if the 'date' key exists.

        Args:
            date_dict (dict): A dictionary potentially containing a 'date' key.

        Returns:
            str: The value associated with the 'date' key, if it exists; otherwise, None.
        """
        # Check if the input is a dictionary and contains the 'date' key
        if isinstance(date_dict, dict) and 'date' in date_dict:
            return date_dict['date']

        # Return None if the 'date' key does not exist or the input is not a dictionary
        return None


    def flatten_json(self, y):
        """
        Flattens a nested JSON structure into a flat dictionary. Special handling for date fields.

        Args:
            y (dict or list): The JSON object (nested dictionary or list) to be flattened.

        Returns:
            dict: A flat dictionary with all nested keys concatenated by underscores.
        """
        out = {}

        def flatten(x, name=''):
            # Recursively flatten the dictionary
            if isinstance(x, dict):
                for a in x:
                    # Special handling for date fields
                    if a in ['achievement_date', 'recorded_date', 'created_at', 'updated_at']:
                        out[f"{name}{a}"] = self.parse_date(x[a])
                    else:
                        flatten(x[a], f"{name}{a}_")
            # Flatten each item in the list separately
            elif isinstance(x, list):
                for i, a in enumerate(x):
                    flatten(a, f"{name}{i}_")
            else:
                # Handle the base case
                out[name[:-1]] = x

        flatten(y)
        return out

    def generate_date_chunks(self, start_date, end_date, chunk_size=timedelta(weeks=2)):
        """
        Generates date ranges from start_date to end_date in specified chunk sizes.

        Args:
            start_date (datetime): The start date of the range.
            end_date (datetime): The end date of the range.
            chunk_size (timedelta, optional): The size of each date chunk. Defaults to two weeks.

        Yields:
            tuple: A tuple containing the start and end date of each chunk.
        """
        while start_date < end_date:
            chunk_end_date = min(start_date + chunk_size, end_date)
            yield (start_date, chunk_end_date)
            start_date = chunk_end_date

    def update_query_with_chunks(self, original_query, start_date, end_date):
        """
        Updates a query string by replacing or adding 'updated_after' and 'updated_before' parameters 
        with the provided start and end dates.

        Args:
            original_query (str): The original query string.
            start_date (datetime): The start date for the 'updated_after' parameter.
            end_date (datetime): The end date for the 'updated_before' parameter.

        Returns:
            str: The updated query string.
        """
        # Remove existing 'updated_after' parameter using regex
        query_without_updated_after = re.sub(r'updated_after=[^&]*', '', original_query)

        # Trim any trailing '&' characters
        query_without_updated_after = query_without_updated_after.rstrip('&')

        # Ensure the query starts correctly with '?' or '&' based on existing content
        if query_without_updated_after and not query_without_updated_after.startswith('?'):
            query_prefix = '&' if '?' in query_without_updated_after else '?'
            query_without_updated_after = query_prefix + query_without_updated_after

        # Format the new 'updated_after' and 'updated_before' parameters
        formatted_start_date = start_date.strftime('%Y-%m-%dT%H:%M:%SZ')
        formatted_end_date = end_date.strftime('%Y-%m-%dT%H:%M:%SZ')
        chunk_query = f"{query_without_updated_after}updated_after={formatted_start_date}&updated_before={formatted_end_date}"
        
        return chunk_query    

    def get_secret(self, spark, secret_name, keyvault_linked_service, keyvault):
        """
        Retrieves a specified secret from Azure Key Vault.

        Args:
            spark (SparkSession): The SparkSession object.
            secret_name (str): The name of the secret to retrieve.
            keyvault_linked_service (str): The name of the Azure Synapse Analytics linked service.
            keyvault (str): The name of the Azure Key Vault.

        Returns:
            str: The value of the retrieved secret.
        """
        if self.platform == "Fabric":
            value = mssparkutils.credentials.getSecret(keyvault, secret_name)
        else:
            # Access the TokenLibrary from Azure Synapse's Spark pool
            token_library = spark._jvm.com.microsoft.azure.synapse.tokenlibrary.TokenLibrary

            # Retrieve the secret value
            value = token_library.getSecret(keyvault, secret_name, keyvault_linked_service)

        return value

    def row_to_dict(self, row_obj):
        """
        Converts a PySpark Row object into a Python dictionary.

        Args:
            row_obj (Row): The Row object to be converted.

        Returns:
            dict: A dictionary representation of the Row object.
        """
        # Convert the Row object to a dictionary
        return {k: v for k, v in row_obj.asDict().items()}

    def safe_get(self, dct, *keys):
        """
        Safely retrieves a value from a nested dictionary using a sequence of keys.

        Args:
            dct (dict): The dictionary from which to retrieve the value.
            *keys: A sequence of keys to traverse through the nested dictionary.

        Returns:
            The value found at the nested key path, or None if any key is missing or an error occurs.
        """
        for key in keys:
            try:
                dct = dct[key]
            except (TypeError, KeyError):
                # Return None if the key is not found or if not a dictionary
                return None
        return dct

    
    def safe_get_or_create(self, dct, default_value, *keys):
        """
        Safely retrieves or sets a value in a nested dictionary using a sequence of keys. 
        If the key path does not exist, it is created and set to a default value.

        Args:
            dct (dict): The dictionary to traverse or modify.
            default_value: The default value to set if the key path does not exist.
            *keys: A sequence of keys representing the path in the nested dictionary.

        Returns:
            The value found or set at the nested key path.
        """
        for key in keys[:-1]:
            # Ensure each key in the path exists and is a dictionary
            if key not in dct or not isinstance(dct[key], dict):
                dct[key] = {}
            dct = dct[key]

        # Set the final key to default value if it does not exist
        final_key = keys[-1]
        if final_key not in dct:
            dct[final_key] = default_value

        return dct[final_key]

    import json

    def flatten_nested_json(self, data):
        """
        Flattens a nested JSON structure. Special handling is implemented for 'students' data.

        Args:
            data (str): A string representation of the JSON data.

        Returns:
            list: A list of flattened dictionary objects.
        """
        data = json.loads(data)
        output_data = []

        for item in data:
            base_info = self._flatten_dict(item, exclude_keys=['students'])

            if 'students' in item:
                students_info = self._flatten_students_data(item['students'])

                for student in students_info:
                    # Combine base info with each student's info
                    combined_info = {**base_info, **student}
                    output_data.append(combined_info)
            else:
                # If there are no students, append base_info
                output_data.append(base_info)

        return output_data

    def _flatten_dict(self, dct, exclude_keys=None, prefix=None):
        """
        Flattens a dictionary, optionally excluding specified keys and adding a prefix to keys.

        Args:
            dct (dict): The dictionary to flatten.
            exclude_keys (list, optional): Keys to exclude from flattening.
            prefix (str, optional): A prefix to prepend to each key in the flattened dictionary.

        Returns:
            dict: A flattened dictionary.
        """
        exclude_keys = exclude_keys or []
        flattened = {}
        for key, value in dct.items():
            if key in exclude_keys:
                continue
            if isinstance(value, dict):
                for subkey, subvalue in value.items():
                    new_key = f'{prefix}_{subkey}' if prefix else f'{key}_{subkey}'
                    flattened[new_key] = subvalue
            else:
                new_key = f'{prefix}_{key}' if prefix else key
                flattened[new_key] = value
        return flattened

    def _flatten_students_data(self, students_data):
        """
        Specifically flattens the 'students' data within the nested JSON.

        Args:
            students_data (dict): The 'students' section of the data.

        Returns:
            list: A list of flattened student data dictionaries.
        """
        flattened_students = []
        for student_data in students_data['data']:
            flattened_student = self._flatten_dict(student_data, prefix='student_data')
            flattened_students.append(flattened_student)
        return flattened_students

    def apply_column_mappings(self, df, mappings):
        """
        Applies various column mappings to a DataFrame such as dropping, renaming, 
        and adding columns with default values.

        Args:
            df (DataFrame): The DataFrame to be modified.
            mappings (dict): A dictionary containing the mapping instructions. 
                             Keys are column names and values are actions or new names.

        Returns:
            DataFrame: The modified DataFrame after applying the mappings.
        """
        # Drop columns
        drop_cols = [col for col, action in mappings.items() if action == "drop"]
        df = df.drop(*drop_cols)

        # Rename columns
        rename_mappings = {col: details['new_name'] for col, details in mappings.items()
                           if isinstance(details, dict) and 'new_name' in details}
        for old_col, new_col in rename_mappings.items():
            df = df.withColumnRenamed(old_col, new_col)

        # Add new columns with default values
        add_columns = mappings.get("add_columns", {})
        for new_col, default_value in add_columns.items():
            df = df.withColumn(new_col, lit(default_value))

        return df

    def add_missing_columns(self, df, columns):
        """
        Adds missing columns to the DataFrame as null columns of StringType.

        Args:
            df (DataFrame): The DataFrame to which columns will be added.
            columns (list): A list of column names to be added if they are missing.

        Returns:
            DataFrame: The DataFrame with the missing columns added.
        """
        # Special case: If the DataFrame is empty (except for 'school_id'), add all specified columns
        if len(df.columns) == 1 and 'school_id' in df.columns:
            missing_columns = [col for col in columns if col != 'school_id']
        else:
            # General case: Identify columns that are missing from the DataFrame
            missing_columns = [col for col in columns if col not in df.columns]

        # Add each missing column as a null column
        for col in missing_columns:
            df = df.withColumn(col, lit(None).cast(StringType()))

        return df

    def create_null_struct(fields):
        """
        Creates a Spark SQL struct with null values for each specified field.

        Args:
            fields (list of StructField): A list of StructField objects defining the schema of the struct.

        Returns:
            Column: A Spark SQL Column representing a struct with null values for each field.
        """
        return lit(None).cast(StructType(fields))

    def get_uuid_column_name(self, delta_table_name):
        """
        Determines the UUID column name based on the given delta table name.

        Args:
            delta_table_name (str): The name of the delta table.

        Returns:
            str: The standardized UUID column name.
        """
        # Remove prefixes and convert to lowercase to get the base name
        base_name = delta_table_name.replace("dim_", "").replace("fact_", "").lower()
        
        # Form the UUID column name by appending 'key' to the base name
        uuid_column_name = f"{base_name}key"

        return uuid_column_name

    def match_column_types(self, df1, df2):
        """
        Matches the column data types of two DataFrames.

        The method iterates through the columns of df1 and updates df2 to ensure
        matching data types, casting them as necessary. If df2 lacks a column present in df1,
        it is added with null values of the appropriate type.

        Args:
            df1 (DataFrame): The DataFrame with the desired column types.
            df2 (DataFrame): The DataFrame to be modified to match df1's column types.

        Returns:
            DataFrame: The updated df2 with matching column types to df1.
        """
        for col_name in df1.columns:
            if col_name in df2.columns:
                df1_col_type = df1.schema[col_name].dataType
                df2_col_type = df2.schema[col_name].dataType

                if isinstance(df1_col_type, StructType) and isinstance(df2_col_type, StructType):
                    # Additional logic for StructType columns
                    pass
                elif isinstance(df1_col_type, ArrayType) and isinstance(df2_col_type, ArrayType):
                    # Additional logic for ArrayType columns
                    pass
                else:
                    # Cast to the same type as in df1
                    df2 = df2.withColumn(col_name, col(col_name).cast(df1_col_type))
            else:
                # Add missing column in df2 as nulls of the same type as in df1
                df2 = df2.withColumn(col_name, lit(None).cast(df1.schema[col_name].dataType))

        return df2

    def log_error(self, spark, message, error_log_path):
        """
        Logs an error message to the specified error log file in ABFS. 
        Creates the file if it doesn't exist.

        Args:
            spark (SparkSession): Active Spark session for file system operations.
            message (str): The error message to log.
            error_log_path (str): Path to the error log file in ABFS.

        Returns:
            None
        """
        try:
            # Get the current date and time
            current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

            # Create a new Row with the error message and timestamp
            error_entry = Row(timestamp=current_time, message=message)
            error_df = spark.createDataFrame([error_entry])

            # Writing the error message to the error log file
            error_df.repartition(1).write.mode("append").json(error_log_path)
        except Exception as e:
            print(f"An error occurred while logging the error: {e}")

    def load_audit_log_silver(self, spark, audit_log_file):
        """
        Loads the audit log from a specified JSON file if it exists, otherwise returns an empty list.

        Args:
            spark (SparkSession): Active Spark session for file system operations.
            audit_log_file (str): Path to the audit log file.

        Returns:
            list: A list of dictionaries representing the audit log entries.
        """
        try:
            fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(spark._jsc.hadoopConfiguration())
            path = spark._jvm.org.apache.hadoop.fs.Path(audit_log_file)
            
            schema = StructType([
                StructField("school_id", StringType(), True),
                StructField("endpoint", StringType(), True),
                StructField("query", StringType(), True),
                StructField("start_time", StringType(), True),
                StructField("end_time", StringType(), True),
                StructField("duration", StringType(), True),
                StructField("records_returned", StringType(), True)
            ])

            if fs.exists(path):
                # Use this schema when reading the audit log
                audit_log_df = spark.read.schema(schema).json(audit_log_file)
                # Converting to a list of dictionaries while avoiding collecting large data sets
                return [row.asDict() for row in audit_log_df.collect()]
            else:
                return []
        except Exception as e:
            # Log the exception
            print(f"An error occurred while loading the audit log: {e}")
            return []

    def save_audit_log_silver(self, spark, audit_log, audit_log_file):
        """
        Saves the audit log to a JSON file.

        Args:
            spark (SparkSession): Active Spark session for file operations.
            audit_log (list): List of audit log entries, each as a dictionary.
            audit_log_file (str): Path to save the audit log file.

        Note:
            This function overwrites the existing file at audit_log_file path.
        """
        # Define the schema for the audit log DataFrame
        schema = StructType([
            StructField("school_id", StringType(), True),
            StructField("endpoint", StringType(), True),
            StructField("query", StringType(), True),
            StructField("start_time", StringType(), True),
            StructField("end_time", StringType(), True),
            StructField("duration", StringType(), True),
            StructField("records_returned", StringType(), True)
        ])

        try:
            # Convert the list of dictionaries to a DataFrame
            audit_log_df = spark.createDataFrame([Row(**record) for record in audit_log], schema)

            # Write the DataFrame to a JSON file, overwriting any existing file
            audit_log_df.write.mode("overwrite").json(audit_log_file)
        except Exception as e:
            # Handle potential exceptions (consider using a logging framework)
            print(f"An error occurred while saving the audit log: {e}")