In [0]:
import pandas as pd
import time
from pyspark.sql.functions import col, count, avg, sum as spark_sum, desc, when


In [0]:
df_base = spark.table("default.yellow_tripdata_2015_01")

# Create a larger dataset
df_large = df_base
for i in range(9):
    df_large = df_large.union(df_base)


In [0]:
df_large.sample(withReplacement=False, fraction=0.1)

DataFrame[VendorID: int, tpep_pickup_datetime: timestamp, tpep_dropoff_datetime: timestamp, passenger_count: int, trip_distance: double, pickup_longitude: double, pickup_latitude: double, RateCodeID: int, store_and_fwd_flag: string, dropoff_longitude: double, dropoff_latitude: double, payment_type: int, fare_amount: double, extra: double, mta_tax: double, tip_amount: double, tolls_amount: double, improvement_surcharge: double, total_amount: double]

In [0]:
df_filtered = df_large.select(
    "VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime",
    "passenger_count", "trip_distance", "fare_amount", "tip_amount", 
    "total_amount", "pickup_latitude", "pickup_longitude")


In [0]:
from pyspark.sql.functions import col

# Filter trips with valid passenger count and reasonable trip distance
df_filtered = df_filtered.filter(
    (col("passenger_count") > 0) &
    (col("trip_distance") > 0) &
    (col("trip_distance") < 100)
)

In [0]:
from pyspark.sql.functions import col, mean, when, unix_timestamp

# ----------------------------
# 1. Remove rows with invalid critical values
# ----------------------------
critical_cols = ["passenger_count", "trip_distance", "fare_amount", "total_amount"]

df_filtered = df_large
for c in critical_cols:
    df_filtered = df_filtered.filter(col(c) > 0)

# ----------------------------
# 2. Fill zeros or nulls in optional columns with column average
# ----------------------------
replace_cols = ["tip_amount", "tolls_amount", "improvement_surcharge", "extra", "mta_tax"]

for c in replace_cols:
    avg_val = df_filtered.select(mean(col(c))).collect()[0][0]
    df_filtered = df_filtered.withColumn(
        c,
        when(col(c).isNull() | (col(c) == 0), avg_val).otherwise(col(c))
    )

# ----------------------------
# 3. Ensure datetime columns are valid
# ----------------------------
df_filtered = df_filtered.filter(
    (col("tpep_pickup_datetime").isNotNull()) &
    (col("tpep_dropoff_datetime").isNotNull())
)

# ----------------------------
# 4. Calculate trip_duration safely
# ----------------------------
df_filtered = df_filtered.withColumn(
    "trip_duration",
    (unix_timestamp(col("tpep_dropoff_datetime")) - unix_timestamp(col("tpep_pickup_datetime"))) / 60
).filter(col("trip_duration") > 0)

# ----------------------------
# 5. filter trips within NYC coordinates
# ----------------------------
df_filtered = df_filtered.filter(
    (col("pickup_latitude").between(40.5, 41)) &
    (col("pickup_longitude").between(-74.5, -73.5))
)


In [0]:
from pyspark.sql.functions import unix_timestamp, round

# Trip duration in minutes
df_filtered = df_filtered.withColumn(
    "trip_duration",
    (unix_timestamp(col("tpep_dropoff_datetime")) - unix_timestamp(col("tpep_pickup_datetime"))) / 60
)

# Tip percentage
df_filtered = df_filtered.withColumn(
    "tip_percentage",
    round((col("tip_amount") / col("fare_amount")) * 100, 2)
)

df_filtered.show()

+--------+--------------------+---------------------+---------------+-------------+------------------+------------------+----------+------------------+------------------+------------------+------------+-----------+-------------------+-------+-----------------+------------------+---------------------+------------+------------------+--------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|  pickup_longitude|   pickup_latitude|RateCodeID|store_and_fwd_flag| dropoff_longitude|  dropoff_latitude|payment_type|fare_amount|              extra|mta_tax|       tip_amount|      tolls_amount|improvement_surcharge|total_amount|     trip_duration|tip_percentage|
+--------+--------------------+---------------------+---------------+-------------+------------------+------------------+----------+------------------+------------------+------------------+------------+-----------+-------------------+-------+-----------------+------------------+---------------------+-

In [0]:
df_filtered.take(5)

[Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2015, 1, 3, 14, 7, 26), tpep_dropoff_datetime=datetime.datetime(2015, 1, 3, 14, 12, 56), passenger_count=1, trip_distance=0.64, pickup_longitude=-73.94904327392578, pickup_latitude=40.78160858154297, RateCodeID=1, store_and_fwd_flag='N', dropoff_longitude=-73.95654296875, dropoff_latitude=40.77803039550781, payment_type=1, fare_amount=5.5, extra=0.30903683757945827, mta_tax=0.5, tip_amount=1.1, tolls_amount=0.2416871612181792, improvement_surcharge=0.3, total_amount=7.4, trip_duration=5.5, tip_percentage=20.0),
 Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2015, 1, 3, 14, 7, 26), tpep_dropoff_datetime=datetime.datetime(2015, 1, 3, 14, 14, 49), passenger_count=1, trip_distance=1.42, pickup_longitude=-73.96768951416016, pickup_latitude=40.76264190673828, RateCodeID=1, store_and_fwd_flag='N', dropoff_longitude=-73.97284698486328, dropoff_latitude=40.74626541137695, payment_type=2, fare_amount=7.0, extra=0.30903683757945827,

In [0]:
from pyspark.sql.functions import avg, sum, round

df_agg = df_filtered.groupBy("VendorID").agg(
    round(avg("trip_distance"), 2).alias("avg_trip_distance"),
    round(sum("fare_amount"), 2).alias("total_fare"),
    round(avg("tip_percentage"), 2).alias("avg_tip_pct"),
    round(avg("trip_duration"), 2).alias("avg_trip_duration")
)

df_agg.show()

+--------+-----------------+-------------+-----------+-----------------+
|VendorID|avg_trip_distance|   total_fare|avg_tip_pct|avg_trip_duration|
+--------+-----------------+-------------+-----------+-----------------+
|       1|            25.97|6.903424952E8|      25.54|            12.36|
|       2|             2.85|7.848906811E8|      25.18|            14.24|
+--------+-----------------+-------------+-----------+-----------------+



In [0]:
df_filtered.createOrReplaceTempView("taxi_trips")

In [0]:
from pyspark.sql.functions import when

df_filtered = df_filtered.withColumn(
    "tip_percentage",
    round(when(col("fare_amount") != 0, col("tip_amount") / col("fare_amount") * 100).otherwise(None), 2)
)

In [0]:
%sql
SELECT *
FROM taxi_trips
ORDER BY fare_amount DESC
LIMIT 5

VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,pickup_longitude,pickup_latitude,RateCodeID,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,trip_duration,tip_percentage
1,2015-01-22T21:12:26.000Z,2015-01-22T21:20:36.000Z,1,1.7,-73.96153259277344,40.77063751220703,1,N,-73.97850799560547,40.74951553344727,2,4008.0,0.5,0.5,1.851125934112553,0.2416871612181792,0.3,4009.3,8.166666666666666,0.05
1,2015-01-22T21:12:26.000Z,2015-01-22T21:20:36.000Z,1,1.7,-73.96153259277344,40.77063751220703,1,N,-73.97850799560547,40.74951553344727,2,4008.0,0.5,0.5,1.851125934112553,0.2416871612181792,0.3,4009.3,8.166666666666666,0.05
1,2015-01-22T21:12:26.000Z,2015-01-22T21:20:36.000Z,1,1.7,-73.96153259277344,40.77063751220703,1,N,-73.97850799560547,40.74951553344727,2,4008.0,0.5,0.5,1.851125934112553,0.2416871612181792,0.3,4009.3,8.166666666666666,0.05
1,2015-01-22T21:12:26.000Z,2015-01-22T21:20:36.000Z,1,1.7,-73.96153259277344,40.77063751220703,1,N,-73.97850799560547,40.74951553344727,2,4008.0,0.5,0.5,1.851125934112553,0.2416871612181792,0.3,4009.3,8.166666666666666,0.05
1,2015-01-22T21:12:26.000Z,2015-01-22T21:20:36.000Z,1,1.7,-73.96153259277344,40.77063751220703,1,N,-73.97850799560547,40.74951553344727,2,4008.0,0.5,0.5,1.851125934112553,0.2416871612181792,0.3,4009.3,8.166666666666666,0.05


In [0]:
%sql
SELECT passenger_count, AVG(fare_amount) AS avg_fare
FROM taxi_trips
GROUP BY passenger_count
ORDER BY avg_fare DESC

passenger_count,avg_fare
9,74.41666666666667
8,33.5
7,13.5
2,12.353915909078612
4,12.12988564700566
3,12.06589180693754
5,11.965361292675723
6,11.805331844890976
1,11.74596665091072


In [0]:
%sql
SELECT VendorID, AVG(try_divide(tip_amount, fare_amount) * 100) AS avg_tip_pct
FROM taxi_trips
GROUP BY VendorID
ORDER BY avg_tip_pct DESC;

VendorID,avg_tip_pct
1,25.537875549198585
2,25.179177727395285


In [0]:
%sql
SELECT VendorID, COUNT(*) AS long_trip_count, AVG(fare_amount) AS avg_fare_long_trip,
       AVG(try_divide(tip_amount, fare_amount) * 100) AS avg_tip_pct_long_trip
FROM taxi_trips
WHERE trip_distance > 10
GROUP BY VendorID


VendorID,long_trip_count,avg_fare_long_trip,avg_tip_pct_long_trip
1,2588060,44.87887896725756,34.6448725464399
2,3182850,44.95026347455895,15.70121635420941


In [0]:
%sql
SELECT CASE
           WHEN trip_duration < 10 THEN 'Short'
           WHEN trip_duration BETWEEN 10 AND 30 THEN 'Medium'
           ELSE 'Long'
       END AS trip_type,
       COUNT(*) AS trip_count,
       AVG(fare_amount) AS avg_fare,
       AVG(try_divide(tip_amount, fare_amount) * 100) AS avg_tip_pct
FROM taxi_trips
GROUP BY trip_type
ORDER BY trip_type;

trip_type,trip_count,avg_fare,avg_tip_pct
Long,5977130,39.05219018157575,24.944778656206477
Medium,56280970,14.931591365605742,22.23529089570521
Short,62044130,6.470373798455994,28.211111438534907


In [0]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator

In [0]:
# Select important numeric columns for prediction
feature_cols = ["trip_distance", "trip_duration", "passenger_count", "tip_amount", "tolls_amount"]

# Vector assembler for MLlib
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

df_ml = assembler.transform(df_filtered).select("features", "fare_amount")

In [0]:
train_df, test_df = df_ml.randomSplit([0.8, 0.2], seed=42)

In [0]:
lr = LinearRegression(featuresCol="features", labelCol="fare_amount")
lr_model = lr.fit(train_df)

In [0]:
predictions = lr_model.transform(test_df)

evaluator = RegressionEvaluator(
    labelCol="fare_amount", predictionCol="prediction", metricName="rmse"
)

rmse = evaluator.evaluate(predictions)
r2 = RegressionEvaluator(
    labelCol="fare_amount", predictionCol="prediction", metricName="r2"
).evaluate(predictions)

print(f"Root Mean Squared Error (RMSE): {rmse}")
print(f"R-squared (R2): {r2}")

Root Mean Squared Error (RMSE): 8.457218703796803
R-squared (R2): 0.2756538664193098


In [0]:
display(predictions.select("features", "fare_amount", "prediction"))

features,fare_amount,prediction
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.03333333333333333"",""5.0"",""1.851125934112553"",""0.2416871612181792""]}",52.0,10.629422225743868
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.06666666666666667"",""1.0"",""0.9"",""0.2416871612181792""]}",2.5,10.527014552618034
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.06666666666666667"",""1.0"",""19.66"",""0.2416871612181792""]}",98.3,10.52715118566426
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.08333333333333333"",""1.0"",""1.851125934112553"",""0.2416871612181792""]}",2.5,10.527529695239496
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.1"",""6.0"",""1.851125934112553"",""0.2416871612181792""]}",2.5,10.657309381378406
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.11666666666666667"",""1.0"",""1.851125934112553"",""0.2416871612181792""]}",52.0,10.528546125979895
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.11666666666666667"",""6.0"",""6.0"",""0.2416871612181792""]}",2.5,10.65784781387558
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.13333333333333333"",""1.0"",""1.851125934112553"",""0.2416871612181792""]}",2.5,10.529054341350095
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.15"",""1.0"",""1.851125934112553"",""0.2416871612181792""]}",2.5,10.529562556720297
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.01"",""0.15"",""1.0"",""1.851125934112553"",""0.2416871612181792""]}",2.5,10.529562556720297


In [0]:
df_filtered.explain(True)

== Parsed Logical Plan ==
Project [VendorID#14961, tpep_pickup_datetime#14962, tpep_dropoff_datetime#14963, passenger_count#14964, trip_distance#14965, pickup_longitude#14966, pickup_latitude#14967, RateCodeID#14968, store_and_fwd_flag#14969, dropoff_longitude#14970, dropoff_latitude#14971, payment_type#14972, fare_amount#14973, extra#15158, mta_tax#15160, tip_amount#15152, tolls_amount#15154, improvement_surcharge#15156, total_amount#14979, trip_duration#15901, round(CASE WHEN NOT (fare_amount#14973 = cast(0 as double)) THEN ((tip_amount#15152 / fare_amount#14973) * cast(100 as double)) ELSE cast(null as double) END, 2) AS tip_percentage#17354]
+- Project [VendorID#14961, tpep_pickup_datetime#14962, tpep_dropoff_datetime#14963, passenger_count#14964, trip_distance#14965, pickup_longitude#14966, pickup_latitude#14967, RateCodeID#14968, store_and_fwd_flag#14969, dropoff_longitude#14970, dropoff_latitude#14971, payment_type#14972, fare_amount#14973, extra#15158, mta_tax#15160, tip_amount

In [0]:
from pyspark.sql.functions import col

# Transformation (lazy)
df_transformed = df_filtered.withColumn("fare_per_mile", col("fare_amount") / col("trip_distance"))

# No computation happens yet
print("Transformation defined but not executed.")

# Action (eager)
row_count = df_transformed.count()  # triggers execution
print(f"Action executed. Total rows: {row_count}")

Transformation defined but not executed.
Action executed. Total rows: 124302230


In [0]:
# Write as Delta table (default format)
df_filtered.write \
    .mode("overwrite") \
    .partitionBy("VendorID") \
    .saveAsTable("nyc_taxi_processed")

In [0]:
# Read it back
df = spark.table("nyc_taxi_processed")
df.show()

+--------+--------------------+---------------------+---------------+-------------+------------------+------------------+----------+------------------+------------------+------------------+------------+-----------+-------------------+-------+-----------------+------------------+---------------------+------------+------------------+--------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|  pickup_longitude|   pickup_latitude|RateCodeID|store_and_fwd_flag| dropoff_longitude|  dropoff_latitude|payment_type|fare_amount|              extra|mta_tax|       tip_amount|      tolls_amount|improvement_surcharge|total_amount|     trip_duration|tip_percentage|
+--------+--------------------+---------------------+---------------+-------------+------------------+------------------+----------+------------------+------------------+------------------+------------+-----------+-------------------+-------+-----------------+------------------+---------------------+-