# PySpark MLlib: Linear Regression

## Basic

In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName("linreg").getOrCreate()

24/06/09 20:39:13 WARN Utils: Your hostname, agusrichard.local resolves to a loopback address: 127.0.0.1; using 192.168.0.104 instead (on interface en0)
24/06/09 20:39:13 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/09 20:39:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df = spark.read.format("libsvm").load("./files/sample_linear_regression_data.txt")
df.show()

24/06/09 20:40:29 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.
                                                                                

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
| -9.490009878824548|(10,[0,1,2,3,4,5,...|
| 0.2577820163584905|(10,[0,1,2,3,4,5,...|
| -4.438869807456516|(10,[0,1,2,3,4,5,...|
|-19.782762789614537|(10,[0,1,2,3,4,5,...|
| -7.966593841555266|(10,[0,1,2,3,4,5,...|
| -7.896274316726144|(10,[0,1,2,3,4,5,...|
| -8.464803554195287|(10,[0,1,2,3,4,5,...|
| 2.1214592666251364|(10,[0,1,2,3,4,5,...|
| 1.0720117616524107|(10,[0,1,2,3,4,5,...|
|-13.772441561702871|(10,[0,1,2,3,4,5,...|
| -5.082010756207233|(10,[0,1,2,3,4,5,...|
|  7.887786536531237|(10,[0,1,2,3,4,5,...|
| 14.323146365332388|(10,[0,1,2,3,4,5,...|
|-20.057482615789212|(10,[0,1,2,3,4,5,...|
|-0.8995693247765151|(10,[0,1,2,3,4,5,...|
| -19.16829262296376|(10,[0,1,2,3,4,5,...|
|  5.601801561245534|(10,[0,1,2,3,4,5,...|
|-3.2256352187273354|(10,[0,1,2,3,4,5,...|
| 1.5299675726687754|(10,[0,1,2,3,4,5,...|
| -0.250102447941961|(10,[0,1,2,3,4,5,...|
+----------

In [4]:
train_df, test_df = df.randomSplit([0.7, 0.3])

In [5]:
train_df.describe().show()

+-------+--------------------+
|summary|               label|
+-------+--------------------+
|  count|                 357|
|   mean|-0.07763567996945892|
| stddev|   10.51016656614689|
|    min| -28.571478869743427|
|    max|   27.78383192005107|
+-------+--------------------+


In [6]:
test_df.describe().show()

+-------+-------------------+
|summary|              label|
+-------+-------------------+
|  count|                144|
|   mean|  1.086230817144352|
| stddev|  9.811270936017236|
|    min|-28.046018037776633|
|    max| 27.111027963108548|
+-------+-------------------+


In [8]:
!pip install numpy

Collecting numpy
  Using cached numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl (14.0 MB)
Installing collected packages: numpy
Successfully installed numpy-1.26.4


In [9]:
from pyspark.ml.regression import LinearRegression

In [12]:
model = LinearRegression(featuresCol="features", labelCol="label", predictionCol="prediction")

In [13]:
fitted_model = model.fit(train_df)

24/06/09 20:45:47 WARN Instrumentation: [12678e15] regParam is zero, which might cause numerical instability and overfitting.


In [15]:
fitted_model.coefficients

DenseVector([-1.1687, 1.2786, -0.8491, 0.9672, 0.9713, 1.0932, -0.3845, -0.404, -1.4362, 0.2242])

In [18]:
fitted_model.intercept

-0.08413902846189886

In [20]:
fitted_model.summary.residuals.show()

+-------------------+
|          residuals|
+-------------------+
|-28.534016675785942|
|-24.792409883220454|
| -25.07826198918433|
|-22.644720400640814|
| -24.32624950301568|
| -20.23765290671922|
|-18.498223070427823|
|-21.576982585490068|
| -18.72155480384377|
|-19.645215664475135|
|-19.317118778492365|
|-17.926696373465088|
|-15.520515877417797|
|-20.366784096630415|
|-15.521931849388428|
| -16.21985395517928|
|  -18.0037412583512|
| -16.83734752280589|
| -16.97352385284288|
|-14.434464598184345|
+-------------------+


In [21]:
fitted_model.summary.rootMeanSquaredError

10.356470291946756

In [22]:
unlabeled_data = test_df.select(["features"])
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
+--------------------+


In [27]:
predictions = fitted_model.transform(unlabeled_data)

In [28]:
predictions.show()

+--------------------+-------------------+
|            features|         prediction|
+--------------------+-------------------+
|(10,[0,1,2,3,4,5,...|-2.5251685010985123|
|(10,[0,1,2,3,4,5,...| 0.8072814362914444|
|(10,[0,1,2,3,4,5,...| 1.2370905300996191|
|(10,[0,1,2,3,4,5,...|-1.2939283160735788|
|(10,[0,1,2,3,4,5,...| 1.1437080247632876|
|(10,[0,1,2,3,4,5,...|  1.035863436625075|
|(10,[0,1,2,3,4,5,...| 0.5900970858697474|
|(10,[0,1,2,3,4,5,...|-1.2894366384228835|
|(10,[0,1,2,3,4,5,...| 0.6564263412122748|
|(10,[0,1,2,3,4,5,...|-1.3995039369083064|
|(10,[0,1,2,3,4,5,...|  2.496622461727185|
|(10,[0,1,2,3,4,5,...|-2.1603481509477374|
|(10,[0,1,2,3,4,5,...|-2.4369810169432844|
|(10,[0,1,2,3,4,5,...| 2.5312648055984885|
|(10,[0,1,2,3,4,5,...| 0.6106534513946493|
|(10,[0,1,2,3,4,5,...|  0.999008144606899|
|(10,[0,1,2,3,4,5,...| -2.408961561556007|
|(10,[0,1,2,3,4,5,...| 1.4629839841862406|
|(10,[0,1,2,3,4,5,...|  0.529165852499554|
|(10,[0,1,2,3,4,5,...|-0.7104832125918854|
+----------

## Real Data

In [29]:
spark = SparkSession.builder.appName("lr_real").getOrCreate()

24/06/09 20:58:53 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [30]:
df = spark.read.csv("./files/Ecommerce_Customers.csv", header=True, inferSchema=True)
df.show()

+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|               Email|             Address|          Avatar|Avg Session Length|       Time on App|   Time on Website|Length of Membership|Yearly Amount Spent|
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|mstephenson@ferna...|835 Frank TunnelW...|          Violet| 34.49726772511229| 12.65565114916675| 39.57766801952616|  4.0826206329529615|  587.9510539684005|
|   hduke@hotmail.com|4547 Archer Commo...|       DarkGreen| 31.92627202636016|11.109460728682564|37.268958868297744|    2.66403418213262|  392.2049334443264|
|    pallen@yahoo.com|24645 Valerie Uni...|          Bisque|33.000914755642675|11.330278057777512|37.110597442120856|   4.104543202376424| 487.54750486747207|
|riverarebecca@gma...|1414 David Throug...|   

                                                                                

In [31]:
df.columns

['Email',
 'Address',
 'Avatar',
 'Avg Session Length',
 'Time on App',
 'Time on Website',
 'Length of Membership',
 'Yearly Amount Spent']

In [32]:
from pyspark.ml.feature import VectorAssembler

In [33]:
assembler = VectorAssembler(
    inputCols=['Avg Session Length', 'Time on App', 'Time on Website', 'Length of Membership'],
    outputCol="features"
)
result = assembler.transform(df)

In [35]:
result = result.withColumnRenamed("Yearly Amount Spent", "label")

In [36]:
result.show()

+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+------------------+--------------------+
|               Email|             Address|          Avatar|Avg Session Length|       Time on App|   Time on Website|Length of Membership|             label|            features|
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+------------------+--------------------+
|mstephenson@ferna...|835 Frank TunnelW...|          Violet| 34.49726772511229| 12.65565114916675| 39.57766801952616|  4.0826206329529615| 587.9510539684005|[34.4972677251122...|
|   hduke@hotmail.com|4547 Archer Commo...|       DarkGreen| 31.92627202636016|11.109460728682564|37.268958868297744|    2.66403418213262| 392.2049334443264|[31.9262720263601...|
|    pallen@yahoo.com|24645 Valerie Uni...|          Bisque|33.000914755642675|11.330278057777512|37.1105

In [37]:
selected_df = result.select(["features", "label"])
selected_df.show()

+--------------------+------------------+
|            features|             label|
+--------------------+------------------+
|[34.4972677251122...| 587.9510539684005|
|[31.9262720263601...| 392.2049334443264|
|[33.0009147556426...|487.54750486747207|
|[34.3055566297555...| 581.8523440352177|
|[33.3306725236463...| 599.4060920457634|
|[33.8710378793419...|  637.102447915074|
|[32.0215955013870...| 521.5721747578274|
|[32.7391429383803...| 549.9041461052942|
|[33.9877728956856...| 570.2004089636196|
|[31.9365486184489...| 427.1993848953282|
|[33.9925727749537...| 492.6060127179966|
|[33.8793608248049...| 522.3374046069357|
|[29.5324289670579...| 408.6403510726275|
|[33.1903340437226...| 573.4158673313865|
|[32.3879758531538...| 470.4527333009554|
|[30.7377203726281...| 461.7807421962299|
|[32.1253868972878...|457.84769594494855|
|[32.3388993230671...|407.70454754954415|
|[32.1878120459321...| 452.3156754800354|
|[32.6178560628234...|  605.061038804892|
+--------------------+------------

In [38]:
train_df, test_df = selected_df.randomSplit([0.7, 0.3])

In [39]:
model = LinearRegression().fit(train_df)

24/06/09 21:06:52 WARN Instrumentation: [5a93fe34] regParam is zero, which might cause numerical instability and overfitting.
                                                                                

In [41]:
result = model.evaluate(test_df)

In [42]:
result.rootMeanSquaredError

10.800762892758213

In [44]:
result.predictions.show()

+--------------------+------------------+------------------+
|            features|             label|        prediction|
+--------------------+------------------+------------------+
|[30.5743636841713...|442.06441375806565|441.75214957364346|
|[30.7377203726281...| 461.7807421962299| 450.9857589392798|
|[30.8364326747734...| 467.5019004269896| 471.6062930861515|
|[31.0662181616375...|448.93329320767435| 461.7538546135622|
|[31.1280900496166...| 557.2526867470547| 564.8111086562972|
|[31.3662121671876...| 430.5888825564849|426.65769860212345|
|[31.4474464941278...|  418.602742095224| 426.1491563153488|
|[31.5702008293202...| 545.9454921414049|  563.727622392435|
|[31.6005122003032...|479.17285149109694| 461.0086267138247|
|[31.6098395733896...|444.54554965110816| 427.3126396729497|
|[31.6739155032749...| 475.7250679098812| 502.1879254190269|
|[31.7207699002873...|  538.774933478023| 545.9971015103786|
|[31.7216523605090...|347.77692663187264|349.26389718030237|
|[31.7366356860502...| 4

## Project

In [60]:
spark = SparkSession.builder.appName("lr_project").getOrCreate()

In [61]:
df = spark.read.csv("./files/cruise_ship_info.csv", header=True, inferSchema=True)
df.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [62]:
from pyspark.ml.feature import StringIndexer

In [63]:
indexer = StringIndexer(inputCol="Cruise_line", outputCol="Cruise_index")
indexer = indexer.fit(df)
df = indexer.transform(df)
df.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_index|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|        16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|        16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|         1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|         1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|         1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|         1.0|
|    Elati

In [64]:
selected_df = df.select(["Age", "Tonnage", "passengers", "length", "cabins", "passenger_density", "Cruise_index", "crew"])
selected_df.show()

+---+------------------+----------+------+------+-----------------+------------+----+
|Age|           Tonnage|passengers|length|cabins|passenger_density|Cruise_index|crew|
+---+------------------+----------+------+------+-----------------+------------+----+
|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|        16.0|3.55|
|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|        16.0|3.55|
| 26|            47.262|     14.86|  7.22|  7.43|             31.8|         1.0| 6.7|
| 11|             110.0|     29.74|  9.53| 14.88|            36.99|         1.0|19.1|
| 17|           101.353|     26.42|  8.92| 13.21|            38.36|         1.0|10.0|
| 22|            70.367|     20.52|  8.55|  10.2|            34.29|         1.0| 9.2|
| 15|            70.367|     20.52|  8.55|  10.2|            34.29|         1.0| 9.2|
| 23|            70.367|     20.56|  8.55| 10.22|            34.23|         1.0| 9.2|
| 19|            70.367|     20.52|  8.55|  10.2|     

In [65]:
assembler = VectorAssembler(
    inputCols=["Age", "Tonnage", "passengers", "length", "cabins", "passenger_density", "Cruise_index"],
    outputCol="features"
)

In [66]:
selected_df = assembler.transform(selected_df).select(["features", "crew"])
selected_df = selected_df.withColumnRenamed("crew", "label")

In [67]:
train_df, test_df = selected_df.randomSplit([0.7, 0.3])

In [68]:
model = LinearRegression().fit(train_df)

24/06/09 21:18:59 WARN Instrumentation: [64e2a98d] regParam is zero, which might cause numerical instability and overfitting.


In [69]:
summary = model.evaluate(test_df)

In [71]:
summary.predictions.show()

+--------------------+-----+------------------+
|            features|label|        prediction|
+--------------------+-----+------------------+
|[5.0,86.0,21.04,9...|  8.0|  9.26965985345183|
|[5.0,133.5,39.59,...|13.13|13.333218388854178|
|[5.0,160.0,36.34,...| 13.6|15.164422558125507|
|[6.0,30.276999999...| 3.55| 4.451915924259921|
|[6.0,90.0,20.0,9....|  9.0|10.192598689531298|
|[6.0,112.0,38.0,9...| 10.9|11.458547855232126|
|[6.0,113.0,37.82,...| 12.0|11.768251804923235|
|[6.0,158.0,43.7,1...| 13.6|14.063738143521487|
|[7.0,158.0,43.7,1...| 13.6|14.001826373573234|
|[9.0,85.0,19.68,9...| 8.69|    9.403067892806|
|[9.0,90.09,25.01,...| 8.69| 9.267664294579129|
|[9.0,113.0,26.74,...|12.38|11.378607447608072|
|[10.0,77.0,20.16,...|  9.0| 8.868529116548233|
|[10.0,86.0,21.14,...|  9.2| 9.720592655385367|
|[10.0,91.62700000...|  9.0| 9.190200967842154|
|[11.0,58.6,15.66,...|  7.6| 7.466388826043693|
|[11.0,90.0,22.4,9...| 11.0| 10.07484195301561|
|[11.0,90.09,25.01...| 8.48| 8.883386249