In [1]:
from pyspark.sql import SparkSession
spark = SparkSession \
        .builder \
        .appName("SparkML") \
        .getOrCreate()

In [2]:
df = spark.read \
    .option("header", "True") \
    .option("inferSchema", "True") \
    .csv('test4.csv')
df.show()

+---------+---+---+------+
|     name|age| xp|salary|
+---------+---+---+------+
|    Krish| 31| 10| 30000|
|Sudhanshu| 30|  8| 25000|
|    Sunny| 29|  4| 20000|
|     Paul| 24|  3| 20000|
|   Harsha| 21|  1| 15000|
|  Shubham| 23|  2| 18000|
+---------+---+---+------+



In [3]:
from pyspark.ml.feature import VectorAssembler
featureAssembler = VectorAssembler(inputCols=["age", "xp"], outputCol="IndFeatures")

In [4]:
output = featureAssembler.transform(df)
output.show()

+---------+---+---+------+-----------+
|     name|age| xp|salary|IndFeatures|
+---------+---+---+------+-----------+
|    Krish| 31| 10| 30000|[31.0,10.0]|
|Sudhanshu| 30|  8| 25000| [30.0,8.0]|
|    Sunny| 29|  4| 20000| [29.0,4.0]|
|     Paul| 24|  3| 20000| [24.0,3.0]|
|   Harsha| 21|  1| 15000| [21.0,1.0]|
|  Shubham| 23|  2| 18000| [23.0,2.0]|
+---------+---+---+------+-----------+



In [5]:
final_data = output.select("IndFeatures", "Salary")
final_data.show()

+-----------+------+
|IndFeatures|Salary|
+-----------+------+
|[31.0,10.0]| 30000|
| [30.0,8.0]| 25000|
| [29.0,4.0]| 20000|
| [24.0,3.0]| 20000|
| [21.0,1.0]| 15000|
| [23.0,2.0]| 18000|
+-----------+------+



In [6]:
from pyspark.ml.regression import LinearRegression

train_data, test_data = final_data.randomSplit([0.75, 0.25])
regressor = LinearRegression(featuresCol="IndFeatures", labelCol="Salary")
regressor = regressor.fit(train_data)

In [7]:
regressor.coefficients

DenseVector([-1363.6364, 3181.8182])

In [8]:
regressor.intercept

40454.545454536375

In [9]:
pred_results = regressor.evaluate(test_data)

In [10]:
pred_results.predictions.show()

+-----------+------+------------------+
|IndFeatures|Salary|        prediction|
+-----------+------+------------------+
| [23.0,2.0]| 18000|15454.545454545765|
| [24.0,3.0]| 20000|17272.727272727512|
| [29.0,4.0]| 20000|13636.363636365619|
+-----------+------+------------------+



In [11]:
pred_results.meanSquaredError

18137741.04682258