In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType, TimestampType
from pyspark.sql.functions import col, lit, to_date, to_timestamp
from delta.tables import DeltaTable
from pyspark.sql.utils import AnalysisException
import sys
from datetime import date, datetime, timedelta

# ----------------------------------------------------------------------------------
# This script demonstrates a 3-stage historical data loading pipeline:
# 1. CSV (Source) -> Parquet (Intermediate)
# 2. Parquet (Intermediate) -> Delta (Target)
# 3. Efficient Incremental Load using MERGE.
# ----------------------------------------------------------------------------------

# --- Configuration for AWS/Simulation ---
# This path is where the final Delta Lake table resides (e.g., s3a://datalake-bucket/tables/products)
DATALAKE_DELTA_PATH = "/tmp/delta/product_sales_delta_table"
DATALAKE_S3_PATH = DATALAKE_DELTA_PATH

# Source path (assumed CSV for the petabyte load)
RAW_S3_PATH = "s3a://your-raw-bucket/historical_sales_data_csv/"
# Intermediate path for high-performance Parquet storage
INTERMEDIATE_PARQUET_PATH = "/tmp/parquet/historical_parquet_intermediate/"

# TOGGLE FLAG: Set to False when running in a real AWS environment with data at RAW_S3_PATH
USE_SIMULATED_HISTORY_DATA = True

# 1. Create a SparkSession configured for Delta Lake
print("Initializing Spark Session configured for Delta Lake...")
spark = SparkSession.builder \
    .appName("DeltaDataLoadingStrategies") \
    .master("local[*]") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

# 2. Define the schema for the final Delta table (Parquet stage will infer from this)
# NOTE: When reading CSV, we often use StringType for all and then cast/transform.
schema_target = StructType([
    StructField("id", IntegerType(), True),
    StructField("product_name", StringType(), True),
    StructField("price", DoubleType(), True),
    StructField("purchase_date", DateType(), True),
    StructField("last_updated", TimestampType(), True),
    StructField("year", IntegerType(), True) # Partition Key
])

# --- UTILITY FUNCTIONS ---

def prepare_data_for_load(df):
    """
    Transforms the DataFrame by ensuring correct types and deriving the 'year' column.
    Assumes incoming date columns might be strings if read from raw CSV/JSON.

    NOTE: The date formats below must match the formats in your raw CSV data.
    """
    # Assuming the raw input purchase_date is 'MM/dd/yyyy' and last_updated is 'yyyy-MM-dd HH:mm:ss'
    df_transformed = df.withColumn("purchase_date", to_date(col("purchase_date"), "MM/dd/yyyy")) \
                       .withColumn("last_updated", to_timestamp(col("last_updated"), "yyyy-MM-dd HH:mm:ss"))

    # Derive the year partition key
    return df_transformed.withColumn("year", col("purchase_date").substr(1, 4).cast(IntegerType()))

def create_simulated_data(is_history=True):
    """Creates two different simulated datasets for the two loading types."""
    # Note: Using String format for dates/timestamps here to simulate CSV raw read
    current_ts_str = datetime.now().strftime("2024-07-20 10:30:00") # Fixed for simulation simplicity

    if is_history:
        # Simulate a large, multi-year dataset (History)
        print("Simulating large historical dataset across multiple years (raw CSV format)...")
        history_data = [
            (101, "Server Rack", 5000.00, "01/15/2022", current_ts_str),
            (102, "Fiber Cable", 150.00, "12/01/2022", current_ts_str),
            (201, "Laptop Pro", 2200.00, "03/22/2023", current_ts_str),
            (202, "Monitor 4K", 650.00, "10/10/2023", current_ts_str),
            (301, "Keyboard Mech", 120.00, "01/01/2024", current_ts_str), # Target for incremental update
            (302, "Mouse Erg", 75.00, "06/05/2024", current_ts_str)
        ]
        # Use StringType for date/timestamp fields to simulate raw CSV read
        raw_schema = StructType([
            StructField("id", IntegerType(), True),
            StructField("product_name", StringType(), True),
            StructField("price", DoubleType(), True),
            StructField("purchase_date", StringType(), True),
            StructField("last_updated", StringType(), True)
        ])
        return spark.createDataFrame(history_data, schema=raw_schema)

    else:
        # Simulate a small, incremental batch (New data for 2024 and 2025)
        print("Simulating small incremental dataset (new and updated records)...")
        next_year_str = str(datetime.now().year + 1)
        current_ts_incr_str = datetime.now().strftime("2024-07-20 10:30:01")

        incremental_data = [
            # 1. Update existing record (ID 301 from 2024)
            (301, "Keyboard Mech", 99.00, "01/01/2024", current_ts_incr_str),
            # 2. Insert new record for current year (2024)
            (303, "Webcam HD", 50.00, "10/05/2024", current_ts_incr_str),
            # 3. Insert new record for a future partition (2025)
            (401, "VR Headset", 1500.00, "01/01/" + next_year_str, current_ts_incr_str)
        ]
        # Use StringType for date/timestamp fields to simulate raw CSV read
        raw_schema = StructType([
            StructField("id", IntegerType(), True),
            StructField("product_name", StringType(), True),
            StructField("price", DoubleType(), True),
            StructField("purchase_date", StringType(), True),
            StructField("last_updated", StringType(), True)
        ])
        return spark.createDataFrame(incremental_data, schema=raw_schema)


def check_path_existence(spark_session: SparkSession, path: str, format_type: str = "delta") -> bool:
    # Simplified path check logic for demonstration
    if format_type == "delta" and DeltaTable.isDeltaTable(spark_session, path):
        return True
    return False

# --- LOAD STRATEGIES ---

def ingest_historical_data_aws_glue(spark):
    """
    1. HISTORY LOAD (Petabytes): Implements the two-stage conversion: CSV -> Parquet -> Delta.
    """
    print("\n" + "=" * 60)
    print("STEP 1: AWS GLUE/SPARK HISTORICAL INGESTION (3-STAGE CONVERSION)")
    print("Stage 1: CSV -> Parquet (Intermediate)")
    print("Stage 2: Parquet -> Delta (Target, Partitioned by 'year')")
    print("=" * 60)

    # --- STAGE 1A: READ RAW CSV DATA ---
    if USE_SIMULATED_HISTORY_DATA:
        df_history_raw = create_simulated_data(is_history=True)
        print(f"Reading {df_history_raw.count()} simulated records (Simulation Mode).")
    else:
        print(f"Attempting to read massive CSV data from S3 Source: {RAW_S3_PATH}...")
        try:
            # Assumes CSV format with header for the petabyte source
            df_history_raw = spark.read.csv(
                RAW_S3_PATH,
                header=True,
                inferSchema=False # Use StringType for max compatibility with raw CSV
            )
            print(f"Successfully read {df_history_raw.count()} records from S3.")
        except AnalysisException as e:
            print(f"ERROR: Could not read data from {RAW_S3_PATH}. Check S3 path and IAM permissions.")
            raise e

    # --- STAGE 1B: TRANSFORM AND WRITE TO INTERMEDIATE PARQUET ---
    # Apply type casting and derive the year column
    df_history_transformed = prepare_data_for_load(df_history_raw)

    print("\n--- Writing to Intermediate Parquet ---")
    df_history_transformed.write \
        .format("parquet") \
        .mode("overwrite") \
        .partitionBy("year") \
        .save(INTERMEDIATE_PARQUET_PATH)

    print(f"Intermediate Parquet Load Complete at {INTERMEDIATE_PARQUET_PATH}")
    print(f"Parquet format offers high read performance for the next stage.")

    # --- STAGE 2: READ INTERMEDIATE PARQUET AND WRITE TO DELTA ---
    print("\n--- Reading Parquet and Writing to Final Delta ---")
    df_intermediate = spark.read.format("parquet").load(INTERMEDIATE_PARQUET_PATH)

    df_intermediate.write \
        .format("delta") \
        .mode("overwrite") \
        .partitionBy("year") \
        .save(DATALAKE_DELTA_PATH)

    print(f"Final Delta Load Complete at {DATALAKE_DELTA_PATH}")

    print("Initial Delta Table Content:")
    DeltaTable.forPath(spark, DATALAKE_DELTA_PATH).history().select("version", "operation").show(1, truncate=False)
    spark.read.format("delta").load(DATALAKE_DELTA_PATH).orderBy("id").show()


def load_incremental_data(spark):
    """
    2. INCREMENTAL LOAD: Efficiently updates existing data and inserts new records
       using Delta MERGE, leveraging the 'year' partition key.
    """
    df_source_increment_raw = create_simulated_data(is_history=False)
    df_source_increment = prepare_data_for_load(df_source_increment_raw)

    print("\n" + "=" * 60)
    print("STEP 2: INCREMENTAL LOAD (Partition Key 'year')")
    print("Applying incremental changes via MERGE (Update, Insert).")
    print("=" * 60)

    if not check_path_existence(spark, DATALAKE_DELTA_PATH, format_type="delta"):
        print(f"FATAL ERROR: Delta Table not initialized at {DATALAKE_DELTA_PATH}. Aborting incremental load.")
        return

    delta_table = DeltaTable.forPath(spark, DATALAKE_DELTA_PATH)

    # Delta MERGE based on ID and YEAR (Crucial for Partition Alignment)
    delta_table.alias("target") \
        .merge(
            source = df_source_increment.alias("source"),
            # Match on both primary key (id) and partition key (year) for precision
            condition = "target.id = source.id AND target.year = source.year"
        ) \
        .whenMatchedUpdate( # Update logic (e.g., price changed)
            condition = "target.price != source.price",
            set = {
                "price": col("source.price"),
                "last_updated": lit(datetime.now())
            }
        ) \
        .whenNotMatchedInsertAll() \
        .execute()

    print("Incremental Load Complete.")

    print("Final Delta Table Content after MERGE:")
    spark.read.format("delta").load(DATALAKE_DELTA_PATH).orderBy("id").show()

try:
    # --- EXECUTION FLOW ---

    # 1. Perform the initial History Load (CSV -> Parquet -> Delta)
    ingest_historical_data_aws_glue(spark)

    # 2. Perform the subsequent Incremental Load
    load_incremental_data(spark)

except Exception as e:
    print(f"\n--- ERROR ---")
    print(f"An error occurred during data loading: {e}")
    sys.exit(1)

finally:
    # Stop the SparkSession
    print("-" * 60)
    spark.stop()
    print("Spark Session stopped.")