In [20]:
from pyspark.ml.recommendation import ALS

In [1]:
session = SparkSession.builder.master("local").appName("recommandation-system").getOrCreate()

In [5]:
rdd_raw = sc.textFile("data/user-item.txt")
rdd_raw.cache()
rdd_raw.take(5)

['1001,9001,10', '1001,9002,1', '1001,9003,9', '1002,9001,3', '1002,9002,5']

In [14]:
rdd_transf = rdd_raw.map(lambda line: line.split(",")).map(lambda line: (int(line[0]), int(line[1]), float(line[2])))
rdd_transf.take(5)               

[(1001, 9001, 10.0),
 (1001, 9002, 1.0),
 (1001, 9003, 9.0),
 (1002, 9001, 3.0),
 (1002, 9002, 5.0)]

In [16]:
df_base = session.createDataFrame(rdd_transf, ["user", "item", "rating"])
df_base.persist()
df_base.show(5)

+----+----+------+
|user|item|rating|
+----+----+------+
|1001|9001|  10.0|
|1001|9002|   1.0|
|1001|9003|   9.0|
|1002|9001|   3.0|
|1002|9002|   5.0|
+----+----+------+
only showing top 5 rows



In [21]:
als = ALS(rank=10, maxIter=5)
model = als.fit(df_base)

In [24]:
# affinity score 
model.userFactors.orderBy("id").collect()

[Row(id=1001, features=[0.9681084156036377, -0.6955734491348267, 0.7195470929145813, -0.23173299431800842, 0.258536159992218, 0.21683301031589508, -0.3001883029937744, 0.9602264761924744, 0.4306694269180298, 0.176595076918602]),
 Row(id=1002, features=[0.2752208411693573, 0.45148375630378723, 0.007497432176023722, 0.029693691059947014, 0.7466303110122681, 1.5252691507339478, 0.1335616558790207, -0.6253135204315186, -0.16368703544139862, -0.18602345883846283]),
 Row(id=1003, features=[0.31628555059432983, 0.48526066541671753, -0.09361215680837631, 0.6461942791938782, 0.16269101202487946, 1.2503844499588013, 0.22005711495876312, -0.5321128368377686, -0.16695955395698547, 0.0727267861366272]),
 Row(id=1004, features=[0.9316565990447998, -0.7643890976905823, 0.4667227268218994, -0.060695432126522064, 0.544359028339386, 0.6515966057777405, -0.22367969155311584, 0.5234723091125488, 0.09014416486024857, 0.03905295953154564]),
 Row(id=1005, features=[0.7033907771110535, 0.16621631383895874, 0.

In [27]:
df_test = session.createDataFrame([[1001,9003],[1001,9004], [1001,9005]], ["user", "item"])
df_test.show()

+----+----+
|user|item|
+----+----+
|1001|9003|
|1001|9004|
|1001|9005|
+----+----+



In [29]:
predictions = model.transform(df_test)
predictions.collect()

[Row(user=1001, item=9004, prediction=-0.6358490586280823),
 Row(user=1001, item=9005, prediction=-2.2901651859283447),
 Row(user=1001, item=9003, prediction=9.001792907714844)]