In [1]:
pip install pyspark

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [12]:
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql import Row

In [13]:
spark = SparkSession.builder.appName("Recommender").getOrCreate()

In [14]:

data = spark.read.csv('book_ratings.csv', 
                      inferSchema=True,header=True) 
  
data.show(5)

+-------+-------+------+
|book_id|user_id|rating|
+-------+-------+------+
|      1|    314|     5|
|      1|    439|     3|
|      1|    588|     5|
|      1|   1169|     4|
|      1|   1185|     4|
+-------+-------+------+
only showing top 5 rows



In [15]:
data.describe().show()

+-------+-----------------+------------------+------------------+
|summary|          book_id|           user_id|            rating|
+-------+-----------------+------------------+------------------+
|  count|           981756|            981756|            981756|
|   mean|4943.275635697668|25616.759933221696|3.8565335989797873|
| stddev|2873.207414896143|15228.338825882149|0.9839408559619973|
|    min|                1|                 1|                 1|
|    max|            10000|             53424|                 5|
+-------+-----------------+------------------+------------------+



In [16]:
train_data, test_data = data.randomSplit([0.8, 0.2])

In [17]:
als = ALS(maxIter=5, 
          regParam=0.01, 
          userCol="user_id", 
          itemCol="book_id", 
          ratingCol="rating") 
  

model = als.fit(train_data)

In [18]:
predictions = model.transform(test_data) 

In [19]:
predictions.show()

+-------+-------+------+----------+
|book_id|user_id|rating|prediction|
+-------+-------+------+----------+
|      1|  12471|     5| 4.3827224|
|      2|  15604|     4|  4.111163|
|      2|  17984|     5| 4.4819336|
|      2|  17566|     4| 4.0805535|
|      1|  10140|     4| 3.4344268|
|      1|  43985|     4|  3.960016|
|      1|  33697|     4| 4.3076034|
|      2|  10751|     3| 3.5920756|
|      1|  21487|     4|  4.236959|
|      2|  17643|     1| 3.3915699|
|      1|  13282|     5| 4.5204687|
|      1|   1169|     4| 3.7499666|
|      2|   1169|     3| 3.4878473|
|      1|  29123|     3| 2.5495992|
|      1|  45493|     5|  4.790303|
|      2|  10509|     2| 4.8685927|
|      2|  11691|     4| 3.9240675|
|      1|   7563|     3| 4.0731373|
|      1|  22602|     4| 3.8647537|
|      1|  47746|     5|  5.274465|
+-------+-------+------+----------+
only showing top 20 rows



In [20]:
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",predictionCol="prediction") 
rmse = evaluator.evaluate(predictions) 
print("Root-mean-square error = " + str(rmse))

Root-mean-square error = nan


In [21]:
user1 = test_data.filter(test_data['user_id']==5461).select(['book_id','user_id']) 
  
 
user1.show()

+-------+-------+
|book_id|user_id|
+-------+-------+
|     10|   5461|
|     15|   5461|
|     38|   5461|
|     43|   5461|
|     46|   5461|
|     47|   5461|
|     55|   5461|
|    117|   5461|
|    121|   5461|
|    130|   5461|
|    142|   5461|
|    180|   5461|
|    186|   5461|
|    198|   5461|
|    273|   5461|
|    293|   5461|
|    354|   5461|
|    357|   5461|
|    361|   5461|
|    375|   5461|
+-------+-------+
only showing top 20 rows



In [22]:
recommendations = model.transform(user1) 
   
recommendations.orderBy('prediction',ascending=False).show()

+-------+-------+----------+
|book_id|user_id|prediction|
+-------+-------+----------+
|    561|   5461| 4.9864507|
|    180|   5461| 4.9695415|
|    669|   5461| 4.8601875|
|     46|   5461| 4.8569736|
|    117|   5461|  4.710736|
|     43|   5461|  4.703219|
|   1161|   5461| 4.6378455|
|    357|   5461| 4.6254034|
|     15|   5461| 4.5995684|
|    998|   5461| 4.5749373|
|     55|   5461|   4.56015|
|     10|   5461| 4.5342503|
|   1202|   5461|  4.523735|
|     47|   5461| 4.4866805|
|    293|   5461| 4.4724894|
|    588|   5461| 4.4615493|
|    798|   5461| 4.4566646|
|    130|   5461|  4.345561|
|    198|   5461|  4.288522|
|    375|   5461| 4.2848024|
+-------+-------+----------+
only showing top 20 rows



In [23]:
spark.stop()