In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName('recom').getOrCreate()

In [5]:
from pyspark.ml.recommendation import ALS

In [6]:
from pyspark.ml.evaluation import RegressionEvaluator

In [7]:
data = spark.read.csv('movielens_ratings.csv',
                         header = True, inferSchema = True)

In [8]:
data.show()

+-------+------+------+
|movieId|rating|userId|
+-------+------+------+
|      2|   3.0|     0|
|      3|   1.0|     0|
|      5|   2.0|     0|
|      9|   4.0|     0|
|     11|   1.0|     0|
|     12|   2.0|     0|
|     15|   1.0|     0|
|     17|   1.0|     0|
|     19|   1.0|     0|
|     21|   1.0|     0|
|     23|   1.0|     0|
|     26|   3.0|     0|
|     27|   1.0|     0|
|     28|   1.0|     0|
|     29|   1.0|     0|
|     30|   1.0|     0|
|     31|   1.0|     0|
|     34|   1.0|     0|
|     37|   1.0|     0|
|     41|   2.0|     0|
+-------+------+------+
only showing top 20 rows



In [9]:
# we have 30 users and 100 movies
data.describe().show()

+-------+------------------+------------------+------------------+
|summary|           movieId|            rating|            userId|
+-------+------------------+------------------+------------------+
|  count|              1501|              1501|              1501|
|   mean| 49.40572951365756|1.7741505662891406|14.383744170552964|
| stddev|28.937034065088994| 1.187276166124803| 8.591040424293272|
|    min|                 0|               1.0|                 0|
|    max|                99|               5.0|                29|
+-------+------------------+------------------+------------------+



In [10]:
training,testing = data.randomSplit([0.7,0.3])

In [12]:
als = ALS(maxIter=5,regParam=0.01,userCol = 'userId',itemCol = 'movieId', ratingCol = 'rating')

In [13]:
model = als.fit(training)

In [14]:
predictions = model.transform(testing)

In [15]:
predictions.show()

+-------+------+------+----------+
|movieId|rating|userId|prediction|
+-------+------+------+----------+
|      1|   1.0|     6|0.95046306|
|      1|   4.0|    15| 1.2517287|
|      1|   1.0|    19| 0.6231756|
|      1|   1.0|    20|0.36261564|
|      1|   3.0|    25|0.28621715|
|      5|   1.0|     5| 1.8876721|
|      5|   1.0|     6| 1.2547414|
|      5|   2.0|    22| 1.6307329|
|      5|   1.0|    29| 0.6946439|
|      4|   3.0|    10| 1.5164242|
|      4|   1.0|    14|  2.744133|
|      4|   2.0|    20| 1.7715421|
|      4|   1.0|    29|-1.0487461|
|      2|   2.0|     1|  2.988127|
|      2|   3.0|     9|  1.602682|
|      2|   1.0|    12| 1.1116943|
|      2|   1.0|    19| 1.2217522|
|      2|   1.0|    26| 2.2951252|
|      2|   4.0|    28| 6.4618015|
|      0|   3.0|    10|  1.273894|
+-------+------+------+----------+
only showing top 20 rows



In [18]:
evaluator = RegressionEvaluator(metricName = 'rmse',labelCol = 'rating',
                                 predictionCol = 'prediction')

In [19]:
#see overall how far our predictions are from the real labels
rmse = evaluator.evaluate(predictions)

In [20]:
rmse

2.152104952412576

In [26]:
#simulate a fresh user scenario
single_user = data.filter(data['userId'] == 1).select(['movieId','userId'])

In [27]:
single_user.show()

+-------+------+
|movieId|userId|
+-------+------+
|      2|     1|
|      3|     1|
|      4|     1|
|      6|     1|
|      9|     1|
|     12|     1|
|     13|     1|
|     14|     1|
|     16|     1|
|     19|     1|
|     21|     1|
|     27|     1|
|     28|     1|
|     33|     1|
|     36|     1|
|     37|     1|
|     40|     1|
|     41|     1|
|     43|     1|
|     44|     1|
+-------+------+
only showing top 20 rows



In [28]:
recommendations = model.transform(single_user)

In [29]:
recommendations.orderBy('prediction',ascending = False).show()

+-------+------+----------+
|movieId|userId|prediction|
+-------+------+----------+
|     62|     1| 4.0510526|
|     68|     1| 3.7660034|
|     60|     1| 3.2947803|
|      2|     1|  2.988127|
|      9|     1| 2.8404396|
|     77|     1| 2.6793573|
|     27|     1| 2.5524368|
|     37|     1| 2.3190424|
|     74|     1| 2.3162448|
|     70|     1|   2.11627|
|      4|     1| 2.0868866|
|     94|     1|  2.022429|
|     41|     1| 1.9754239|
|     36|     1| 1.9743502|
|     88|     1| 1.7614537|
|     56|     1| 1.7551526|
|     85|     1| 1.6537267|
|     19|     1| 1.4702615|
|     12|     1| 1.4281734|
|      3|     1| 1.2837512|
+-------+------+----------+
only showing top 20 rows

