In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('simpleModel').getOrCreate()

In [3]:
spark

In [4]:
df_spark = spark.read.csv('test2.csv',header=True,inferSchema=True)
df_spark.show()

+---------+----+----------+------+
|     Name| age|Experience|Salary|
+---------+----+----------+------+
|    Krish|  31|        10| 30000|
|Sudhanshu|  30|         8| 25000|
|    Sunny|  29|         4| 20000|
|     Paul|  24|         3| 20000|
|   Harsha|  21|         1| 15000|
|  Shubham|  23|         2| 18000|
|   Mahesh|null|      null| 40000|
|     null|  34|        10| 38000|
|     null|  36|      null|  null|
+---------+----+----------+------+



In [5]:
df_spark = df_spark.na.drop()

In [7]:
df_spark.printSchema()

root
 |-- Name: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- Experience: integer (nullable = true)
 |-- Salary: integer (nullable = true)



In [10]:
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(inputCols=['age','Experience'],outputCol='Included Features')

features = assembler.transform(df_spark)

In [12]:
output = features.select(['Included Features','Salary'])

In [13]:
output.show()

+-----------------+------+
|Included Features|Salary|
+-----------------+------+
|      [31.0,10.0]| 30000|
|       [30.0,8.0]| 25000|
|       [29.0,4.0]| 20000|
|       [24.0,3.0]| 20000|
|       [21.0,1.0]| 15000|
|       [23.0,2.0]| 18000|
+-----------------+------+



In [29]:
from pyspark.ml.regression import LinearRegression

train_data,test_data = output.randomSplit([0.6,0.4])

In [30]:
regr = LinearRegression(featuresCol='Included Features',labelCol='Salary')

In [31]:
regr=regr.fit(train_data)

In [32]:
regr.coefficients

DenseVector([1771.9298, -298.2456])

In [33]:
regr.intercept

-21912.28070175451

In [34]:
pred_output = regr.evaluate(test_data)

In [35]:
pred_output.predictions.show()

+-----------------+------+------------------+
|Included Features|Salary|        prediction|
+-----------------+------+------------------+
|       [29.0,4.0]| 20000|28280.701754385987|
|       [30.0,8.0]| 25000|28859.649122807026|
+-----------------+------+------------------+



In [36]:
pred_output.meanAbsoluteError,pred_output.meanSquaredError

(6070.175438596507, 41733456.44813812)