In [0]:
spark.conf.set(
  "fs.azure.account.key.storagestudent.blob.core.windows.net", 
  "pH3rgal+XcwJXc3hQEYEAE+dMBo6YzhKnb4iYQNlTZ9lXaxe8RWmZwVPMF1j2V5zwBnBZ/iNu8JoFgApOxdn4Q=="
)

reviews = spark.read.load(
  "wasbs://default@storagestudent.blob.core.windows.net/datasets/S8-5/Cours/reviewstar.csv",
  format="csv",
  header="true"
)

In [0]:
display(reviews)

user_id,categories,reviews,stars
1,4,1,2
2,4,1,4
3,4,1,4
4,4,1,3
5,4,2,3
6,4,1,4
7,4,1,4
8,4,2,3
9,4,1,4
10,1,1,4


In [0]:
from pyspark.sql.types import *
from pyspark.sql import functions as F

reviews = reviews.select(
  F.col("user_id").cast(IntegerType()),
  F.col("categories").cast(IntegerType()),
  F.col("stars").cast(IntegerType())
)

In [0]:
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

train, test = reviews.randomSplit([0.7, 0.3])

als = ALS(userCol="user_id", itemCol="categories", ratingCol="stars", coldStartStrategy="drop")

evaluator = RegressionEvaluator(metricName="rmse", labelCol="stars", predictionCol="prediction")

paramGrid = ParamGridBuilder()\
                    .addGrid(als.rank, [1, 5])\
                    .addGrid(als.maxIter, [2, 10])\
                    .build()

In [0]:
cv = CrossValidator(
  estimator=als, 
  evaluator=evaluator, 
  estimatorParamMaps=paramGrid, 
  numFolds=5
)

model = cv.fit(train)

In [0]:
predictions = model.transform(test)
evaluator.evaluate(predictions)

In [0]:
display(predictions)

user_id,categories,stars,prediction
1829,1,4,4.2880898
11033,1,5,4.8418674
11458,1,5,4.8505445
17389,1,5,4.8441277
40011,1,2,3.0060859
42834,1,5,4.8739147
46521,1,5,4.844239
71510,1,5,4.827536
78120,1,5,4.8538437
91785,1,5,4.869077


In [0]:
predictions = predictions.withColumn("prediction", F.abs(F.round(predictions["prediction"],0)))
display(predictions)

user_id,categories,stars,prediction
33868,1,5,5.0
35820,1,4,2.0
38422,1,4,4.0
40011,1,2,3.0
44822,1,5,5.0
78120,1,5,5.0
34103,1,5,5.0
69279,1,5,4.0
74452,1,5,5.0
80424,1,5,5.0


In [0]:
userRecommendations = model.bestModel.recommendForAllUsers(10 )
display(userRecommendations)

user_id,recommendations
463,"List(List(4, 2.938305), List(2, 2.2288127), List(10, 1.9952986), List(11, 1.832077), List(9, 1.828679), List(6, 1.7771266), List(5, 1.7277806), List(3, 1.6683965), List(1, 1.6079025), List(8, 1.388194))"
471,"List(List(5, 4.339985), List(1, 4.261614), List(9, 4.154074), List(10, 3.9263768), List(2, 3.9043157), List(4, 3.8578823), List(6, 3.8521414), List(11, 3.8451447), List(3, 3.7008696), List(7, 2.3474438))"
496,"List(List(4, 1.9588698), List(2, 1.485875), List(10, 1.3301991), List(11, 1.2213846), List(9, 1.2191193), List(6, 1.184751), List(5, 1.1518538), List(3, 1.1122643), List(1, 1.071935), List(8, 0.9254626))"
1088,"List(List(4, 3.9177396), List(2, 2.97175), List(10, 2.6603982), List(11, 2.4427693), List(9, 2.4382386), List(6, 2.369502), List(5, 2.3037076), List(3, 2.2245286), List(1, 2.14387), List(8, 1.8509252))"
1238,"List(List(4, 3.9177396), List(2, 2.97175), List(10, 2.6603982), List(11, 2.4427693), List(9, 2.4382386), List(6, 2.369502), List(5, 2.3037076), List(3, 2.2245286), List(1, 2.14387), List(8, 1.8509252))"
1591,"List(List(10, 3.9041405), List(9, 3.3534553), List(5, 3.263825), List(1, 3.21611), List(6, 3.1158226), List(4, 3.1002183), List(2, 3.052342), List(11, 2.7871292), List(3, 2.3927944), List(8, 1.5305529))"
1645,"List(List(4, 3.9177396), List(2, 2.97175), List(10, 2.6603982), List(11, 2.4427693), List(9, 2.4382386), List(6, 2.369502), List(5, 2.3037076), List(3, 2.2245286), List(1, 2.14387), List(8, 1.8509252))"
1829,"List(List(5, 4.3874836), List(1, 4.350979), List(9, 4.148459), List(10, 3.8256946), List(2, 3.7175894), List(11, 3.6900787), List(6, 3.6598215), List(3, 3.5636132), List(4, 3.1714795), List(8, 2.0053735))"
1959,"List(List(5, 4.8233595), List(1, 4.8030896), List(9, 4.6172543), List(2, 3.9350646), List(11, 3.874587), List(3, 3.8622313), List(6, 3.8557975), List(10, 3.8538835), List(4, 3.1326363), List(8, 1.950526))"
2122,"List(List(4, 3.9177396), List(2, 2.97175), List(10, 2.6603982), List(11, 2.4427693), List(9, 2.4382386), List(6, 2.369502), List(5, 2.3037076), List(3, 2.2245286), List(1, 2.14387), List(8, 1.8509252))"


In [0]:
itemRecommendations = model.bestModel.recommendForAllItems(10)
display(itemRecommendations)

categories,recommendations
1,"List(List(76678, 5.751957), List(17457, 5.6462007), List(10068, 5.4599476), List(42109, 5.369413), List(92562, 5.3301334), List(80329, 5.3301334), List(66540, 5.327018), List(8316, 5.281198), List(44499, 5.26343), List(32554, 5.161111))"
6,"List(List(76678, 5.1636753), List(82350, 4.8606243), List(32820, 4.8606243), List(23130, 4.8606243), List(46090, 4.8606243), List(48180, 4.8606243), List(24250, 4.8606243), List(13050, 4.8606243), List(28370, 4.8606243), List(23470, 4.8606243))"
3,"List(List(76678, 5.879947), List(47926, 5.3801694), List(26729, 5.3801694), List(10068, 5.164803), List(7178, 5.0066547), List(61064, 5.0066547), List(34249, 5.0066547), List(42515, 5.0066547), List(15137, 4.9638357), List(94424, 4.9266825))"
5,"List(List(76678, 5.9088984), List(17457, 5.57895), List(10068, 5.5401077), List(66540, 5.4249816), List(80329, 5.4115868), List(92562, 5.4115868), List(42109, 5.3891354), List(8316, 5.353559), List(44499, 5.3029366), List(5202, 5.2194448))"
9,"List(List(66540, 5.1925926), List(68213, 5.110622), List(17457, 5.0193863), List(32264, 4.9496946), List(4409, 4.9496946), List(15804, 4.9496946), List(622, 4.913912), List(46425, 4.913912), List(92562, 4.9063344), List(80329, 4.9063344))"
4,"List(List(108452, 5.7907925), List(112354, 5.5196095), List(107987, 5.4520416), List(112995, 5.3496933), List(76678, 5.3259964), List(111827, 5.1556187), List(108442, 5.1011705), List(113448, 5.0758257), List(108644, 5.064835), List(110919, 5.064835))"
8,"List(List(42515, 4.7858086), List(61064, 4.7858086), List(7178, 4.7858086), List(34249, 4.7858086), List(46380, 4.5853376), List(111460, 4.4211135), List(39347, 4.295005), List(76678, 4.282408), List(6884, 4.1994123), List(47926, 4.188794))"
7,"List(List(7178, 5.440984), List(42515, 5.440984), List(61064, 5.440984), List(34249, 5.440984), List(111460, 5.306477), List(46380, 5.2218757), List(8030, 4.839431), List(37983, 4.839431), List(31164, 4.839431), List(46533, 4.839431))"
10,"List(List(572, 4.9144564), List(6333, 4.9139204), List(66540, 4.907971), List(12370, 4.8801756), List(8530, 4.8801756), List(18130, 4.8801756), List(6990, 4.8801756), List(19490, 4.8801756), List(70, 4.8801756), List(15430, 4.8801756))"
11,"List(List(76678, 5.9711223), List(47926, 5.4877176), List(26729, 5.4877176), List(7178, 5.2703156), List(34249, 5.2703156), List(42515, 5.2703156), List(61064, 5.2703156), List(10068, 5.150781), List(46380, 5.0288277), List(25559, 5.0272627))"
