In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from datetime import datetime, timedelta
import logging

In [0]:
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [0]:
class DataAcquisitionProcessor:
    def __init__(self):
        self.watermark_table = "data_acquisition.watermark_control"
        self.pre_stage_table = "data_acquisition.pre_stage_data"
        self.stage_table = "data_acquisition.stage_data"
        self.reconciliation_table = "data_acquisition.reconciliation_log"
        
    def setup_database_and_tables(self):
        """Create database and required tables"""
        logger.info("Setting up database and tables...")
        
        # Create database
        spark.sql("CREATE DATABASE IF NOT EXISTS data_acquisition")
        
        # Create watermark control table
        spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.watermark_table} (
                source_system STRING,
                table_name STRING,
                last_processed_timestamp TIMESTAMP,
                watermark_column STRING,
                process_date DATE,
                status STRING,
                created_at TIMESTAMP,
                updated_at TIMESTAMP
            ) USING DELTA
        """)
        
        # Create pre-stage table
        spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.pre_stage_table} (
                id LONG,
                customer_id STRING,
                product_id STRING,
                transaction_amount DOUBLE,
                transaction_date TIMESTAMP,
                status STRING,
                source_system STRING,
                load_timestamp TIMESTAMP,
                batch_id STRING
            ) USING DELTA
        """)
        
        # Create stage table
        spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.stage_table} (
                id LONG,
                customer_id STRING,
                product_id STRING,
                transaction_amount DOUBLE,
                transaction_date TIMESTAMP,
                status STRING,
                source_system STRING,
                load_timestamp TIMESTAMP,
                batch_id STRING,
                processed_timestamp TIMESTAMP
            ) USING DELTA
        """)
        
        # Create reconciliation log table
        spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {self.reconciliation_table} (
                batch_id STRING,
                source_table STRING,
                target_table STRING,
                source_count LONG,
                target_count LONG,
                status STRING,
                variance LONG,
                process_timestamp TIMESTAMP,
                remarks STRING
            ) USING DELTA
        """)
        
        logger.info("Database and tables setup completed.")
    
    def get_watermark(self, source_system, table_name):
        """Get the last processed watermark"""
        logger.info(f"Getting watermark for {source_system}.{table_name}")
        
        try:
            watermark_df = spark.sql(f"""
                SELECT last_processed_timestamp 
                FROM {self.watermark_table} 
                WHERE source_system = '{source_system}' 
                AND table_name = '{table_name}'
                AND status = 'ACTIVE'
                ORDER BY updated_at DESC 
                LIMIT 1
            """)
            
            if watermark_df.count() > 0:
                watermark = watermark_df.collect()[0]['last_processed_timestamp']
                logger.info(f"Found watermark: {watermark}")
                return watermark
            else:
                # Default watermark (24 hours ago)
                default_watermark = datetime.now() - timedelta(days=1)
                logger.info(f"No watermark found, using default: {default_watermark}")
                return default_watermark
        except Exception as e:
            logger.error(f"Error getting watermark: {e}")
            return datetime.now() - timedelta(days=1)
    
    def update_watermark(self, source_system, table_name, new_watermark, batch_id):
        """Update the watermark after successful processing"""
        logger.info(f"Updating watermark for {source_system}.{table_name}")
        
        try:
            # Insert or update watermark
            watermark_data = [(
                source_system,
                table_name,
                new_watermark,
                "transaction_date",
                datetime.now().date(),
                "ACTIVE",
                datetime.now(),
                datetime.now()
            )]
            
            watermark_df = spark.createDataFrame(watermark_data, [
                "source_system", "table_name", "last_processed_timestamp", 
                "watermark_column", "process_date", "status", "created_at", "updated_at"
            ])
            
            watermark_df.write.mode("append").insertInto(self.watermark_table)
            
            logger.info("Watermark updated successfully")
        except Exception as e:
            logger.error(f"Error updating watermark: {e}")
    
    def load_data_to_pre_stage(self, batch_id):
        """Load data from source to pre-stage table"""
        logger.info("Loading data to pre-stage table...")
        
        try:
            # Read source data using S3A protocol
            source_df = spark.read.format("csv").option("header", "true").load("s3://da-process/sample_transactions.csv")
            
            # Add metadata columns
            pre_stage_df = source_df.withColumn("load_timestamp", current_timestamp()) \
                                  .withColumn("batch_id", lit(batch_id))
            
            # Write to pre-stage table
            pre_stage_df.write.mode("append").insertInto(self.pre_stage_table)
            
            source_count = pre_stage_df.count()
            logger.info(f"Loaded {source_count} records to pre-stage table")
            
            return source_count
        except Exception as e:
            logger.error(f"Error loading data to pre-stage: {e}")
            raise
    
    def load_data_to_stage(self, batch_id):
        """Load data from pre-stage to stage table with transformations"""
        logger.info("Loading data to stage table...")
        
        try:
            # Read from pre-stage
            pre_stage_df = spark.sql(f"""
                SELECT * FROM {self.pre_stage_table} 
                WHERE batch_id = '{batch_id}'
            """)
            
            # Apply transformations (example: data quality checks, business rules)
            stage_df = pre_stage_df.filter(col("transaction_amount") > 0) \
                                 .filter(col("customer_id").isNotNull()) \
                                 .filter(col("product_id").isNotNull()) \
                                 .withColumn("processed_timestamp", current_timestamp())
            
            # Write to stage table
            stage_df.write.mode("append").insertInto(self.stage_table)
            
            target_count = stage_df.count()
            logger.info(f"Loaded {target_count} records to stage table")
            
            return target_count
        except Exception as e:
            logger.error(f"Error loading data to stage: {e}")
            raise
    
    def perform_reconciliation(self, batch_id, source_count, target_count):
        """Perform reconciliation between source and target counts"""
        logger.info("Performing reconciliation...")
        
        try:
            variance = source_count - target_count
            
            # Determine status based on variance
            if variance == 0:
                status = "PASS"
                remarks = "Record counts match perfectly"
            elif variance > 0:
                status = "WARN"
                remarks = f"Source has {variance} more records than target"
            else:
                status = "FAIL"
                remarks = f"Target has {abs(variance)} more records than source"
            
            # Log reconciliation results
            recon_data = [(
                batch_id,
                "source_data",
                "stage_data",
                source_count,
                target_count,
                status,
                variance,
                datetime.now(),
                remarks
            )]
            
            recon_df = spark.createDataFrame(recon_data, [
                "batch_id", "source_table", "target_table", "source_count", 
                "target_count", "status", "variance", "process_timestamp", "remarks"
            ])
            
            recon_df.write.mode("append").insertInto(self.reconciliation_table)
            
            logger.info(f"Reconciliation completed: {status} - {remarks}")
            
            return {
                "status": status,
                "source_count": source_count,
                "target_count": target_count,
                "variance": variance,
                "remarks": remarks
            }
        except Exception as e:
            logger.error(f"Error during reconciliation: {e}")
            raise
    
    def get_process_status(self, batch_id):
        """Get the current status of the data acquisition process"""
        logger.info(f"Getting process status for batch: {batch_id}")
        
        try:
            # Get reconciliation status
            status_df = spark.sql(f"""
                SELECT 
                    batch_id,
                    source_count,
                    target_count,
                    variance,
                    status,
                    remarks,
                    process_timestamp
                FROM {self.reconciliation_table}
                WHERE batch_id = '{batch_id}'
                ORDER BY process_timestamp DESC
                LIMIT 1
            """)
            
            if status_df.count() > 0:
                return status_df.collect()[0].asDict()
            else:
                return {"status": "NOT_FOUND", "message": "No reconciliation record found"}
        except Exception as e:
            logger.error(f"Error getting process status: {e}")
            return {"status": "ERROR", "message": str(e)}
    
    def run_data_acquisition_cycle(self):
        """Run the complete data acquisition cycle"""
        logger.info("Starting Data Acquisition Cycle...")
        
        batch_id = f"BATCH_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        
        try:
            # Step 1: Setup
            self.setup_database_and_tables()
            
            # Step 3: Get watermark
            watermark = self.get_watermark("SYSTEM_A", "transactions")
            
            # Step 4: Load to pre-stage
            source_count = self.load_data_to_pre_stage(batch_id)
            
            # Step 5: Load to stage
            target_count = self.load_data_to_stage(batch_id)
            
            # Step 6: Perform reconciliation
            recon_result = self.perform_reconciliation(batch_id, source_count, target_count)
            
            # Step 7: Update watermark if successful
            if recon_result["status"] in ["PASS", "WARN"]:
                new_watermark = datetime.now()
                self.update_watermark("SYSTEM_A", "transactions", new_watermark, batch_id)
            
            # Step 8: Get final status
            final_status = self.get_process_status(batch_id)
            
            logger.info("Data Acquisition Cycle completed successfully")
            
            return {
                "batch_id": batch_id,
                "cycle_status": "COMPLETED",
                "reconciliation": recon_result,
                "final_status": final_status
            }
            
        except Exception as e:
            logger.error(f"Data Acquisition Cycle failed: {e}")
            return {
                "batch_id": batch_id,
                "cycle_status": "FAILED",
                "error": str(e)
            }


In [0]:
 # Initialize the processor
processor = DataAcquisitionProcessor()

# Run the data acquisition cycle
result = processor.run_data_acquisition_cycle()

# Display results
print("\n" + "="*50)
print("DATA ACQUISITION PROCESS RESULTS")
print("="*50)
print(f"Batch ID: {result['batch_id']}")
print(f"Cycle Status: {result['cycle_status']}")

if result['cycle_status'] == 'COMPLETED':
    recon = result['reconciliation']
    print(f"Reconciliation Status: {recon['status']}")
    print(f"Source Count: {recon['source_count']}")
    print(f"Target Count: {recon['target_count']}")
    print(f"Variance: {recon['variance']}")
    print(f"Remarks: {recon['remarks']}")
else:
    print(f"Error: {result.get('error', 'Unknown error')}")

print("="*50)


In [0]:
# Display table contents for verification
print("\nTABLE CONTENTS:")
print("-" * 30)

# Pre-stage table
print("PRE-STAGE TABLE:")
spark.sql("SELECT * FROM data_acquisition.pre_stage_data").show(5)

# Stage table
print("STAGE TABLE:")
spark.sql("SELECT * FROM data_acquisition.stage_data").show(5)

# Reconciliation log
print("RECONCILIATION LOG:")
spark.sql("SELECT * FROM data_acquisition.reconciliation_log").show(5)

# Watermark control
print("WATERMARK CONTROL:")
spark.sql("SELECT * FROM data_acquisition.watermark_control").show(5)