Multi-Linear Rwgression using pyspark

In [0]:
df = spark.read.format("csv").option("header", "true").load("dbfs:/FileStore/tables/50_Startups-9.csv")
df.show()

+---------+--------------+---------------+----------+---------+
|R&D Spend|Administration|Marketing Spend|     State|   Profit|
+---------+--------------+---------------+----------+---------+
| 165349.2|      136897.8|       471784.1|  New York|192261.83|
| 162597.7|     151377.59|      443898.53|California|191792.06|
|153441.51|     101145.55|      407934.54|   Florida|191050.39|
|144372.41|     118671.85|      383199.62|  New York|182901.99|
|142107.34|      91391.77|      366168.42|   Florida|166187.94|
| 131876.9|      99814.71|      362861.36|  New York|156991.12|
|134615.46|     147198.87|      127716.82|California|156122.51|
|130298.13|     145530.06|      323876.68|   Florida| 155752.6|
|120542.52|     148718.95|      311613.29|  New York|152211.77|
|123334.88|     108679.17|      304981.62|California|149759.96|
|101913.08|     110594.11|      229160.95|   Florida|146121.95|
|100671.96|      91790.61|      249744.55|California| 144259.4|
| 93863.75|     127320.38|      249839.4

In [0]:
df.printSchema()

root
 |-- R&D Spend: string (nullable = true)
 |-- Administration: string (nullable = true)
 |-- Marketing Spend: string (nullable = true)
 |-- State: string (nullable = true)
 |-- Profit: string (nullable = true)



In [0]:
df.describe()

Out[39]: DataFrame[summary: string, R&D Spend: string, Administration: string, Marketing Spend: string, State: string, Profit: string]

In [0]:
df.dtypes

Out[40]: [('R&D Spend', 'string'),
 ('Administration', 'string'),
 ('Marketing Spend', 'string'),
 ('State', 'string'),
 ('Profit', 'string')]

In [0]:
df = df.withColumn("R&D Spend", df['R&D Spend'].cast("float"))
df = df.withColumn("Administration", df['Administration'].cast("float"))
df = df.withColumn("Marketing Spend", df['Marketing Spend'].cast("float"))
df = df.withColumn("Profit", df['Profit'].cast("float"))

In [0]:
df.show()

+---------+--------------+---------------+----------+---------+
|R&D Spend|Administration|Marketing Spend|     State|   Profit|
+---------+--------------+---------------+----------+---------+
| 165349.2|      136897.8|       471784.1|  New York|192261.83|
| 162597.7|      151377.6|      443898.53|California|191792.06|
|153441.52|     101145.55|      407934.53|   Florida|191050.39|
| 144372.4|     118671.85|      383199.62|  New York|182901.98|
|142107.34|      91391.77|       366168.4|   Florida|166187.94|
| 131876.9|      99814.71|      362861.38|  New York|156991.12|
|134615.45|     147198.88|      127716.82|California|156122.52|
|130298.13|     145530.06|       323876.7|   Florida| 155752.6|
|120542.52|     148718.95|      311613.28|  New York|152211.77|
|123334.88|     108679.17|      304981.62|California|149759.95|
|101913.08|     110594.11|      229160.95|   Florida|146121.95|
|100671.96|      91790.61|      249744.55|California| 144259.4|
| 93863.75|     127320.38|      249839.4

In [0]:
from pyspark.ml.feature import VectorAssembler, StringIndexer

In [0]:
stringIndexer = StringIndexer(inputCol="State",outputCol="State_indexes")
df = stringIndexer.fit(df).transform(df)
df.show()

+---------+--------------+---------------+----------+---------+-----------+------------+-------+-------------+
|R&D Spend|Administration|Marketing Spend|     State|   Profit|State_index|State_indexs|State_i|State_indexes|
+---------+--------------+---------------+----------+---------+-----------+------------+-------+-------------+
| 165349.2|      136897.8|       471784.1|  New York|192261.83|        1.0|         1.0|    1.0|          1.0|
| 162597.7|      151377.6|      443898.53|California|191792.06|        0.0|         0.0|    0.0|          0.0|
|153441.52|     101145.55|      407934.53|   Florida|191050.39|        2.0|         2.0|    2.0|          2.0|
| 144372.4|     118671.85|      383199.62|  New York|182901.98|        1.0|         1.0|    1.0|          1.0|
|142107.34|      91391.77|       366168.4|   Florida|166187.94|        2.0|         2.0|    2.0|          2.0|
| 131876.9|      99814.71|      362861.38|  New York|156991.12|        1.0|         1.0|    1.0|          1.0|
|

In [0]:
df = df.drop('State_indexs','State_i','State_indexes')

In [0]:
from pyspark.ml.feature import VectorAssembler
vectorassembler = VectorAssembler(inputCols=["R&D Spend", "Administration", "Marketing Spend", "Profit", "State_index"],outputCol="Independent Value")
df = vectorassembler.transform(df)

In [0]:
Linear_df = df.select('Independent Value', 'Profit')
Linear_df.show()

+--------------------+---------+
|   Independent Value|   Profit|
+--------------------+---------+
|[165349.203125,13...|192261.83|
|[162597.703125,15...|191792.06|
|[153441.515625,10...|191050.39|
|[144372.40625,118...|182901.98|
|[142107.34375,913...|166187.94|
|[131876.90625,998...|156991.12|
|[134615.453125,14...|156122.52|
|[130298.1328125,1...| 155752.6|
|[120542.5234375,1...|152211.77|
|[123334.8828125,1...|149759.95|
|[101913.078125,11...|146121.95|
|[100671.9609375,9...| 144259.4|
|[93863.75,127320....|141585.52|
|[91992.390625,135...|134307.34|
|[119943.2421875,1...|132602.66|
|[114523.609375,12...|129917.04|
|[78013.109375,121...|126992.93|
|[94657.15625,1450...|125370.37|
|[91749.15625,1141...| 124266.9|
|[86419.703125,153...|122776.86|
+--------------------+---------+
only showing top 20 rows



In [0]:
from pyspark.ml.regression import LinearRegression
train_data,test_data = Linear_df.randomSplit([0.8,0.2])
regressor = LinearRegression(featuresCol = 'Independent Value',labelCol = 'Profit')
regressor = regressor.fit(train_data)

In [0]:
regressor.coefficients

Out[73]: DenseVector([-0.0, 0.0, 0.0, 1.0, 0.0])

In [0]:
regressor.intercept

Out[74]: -1.3580478423148815e-09

In [0]:
pred_resut = regressor.evaluate(test_data)

In [0]:
pred_resut.predictions.show()

+--------------------+---------+------------------+
|   Independent Value|   Profit|        prediction|
+--------------------+---------+------------------+
|(5,[1,3],[135426....| 42559.73| 42559.73046874977|
|[0.0,116983.79687...|  14681.4|14681.400390624265|
|[542.049987792968...| 35673.41|35673.410156249374|
|[1315.4599609375,...| 49490.75| 49490.74999999985|
|[23640.9296875,96...| 71498.49| 71498.49218749984|
|[46014.01953125,8...| 96479.51| 96479.50781249994|
|[61136.37890625,1...| 97483.56| 97483.56249999997|
|[61994.48046875,1...| 99937.59| 99937.59374999988|
|[64664.7109375,13...|107404.34|107404.34375000003|
|[67532.53125,1057...|108733.99|108733.99218749994|
|[73994.5625,12278...|110352.25|110352.24999999994|
|[78389.46875,1537...|111313.02|    111313.0234375|
|[100671.9609375,9...| 144259.4|      144259.40625|
|[120542.5234375,1...|152211.77|152211.76562500012|
|[123334.8828125,1...|149759.95|149759.95312499988|
|[153441.515625,10...|191050.39| 191050.3906250002|
|[162597.703

In [0]:
pred_resut.r2,pred_resut.meanSquaredError,pred_resut.rootMeanSquaredError

Out[80]: (1.0, 6.95874935270022e-20, 2.637944152687888e-10)