<a href="https://colab.research.google.com/github/aristidekanamugire/nyc-taxi-spark-lab/blob/main/spark_nyc_taxi_lab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Lab: Why Spark? Hands‑on with NYC Yellow Taxi (Parquet)

**Duration:** 2–3 hours  
**Goal:** Experience Spark’s advantages on a multi‑million‑row dataset: schema‑on‑read, distributed transforms, joins, window functions, and caching—plus a quick Pandas vs. Spark comparison.

**Dataset (public, free):**
- Monthly NYC Yellow Taxi trip records (Parquet). Example month(s): **2023‑01** (~3M trips).  
  Source: NYC TLC Trip Record Data.  
- Taxi Zone lookup (CSV) for human‑readable zone names.

> This notebook is designed to run in Google Colab or a local Jupyter with PySpark installed.



## 0) Environment Setup

### Option A: Google Colab (simplest)
Just run the following cell to install PySpark and start a session.

### Option B: Local laptop
- Install Java 8+ and `pyspark` via `pip install pyspark`.
- Then run the same cells below.


In [1]:

# If running on Colab or a fresh environment, install PySpark first (no-op if already installed)
!pip -q install pyspark

In [2]:


import pyspark
from pyspark.sql import SparkSession, functions as F, types as T

spark = (
    SparkSession.builder
    .appName("nyc-taxi-spark-lab")
    # Adjust partitions to your environment; 200 is okay for Colab, reduce on small machines
    .config("spark.sql.shuffle.partitions", "200")
    .getOrCreate()
)

spark



## 1) Get the Data (NYC TLC Trip Records + Taxi Zones)

We’ll start with **January 2023** Yellow Taxi trips (Parquet), and the taxi zone lookup table (CSV).  
You can add more months later for the stretch tasks.


In [3]:

# Download monthly Parquet and taxi zone lookup
# Note: These are public files hosted by NYC TLC's CDN.
!wget -q https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2023-01.parquet
# Optional: uncomment to add another month
# !wget -q https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2023-02.parquet

!wget -q https://d37ci6vzurychx.cloudfront.net/misc/taxi_zone_lookup.csv

# Quick file listing (optional)
!ls -lh *.parquet *.csv


-rw-r--r-- 1 root root 13K Feb 22  2024 taxi_zone_lookup.csv
-rw-r--r-- 1 root root 46M Mar 20  2023 yellow_tripdata_2023-01.parquet



## 2) First Taste: Read, Schema, and Quick EDA (Spark)


In [4]:

# Read the Parquet file
taxi = spark.read.parquet("yellow_tripdata_2023-01.parquet")

# 1. Inspect schema & a few rows
taxi.printSchema()
taxi.show(5, truncate=False)

# 2. Basic record count & date bounds
taxi_count = taxi.count()
date_bounds = taxi.select(
    F.min("tpep_pickup_datetime").alias("min_pickup"),
    F.max("tpep_dropoff_datetime").alias("max_dropoff")
).collect()

print("Row count (Jan 2023):", taxi_count)
print("Date bounds:", date_bounds)


root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+----


**Checkpoint (answer briefly below):**
1. How many rows are in January 2023?  
2. Which fields are timestamps, categorical codes, and numeric measures? (Hint: `printSchema()` and the TLC data dictionary)



## 3) “This Would Be Annoying in Pandas” (Mini Comparison)

We’ll compute `trip_time_min` and `tip_pct` in **Pandas** on a **sample**, then do the same in **Spark** on **all rows**.


In [5]:

import pandas as pd
import numpy as np


# ---- Pandas on a sample ----
pdf = taxi.limit(200_000).toPandas()  # keep this small enough for your RAM
pdf["trip_time_min"] = (
    pd.to_datetime(pdf["tpep_dropoff_datetime"]) - pd.to_datetime(pdf["tpep_pickup_datetime"])
).dt.total_seconds() / 60



import numpy as np

den = pdf["fare_amount"].replace(0, np.nan)
pdf["tip_pct"] = pdf["tip_amount"] / den
pdf[["trip_time_min","tip_pct"]].describe()



Unnamed: 0,trip_time_min,tip_pct
count,200000.0,199923.0
mean,16.911646,0.189344
std,52.692663,2.25017
min,0.0,-1.172656
25%,6.916667,0.0
50%,11.716667,0.225989
75%,19.483333,0.282645
max,2596.633333,1000.0


In [6]:
from pyspark.sql import functions as F

taxi_spark = (
    taxi
    .withColumn("pickup_ts",  F.col("tpep_pickup_datetime").cast("timestamp"))
    .withColumn("dropoff_ts", F.col("tpep_dropoff_datetime").cast("timestamp"))
    .withColumn(
        "trip_time_min",
        (F.col("dropoff_ts").cast("long") - F.col("pickup_ts").cast("long")) / 60.0
    )
    .withColumn(
        "tip_pct",
        F.when(F.col("fare_amount") > 0,
               F.col("tip_amount") / F.col("fare_amount"))
    )
    .drop("pickup_ts", "dropoff_ts")
)

taxi_spark.select("trip_time_min","tip_pct").summary().show()


+-------+------------------+-------------------+
|summary|     trip_time_min|            tip_pct|
+-------+------------------+-------------------+
|  count|           3066766|            3040607|
|   mean|15.668995167332046|0.21256831393499978|
| stddev|42.594351241955756|  8.515465690030874|
|    min|             -29.2|-0.1297056810403833|
|    25%| 7.116666666666666|0.08264462809917356|
|    50%|11.516666666666667| 0.2404040404040404|
|    75%|              18.3| 0.2930232558139535|
|    max|10029.183333333332|            14500.0|
+-------+------------------+-------------------+




**Prompt:** What differences did you notice (runtime, memory comfort, coding ergonomics) between Pandas and Spark?



## 4) Joins & Aggregations with Taxi Zones
Compute top pickup zones by trips and average tip percentage.


In [7]:

zones = (
    spark.read.csv("taxi_zone_lookup.csv", header=True, inferSchema=True)
    .withColumnRenamed("LocationID", "LocationID_zone")
)

pickup_stats = (
    taxi_spark.groupBy("PULocationID")
    .agg(
        F.count("*").alias("trips"),
        F.mean("tip_pct").alias("avg_tip_pct")
    )
    .orderBy(F.desc("trips"))
)

pickup_stats_named = (
    pickup_stats.join(zones, pickup_stats.PULocationID == zones.LocationID_zone, "left")
    .select("PULocationID","Borough","Zone","trips","avg_tip_pct")
)

pickup_stats_named.show(20, truncate=False)


+------------+---------+---------------------------------+------+--------------------+
|PULocationID|Borough  |Zone                             |trips |avg_tip_pct         |
+------------+---------+---------------------------------+------+--------------------+
|26          |Brooklyn |Borough Park                     |107   |0.20331829873499704 |
|29          |Brooklyn |Brighton Beach                   |76    |0.016304171411700197|
|65          |Brooklyn |Downtown Brooklyn/MetroTech      |1550  |0.16213510093132721 |
|191         |Queens   |Queens Village                   |235   |0.009885360895633376|
|222         |Brooklyn |Starrett City                    |152   |0.05061861899469242 |
|243         |Manhattan|Washington Heights North         |386   |0.16051182201991682 |
|54          |Brooklyn |Columbia Street                  |35    |0.13967661495256645 |
|19          |Queens   |Bellerose                        |79    |0.048844434675436144|
|113         |Manhattan|Greenwich Village N


**Deliverables:**
- Top 10 pickup zones by trip volume (copy the output).  
- Which boroughs have the highest average tip %? Any surprises?



## 5) Window Functions: “Rush Hours” by Day‑of‑Week
Find the top 3 pickup hours for each day of week.


In [8]:

taxi_enriched = (
    taxi_spark
    .withColumn("dow", F.date_format("tpep_pickup_datetime","E"))
    .withColumn("hour", F.hour("tpep_pickup_datetime"))
)

by_hour = taxi_enriched.groupBy("dow","hour").agg(F.count("*").alias("trips"))

from pyspark.sql.window import Window
w = Window.partitionBy("dow").orderBy(F.desc("trips"))

ranked = by_hour.withColumn("rnk", F.row_number().over(w)).filter(F.col("rnk") <= 3)
ranked.orderBy("dow","rnk").show(50, truncate=False)


+---+----+-----+---+
|dow|hour|trips|rnk|
+---+----+-----+---+
|Fri|18  |30489|1  |
|Fri|17  |30091|2  |
|Fri|19  |29100|3  |
|Mon|17  |30162|1  |
|Mon|18  |28902|2  |
|Mon|15  |28831|3  |
|Sat|19  |27950|1  |
|Sat|18  |27918|2  |
|Sat|17  |27483|3  |
|Sun|14  |28081|1  |
|Sun|16  |28033|2  |
|Sun|15  |27358|3  |
|Thu|18  |34409|1  |
|Thu|17  |30992|2  |
|Thu|19  |28670|3  |
|Tue|18  |36405|1  |
|Tue|17  |34731|2  |
|Tue|15  |32007|3  |
|Wed|18  |32941|1  |
|Wed|17  |29555|2  |
|Wed|19  |29181|3  |
+---+----+-----+---+




**Deliverable:** For each weekday, list the **top 3 pickup hours**.



## 6) Data Quality & Filtering
Apply simple quality rules (justify them in words).


In [9]:

clean = taxi_enriched.filter(
    (F.col("trip_time_min").between(1, 180)) &
    (F.col("trip_distance").between(0.1, 60)) &
    (F.col("fare_amount") > 0)
)

clean_count = clean.count()
print("Rows after cleaning:", clean_count)
print("Filtered out %:", round((1 - clean_count / taxi.count())*100, 2))


Rows after cleaning: 2985014
Filtered out %: 2.67



**Deliverable:** What % of trips were filtered out? Provide a 4–6 sentence rationale for your thresholds (e.g., negative times, extreme distances, or zero fares).



## 7) Caching & Partitioning Effects (Performance)
Measure before/after caching, and try repartitioning.


In [10]:

from time import time

# Baseline (no cache)
start = time()
clean.groupBy("PULocationID").count().collect()
t1 = time() - start

# Cache + materialize
clean_cached = clean.cache()
clean_cached.count()
start = time()
clean_cached.groupBy("PULocationID").count().collect()
t2 = time() - start

# Repartition
clean_repart = clean.repartition(8, "PULocationID")  # Adjust 8 for your machine
start = time()
clean_repart.groupBy("PULocationID").count().collect()
t3 = time() - start

print(f"First run (no cache): {t1:.2f}s")
print(f"Second run (cached):  {t2:.2f}s")
print(f"Repartitioned run:    {t3:.2f}s")


First run (no cache): 4.87s
Second run (cached):  1.03s
Repartitioned run:    2.56s



**Deliverable:** Briefly explain why caching and repartitioning help (shuffle reduction, task parallelism, I/O). Include your measured times.



## 8) Stretch Tasks

1. **Union months:** Download Feb 2023 and union with Jan; redo “rush hours” and compare.  
   ```python
   feb = spark.read.parquet("yellow_tripdata_2023-02.parquet")
   all_trips = taxi_spark.unionByName(feb)
   ```
2. **Payment mix:** For each borough, compute percent of trips by payment type.
3. **Tip model (quick & dirty):** Using Spark ML, predict `tip_amount` from distance, time, hour, borough.
4. **Windowed rolling metric:** Average `trip_distance` per zone using a 1‑hour tumbling window.
5. **Outlier detection:** Flag implausible fares (e.g., `fare_amount/trip_distance` above the 99.5th percentile).



## References / Useful Links
- TLC Trip Record Data portal: https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page  
- Yellow Taxi monthly Parquet example (Jan 2023): https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2023-01.parquet  
- Taxi Zone lookup CSV: https://d37ci6vzurychx.cloudfront.net/misc/taxi_zone_lookup.csv
