In [23]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("cs131_a4").getOrCreate()

In [25]:
df_raw = (spark.read
          .csv("2019-01-h1.csv", header=True, inferSchema=True))

cols = ["passenger_count", "PULocationID", "DOLocationID", "total_amount"]
taxiDF = df_raw.select(*cols)
taxiDF.show(10)   

+---------------+------------+------------+------------+
|passenger_count|PULocationID|DOLocationID|total_amount|
+---------------+------------+------------+------------+
|            1.0|       151.0|       239.0|        9.95|
|            1.0|       239.0|       246.0|        16.3|
|            3.0|       236.0|       236.0|         5.8|
|            5.0|       193.0|       193.0|        7.55|
|            5.0|       193.0|       193.0|       55.55|
|            5.0|       193.0|       193.0|       13.31|
|            5.0|       193.0|       193.0|       55.55|
|            1.0|       163.0|       229.0|        9.05|
|            1.0|       229.0|         7.0|        18.5|
|            2.0|       141.0|       234.0|        13.0|
+---------------+------------+------------+------------+
only showing top 10 rows



In [26]:
df_raw.show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|vendorid|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|ratecodeid|store_and_fwd_flag|pulocationid|dolocationid|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|     1.0| 2019-01-01 00:46:40|  2019-01-01 00:53:20|            1.0|          1.5|       1.0|                 N|       151.0|       239.0|         1.0|        7.0|  0.5|    0.5|      1.65|         0.0|                  0.3

In [27]:
trainDF, testDF = taxiDF.randomSplit([0.8, 0.2], seed=42)
print("Rows for train:", trainDF.count(), " Rows for test:", testDF.count())

Rows for train: 2921462  Rows for test: 729537


In [28]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator

assembler = VectorAssembler(
    inputCols=["passenger_count", "PULocationID", "DOLocationID"],
    outputCol="features")

dtr = (DecisionTreeRegressor(labelCol="total_amount",
                             featuresCol="features")
       .setMaxBins(512))                # avoids category overflow

pipeline = Pipeline(stages=[assembler, dtr])

In [29]:
model = pipeline.fit(trainDF)
predDF = model.transform(testDF)
predDF.select("passenger_count", "PULocationID", "DOLocationID", "prediction") \
      .show(10)

rmse = (RegressionEvaluator(labelCol="total_amount",
                            predictionCol="prediction",
                            metricName="rmse")
        .evaluate(predDF))
print("RMSE =", rmse)
spark.stop()

+---------------+------------+------------+------------------+
|passenger_count|PULocationID|DOLocationID|        prediction|
+---------------+------------+------------+------------------+
|            0.0|         4.0|         4.0|17.922369380316045|
|            0.0|         4.0|        79.0|17.922369380316045|
|            0.0|         4.0|        90.0|17.922369380316045|
|            0.0|         4.0|       170.0|17.922369380316045|
|            0.0|         7.0|         7.0|17.922369380316045|
|            0.0|         7.0|        48.0|17.922369380316045|
|            0.0|         7.0|       164.0|17.922369380316045|
|            0.0|         9.0|        73.0|17.922369380316045|
|            0.0|        13.0|        13.0|17.922369380316045|
|            0.0|        13.0|        13.0|17.922369380316045|
+---------------+------------+------------+------------------+
only showing top 10 rows

RMSE = 24.01571105495559
