In [75]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("MovieLens_Recommendations").getOrCreate()

In [76]:
ratings = (
    spark.read.csv(
        "ratings.csv",
        sep=",",
        header=True,
        quote='"',
        schema="userId INT, movieId INT, rating DOUBLE, timestamp INT",
    )
    .drop(col("timestamp"))
    .withColumnRenamed("userId", "user")
    .withColumnRenamed("movieId", "item")
)

In [77]:
ratings.show()
ratings.count()

+----+----+------+
|user|item|rating|
+----+----+------+
|   1|   1|   4.0|
|   1|   3|   4.0|
|   1|   6|   4.0|
|   1|  47|   5.0|
|   1|  50|   5.0|
|   1|  70|   3.0|
|   1| 101|   5.0|
|   1| 110|   4.0|
|   1| 151|   5.0|
|   1| 157|   5.0|
|   1| 163|   5.0|
|   1| 216|   5.0|
|   1| 223|   3.0|
|   1| 231|   5.0|
|   1| 235|   4.0|
|   1| 260|   5.0|
|   1| 296|   3.0|
|   1| 316|   3.0|
|   1| 333|   5.0|
|   1| 349|   4.0|
+----+----+------+
only showing top 20 rows



100836

In [28]:
from pyspark.ml.recommendation import ALS

In [81]:
rank = 10
max_iter = 25
seed = 0

als = ALS(rank=rank, maxIter=max_iter, seed=seed)
model = als.fit(ratings)

In [82]:
model.userFactors.orderBy("id").show(30, False)

+---+----------------------------------------------------------------------------------------------------------------------------+
|id |features                                                                                                                    |
+---+----------------------------------------------------------------------------------------------------------------------------+
|1  |[-0.29646167, -1.0991681, 0.43098325, 0.11375476, 1.1953692, 0.15712161, 1.2840616, 0.75633204, 0.6297434, 0.1069835]       |
|2  |[-0.49296665, -0.28051114, 0.33050418, -0.41401663, 1.1505803, 0.4131445, 1.1909525, 0.85643035, 0.15713352, 0.035587862]   |
|3  |[-0.57693475, -0.45748603, 0.20840773, -1.19696, 0.43931463, 1.1464447, -0.21777278, 0.56207395, 0.40704304, -1.1016352]    |
|4  |[0.23605551, -1.6480961, -0.84258175, 0.31882033, 0.58057207, 0.2083654, 0.5769176, -0.056165535, 1.0585331, 0.17286816]    |
|5  |[-0.1288039, -1.0679277, -0.53117055, -0.120779835, 0.781204, -0.31552023, 0.7

In [83]:
simple_test_data = [
    (1, 2), 
    (10, 3), 
    (10, 8), 
    (20, 4), 
    (9, 7), 
    (550, 235)
]

test = spark.createDataFrame(simple_test_data, "user INT, item INT")

In [84]:
test_prediction = model.transform(test)

In [86]:
test_prediction.show()

+----+----+----------+
|user|item|prediction|
+----+----+----------+
|  10|   3|  1.957693|
| 550| 235| 3.4253366|
|  20|   4| 2.3973246|
|  10|   8| 1.5874846|
|   9|   7| 2.4595492|
|   1|   2| 3.9841352|
+----+----+----------+



In [70]:
predictions_in_pandas = test_prediction.toPandas()

In [87]:
print(predictions_in_pandas)

   user  item  prediction
0    10     3    2.246040
1   550   235    3.208731
2    20     4    2.438197
3    10     8    1.380121
4     9     7    2.920148
5     1     2    4.126770
