In [2]:
#importing libraries
from pyspark.sql import SparkSession
import pyspark.sql.types as tp
from pyspark.sql import functions as f
import matplotlib.pyplot as plt
%matplotlib inline


In [3]:
#initializing spark
spark = SparkSession.builder.appName("purchase_predictor").getOrCreate()

In [4]:
# loading data
purchase_train = spark.read.csv("/content/train.csv", header =True, inferSchema = True)
purchase_test = spark.read.csv("/content/test.csv", header =True, inferSchema = True)

In [5]:
purchase_train.show()

+-------+----------+------+-----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+
|User_ID|Product_ID|Gender|  Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|Purchase|
+-------+----------+------+-----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+
|1000001| P00069042|     F| 0-17|        10|            A|                         2|             0|                 3|              NULL|              NULL|    8370|
|1000001| P00248942|     F| 0-17|        10|            A|                         2|             0|                 1|                 6|                14|   15200|
|1000001| P00087842|     F| 0-17|        10|            A|                         2|             0|                12|              NULL|              NULL|    1422

In [6]:
purchase_test.show()

+-------+----------+------+-----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+
|User_ID|Product_ID|Gender|  Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|
+-------+----------+------+-----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+
|1000004| P00128942|     M|46-50|         7|            B|                         2|             1|                 1|                11|              NULL|
|1000009| P00113442|     M|26-35|        17|            C|                         0|             0|                 3|                 5|              NULL|
|1000010| P00288442|     F|36-45|         1|            B|                        4+|             1|                 5|                14|              NULL|
|1000010| P00145342|     F|36-45|         1|        

In [7]:
purchase_train.printSchema()

root
 |-- User_ID: integer (nullable = true)
 |-- Product_ID: string (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- Occupation: integer (nullable = true)
 |-- City_Category: string (nullable = true)
 |-- Stay_In_Current_City_Years: string (nullable = true)
 |-- Marital_Status: integer (nullable = true)
 |-- Product_Category_1: integer (nullable = true)
 |-- Product_Category_2: integer (nullable = true)
 |-- Product_Category_3: integer (nullable = true)
 |-- Purchase: integer (nullable = true)



In [8]:
#Average purchase amount
average_purchase = purchase_train.groupBy("Product_ID").agg(f.avg("purchase").alias("Avarage_purchase"))
average_purchase.show()

+----------+------------------+
|Product_ID|  Avarage_purchase|
+----------+------------------+
|  P0098242|12581.417721518987|
| P00281742|            7972.0|
| P00026042| 6745.601851851852|
| P00015342| 10034.42857142857|
| P00159842|10255.973684210527|
| P00162642|            6359.0|
| P00048442| 6990.714285714285|
| P00078842|           8912.95|
| P00313242| 7527.372093023256|
| P00318342|           10732.8|
|  P0096342| 6328.872340425532|
| P00146342| 4951.136363636364|
| P00180642|3247.4736842105262|
| P00256142|            7093.5|
| P00323242|          15783.25|
| P00014542| 8122.818713450292|
| P00331542| 7053.490566037736|
| P00212242| 8396.680555555555|
| P00165442| 9807.792452830188|
| P00163342|12501.904761904761|
+----------+------------------+
only showing top 20 rows



In [9]:
average_purchase.orderBy("Avarage_purchase", ascending=False).show()


+----------+------------------+
|Product_ID|  Avarage_purchase|
+----------+------------------+
| P00086242|        21262.7625|
| P00107342|           21148.0|
| P00085342|21020.238993710693|
| P00272342|           20889.0|
| P00162142|           20859.0|
| P00188642| 20792.57894736842|
| P00277342|           20725.0|
| P00200642| 20597.97142857143|
| P00116142|20526.117318435754|
| P00311242|           20488.0|
| P00119342| 20472.18181818182|
| P00131842|           20398.0|
| P00087042| 20297.21518987342|
| P00117642|20222.083333333332|
| P00052842|  20162.4350877193|
| P00343842|         19634.875|
| P00308042|           19206.0|
| P00071442|19133.864485981307|
| P00273342|          19048.25|
| P00124742|19024.535714285714|
+----------+------------------+
only showing top 20 rows



COUNTING AND REMOVING NULL VALUES

In [10]:
#count null values
purchase_train.select([f.count(f.when(f.isnull(c), c)).alias(c) for c in purchase_train.columns]).show()

+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+
|User_ID|Product_ID|Gender|Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|Purchase|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+
|      0|         0|     0|  0|         0|            0|                         0|             0|                 0|             49069|            109769|       0|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+



In [11]:
purchase_test.select([f.count(f.when(f.isnull(c), c)).alias(c) for c in purchase_test.columns]).show()

+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+
|User_ID|Product_ID|Gender|Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+
|      0|         0|     0|  0|         0|            0|                         0|             0|                 0|             72344|            162562|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+



In [12]:
# count of cat2 and cat2
purchase_train.groupBy("Product_Category_2").agg(f.count("Product_Category_2")).orderBy("count(Product_Category_2)", ascending=False).show()
purchase_train.groupBy("Product_Category_3").agg(f.count("Product_Category_3")).orderBy("count(Product_Category_3)", ascending=False).show()

+------------------+-------------------------+
|Product_Category_2|count(Product_Category_2)|
+------------------+-------------------------+
|                 8|                    18477|
|                14|                    16004|
|                 2|                    14152|
|                16|                    12663|
|                15|                    10847|
|                 5|                     7612|
|                 4|                     7474|
|                 6|                     4737|
|                11|                     4103|
|                17|                     3885|
|                13|                     3080|
|                 9|                     1682|
|                12|                     1641|
|                10|                      869|
|                 3|                      839|
|                18|                      795|
|                 7|                      177|
|              NULL|                        0|
+------------

In [13]:
purchase_test.groupBy("Product_Category_2").agg(f.count("Product_Category_2")).orderBy("count(Product_Category_2)", ascending=False).show()
purchase_test.groupBy("Product_Category_3").agg(f.count("Product_Category_3")).orderBy("count(Product_Category_3)", ascending=False).show()

+------------------+-------------------------+
|Product_Category_2|count(Product_Category_2)|
+------------------+-------------------------+
|                 8|                    27229|
|                14|                    23726|
|                 2|                    21281|
|                16|                    18432|
|                15|                    16259|
|                 4|                    11028|
|                 5|                    10930|
|                 6|                     7109|
|                11|                     6096|
|                17|                     5784|
|                13|                     4523|
|                 9|                     2484|
|                12|                     2273|
|                10|                     1377|
|                18|                     1257|
|                 3|                     1239|
|                 7|                      228|
|              NULL|                        0|
+------------

In [14]:
# fill with cat with most count
purchase_train = purchase_train.na.fill({"Product_Category_2": 8, "Product_Category_3": 16})
purchase_test = purchase_test.na.fill({"Product_Category_2": 8, "Product_Category_3": 16})

In [15]:
# checking
purchase_train.select([f.count(f.when(f.isnull(c), c)).alias(c) for c in purchase_train.columns]).show()
purchase_test.select([f.count(f.when(f.isnull(c), c)).alias(c) for c in purchase_test.columns]).show()

+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+
|User_ID|Product_ID|Gender|Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|Purchase|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+
|      0|         0|     0|  0|         0|            0|                         0|             0|                 0|                 0|                 0|       0|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+

+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+
|User_ID|Product_I

CHECKING FOR DISTINCT VALUES

In [16]:
# distinct values in each column
purchase_train.agg(*(f.countDistinct(f.col(c)).alias(c) for c in purchase_train.columns)).show()

+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+
|User_ID|Product_ID|Gender|Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|Purchase|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+
|   5879|      3428|     2|  7|        21|            3|                         5|             2|                18|                17|                15|   15197|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+



In [17]:
purchase_test.agg(*(f.countDistinct(f.col(c)).alias(c) for c in purchase_test.columns)).show()

+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+
|User_ID|Product_ID|Gender|Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+
|   5891|      3491|     2|  7|        21|            3|                         5|             2|                18|                17|                15|
+-------+----------+------+---+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+



Count category values within each of the following column:

 ● Gender

 ● Age

 ● City_Category

 ● Stay_In_Current_City_Years

● **Marital_Status**

In [18]:
#Count gender
purchase_train.groupBy("Gender").agg(f.count("Gender")).orderBy("count(Gender)", ascending=False).show()

+------+-------------+
|Gender|count(Gender)|
+------+-------------+
|     M|       119217|
|     F|        38889|
+------+-------------+



In [19]:
#Count age
purchase_train.groupBy("Age").agg(f.count("Age")).orderBy("count(Age)", ascending=False).show()

+-----+----------+
|  Age|count(Age)|
+-----+----------+
|26-35|     62940|
|36-45|     31559|
|18-25|     29064|
|46-50|     13060|
|51-55|     11106|
|  55+|      6094|
| 0-17|      4283|
+-----+----------+



In [20]:
#Count city_category
purchase_train.groupBy("City_Category").agg(f.count("City_Category")).orderBy("count(City_Category)", ascending=False).show()

+-------------+--------------------+
|City_Category|count(City_Category)|
+-------------+--------------------+
|            B|               66617|
|            C|               48990|
|            A|               42499|
+-------------+--------------------+



In [21]:
#Count Stay_In_Current_City_Years
purchase_train.groupBy("Stay_In_Current_City_Years").agg(f.count("Stay_In_Current_City_Years")).orderBy("count(Stay_In_Current_City_Years)", ascending=False).show()

+--------------------------+---------------------------------+
|Stay_In_Current_City_Years|count(Stay_In_Current_City_Years)|
+--------------------------+---------------------------------+
|                         1|                            55521|
|                         2|                            29254|
|                         3|                            27288|
|                        4+|                            24562|
|                         0|                            21481|
+--------------------------+---------------------------------+



In [22]:
#Count Marital_status
purchase_train.groupBy("Marital_status").agg(f.count("Marital_Status")).orderBy("count(Marital_Status)", ascending=False).show()

+--------------+---------------------+
|Marital_status|count(Marital_Status)|
+--------------+---------------------+
|             0|                93265|
|             1|                64841|
+--------------+---------------------+



Calculate average Purchase for each of the following columns:

 ● Gender

 ● Age

 ● City_Category

 ● Stay_In_Current_City_Years

 ● Marital_Status

In [23]:
# avrage purchase for gender
purchase_train.groupBy("Gender").agg(f.round(f.avg("Purchase"),2).alias("Average_purchase")).orderBy("Average_Purchase", ascending=False).show()

+------+----------------+
|Gender|Average_purchase|
+------+----------------+
|     M|         9481.98|
|     F|         8796.77|
+------+----------------+



In [24]:
# avrage purchase forage
purchase_train.groupBy("Age").agg(f.round(f.avg("Purchase"),2).alias("Average_purchase")).orderBy("Average_Purchase", ascending=False).show()

+-----+----------------+
|  Age|Average_purchase|
+-----+----------------+
|51-55|         9602.01|
|  55+|         9390.19|
|36-45|          9388.2|
|26-35|         9299.17|
|46-50|         9263.53|
|18-25|         9187.43|
| 0-17|         9122.21|
+-----+----------------+



In [25]:
# avrage purchase for city category
purchase_train.groupBy("City_Category").agg(f.round(f.avg("Purchase"),2).alias("Average_purchase")).orderBy("Average_Purchase", ascending=False).show()

+-------------+----------------+
|City_Category|Average_purchase|
+-------------+----------------+
|            C|         9835.79|
|            B|         9177.59|
|            A|         8924.26|
+-------------+----------------+



In [26]:
# avrage purchase for stayer in current city years
purchase_train.groupBy("Stay_In_Current_City_Years").agg(f.round(f.avg("Purchase"),2).alias("Average_purchase")).orderBy("Average_Purchase", ascending=False).show()

+--------------------------+----------------+
|Stay_In_Current_City_Years|Average_purchase|
+--------------------------+----------------+
|                         2|         9375.14|
|                         3|         9322.38|
|                         1|          9321.7|
|                        4+|         9304.99|
|                         0|          9206.4|
+--------------------------+----------------+



In [27]:
# avrage purchase for marital status
purchase_train.groupBy("Marital_Status").agg(f.round(f.avg("Purchase"),2).alias("Average_purchase")).orderBy("Average_Purchase", ascending=False).show()

+--------------+----------------+
|Marital_Status|Average_purchase|
+--------------+----------------+
|             1|         9327.63|
|             0|         9303.58|
+--------------+----------------+



Label encode the following columns:

 ● Age

 ● Gender

 ● Stay_In_Current_City_Years

 ● City_Category

In [28]:
#import libraries for label encoding
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml import pipeline

In [29]:
# label coding
SI_Age = StringIndexer(inputCol="Age", outputCol="Age_indexed", handleInvalid="skip")
SI_Gender = StringIndexer(inputCol="Gender", outputCol="Gender_indexed", handleInvalid="skip")
SI_Stay_In_Current_city_Years = StringIndexer(inputCol="Stay_In_Current_City_Years", outputCol="Stay_indexed", handleInvalid="skip")
SI_city_Category = StringIndexer(inputCol="City_Category", outputCol="City_indexed", handleInvalid="skip")

In [30]:
SI_Age_Obj = SI_Age.fit(purchase_train)
SI_Gender_Obj = SI_Gender.fit(purchase_train)
SI_Stay_In_Current_city_Years_Obj = SI_Stay_In_Current_city_Years.fit(purchase_train)
SI_city_Obj = SI_city_Category.fit(purchase_train)

In [31]:
purchase_train_encoded = SI_Age_Obj.transform(purchase_train)
purchase_train_encoded = SI_Gender_Obj.transform(purchase_train_encoded)
purchase_train_encoded = SI_Stay_In_Current_city_Years_Obj.transform(purchase_train_encoded)
purchase_train_encoded = SI_city_Obj.transform(purchase_train_encoded)

In [32]:
purchase_test_encoded = SI_Age_Obj.transform(purchase_test)
purchase_test_encoded = SI_Gender_Obj.transform(purchase_test_encoded)
purchase_test_encoded = SI_Stay_In_Current_city_Years_Obj.transform(purchase_test_encoded)
purchase_test_encoded = SI_city_Obj.transform(purchase_test_encoded)

In [33]:
purchase_train_encoded.show()

+-------+----------+------+-----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+-----------+--------------+------------+------------+
|User_ID|Product_ID|Gender|  Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|Purchase|Age_indexed|Gender_indexed|Stay_indexed|City_indexed|
+-------+----------+------+-----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+-----------+--------------+------------+------------+
|1000001| P00069042|     F| 0-17|        10|            A|                         2|             0|                 3|                 8|                16|    8370|        6.0|           1.0|         1.0|         2.0|
|1000001| P00248942|     F| 0-17|        10|            A|                         2|             0|                 1| 

One-Hot encode following columns:

 ● Gender

 ● City_Category

 ● Occupation

In [34]:
purchase_train_encoded.columns

['User_ID',
 'Product_ID',
 'Gender',
 'Age',
 'Occupation',
 'City_Category',
 'Stay_In_Current_City_Years',
 'Marital_Status',
 'Product_Category_1',
 'Product_Category_2',
 'Product_Category_3',
 'Purchase',
 'Age_indexed',
 'Gender_indexed',
 'Stay_indexed',
 'City_indexed']

In [35]:
OHE_train = OneHotEncoder(inputCols = ["Gender_indexed",
                                       "City_indexed",
                                       "Occupation"],
                          outputCols = ["Gender_ohe",
                                        "City_Category_ohe",
                                        "Occupation_ohe"])

In [36]:
OHE_Obj = OHE_train.fit(purchase_train_encoded)

In [37]:
purchase_train_encoded = OHE_Obj.transform(purchase_train_encoded)
purchase_train_encoded.show()

+-------+----------+------+-----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+-----------+--------------+------------+------------+-------------+-----------------+---------------+
|User_ID|Product_ID|Gender|  Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|Purchase|Age_indexed|Gender_indexed|Stay_indexed|City_indexed|   Gender_ohe|City_Category_ohe| Occupation_ohe|
+-------+----------+------+-----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+-----------+--------------+------------+------------+-------------+-----------------+---------------+
|1000001| P00069042|     F| 0-17|        10|            A|                         2|             0|                 3|                 8|                16|    8370|        6.0|           1.0|   

In [38]:
purchase_test_encoded = OHE_Obj.transform(purchase_test_encoded)

In [39]:
purchase_train_encoded.columns

['User_ID',
 'Product_ID',
 'Gender',
 'Age',
 'Occupation',
 'City_Category',
 'Stay_In_Current_City_Years',
 'Marital_Status',
 'Product_Category_1',
 'Product_Category_2',
 'Product_Category_3',
 'Purchase',
 'Age_indexed',
 'Gender_indexed',
 'Stay_indexed',
 'City_indexed',
 'Gender_ohe',
 'City_Category_ohe',
 'Occupation_ohe']

Build a baseline model using any of the ML algorithms.

In [40]:
assembler = VectorAssembler(
    inputCols=[
        'Age_indexed',
        'Stay_indexed',
        'Product_Category_1',
        'Product_Category_2',
        'Product_Category_3',
        'Marital_Status',
        'Gender_ohe',
        'City_Category_ohe',
        'Occupation_ohe'
    ],
    outputCol='features'
)

In [41]:
purchase_train_encoded = assembler.transform(purchase_train_encoded)
purchase_test_encoded = assembler.transform(purchase_test_encoded)

In [42]:
purchase_train_encoded.select("features").show()

+--------------------+
|            features|
+--------------------+
|(29,[0,1,2,3,4,19...|
|(29,[0,1,2,3,4,19...|
|(29,[0,1,2,3,4,19...|
|(29,[0,1,2,3,4,19...|
|(29,[0,1,2,3,4,6,...|
|(29,[1,2,3,4,6,24...|
|(29,[0,1,2,3,4,5,...|
|(29,[0,1,2,3,4,5,...|
|(29,[0,1,2,3,4,5,...|
|(29,[2,3,4,5,6],[...|
|(29,[2,3,4,5,6],[...|
|(29,[2,3,4,5,6],[...|
|(29,[2,3,4,5,6],[...|
|(29,[2,3,4,5,6],[...|
|(29,[0,2,3,4,18],...|
|(29,[0,2,3,4,18],...|
|(29,[0,2,3,4,18],...|
|(29,[0,2,3,4,18],...|
|(29,[0,2,3,4,5,6,...|
|(29,[1,2,3,4,5,6,...|
+--------------------+
only showing top 20 rows



In [43]:
purchase_valid_encoded = purchase_test_encoded

In [44]:
purchase_train, purchase_test = purchase_train_encoded.randomSplit([0.8,0.2], seed=42)

In [45]:
from pyspark.ml.regression import GBTRegressor
from pyspark.ml.evaluation import RegressionEvaluator

In [46]:
model_gbt = GBTRegressor(featuresCol="features", labelCol="Purchase")

In [47]:
model_gbt = model_gbt.fit(purchase_train)

In [48]:
purchase_test.selectExpr("min(Purchase)", "max(Purchase)", "avg(Purchase)").show()


+-------------+-------------+-----------------+
|min(Purchase)|max(Purchase)|    avg(Purchase)|
+-------------+-------------+-----------------+
|          193|        23956|9322.331637650479|
+-------------+-------------+-----------------+



In [49]:
predictions_gbt = model_gbt.transform(purchase_train)
evaluator_gbt = RegressionEvaluator(labelCol="Purchase", predictionCol="prediction", metricName="rmse")
print("RMSE:", evaluator_gbt.evaluate(predictions_gbt))

RMSE: 2966.81964112578


In [50]:
predictions_gbt = model_gbt.transform(purchase_test)
evaluator_gbt = RegressionEvaluator(labelCol="Purchase", predictionCol="prediction", metricName="rmse")
print("RMSE:", evaluator_gbt.evaluate(predictions_gbt))

RMSE: 3015.021738117338


Model improvement with Grid-Search CV

In [51]:
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
paramGrid = ParamGridBuilder() \
    .addGrid(model_gbt.maxDepth, [3, 5, 7]) \
    .addGrid(model_gbt.maxIter, [50, 100]) \
    .addGrid(model_gbt.stepSize, [0.05, 0.1]) \
    .build()

In [52]:
model_gbt = GBTRegressor(featuresCol="features", labelCol="Purchase")

In [53]:
cv = CrossValidator(estimator=model_gbt,
                    estimatorParamMaps=paramGrid,
                    evaluator= evaluator_gbt,
                    numFolds=5,
                    seed=27)

In [54]:
grid_model_gbt = cv.fit(purchase_train)

In [55]:
predictions_gbt = grid_model_gbt.transform(purchase_train)
evaluator_gbt = RegressionEvaluator(labelCol="Purchase", predictionCol="prediction", metricName="rmse")
print("RMSE:", evaluator_gbt.evaluate(predictions_gbt))

RMSE: 2966.81964112578


In [56]:
predictions_gbt = grid_model_gbt.transform(purchase_test)
evaluator_gbt = RegressionEvaluator(labelCol="Purchase", predictionCol="prediction", metricName="rmse")
print("RMSE:", evaluator_gbt.evaluate(predictions_gbt))

RMSE: 3015.021738117338


In [58]:
best_model = grid_model_gbt.bestModel
print("Best regParam:", best_model)


Best regParam: GBTRegressionModel: uid=GBTRegressor_0c0ef9c0ae1b, numTrees=20, numFeatures=29


In [59]:
print("Best maxIter:", best_model.getMaxIter())
print("Best maxDepth:", best_model.getMaxDepth())
print("Best stepSize:", best_model.getStepSize())
print("Num Trees:", best_model.getNumTrees)
print("Num Features:", best_model.numFeatures)

Best maxIter: 20
Best maxDepth: 5
Best stepSize: 0.1
Num Trees: 20
Num Features: 29


In [61]:
final_gbt = GBTRegressor(
    featuresCol="features",
    labelCol="Purchase",
    maxIter=20,
    maxDepth=5,
    stepSize=0.1
)

In [62]:
final_model = final_gbt.fit(purchase_train)

In [63]:
predictions = final_model.transform(purchase_test)
evaluator = RegressionEvaluator(labelCol="Purchase", predictionCol="prediction", metricName="rmse")
print("RMSE:", evaluator.evaluate(predictions_gbt))
predictions.select("Purchase", "prediction").show(5)


RMSE: 3015.021738117338
+--------+------------------+
|Purchase|        prediction|
+--------+------------------+
|    8370|  10044.4199284901|
|    2763|3059.4513664714837|
|    8839|   6134.2881992845|
|    6187| 8348.744377311461|
|    7980| 7782.594103663812|
+--------+------------------+
only showing top 5 rows



Create a Spark ML Pipeline for the final model.

In [64]:
# 1. Index categorical columns
indexers = [
    StringIndexer(inputCol='Gender', outputCol='Gender_indexed', handleInvalid="keep"),
    StringIndexer(inputCol='Age', outputCol='Age_indexed', handleInvalid="keep"),
    StringIndexer(inputCol='Stay_In_Current_City_Years', outputCol='Stay_indexed', handleInvalid="keep"),
    StringIndexer(inputCol='City_Category', outputCol='City_indexed', handleInvalid="keep")
]

In [65]:
# 2. One-Hot Encode selected indexed columns
encoder = OneHotEncoder(
    inputCols=["Gender_indexed", "City_indexed", "Occupation"],
    outputCols=["Gender_ohe", "City_Category_ohe", "Occupation_ohe"],
    handleInvalid="keep"
)


In [67]:
# 3. Feature columns for the model
feature_cols = [
        'Age_indexed',
        'Stay_indexed',
        'Product_Category_1',
        'Product_Category_2',
        'Product_Category_3',
        'Marital_Status',
        'Gender_ohe',
        'City_Category_ohe',
        'Occupation_ohe'
]

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")


In [68]:
# 4. Final GBT Model with tuned params
gbt = GBTRegressor(
    featuresCol="features",
    labelCol="Purchase",
    maxIter=20,
    maxDepth=5,
    stepSize=0.1
)

In [69]:
# 5. Create the pipeline
from pyspark.ml import Pipeline
pipeline = Pipeline(stages=indexers + [encoder, assembler, gbt])

In [71]:
pipeline

Pipeline_74dadec758e5

In [73]:
#load data again
purchase_train_ag = spark.read.csv("/content/train.csv", header=True, inferSchema=True)
purchase_test_ag = spark.read.csv("/content/test.csv", header=True, inferSchema=True)

In [74]:
purchase_train_ag = purchase_train_ag.na.fill({"Product_Category_2": 8, "Product_Category_3": 16})
purchase_test_ag = purchase_test_ag.na.fill({"Product_Category_2": 8, "Product_Category_3": 16})

In [75]:
purchase_valid_ag = purchase_test_ag

In [76]:
purchase_train_ag, purchase_test_ag = purchase_train_ag.randomSplit([0.8,0.2], seed=42)

In [78]:
# 6. Fit the model on training data
pipeline_model = pipeline.fit(purchase_train_ag)

# 7. Predict on test data
predictions = pipeline_model.transform(purchase_test_ag)
predictions.select("Purchase", "prediction").show(5)

+--------+------------------+
|Purchase|        prediction|
+--------+------------------+
|    9938| 8008.996156592192|
|    1057|1498.8480191003503|
|    1422|1717.4039972147618|
|    7882| 7645.854874906616|
|   11039| 9420.069335760385|
+--------+------------------+
only showing top 5 rows



In [79]:
evaluator = RegressionEvaluator(labelCol="Purchase", predictionCol="prediction", metricName="rmse")
print("RMSE:", evaluator.evaluate(predictions))

RMSE: 2960.1725880389063


In [80]:
spark.stop()