In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np


from pyspark import SparkContext

from pyspark.sql import Row
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS, ALSModel

In [3]:
%env MASTER=local[4]
#SparkContext.setSystemProperty('spark.executor.memory', '8g')
sc = SparkContext("local", "spark session ratings")
spark = (SparkSession.builder
            .master("local")
            .appName("spark session ratings")
            .enableHiveSupport()
            .getOrCreate()
        )

env: MASTER=local[4]


In [4]:
spark.sparkContext.setCheckpointDir('~/checkpoints')

In [9]:
!cp /data/movie_dataset/ratings.csv .

In [5]:
lines = (sc.textFile("ratings.csv")
         .filter(lambda s: not s.startswith("userId")) # to ignore header
        )
parts = lines.map(lambda row: row.split(","))
ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), 
                                     movieId=int(p[1]),
                                     rating=float(p[2])
                                    )
                      )#.filter(lambda r: r.movieId in top_movies)
ratings = spark.createDataFrame(ratingsRDD)
#ratings.cache()

In [6]:
top_movies = ratings.groupby("movieId").count()
top_movies = top_movies.sort(F.col("count").desc()).select("movieId").limit(9000)

In [7]:
ratings = ratings.filter(ratings.movieId.isin(top_movies.movieId))

In [8]:
(training, test) = ratings.randomSplit([0.8, 0.2], 42)

# Null model as reference

In [14]:
evaluator = RegressionEvaluator(metricName="r2", labelCol="rating",
                                predictionCol="prediction")

In [17]:
avg_votes = training.groupby("movieId").avg("rating").withColumnRenamed("avg(rating)","prediction")
test_null_model = avg_votes.join(test,test.movieId==avg_votes.movieId)

In [18]:
r2_base = evaluator.evaluate(test_null_model)
print("TEST r^2 base = " +str(r2_base))

TEST r^2 base = 0.18257645067228345


In [45]:
from pyspark.mllib.evaluation import BinaryClassificationMetrics
avg = 2.5
test_null_bin = test_null_model.rdd.map(lambda r: 
                                        (1. if r.prediction >= avg else 0., 1. if r.rating >= avg else 0.)
                                       )

bin_clf =  BinaryClassificationMetrics(test_null_bin)
bin_clf.areaUnderROC

0.5496725622008477

# ALS Model

In [9]:
TRAIN = False

In [10]:
if TRAIN:
    als = ALS(maxIter=25, regParam=0.15, userCol="userId", itemCol="movieId", ratingCol="rating",
              coldStartStrategy="drop", seed=46)
    model = als.fit(training)
    model.write().overwrite().save("als.model")
else:
    model = ALSModel.load("als.model")

In [11]:
test_predictions = model.transform(test).na.fill(0)

In [15]:
r2_als = evaluator.evaluate(test_predictions)
print("TEST r^2 = " + str(r2_als))

TEST r^2 = 0.37501610720436007


In [17]:
from pyspark.mllib.evaluation import BinaryClassificationMetrics
avg = 2.5
test_predictions_bin = test_predictions.rdd.map(lambda r: (1. if r.prediction>=avg else 0., 1. if r.rating >= avg else 0.))

bin_clf =  BinaryClassificationMetrics(test_predictions_bin)
bin_clf.areaUnderROC

0.6732428900128344

In [None]:
test_predictions.sort("userId").write.format("csv").save("test_predictions.csv")

In [27]:
test_predictions.filter("userId == 11").show()

+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
|    11|  55363|   3.0| 3.0163019|
|    11|  33437|   3.0| 3.0601606|
|    11|  58559|   4.5|  3.563614|
|    11|  57368|   3.5| 2.8807693|
|    11|  55247|   4.5| 3.3111665|
|    11|  53921|   3.5|  2.848623|
|    11|     47|   3.5| 3.4663193|
|    11|   7347|   3.5|  2.870282|
|    11|  48774|   3.5| 3.2386696|
|    11|   2054|   2.5|  2.443082|
|    11|  53322|   4.0| 3.1258006|
|    11|  49272|   3.5|  3.322027|
|    11|  60126|   3.0| 2.8522007|
|    11|  56633|   2.5| 2.9587643|
|    11|  52973|   3.5| 3.0321949|
|    11|  49130|   3.5|    3.0962|
|    11|  51935|   4.0|  3.187418|
|    11|  55729|   2.0| 2.9732869|
|    11|  44555|   4.0|   3.48112|
|    11|  61132|   3.5| 2.8985257|
+------+-------+------+----------+
only showing top 20 rows

