# Preprocess the TLC Data

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import pandas as pd

In [2]:
spark = (
    SparkSession.builder.appName("preprocess")
    .config("spark.sql.repl.eagerEval.enabled", True)
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.driver.memory", "2g")
    .config("spark.executor.memory", "4g")
    .getOrCreate()
)

In [3]:
sdf_2019 = spark.read.parquet('../data/raw/yellow_taxi_data_2019/')

# Preliminary data analysis

In [4]:
print(f'total of {sdf_2019.count():,} rows')

total of 84,598,444 rows


In [5]:
sdf_2019.printSchema()

root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: integer (nullable = true)



In [6]:
# Check the statistics for numeric attributes
sdf_2019.describe(['tip_amount', 'trip_distance', 'fare_amount', 'tolls_amount']).show()

+-------+------------------+------------------+------------------+------------------+
|summary|        tip_amount|     trip_distance|       fare_amount|      tolls_amount|
+-------+------------------+------------------+------------------+------------------+
|  count|          84598444|          84598444|          84598444|          84598444|
|   mean| 2.190078737505638|3.0183506184817515|13.412639732835764|0.3868694478611098|
| stddev|15.638996154306168| 8.093902044464816|174.17668755385404|1.8233435182238924|
|    min|            -221.0|         -37264.53|           -1856.0|             -70.0|
|    max|         141492.02|          45977.22|          943274.8|            3288.0|
+-------+------------------+------------------+------------------+------------------+



In [7]:
# Calculate the bounds for outliers detection
columns = ['fare_amount', 'tip_amount', 'tolls_amount', 'trip_distance']
mean_dict = {column: 'mean' for column in columns}
stddev_dict = {column: 'stddev' for column in columns}

mean = sdf_2019.agg(mean_dict).collect()[0]
std = sdf_2019.agg(stddev_dict).collect()[0]

bounds = {columns[i]: [mean[i] - 3*std[i], mean[i] + 3*std[i]] for i in range(4)}

In [8]:
# Check null value
for column in sdf_2019.columns:
    print(f'{column} has {sdf_2019.where(F.col(column).isNull()).count()} null value(s)')

VendorID has 0 null value(s)
tpep_pickup_datetime has 0 null value(s)
tpep_dropoff_datetime has 0 null value(s)
passenger_count has 444383 null value(s)
trip_distance has 0 null value(s)
RatecodeID has 444383 null value(s)
store_and_fwd_flag has 444383 null value(s)
PULocationID has 0 null value(s)
DOLocationID has 0 null value(s)
payment_type has 0 null value(s)
fare_amount has 0 null value(s)
extra has 0 null value(s)
mta_tax has 0 null value(s)
tip_amount has 0 null value(s)
tolls_amount has 0 null value(s)
improvement_surcharge has 0 null value(s)
total_amount has 0 null value(s)
congestion_surcharge has 5300601 null value(s)
airport_fee has 84598444 null value(s)


In [9]:
# line 17~24: Drop the attributes that we are not interested in
# line 25: Drop the rows with null passenger_count and RatecodeID
# line 26~27: Remove data not in 2019
# line 28~35: Remove outliers that are 3 standard deviation away from mean
# line 36: RateCodeID should be one of the integer in the range of 1-6
# line 37: Only trips paid by credit card or cash are included
# line 38: Passenger count must not be zero
# line 39~40: Zone 264, 265 are unknown zone
# line 41: Extract the pick-up month
# line 42: Extract the drop-off month

def transform(sdf, year):
    """
    return a modified sdf
    """
    sdf_mdf = sdf \
            .drop('VendorID', 
                  'store_and_fwd_flag', 
                  'extra',
                  'mta_tax',
                  'improvement_surcharge',
                  'total_amount',
                  'congestion_surcharge',
                  'airport_fee') \
            .dropna(how='any') \
            .filter((F.year('tpep_pickup_datetime') == year) & 
                    (F.year('tpep_dropoff_datetime') == year)) \
            .filter((F.col('tip_amount') >= bounds['tip_amount'][0]) & 
                    (F.col('tip_amount') < bounds['tip_amount'][1])) \
            .filter((F.col('trip_distance') > bounds['trip_distance'][0]) & 
                    (F.col('trip_distance') < bounds['trip_distance'][1])) \
            .filter((F.col('fare_amount') > bounds['fare_amount'][0]) & 
                    (F.col('fare_amount') < bounds['fare_amount'][1])) \
            .filter((F.col('tolls_amount') > bounds['tolls_amount'][0]) & 
                    (F.col('tolls_amount') < bounds['tolls_amount'][1])) \
            .filter(F.col('RatecodeID').isin(list(range(1, 7)))) \
            .filter(F.col('payment_type').isin(1, 2)) \
            .filter(F.col('passenger_count') != 0) \
            .filter((F.col('PULocationID').isin([264, 265]) == False) & 
                    (F.col('DOLocationID').isin([264, 265]) == False)) \
            .withColumn('PUMonth', F.month(F.col('tpep_pickup_datetime'))) \
            .withColumn('DOMonth', F.month(F.col('tpep_dropoff_datetime'))) \
    
    return sdf_mdf


new_sdf_2019 = transform(sdf_2019, 2019)

In [10]:
pu_2019_aggregated_result = new_sdf_2019 \
                            .groupBy("PULocationID",
                                     "PUMonth",
                                     ) \
                            .agg(
                                 F.count('PULocationID').alias("trip_count")
                             )
                    

pu_2019_aggregated_result.show()

+------------+-------+----------+
|PULocationID|PUMonth|trip_count|
+------------+-------+----------+
|         173|      3|       341|
|         113|      5|    108312|
|          28|      3|       874|
|         158|      2|     67429|
|          94|      3|       174|
|         201|      3|        26|
|         193|      3|      5104|
|           7|      2|      9858|
|         200|      3|       229|
|         123|      3|       349|
|         109|      3|         1|
|         211|      2|     54139|
|         129|      3|      3841|
|         137|      2|     86046|
|          87|      4|     48752|
|          75|      2|     51472|
|         247|      3|      1279|
|         181|      3|      6930|
|         180|      3|       161|
|         162|      5|    252678|
+------------+-------+----------+
only showing top 20 rows



In [11]:
do_2019_aggregated_result = new_sdf_2019 \
                            .groupBy("DOLocationID",
                                     "DOMonth",
                                     ) \
                            .agg(
                                 F.count('DOLocationID').alias("trip_count")
                             )
                    

do_2019_aggregated_result.show()

+------------+-------+----------+
|DOLocationID|DOMonth|trip_count|
+------------+-------+----------+
|          66|      4|     11802|
|         136|      4|       897|
|          89|      4|      4992|
|         173|      3|      2190|
|          28|      3|      2399|
|          94|      3|       718|
|         201|      3|       323|
|          71|      4|      1250|
|         123|      3|      1409|
|         193|      3|      6545|
|         200|      3|      2361|
|           7|      2|     25954|
|         109|      3|         2|
|         149|      4|       771|
|         216|      4|      4686|
|         129|      3|     13727|
|         211|      2|     55637|
|          87|      4|     51886|
|         135|      4|      1643|
|         247|      3|      3104|
+------------+-------+----------+
only showing top 20 rows



In [12]:
# write out the aggregated results
pu_2019_aggregated_result.write.mode('overwrite').parquet('../data/curated/pu_aggregated_result_2019')
do_2019_aggregated_result.write.mode('overwrite').parquet('../data/curated/do_aggregated_result_2019')

# Subsampling

In [13]:
# Also the same procedure will be applied to 2020 data
sdf_2020 = spark.read.parquet('../data/raw/yellow_taxi_data_2020/')
new_sdf_2020 = transform(sdf_2020, 2020)

In [14]:
# Due to lockdown of NYC in March, 2020, the records has declined substantially
# I will use data from January to February to validate model predictions
reduced_sdf = new_sdf_2020.filter(F.col("PUMonth").isin([1, 2]))

pu_2020_aggregated_result = reduced_sdf \
                            .groupBy("PULocationID",
                                     "PUMonth",
                                     ) \
                            .agg(
                                 F.count('PULocationID').alias("trip_count")
                             )

do_2020_aggregated_result = reduced_sdf \
                            .groupBy("DOLocationID",
                                     "DOMonth",
                                     ) \
                            .agg(
                                 F.count('DOLocationID').alias("trip_count")
                             )

pu_2020_aggregated_result.write.mode('overwrite').parquet("../data/curated/pu_aggregated_result_2020")
do_2020_aggregated_result.write.mode('overwrite').parquet("../data/curated/do_aggregated_result_2020")