In [9]:
# Import
from pyspark.sql import SparkSession

# Start Spark session
spark = SparkSession.builder.appName("A4Assignment").getOrCreate()

# Load the CSV file from GCS
df = spark.read.csv(
    "gs://dataproc-staging-us-central1-459220959832-gvez90q4/2019-01-h1.csv", 
    header=True,
    inferSchema=True
)

# Show first 5 rows to verify
df.show(5)


                                                                                

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|vendorid|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|ratecodeid|store_and_fwd_flag|pulocationid|dolocationid|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|     1.0| 2019-01-01 00:46:40|  2019-01-01 00:53:20|            1.0|          1.5|       1.0|                 N|       151.0|       239.0|         1.0|        7.0|  0.5|    0.5|      1.65|         0.0|                  0.3

In [10]:
# Select only needed columns
columns = ["passenger_count", "pulocationid", "dolocationid", "total_amount"]
df_selected = df.select(columns)

# Split into train and test sets
trainDF, testDF = df_selected.randomSplit([0.8, 0.2], seed=42)

# Show first 10 rows of trainDF to verify
trainDF.show(10)


[Stage 6:>                                                          (0 + 1) / 1]

+---------------+------------+------------+------------+
|passenger_count|pulocationid|dolocationid|total_amount|
+---------------+------------+------------+------------+
|            0.0|         1.0|         1.0|        90.0|
|            0.0|         1.0|         1.0|      101.39|
|            0.0|         4.0|         4.0|         4.3|
|            0.0|         4.0|         4.0|         4.8|
|            0.0|         4.0|         4.0|        5.75|
|            0.0|         4.0|        33.0|       17.75|
|            0.0|         4.0|        68.0|        15.8|
|            0.0|         4.0|        68.0|       16.55|
|            0.0|         4.0|        79.0|         5.3|
|            0.0|         4.0|        79.0|         5.8|
+---------------+------------+------------+------------+
only showing top 10 rows



                                                                                

In [12]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml import Pipeline

# Assemble features into a single vector column
assembler = VectorAssembler(
    inputCols=["passenger_count", "pulocationid", "dolocationid"],
    outputCol="features"
)

# Create the Decision Tree Regressor
dt = DecisionTreeRegressor(
    featuresCol="features",
    labelCol="total_amount",
    maxBins=131
)

# Create a pipeline
pipeline = Pipeline(stages=[assembler, dt])

# Fit the pipeline model
model = pipeline.fit(trainDF)


25/04/27 05:54:57 ERROR TransportResponseHandler: Still have 1 requests outstanding when connection from /10.128.0.2:41408 is closed
25/04/27 05:54:57 WARN BlockManagerMasterEndpoint: Error trying to remove shuffle 11 from block manager BlockManagerId(11, cluster-3c7f-w-0.c.consummate-gift-449119-b0.internal, 45217, None)
java.io.IOException: Connection from /10.128.0.2:41408 closed
	at org.apache.spark.network.client.TransportResponseHandler.channelInactive(TransportResponseHandler.java:147) ~[spark-network-common_2.12-3.5.3.jar:3.5.3]
	at org.apache.spark.network.server.TransportChannelHandler.channelInactive(TransportChannelHandler.java:117) ~[spark-network-common_2.12-3.5.3.jar:3.5.3]
	at io.netty.channel.AbstractChannelHandlerContext.invokeChannelInactive(AbstractChannelHandlerContext.java:305) ~[netty-transport-4.1.100.Final.jar:4.1.100.Final]
	at io.netty.channel.AbstractChannelHandlerContext.invokeChannelInactive(AbstractChannelHandlerContext.java:281) ~[netty-transport-4.1.100

In [13]:
# Predict on test data
predictions = model.transform(testDF)

# Show predictions (first 10 rows)
predictions.select("passenger_count", "pulocationid", "dolocationid", "total_amount", "prediction").show(10)


[Stage 35:>                                                         (0 + 1) / 1]

+---------------+------------+------------+------------+-----------------+
|passenger_count|pulocationid|dolocationid|total_amount|       prediction|
+---------------+------------+------------+------------+-----------------+
|            0.0|         1.0|         1.0|      116.75|24.03192497663894|
|            0.0|         4.0|        17.0|        20.3|19.23259084947156|
|            0.0|         4.0|        68.0|        12.8|17.01111277291278|
|            0.0|         4.0|        79.0|        6.35|17.01111277291278|
|            0.0|         4.0|        90.0|        15.8|17.01111277291278|
|            0.0|         4.0|       125.0|       13.55|17.01111277291278|
|            0.0|         4.0|       170.0|        14.3|17.01111277291278|
|            0.0|         4.0|       264.0|         8.3|17.01111277291278|
|            0.0|         7.0|         7.0|         7.3|19.23259084947156|
|            0.0|         7.0|         7.0|         7.8|19.23259084947156|
+---------------+--------

                                                                                

In [14]:
from pyspark.ml.evaluation import RegressionEvaluator

# Create evaluator
evaluator = RegressionEvaluator(
    labelCol="total_amount",
    predictionCol="prediction",
    metricName="rmse"
)

# Evaluate model
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE):", rmse)




Root Mean Squared Error (RMSE): 60.117668618791946


                                                                                