In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, count, regexp_extract, coalesce
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml import Pipeline
from pyspark.ml.feature import MinMaxScaler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator


In [2]:
spark = SparkSession.builder.appName("DataFrame Preprocessing").getOrCreate()
dataset = spark.read.csv("./big_mart.csv",header=True)

In [3]:
dataset.show()

+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+
|Item_Identifier|Item_Weight|Item_Fat_Content|Item_Visibility|           Item_Type|Item_MRP|Outlet_Identifier|Outlet_Establishment_Year|Outlet_Size|Outlet_Location_Type|      Outlet_Type|Item_Outlet_Sales|   Category|City_Type|Variance_of_sales|  mean_comparison|Data_direction|
+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+
|          FDP10|       12.6|         Low Fat|    0.127469857|         Snack Foods|107.7622|           OUT027|                     1985|     Medium|              T

In [4]:
dataset.printSchema()

root
 |-- Item_Identifier: string (nullable = true)
 |-- Item_Weight: string (nullable = true)
 |-- Item_Fat_Content: string (nullable = true)
 |-- Item_Visibility: string (nullable = true)
 |-- Item_Type: string (nullable = true)
 |-- Item_MRP: string (nullable = true)
 |-- Outlet_Identifier: string (nullable = true)
 |-- Outlet_Establishment_Year: string (nullable = true)
 |-- Outlet_Size: string (nullable = true)
 |-- Outlet_Location_Type: string (nullable = true)
 |-- Outlet_Type: string (nullable = true)
 |-- Item_Outlet_Sales: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- City_Type: string (nullable = true)
 |-- Variance_of_sales: string (nullable = true)
 |-- mean_comparison: string (nullable = true)
 |-- Data_direction: string (nullable = true)



In [5]:
column_names = ['Item_Weight','Item_MRP','Item_Outlet_Sales']
# for cols in column_names:
#     dataset = dataset.withColumn([cols, col(cols).cast('float')])
dataset = dataset.select(*[col(cols).cast('float').alias(cols) if cols in column_names else col(cols) for cols in dataset.columns])


In [6]:
dataset.printSchema()

root
 |-- Item_Identifier: string (nullable = true)
 |-- Item_Weight: float (nullable = true)
 |-- Item_Fat_Content: string (nullable = true)
 |-- Item_Visibility: string (nullable = true)
 |-- Item_Type: string (nullable = true)
 |-- Item_MRP: float (nullable = true)
 |-- Outlet_Identifier: string (nullable = true)
 |-- Outlet_Establishment_Year: string (nullable = true)
 |-- Outlet_Size: string (nullable = true)
 |-- Outlet_Location_Type: string (nullable = true)
 |-- Outlet_Type: string (nullable = true)
 |-- Item_Outlet_Sales: float (nullable = true)
 |-- Category: string (nullable = true)
 |-- City_Type: string (nullable = true)
 |-- Variance_of_sales: string (nullable = true)
 |-- mean_comparison: string (nullable = true)
 |-- Data_direction: string (nullable = true)



In [7]:
dataset.show()

+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+
|Item_Identifier|Item_Weight|Item_Fat_Content|Item_Visibility|           Item_Type|Item_MRP|Outlet_Identifier|Outlet_Establishment_Year|Outlet_Size|Outlet_Location_Type|      Outlet_Type|Item_Outlet_Sales|   Category|City_Type|Variance_of_sales|  mean_comparison|Data_direction|
+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+
|          FDP10|       12.6|         Low Fat|    0.127469857|         Snack Foods|107.7622|           OUT027|                     1985|     Medium|              T

In [8]:
dataset.count()

8523

In [9]:
dataset.first()

Row(Item_Identifier='FDP10', Item_Weight=12.600000381469727, Item_Fat_Content='Low Fat', Item_Visibility='0.127469857', Item_Type='Snack Foods', Item_MRP=107.76219940185547, Outlet_Identifier='OUT027', Outlet_Establishment_Year='1985', Outlet_Size='Medium', Outlet_Location_Type='Tier 3', Outlet_Type='Supermarket Type3', Item_Outlet_Sales=4022.763671875, Category='Healthy', City_Type='Village', Variance_of_sales='1841.47', mean_comparison='greater_than_mean', Data_direction='Left tailed')

In [10]:
dataset.show(15)

+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+
|Item_Identifier|Item_Weight|Item_Fat_Content|Item_Visibility|           Item_Type|Item_MRP|Outlet_Identifier|Outlet_Establishment_Year|Outlet_Size|Outlet_Location_Type|      Outlet_Type|Item_Outlet_Sales|   Category|City_Type|Variance_of_sales|  mean_comparison|Data_direction|
+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+
|          FDP10|       12.6|         Low Fat|    0.127469857|         Snack Foods|107.7622|           OUT027|                     1985|     Medium|              T

In [11]:
filter_by_weight = dataset.filter("Item_Weight <= 20")

In [12]:
filter_by_weight.show(21)

+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+
|Item_Identifier|Item_Weight|Item_Fat_Content|Item_Visibility|           Item_Type|Item_MRP|Outlet_Identifier|Outlet_Establishment_Year|Outlet_Size|Outlet_Location_Type|      Outlet_Type|Item_Outlet_Sales|   Category|City_Type|Variance_of_sales|  mean_comparison|Data_direction|
+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+
|          FDP10|       12.6|         Low Fat|    0.127469857|         Snack Foods|107.7622|           OUT027|                     1985|     Medium|              T

In [13]:
filter_by_weight.count()

8064

In [14]:
dataset.select('Item_Type','Outlet_Size').filter("Item_Outlet_Sales > 2000").show(50)

+--------------------+-----------+
|           Item_Type|Outlet_Size|
+--------------------+-----------+
|         Snack Foods|     Medium|
|         Hard Drinks|     Medium|
|        Baking Goods|     Medium|
|Fruits and Vegeta...|     Medium|
|         Snack Foods|     Medium|
|         Snack Foods|     Medium|
|        Baking Goods|     Medium|
|           Breakfast|     Medium|
|           Household|     Medium|
|         Snack Foods|     Medium|
|         Snack Foods|     Medium|
|              Canned|     Medium|
|  Health and Hygiene|     Medium|
|       Starchy Foods|     Medium|
|           Household|     Medium|
|        Frozen Foods|     Medium|
|        Baking Goods|     Medium|
|           Household|     Medium|
|               Dairy|     Medium|
|              Canned|     Medium|
|        Baking Goods|     Medium|
|              Canned|     Medium|
|           Household|     Medium|
|             Seafood|     Medium|
|         Hard Drinks|     Medium|
|         Snack Food

In [15]:
dataset.describe().show()

+-------+---------------+------------------+----------------+--------------------+-------------+-----------------+-----------------+-------------------------+-----------+--------------------+-----------------+------------------+-----------+---------+--------------------+-----------------+--------------+
|summary|Item_Identifier|       Item_Weight|Item_Fat_Content|     Item_Visibility|    Item_Type|         Item_MRP|Outlet_Identifier|Outlet_Establishment_Year|Outlet_Size|Outlet_Location_Type|      Outlet_Type| Item_Outlet_Sales|   Category|City_Type|   Variance_of_sales|  mean_comparison|Data_direction|
+-------+---------------+------------------+----------------+--------------------+-------------+-----------------+-----------------+-------------------------+-----------+--------------------+-----------------+------------------+-----------+---------+--------------------+-----------------+--------------+
|  count|           8523|              8523|            8523|                8523|   

In [16]:
dataset.groupby('Item_Fat_Content').count().show()

+----------------+-----+
|Item_Fat_Content|count|
+----------------+-----+
|         Low Fat| 5517|
|         Regular| 3006|
+----------------+-----+



In [17]:
# dataset.dropna('all').count()

In [18]:
dataset.count()

8523

In [19]:
dataset.describe().show()

+-------+---------------+------------------+----------------+--------------------+-------------+-----------------+-----------------+-------------------------+-----------+--------------------+-----------------+------------------+-----------+---------+--------------------+-----------------+--------------+
|summary|Item_Identifier|       Item_Weight|Item_Fat_Content|     Item_Visibility|    Item_Type|         Item_MRP|Outlet_Identifier|Outlet_Establishment_Year|Outlet_Size|Outlet_Location_Type|      Outlet_Type| Item_Outlet_Sales|   Category|City_Type|   Variance_of_sales|  mean_comparison|Data_direction|
+-------+---------------+------------------+----------------+--------------------+-------------+-----------------+-----------------+-------------------------+-----------+--------------------+-----------------+------------------+-----------+---------+--------------------+-----------------+--------------+
|  count|           8523|              8523|            8523|                8523|   

In [20]:
dataset.groupby('Item_Fat_Content').count().toPandas()

Unnamed: 0,Item_Fat_Content,count
0,Low Fat,5517
1,Regular,3006


In [21]:
dataset.groupby(['Item_Fat_Content','Outlet_Type']).agg({'Item_Outlet_Sales': 'mean','Item_MRP': "mean"}).show()

+----------------+-----------------+------------------+----------------------+
|Item_Fat_Content|      Outlet_Type|     avg(Item_MRP)|avg(Item_Outlet_Sales)|
+----------------+-----------------+------------------+----------------------+
|         Low Fat|Supermarket Type3| 138.4701006613487|    3643.9465082341976|
|         Low Fat|Supermarket Type2|141.88061356305278|    2008.8711346272241|
|         Low Fat|Supermarket Type1|140.72770716127314|    2288.0356293845025|
|         Low Fat|Supermarket Type4| 141.5806591331536|    341.39201943147265|
|         Regular|Supermarket Type3| 142.2432228897557|     3785.873982654918|
|         Regular|Supermarket Type4|137.89624885276513|    336.91241408655884|
|         Regular|Supermarket Type2|141.31262228994657|    1971.2663408915203|
|         Regular|Supermarket Type1| 142.1054834078967|    2367.7955709007697|
+----------------+-----------------+------------------+----------------------+



In [22]:
dataset.groupby(['Item_Fat_Content','Outlet_Type']).agg({'Item_Outlet_Sales': 'mean','Item_MRP': "mean"}).toPandas().set_index('Item_Fat_Content')

Unnamed: 0_level_0,Outlet_Type,avg(Item_MRP),avg(Item_Outlet_Sales)
Item_Fat_Content,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Low Fat,Supermarket Type3,138.470101,3643.946508
Low Fat,Supermarket Type2,141.880614,2008.871135
Low Fat,Supermarket Type1,140.727707,2288.035629
Low Fat,Supermarket Type4,141.580659,341.392019
Regular,Supermarket Type3,142.243223,3785.873983
Regular,Supermarket Type4,137.896249,336.912414
Regular,Supermarket Type2,141.312622,1971.266341
Regular,Supermarket Type1,142.105483,2367.795571


In [23]:
mean_value = dataset.select('Item_Outlet_Sales').agg({'Item_Outlet_Sales': 'mean'}).collect()[0][0]
new_dataset = dataset.withColumn('reduced_to_mean', dataset.Item_Outlet_Sales / mean_value)

In [24]:
new_dataset.show()

+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+-------------------+
|Item_Identifier|Item_Weight|Item_Fat_Content|Item_Visibility|           Item_Type|Item_MRP|Outlet_Identifier|Outlet_Establishment_Year|Outlet_Size|Outlet_Location_Type|      Outlet_Type|Item_Outlet_Sales|   Category|City_Type|Variance_of_sales|  mean_comparison|Data_direction|    reduced_to_mean|
+---------------+-----------+----------------+---------------+--------------------+--------+-----------------+-------------------------+-----------+--------------------+-----------------+-----------------+-----------+---------+-----------------+-----------------+--------------+-------------------+
|          FDP10|       12.6|         Low Fat|    0.127469857|         Snack Foods|107.7622|           

Null Values In the Dataframe

In [25]:
null_values = dataset.select([count(when(col(c).isNull(),c)).alias(c) for c in dataset.columns])
null_values.show()

+---------------+-----------+----------------+---------------+---------+--------+-----------------+-------------------------+-----------+--------------------+-----------+-----------------+--------+---------+-----------------+---------------+--------------+
|Item_Identifier|Item_Weight|Item_Fat_Content|Item_Visibility|Item_Type|Item_MRP|Outlet_Identifier|Outlet_Establishment_Year|Outlet_Size|Outlet_Location_Type|Outlet_Type|Item_Outlet_Sales|Category|City_Type|Variance_of_sales|mean_comparison|Data_direction|
+---------------+-----------+----------------+---------------+---------+--------+-----------------+-------------------------+-----------+--------------------+-----------+-----------------+--------+---------+-----------------+---------------+--------------+
|              0|          0|               0|              0|        0|       0|                0|                        0|          0|                   0|          0|                0|       0|        0|                0|    

In [26]:
dataset.columns

['Item_Identifier',
 'Item_Weight',
 'Item_Fat_Content',
 'Item_Visibility',
 'Item_Type',
 'Item_MRP',
 'Outlet_Identifier',
 'Outlet_Establishment_Year',
 'Outlet_Size',
 'Outlet_Location_Type',
 'Outlet_Type',
 'Item_Outlet_Sales',
 'Category',
 'City_Type',
 'Variance_of_sales',
 'mean_comparison',
 'Data_direction']

In [27]:
new_dataset = dataset.select("Item_Weight","Item_Type","Outlet_Size","Item_MRP","Item_Outlet_Sales")

In [28]:
new_dataset.show()

+-----------+--------------------+-----------+--------+-----------------+
|Item_Weight|           Item_Type|Outlet_Size|Item_MRP|Item_Outlet_Sales|
+-----------+--------------------+-----------+--------+-----------------+
|       12.6|         Snack Foods|     Medium|107.7622|        4022.7637|
|       12.6|         Hard Drinks|     Medium|113.2834|         2303.668|
|       12.6|        Baking Goods|     Medium|144.5444|        4064.0432|
|       12.6|Fruits and Vegeta...|     Medium|128.0678|        2797.6917|
|       12.6|         Snack Foods|     Medium| 36.9874|         388.1614|
|       12.6|         Snack Foods|     Medium| 87.6198|         2180.495|
|       12.6|Fruits and Vegeta...|     Medium| 38.2848|         484.7024|
|       12.6|         Snack Foods|     Medium|255.8356|         2543.356|
|       12.6|        Baking Goods|     Medium|171.3764|         3091.975|
|       12.6|           Breakfast|     Medium| 155.963|         3285.723|
|       12.6|           Household|    

In [29]:
independent_features = new_dataset.drop("Item_Outlet_Sales")

In [30]:
independent_features.show()

+-----------+--------------------+-----------+--------+
|Item_Weight|           Item_Type|Outlet_Size|Item_MRP|
+-----------+--------------------+-----------+--------+
|       12.6|         Snack Foods|     Medium|107.7622|
|       12.6|         Hard Drinks|     Medium|113.2834|
|       12.6|        Baking Goods|     Medium|144.5444|
|       12.6|Fruits and Vegeta...|     Medium|128.0678|
|       12.6|         Snack Foods|     Medium| 36.9874|
|       12.6|         Snack Foods|     Medium| 87.6198|
|       12.6|Fruits and Vegeta...|     Medium| 38.2848|
|       12.6|         Snack Foods|     Medium|255.8356|
|       12.6|        Baking Goods|     Medium|171.3764|
|       12.6|           Breakfast|     Medium| 155.963|
|       12.6|           Household|     Medium|149.9708|
|       12.6|         Snack Foods|     Medium|178.5344|
|       12.6|         Snack Foods|     Medium|121.7098|
|       12.6|              Canned|     Medium|180.5976|
|       12.6|  Health and Hygiene|     Medium|22

In [31]:
new_dataset.printSchema()

root
 |-- Item_Weight: float (nullable = true)
 |-- Item_Type: string (nullable = true)
 |-- Outlet_Size: string (nullable = true)
 |-- Item_MRP: float (nullable = true)
 |-- Item_Outlet_Sales: float (nullable = true)



In [32]:
new_dataset.printSchema()

root
 |-- Item_Weight: float (nullable = true)
 |-- Item_Type: string (nullable = true)
 |-- Outlet_Size: string (nullable = true)
 |-- Item_MRP: float (nullable = true)
 |-- Item_Outlet_Sales: float (nullable = true)



In [33]:
string_columns = [col_name for col_name, col_type in new_dataset.dtypes if col_type == "string"]

In [34]:
stages = []

for col in string_columns:
   
    string_indexer = StringIndexer(inputCol=col, outputCol=col + "_index")
    one_hot_encoder = OneHotEncoder(inputCol=col + "_index", outputCol=col + "_encoded")
    stages += [string_indexer, one_hot_encoder]

pipeline = Pipeline(stages=stages)
pipeline_model = pipeline.fit(new_dataset)
transformed_df = pipeline_model.transform(new_dataset)
transformed_df.show()

+-----------+--------------------+-----------+--------+-----------------+---------------+-----------------+-----------------+-------------------+
|Item_Weight|           Item_Type|Outlet_Size|Item_MRP|Item_Outlet_Sales|Item_Type_index|Item_Type_encoded|Outlet_Size_index|Outlet_Size_encoded|
+-----------+--------------------+-----------+--------+-----------------+---------------+-----------------+-----------------+-------------------+
|       12.6|         Snack Foods|     Medium|107.7622|        4022.7637|            1.0|   (15,[1],[1.0])|              0.0|      (2,[0],[1.0])|
|       12.6|         Hard Drinks|     Medium|113.2834|         2303.668|           11.0|  (15,[11],[1.0])|              0.0|      (2,[0],[1.0])|
|       12.6|        Baking Goods|     Medium|144.5444|        4064.0432|            6.0|   (15,[6],[1.0])|              0.0|      (2,[0],[1.0])|
|       12.6|Fruits and Vegeta...|     Medium|128.0678|        2797.6917|            0.0|   (15,[0],[1.0])|              0.0

In [35]:
final_dataframe = transformed_df.drop('Item_Type','Outlet_Size','Item_Type_index','Outlet_Size_index')

In [36]:
final_dataframe.show()

+-----------+--------+-----------------+-----------------+-------------------+
|Item_Weight|Item_MRP|Item_Outlet_Sales|Item_Type_encoded|Outlet_Size_encoded|
+-----------+--------+-----------------+-----------------+-------------------+
|       12.6|107.7622|        4022.7637|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6|113.2834|         2303.668|  (15,[11],[1.0])|      (2,[0],[1.0])|
|       12.6|144.5444|        4064.0432|   (15,[6],[1.0])|      (2,[0],[1.0])|
|       12.6|128.0678|        2797.6917|   (15,[0],[1.0])|      (2,[0],[1.0])|
|       12.6| 36.9874|         388.1614|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6| 87.6198|         2180.495|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6| 38.2848|         484.7024|   (15,[0],[1.0])|      (2,[0],[1.0])|
|       12.6|255.8356|         2543.356|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6|171.3764|         3091.975|   (15,[6],[1.0])|      (2,[0],[1.0])|
|       12.6| 155.963|         3285.723|  (15,[14],[

In [37]:
x = final_dataframe.drop('Item_Outlet_Sales')

In [38]:
x.show()

+-----------+--------+-----------------+-------------------+
|Item_Weight|Item_MRP|Item_Type_encoded|Outlet_Size_encoded|
+-----------+--------+-----------------+-------------------+
|       12.6|107.7622|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6|113.2834|  (15,[11],[1.0])|      (2,[0],[1.0])|
|       12.6|144.5444|   (15,[6],[1.0])|      (2,[0],[1.0])|
|       12.6|128.0678|   (15,[0],[1.0])|      (2,[0],[1.0])|
|       12.6| 36.9874|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6| 87.6198|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6| 38.2848|   (15,[0],[1.0])|      (2,[0],[1.0])|
|       12.6|255.8356|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6|171.3764|   (15,[6],[1.0])|      (2,[0],[1.0])|
|       12.6| 155.963|  (15,[14],[1.0])|      (2,[0],[1.0])|
|       12.6|149.9708|   (15,[2],[1.0])|      (2,[0],[1.0])|
|       12.6|178.5344|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6|121.7098|   (15,[1],[1.0])|      (2,[0],[1.0])|
|       12.6|180.5976|  

In [39]:
assembler = VectorAssembler(inputCols=x.columns, outputCol="features")
output = assembler.transform(final_dataframe)


In [40]:
output.show()

+-----------+--------+-----------------+-----------------+-------------------+--------------------+
|Item_Weight|Item_MRP|Item_Outlet_Sales|Item_Type_encoded|Outlet_Size_encoded|            features|
+-----------+--------+-----------------+-----------------+-------------------+--------------------+
|       12.6|107.7622|        4022.7637|   (15,[1],[1.0])|      (2,[0],[1.0])|(19,[0,1,3,17],[1...|
|       12.6|113.2834|         2303.668|  (15,[11],[1.0])|      (2,[0],[1.0])|(19,[0,1,13,17],[...|
|       12.6|144.5444|        4064.0432|   (15,[6],[1.0])|      (2,[0],[1.0])|(19,[0,1,8,17],[1...|
|       12.6|128.0678|        2797.6917|   (15,[0],[1.0])|      (2,[0],[1.0])|(19,[0,1,2,17],[1...|
|       12.6| 36.9874|         388.1614|   (15,[1],[1.0])|      (2,[0],[1.0])|(19,[0,1,3,17],[1...|
|       12.6| 87.6198|         2180.495|   (15,[1],[1.0])|      (2,[0],[1.0])|(19,[0,1,3,17],[1...|
|       12.6| 38.2848|         484.7024|   (15,[0],[1.0])|      (2,[0],[1.0])|(19,[0,1,2,17],[1...|


In [41]:
df = output.select('features','Item_Outlet_Sales')

In [42]:
df.show()

+--------------------+-----------------+
|            features|Item_Outlet_Sales|
+--------------------+-----------------+
|(19,[0,1,3,17],[1...|        4022.7637|
|(19,[0,1,13,17],[...|         2303.668|
|(19,[0,1,8,17],[1...|        4064.0432|
|(19,[0,1,2,17],[1...|        2797.6917|
|(19,[0,1,3,17],[1...|         388.1614|
|(19,[0,1,3,17],[1...|         2180.495|
|(19,[0,1,2,17],[1...|         484.7024|
|(19,[0,1,3,17],[1...|         2543.356|
|(19,[0,1,8,17],[1...|         3091.975|
|(19,[0,1,16,17],[...|         3285.723|
|(19,[0,1,4,17],[1...|        4363.6533|
|(19,[0,1,3,17],[1...|        2854.9504|
|(19,[0,1,3,17],[1...|         4097.333|
|(19,[0,1,7,17],[1...|        7968.2944|
|(19,[0,1,9,17],[1...|        6976.2524|
|(19,[0,1,15,17],[...|        5262.4834|
|(19,[0,1,10,17],[...|           898.83|
|(19,[0,1,2,17],[1...|        1808.9786|
|(19,[0,1,4,17],[1...|         5555.435|
|(19,[0,1,5,17],[1...|         6024.158|
+--------------------+-----------------+
only showing top

In [43]:
scaler = MinMaxScaler(inputCol="features", outputCol="scaledFeatures")
scalerModel = scaler.fit(df)
scaledData = scalerModel.transform(df)

In [44]:
scaledData.show()

+--------------------+-----------------+--------------------+
|            features|Item_Outlet_Sales|      scaledFeatures|
+--------------------+-----------------+--------------------+
|(19,[0,1,3,17],[1...|        4022.7637|(19,[0,1,3,17],[0...|
|(19,[0,1,13,17],[...|         2303.668|(19,[0,1,13,17],[...|
|(19,[0,1,8,17],[1...|        4064.0432|(19,[0,1,8,17],[0...|
|(19,[0,1,2,17],[1...|        2797.6917|(19,[0,1,2,17],[0...|
|(19,[0,1,3,17],[1...|         388.1614|(19,[0,1,3,17],[0...|
|(19,[0,1,3,17],[1...|         2180.495|(19,[0,1,3,17],[0...|
|(19,[0,1,2,17],[1...|         484.7024|(19,[0,1,2,17],[0...|
|(19,[0,1,3,17],[1...|         2543.356|(19,[0,1,3,17],[0...|
|(19,[0,1,8,17],[1...|         3091.975|(19,[0,1,8,17],[0...|
|(19,[0,1,16,17],[...|         3285.723|(19,[0,1,16,17],[...|
|(19,[0,1,4,17],[1...|        4363.6533|(19,[0,1,4,17],[0...|
|(19,[0,1,3,17],[1...|        2854.9504|(19,[0,1,3,17],[0...|
|(19,[0,1,3,17],[1...|         4097.333|(19,[0,1,3,17],[0...|
|(19,[0,

In [45]:
scaledData = scaledData.drop('features')

In [46]:
scaledData.show()

+-----------------+--------------------+
|Item_Outlet_Sales|      scaledFeatures|
+-----------------+--------------------+
|        4022.7637|(19,[0,1,3,17],[0...|
|         2303.668|(19,[0,1,13,17],[...|
|        4064.0432|(19,[0,1,8,17],[0...|
|        2797.6917|(19,[0,1,2,17],[0...|
|         388.1614|(19,[0,1,3,17],[0...|
|         2180.495|(19,[0,1,3,17],[0...|
|         484.7024|(19,[0,1,2,17],[0...|
|         2543.356|(19,[0,1,3,17],[0...|
|         3091.975|(19,[0,1,8,17],[0...|
|         3285.723|(19,[0,1,16,17],[...|
|        4363.6533|(19,[0,1,4,17],[0...|
|        2854.9504|(19,[0,1,3,17],[0...|
|         4097.333|(19,[0,1,3,17],[0...|
|        7968.2944|(19,[0,1,7,17],[0...|
|        6976.2524|(19,[0,1,9,17],[0...|
|        5262.4834|(19,[0,1,15,17],[...|
|           898.83|(19,[0,10,17],[0....|
|        1808.9786|(19,[0,1,2,17],[0...|
|         5555.435|(19,[0,1,4,17],[0...|
|         6024.158|(19,[0,1,5,17],[0...|
+-----------------+--------------------+
only showing top

In [51]:
train,test = scaledData.randomSplit([0.80,0.20])

In [52]:
train.show()

+-----------------+--------------------+
|Item_Outlet_Sales|      scaledFeatures|
+-----------------+--------------------+
|            33.29|(19,[0,1,4,17],[0...|
|            33.29|(19,[0,1,10,17],[...|
|          33.9558|(19,[0,1,3,18],[0...|
|          34.6216|(19,[0,1,9,18],[0...|
|          35.2874|(19,[0,1,12,17],[...|
|           36.619|(19,[0,1,5,18],[0...|
|          37.2848|(19,[0,1,8,18],[0...|
|          37.9506|(19,[0,1,5,18],[0...|
|          37.9506|(19,[0,1,7,18],[0...|
|          37.9506|(19,[0,1,9,17],[0...|
|          37.9506|(19,[0,1,13,18],[...|
|          38.6164|(19,[0,1,5,18],[0...|
|           39.948|(19,[0,1,14,18],[...|
|           39.948|(19,[0,1,16,17],[...|
|          40.6138|(19,[0,1,10,18],[...|
|          41.2796|(19,[0,1,4,18],[0...|
|          41.2796|(19,[0,1,5,17],[0...|
|          41.2796|(19,[0,1,5,18],[0...|
|          41.9454|(19,[0,1,5,18],[0...|
|          41.9454|(19,[0,1,7,17],[0...|
+-----------------+--------------------+
only showing top

In [53]:
test.show()

+-----------------+--------------------+
|Item_Outlet_Sales|      scaledFeatures|
+-----------------+--------------------+
|           36.619|(19,[0,1,5,17],[0...|
|          37.9506|(19,[0,1,5,18],[0...|
|          38.6164|(19,[0,1,8,18],[0...|
|          40.6138|(19,[0,1,6,18],[0...|
|          41.9454|(19,[0,1,4,17],[0...|
|          47.9376|(19,[0,1,11,17],[...|
|          50.6008|(19,[0,1,3,18],[0...|
|           56.593|(19,[0,1,2,17],[0...|
|          67.9116|(19,[0,1,4,17],[0...|
|          69.2432|(19,[0,1,4,17],[0...|
|          71.9064|(19,[0,1,11,17],[...|
|           73.238|(19,[0,1,13,17],[...|
|          75.9012|(19,[0,1,13,17],[...|
|          78.5644|(19,[0,1,11,17],[...|
|           79.896|(19,[0,1,3,18],[0...|
|          81.2276|(19,[0,1,6,17],[0...|
|          87.2198|(19,[0,1,2,18],[0...|
|          87.2198|(19,[0,1,7,17],[0...|
|          88.5514|(19,[0,1,8,17],[0...|
|          89.2172|(19,[0,1,7,18],[0...|
+-----------------+--------------------+
only showing top

In [71]:
train.count() , test.count()

(6823, 1700)

In [55]:
lin_reg = LinearRegression(featuresCol='scaledFeatures',labelCol='Item_Outlet_Sales')
lin_model = lin_reg.fit(train)

In [56]:
print("coefficients: ",lin_model.coefficients)
print('intercept: ',lin_model.intercept)

coefficients:  [-22.051885611943646,3690.6689084390537,-224.88569986242524,-190.34377891588818,-284.0705501725863,-225.31574000639515,-227.52245876048045,-141.65850847789068,-212.69820645777114,-258.9819706330785,-233.30614893371992,-180.70525082184759,-194.6977022321746,-225.97949526421422,-304.02492123673545,-65.89797307700208,-364.93806393542724,16.206582156800234,-397.8257702074755]
intercept:  798.083041119296


In [57]:
model_summary = lin_model.summary
print('RMSE: ',model_summary.rootMeanSquaredError)
print('R2Score: ',model_summary.r2*100)

RMSE:  1400.0402303970607
R2Score:  33.283146687783216


In [58]:
prediction = lin_model.transform(test)

In [59]:
prediction.show()

+-----------------+--------------------+------------------+
|Item_Outlet_Sales|      scaledFeatures|        prediction|
+-----------------+--------------------+------------------+
|           36.619|(19,[0,1,5,17],[0...| 636.3481479053805|
|          37.9506|(19,[0,1,5,18],[0...| 298.4809524965891|
|          38.6164|(19,[0,1,8,18],[0...|271.39998724133477|
|          40.6138|(19,[0,1,6,18],[0...|333.29392078170605|
|          41.9454|(19,[0,1,4,17],[0...| 708.2331236529742|
|          47.9376|(19,[0,1,11,17],[...| 890.0278876420905|
|          50.6008|(19,[0,1,3,18],[0...| 525.3533124648691|
|           56.593|(19,[0,1,2,17],[0...| 934.5619306367005|
|          67.9116|(19,[0,1,4,17],[0...| 584.0146195156473|
|          69.2432|(19,[0,1,4,17],[0...| 595.5504383574511|
|          71.9064|(19,[0,1,11,17],[...| 726.0222111427948|
|           73.238|(19,[0,1,13,17],[...|1248.1950109974016|
|          75.9012|(19,[0,1,13,17],[...| 708.0110311174966|
|          78.5644|(19,[0,1,11,17],[...|

In [60]:
linear_model_evaluator = RegressionEvaluator(predictionCol='prediction',labelCol='Item_Outlet_Sales',metricName='r2')
print('rsquared_for_test_data: ',linear_model_evaluator.evaluate(prediction))

rsquared_for_test_data:  0.3348655894011082
