In [1]:
from dataclasses import dataclass
from pyspark.sql import DataFrame
from pyspark.sql.functions import explode, lit, col
from delta.tables import DeltaTable
from datetime import datetime
import re


def str_to_snake_case(name: str) -> str:
    return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()

def snake_case_columns(data: any) -> any:
    if isinstance(data, DataFrame):
        return data.select([data[col].alias(str_to_snake_case(col)) for col in data.columns])
    elif isinstance(data, list):
        return [str_to_snake_case(col) for col in data]
    else:
        raise TypeError("Input should be a DataFrame or a list of column names")

def table_exists(table_name: str) -> bool:
    try:
        return spark.catalog._jcatalog.tableExists(table_name)
    except Exception:
        return False


def explode_df_columns(df: DataFrame, column_to_explode: str) -> DataFrame:
    exploded_df = df.withColumn(column_to_explode,
                                explode(col(column_to_explode)))
    fields = exploded_df.schema[column_to_explode].dataType.fields
    select_exprs = [
        col(f"{column_to_explode}.{field.name}").alias(field.name)
        for field in fields
    ]
    return exploded_df.select(*select_exprs)


@dataclass
class MetadataFetcher:
    entity_id: int

    def fetch(self,
              table,
              value_column,
              filter_column=None,
              filter_value=None,
              parameter_name=None) -> str:
        query = f"SELECT {value_column} FROM lh_metadata.{table}"
        if filter_column and filter_value:
            query += f" WHERE {filter_column} = {filter_value}"
        if parameter_name:
            query += f" AND parameter_name = '{parameter_name}'"
        try:
            df = spark.sql(query)
            result = df.collect()
            if result:
                return result[0][0]
            else:
                raise ValueError(f"No value found for:\n"
                                 f"  table: {table}\n"
                                 f"  value_column: {value_column}\n"
                                 f"  filter_column: {filter_column}\n"
                                 f"  filter_value: {filter_value}\n"
                                 f"  parameter_name: {parameter_name}")
        except Exception as e:
            print(f"Error executing query: {e}")
            raise

    def fetch_file_format(self) -> str:
        try:
            result = self.fetch("entity_parameters", "parameter_value",
                                "entity_id", self.entity_id,
                                "sink_file_format")
            return result
        except Exception as e:
            print(f"Failed to fetch file format: {e}")
            raise

    def fetch_column_to_explode(self) -> str:
        try:
            result = self.fetch("entity_parameters", "parameter_value",
                                "entity_id", self.entity_id,
                                "column_to_explode")
            return result
        except Exception as e:
            print(f"Failed to fetch column to explode: {e}")
            raise

    def fetch_entity_name(self) -> str:
        try:
            result = self.fetch("entity", "name", "id", self.entity_id)
            return result
        except Exception as e:
            print(f"Failed to fetch entity name: {e}")
            raise

    def fetch_source_id(self) -> int:
        try:
            result = self.fetch("entity", "source_id", "id", self.entity_id)
            return result
        except Exception as e:
            print(f"Failed to fetch source ID: {e}")
            raise

    def fetch_source_name(self) -> str:
        try:
            source_id = self.fetch_source_id()
            result = self.fetch("source", "name", "id", source_id)
            return result
        except Exception as e:
            print(f"Failed to fetch source name: {e}")
            raise

    def fetch_processing_type(self) -> str:
        try:
            result = self.fetch("entity", "processing_type", "id",
                                self.entity_id)
            return result
        except Exception as e:
            print(f"Failed to fetch processing type: {e}")
            raise

    def fetch_key_columns(self) -> str:
        try:
            result = self.fetch("entity_parameters", "parameter_value",
                                "entity_id", self.entity_id, "key_columns")
            return result
            print(result)
        except Exception as e:
            print(f"Failed to fetch key columns: {e}")
            raise


@dataclass
class FileProcessor(MetadataFetcher):
    file_path: str

    def __post_init__(self):
        self.format_readers = {
            'json': spark.read.json,
            'parquet': spark.read.parquet
            # more formats if needed
        }

    def insert_processing_log(self, layer: str, status: str):
        """
        Inserts a processing log entry into the processing_logs table.

        Args:
            layer (str): The processing layer ('bronze' or 'silver').
            status (str): The processing status.
        
        Raises:
            Exception: If an error occurs while writing the log entry.
        """
        log_data = [(self.entity_id, layer, self.file_path, datetime.now(),
                     status)]

        try:
            df = spark.createDataFrame(log_data,
                                       schema=[
                                           "entity_id", "layer", "file_path",
                                           "processing_time", "status"
                                       ])
            df = df.withColumn("entity_id", col("entity_id").cast("integer"))
            df.write.mode("append").saveAsTable("lh_logging.processing_logs")
            print(
                "Log entry successfully written to lh_logging.processing_logs")
        except AnalysisException as ae:
            print(f"AnalysisException: {ae}")
        except Exception as e:
            print(f"An error occurred while writing the log entry: {e}")

    def read_file_to_dataframe(self) -> DataFrame:
        """
        Reads the file into a DataFrame. Supports JSON and Parquet formats.

        Returns:
            DataFrame: The data read from the file.
        
        Raises:
            ValueError: If the file format is unsupported.
        """
        file_format = self.fetch_file_format()
        if file_format in self.format_readers:
            df = self.format_readers[file_format](self.file_path)
            if file_format == "json":
                df = explode_df_columns(df, self.fetch_column_to_explode())
        else:
            raise ValueError(f"Unsupported file format: {file_format}")
        return df

    def business_key_duplicates(self, df: DataFrame) -> bool:
        """
        Checks for duplicates based on business key columns.

        Args:
            df (DataFrame): The DataFrame to check for duplicates.

        Returns:
            bool: True if duplicates are found, False otherwise.
        """
        listed_key_columns = eval(self.fetch_key_columns())
        return df.select(*listed_key_columns).distinct().count() != df.count()

    def _merge_to_table(self, df: DataFrame, table_name: str,
                        key_columns: list):
        """
        Merges the DataFrame into the specified table using the key columns.

        Args:
            df (DataFrame): The DataFrame to merge.
            table_name (str): The name of the table to merge into.
            key_columns (list): The key columns for merging.

        Raises:
            Exception: If an error occurs during the merge.
        """
        try:
            if table_exists(table_name):
                target_table = DeltaTable.forName(spark, table_name)
                merge_condition = " AND ".join(
                    [f"tgt.{col} = src.{col}" for col in key_columns])
                target_table.alias("tgt").merge(
                    df.alias("src"), merge_condition).whenMatchedUpdateAll(
                    ).whenNotMatchedInsertAll().execute()
            else:
                df.write.saveAsTable(table_name)
        except Exception as e:
            raise Exception(f"Error merging to table {table_name}: {e}")

    def write_file_to_bronze(self):
        """
        Reads the file into a DataFrame, adds metadata columns, and writes it to the bronze table.
        """
        try:
            df = self.read_file_to_dataframe()
            df = df.withColumn("_processed_datetime", lit(datetime.now()))
            df = df.withColumn("_source_name", lit(self.fetch_source_name()))
            df = df.withColumn("_file_path", lit(self.file_path))

            entity_name = self.fetch_entity_name()

            # Write to bronze table
            df.write.mode("append").saveAsTable(f"lh_bronze.{entity_name}")

            self.insert_processing_log('bronze', 'success')
        except Exception as e:
            print(f"Error writing to bronze: {e}")
            self.insert_processing_log('bronze', f"failure: {str(e)}")

    def write_file_to_silver(self):
        """
        Reads the file into a DataFrame, applies quality assurance checks, and writes it to the silver table.

        Raises:
            Exception: If an error occurs during processing or writing.
        """
        try:
            df = self.read_file_to_dataframe()
            df = snake_case_columns(df).dropDuplicates()

            if self.business_key_duplicates(df):
                raise ValueError("Business key duplicates found")    

            processing_type = self.fetch_processing_type()
            entity_name = self.fetch_entity_name()
            key_columns = eval(self.fetch_key_columns())

            silver_table = f"lh_silver.{entity_name}"
            if processing_type == "rebuild":
                df.write.mode("overwrite").saveAsTable(silver_table)
            elif processing_type == "merge":
                self._merge_to_table(df, silver_table, key_columns)
            else:
                raise ValueError(
                    f"Unsupported processing type: {processing_type}")

            self.insert_processing_log("silver", "success")
        except Exception as e:
            self.insert_processing_log("silver", f"failure: {str(e)}")
            raise


StatementMeta(, f0c4d403-0739-428d-8f7b-24986363e442, 3, Finished, Available, Finished)