In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder\
    .master('local') \
    .appName('nyctaxi') \
    .getOrCreate()

In [2]:
df = (spark.read
          .option("header", True)
          .option("inferSchema", True)
          .csv("sample/Sample NYC Data.csv")
)

In [3]:
spark.conf.set("spark.sql.repl.eagerEval.enabled",True) #OK for exploration, not great for performance

In [4]:
df

medallion,hack_license,vendor_id,rate_code,store_and_fwd_flag,pickup_datetime,dropoff_datetime,passenger_count,pickup_longitude,pickup_latitude,dropoff_longitude,dropoff_latitude
89D227B655E5C82AE...,BA96DE419E711691B...,CMT,1,N,01-01-13 15:11,01-01-13 15:18,4,-73.978165,40.757977,-73.989838,40.751171
0BD7C8F5BA12B88E0...,9FD8F69F0804BDB55...,CMT,1,N,06-01-13 00:18,06-01-13 00:22,1,-74.006683,40.731781,-73.994499,40.75066
0BD7C8F5BA12B88E0...,9FD8F69F0804BDB55...,CMT,1,N,05-01-13 18:49,05-01-13 18:54,1,-74.004707,40.73777,-74.009834,40.726002
DFD2202EE08F7A8DC...,51EE87E3205C985EF...,CMT,1,N,07-01-13 23:54,07-01-13 23:58,2,-73.974602,40.759945,-73.984734,40.759388
DFD2202EE08F7A8DC...,51EE87E3205C985EF...,CMT,1,N,07-01-13 23:25,07-01-13 23:34,1,-73.97625,40.748528,-74.002586,40.747868
20D9ECB2CA0767CF7...,598CCE5B9C1918568...,CMT,1,N,07-01-13 15:27,07-01-13 15:38,1,-73.966743,40.764252,-73.983322,40.743763
496644932DF393260...,513189AD756FF14FE...,CMT,1,N,08-01-13 11:01,08-01-13 11:08,1,-73.995804,40.743977,-74.007416,40.744343
0B57B9633A2FECD3D...,CCD4367B417ED6634...,CMT,1,N,07-01-13 12:39,07-01-13 13:10,3,-73.989937,40.756775,-73.86525,40.77063
2C0E91FF20A856C89...,1DA2F6543A62B8ED9...,CMT,1,N,07-01-13 18:15,07-01-13 18:20,1,-73.980072,40.743137,-73.982712,40.735336
2D4B95E2FA7B2E851...,CD2F522EEE1FF5F5A...,CMT,1,N,07-01-13 15:33,07-01-13 15:49,2,-73.977936,40.786983,-73.952919,40.80637


In [5]:
df.printSchema()

root
 |-- medallion: string (nullable = true)
 |-- hack_license: string (nullable = true)
 |-- vendor_id: string (nullable = true)
 |-- rate_code: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- pickup_datetime: string (nullable = true)
 |-- dropoff_datetime: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- pickup_longitude: double (nullable = true)
 |-- pickup_latitude: double (nullable = true)
 |-- dropoff_longitude: double (nullable = true)
 |-- dropoff_latitude: double (nullable = true)



In [6]:
import pyspark.sql.functions as F

#selecting few important columns into new df
taxi_df = df.select(
    F.col("hack_license"), 
    F.col("passenger_count"),
    F.col("pickup_datetime"),
    F.col("pickup_longitude"),
    F.col("pickup_latitude"),
    F.col("dropoff_datetime"),
    F.col("dropoff_longitude"),
    F.col("dropoff_latitude")
)

In [7]:
taxi_df

hack_license,passenger_count,pickup_datetime,pickup_longitude,pickup_latitude,dropoff_datetime,dropoff_longitude,dropoff_latitude
BA96DE419E711691B...,4,01-01-13 15:11,-73.978165,40.757977,01-01-13 15:18,-73.989838,40.751171
9FD8F69F0804BDB55...,1,06-01-13 00:18,-74.006683,40.731781,06-01-13 00:22,-73.994499,40.75066
9FD8F69F0804BDB55...,1,05-01-13 18:49,-74.004707,40.73777,05-01-13 18:54,-74.009834,40.726002
51EE87E3205C985EF...,2,07-01-13 23:54,-73.974602,40.759945,07-01-13 23:58,-73.984734,40.759388
51EE87E3205C985EF...,1,07-01-13 23:25,-73.97625,40.748528,07-01-13 23:34,-74.002586,40.747868
598CCE5B9C1918568...,1,07-01-13 15:27,-73.966743,40.764252,07-01-13 15:38,-73.983322,40.743763
513189AD756FF14FE...,1,08-01-13 11:01,-73.995804,40.743977,08-01-13 11:08,-74.007416,40.744343
CCD4367B417ED6634...,3,07-01-13 12:39,-73.989937,40.756775,07-01-13 13:10,-73.86525,40.77063
1DA2F6543A62B8ED9...,1,07-01-13 18:15,-73.980072,40.743137,07-01-13 18:20,-73.982712,40.735336
CD2F522EEE1FF5F5A...,2,07-01-13 15:33,-73.977936,40.786983,07-01-13 15:49,-73.952919,40.80637


In [8]:
import json
from shapely.geometry import shape, Point
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf, col

# Load GeoJSON file
geojson_path = "input/nyc-boroughs.geojson"
with open(geojson_path, "r") as f:
    geojson_data = json.load(f)

In [9]:
# Extract and sort boroughs by size (descending) and borough code
boroughs = []
for feature in geojson_data["features"]:
    borough_code = feature["properties"]["boroughCode"]
    polygon = shape(feature["geometry"])
    borough_name = feature["properties"]["borough"]
    boroughs.append((borough_code, polygon, borough_name))

In [10]:
# Sort by borough code (Manhattan=1, Staten Island=5) and polygon area
boroughs.sort(key=lambda x: (x[0], -x[1].area))

In [11]:
# Function to find the borough for a given lat/lon
def find_borough(lat, lon):
    point = Point(lon, lat)
    for _, polygon, borough in boroughs:
        if polygon.contains(point):
            return borough
    return "Unknown"

In [12]:
# Register UDF
find_borough_udf = udf(find_borough, StringType())

In [13]:
# Enrich taxi data with boroughs
taxi_df = taxi_df.withColumn("pickup_borough", find_borough_udf(col("pickup_latitude"), col("pickup_longitude"))) \
                 .withColumn("dropoff_borough", find_borough_udf(col("dropoff_latitude"), col("dropoff_longitude")))

In [14]:
taxi_df

hack_license,passenger_count,pickup_datetime,pickup_longitude,pickup_latitude,dropoff_datetime,dropoff_longitude,dropoff_latitude,pickup_borough,dropoff_borough
BA96DE419E711691B...,4,01-01-13 15:11,-73.978165,40.757977,01-01-13 15:18,-73.989838,40.751171,Manhattan,Manhattan
9FD8F69F0804BDB55...,1,06-01-13 00:18,-74.006683,40.731781,06-01-13 00:22,-73.994499,40.75066,Manhattan,Manhattan
9FD8F69F0804BDB55...,1,05-01-13 18:49,-74.004707,40.73777,05-01-13 18:54,-74.009834,40.726002,Manhattan,Manhattan
51EE87E3205C985EF...,2,07-01-13 23:54,-73.974602,40.759945,07-01-13 23:58,-73.984734,40.759388,Manhattan,Manhattan
51EE87E3205C985EF...,1,07-01-13 23:25,-73.97625,40.748528,07-01-13 23:34,-74.002586,40.747868,Manhattan,Manhattan
598CCE5B9C1918568...,1,07-01-13 15:27,-73.966743,40.764252,07-01-13 15:38,-73.983322,40.743763,Manhattan,Manhattan
513189AD756FF14FE...,1,08-01-13 11:01,-73.995804,40.743977,08-01-13 11:08,-74.007416,40.744343,Manhattan,Manhattan
CCD4367B417ED6634...,3,07-01-13 12:39,-73.989937,40.756775,07-01-13 13:10,-73.86525,40.77063,Manhattan,Queens
1DA2F6543A62B8ED9...,1,07-01-13 18:15,-73.980072,40.743137,07-01-13 18:20,-73.982712,40.735336,Manhattan,Manhattan
CD2F522EEE1FF5F5A...,2,07-01-13 15:33,-73.977936,40.786983,07-01-13 15:49,-73.952919,40.80637,Manhattan,Manhattan
