In [1]:
import os
os.environ['JAVA_HOME'] = '/Library/Java/JavaVirtualMachines/jdk-15.0.1.jdk/Contents/Home'

from pyspark import SparkContext
from pyspark.sql import *

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS, ALSModel

In [2]:
spark = SparkSession.builder.config("spark.driver.memory", "4g").getOrCreate()
sc = spark.sparkContext

In [3]:
train = spark.read.load('../Data/appliances_train.parquet')
test = spark.read.load('../Data/appliances_test.parquet')
val = spark.read.load('../Data/appliances_val.parquet')

In [4]:
train.printSchema()
test.printSchema()
val.printSchema()

root
 |-- item_id: string (nullable = true)
 |-- user_id: string (nullable = true)
 |-- rating: integer (nullable = true)
 |-- user_int_id: integer (nullable = true)
 |-- item_int_id: integer (nullable = true)

root
 |-- item_id: string (nullable = true)
 |-- user_id: string (nullable = true)
 |-- rating: integer (nullable = true)
 |-- user_int_id: integer (nullable = true)
 |-- item_int_id: integer (nullable = true)

root
 |-- item_id: string (nullable = true)
 |-- user_id: string (nullable = true)
 |-- rating: integer (nullable = true)
 |-- user_int_id: integer (nullable = true)
 |-- item_int_id: integer (nullable = true)



In [5]:
train.count()

422281

In [6]:
test.count()

120129

In [7]:
val.count()

60367

In [8]:
train.show()

+----------+--------------+------+-----------+-----------+
|   item_id|       user_id|rating|user_int_id|item_int_id|
+----------+--------------+------+-----------+-----------+
|1118461304|A1A7PGN2HLMLOW|     5|      97723|       2229|
|1118461304|A1O690F0T9XR4Z|     4|     143793|       2229|
|1118461304|A1XMYM1OVUNCXX|     5|     175068|       2229|
|1118461304|A21CQSIOV4NG7Y|     5|     187263|       2229|
|1118461304|A2F8QSA7BCK1N6|     5|     233395|       2229|
|1118461304| APP8XWYYV4PAA|     4|     481193|       2229|
|1118461304| AWSLGG21FW7IU|     4|     504891|       2229|
|1906487049|A1O1V5DBXB6DMP|     5|      22256|      12544|
|1906487049|A2BE36GLSNCEOF|     4|     220583|      12544|
|1906487049| AT7W8XDH1EGGL|     5|     493032|      12544|
|7301113188|A24HQ894NFSTF5|     5|     197636|      19551|
|7861850250|A3B0UA9I9CEVBT|     3|     338816|      19552|
|8792559360|A283UCG4U3AUOM|     5|     209745|       7985|
|8792559360|A3VD9JLBEITZFF|     5|     406095|       798

In [9]:
test.show()

+----------+--------------+------+-----------+-----------+
|   item_id|       user_id|rating|user_int_id|item_int_id|
+----------+--------------+------+-----------+-----------+
|1118461304|A21MGW4YYUZ1YW|     5|     188083|       2229|
|1118461304|A3C2VXXYG6ZB7J|     5|     342231|       2229|
|6040985461|A3KBUSJHZO3P6A|     5|     369322|      19550|
|8792559360|A1A7HV01DD1YKW|     5|      97701|       7985|
|8792559360| A9EZ3SN4HHPTP|     5|     427105|       7985|
|B00002N7HY|A2LNP1T89OK7ZS|     5|     254625|       7408|
|B00002N7HY|A3TIWHNJXMSIU7|     5|       3455|       7408|
|B00002N7HY| ATCY52ZP70FKP|     5|      12888|       7408|
|B00002N7HY| AXE83MK90ZEVZ|     4|     506890|       7408|
|B00002N7HY| AY41JC7MWFY1R|     5|      13193|       7408|
|B00002N9OE|A3M66IF7LFUS27|     5|     375364|      19554|
|B00004SQHD|A2N78TLERYNKC8|     5|     259824|       2942|
|B00004SQHD|A3G13F4UZ1P53N|     5|     355346|       2942|
|B00004SQHD| AMH5CUDJI208M|     5|     470601|       294

In [10]:
val.show()

+----------+--------------+------+-----------+-----------+
|   item_id|       user_id|rating|user_int_id|item_int_id|
+----------+--------------+------+-----------+-----------+
|1118461304|A34T20MDA599MA|     5|     317971|       2229|
|8792559360| A74TBP9UDF62J|     5|     419450|       7985|
|8792559360| ALMBAY4H712C7|     5|     467633|       7985|
|9792954481|A12UD9093P4A14|     5|        427|      15065|
|B00002N5EL|A20JBUWRGLX9QK|     5|     184632|      19553|
|B00002N7IL| A5OV35IHA1I9D|     5|      52625|       6080|
|B00002NARC|A13O7AXYRWR7SE|     5|      76041|       9623|
|B00004SQHD|A2BNNNIE7HZDZP|     5|     221503|       2942|
|B00004YWK7| A8RIQ8GS7FRRJ|     2|     424917|       4182|
|B000056J8D|A2D39KOZE8VXBE|     4|     226296|       2164|
|B000056J8D|A30JQPUENNPAIX|     5|     303789|       2164|
|B000056J8D| AFJGWMZTV64RF|     5|     447416|       2164|
|B00005OU6T|A20KLKLE3FNCIK|     5|     184734|        425|
|B00006IV17|A1EYNVHX20ESW4|     5|     113477|         8

In [11]:
ranks = [50, 100]
regParams = [0.01, 0.1, 1]
min_rmse = float('inf')
best_rank = -1
best_regParam = -1

In [6]:
evaluator = RegressionEvaluator(metricName='rmse', labelCol="rating", predictionCol="prediction")

In [13]:
for rank in ranks:
    for regParam in regParams:
        als = ALS(maxIter=10, rank=rank, regParam=regParam, userCol='user_int_id', itemCol='item_int_id',
                  ratingCol='rating', coldStartStrategy='drop')
        model = als.fit(train)
        
        predictions = model.transform(val)
        rmse = evaluator.evaluate(predictions)
        print('Rank = {}\tRegularization = {}\tRMSE = {}'.format(rank, regParam, rmse))
        
        if rmse < min_rmse:
            min_rmse = rmse
            best_rank = rank
            best_regParam = regParam
            model.write().overwrite().save('../Models/appliances_model')
        
        del model

Rank = 50	Regularization = 0.01	RMSE = 3.9331106807395075
Rank = 50	Regularization = 0.1	RMSE = 3.693289723237236
Rank = 50	Regularization = 1	RMSE = 3.772034715387931
Rank = 100	Regularization = 0.01	RMSE = 3.8470659758532477
Rank = 100	Regularization = 0.1	RMSE = 3.68719655744318
Rank = 100	Regularization = 1	RMSE = 3.7694189252855437


In [14]:
print('Optimal rank = {}'.format(best_rank))
print('Optimal regularization = {}'.format(best_regParam))

Optimal rank = 100
Optimal regularization = 0.1


In [15]:
best_model = ALSModel.load('../Models/appliances_model')

In [16]:
test_predictions = best_model.transform(test)
test_rmse = evaluator.evaluate(test_predictions)
print('Test RMSE = {}'.format(test_rmse))

Test RMSE = 3.6842476137934237


In [17]:
test_predictions.show()

+----------+--------------+------+-----------+-----------+-----------+
|   item_id|       user_id|rating|user_int_id|item_int_id| prediction|
+----------+--------------+------+-----------+-----------+-----------+
|B00O7U6I0A| AZEXVY7HN4CBE|     2|      63493|        148| 0.15703431|
|B00O7U6I0A|A1XYY3JY741SP4|     4|       6416|        148| -0.5187491|
|B00O7U6I0A|A1EM4YL5JCXH2Z|     3|      18751|        148| 0.08858597|
|B00O7U6I0A| A96JD9312DHWC|     5|      11511|        148|-0.07179089|
|B00O7U6I0A| AAPSTM0B83J62|     5|      54408|        148|   0.369241|
|B00O7U6I0A|A1BVZP7MGB836D|     3|      17726|        148|  0.6030436|
|B00O7U6I0A|A329QXIE6FCSNZ|     2|      40951|        148|-0.13411051|
|B00O7U6I0A|A3BOUVNTIGAIYT|     5|      44455|        148|  1.2567904|
|B00O7U6I0A|A27K5G8B9Q4Y7Y|     5|      29509|        148|-0.23368415|
|B00O7U6I0A| AKWTL66CUUG2S|     1|      58078|        148| 0.70660526|
|B00O7U6I0A|A339GQCZI5WP06|     1|      41343|        148| -0.6321523|
|B005B

In [4]:
als = ALS(maxIter=25, rank=100, regParam=0.1, userCol='user_int_id', itemCol='item_int_id',
          ratingCol='rating', coldStartStrategy='drop')
model = als.fit(train)

In [9]:
model.write().overwrite().save('../Models/appliances_model')

In [7]:
test_predictions = model.transform(test)
test_rmse = evaluator.evaluate(test_predictions)
print('Test RMSE = {}'.format(test_rmse))

Test RMSE = 2.850876953527476


In [8]:
test_predictions.show()

+----------+--------------+------+-----------+-----------+-----------+
|   item_id|       user_id|rating|user_int_id|item_int_id| prediction|
+----------+--------------+------+-----------+-----------+-----------+
|B00O7U6I0A| AZEXVY7HN4CBE|     2|      63493|        148| 0.95750785|
|B00O7U6I0A|A1XYY3JY741SP4|     4|       6416|        148|  0.6383388|
|B00O7U6I0A|A1EM4YL5JCXH2Z|     3|      18751|        148| 0.93844664|
|B00O7U6I0A| A96JD9312DHWC|     5|      11511|        148| 0.44943675|
|B00O7U6I0A| AAPSTM0B83J62|     5|      54408|        148|  1.3701524|
|B00O7U6I0A|A1BVZP7MGB836D|     3|      17726|        148|  1.8148531|
|B00O7U6I0A|A329QXIE6FCSNZ|     2|      40951|        148| 0.12827846|
|B00O7U6I0A|A3BOUVNTIGAIYT|     5|      44455|        148|  3.3528082|
|B00O7U6I0A|A27K5G8B9Q4Y7Y|     5|      29509|        148|-0.45448148|
|B00O7U6I0A| AKWTL66CUUG2S|     1|      58078|        148|  1.6126176|
|B00O7U6I0A|A339GQCZI5WP06|     1|      41343|        148|  1.6714823|
|B005B

In [10]:
sc.stop()