In [1]:
# Imports

from pyspark.sql.types import StructField, IntegerType, StructType
from pyspark.sql import SparkSession
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS

In [2]:
# Initializing a spark session
spark = SparkSession.builder.master("local").getOrCreate()

# Location of the "ratings.csv" file
ratings_location = "./ratings.csv"


# CSV options
infer_schema = "True"
first_row_is_header = "True"
delimiter = ","

data_schema = [StructField('userId', IntegerType(), True)]
final_struct = StructType(fields = data_schema)

# The applied options are for CSV files. For other file types, these will be ignored.
ratings = spark.read.format("csv") \
  .option("inferSchema", infer_schema) \
  .option("header", first_row_is_header) \
  .option("sep", delimiter) \
  .load(ratings_location)

ratings.show()

+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
|     1|      2|   3.5|1112486027|
|     1|     29|   3.5|1112484676|
|     1|     32|   3.5|1112484819|
|     1|     47|   3.5|1112484727|
|     1|     50|   3.5|1112484580|
|     1|    112|   3.5|1094785740|
|     1|    151|   4.0|1094785734|
|     1|    223|   4.0|1112485573|
|     1|    253|   4.0|1112484940|
|     1|    260|   4.0|1112484826|
|     1|    293|   4.0|1112484703|
|     1|    296|   4.0|1112484767|
|     1|    318|   4.0|1112484798|
|     1|    337|   3.5|1094785709|
|     1|    367|   3.5|1112485980|
|     1|    541|   4.0|1112484603|
|     1|    589|   3.5|1112485557|
|     1|    593|   3.5|1112484661|
|     1|    653|   3.0|1094785691|
|     1|    919|   3.5|1094785621|
+------+-------+------+----------+
only showing top 20 rows



In [3]:
ratings.describe().show()

+-------+-----------------+------------------+------------------+--------------------+
|summary|           userId|           movieId|            rating|           timestamp|
+-------+-----------------+------------------+------------------+--------------------+
|  count|         20000263|          20000263|          20000263|            20000263|
|   mean|69045.87258292554| 9041.567330339605|3.5255285642993797|1.1009179216771157E9|
| stddev|40038.62665316201|19789.477445413086| 1.051988919294227|1.6216942478273067E8|
|    min|                1|                 1|               0.5|           789652004|
|    max|           138493|            131262|               5.0|          1427784002|
+-------+-----------------+------------------+------------------+--------------------+



In [4]:
# Location of the "movies.csv" file
movies_location = "./movies.csv"

movies = spark.read.format("csv") \
  .option("inferSchema", infer_schema) \
  .option("header", first_row_is_header) \
  .option("sep", delimiter) \
  .load(movies_location)

movies.show()

+-------+--------------------+--------------------+
|movieId|               title|              genres|
+-------+--------------------+--------------------+
|      1|    Toy Story (1995)|Adventure|Animati...|
|      2|      Jumanji (1995)|Adventure|Childre...|
|      3|Grumpier Old Men ...|      Comedy|Romance|
|      4|Waiting to Exhale...|Comedy|Drama|Romance|
|      5|Father of the Bri...|              Comedy|
|      6|         Heat (1995)|Action|Crime|Thri...|
|      7|      Sabrina (1995)|      Comedy|Romance|
|      8| Tom and Huck (1995)|  Adventure|Children|
|      9| Sudden Death (1995)|              Action|
|     10|    GoldenEye (1995)|Action|Adventure|...|
|     11|American Presiden...|Comedy|Drama|Romance|
|     12|Dracula: Dead and...|       Comedy|Horror|
|     13|        Balto (1995)|Adventure|Animati...|
|     14|        Nixon (1995)|               Drama|
|     15|Cutthroat Island ...|Action|Adventure|...|
|     16|       Casino (1995)|         Crime|Drama|
|     17|Sen

In [5]:
movies.describe().show()

+-------+-----------------+--------------------+------------------+
|summary|          movieId|               title|            genres|
+-------+-----------------+--------------------+------------------+
|  count|            27278|               27278|             27278|
|   mean|59855.48057042305|                null|              null|
| stddev|44429.31469707313|                null|              null|
|    min|                1|"""Great Performa...|(no genres listed)|
|    max|           131262|       貞子3D (2012)|           Western|
+-------+-----------------+--------------------+------------------+



In [6]:
# Spliting the ratings data into test and train

train,test = ratings.randomSplit([0.8, 0.2])

In [7]:
# Setting the ALS algorithm

als = ALS()
(als.setRank(1)
  .setUserCol("userId")
  .setItemCol("movieId")
  .setRatingCol("rating")
  .setMaxIter(5)
  .setColdStartStrategy("drop"))

ALS_8942c42a6418

In [8]:
#Training the model

model = als.fit(train)

In [9]:
prediction = model.transform(test)
prediction.show()

+------+-------+------+----------+----------+
|userId|movieId|rating| timestamp|prediction|
+------+-------+------+----------+----------+
| 97435|    148|   4.0|1042483722| 3.0167832|
|136222|    148|   2.0| 849125057| 3.0014815|
| 60081|    148|   2.0| 837850255| 2.9971402|
|  3990|    148|   4.0|1422817494|  2.570272|
| 64843|    148|   3.5|1104862927|  3.039391|
| 20344|    148|   2.0| 965940170|  2.540976|
| 78276|    148|   2.0| 935277774| 3.0822134|
|  8663|    148|   1.0|1047357247| 2.9572928|
| 61663|    148|   2.0| 874577512| 3.2588964|
|  9084|    148|   2.0| 833674024|  3.315384|
| 14282|    148|   3.0| 940520793| 2.8264296|
|105376|    148|   3.0| 832006144| 2.6369488|
|130531|    148|   1.0| 831284829| 1.7416174|
|116920|    148|   3.0| 829819153| 3.2041485|
| 91231|    148|   4.0|1025350818|  2.796146|
|117581|    148|   1.0| 833133885| 1.1877666|
| 24709|    148|   2.0| 833173771| 2.9744246|
| 48132|    148|   3.0| 833538188|  3.104985|
| 86058|    148|   3.0| 947795986|

In [10]:
def get_movie_name(movie_id, movies):
    print(movies.where(movies.movieId==movie_id).collect())

In [12]:
users = ratings.select(als.getUserCol()).distinct().limit(1)
rd = model.recommendForUserSubset(users, 1)
rd = rd.take(1)

In [13]:
# Using index to obtain the movie id of top predicted rated item
recommendation = rd[0]['recommendations'][0][0]

# Using get_movie_name function to get the name and genre of the movie
get_movie_name(recommendation,movies)

[Row(movieId=126219, title='Marihuana (1936)', genres='Documentary|Drama')]
