# Linear Regression: Predicting total amount

In [1]:
# Import Libraries
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.regression import LinearRegression

In [2]:
# Create SparkSession
spark = SparkSession.builder \
    .appName("TLC Linear Regression") \
    .getOrCreate()

24/12/30 03:09:41 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [3]:
# Path lists
fact_trip = "hdfs://10.128.0.59:8020/data_warehouse/fact_trip"
dim_datetime = "hdfs://10.128.0.59:8020/data_warehouse/dim_datetime"

output = "uber-analysis-439804.query_result.model_evaluation"

In [4]:
df_fact = spark.read \
    .format("parquet") \
    .option("path", fact_trip) \
    .load()

df_datetime = spark.read \
    .format("parquet") \
    .option("path", dim_datetime) \
    .load() \
    .filter(col("pick_year") == 2024) \
    .select(
        col("datetime_id"),
        col("pick_hour"),
        col("pick_weekday_id"),
        col("drop_hour"),
        col("drop_weekday_id")
    )

df_joined = df_fact \
    .join(df_datetime,
          df_fact.datetimestamp_id == df_datetime.datetime_id, "inner") \
    .drop(col("datetimestamp_id"), col("datetime_id"))

df_joined.printSchema()

                                                                                

root
 |-- trip_id: long (nullable = true)
 |-- vendor_id: long (nullable = true)
 |-- pu_location_id: long (nullable = true)
 |-- do_location_id: long (nullable = true)
 |-- ratecode_id: long (nullable = true)
 |-- payment_id: long (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- pick_hour: double (nullable = true)
 |-- pick_weekday_id: integer (nullable = true)
 |-- drop_hour: double (nullable = true)
 |-- drop_weekday_id: integer (nullable = true)



In [5]:
selected_columns = [
    "vendor_id",
    "pu_location_id",
    "do_location_id",
    "ratecode_id",
    "payment_id",
    "passenger_count",
    "trip_distance",
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "pick_hour",
    "pick_weekday_id",
    "drop_hour",
    "drop_weekday_id"
]

assembler = VectorAssembler(
    inputCols=selected_columns,
    outputCol="features"
)

data_transformed = assembler.transform(df_joined)

In [6]:
# Split dataset
train_data, test_data = data_transformed.randomSplit([0.8, 0.2])

In [7]:
# Linear Regression
lr_model = LinearRegression(
    featuresCol="features",
    labelCol="total_amount",
    regParam=0.01
)

# Train model
trained_model = lr_model.fit(train_data)

# Testing
predictions = trained_model.transform(test_data)

                                                                                

In [8]:
# Evaluation
rmse_evaluator = RegressionEvaluator(
    labelCol="total_amount",
    predictionCol="prediction",
    metricName="rmse"
)

mae_evaluator = RegressionEvaluator(
    labelCol="total_amount",
    predictionCol="prediction",
    metricName="mae"
)

r2_evaluator = RegressionEvaluator(
    labelCol="total_amount",
    predictionCol="prediction",
    metricName="r2"
)

rmse = rmse_evaluator.evaluate(predictions)
mae = mae_evaluator.evaluate(predictions)
r2 = r2_evaluator.evaluate(predictions)

                                                                                

In [10]:
# Store in BigQuery
evaluation_data = spark.createDataFrame([
    Row(name="Linear Regression", rmse=rmse, mae=mae, r2=r2)
])

evaluation_data.show()

evaluation_data.write \
    .format("bigquery") \
    .option("table", output) \
    .option("temporaryGcsBucket", "uber-pyspark-jobs/temp") \
    .mode("overwrite") \
    .save()

                                                                                

+-----------------+-----------------+------------------+------------------+
|             name|             rmse|               mae|                r2|
+-----------------+-----------------+------------------+------------------+
|Linear Regression|0.502891116917058|0.2197194188491359|0.9995676032268148|
+-----------------+-----------------+------------------+------------------+



                                                                                

In [11]:
spark.stop()