In [7]:
import pandas as pd
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from statsmodels.formula.api import ols, glm
import statsmodels.api as sm
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator

In [8]:
spark = (
 SparkSession.builder.appName("project 1 LR")
 .config("spark.sql.repl.eagerEval.enabled", True)
 .config("spark.executor.memory","4G")
 .config("spark.driver.memory","2G")
 .config("spark.sql.parquet.cacheMetadata", "true")
 .config("spark.sql.session.timeZone", "Etc/UTC")
 .config('spark.driver.maxResultSize', '2048m')
 .getOrCreate()
)

In [9]:
data = spark.read.parquet('data/merged_sdf.parquet')

## Adjusting the columns to train the model ##

In [10]:
model_feature = ['duration', 'extra', 'Weekend', 'Airport', 'Congestion','tip_amount']
selected_data = data.select(*[col(col_name) for col_name in model_feature])

In [11]:
input_column = ['duration', 'extra', 'Weekend', 'Airport', 'Congestion']
assembler = VectorAssembler(inputCols=input_column, outputCol='features')
assembled_data = assembler.transform(selected_data)

# Split the data into training and testing sets
train_data, test_data = assembled_data.randomSplit([0.8, 0.2], seed=0)

## Model training ##

In [12]:
# Take 3 mins to run
for i in [0.0001, 0.001, 0.01]:
    lm = LinearRegression(
        featuresCol='features', 
        labelCol='tip_amount',
        regParam=i, 
        elasticNetParam=0.5
    ).fit(train_data)
    predictions = lm.transform(test_data)
    evaluator = RegressionEvaluator(labelCol='tip_amount', predictionCol='prediction', metricName='rmse')
    rmse = evaluator.evaluate(predictions)
    print('R_sqr:', lm.summary.r2, ", Root Mean Squared Error:", rmse, ", regParam:", i)



                                                                                

R_sqr: 0.47586036323747827 , Root Mean Squared Error: 2.8441040944187965 , regParam: 0.0001


                                                                                

R_sqr: 0.4758602439651509 , Root Mean Squared Error: 2.844103178086098 , regParam: 0.001




R_sqr: 0.475848348201902 , Root Mean Squared Error: 2.8441223520029384 , regParam: 0.01


                                                                                