# Preprocess Yellow Taxis

In [21]:
from pyspark.sql import SparkSession

# Create a Spark session
spark = SparkSession.builder.getOrCreate()

In [22]:
# Data output directory is `data/raw/`
tlc_output_dir = '../data/raw'

# Set the year and months
YEAR_2023 = '2023'
YEAR_2024 = '2024'

# Months: November 2023 to May 2024
MONTHS_2023 = range(11, 13)  
MONTHS_2024 = range(1, 6)    


# Combine all yellow taxi datasets
yellow_taxi_files = [f'{tlc_output_dir}/yellow_taxi_{YEAR_2023}-{str(m).zfill(2)}.parquet' for m in MONTHS_2023] + \
                    [f'{tlc_output_dir}/yellow_taxi_{YEAR_2024}-{str(m).zfill(2)}.parquet' for m in MONTHS_2024]


In [23]:
# Load all the datasets into a single Spark DataFrame
yellow_taxi_sdf = spark.read.parquet(*yellow_taxi_files)

# Display the schema to verify the data structure
yellow_taxi_sdf.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)



                                                                                

In [24]:
# Generate summary statistics 
yellow_taxi_sdf.describe().show(vertical=True)



-RECORD 0------------------------------------
 summary               | count               
 VendorID              | 23509182            
 passenger_count       | 21631300            
 trip_distance         | 23509182            
 RatecodeID            | 21631300            
 store_and_fwd_flag    | 21631300            
 PULocationID          | 23509182            
 DOLocationID          | 23509182            
 payment_type          | 23509182            
 fare_amount           | 23509182            
 extra                 | 23509182            
 mta_tax               | 23509182            
 tip_amount            | 23509182            
 tolls_amount          | 23509182            
 improvement_surcharge | 23509182            
 total_amount          | 23509182            
 congestion_surcharge  | 21631300            
 Airport_fee           | 21631300            
-RECORD 1------------------------------------
 summary               | mean                
 VendorID              | 1.7544913

                                                                                

In [25]:
import pyspark.sql.functions as F

# Check for NULL values in each column
missing_values = yellow_taxi_sdf.select([F.sum(F.col(col).isNull().cast("int")).alias(col) for col in yellow_taxi_sdf.columns])
missing_values.show(vertical=True)



-RECORD 0------------------------
 VendorID              | 0       
 tpep_pickup_datetime  | 0       
 tpep_dropoff_datetime | 0       
 passenger_count       | 1877882 
 trip_distance         | 0       
 RatecodeID            | 1877882 
 store_and_fwd_flag    | 1877882 
 PULocationID          | 0       
 DOLocationID          | 0       
 payment_type          | 0       
 fare_amount           | 0       
 extra                 | 0       
 mta_tax               | 0       
 tip_amount            | 0       
 tolls_amount          | 0       
 improvement_surcharge | 0       
 total_amount          | 0       
 congestion_surcharge  | 1877882 
 Airport_fee           | 1877882 



                                                                                

In [26]:
from pyspark.sql import functions as F

# Define the time range for filtering
date1 = "2023-11-01 00:00:00"
date2 = "2024-05-31 23:59:59"

# Convert the strings to timestamp types
date1_ts = F.to_timestamp(F.lit(date1))
date2_ts = F.to_timestamp(F.lit(date2))

# Filter the DataFrame for records within the specified time range
yellow_taxi_sdf = yellow_taxi_sdf.filter(
    (F.col("tpep_pickup_datetime") >= date1_ts) & 
    (F.col("tpep_pickup_datetime") <= date2_ts)
)

# Remove data points that have negative trip distance
yellow_taxi_sdf = yellow_taxi_sdf.filter(F.col("trip_distance") >= 0.2)

In [27]:
columns_to_check = [
    'fare_amount', 'extra', 'mta_tax', 'tip_amount', 'tolls_amount',
    'improvement_surcharge', 'total_amount', 'congestion_surcharge', 'Airport_fee'
]

for col in columns_to_check:
    yellow_taxi_sdf = yellow_taxi_sdf.filter(F.col(col) >= 0)

# Additionally, filter out records where fare_amount is less than $3
yellow_taxi_sdf = yellow_taxi_sdf.filter(F.col('fare_amount') >= 3)

In [28]:
# Filter out invalid trip distance, pickup LocationID, passenger_count. And filter pick up Location ID out
# of the city
yellow_taxi_sdf = yellow_taxi_sdf.filter(F.col('trip_distance') >= 0.2) \
    .filter(F.col('tpep_dropoff_datetime') > F.col('tpep_pickup_datetime')) \
    .filter(F.col('passenger_count') > 0) \
    .filter((F.col('PULocationID') >= 1) & (F.col('PULocationID') <= 263))

In [29]:
# Only keep those with rate code ID from 1 to 6 as defined on TLC dictionary
valid_ratecode_ids = [1, 2, 3, 4, 5, 6]

yellow_taxi_sdf = yellow_taxi_sdf.filter(F.col('RatecodeID').isin(valid_ratecode_ids))


In [30]:
from pyspark.sql.functions import col, unix_timestamp, round

# Calculate trip duration in minutes
yellow_taxi_sdf = yellow_taxi_sdf.withColumn(
    "trip_duration", 
    round((unix_timestamp(col("tpep_dropoff_datetime")) - unix_timestamp(col("tpep_pickup_datetime"))) / 60, 2)
)

# Show a few records to verify the new column
yellow_taxi_sdf.select("tpep_pickup_datetime", "tpep_dropoff_datetime", "trip_duration").show(5, truncate=False)

+--------------------+---------------------+-------------+
|tpep_pickup_datetime|tpep_dropoff_datetime|trip_duration|
+--------------------+---------------------+-------------+
|2023-11-01 00:03:03 |2023-11-01 01:04:08  |61.08        |
|2023-11-01 00:03:50 |2023-11-01 00:04:59  |1.15         |
|2023-11-01 00:06:30 |2023-11-01 00:14:25  |7.92         |
|2023-11-01 00:17:18 |2023-11-01 00:23:39  |6.35         |
|2023-11-01 00:14:49 |2023-11-01 00:39:44  |24.92        |
+--------------------+---------------------+-------------+
only showing top 5 rows



In [31]:
# Define the columns of interest
columns_of_interest = ["trip_distance", "trip_duration", "tolls_amount", "fare_amount", "total_amount"]

# Calculate the 99.98th percentile for each column of interest
percentile_expr = [F.expr(f'percentile_approx({col}, 0.9998)').alias(col) for col in columns_of_interest]

# Get the 99.98th percentile values
percentiles = yellow_taxi_sdf.agg(*percentile_expr).collect()[0].asDict()

# Display the 99.98th percentile values
print("99.98th Percentile Values:")
for col, value in percentiles.items():
    print(f"{col}: {value}")



99.98th Percentile Values:
trip_distance: 44.16
trip_duration: 1429.63
tolls_amount: 27.69
fare_amount: 220.0
total_amount: 261.33


                                                                                

In [32]:
# Filter the DataFrame to remove records with values above the 99.98th percentile
yellow_taxi_sdf = yellow_taxi_sdf.filter(
    (F.col("tolls_amount") <= percentiles['tolls_amount']) &
    (F.col("fare_amount") <= percentiles['fare_amount']) &
    (F.col("total_amount") <= percentiles['total_amount'])
)

In [33]:
# Apply the filters for trip distance > 1000 miles and  trip duration < 16 hours (960 minutes) and > 1 min
yellow_taxi_sdf = yellow_taxi_sdf.filter(
    (F.col("trip_distance") <= 1000) & (F.col("trip_duration") <= 960) & (F.col("trip_duration") >= 1)
)

# Show the statistical summary after filtering
yellow_taxi_sdf.describe().show(vertical=True)



-RECORD 0------------------------------------
 summary               | count               
 VendorID              | 20397304            
 passenger_count       | 20397304            
 trip_distance         | 20397304            
 RatecodeID            | 20397304            
 store_and_fwd_flag    | 20397304            
 PULocationID          | 20397304            
 DOLocationID          | 20397304            
 payment_type          | 20397304            
 fare_amount           | 20397304            
 extra                 | 20397304            
 mta_tax               | 20397304            
 tip_amount            | 20397304            
 tolls_amount          | 20397304            
 improvement_surcharge | 20397304            
 total_amount          | 20397304            
 congestion_surcharge  | 20397304            
 Airport_fee           | 20397304            
 trip_duration         | 20397304            
-RECORD 1------------------------------------
 summary               | mean     

                                                                                

## Get total yellow taxi trips pick up per location and time

In [34]:
# Retain only the specified columns
yellow_taxi_sdf = yellow_taxi_sdf.select("tpep_pickup_datetime", "PULocationID")

# Extract date and hour from tpep_pickup_datetime
yellow_taxi_sdf = yellow_taxi_sdf.withColumn("date", F.to_date("tpep_pickup_datetime")) \
                                 .withColumn("hour", F.hour("tpep_pickup_datetime"))

# Create a 3-hour bucket column
yellow_taxi_sdf = yellow_taxi_sdf.withColumn("hour_bucket", F.floor(F.col("hour") / 3) * 3)

# Select only the relevant columns: date, hour_bucket, PULocationID
yellow_taxi_sdf = yellow_taxi_sdf.select("date", "hour_bucket", "PULocationID")

In [35]:
from pyspark.sql import Window
from pyspark.sql import DataFrame
import itertools

# First, aggregate the original data to count the total occurrences for each combination
yellow_taxi_sdf_aggregated = yellow_taxi_sdf.groupBy("date", "hour_bucket", "PULocationID") \
                                           .agg(F.count("*").alias("total"))

# Generate the full set of possible combinations
dates = yellow_taxi_sdf.select("date").distinct().rdd.flatMap(lambda x: x).collect()
hours = list(range(0, 24, 3))  # 3-hour buckets: 0, 3, 6, 9, 12, 15, 18, 21
locations = yellow_taxi_sdf.select("PULocationID").distinct().rdd.flatMap(lambda x: x).collect()

# Create a DataFrame with all combinations of date, hour_bucket, and PULocationID
combinations = itertools.product(dates, hours, locations)
combinations_df = spark.createDataFrame(combinations, ["date", "hour_bucket", "PULocationID"])

# Perform a left join with the aggregated data to fill in the missing combinations with 0 counts
yellow_taxi_sdf_full = combinations_df.join(
    yellow_taxi_sdf_aggregated, on=["date", "hour_bucket", "PULocationID"], how="left"
).fillna(0, subset=["total"])

# Show the first few rows to verify
yellow_taxi_sdf_full.show(5, truncate=False)


                                                                                

+----------+-----------+------------+-----+
|date      |hour_bucket|PULocationID|total|
+----------+-----------+------------+-----+
|2023-11-08|0          |85          |0    |
|2023-11-08|0          |137         |22   |
|2023-11-08|0          |65          |2    |
|2023-11-08|0          |31          |0    |
|2023-11-08|0          |148         |58   |
+----------+-----------+------------+-----+
only showing top 5 rows



In [36]:
# Show the first few rows to verify
yellow_taxi_sdf_full.show(100, truncate=False)



+----------+-----------+------------+-----+
|date      |hour_bucket|PULocationID|total|
+----------+-----------+------------+-----+
|2023-11-08|0          |174         |0    |
|2023-11-08|3          |262         |14   |
|2023-11-08|9          |183         |0    |
|2023-11-08|9          |222         |0    |
|2023-11-08|9          |262         |404  |
|2023-11-08|3          |16          |0    |
|2023-11-08|6          |1           |0    |
|2023-11-08|0          |141         |35   |
|2023-11-08|6          |190         |0    |
|2023-11-08|6          |237         |580  |
|2023-11-08|0          |231         |39   |
|2023-11-08|3          |32          |0    |
|2023-11-08|3          |214         |0    |
|2023-11-08|6          |99          |0    |
|2023-11-08|9          |128         |0    |
|2023-11-08|0          |229         |47   |
|2023-11-08|3          |78          |0    |
|2023-11-08|3          |73          |0    |
|2023-11-08|6          |12          |0    |
|2023-11-08|6          |165     

                                                                                

In [37]:
# Generate summary statistics for the dataset
yellow_taxi_sdf_full.describe().show(vertical=True)

                                                                                

-RECORD 0--------------------------
 summary      | count              
 hour_bucket  | 439632             
 PULocationID | 439632             
 total        | 439632             
-RECORD 1--------------------------
 summary      | mean               
 hour_bucket  | 10.5               
 PULocationID | 132.984496124031   
 total        | 46.39631328019799  
-RECORD 2--------------------------
 summary      | stddev             
 hour_bucket  | 6.873871360194142  
 PULocationID | 76.12574610782154  
 total        | 144.28656754050212 
-RECORD 3--------------------------
 summary      | min                
 hour_bucket  | 0                  
 PULocationID | 1                  
 total        | 0                  
-RECORD 4--------------------------
 summary      | max                
 hour_bucket  | 21                 
 PULocationID | 263                
 total        | 1869               



# Preprocess high volume for hire vehicle (HVFHV)

In [38]:
# Combine all hvfhv datasets
hvfhv_files = [f'{tlc_output_dir}/hvfhv_{YEAR_2023}-{str(m).zfill(2)}.parquet' for m in MONTHS_2023] + \
                    [f'{tlc_output_dir}/hvfhv_{YEAR_2024}-{str(m).zfill(2)}.parquet' for m in MONTHS_2024]


In [39]:
# Load all the datasets into a single Spark DataFrame
hvfhv_sdf = spark.read.parquet(*hvfhv_files)

# Display the schema 
hvfhv_sdf.printSchema()

root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- originating_base_num: string (nullable = true)
 |-- request_datetime: timestamp_ntz (nullable = true)
 |-- on_scene_datetime: timestamp_ntz (nullable = true)
 |-- pickup_datetime: timestamp_ntz (nullable = true)
 |-- dropoff_datetime: timestamp_ntz (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- trip_miles: double (nullable = true)
 |-- trip_time: long (nullable = true)
 |-- base_passenger_fare: double (nullable = true)
 |-- tolls: double (nullable = true)
 |-- bcf: double (nullable = true)
 |-- sales_tax: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)
 |-- tips: double (nullable = true)
 |-- driver_pay: double (nullable = true)
 |-- shared_request_flag: string (nullable = true)
 |-- shared_match_flag: string (nullable = true)
 |-- access_a_

In [40]:
# Generate summary statistics 
hvfhv_sdf.describe().show(vertical=True)



-RECORD 0----------------------------------
 summary              | count              
 hvfhs_license_num    | 140526989          
 dispatching_base_num | 140526989          
 originating_base_num | 102871223          
 PULocationID         | 140526989          
 DOLocationID         | 140526989          
 trip_miles           | 140526989          
 trip_time            | 140526989          
 base_passenger_fare  | 140526989          
 tolls                | 140526989          
 bcf                  | 140526989          
 sales_tax            | 140526989          
 congestion_surcharge | 140526989          
 airport_fee          | 140526989          
 tips                 | 140526989          
 driver_pay           | 140526989          
 shared_request_flag  | 140526989          
 shared_match_flag    | 140526989          
 access_a_ride_flag   | 140526989          
 wav_request_flag     | 140526989          
 wav_match_flag       | 140526989          
-RECORD 1-----------------------

                                                                                

In [41]:
# Check for NULL values 
missing_values = hvfhv_sdf.select([F.sum(F.col(col).isNull().cast("int")).alias(col) for col in hvfhv_sdf.columns])
missing_values.show(vertical=True)



-RECORD 0------------------------
 hvfhs_license_num    | 0        
 dispatching_base_num | 0        
 originating_base_num | 37655766 
 request_datetime     | 0        
 on_scene_datetime    | 37655378 
 pickup_datetime      | 0        
 dropoff_datetime     | 0        
 PULocationID         | 0        
 DOLocationID         | 0        
 trip_miles           | 0        
 trip_time            | 0        
 base_passenger_fare  | 0        
 tolls                | 0        
 bcf                  | 0        
 sales_tax            | 0        
 congestion_surcharge | 0        
 airport_fee          | 0        
 tips                 | 0        
 driver_pay           | 0        
 shared_request_flag  | 0        
 shared_match_flag    | 0        
 access_a_ride_flag   | 0        
 wav_request_flag     | 0        
 wav_match_flag       | 0        



                                                                                

In [42]:
# Filter the dataset
hvfhv_sdf_filtered = hvfhv_sdf.filter(
    (F.col("base_passenger_fare") > 0) &
    (F.col("driver_pay") > 0) &
    (F.col("trip_miles") > 0.2) &
    (F.col("trip_time") > 60)
)

In [43]:
# Check row number of filtered dataset
row_count = hvfhv_sdf_filtered.count() 

                                                                                

In [44]:
# Print the result
print(f"The number of rows in the filtered DataFrame is: {row_count}")

The number of rows in the filtered DataFrame is: 140386155


In [45]:
# Filter the DataFrame for records within the specified time range
hvfhv_sdf = hvfhv_sdf_filtered.filter(
    (F.col("pickup_datetime") >= date1_ts) & 
    (F.col("pickup_datetime") <= date2_ts)
)

In [46]:
# Generate summary statistics for the dataset
hvfhv_sdf.describe().show(vertical=True)



-RECORD 0-----------------------------------
 summary              | count               
 hvfhs_license_num    | 140386155           
 dispatching_base_num | 140386155           
 originating_base_num | 102757819           
 PULocationID         | 140386155           
 DOLocationID         | 140386155           
 trip_miles           | 140386155           
 trip_time            | 140386155           
 base_passenger_fare  | 140386155           
 tolls                | 140386155           
 bcf                  | 140386155           
 sales_tax            | 140386155           
 congestion_surcharge | 140386155           
 airport_fee          | 140386155           
 tips                 | 140386155           
 driver_pay           | 140386155           
 shared_request_flag  | 140386155           
 shared_match_flag    | 140386155           
 access_a_ride_flag   | 140386155           
 wav_request_flag     | 140386155           
 wav_match_flag       | 140386155           
-RECORD 1-

                                                                                

## Calculate total hvfhv pick up per location and time bucket

In [47]:
# Retain only the specified columns
hvfhv_sdf = hvfhv_sdf.select("pickup_datetime", "PULocationID")

# Extract date and hour from pickup_datetime
hvfhv_sdf = hvfhv_sdf.withColumn("date", F.to_date("pickup_datetime")) \
                     .withColumn("hour", F.hour("pickup_datetime"))

# Create a 3-hour bucket column
hvfhv_sdf = hvfhv_sdf.withColumn("hour_bucket", F.floor(F.col("hour") / 3) * 3)

# Select only the relevant columns: date, hour_bucket, PULocationID
hvfhv_sdf = hvfhv_sdf.select("date", "hour_bucket", "PULocationID")

# Aggregate the data to count the total number of records for each combination of date, hour_bucket, and PULocationID
hvfhv_sdf_aggregated = hvfhv_sdf.groupBy("date", "hour_bucket", "PULocationID") \
                                 .agg(F.count("*").alias("total"))

In [48]:
# Show the first few rows to verify
hvfhv_sdf_aggregated.show(5, truncate=False)



+----------+-----------+------------+-----+
|date      |hour_bucket|PULocationID|total|
+----------+-----------+------------+-----+
|2023-11-01|0          |200         |46   |
|2023-11-01|6          |234         |661  |
|2023-11-01|12         |162         |711  |
|2023-11-01|15         |223         |407  |
|2023-11-01|15         |230         |1078 |
+----------+-----------+------------+-----+
only showing top 5 rows



                                                                                

In [49]:
# Generate the full set of possible combinations
dates = hvfhv_sdf.select("date").distinct().rdd.flatMap(lambda x: x).collect()
hours = list(range(0, 24, 3))  # 3-hour buckets: 0, 3, 6, 9, 12, 15, 18, 21
locations = hvfhv_sdf.select("PULocationID").distinct().rdd.flatMap(lambda x: x).collect()

# Create a DataFrame with all combinations of date, hour_bucket, and PULocationID
combinations = itertools.product(dates, hours, locations)
combinations_df = spark.createDataFrame(combinations, ["date", "hour_bucket", "PULocationID"])

# Perform a left join with the aggregated data to fill in the missing combinations with 0 counts
hvfhv_sdf_full = combinations_df.join(
    hvfhv_sdf_aggregated, on=["date", "hour_bucket", "PULocationID"], how="left"
).fillna(0, subset=["total"])

# Show the first few rows to verify
hvfhv_sdf_full.show(5, truncate=False)

[Stage 109:>                                                        (0 + 7) / 7]

+----------+-----------+------------+-----+
|date      |hour_bucket|PULocationID|total|
+----------+-----------+------------+-----+
|2023-12-08|9          |34          |60   |
|2024-01-09|6          |65          |279  |
|2024-01-09|6          |137         |496  |
|2024-03-04|12         |133         |210  |
|2024-05-03|18         |101         |85   |
+----------+-----------+------------+-----+
only showing top 5 rows



                                                                                

In [50]:
# Show the first few rows to verify
hvfhv_sdf_full.show(100, truncate=False)



+----------+-----------+------------+-----+
|date      |hour_bucket|PULocationID|total|
+----------+-----------+------------+-----+
|2023-11-08|0          |19          |10   |
|2023-11-08|0          |94          |54   |
|2023-11-08|0          |108         |25   |
|2023-11-08|0          |111         |2    |
|2023-11-08|0          |210         |51   |
|2023-11-08|0          |212         |89   |
|2023-11-08|0          |223         |100  |
|2023-11-08|0          |235         |97   |
|2023-11-08|0          |236         |91   |
|2023-11-08|0          |246         |254  |
|2023-11-08|0          |250         |46   |
|2023-11-08|0          |259         |81   |
|2023-12-08|9          |13          |392  |
|2023-12-08|9          |34          |60   |
|2023-12-08|9          |37          |782  |
|2023-12-08|9          |91          |539  |
|2023-12-08|9          |122         |134  |
|2023-12-08|9          |164         |1035 |
|2023-12-08|9          |182         |270  |
|2023-12-08|9          |202     

                                                                                

## Join 2 dataset (yellow and hvfhv), then calculate the pick up share of yellow taxis by location and time

In [51]:
# Perform an inner join to match yellow_taxi_sdf_full and hvfhv_sdf_full on date, hour_bucket, and PULocationID
combined_sdf = yellow_taxi_sdf_full.alias("yellow").join(
    hvfhv_sdf_full.alias("hvfhv"),
    on=["date", "hour_bucket", "PULocationID"],
    how="inner"
)

# Calculate the market share for yellow taxis
combined_sdf = combined_sdf.withColumn(
    "market_share_yellow",
    F.when(
        (F.col("yellow.total") > 0) & (F.col("hvfhv.total") == 0),
        100
    ).when(
        (F.col("yellow.total") == 0) & (F.col("hvfhv.total") > 0),
        0
    ).when(
        (F.col("yellow.total") > 0) & (F.col("hvfhv.total") > 0),
        (F.col("yellow.total") / (F.col("yellow.total") + F.col("hvfhv.total"))) * 100
    ).otherwise(None)  # If both totals are zero, we will filter these out in the next step
)

# Filter out rows where both totals are zero
combined_sdf = combined_sdf.filter(~((F.col("yellow.total") == 0) & (F.col("hvfhv.total") == 0)))

# Select the relevant columns and show the result
result_sdf = combined_sdf.select("date", "hour_bucket", "PULocationID", "market_share_yellow")
result_sdf.show(100, truncate=False)



+----------+-----------+------------+-------------------+
|date      |hour_bucket|PULocationID|market_share_yellow|
+----------+-----------+------------+-------------------+
|2023-11-01|0          |19          |0.0                |
|2023-11-01|0          |20          |0.6024096385542169 |
|2023-11-01|0          |38          |0.0                |
|2023-11-01|0          |50          |7.119741100323624  |
|2023-11-01|0          |61          |0.0                |
|2023-11-01|0          |73          |0.0                |
|2023-11-01|0          |78          |0.0                |
|2023-11-01|0          |83          |0.0                |
|2023-11-01|0          |98          |0.0                |
|2023-11-01|0          |99          |0.0                |
|2023-11-01|0          |114         |26.745435016111706 |
|2023-11-01|0          |121         |0.0                |
|2023-11-01|0          |122         |0.0                |
|2023-11-01|0          |126         |0.0                |
|2023-11-01|0 

                                                                                

In [52]:
# Generate summary statistics for the dataset
result_sdf.describe().show(vertical=True)

[Stage 147:>                                                        (0 + 8) / 8]

-RECORD 0---------------------------------
 summary             | count              
 hour_bucket         | 430757             
 PULocationID        | 430757             
 market_share_yellow | 430757             
-RECORD 1---------------------------------
 summary             | mean               
 hour_bucket         | 10.526503341791312 
 PULocationID        | 133.95188934828684 
 market_share_yellow | 5.0748786625049265 
-RECORD 2---------------------------------
 summary             | stddev             
 hour_bucket         | 6.864751949512822  
 PULocationID        | 75.62328828531652  
 market_share_yellow | 11.24862404203357  
-RECORD 3---------------------------------
 summary             | min                
 hour_bucket         | 0                  
 PULocationID        | 1                  
 market_share_yellow | 0.0                
-RECORD 4---------------------------------
 summary             | max                
 hour_bucket         | 21                 
 PULocation

                                                                                

# Export the curated dataset

In [53]:
yellow_taxi_sdf_full = yellow_taxi_sdf_full.withColumnRenamed("total", "total_yellow")
hvfhv_sdf_full = hvfhv_sdf_full.withColumnRenamed("total", "total_hvfhv")

In [None]:
import os

output_dir = '../data/curated'

# Check if the directory exists, and create if not
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Rename the `total` column to avoid conflicts
combined_sdf = combined_sdf.withColumnRenamed("total", "total_combined")
yellow_taxi_sdf_full = yellow_taxi_sdf_full.withColumnRenamed("total", "total_yellow")
hvfhv_sdf_full = hvfhv_sdf_full.withColumnRenamed("total", "total_hvfhv")

# Export combined_sdf in Parquet format
result_sdf.write.parquet(os.path.join(output_dir, 'yellow_hvfhv_sdf.parquet'), mode='overwrite')

# Export yellow_taxi_sdf_full in Parquet format
yellow_taxi_sdf_full.write.parquet(os.path.join(output_dir, 'yellow_taxi_sdf_full.parquet'), mode='overwrite')

# Export hvfhv_sdf_full in Parquet format
hvfhv_sdf_full.write.parquet(os.path.join(output_dir, 'hvfhv_sdf_full.parquet'), mode='overwrite')

print("Data has been successfully exported to the processed_taxi_hvfhv directory in Parquet format.")

24/08/26 14:56:35 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
[Stage 174:>                                                       (0 + 8) / 27]