In [None]:
%run ./Global_Configurations

In [None]:
class LayerUtils:

    folder_name_for_dbfs = ""
    file_name_for_s3 = ""

    redshift_url = "jdbc:redshift://your-cluster-name.region-name.redshift.amazonaws.com:port/database-name"
    etl_trakcer_tbl = 'gold.etl_tracker'
    redshift_user = "redshift-user"
    redshift_password = "redshift-password"
    redshift_driver = 'com.amazon.redshift.jdbc.Driver'

    etl_log_df.write.jdbc(
        url=LayerUtils.redshift_url,
        table=LayerUtils.etl_trakcer_tbl,
        mode="append",
        properties={
            "user": LayerUtils.redshift_user,
            "password": LayerUtils.redshift_password,
            "driver": LayerUtils.redshift_driver
        }
    )

    max_retries = 3  # Define as class variable
    retry_delay = 5  # Define retry delay in seconds

    @staticmethod
    def initialize_spark():
        """Retries Spark session initialization in case of failures."""
        for attempt in range(LayerUtils.max_retries):
            try:
                return SparkSession.builder.appName("DataExtractor").getOrCreate()
            except Exception as e:
                print(f"Spark initialization failed (Attempt {attempt + 1}/{LayerUtils.max_retries}): {str(e)}")
                time.sleep(LayerUtils.retry_delay)  # Use class variable
        raise Exception("Failed to initialize Spark after multiple attempts.")

    @staticmethod
    def is_layer_empty(layer_path: str) -> bool:
        """
        Checks if the specified layer directory in DBFS is empty.

        Args:
            layer_path (str): The path to the layer directory.

        Returns:
            bool: True if the layer directory is empty or doesn't exist; otherwise, False.
        """
        try:
            files = dbutils.fs.ls(layer_path)
            return len(files) == 0  # True if no files exist
        except Exception:
            return True

    @staticmethod
    def write_to_layer(data_df: DataFrame, layer_path: str):
        """
        Writes the processed data to the specified layer directory.

        - For full extraction: Overwrites existing data in the directory.
        - For incremental extraction: Writes new records to a new file.
        """
        if LayerUtils.is_layer_empty(layer_path):
            # LayerUtils.folder_name_for_dbfs = "full_extraction.parquet"
            LayerUtils.folder_name_for_dbfs = "full_extraction_mini.parquet"
        else:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            LayerUtils.folder_name_for_dbfs = f"incremental_data_{timestamp}.parquet"
        
        # Coalesce the DataFrame to a single partition
        single_partition_df = data_df.coalesce(1)

        # Define temporary output path
        temp_output_path = f"{layer_path}_temp"

        # Write the DataFrame to the temporary path
        single_partition_df.write.mode("overwrite").parquet(temp_output_path)

        # Move the file from the temporary path to the final path with the correct file name
        files = dbutils.fs.ls(temp_output_path)
        # for file in files:
        #     if file.name.endswith('.parquet'):
        #         dbutils.fs.mv(file.path, f"{layer_path}/{LayerUtils.folder_name_for_dbfs}")
        #         break
        for file in files:
            if file.name.endswith('.parquet'):
                destination_path = f"{layer_path}/{LayerUtils.folder_name_for_dbfs}"
                # Copy the file to the destination
                dbutils.fs.cp(file.path, destination_path)
                # Delete the original file from the temporary location
                dbutils.fs.rm(file.path, recurse=True)
                break

        # Clean up the temporary directory
        dbutils.fs.rm(temp_output_path, True)

        print(f"Data written to layer as {LayerUtils.folder_name_for_dbfs}.")

    @staticmethod
    def is_s3_bucket_empty(bucket_name: str, directory: str = "") -> bool:
        """
        Checks if the specified directory inside an S3 bucket is empty.
        If directory is empty, it checks the entire bucket.
        """
        s3 = boto3.client('s3')
        response = s3.list_objects_v2(Bucket=bucket_name, Prefix=directory)

        # If 'Contents' is not in the response, the bucket/directory is empty
        if 'Contents' not in response:
            return True

        # Check if any actual files exist in the directory
        return all(obj['Key'].endswith("/") for obj in response['Contents'])  # True if only folder placeholders exist

    @staticmethod
    def write_to_s3(data_df: DataFrame, bucket_name: str, s3_path: str):
        """
        Transfers data from the Bronze layer to the Silver layer (S3).

        - Writes the data first to DBFS
        - Renames and moves it to the correct S3 location
        """
        
        source_layer = "bronze"
        destination_layer = "silver"
        source_table = ".parquet file"  # Update with actual Bronze table name
        destination_table = ".parquet file"  # Update with actual Silver table name

        spark = LayerUtils.initialize_spark()
        
        # Log ETL Start
        LayerUtils.log_etl_status(source_layer, destination_layer, source_table, destination_table, "IN_PROGRESS", 0)

        try:
            # Check if Silver Layer is Empty
            if LayerUtils.is_s3_bucket_empty(bucket_name, bucket_path):
                print("Silver layer is empty. Performing full transfer...")
                # bronze_data = spark.read.parquet(bronze_layer_path)
                output_filename = "full_extraction_mini.parquet"
            else:
                print("Silver layer is not empty. Performing incremental transfer...")
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                output_filename = f"incremental_data_{timestamp}.parquet"

                # Get max `modified_at` from Silver layer
                latest_modified_at_from_silver = get_latest_modified_at_from_s3(bucket_name, bucket_path)

                # Filter Bronze data for incremental transfer
                data_df = data_df.filter(data_df["modified_at"] > latest_modified_at_from_silver)

                # Generate timestamped filename for incremental load
                output_filename = f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.parquet"
            
            # Count records before transfer
            record_count = data_df.count()
            
            # Step 1: Write to Temporary Location in DBFS
            temp_dbfs_path = "dbfs:/tmp/silver_layer_temp"
            data_df.coalesce(1).write.mode("overwrite").parquet(temp_dbfs_path)
            print(f"Data written temporarily to {temp_dbfs_path}")

            # Step 2: Fetch the generated Parquet file from DBFS
            temp_files = dbutils.fs.ls(temp_dbfs_path)
            parquet_file = None

            for file in temp_files:
                if file.name.endswith(".parquet"):
                    parquet_file = file.path
                    break
            
            if not parquet_file:
                raise Exception("No Parquet file found in the temporary DBFS location!")

            # Step 3: Move the File to S3
            final_s3_path = f"s3://{bucket_name}/{bucket_path}/{output_filename}"

            print()
            print(f"bucket_name: {bucket_name}")
            print(f"bucket_path: {bucket_path}")
            print(f"output_filename: {output_filename}")
            print()
            
            dbutils.fs.cp(parquet_file, final_s3_path)
            print(f"File moved to final destination: {final_s3_path}")

            # Log Success in ETL Tracker
            LayerUtils.log_etl_status(source_layer, destination_layer, source_table, destination_table, "SUCCESS", record_count)

        except Exception as e:
            print(f"Error during transfer: {str(e)}")

            # Log Failure in ETL Tracker
            LayerUtils.log_etl_status(source_layer, destination_layer, source_table, destination_table, "FAILED", 0)
    
    @staticmethod
    def read_from_bronze_layer():
        """
        Reads all Parquet files from the specified layer path and returns a Spark DataFrame.

        Args:
            layer_path (str): The path to the layer directory.
        
        Returns:
            pyspark.sql.DataFrame: A Spark DataFrame containing the data from the Parquet files.
        """
        spark = LayerUtils.initialize_spark()
        return spark.read.parquet(f"{bronze_layer_path}/{LayerUtils.folder_name_for_dbfs}")
    
    @staticmethod
    def get_source_count_from_dynamodb(source_table):
        """
        Fetches the record count of the source table from DynamoDB using scan.
        """
        try:
            dynamo_db_client = boto3.client(
                "dynamodb",
                aws_access_key_id=aws_access_key,
                aws_secret_access_key=aws_secret_access_key,
                region_name=region_name,
            )

            # Use scan to count all records
            response = dynamo_db_client.scan(
                TableName="tbl_healthcare_analytics_data_mini",
                Select="COUNT"  # Optimized for counting items
            )

            return response.get('Count', 0)  # Return count of records

        except Exception as e:
            print(f"Error fetching count from DynamoDB: {e}")
            return -1  # Indicate error with -1

    @staticmethod
    def get_count_from_dbfs(self, bronze_table_path = bronze_layer_path):
        """
        Fetches the record count from a Delta table stored in DBFS (Bronze Layer).
        Uses is_layer_empty() to check if the layer directory exists or is empty.
        """
        try:
            # Initialize a Spark session
            spark = SparkSession.builder.appName("DataExtractor").getOrCreate()

            # Check if the directory is empty
            if LayerUtils.is_layer_empty(bronze_table_path):
                print(f"Bronze layer is empty or does not exist: {bronze_table_path}")
                return 0

            # Read all Parquet files from the directory
            df = spark.read.format("parquet").load(bronze_table_path)

            if df.isEmpty():
                print(f"No data found in the Bronze layer path: {bronze_table_path}")
                return 0

            # Add file path column
            df = df.withColumn("file_path", input_file_name())

            # Extract the latest file based on filename timestamp
            latest_file = (
                df.select("file_path")
                .distinct()
                .orderBy(desc("file_path"))
                .limit(1)
                .collect()
            )

            if not latest_file:
                print("No valid files found in the specified Bronze layer path.")
                return 0

            latest_file_path = latest_file[0]["file_path"]
            print(f"Latest Parquet file detected: {latest_file_path}")

            # Read only the latest file
            latest_df = spark.read.format("parquet").load(latest_file_path)

            return latest_df.count()

        except Exception as e:
            print(f"Error fetching count from DBFS: {e}")
            return -1
    
    @staticmethod
    def get_count_from_s3(self, staging_table_path = f"s3://{bucket_name}/{bucket_path}"):
        """
        Fetches the record count from a Delta table stored in S3 (Staging Layer).
        Uses is_s3_bucket_empty() to check if the bucket/directory is empty before proceeding.
        """

        spark = LayerUtils.initialize_spark()

        try:
            # Extract bucket name and prefix (directory) from the S3 path
            bucket_name = staging_table_path.split("//")[1].split("/")[0]
            directory = "/".join(staging_table_path.split("//")[1].split("/")[1:])

            # Use the existing method to check if the S3 bucket/directory is empty
            if LayerUtils.is_s3_bucket_empty(bucket_name, directory):
                print(f"S3 bucket is empty or does not contain valid data: {staging_table_path}")
                return 0  # Return 0 if the bucket or directory is empty

            # Read the Delta table from S3
            df = spark.read.format("delta").load(staging_table_path)

            # Check if a timestamp column exists
            if "modified_at" in df.columns:
                latest_timestamp = df.select(spark_max(col("modified_at"))).collect()[0][0]

                if latest_timestamp is not None:
                    df = df.filter(col("modified_at") == latest_timestamp)

            return df.count()
        
        except Exception as e:
            print(f"Error fetching count from S3: {e}")
            return None  # Indicate error
    
    @staticmethod
    def get_latest_modified_at_from_s3(bucket_name, bucket_path):
        """
        Retrieves the latest modified timestamp from an S3 bucket path.

        :param bucket_name: Name of the S3 bucket.
        :param bucket_path: Prefix or folder path in the S3 bucket.
        :return: Latest modified timestamp as a string in 'YYYY-MM-DD HH:MM:SS' format or None if no files exist.
        """
        try:
            s3_client = boto3.client('s3')
            response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=bucket_path)
            
            if 'Contents' not in response:
                print(f"No files found in S3 path: s3://{bucket_name}/{bucket_path}")
                return None
            
            latest_file = max(response['Contents'], key=lambda x: x['LastModified'])
            latest_modified_at = latest_file['LastModified']
            
            # Convert to string format
            return latest_modified_at.strftime('%Y-%m-%d %H:%M:%S')
        
        except Exception as e:
            print(f"Error retrieving latest modified timestamp from S3: {e}")
            return None

    @staticmethod
    def log_etl_status(spark, source_layer, destination_layer, source_table, destination_table, status, source_path=None, dest_path=None):
        """
        Logs ETL progress and performs reconciliation.

        Args:
            source_layer (str): Source layer name (e.g., "source", "bronze").
            destination_layer (str): Destination layer name (e.g., "bronze", "staging").
            source_table (str): Source table name.
            destination_table (str): Destination table name.
            status (str): Status of the ETL process (e.g., "SUCCESS", "FAILED").
            source_path (str, optional): Path to the source dataset (DBFS for bronze, S3 for staging).
            dest_path (str, optional): Path to the destination dataset.
        """

        # Get current UTC timestamp
        utc_now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")

        spark = LayerUtils.initialize_spark()

        if source_layer == "source" and destination_layer == "bronze":
            source_count = LayerUtils.get_source_count_from_dynamodb(source_table)  # Fetch from DynamoDB
            destination_count = LayerUtils.get_count_from_dbfs(source_path)  # Fetch from DBFS
        elif source_layer == "bronze" and destination_layer == "staging":
            source_count = LayerUtils.get_count_from_dbfs(source_path)  # Fetch from DBFS
            destination_count = LayerUtils.get_count_from_s3(dest_path)  # Fetch from S3
        else:
            source_count = None
            destination_count = None

        # Handle None values explicitly
        source_count = source_count if source_count is not None else -1
        destination_count = destination_count if destination_count is not None else -1

        # Determine reconciliation status
        if source_count == -1 or destination_count == -1:
            reconciliation_status = "SKIPPED"
        elif source_count == destination_count:
            reconciliation_status = "MATCH"
        else:
            reconciliation_status = "MISMATCH"

        # Define schema explicitly
        schema = StructType([
            StructField("source_layer", StringType(), False),
            StructField("destination_layer", StringType(), False),
            StructField("source_table", StringType(), False),
            StructField("destination_table", StringType(), False),
            StructField("status", StringType(), False),
            StructField("source_count", IntegerType(), True),
            StructField("destination_count", IntegerType(), True),
            StructField("reconciliation_status", StringType(), False),  # Corrected column name
            StructField("completion_timestamp", StringType(), False)
        ])

        # Create DataFrame with explicit schema
        etl_log_df = spark.createDataFrame(
            [(source_layer, destination_layer, source_table, destination_table, status, source_count, destination_count, reconciliation_status, utc_now)],
            schema
        )

        
        etl_log_df.write.jdbc(
            url=LayerUtils.redshift_url,
            table=LayerUtils.etl_trakcer_tbl,
            mode="append",
            properties={
                "user": LayerUtils.redshift_user,
                "password": LayerUtils.redshift_password,
                "driver": LayerUtils.redshift_driver
            }
        )

        print(f"ETL status logged: {source_layer} → {destination_layer}, {source_table} → {destination_table}, Status: {status}, Recon: {reconciliation_status}")