In [32]:
import pyspark as sp
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
spark = SparkSession.builder.appName('BigMartSales').getOrCreate()

In [33]:
df = spark.read.csv('data/cleaned_train.csv', header=True, inferSchema=True)

In [34]:
df.show()

+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+---------+-----------+--------------------+-----------+-----------------+
|Item_Identifier|Item_Weight|Item_Fat_Content|Item_Visibility|           Item_Type|Item_MRP|Outlet_Identifier|outletAge|Outlet_Size|Outlet_Location_Type|Outlet_Type|Item_Outlet_Sales|
+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+---------+-----------+--------------------+-----------+-----------------+
|          FDA15|        9.3|         Low Fat|    0.016047301|               Dairy|249.8092|           OUT049|       19|     Medium|                   1|         S1|         3735.138|
|          DRC01|       5.92|         Regular|    0.019278216|         Soft Drinks| 48.2692|           OUT018|        9|     Medium|                   3|         S2|         443.4228|
|          FDN15|       17.5|         Low Fat|    0.016760075|                Me

In [35]:
df.printSchema()

root
 |-- Item_Identifier: string (nullable = true)
 |-- Item_Weight: double (nullable = true)
 |-- Item_Fat_Content: string (nullable = true)
 |-- Item_Visibility: double (nullable = true)
 |-- Item_Type: string (nullable = true)
 |-- Item_MRP: double (nullable = true)
 |-- Outlet_Identifier: string (nullable = true)
 |-- outletAge: integer (nullable = true)
 |-- Outlet_Size: string (nullable = true)
 |-- Outlet_Location_Type: integer (nullable = true)
 |-- Outlet_Type: string (nullable = true)
 |-- Item_Outlet_Sales: double (nullable = true)



In [36]:
df.describe().show()

+-------+---------------+------------------+----------------+-------------------+-------------+-----------------+-----------------+------------------+-----------+--------------------+-----------+------------------+
|summary|Item_Identifier|       Item_Weight|Item_Fat_Content|    Item_Visibility|    Item_Type|         Item_MRP|Outlet_Identifier|         outletAge|Outlet_Size|Outlet_Location_Type|Outlet_Type| Item_Outlet_Sales|
+-------+---------------+------------------+----------------+-------------------+-------------+-----------------+-----------------+------------------+-----------+--------------------+-----------+------------------+
|  count|           8523|              8523|            8523|               8523|         8523|             8523|             8523|              8523|       8523|                8523|       8523|              8523|
|   mean|           NULL|12.875361375103129|            NULL|0.06613202877895127|         NULL|140.9927819781768|             NULL|20.168133

In [37]:
to_encode = {'Item_Identifier': 'itemID',
             'Item_Fat_Content': 'isLF',
             'Item_Type': 'itemTypeID',
             'Outlet_Size': 'outletSize',
             'Outlet_Type': 'outletType'}

for column, new_column in to_encode.items():
    indexer = StringIndexer(inputCol=column, outputCol=new_column)
    df = indexer.fit(df).transform(df)

In [38]:
df = df.drop('Item_Identifier').drop('Item_Fat_Content').drop('Item_Type').drop('Outlet_Size').drop('Outlet_Identifier').drop('Outlet_Type')

In [39]:
df.show()

+-----------+---------------+--------+---------+--------------------+-----------------+------+----+----------+----------+----------+
|Item_Weight|Item_Visibility|Item_MRP|outletAge|Outlet_Location_Type|Item_Outlet_Sales|itemID|isLF|itemTypeID|outletSize|outletType|
+-----------+---------------+--------+---------+--------------------+-----------------+------+----+----------+----------+----------+
|        9.3|    0.016047301|249.8092|       19|                   1|         3735.138|  40.0| 0.0|       4.0|       0.0|       0.0|
|       5.92|    0.019278216| 48.2692|        9|                   3|         443.4228| 392.0| 1.0|       8.0|       0.0|       3.0|
|       17.5|    0.016760075| 141.618|       19|                   1|          2097.27| 243.0| 0.0|       9.0|       0.0|       0.0|
|       19.2|            0.0| 182.095|       20|                   3|           732.38| 671.0| 1.0|       0.0|       1.0|       1.0|
|       8.93|            0.0| 53.8614|       31|                   3|

In [40]:
df.printSchema()

root
 |-- Item_Weight: double (nullable = true)
 |-- Item_Visibility: double (nullable = true)
 |-- Item_MRP: double (nullable = true)
 |-- outletAge: integer (nullable = true)
 |-- Outlet_Location_Type: integer (nullable = true)
 |-- Item_Outlet_Sales: double (nullable = true)
 |-- itemID: double (nullable = false)
 |-- isLF: double (nullable = false)
 |-- itemTypeID: double (nullable = false)
 |-- outletSize: double (nullable = false)
 |-- outletType: double (nullable = false)



In [42]:
df.show()

+-----------+---------------+--------+---------+--------------------+-----------------+------+----+----------+----------+----------+
|Item_Weight|Item_Visibility|Item_MRP|outletAge|Outlet_Location_Type|Item_Outlet_Sales|itemID|isLF|itemTypeID|outletSize|outletType|
+-----------+---------------+--------+---------+--------------------+-----------------+------+----+----------+----------+----------+
|        9.3|    0.016047301|249.8092|       19|                   1|         3735.138|  40.0| 0.0|       4.0|       0.0|       0.0|
|       5.92|    0.019278216| 48.2692|        9|                   3|         443.4228| 392.0| 1.0|       8.0|       0.0|       3.0|
|       17.5|    0.016760075| 141.618|       19|                   1|          2097.27| 243.0| 0.0|       9.0|       0.0|       0.0|
|       19.2|            0.0| 182.095|       20|                   3|           732.38| 671.0| 1.0|       0.0|       1.0|       1.0|
|       8.93|            0.0| 53.8614|       31|                   3|

In [61]:
independent_variables = ["itemID", "Item_Weight", "isLF", "Item_Visibility", "itemTypeID", "Item_MRP", "outletAge", "outletSize", "Outlet_Location_Type", "outletType"]
featuresVector = VectorAssembler(inputCols=independent_variables,outputCol="featuresVector")
output = featuresVector.transform(df)

In [62]:
output.show()

+-----------+---------------+--------+---------+--------------------+-----------------+------+----+----------+----------+----------+--------------------+
|Item_Weight|Item_Visibility|Item_MRP|outletAge|Outlet_Location_Type|Item_Outlet_Sales|itemID|isLF|itemTypeID|outletSize|outletType|      featuresVector|
+-----------+---------------+--------+---------+--------------------+-----------------+------+----+----------+----------+----------+--------------------+
|        9.3|    0.016047301|249.8092|       19|                   1|         3735.138|  40.0| 0.0|       4.0|       0.0|       0.0|[40.0,9.3,0.0,0.0...|
|       5.92|    0.019278216| 48.2692|        9|                   3|         443.4228| 392.0| 1.0|       8.0|       0.0|       3.0|[392.0,5.92,1.0,0...|
|       17.5|    0.016760075| 141.618|       19|                   1|          2097.27| 243.0| 0.0|       9.0|       0.0|       0.0|[243.0,17.5,0.0,0...|
|       19.2|            0.0| 182.095|       20|                   3|       

In [63]:
for col in independent_variables:
    output = output.drop(col)
output = output.withColumnRenamed('Item_Outlet_Sales', 'totalSales')

In [64]:
output.show()

+----------+--------------------+
|totalSales|      featuresVector|
+----------+--------------------+
|  3735.138|[40.0,9.3,0.0,0.0...|
|  443.4228|[392.0,5.92,1.0,0...|
|   2097.27|[243.0,17.5,0.0,0...|
|    732.38|[671.0,19.2,1.0,0...|
|  994.7052|[719.0,8.93,0.0,0...|
|  556.6088|[1450.0,10.395,1....|
|  343.5528|[72.0,13.65,1.0,0...|
| 4022.7636|[254.0,19.0,0.0,0...|
| 1076.5986|[197.0,16.2,1.0,0...|
|  4710.535|[1024.0,19.2,1.0,...|
| 1516.0266|(10,[0,1,5,6,8],[...|
|  2187.153|[427.0,18.5,1.0,0...|
| 1589.2646|[325.0,15.1,1.0,0...|
| 2145.2076|[1003.0,17.6,1.0,...|
|  1977.426|[872.0,16.35,0.0,...|
| 1547.3192|[258.0,9.0,1.0,0....|
| 1621.8888|[714.0,11.8,0.0,0...|
|  718.3982|[258.0,9.0,1.0,0....|
|  2303.668|[151.0,8.26,0.0,0...|
| 2748.4224|[297.0,13.35,0.0,...|
+----------+--------------------+
only showing top 20 rows



In [65]:
train,test = output.randomSplit([0.8, 0.2])
model = LinearRegression(featuresCol='featuresVector', labelCol='totalSales')
model = model.fit(train)

In [68]:
prediction = model.evaluate(test)

In [70]:
prediction.r2

0.3744801843091541

In [72]:
prediction.meanAbsoluteError

991.8691764484176