In [0]:
from pyspark.ml.regression import LinearRegressionModel

MODEL_PATH = "dbfs:/models/taxi_fare_lr_v1"

model = LinearRegressionModel.load(MODEL_PATH)

print("Model loaded successfully")


Model loaded successfully


In [0]:
silver_df = spark.read.table("default.taxi_silver")

print("Rows to score:", silver_df.count())
display(silver_df.limit(5))


Rows to score: 10000


pickup_datetime,dropoff_datetime,passenger_count,fare_amount,trip_distance_km,trip_duration_min,surge_multiplier,hour,dayofweek,month,year,pickup_zone,dropoff_zone
2025-10-04 06:01:44,2025-10-04 06:27:03,1,248.96,10.732,25.32,1.0,6,7,10,2025,Rajajinagar,Koramangala
2025-10-09 11:00:38,2025-10-09 11:28:12,2,340.41,16.604,27.57,1.0,11,5,10,2025,Bannerghatta,Hebbal
2025-10-07 00:20:47,2025-10-07 00:33:47,2,200.3,8.718,13.01,1.0,0,3,10,2025,Yelahanka,Rajajinagar
2025-10-01 08:37:05,2025-10-01 08:44:57,2,95.69,2.259,7.87,1.0,8,4,10,2025,Whitefield,Whitefield
2025-10-01 03:37:41,2025-10-01 03:52:45,1,201.14,8.568,15.07,1.0,3,4,10,2025,MG_Road,Bannerghatta


In [0]:
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import functions as F

feature_cols = [
    "trip_distance_km",
    "trip_duration_min",
    "hour",
    "dayofweek",
    "passenger_count"
]

assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features"
)

scoring_df = assembler.transform(silver_df)


In [0]:
predictions_df = model.transform(scoring_df)


In [0]:
gold_df = predictions_df.select(
    "pickup_datetime",
    "dropoff_datetime",
    "trip_distance_km",
    "trip_duration_min",
    "hour",
    "dayofweek",
    "passenger_count",
    "fare_amount",
    F.col("prediction").alias("predicted_fare")
).withColumn(
    "prediction_timestamp", F.current_timestamp()
)

display(gold_df.limit(10))


pickup_datetime,dropoff_datetime,trip_distance_km,trip_duration_min,hour,dayofweek,passenger_count,fare_amount,predicted_fare,prediction_timestamp
2025-10-04 06:01:44,2025-10-04 06:27:03,10.732,25.32,6,7,1,248.96,258.0734401395553,2025-12-23T11:55:11.517Z
2025-10-09 11:00:38,2025-10-09 11:28:12,16.604,27.57,11,5,2,340.41,354.5868282948265,2025-12-23T11:55:11.517Z
2025-10-07 00:20:47,2025-10-07 00:33:47,8.718,13.01,0,3,2,200.3,208.69214471652413,2025-12-23T11:55:11.517Z
2025-10-01 08:37:05,2025-10-01 08:44:57,2.259,7.87,8,4,2,95.69,99.64693777544156,2025-12-23T11:55:11.517Z
2025-10-01 03:37:41,2025-10-01 03:52:45,8.568,15.07,3,4,1,201.14,209.0463365513461,2025-12-23T11:55:11.517Z
2025-10-07 08:45:20,2025-10-07 09:13:11,10.698,27.85,8,3,1,252.25,261.83522226563446,2025-12-23T11:55:11.517Z
2025-10-05 23:10:08,2025-10-05 23:31:16,8.161,21.15,23,1,1,204.13,212.81971119235328,2025-12-23T11:55:11.517Z
2025-10-07 04:50:56,2025-10-07 04:56:27,5.011,5.52,4,3,2,133.44,139.41686663209785,2025-12-23T11:55:11.517Z
2025-10-05 00:29:28,2025-10-05 00:44:26,11.169,14.97,0,1,1,240.0,250.04962176923976,2025-12-23T11:55:11.517Z
2025-10-04 23:30:56,2025-10-04 23:59:18,17.953,28.37,23,7,1,542.78,376.9084597043715,2025-12-23T11:55:11.517Z


In [0]:
gold_df.write.mode("overwrite").saveAsTable("default.taxi_gold_predictions")

print("✅ Gold table created: default.taxi_gold_predictions")


✅ Gold table created: default.taxi_gold_predictions
