In [0]:
import sys
sys.path.insert(0, "../utils")
from logger import log_silver_ingestion
import pyspark.sql.functions as F
import uuid
from delta.tables import DeltaTable
import datetime

In [0]:
# Setting Paths
BRONZE_TABLE = "nyc_taxi.bronze.yellow_taxi_trips"
SILVER_TABLE = "nyc_taxi.silver.yellow_taxi_trips"
QUARANTINE_TABLE = "nyc_taxi.quarantine.yellow_taxi_trips"
LOG_TABLE_SILVER= "nyc_taxi.logs.silver_ingestion_logs"
dataset_name="yellow_taxi_trips"
run_id=str(uuid.uuid4())

In [0]:
# Setting Types

type_mapping = {
    "year":"int",
    "month":"int",
    "vendor_id":"int",
    "pickup_datetime":"timestamp",
    "dropoff_datetime":"timestamp",
    "trip_distance":"double",
    "passenger_count":"int",
    "rate_code_id":"int",
    "pickup_location_id":"int",
    "dropoff_location_id":"int",
    "payment_type_id":"int",
    "fare_amount":"double",
    "extra_charge":"double",
    "mta_tax":"double",
    "tip_amount":"double",
    "tolls_amount":"double",
    "improvement_surcharge":"double",
    "congestion_surcharge":"double",
    "cbd_congestion_fee":"double",
    "airport_fee":"double",
    "total_amount":"double",
    "store_and_fwd_flag":"boolean",
    "bronze_id":"string"
}

# Setting Columns to build hash

hash_columns = ["vendor_id","pickup_datetime","dropoff_datetime",
                "rate_code_id","pickup_location_id","dropoff_location_id",
                "payment_type_id"]

# Validation rules for silver table
silver_rules = [
    F.when(F.col("vendor_id").isNull(), F.lit("vendor_id_null")),
    F.when(F.col("passenger_count").isNull(), F.lit("passenger_count_null")),
    F.when(~F.col("passenger_count").between(1,6), F.lit("invalid_passenger_count")),
    F.when(F.col("trip_distance").isNull(), F.lit("trip_distance_null")),
    F.when(F.col("trip_distance") < 0, F.lit("invalid_trip_distance")),
    F.when(F.col("store_and_fwd_flag").isNull(), F.lit("store_and_fwd_flag_null")),
    F.when(F.col("rate_code_id").isNull(), F.lit("rate_code_id_null")),
    F.when(F.col("pickup_location_id").isNull() | F.col("dropoff_location_id").isNull(), F.lit("pickup_dropoff_location_null")),
    F.when(F.col("pickup_datetime").isNull(), F.lit("pickup_datetime_null")),
    F.when(F.col("dropoff_datetime").isNull(), F.lit("dropoff_datetime_null")),
    F.when(F.col("pickup_datetime") > F.col("dropoff_datetime"), F.lit("pickup_datetime_after_dropoff")),
    F.when(F.col("payment_type_id").isNull(), F.lit("payment_type_id_null")),
    F.when(F.col("fare_amount").isNull(), F.lit("fare_amount_null")),
    F.when(F.col("fare_amount") < 0, F.lit("invalid_fare_amount")),
    F.when(F.col("extra_charge").isNull(), F.lit("extra_charge_null")),
    F.when(F.col("extra_charge")<0, F.lit("extra_charge_invalid")),
    F.when(F.col("mta_tax").isNull(), F.lit("mta_tax_null")),
    F.when(~F.col("mta_tax").isin(0,0.5), F.lit("mta_tax_invalid")),
    F.when(F.col("tip_amount").isNull(), F.lit("tip_amount_null")),
    F.when(F.col("tip_amount")<0, F.lit("tip_amount_null")),  
    F.when(F.col("tolls_amount").isNull(), F.lit("tolls_amount_null")),
    F.when(F.col("tolls_amount")<0, F.lit("tolls_amount_invalid")),
    F.when(F.col("improvement_surcharge").isNull(), F.lit("improvement_surcharge_null")),
    F.when(~F.col("improvement_surcharge").isin(0,0.3,1), F.lit("improvement_surcharge_invalid")),
    F.when(F.col("congestion_surcharge").isNull(), F.lit("congestion_surcharge_null")),
    F.when(~F.col("congestion_surcharge").isin(0,2.5), F.lit("congestion_surcharge_invalid")),
    F.when(F.col("airport_fee").isNull(), F.lit("airport_fee_null")),
    F.when(~F.col("airport_fee").isin(0,1.75), F.lit("airport_fee_invalid")),
    F.when(F.col("trip_duration").isNull(), F.lit("trip_duration_null")),
    F.when(F.col("trip_duration") < 0, F.lit("invalid_trip_duration"))
]

In [0]:
# --- INITIALIZATION & CONFIGURATION ---
start_ts = datetime.datetime.now()

# Initialize metrics & state flags
bronze_df = None
silver_count = quarantine_count = duplicates_dropped = bronze_count = 0
silver_survivors = quarantine_survivors = 0
success = silver_load_success = quarantine_load_success = False
error_msg = ""
init_run = 0 # Set to 1 to skip watermark retrieval for first-time setup
look_back_minutes = 5

# Ensure Logging Table exists
spark.sql(f"CREATE TABLE IF NOT EXISTS {LOG_TABLE_SILVER} USING DELTA")

try:
    # 1. WATERMARKING: Incremental Data Retrieval
    if init_run == 0:
        print("Retrieving last ingested timestamp...")
        log_history = spark.read.table(LOG_TABLE_SILVER).filter(f"status = 'Success' AND dataset_name = '{dataset_name}'")
        last_ingested_ts = log_history.agg({"max_bronze_ts": "max"}).collect()[0][0] if log_history.count() > 0 else None
    else:
        last_ingested_ts = None

    # 2. INGESTION: Read with Lookback for late-arriving data
    print("Ingesting bronze data...")
    if last_ingested_ts:
        read_watermark = last_ingested_ts - datetime.timedelta(minutes=look_back_minutes)
        bronze_df = spark.read.table(BRONZE_TABLE).filter(f"_ingest_ts > timestamp('{read_watermark}')")
    else:
        bronze_df = spark.read.table(BRONZE_TABLE)

    bronze_count = bronze_df.count()

    if bronze_count > 0:
        print(f"Processing {bronze_count} records...")

        # 3. TRANSFORMATIONS: Schema Standardization & Feature Engineering
        # Rename columns to snake_case and cast types via predefined 'type_mapping'
        df_transformed = (bronze_df
            .withColumnsRenamed({
                "tpep_pickup_datetime": "pickup_datetime", "tpep_dropoff_datetime": "dropoff_datetime",
                "VendorID": "vendor_id", "RatecodeID": "rate_code_id",
                "PULocationID": "pickup_location_id", "DOLocationID": "dropoff_location_id",
                "payment_type": "payment_type_id", "extra": "extra_charge",
                "Airport_fee": "airport_fee", "run_id": "bronze_id"
            })
            .withColumn("store_and_fwd_flag", F.col("store_and_fwd_flag") == "Y")
            .select([F.col(c).cast(t) for c, t in type_mapping.items()])
            .withColumn("trip_id", F.sha2(F.concat_ws("||", *hash_columns), 256))
            .withColumn("driver_pay", (F.col("fare_amount") + F.col("tip_amount")).cast("double"))
            .withColumn("trip_duration", (F.col('dropoff_datetime').cast("long") - F.col('pickup_datetime').cast("long")).cast("int"))
            .withColumn("ingest_ts", F.current_timestamp())
            .withColumn("run_id", F.lit(run_id))
            .dropDuplicates(["trip_id"]))

        # 4. VALIDATION: Apply DQ Rules and identify Quarantine reasons
        df_validated = df_transformed.withColumn("quarantine_reasons", F.array())
        for rule in silver_rules:
            df_validated = df_validated.withColumn(
                "quarantine_reasons",
                F.when(rule.isNotNull(), F.array_union("quarantine_reasons", F.array(rule))).otherwise(F.col("quarantine_reasons"))
            )

        # 5. SPLIT: Divide data into Silver (Clean) and Quarantine (Dirty)
        silver_df = df_validated.filter(F.size("quarantine_reasons") == 0).drop("quarantine_reasons")
        quarantine_df = df_validated.filter(F.size("quarantine_reasons") > 0)

        # 6. ATOMIC WRITES: Idempotent Merge into Target Tables
        # Write Silver
        print("Merging Silver data...")
        s_table = DeltaTable.forName(spark, SILVER_TABLE)
        s_table.alias("t").merge(silver_df.alias("s"), "t.trip_id = s.trip_id").whenNotMatchedInsertAll().execute()
        silver_load_success = True
        
        # Write Quarantine
        print("Merging Quarantine data...")
        q_table = DeltaTable.forName(spark, QUARANTINE_TABLE)
        q_table.alias("t").merge(quarantine_df.alias("s"), "t.trip_id = s.trip_id").whenNotMatchedInsertAll().execute()
        quarantine_load_success = True

        # 7. METRICS: Efficient Audit using Delta Transaction Logs
        s_metrics = s_table.history(1).select("operationMetrics").collect()[0][0]
        q_metrics = q_table.history(1).select("operationMetrics").collect()[0][0]
        
        silver_count = int(s_metrics.get("numTargetRowsInserted", 0))
        quarantine_count = int(q_metrics.get("numTargetRowsInserted", 0))
        
        # Calculate duplicates by comparing Bronze Input vs records sent to Merge
        silver_survivors = int(s_metrics.get("numSourceRows", 0))
        quarantine_survivors = int(q_metrics.get("numSourceRows", 0))
        duplicates_dropped = bronze_count - (silver_survivors + quarantine_survivors)

        max_bronze_ts = df_validated.select(F.max("ingest_ts")).collect()[0][0]
        success = True
        print(f"Job Successful. Silver: {silver_count}, Quarantine: {quarantine_count}, Dropped: {duplicates_dropped}")

    else:
        max_bronze_ts = last_ingested_ts
        error_msg = "No new data to process"
        success=True
        print(error_msg)

except Exception as e:
    error_msg = str(e)
    print(f"Job Failed: {error_msg}")
    
    # Automatic Rollback using Delta Time Travel
    for table, flag in [(SILVER_TABLE, silver_load_success), (QUARANTINE_TABLE, quarantine_load_success)]:
        if flag:
            print(f"Rolling back {table}...")
            history = spark.sql(f"DESCRIBE HISTORY {table}").collect()
            if len(history) > 1:
                prev_v = history[0][0] - 1
                spark.sql(f"RESTORE TABLE {table} TO VERSION AS OF {prev_v}")

finally:
    # Record audit log
    end_ts = datetime.datetime.now()
    
    log_silver_ingestion(
        spark=spark, run_id=run_id, dataset_name=dataset_name, start_ts=start_ts, 
        end_ts=end_ts, max_bronze_ts=max_bronze_ts, success=success, 
        bronze_count=bronze_count, silver_count=silver_count, 
        quarantine_count=quarantine_count, duplicates_dropped=duplicates_dropped,
        error_msg=error_msg, log_table=LOG_TABLE_SILVER, catalog_table=SILVER_TABLE
    )