# NYC Taxi ETL - Incremental Ingestion

In [1]:
import json
from pathlib import Path
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, DoubleType, StringType
from functools import reduce

In [2]:
INBOX_PATH = Path("data/inbox")
STATE_DIR = Path("state")
MANIFEST_PATH = STATE_DIR / "manifest.json"

STATE_DIR.mkdir(parents=True, exist_ok=True)

In [3]:
# Initialize Spark Session
spark = (SparkSession.builder
    .appName("NYC_Taxi_Incremental_Ingestion")
    .master("local[*]")
    .getOrCreate())

spark.sparkContext.setLogLevel("WARN")
print(f"Spark session initialized: {spark.version}")

Spark session initialized: 4.1.0


In [4]:
# Manifest management functions
def load_manifest():
    """Load the manifest tracking processed files."""
    if MANIFEST_PATH.exists():
        try:
            with open(MANIFEST_PATH, 'r') as f:
                manifest = json.load(f)

            if "processed_files" not in manifest:
                manifest["processed_files"] = []
            if "last_run" not in manifest:
                manifest["last_run"] = None
            return manifest
        except json.JSONDecodeError as e:
            print(f"Warning: Manifest file corrupted, creating new manifest. Error: {e}")
            return {"processed_files": [], "last_run": None}
        except Exception as e:
            print(f"Error reading manifest: {e}")
            raise
    return {"processed_files": [], "last_run": None}

def save_manifest(manifest):
    """Save the manifest to disk."""
    try:
        manifest["last_run"] = datetime.now().isoformat()
        with open(MANIFEST_PATH, 'w') as f:
            json.dump(manifest, indent=2, fp=f)
        print(f"Manifest saved: {len(manifest['processed_files'])} files tracked")
    except Exception as e:
        print(f"Error saving manifest: {e}")
        raise

def add_to_manifest(manifest, file_info):
    """Add a processed file to the manifest."""
    manifest["processed_files"].append(file_info)

print("Manifest functions defined")

Manifest functions defined


In [5]:
# Detect new
def get_new_files(inbox_path, manifest):
    """Identify files in inbox that haven't been processed yet."""
    inbox_path = Path(inbox_path)

    if not inbox_path.exists():
        raise FileNotFoundError(f"Inbox path does not exist: {inbox_path}")

    processed_filenames = {f["filename"] for f in manifest["processed_files"]}

    all_parquet = sorted(inbox_path.glob("*.parquet"))
    print(f"Parquet files found in inbox: {len(all_parquet)}")

    inbox_files = []
    for file in all_parquet:
        if "zone_lookup" in file.name: # for ignoring the zone lookup
            continue

        if file.name not in processed_filenames:
            file_stat = file.stat()
            inbox_files.append({
                "filename": file.name,
                "path": str(file),
                "size_bytes": file_stat.st_size
            })

    return inbox_files

manifest = load_manifest()
new_files = get_new_files(INBOX_PATH, manifest)

print("Manifest loaded")
print(f"  - Previously processed: {len(manifest['processed_files'])} files")
print(f"  - New files found: {len(new_files)} files")
if new_files:
    for f in new_files:
        print(f"    → {f['filename']} ({f['size_bytes']:,} bytes)")

Parquet files found in inbox: 3
Manifest loaded
  - Previously processed: 2 files
  - New files found: 0 files


In [6]:
# Process new files and update manifest
if len(new_files) == 0:
    print("No new files to process.")
else:
    print(f"Processing {len(new_files)} new file(s)...")
    
    all_dataframes = []
    
    for file_info in new_files:
        print(f"Processing: {file_info['filename']}")
        
        df = spark.read.parquet(file_info['path'])
        row_count = df.count()
        
        # Parse and cast types correctly
        df = (
            df
            .withColumn("VendorID", F.col("VendorID").cast(IntegerType()))
            .withColumn("tpep_pickup_datetime", F.to_timestamp("tpep_pickup_datetime"))
            .withColumn("tpep_dropoff_datetime", F.to_timestamp("tpep_dropoff_datetime"))
            .withColumn("passenger_count", F.col("passenger_count").cast(IntegerType()))
            .withColumn("trip_distance", F.col("trip_distance").cast(DoubleType()))
            .withColumn("RatecodeID", F.col("RatecodeID").cast(IntegerType()))
            .withColumn("store_and_fwd_flag", F.col("store_and_fwd_flag").cast(StringType()))
            .withColumn("PULocationID", F.col("PULocationID").cast(IntegerType()))
            .withColumn("DOLocationID", F.col("DOLocationID").cast(IntegerType()))
            .withColumn("payment_type", F.col("payment_type").cast(IntegerType()))
            .withColumn("fare_amount", F.col("fare_amount").cast(DoubleType()))
            .withColumn("extra", F.col("extra").cast(DoubleType()))
            .withColumn("mta_tax", F.col("mta_tax").cast(DoubleType()))
            .withColumn("tip_amount", F.col("tip_amount").cast(DoubleType()))
            .withColumn("tolls_amount", F.col("tolls_amount").cast(DoubleType()))
            .withColumn("improvement_surcharge", F.col("improvement_surcharge").cast(DoubleType()))
            .withColumn("total_amount", F.col("total_amount").cast(DoubleType()))
            .withColumn("congestion_surcharge", F.col("congestion_surcharge").cast(DoubleType()))
            .withColumn("Airport_fee", F.col("Airport_fee").cast(DoubleType()))
            .withColumn("cbd_congestion_fee", F.col("cbd_congestion_fee").cast(DoubleType()))
        )
        
        print("\nBad Row Examples")
        # negative or zero trip_distance
        bad_distance = df.filter((F.col("trip_distance").isNull()) | (F.col("trip_distance") <= 0)).limit(3)
        bad_distance_count = bad_distance.count()
        if bad_distance_count > 0:
            print(f"  Example 1 - Invalid trip_distance (null or ≤0): {bad_distance_count} rows found")
            bad_distance.select("trip_distance", "tpep_pickup_datetime", "fare_amount").show(3, truncate=False)
            print("Filtered out (trip_distance must be > 0)\n")
        
        # dropoff before pickup
        bad_time = df.filter(
            F.col("tpep_pickup_datetime").isNotNull() &
            F.col("tpep_dropoff_datetime").isNotNull() &
            (F.col("tpep_dropoff_datetime") < F.col("tpep_pickup_datetime"))
        ).limit(3)
        bad_time_count = bad_time.count()
        if bad_time_count > 0:
            print(f"  Example 2 - Time travel (dropoff before pickup): {bad_time_count} rows found")
            bad_time.select("tpep_pickup_datetime", "tpep_dropoff_datetime", "trip_distance").show(3, truncate=False)
            print("Filtered out (dropoff must be >= pickup)\n")
        
        # invalid passenger_count
        bad_passengers = df.filter(
            F.col("passenger_count").isNotNull() & 
            ~F.col("passenger_count").between(0, 8)
        ).limit(3)
        bad_passengers_count = bad_passengers.count()
        if bad_passengers_count > 0:
            print(f"  Example 3 - Invalid passenger_count (outside 0-8): {bad_passengers_count} rows found")
            bad_passengers.select("passenger_count", "trip_distance", "fare_amount").show(3, truncate=False)
            print("Filtered out (passenger_count must be 0-8)\n")
        
        # Apply data cleaning rules
        df = df.filter(
            F.col("tpep_pickup_datetime").isNotNull() &
            F.col("tpep_dropoff_datetime").isNotNull() &
            (F.col("tpep_dropoff_datetime") >= F.col("tpep_pickup_datetime")) &
            F.col("trip_distance").isNotNull() & (F.col("trip_distance") > 0) &
            F.col("passenger_count").isNotNull() & F.col("passenger_count").between(0, 8) &
            F.col("PULocationID").isNotNull() & F.col("DOLocationID").isNotNull()
        )
        # Money related cols should be >=0
        money_related_cols = [
            "extra","mta_tax","tip_amount","tolls_amount",
            "improvement_surcharge","congestion_surcharge",
            "Airport_fee","cbd_congestion_fee"
        ]
        
        for money in money_related_cols:
            df = df.withColumn(money, F.coalesce(F.col(money), F.lit(0.0)))
            df = df.filter(F.col(money) >= 0)


        df = df.withColumn("fare_amount", F.coalesce(F.col("fare_amount"), F.lit(0.0)))

        # Only total amount can not be null
        df = df.filter(
            F.col("total_amount").isNotNull() &
            (F.col("total_amount") >= 0)
        )
        new_row_count = df.count()
        
        # Deduplicate records using a defined key
        dedup_key = [
            "VendorID",
            "tpep_pickup_datetime",
            "tpep_dropoff_datetime",
            "PULocationID",
            "DOLocationID",
            "trip_distance",
            "total_amount",
        ]
        df = df.dropDuplicates(dedup_key)

        after_dedup_row_count = df.count()

        print(f"  Rows: {row_count:,}")
        print(f"  After Cleaning Rows: {new_row_count:,}")
        print(f"  After dedup Rows: {after_dedup_row_count:,}")
        print(f"  Size: {file_info['size_bytes']:,} bytes")

        ingested_at_value = datetime.now().isoformat()

        df = (
            df
            .withColumn("trip_duration_minutes", (F.unix_timestamp("tpep_dropoff_datetime") - F.unix_timestamp("tpep_pickup_datetime")) / 60.0)
            .withColumn("pickup_date", F.to_date("tpep_pickup_datetime"))
            .withColumn("source_file", F.lit(file_info["filename"]))
            .withColumn("ingested_at", F.lit(ingested_at_value))
        )

        all_dataframes.append(df)

        file_metadata = {
            "filename": file_info['filename'],
            "size_bytes": file_info['size_bytes'],
            "raw_row_count": row_count,
            "clean_row_count": new_row_count,
            "after_dedup_row_count": after_dedup_row_count,
            "processed_at": datetime.now().isoformat()
        }
        add_to_manifest(manifest, file_metadata)
        
        print(f"Added to manifest")
    
    save_manifest(manifest)
    
    print("Processing complete!")


No new files to process.


In [7]:
OUTBOX_PATH = Path("data/outbox")
OUT_PATH = OUTBOX_PATH / "trips_enriched.parquet"
LOOKUP_PATH = INBOX_PATH / "taxi_zone_lookup.parquet"
OUTBOX_PATH.mkdir(parents=True, exist_ok=True)

if len(new_files) == 0:
    print("No new cleaned dataframes to enrich")
else:
    print("Enriching files")
    trips_new = reduce(lambda a, b: a.unionByName(b, allowMissingColumns=True), 
                       all_dataframes)

    zones = (spark.read.parquet(str(LOOKUP_PATH)).select("LocationID", "Zone"))

    zones = F.broadcast(zones)

    pu = zones.select(
        F.col("LocationID").alias("PULocationID"),
        F.col("Zone").alias("PU_Zone"),
    )
    do = zones.select(
        F.col("LocationID").alias("DOLocationID"),
        F.col("Zone").alias("DO_Zone"),
    )

    trips_enriched_new = (
        trips_new
        .join(pu, on="PULocationID", how="left")
        .join(do, on="DOLocationID", how="left")
    )

    required_cols = [
        "tpep_pickup_datetime",
        "tpep_dropoff_datetime",
        "PULocationID",
        "DOLocationID",
        "PU_Zone",
        "DO_Zone",
        "passenger_count",
        "trip_distance",
        "trip_duration_minutes",
        "pickup_date",
        "source_file",
        "ingested_at",
        "fare_amount"
    ]
    trips_enriched_new = trips_enriched_new.select(*required_cols)

    # If we already have output data, load it and add the new trips to it
    if OUT_PATH.exists():
        trips_prev = spark.read.parquet(str(OUT_PATH))
        trips_all = trips_prev.unionByName(trips_enriched_new, allowMissingColumns=True)
    else:
        trips_all = trips_enriched_new
    
    trips_all = trips_all.cache()
    row_c = trips_all.count()
    trips_all.write.mode("overwrite").parquet(str(OUT_PATH))
    print("Final row count:", row_c)
    print(f"Wrote enriched dataset to: {OUT_PATH}")

No new cleaned dataframes to enrich


In [8]:
# CUSTOM SCENARIO: Flag suspicious trips
print("\n--- Custom Scenario: Flagging Suspicious Trips ---")

trips_with_suspicious = trips_all.withColumn(
    "is_suspicious",
    F.when(
        (F.col("trip_duration_minutes") > 120) | 
        (F.col("trip_distance") > 50) | 
        (F.col("fare_amount") < 0),
        True
    ).otherwise(False)
)

total_trips = trips_with_suspicious.count()
suspicious_trips = trips_with_suspicious.filter(F.col("is_suspicious") == True)
suspicious_count = suspicious_trips.count()

print(f"Total trips: {total_trips:,}")
print(f"Suspicious trips found: {suspicious_count:,}")
if total_trips > 0:
    print(f"Percentage: {suspicious_count/total_trips*100:.2f}%")

try:
    trips_with_suspicious.write.mode("overwrite").parquet(str(OUT_PATH))
    print(f"✓ Main output (with is_suspicious flag) written to: {OUT_PATH}")
    
    SUSPICIOUS_PATH = OUTBOX_PATH / "suspicious_trips.parquet"
    suspicious_trips.write.mode("overwrite").parquet(str(SUSPICIOUS_PATH))
    print(f"✓ Suspicious trips written to: {SUSPICIOUS_PATH}")
    
except Exception as e:
    print(f"⚠ Warning: Could not write files due to Spark connection timeout")
    print(f"   But data was successfully calculated!")
    print(f"   Try restarting the kernel and running just this cell again")

print("--- Custom Scenario Complete ---\n")


--- Custom Scenario: Flagging Suspicious Trips ---


NameError: name 'trips_all' is not defined