## Linear Regression with PySpark

* Based on the official Spark documentation for PySpark

In [1]:
!wget https://raw.githubusercontent.com/apache/spark/master/data/mllib/sample_linear_regression_data.txt >> sample_linear_regression_data.txt

--2025-07-14 08:57:56--  https://raw.githubusercontent.com/apache/spark/master/data/mllib/sample_linear_regression_data.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
connected. to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... 
HTTP request sent, awaiting response... 200 OK
Length: 119069 (116K) [text/plain]
Saving to: ‘sample_linear_regression_data.txt.1’


2025-07-14 08:57:57 (559 KB/s) - ‘sample_linear_regression_data.txt.1’ saved [119069/119069]



In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName("lr_example").getOrCreate()

25/07/14 08:58:00 WARN Utils: Your hostname, aditya-HP-Laptop-15s-eq1xxx resolves to a loopback address: 127.0.1.1; using 10.200.82.42 instead (on interface wlo1)
25/07/14 08:58:00 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/14 08:58:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
from pyspark.ml.regression import LinearRegression

In [5]:
training = spark.read.format("libsvm").load("sample_linear_regression_data.txt.2")

25/07/14 08:58:05 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.
                                                                                

In [6]:
training.show()

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
| -9.490009878824548|(10,[0,1,2,3,4,5,...|
| 0.2577820163584905|(10,[0,1,2,3,4,5,...|
| -4.438869807456516|(10,[0,1,2,3,4,5,...|
|-19.782762789614537|(10,[0,1,2,3,4,5,...|
| -7.966593841555266|(10,[0,1,2,3,4,5,...|
| -7.896274316726144|(10,[0,1,2,3,4,5,...|
| -8.464803554195287|(10,[0,1,2,3,4,5,...|
| 2.1214592666251364|(10,[0,1,2,3,4,5,...|
| 1.0720117616524107|(10,[0,1,2,3,4,5,...|
|-13.772441561702871|(10,[0,1,2,3,4,5,...|
| -5.082010756207233|(10,[0,1,2,3,4,5,...|
|  7.887786536531237|(10,[0,1,2,3,4,5,...|
| 14.323146365332388|(10,[0,1,2,3,4,5,...|
|-20.057482615789212|(10,[0,1,2,3,4,5,...|
|-0.8995693247765151|(10,[0,1,2,3,4,5,...|
| -19.16829262296376|(10,[0,1,2,3,4,5,...|
|  5.601801561245534|(10,[0,1,2,3,4,5,...|
|-3.2256352187273354|(10,[0,1,2,3,4,5,...|
| 1.5299675726687754|(10,[0,1,2,3,4,5,...|
| -0.250102447941961|(10,[0,1,2,3,4,5,...|
+----------

In [7]:
lr = LinearRegression(featuresCol = "features", labelCol = "label", predictionCol="prediction")

In [8]:
lrModel = lr.fit(training)

25/07/14 08:58:12 WARN Instrumentation: [2dc7abc1] regParam is zero, which might cause numerical instability and overfitting.
25/07/14 08:58:12 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS


In [9]:
print("Coefficients:", str(lrModel.coefficients))
print("Intercept: ", str(lrModel.intercept))

Coefficients: [0.0073350710225801715,0.8313757584337543,-0.8095307954684084,2.441191686884721,0.5191713795290003,1.1534591903547016,-0.2989124112808717,-0.5128514186201779,-0.619712827067017,0.6956151804322931]
Intercept:  0.14228558260358093


In [10]:
trainSummary = lrModel.summary

In [11]:
trainSummary.meanSquaredError

103.28843028724194

In [12]:
trainSummary.rootMeanSquaredError

10.16309157133015

In [13]:
trainSummary.r2

0.027839179518600154

In [14]:
print("MAE: ", trainSummary.meanAbsoluteError)
print("MSE: ", trainSummary.meanSquaredError)

MAE:  8.145215527783876
MSE:  103.28843028724194


In [15]:
print("RMSE: ", trainSummary.rootMeanSquaredError)

RMSE:  10.16309157133015


In [16]:
print("R2: ", trainSummary.r2)

R2:  0.027839179518600154


In [17]:
print("Adj R2: ", trainSummary.r2adj)

Adj R2:  0.007999162774081858


### Train Test Split with PySpark

* Pass in the split between training/test as a list.
* No correct, but generally 70/30 or 60/40 splits are used.
* Depending on how much data you have and how unbalanced it is.

In [18]:
df = spark.read.format("libsvm").load("sample_linear_regression_data.txt.2") # Full Dataset

25/07/14 08:58:13 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.


In [19]:
train_data, test_data = df.randomSplit([0.7, 0.3], seed = 42)

In [20]:
train_data.show()

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
|-28.571478869743427|(10,[0,1,2,3,4,5,...|
|-28.046018037776633|(10,[0,1,2,3,4,5,...|
|-26.736207182601724|(10,[0,1,2,3,4,5,...|
| -23.51088409032297|(10,[0,1,2,3,4,5,...|
|-23.487440120936512|(10,[0,1,2,3,4,5,...|
|-22.837460416919342|(10,[0,1,2,3,4,5,...|
|-20.057482615789212|(10,[0,1,2,3,4,5,...|
|-19.884560774273424|(10,[0,1,2,3,4,5,...|
|-19.872991038068406|(10,[0,1,2,3,4,5,...|
| -19.16829262296376|(10,[0,1,2,3,4,5,...|
|-18.845922472898582|(10,[0,1,2,3,4,5,...|
| -18.27521356600463|(10,[0,1,2,3,4,5,...|
|-17.494200356883344|(10,[0,1,2,3,4,5,...|
| -17.32672073267595|(10,[0,1,2,3,4,5,...|
| -16.71909683360509|(10,[0,1,2,3,4,5,...|
|-16.692207021311106|(10,[0,1,2,3,4,5,...|
| -16.26143027545273|(10,[0,1,2,3,4,5,...|
| -15.86200932757056|(10,[0,1,2,3,4,5,...|
|-15.732088272239245|(10,[0,1,2,3,4,5,...|
|-15.375857723312297|(10,[0,1,2,3,4,5,...|
+----------

In [21]:
test_data.show()

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
|-26.805483428483072|(10,[0,1,2,3,4,5,...|
|-22.949825936196074|(10,[0,1,2,3,4,5,...|
|-21.432387764165806|(10,[0,1,2,3,4,5,...|
|-20.212077258958672|(10,[0,1,2,3,4,5,...|
|-19.782762789614537|(10,[0,1,2,3,4,5,...|
| -19.66731861537172|(10,[0,1,2,3,4,5,...|
|-19.402336030214553|(10,[0,1,2,3,4,5,...|
|-17.803626188664516|(10,[0,1,2,3,4,5,...|
|-17.428674570939506|(10,[0,1,2,3,4,5,...|
|-17.065399625876015|(10,[0,1,2,3,4,5,...|
|-17.026492264209548|(10,[0,1,2,3,4,5,...|
|-16.151349351277112|(10,[0,1,2,3,4,5,...|
| -16.08565904102149|(10,[0,1,2,3,4,5,...|
|-15.951512565794573|(10,[0,1,2,3,4,5,...|
|-15.780685032623301|(10,[0,1,2,3,4,5,...|
| -15.72351561304857|(10,[0,1,2,3,4,5,...|
|-15.437384793431217|(10,[0,1,2,3,4,5,...|
|-15.334767479922341|(10,[0,1,2,3,4,5,...|
|-14.822152909751189|(10,[0,1,2,3,4,5,...|
|-14.762758252931127|(10,[0,1,2,3,4,5,...|
+----------

In [22]:
unlabeled_data = test_data.select('features')

In [23]:
corrected_model = lr.fit(train_data)

25/07/14 08:58:15 WARN Instrumentation: [400722c0] regParam is zero, which might cause numerical instability and overfitting.
25/07/14 08:58:15 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors


In [24]:
res = corrected_model.evaluate(test_data)

In [26]:
print("MAE: ", res.meanAbsoluteError)
print("MSE: ", res.meanSquaredError)
print("RMSE: ", res.rootMeanSquaredError)
print("R2: ", res.r2)
print("Adj R2: ", res.r2adj)

MAE:  9.855750048378727
MSE:  142.31866794563598
RMSE:  11.929738804585622
R2:  -0.14679155085585793
Adj R2:  -0.24651255527810645


In [27]:
predictions = corrected_model.transform(unlabeled_data)

In [28]:
predictions.show()

+--------------------+--------------------+
|            features|          prediction|
+--------------------+--------------------+
|(10,[0,1,2,3,4,5,...|   1.500419302439231|
|(10,[0,1,2,3,4,5,...|   6.540721556576252|
|(10,[0,1,2,3,4,5,...|  1.4369775273526635|
|(10,[0,1,2,3,4,5,...|  1.3156052948594428|
|(10,[0,1,2,3,4,5,...|-0.09510236182489817|
|(10,[0,1,2,3,4,5,...| 0.12648407749270263|
|(10,[0,1,2,3,4,5,...|-0.40745999229762575|
|(10,[0,1,2,3,4,5,...| -1.3827504557268635|
|(10,[0,1,2,3,4,5,...|  2.6965070486236957|
|(10,[0,1,2,3,4,5,...|    2.42284270742401|
|(10,[0,1,2,3,4,5,...|-0.33620505674116263|
|(10,[0,1,2,3,4,5,...|  1.5811910073932323|
|(10,[0,1,2,3,4,5,...| -0.9126865153126812|
|(10,[0,1,2,3,4,5,...| -2.4337353560269603|
|(10,[0,1,2,3,4,5,...|  4.7238640017384945|
|(10,[0,1,2,3,4,5,...|  1.7972086764514907|
|(10,[0,1,2,3,4,5,...| -0.3727532193177282|
|(10,[0,1,2,3,4,5,...|   3.393593882956883|
|(10,[0,1,2,3,4,5,...|   1.173823533651508|
|(10,[0,1,2,3,4,5,...| 0.4009232