In [0]:
import sys
sys.path.insert(0, "../utils")
from logger import log_silver_ingestion
import pyspark.sql.functions as F
import uuid
from delta.tables import DeltaTable
import datetime

In [0]:
# Setting Paths
BRONZE_TABLE = "nyc_taxi.bronze.yellow_taxi_trips"
SILVER_TABLE = "nyc_taxi.silver.yellow_taxi_trips"
QUARANTINE_TABLE = "nyc_taxi.quarantine.yellow_taxi_trips"
LOG_TABLE_SILVER= "nyc_taxi.logs.silver_ingestion_logs"
dataset_name="yellow_taxi_trips"
run_id=str(uuid.uuid4())

In [0]:
# Setting Types

type_mapping = {
    "year":"int",
    "month":"int",
    "vendor_id":"int",
    "pickup_datetime":"timestamp",
    "dropoff_datetime":"timestamp",
    "trip_distance":"double",
    "passenger_count":"int",
    "rate_code_id":"int",
    "pickup_location_id":"int",
    "dropoff_location_id":"int",
    "payment_type_id":"int",
    "fare_amount":"double",
    "extra_charge":"double",
    "mta_tax":"double",
    "tip_amount":"double",
    "tolls_amount":"double",
    "improvement_surcharge":"double",
    "congestion_surcharge":"double",
    "cbd_congestion_fee":"double",
    "airport_fee":"double",
    "total_amount":"double",
    "store_and_fwd_flag":"boolean",
    "bronze_id":"string"
}

# Setting Columns to build hash

hash_columns = ["vendor_id","pickup_datetime","dropoff_datetime",
                "rate_code_id","pickup_location_id","dropoff_location_id",
                "payment_type_id"]

# Validation rules for silver table
silver_rules = [
    F.when(F.col("vendor_id").isNull(), F.lit("vendor_id_null")),
    F.when(F.col("passenger_count").isNull(), F.lit("passenger_count_null")),
    F.when(~F.col("passenger_count").between(1,6), F.lit("invalid_passenger_count")),
    F.when(F.col("trip_distance").isNull(), F.lit("trip_distance_null")),
    F.when(F.col("trip_distance") < 0, F.lit("invalid_trip_distance")),
    F.when(F.col("store_and_fwd_flag").isNull(), F.lit("store_and_fwd_flag_null")),
    F.when(F.col("rate_code_id").isNull(), F.lit("rate_code_id_null")),
    F.when(F.col("pickup_location_id").isNull() | F.col("dropoff_location_id").isNull(), F.lit("pickup_dropoff_location_null")),
    F.when(F.col("pickup_datetime").isNull(), F.lit("pickup_datetime_null")),
    F.when(F.col("dropoff_datetime").isNull(), F.lit("dropoff_datetime_null")),
    F.when(F.col("pickup_datetime") > F.col("dropoff_datetime"), F.lit("pickup_datetime_after_dropoff")),
    F.when(F.col("payment_type_id").isNull(), F.lit("payment_type_id_null")),
    F.when(F.col("fare_amount").isNull(), F.lit("fare_amount_null")),
    F.when(F.col("fare_amount") < 0, F.lit("invalid_fare_amount")),
    F.when(F.col("extra_charge").isNull(), F.lit("extra_charge_null")),
    F.when(F.col("extra_charge")<0, F.lit("extra_charge_invalid")),
    F.when(F.col("mta_tax").isNull(), F.lit("mta_tax_null")),
    F.when(~F.col("mta_tax").isin(0,0.5), F.lit("mta_tax_invalid")),
    F.when(F.col("tip_amount").isNull(), F.lit("tip_amount_null")),
    F.when(F.col("tip_amount")<0, F.lit("tip_amount_null")),  
    F.when(F.col("tolls_amount").isNull(), F.lit("tolls_amount_null")),
    F.when(F.col("tolls_amount")<0, F.lit("tolls_amount_invalid")),
    F.when(F.col("improvement_surcharge").isNull(), F.lit("improvement_surcharge_null")),
    F.when(~F.col("improvement_surcharge").isin(0,0.3), F.lit("improvement_surcharge_invalid")),
    F.when(F.col("congestion_surcharge").isNull(), F.lit("congestion_surcharge_null")),
    F.when(~F.col("congestion_surcharge").isin(0,2.5), F.lit("congestion_surcharge_invalid")),
    F.when(F.col("airport_fee").isNull(), F.lit("airport_fee_null")),
    F.when(~F.col("airport_fee").isin(0,1.75), F.lit("airport_fee_invalid")),
    F.when(F.col("trip_duration").isNull(), F.lit("trip_duration_null")),
    F.when(F.col("trip_duration") < 0, F.lit("invalid_trip_duration"))
]

In [0]:
# Logging job start
start_ts = datetime.datetime.now()
# Initialization   
bronze_df = None
silver_df = None
quarantine_df = None
max_bronze_ts = None
last_ingested_ts = None
success = False
error_msg = ""
silver_load_success= False
quarantine_load_success= False
# Set for initial run indication
init_run=0
# Setting lookback window
look_back_minutes = 5

spark.sql(f"CREATE TABLE IF NOT EXISTS {LOG_TABLE_SILVER} USING DELTA")

try:
    #only retrieve for non-initial run
    if init_run ==0:

    # Retrieve last successful silver run timestamp
        print("Retrieving last ingested timestamp...")
        log_history = spark.read.table(LOG_TABLE_SILVER).filter(f"status = 'Success' AND dataset_name = '{dataset_name}'")
        last_ingested_ts = log_history.agg({"max_bronze_ts": "max"}).collect()[0][0] if log_history.count() > 0 else None

    # Ingest Bronze data incrementally
        print("Ingesting bronze data...")
    if last_ingested_ts is not None:
        read_watermark_ts = last_ingested_ts - datetime.timedelta(minutes=look_back_minutes)
        bronze_df = spark.read.table(BRONZE_TABLE).filter(f"ingest_ts > timestamp('{read_watermark_ts}')")
    else:
        bronze_df = spark.read.table(BRONZE_TABLE)  # first run, ingest all

    bronze_count=bronze_df.count()

    if bronze_count > 0:
        print(f"Ingested {bronze_count} records.")
    # Standardizing Column Names and casting data types
        df_renamed = bronze_df.withColumnsRenamed(
            {
                "tpep_pickup_datetime": "pickup_datetime",
                "tpep_dropoff_datetime": "dropoff_datetime",
                "VendorID":"vendor_id",
                "RatecodeID":"rate_code_id",
                "PULocationID":"pickup_location_id",
                "DOLocationID":"dropoff_location_id",
                "payment_type":"payment_type_id",
                "extra":"extra_charge",
                "Airport_fee":"airport_fee",
                "run_id":"bronze_id"
            }
        )

        df_renamed= df_renamed.withColumn("store_and_fwd_flag",F.when(F.col("store_and_fwd_flag")=="Y",True).otherwise(False))
        df_cast= df_renamed.select([F.col(col).cast(dtype) for col,dtype in type_mapping.items()])

        # Transforming Data
        df_with_hash = df_cast.withColumn("trip_id",F.sha2(F.concat_ws("||", *hash_columns), 256))
        df_with_pay=df_with_hash.withColumn("driver_pay",(F.col("fare_amount")+F.col("tip_amount")).cast("double"))
        df_with_duration = df_with_pay.withColumn("trip_duration",((F.col('dropoff_datetime').cast("long")-F.col('pickup_datetime').cast("long")).cast("int")))
        #Metadata inclusion
        df_with_metadata=df_with_duration.withColumn("ingest_ts",F.current_timestamp()).withColumn("run_id",F.lit(run_id))

        # Deduplication
        df_deduped=df_with_metadata.dropDuplicates(["trip_id"])

        # Validating business and data quality rules
        df_validated = df_deduped.withColumn("quarantine_reasons", F.array())
        for rule in silver_rules:
            df_validated = df_validated.withColumn(
                "quarantine_reasons",
                F.when(rule.isNotNull(), F.array_union(F.col("quarantine_reasons"), F.array(rule)))
                .otherwise(F.col("quarantine_reasons"))
            )

        # Splitting Data into Silver and Quarantine Tables
        silver_df = df_validated.filter(F.size(F.col("quarantine_reasons")) == 0).drop("quarantine_reasons")
        quarantine_df = df_validated.filter(F.size(F.col("quarantine_reasons")) > 0)



        # Writing silver data via merge, inserting new rows only
        print("Writing to silver table...")
        silver_delta = DeltaTable.forName(spark, SILVER_TABLE)
        silver_delta.alias("t").merge(
            silver_df.alias("s"),
            "t.trip_id = s.trip_id",
            )\
            .whenNotMatchedInsertAll().execute()

        silver_load_success=True
        silver_count = silver_df.count()

        print("Silver write successful.")
        print(f"Inserted {silver_count} records into silver table.")
        # Writing Quarantine Data
        print("Writing to quarantine table...")
        spark.sql(f"CREATE TABLE IF NOT EXISTS {QUARANTINE_TABLE} USING DELTA")
        quarantine_delta = DeltaTable.forName(spark, QUARANTINE_TABLE)
        quarantine_delta.alias("t").merge(
            quarantine_df.alias("s"),
            "t.trip_id = s.trip_id"
        ).whenNotMatchedInsertAll().execute()

        quarantine_count = quarantine_df.count()
        quarantine_load_success=True
        print("Quarantine write successful.")
        print(f"Inserted {quarantine_count} records into quarantine table.")    

        #Logging end of successful job run
        max_bronze_ts= df_validated.select(F.max("ingest_ts")).collect()[0][0]
        success = True
        error_msg = ""
    else:
        silver_count=0
        quarantine_count=0
        max_bronze_ts=last_ingested_ts
        error_msg = "No new data to process"
        print(error_msg)

except Exception as e:
    print(f"Job Failed\nError:{e}")
    success = False
    error_msg = str(e)
    max_bronze_ts=last_ingested_ts
    if silver_load_success:
        print("Rolling back Silver table...")
        # Get the version before the one we just created
        prev_version = spark.sql(f"DESCRIBE HISTORY {SILVER_TABLE}").select("version").first()[0] - 1
        spark.sql(f"RESTORE TABLE {SILVER_TABLE} TO VERSION AS OF {prev_version}")
        silver_load_success = False # Reset flag after rollback
        silver_count = 0
        print(f"Silver rollback to version {prev_version} successful.")
    if quarantine_load_success:
            print("Rolling back Quarantine table...")
            prev_version = spark.sql(f"DESCRIBE HISTORY {QUARANTINE_TABLE}").select("version").first()[0] - 1
            spark.sql(f"RESTORE TABLE {QUARANTINE_TABLE} TO VERSION AS OF {prev_version}")
            quarantine_load_success = False
            quarantine_count = 0 
            print(f"Quarantine rollback to version {prev_version} successful.")
    if not silver_load_success and not quarantine_load_success:
        print("No data written to tables.")       

finally:
    end_ts=datetime.datetime.now()
    print(f"Job completed in {end_ts-start_ts}")
    print(f"Logging job run to {LOG_TABLE_SILVER}...")
    log_silver_ingestion(
        spark=spark,
        run_id=run_id,
        dataset_name=dataset_name,
        start_ts=start_ts,
        end_ts=end_ts,
        max_bronze_ts=max_bronze_ts,
        success=success,
        bronze_count=bronze_count if 'bronze_count' in locals() else 0,
        silver_count=silver_count,
        quarantine_count=quarantine_count,
        error_msg=error_msg,
        log_table = LOG_TABLE_SILVER,
        catalog_table= SILVER_TABLE
    )
    print("Log entry completed.")