In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit
import pandas as pd
import os

# Starts a Spark session
spark = (
    SparkSession.builder
        .appName("Taxi vs Rideshare Profitability")
        .config("spark.sql.repl.eagerEval.enabled", True)
        .config("spark.sql.parquet.cacheMetadata", "true")
        .config("spark.sql.session.timeZone", "Etc/UTC")
        .getOrCreate()
)

# Define months
months = ["2024-01","2024-02","2024-03","2024-04","2024-05","2024-06"]

# Load in data files 
yellow_files = [f"data/yellow/yellow_tripdata_{m}.parquet" for m in months]
fhvhv_files  = [f"data/fhvhv/fhvhv_tripdata_{m}.parquet" for m in months]

df_yellow = (
    spark.read.parquet(*yellow_files)
         .withColumn("service_type", lit("yellow"))
)
df_fhvhv = (
    spark.read.parquet(*fhvhv_files)
         .withColumn("service_type", lit("hv_fhv"))
)

# Merge
df = df_yellow.unionByName(df_fhvhv, allowMissingColumns=True)

# header of yellow
#df_yellow.printSchema()
# header of fhvhv
#df_fhvhv.printSchema()

electricity = spark.read.csv("data/external/electricity.csv", header=True, inferSchema=True)
fuel = spark.read.csv("data/external/fuel.csv", header=True, inferSchema=True)
#electricity.show(); fuel.show()



Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/08/13 21:27:57 WARN Utils: Your hostname, LAPTOP-E04ANIN1, resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/08/13 21:27:57 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/13 21:27:58 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                

In [5]:
# Preprocess data
from pyspark.sql import functions as F
from pyspark.sql.functions import col, to_timestamp, coalesce, unix_timestamp, when

# Standardise timestamps
df = (
    df.withColumn("pickup_ts",  to_timestamp(coalesce(col("tpep_pickup_datetime"),  col("pickup_datetime"))))
      .withColumn("dropoff_ts", to_timestamp(coalesce(col("tpep_dropoff_datetime"), col("dropoff_datetime"))))
)

# Remove rows with null pickup or dropoff timestamps
df = df.filter(col("pickup_ts").isNotNull() & col("dropoff_ts").isNotNull())

# Standardise distance and keep positive distances only and not null
df = (
    df.withColumn("distance_mi", coalesce(col("trip_distance"), col("trip_miles")))
      .filter(col("distance_mi").isNotNull() & (col("distance_mi") > 0))
)

# Standardise trip time and not null
df = df.withColumn(
    "trip_time_s",
    when(col("trip_time").isNotNull(), col("trip_time"))  # use HVFHV's trip_time if present
    .otherwise(unix_timestamp(col("dropoff_ts")) - unix_timestamp(col("pickup_ts")))  # else compute
)
df = df.filter(col("trip_time_s").isNotNull() & (col("trip_time_s") > 0))

# Deduplicate rows
dedupe_key = [c for c in ["service_type","pickup_ts","dropoff_ts","PULocationID","DOLocationID","distance_mi","trip_time_s"] if c in df.columns]
df = df.dropDuplicates(dedupe_key)

# show the schema of the merged dataframe
df.printSchema()


root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- service_type: string (nullable = false)
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: strin