In [None]:
%run ./Global_Configurations

In [None]:
%run ./Layer_Utilities_NB

In [None]:
class DataExtractor:
    """
    Extracts data from an AWS DynamoDB table and loads it into a Bronze layer in a data lake.
    Supports both full and incremental extraction based on the `modified_at` timestamp.
    """

    def __init__(self, dynamo_table_name, bronze_layer_path, aws_access_key, aws_secret_access_key, region_name, max_retries=3, retry_delay=5):
        """
        Initializes the DataExtractor class with AWS credentials, Spark session, and table details.

        Args:
            dynamo_table_name (str): Name of the DynamoDB table to extract data from.
            bronze_layer_path (str): Path to the Bronze layer where extracted data will be stored.
            aws_access_key (str): AWS access key for authentication.
            aws_secret_access_key (str): AWS secret access key for authentication.
            region_name (str): AWS region where the DynamoDB table is hosted.
            max_retries (int): Maximum number of retry attempts for connections.
            retry_delay (int): Delay (in seconds) between retries.
        """
        self.dynamo_table_name = dynamo_table_name
        self.bronze_layer_path = bronze_layer_path
        self.max_retries = max_retries
        self.retry_delay = retry_delay

        # Initialize Spark session with retry logic
        self.spark = self._initialize_spark()

        # Initialize DynamoDB client with retry logic
        self.dynamodb_client = self._initialize_dynamodb(aws_access_key, aws_secret_access_key, region_name)

    def _initialize_spark(self):
        """Retries Spark session initialization in case of failures."""
        for attempt in range(self.max_retries):
            try:
                return SparkSession.builder.appName("DataExtractor").getOrCreate()
            except Exception as e:
                print(f"Spark initialization failed (Attempt {attempt + 1}/{self.max_retries}): {str(e)}")
                time.sleep(self.retry_delay)
        raise Exception("Failed to initialize Spark after multiple attempts.")

    def _initialize_dynamodb(self, aws_access_key, aws_secret_access_key, region_name):
        """Retries DynamoDB client initialization in case of failures."""
        for attempt in range(self.max_retries):
            try:
                return boto3.client(
                    "dynamodb",
                    aws_access_key_id=aws_access_key,
                    aws_secret_access_key=aws_secret_access_key,
                    region_name=region_name,
                )
            except (BotoCoreError, NoCredentialsError, ClientError) as e:
                print(f"DynamoDB initialization failed (Attempt {attempt + 1}/{self.max_retries}): {str(e)}")
                time.sleep(self.retry_delay)
        raise Exception("Failed to initialize DynamoDB after multiple attempts.")

    def _get_dynamo_data(self, filter_expression=None):
        """Extracts data from DynamoDB with optional filtering."""
        scan_kwargs = {}
        if filter_expression:
            scan_kwargs["FilterExpression"] = filter_expression

        items = []
        while True:
            try:
                response = self.dynamodb_client.scan(TableName=self.dynamo_table_name, **scan_kwargs)
                items.extend(response.get("Items", []))
                
                if "LastEvaluatedKey" not in response:
                    break

                scan_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"]

            except (BotoCoreError, ClientError) as e:
                print(f"Error scanning DynamoDB: {str(e)}")
                break  # Exit loop on failure

        if not items:
            return self.spark.createDataFrame([], self.spark.createDataFrame([{}]).schema)

        # Convert DynamoDB items to Spark DataFrame
        data = [{k: list(v.values())[0] for k, v in item.items()} for item in items]
        return self.spark.createDataFrame(data)
    
    def extract_data(self):
        """
        Extracts data from the DynamoDB table and loads it into the Bronze layer.
        
        - If the Bronze layer is empty, performs a full extraction.
        - If the Bronze layer contains data, performs an incremental extraction based on `modified_at`.
        
        Logs extraction status in the `gold.etl_tracker` table only after extraction completion.
        """
        source_layer = "source"
        destination_layer = "bronze"
        source_table = self.dynamo_table_name
        destination_table = self.bronze_layer_path
        current_timestamp = datetime.utcnow().isoformat()

        try:
            if LayerUtils.is_layer_empty(self.bronze_layer_path):
                print("Bronze layer is empty. Performing full extraction...")
                data_df = self._get_dynamo_data()
            else:
                print("Bronze layer is not empty. Performing incremental extraction...")
                bronze_data = self._safe_read_parquet(self.bronze_layer_path)

                # Get the latest `modified_at` timestamp from Bronze layer
                latest_modified_at = bronze_data.agg({"modified_at": "max"}).collect()[0][0]

                # Define filter for incremental extraction
                filter_expression = f"modified_at > '{latest_modified_at}' AND modified_at <= '{current_timestamp}'"
                data_df = self._get_dynamo_data(filter_expression)

            # Write extracted data to Bronze layer
            extracted_count = data_df.count()
            self._write_to_bronze_layer(data_df)

            # Log success **AFTER** extraction and writing are complete
            LayerUtils.log_etl_status(self.spark, source_layer, destination_layer, source_table, destination_table, "SUCCESS", extracted_count)

        except Exception as e:
            print(f"Error during extraction: {str(e)}")

            # Log failure in case of an error
            LayerUtils.log_etl_status(self.spark, source_layer, destination_layer, source_table, destination_table, "FAILED", 0)

    def _safe_read_parquet(self, path):
        """Reads a Parquet file safely, handling missing/corrupt files gracefully."""
        try:
            return self.spark.read.parquet(path)
        except AnalysisException as e:
            print(f"Error reading Parquet file: {str(e)}")
            return self.spark.createDataFrame([], self.spark.read.parquet(path).schema)

    def _write_to_bronze_layer(self, data_df):
        """
        Writes extracted data to the Bronze layer.
        
        - Overwrites for full extraction.
        - Appends for incremental extraction.
        """
        LayerUtils.write_to_layer(data_df, self.bronze_layer_path)