## Linear regression in Spark

In [1]:
from pyspark.sql import SparkSession

In [8]:
spark = SparkSession.builder.appName('linrex').getOrCreate()

In [9]:
from pyspark.ml.regression import LinearRegression

In [12]:
all_data = spark.read.format('libsvm').load('../datasets/sample_lin_reg_data.txt')

In [18]:
# Train Test Split
train_data, test_data = all_data.randomSplit([0.7,0.3])

In [19]:
train_data.describe().show()

+-------+-------------------+
|summary|              label|
+-------+-------------------+
|  count|                363|
|   mean|0.44476775326060647|
| stddev| 10.407621414347464|
|    min|-28.571478869743427|
|    max| 27.111027963108548|
+-------+-------------------+



In [20]:
test_data.describe().show()

+-------+--------------------+
|summary|               label|
+-------+--------------------+
|  count|                 138|
|   mean|-0.23731445299934523|
| stddev|  10.098608765703869|
|    min| -28.046018037776633|
|    max|   27.78383192005107|
+-------+--------------------+



In [16]:
#Linear regression
lr = LinearRegression(featuresCol='features',labelCol='label',predictionCol='prediction')

In [17]:
model = lr.fit(train_data)

In [21]:
model_result =  model.evaluate(test_data)

In [24]:
model_result.rootMeanSquaredError

10.203530174925838

In [25]:
unlabeled_data = test_data.select('features')

In [26]:
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
|(10,[0,1,2,3,4,5,...|
+--------------------+
only showing top 20 rows



In [27]:
predictions = model.transform(unlabeled_data)

In [28]:
predictions.show()

+--------------------+--------------------+
|            features|          prediction|
+--------------------+--------------------+
|(10,[0,1,2,3,4,5,...| -1.6151302330745434|
|(10,[0,1,2,3,4,5,...|-0.39734697369053706|
|(10,[0,1,2,3,4,5,...|  -3.957656161786463|
|(10,[0,1,2,3,4,5,...| 0.22098137317181113|
|(10,[0,1,2,3,4,5,...|-0.24402430806440376|
|(10,[0,1,2,3,4,5,...|  0.6086103277991831|
|(10,[0,1,2,3,4,5,...|  3.2708780391351873|
|(10,[0,1,2,3,4,5,...| -1.4249197217511604|
|(10,[0,1,2,3,4,5,...| -2.0688109541739923|
|(10,[0,1,2,3,4,5,...|   3.134633174865831|
|(10,[0,1,2,3,4,5,...| -0.9267013637097196|
|(10,[0,1,2,3,4,5,...|  0.3812103849125552|
|(10,[0,1,2,3,4,5,...| -1.8501975223359817|
|(10,[0,1,2,3,4,5,...|   2.301631854303685|
|(10,[0,1,2,3,4,5,...|   3.220564518143553|
|(10,[0,1,2,3,4,5,...|   -1.66805026817473|
|(10,[0,1,2,3,4,5,...| -3.4834820224630176|
|(10,[0,1,2,3,4,5,...| -1.1576180561131735|
|(10,[0,1,2,3,4,5,...|   2.827183851888351|
|(10,[0,1,2,3,4,5,...|   2.04042