# NYC Taxi Data Analysis with PySpark

**Important:** This notebook requires Python 3.11 or 3.12. Python 3.13 is not compatible with PySpark due to Java security manager changes.

If you see the error `getSubject is supported only if a security manager is allowed`, you need to:
1. Select a Python 3.12 kernel: Click on the kernel selector in the top right
2. Choose "Select Another Kernel" → "Python Environments"
3. Select a Python 3.12 interpreter (not 3.13)

In [1]:
import pandas as pd
import pyarrow.parquet as pq

pq_df = pd.read_parquet("data/raw/yellow_tripdata_2023-02.parquet")
print(pq_df.head())

table = pq.read_table("data/raw/yellow_tripdata_2023-02.parquet")
print(table)

   VendorID tpep_pickup_datetime tpep_dropoff_datetime  passenger_count  \
0         1  2023-02-01 00:32:53   2023-02-01 00:34:34              2.0   
1         2  2023-02-01 00:35:16   2023-02-01 00:35:30              1.0   
2         2  2023-02-01 00:35:16   2023-02-01 00:35:30              1.0   
3         1  2023-02-01 00:29:33   2023-02-01 01:01:38              0.0   
4         2  2023-02-01 00:12:28   2023-02-01 00:25:46              1.0   

   trip_distance  RatecodeID store_and_fwd_flag  PULocationID  DOLocationID  \
0           0.30         1.0                  N           142           163   
1           0.00         1.0                  N            71            71   
2           0.00         1.0                  N            71            71   
3          18.80         1.0                  N           132            26   
4           3.22         1.0                  N           161           145   

   payment_type  fare_amount  extra  mta_tax  tip_amount  tolls_amount  \


In [2]:
import os
import sys

# CRITICAL FIX for Java 21+ compatibility with PySpark
# Java 21+ removed security manager, need to enable it with special flag
# Must be set BEFORE any PySpark imports
import subprocess

# Check if JAVA_HOME is set, if not find Java location
java_cmd = "java"
try:
    result = subprocess.run([java_cmd, "-version"], capture_output=True, text=True)
    print(f"Java version detected: {result.stderr.split()[2]}")
except:
    print("Could not detect Java version")

# Set environment variables to enable security manager for Java 21+
os.environ["PYSPARK_SUBMIT_ARGS"] = (
    "--driver-java-options '-Djava.security.manager=allow' pyspark-shell"
)
os.environ["SPARK_SUBMIT_OPTS"] = "-Djava.security.manager=allow"

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import (
    StructType,
    StructField,
    DoubleType,
    LongType,
    StringType,
    TimestampNTZType,
)

taxi_schema = StructType(
    [
        StructField("VendorID", LongType(), True),
        StructField("tpep_pickup_datetime", TimestampNTZType(), True),
        StructField("tpep_dropoff_datetime", TimestampNTZType(), True),
        StructField("passenger_count", LongType(), True),
        StructField("trip_distance", DoubleType(), True),
        StructField("RatecodeID", LongType(), True),
        StructField("store_and_fwd_flag", StringType(), True),
        StructField("PULocationID", LongType(), True),
        StructField("DOLocationID", LongType(), True),
        StructField("payment_type", LongType(), True),
        StructField("fare_amount", DoubleType(), True),
        StructField("extra", DoubleType(), True),
        StructField("mta_tax", DoubleType(), True),
        StructField("tip_amount", DoubleType(), True),
        StructField("tolls_amount", DoubleType(), True),
        StructField("improvement_surcharge", DoubleType(), True),
        StructField("total_amount", DoubleType(), True),
        StructField("congestion_surcharge", DoubleType(), True),
        StructField("airport_fee", DoubleType(), True),
    ]
)

# Create Spark session with explicit Java options for Java 21+ compatibility
spark = (
    SparkSession.builder.appName("Lecture-Demo")
    .master("local[4]")
    .config("spark.driver.extraJavaOptions", "-Djava.security.manager=allow")
    .config("spark.executor.extraJavaOptions", "-Djava.security.manager=allow")
    .getOrCreate()
)
file_paths = [
    "data/raw/yellow_tripdata_2023-02.parquet",
    "data/raw/yellow_tripdata_2023-03.parquet",
]
df = spark.read.schema(taxi_schema).parquet(*file_paths)

Java version detected: "23.0.2"


In [3]:
plan = df.filter(F.col("passenger_count") > 0).groupBy("PULocationID").count()
plan.explain("formatted")  # Nothing runs yet

== Physical Plan ==
AdaptiveSparkPlan (7)
+- HashAggregate (6)
   +- Exchange (5)
      +- HashAggregate (4)
         +- Project (3)
            +- Filter (2)
               +- Scan parquet  (1)


(1) Scan parquet 
Output [2]: [passenger_count#3L, PULocationID#7L]
Batched: true
Location: InMemoryFileIndex [file:/c:/Users/adurs/OneDrive/Documents/repos/WashU/CSE 5114 - Data Manipulation/Assignments/5 - Spark/CSE5114-spark-nyc-taxi/data/raw/yellow_tripdata_2023-02.parquet, ... 1 entries]
PushedFilters: [IsNotNull(passenger_count), GreaterThan(passenger_count,0)]
ReadSchema: struct<passenger_count:bigint,PULocationID:bigint>

(2) Filter
Input [2]: [passenger_count#3L, PULocationID#7L]
Condition : (isnotnull(passenger_count#3L) AND (passenger_count#3L > 0))

(3) Project
Output [1]: [PULocationID#7L]
Input [2]: [passenger_count#3L, PULocationID#7L]

(4) HashAggregate
Input [1]: [PULocationID#7L]
Keys [1]: [PULocationID#7L]
Functions [1]: [partial_count(1)]
Aggregate Attributes [1]: [count#4

In [4]:
plan.orderBy(F.desc("count")).show(5)  # Action => triggers a Job

+------------+------+
|PULocationID| count|
+------------+------+
|         132|297868|
|         161|286788|
|         237|278842|
|         236|254010|
|         162|220572|
+------------+------+
only showing top 5 rows


In [5]:
spark.conf.set("spark.sql.shuffle.partitions", "8")
by_zone = df.groupBy("PULocationID").count()
by_zone.explain("formatted")  # Look for Exchange (shuffle) node
by_zone.orderBy(F.desc("count")).show(10)

== Physical Plan ==
AdaptiveSparkPlan (5)
+- HashAggregate (4)
   +- Exchange (3)
      +- HashAggregate (2)
         +- Scan parquet  (1)


(1) Scan parquet 
Output [1]: [PULocationID#7L]
Batched: true
Location: InMemoryFileIndex [file:/c:/Users/adurs/OneDrive/Documents/repos/WashU/CSE 5114 - Data Manipulation/Assignments/5 - Spark/CSE5114-spark-nyc-taxi/data/raw/yellow_tripdata_2023-02.parquet, ... 1 entries]
ReadSchema: struct<PULocationID:bigint>

(2) HashAggregate
Input [1]: [PULocationID#7L]
Keys [1]: [PULocationID#7L]
Functions [1]: [partial_count(1)]
Aggregate Attributes [1]: [count#71L]
Results [2]: [PULocationID#7L, count#72L]

(3) Exchange
Input [2]: [PULocationID#7L, count#72L]
Arguments: hashpartitioning(PULocationID#7L, 8), ENSURE_REQUIREMENTS, [plan_id=95]

(4) HashAggregate
Input [2]: [PULocationID#7L, count#72L]
Keys [1]: [PULocationID#7L]
Functions [1]: [count(1)]
Aggregate Attributes [1]: [count(1)#70L]
Results [2]: [PULocationID#7L, count(1)#70L AS count#50L]

(5) A

In [6]:
df_clean = df.select(
    F.col("tpep_pickup_datetime").alias("pickup_ts"),
    F.col("passenger_count").cast("int").alias("passengers"),
    F.col("PULocationID").cast("int").alias("PU"),
    F.col("DOLocationID").cast("int").alias("DO"),
    F.col("total_amount").cast("double").alias("total"),
).filter((F.col("passengers") > 0) & (F.col("total") >= 0))

df_clean.createOrReplaceTempView("trips_clean")

zones = (
    spark.read.option("header", True)
    .csv("data/raw/taxi_zone_lookup.csv")
    .select(
        F.col("LocationID").cast("int").alias("LocationID"),
        F.col("Borough"),
        F.col("Zone"),
    )
)

In [7]:
enriched = (
    df_clean.join(F.broadcast(zones), df_clean.PU == zones.LocationID, "left")
    .drop("LocationID")
    .withColumnRenamed("Zone", "PU_Zone")
)
enriched.select("PU", "PU_Zone", "total").show(5)

+---+--------------------+-----+
| PU|             PU_Zone|total|
+---+--------------------+-----+
|142| Lincoln Square East|  9.4|
| 71|East Flatbush/Far...|  5.5|
|161|      Midtown Center| 25.3|
|148|     Lower East Side|32.25|
|137|            Kips Bay| 50.0|
+---+--------------------+-----+
only showing top 5 rows


In [8]:
subset = df_clean.filter(F.col("PU") == 1)
subset_cached = subset.cache()
subset_cached.count()  # materialize

778

In [9]:
# Reused multiple times
subset_cached.groupBy("DO").count().count()
subset_cached.agg(F.sum("total")).collect()
subset_cached.unpersist()

DataFrame[pickup_ts: timestamp_ntz, passengers: int, PU: int, DO: int, total: double]

In [10]:
spark.conf.set("spark.sql.shuffle.partitions", "8")
_ = (
    df_clean.groupBy("PU")
    .agg(F.count("*").alias("trips"))
    .orderBy(F.desc("trips"))
    .show(10)
)

+---+------+
| PU| trips|
+---+------+
|132|293087|
|161|284623|
|237|276938|
|236|252419|
|162|218919|
|186|212418|
|230|206365|
|138|195118|
|142|193450|
|170|178950|
+---+------+
only showing top 10 rows


In [11]:
import time

spark.conf.set("spark.sql.shuffle.partitions", "8")
start = time.perf_counter()
_ = df_clean.groupBy("PU").agg(F.sum("total").alias("revenue")).count()
t1 = time.perf_counter() - start

# Example tweak: increase partitions if underutilized
spark.conf.set("spark.sql.shuffle.partitions", "16")
start = time.perf_counter()
_ = df_clean.groupBy("PU").agg(F.sum("total").alias("revenue")).count()
t2 = time.perf_counter() - start

print(f"Baseline: {t1:.2f}s | Tuned: {t2:.2f}s (lower is better)")

Baseline: 0.84s | Tuned: 0.64s (lower is better)


In [12]:
spark.stop()