In [1]:
from pyspark.sql import SparkSession
import pandas as pd
import matplotlib.pyplot as plt
import time



# Create a SparkSession instance (an entry point to all Spark functions)
spark = SparkSession.builder.appName("a4").getOrCreate()




25/04/25 19:20:12 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
                                                                                

+---------------+------------+------------+------------+
|passenger_count|pulocationid|dolocationid|total_amount|
+---------------+------------+------------+------------+
|            1.0|       151.0|       239.0|        9.95|
|            1.0|       239.0|       246.0|        16.3|
|            3.0|       236.0|       236.0|         5.8|
|            5.0|       193.0|       193.0|        7.55|
|            5.0|       193.0|       193.0|       55.55|
|            5.0|       193.0|       193.0|       13.31|
|            5.0|       193.0|       193.0|       55.55|
|            1.0|       163.0|       229.0|        9.05|
|            1.0|       229.0|         7.0|        18.5|
|            2.0|       141.0|       234.0|        13.0|
+---------------+------------+------------+------------+
only showing top 10 rows



In [33]:
df = spark.read.csv('gs://dataproc-staging-us-central1-721720945833-6cwqbnms/2019-01-h1.csv', header=True, inferSchema=True)
# Only care about these rows
filtered = df.select("passenger_count", "pulocationid", "dolocationid", "total_amount")

filtered.show(10)

                                                                                

+---------------+------------+------------+------------+
|passenger_count|pulocationid|dolocationid|total_amount|
+---------------+------------+------------+------------+
|            1.0|       151.0|       239.0|        9.95|
|            1.0|       239.0|       246.0|        16.3|
|            3.0|       236.0|       236.0|         5.8|
|            5.0|       193.0|       193.0|        7.55|
|            5.0|       193.0|       193.0|       55.55|
|            5.0|       193.0|       193.0|       13.31|
|            5.0|       193.0|       193.0|       55.55|
|            1.0|       163.0|       229.0|        9.05|
|            1.0|       229.0|         7.0|        18.5|
|            2.0|       141.0|       234.0|        13.0|
+---------------+------------+------------+------------+
only showing top 10 rows



In [34]:
# Splitting into test and train 
train_df, test_df = filtered.randomSplit([.8, .2], seed=42)
print(train_df.count(), test_df.count())




2920849 730150


                                                                                

In [48]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml import Pipeline

# First, we have to make a vector assembler, so that spark can actually read the columns
assembler = VectorAssembler(
    inputCols=["passenger_count", "pulocationid", "dolocationid"],
    outputCol="features"
)


#next, we make the regressor.
dt = DecisionTreeRegressor(
    featuresCol="features",
    labelCol="total_amount",
    maxBins=200
)

#to make things easier, we can make a pipeline that does the things above
pipeline = Pipeline(stages=[assembler, dt])

#fit the model on the training data set.
pipelineModel = pipeline.fit(train_df)

predDf = pipelineModel.transform(test_df)

predDf.show(10)


[Stage 369:>                                                        (0 + 1) / 1]

+---------------+------------+------------+------------+---------------+------------------+
|passenger_count|pulocationid|dolocationid|total_amount|       features|        prediction|
+---------------+------------+------------+------------+---------------+------------------+
|            0.0|         4.0|         4.0|         4.3|  [0.0,4.0,4.0]|23.931810506566894|
|            0.0|         4.0|        33.0|       17.75| [0.0,4.0,33.0]|19.266741166043648|
|            0.0|         4.0|        68.0|        15.8| [0.0,4.0,68.0]| 18.39775119715606|
|            0.0|         4.0|        79.0|        9.75| [0.0,4.0,79.0]| 18.39775119715606|
|            0.0|         4.0|       125.0|         9.3|[0.0,4.0,125.0]| 18.39775119715606|
|            0.0|         4.0|       170.0|       11.15|[0.0,4.0,170.0]| 18.39775119715606|
|            0.0|         7.0|         7.0|        0.31|  [0.0,7.0,7.0]|19.266741166043648|
|            0.0|         7.0|         7.0|         6.3|  [0.0,7.0,7.0]|19.26674

                                                                                

In [49]:
from pyspark.ml.evaluation import RegressionEvaluator

#Showing predictions alongside the original features
predDf.show(10)

#Evaluat the model using RMSE
evaluator = RegressionEvaluator(
    labelCol="total_amount", predictionCol="prediction", metricName="rmse"
)

RMSE = evaluator.evaluate(predDf)
print(f"(RMSE) on test data = {RMSE:.2f}")

                                                                                

+---------------+------------+------------+------------+---------------+------------------+
|passenger_count|pulocationid|dolocationid|total_amount|       features|        prediction|
+---------------+------------+------------+------------+---------------+------------------+
|            0.0|         1.0|         1.0|      116.75|  [0.0,1.0,1.0]|23.931810506566894|
|            0.0|         4.0|        17.0|        20.3| [0.0,4.0,17.0]|19.266741166043648|
|            0.0|         4.0|        68.0|        12.8| [0.0,4.0,68.0]| 18.39775119715606|
|            0.0|         4.0|        79.0|        6.35| [0.0,4.0,79.0]| 18.39775119715606|
|            0.0|         4.0|        90.0|        15.8| [0.0,4.0,90.0]| 18.39775119715606|
|            0.0|         4.0|       125.0|       13.55|[0.0,4.0,125.0]| 18.39775119715606|
|            0.0|         4.0|       170.0|        14.3|[0.0,4.0,170.0]| 18.39775119715606|
|            0.0|         4.0|       264.0|         8.3|[0.0,4.0,264.0]| 18.3977



(RMSE) on test data = 60.13


                                                                                