In [0]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('rec_system').getOrCreate()

In [0]:
#Importing the libraries
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator

In [0]:
#Importing the data
data = spark.read.csv('/FileStore/tables/movielens_ratings.csv',header=True,inferSchema=True)
data.show()

+-------+------+------+
|movieId|rating|userId|
+-------+------+------+
|      2|   3.0|     0|
|      3|   1.0|     0|
|      5|   2.0|     0|
|      9|   4.0|     0|
|     11|   1.0|     0|
|     12|   2.0|     0|
|     15|   1.0|     0|
|     17|   1.0|     0|
|     19|   1.0|     0|
|     21|   1.0|     0|
|     23|   1.0|     0|
|     26|   3.0|     0|
|     27|   1.0|     0|
|     28|   1.0|     0|
|     29|   1.0|     0|
|     30|   1.0|     0|
|     31|   1.0|     0|
|     34|   1.0|     0|
|     37|   1.0|     0|
|     41|   2.0|     0|
+-------+------+------+
only showing top 20 rows



In [0]:
data.describe().show()

+-------+------------------+------------------+------------------+
|summary|           movieId|            rating|            userId|
+-------+------------------+------------------+------------------+
|  count|              1501|              1501|              1501|
|   mean| 49.40572951365756|1.7741505662891406|14.383744170552964|
| stddev|28.937034065088994| 1.187276166124803| 8.591040424293272|
|    min|                 0|               1.0|                 0|
|    max|                99|               5.0|                29|
+-------+------------------+------------------+------------------+



In [0]:
#Train-test split
train_data, test_data = data.randomSplit([0.8,0.2])

In [0]:
#ALS
als = ALS(maxIter=5,regParam=0.01,userCol='userId',itemCol='movieId',ratingCol='rating')
model = als.fit(train_data)

In [0]:
predictions = model.transform(test_data)
predictions.show()
#prediction column shows what our model thinks that userID would have rated that movieID

+-------+------+------+-----------+
|movieId|rating|userId| prediction|
+-------+------+------+-----------+
|      3|   1.0|    28|-0.48001838|
|      6|   3.0|    26|  2.7853274|
|      6|   1.0|     1| 0.94121677|
|      2|   1.0|    16| -1.2471575|
|      5|   1.0|     6|  1.3792421|
|      6|   1.0|     6|  0.6923469|
|      2|   1.0|     3|-0.40758982|
|      2|   2.0|    20| -1.5782562|
|      4|   1.0|    19|  3.8051877|
|      2|   3.0|     9|  0.9232258|
|      4|   1.0|     9|  2.5877974|
|      5|   1.0|     8|  1.6766896|
|      4|   1.0|    23|  1.2574925|
|      7|   1.0|    10|  1.6311859|
|      6|   3.0|    24|  1.4781915|
|      5|   1.0|    29|  1.4025052|
|      6|   2.0|    11|-0.27427202|
|      5|   1.0|    14|  1.5003328|
|      7|   1.0|    14|  1.2753592|
|      6|   1.0|     2|  1.1792974|
+-------+------+------+-----------+
only showing top 20 rows



In [0]:
#Evaluating our predictions
evaluator = RegressionEvaluator(metricName='rmse',labelCol='rating',predictionCol='prediction')
rmse = evaluator.evaluate(predictions)
print(rmse)
#RMSE is pretty bad since our ratings range from 1-5

1.7022453518544653


In [0]:
#Testing on a new user
new_user = test_data.filter(test_data['userId']==11)
new_user.show()

+-------+------+------+
|movieId|rating|userId|
+-------+------+------+
|      6|   2.0|    11|
|     25|   1.0|    11|
|     38|   4.0|    11|
|     40|   1.0|    11|
|     47|   1.0|    11|
|     51|   3.0|    11|
|     62|   1.0|    11|
|     67|   1.0|    11|
|     78|   1.0|    11|
|     80|   3.0|    11|
|     82|   1.0|    11|
|     90|   4.0|    11|
|     94|   2.0|    11|
+-------+------+------+



In [0]:
results = model.transform(new_user)
results.orderBy('prediction',ascending=False).show()

+-------+------+------+-----------+
|movieId|rating|userId| prediction|
+-------+------+------+-----------+
|     90|   4.0|    11|  2.9361863|
|     67|   1.0|    11|  1.2081172|
|     78|   1.0|    11|  0.5075816|
|     38|   4.0|    11|  0.5016695|
|     80|   3.0|    11| -0.2682267|
|      6|   2.0|    11|-0.27427202|
|     82|   1.0|    11|-0.36417758|
|     62|   1.0|    11|-0.51631796|
|     25|   1.0|    11|-0.57077545|
|     51|   3.0|    11| -0.7147925|
|     47|   1.0|    11|  -1.276248|
|     40|   1.0|    11| -1.4126971|
|     94|   2.0|    11| -2.8477051|
+-------+------+------+-----------+

