# Pyspark setup

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

In [2]:
spark = (
    SparkSession.builder.appName("preprocessing of taxi data")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.driver.memory", "15g")
    .getOrCreate()
)

22/08/25 19:23:49 WARN Utils: Your hostname, kams-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 10.13.117.182 instead (on interface en0)
22/08/25 19:23:49 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/08/25 19:23:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


# Data preparation and gain basic insight of data

In [3]:
# check the data type for each column and size of dataset
sdf_all = spark.read.parquet('../data/raw/')
sdf_all.printSchema()

                                                                                

root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: integer (nullable = true)
 |-- airport_fee: integer (nullable = true)



In [4]:
sdf_all.show()

                                                                                

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2016-01-01 00:12:22|  2016-01-01 00:29:14|              1|          3.2|         1|                 N|          48|         262|           1|       14.0|  0.5|    0.5|      3.0

In [5]:
raw_data_count = sdf_all.count()
raw_data_count

                                                                                

180457103

# Filtering for credit card payments
As tip amount contribute greatly to a drivers income, and only tip amounts of credit card payments are recorded.

In [6]:
# Filtered for only payment type = 1

valid_payment_type_sdf = sdf_all.filter(F.col("payment_type") == 1)
valid_payment_type_count = valid_payment_type_sdf.count()
msg = (
        f"Out of {raw_data_count} raw data points, {raw_data_count - valid_payment_type_count} "
        f"payment type that is not 1 has been cleaned."
)
print(msg)
print("Remaining data count:",valid_payment_type_count)



Out of 180457103 raw data points, 61037121 payment type that is not 1 has been cleaned.
Remaining data count: 119419982


                                                                                

# Filtered for fare amount and total amount are >= $2.5
This is the initial charge for every taxi trip as specified by TLC.

In [7]:
valid_amount_payed_sdf = valid_payment_type_sdf.filter((F.col("fare_amount") >= 2.5) & 
                                                        (F.col("total_amount") >= 2.5))
valid_amount_payed_count = valid_amount_payed_sdf.count()
msg = (
        f"Out of {valid_payment_type_count} data points that is cleaned, "
        f"{valid_payment_type_count - valid_amount_payed_count} invalid charges has been removed"
)
print(msg)
print("Remaining data count:",valid_payment_type_count)




Out of 119419982 data points that is cleaned, 10083 invalid charges has been removed
Remaining data count: 119419982


                                                                                

# Filter VendorID that is not 1 or 2
According to TLC there are only two vendor IDs that provide the trip entries.

In [8]:
valid_vendor_id_sdf = valid_amount_payed_sdf.filter((F.col("VendorID") == 1) | 
                                                    (F.col("VendorID") == 2))
valid_vendor_id_count = valid_vendor_id_sdf.count()
msg = (
        f"Out of {valid_amount_payed_count} data points that is cleaned, "
        f"{valid_amount_payed_count - valid_vendor_id_count} invalid VendorID has been removed"
)
print(msg)
print("Remaining data count:",valid_vendor_id_count)



Out of 119409899 data points that is cleaned, 5807 invalid VendorID has been removed
Remaining data count: 119404092


                                                                                

# Filter negative amounts from extra, mta_tax, tip_amount, tolls_amount, improvement_surcharge

In [9]:
valid_charges_sdf = valid_vendor_id_sdf.filter((F.col("extra") >= 0) & 
                                               (F.col("mta_tax") >= 0) & 
                                               (F.col("tip_amount") >= 0) & 
                                               (F.col("tolls_amount") >= 0) & 
                                               (F.col("improvement_surcharge") >= 0))
valid_charges_count = valid_charges_sdf.count()
msg = (
        f"Out of {valid_vendor_id_count} data points that is cleaned, "
        f"{valid_vendor_id_count - valid_charges_count} negative charges has been removed"
)
print(msg)
print("Remaining data count:",valid_charges_count)



Out of 119404092 data points that is cleaned, 49 negative charges has been removed
Remaining data count: 119404043


                                                                                

# Filter negative distance
By common sense, taxi cannot trave a negative distance.

In [10]:
valid_distance_sdf = valid_charges_sdf.filter((F.col("trip_distance") >= 0))
valid_distance_count = valid_distance_sdf.count()
msg = (
        f"Out of {valid_charges_count} data points that is cleaned, "
        f"{valid_charges_count - valid_distance_count} negative distances has been removed"
)
print(msg)
print("Remaining data count:",valid_distance_count) 



Out of 119404043 data points that is cleaned, 0 negative distances has been removed
Remaining data count: 119404043


                                                                                

# Filter locations out of NYC

In [11]:
valid_location_sdf = valid_distance_sdf.filter((F.col('PULocationID').between(1, 263)) & 
                                            (F.col('DOLocationID').between(1, 263)))
valid_loc_count = valid_location_sdf.count()
msg = (
        f"Out of {valid_distance_count} raw data points, {valid_distance_count - valid_loc_count} "
        f"are out of location ID range [1, 263] and has been cleaned"
)
print(msg)
print("Remaining data count:",valid_loc_count)



Out of 119404043 raw data points, 2084487 are out of location ID range [1, 263] and has been cleaned
Remaining data count: 117319556


                                                                                

# Generate trip duration in seconds
This columun is generated for removing invalid entries, and for future feature engineering to create a trip value column.

In [12]:
trip_duration_add_sdf = valid_location_sdf.withColumn('duration',
                                            F.col("tpep_dropoff_datetime").cast("long")\
                                            - F.col('tpep_pickup_datetime').cast("long"))
trip_duration_add_sdf.show()


+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+--------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|duration|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+--------+
|       1| 2016-01-01 00:12:22|  2016-01-01 00:29:14|              1|          3.2|         1|                 N|          48|         262|           1|       1

In [13]:
# convert trip duration in seconds to minutes
trip_duration_add_sdf = trip_duration_add_sdf.withColumn("duration", F.round(F.col("duration") / 60, 3))
trip_duration_add_sdf = trip_duration_add_sdf.withColumnRenamed("duration", "duration (minutes)")

# Filter negative trip duration and trip duration greather than 10 hours
According to TLC, the legal maximum consecutive trip hours allowed is 10 hour.

In [14]:
valid_duration_sdf = trip_duration_add_sdf.filter(F.col("duration").between(1, 10*60))
valid_duration_count = valid_duration_sdf.count()

msg = (
        f"Out of {valid_loc_count} valid location data points, {valid_loc_count - valid_duration_count} "
        f"are out of duration range [1 minute, 10 hours] and has been cleaned"
)
print(msg)
print("Remaining data count:", valid_duration_count)



Out of 117319556 valid location data points, 507084 are out of duration range [1 minute, 10 hours] and has been cleaned
Remaining data count: 116812472


                                                                                

In [15]:
# quick check on the removal of invalid trip durations
valid_duration_sdf.describe("duration (minutes)")

                                                                                

summary,duration (minutes)
count,116812472.0
mean,14.746870260309173
stddev,11.755260472955396
min,1.0
max,599.917


# Filter records with datatime out of defined data range

In [17]:
valid_datetime_sdf = valid_duration_sdf.filter((F.col("tpep_pickup_datetime") >= "2016-01-01") & \
                                                (F.col("tpep_pickup_datetime") <= "2017-05-31"))

valid_datetime_count = valid_datetime_sdf.count()

msg = (
        f"Out of {valid_duration_count} valid location data points, {valid_duration_count - valid_datetime_count} "
        f"are out of the range 2016-01-01 to 2017-05-31 and has been cleaned"
)
print(msg)
print("Remaining data count:", valid_datetime_count)



Out of 116812472 valid location data points, 220947 are out of the range 2016-01-01 to 2017-05-31 and has been cleaned
Remaining data count: 116591525


                                                                                

# Retain columns of interest for feature engineerring.

In [18]:
cleaned_trip_sdf = valid_datetime_sdf.select(["tpep_pickup_datetime", 
                                              "PULocationID", 
                                              "fare_amount", 
                                              "extra", 
                                              "tip_amount", 
                                              "duration (minutes)"])

In [19]:
cleaned_trip_sdf.printSchema()

root
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- duration (minutes): double (nullable = true)



# Check for any null values to decided whether imputation should be performed.

In [20]:
Dict_Null = {col:cleaned_trip_sdf.filter(cleaned_trip_sdf[col].isNull()).count() for col in cleaned_trip_sdf.columns}
Dict_Null

                                                                                

{'tpep_pickup_datetime': 0,
 'PULocationID': 0,
 'fare_amount': 0,
 'extra': 0,
 'tip_amount': 0,
 'duration (minutes)': 0}

In [21]:
cleaned_trip_sdf.write.mode('overwrite').parquet('../data/curated/preprocess_taxi_result1')


                                                                                

In [22]:
cleaned_trip_sdf.show()
print(cleaned_trip_sdf.count())

+--------------------+------------+-----------+-----+----------+------------------+
|tpep_pickup_datetime|PULocationID|fare_amount|extra|tip_amount|duration (minutes)|
+--------------------+------------+-----------+-----+----------+------------------+
| 2016-01-01 00:12:22|          48|       14.0|  0.5|      3.06|            16.867|
| 2016-01-01 00:49:47|         141|       11.0|  0.5|      2.45|             14.95|
| 2016-01-01 00:44:33|          68|        7.5|  0.5|       2.2|             9.067|
| 2016-01-01 00:56:47|         234|        9.0|  0.5|       0.0|            11.017|
| 2016-01-01 00:27:11|         239|        9.5|  0.5|      1.75|             9.817|
| 2016-01-01 00:19:09|         239|        6.5|  0.5|      1.95|             6.183|
| 2016-01-01 00:26:54|         143|       28.5|  0.5|       7.0|             28.55|
| 2016-01-01 00:19:36|         164|        6.5|  0.5|       1.0|             6.667|
| 2016-01-01 00:48:09|         246|       12.0|  0.5|      3.95|            



116591525


                                                                                