In [18]:
from pyspark import SparkContext
from pyspark.sql import SparkSession ,Row
from pyspark.sql.functions import col
from pyspark.sql import SQLContext
import pyspark.sql.functions as F

from pyspark.mllib.recommendation import *
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import explode

from operator import *
import os
import random

In [2]:
# initializing a spark session
spark_session = SparkSession.builder.appName('GRP10_MusicRec').getOrCreate()
# creating spark context for sql
SQL_context = SQLContext(spark_session.sparkContext)


Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/03/11 15:42:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/03/11 15:42:03 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
# reading csv data from local FS
# THIS PART NEEDS TO BE CHANGED TO HDFS VERSION
fPATH = 'train_triplets.txt'
triplets = spark_session.read.csv(fPATH, header = False, sep=r'\t')

# alter col name
triplets = triplets.withColumnRenamed('_c0','User').withColumnRenamed('_c1','Song').withColumnRenamed('_c2','Count')
triplets.show(n=10)
#print((triplets.count(), len(triplets.columns)))

                                                                                

+--------------------+------------------+-----+
|                User|              Song|Count|
+--------------------+------------------+-----+
|b80344d063b5ccb32...|SOAKIMP12A8C130995|    1|
|b80344d063b5ccb32...|SOAPDEY12A81C210A9|    1|
|b80344d063b5ccb32...|SOBBMDR12A8C13253B|    2|
|b80344d063b5ccb32...|SOBFNSP12AF72A0E22|    1|
|b80344d063b5ccb32...|SOBFOVM12A58A7D494|    1|
|b80344d063b5ccb32...|SOBNZDC12A6D4FC103|    1|
|b80344d063b5ccb32...|SOBSUJE12A6D4F8CF5|    2|
|b80344d063b5ccb32...|SOBVFZR12A6D4F8AE3|    1|
|b80344d063b5ccb32...|SOBXALG12A8C13C108|    1|
|b80344d063b5ccb32...|SOBXHDL12A81C204C0|    1|
+--------------------+------------------+-----+
only showing top 10 rows



In [12]:
# initializing a new dataframe
# generate hash for the upcoming processing
tripletsDF = triplets.withColumn('UserID', F.hash(col('User')))
tripletsDF = tripletsDF.withColumn('SongID', F.hash(col('Song')))
tripletsDF = tripletsDF.withColumn('CountNum', col('Count').cast(IntegerType()))


# This DF should have 48M entries, if we want to just have a test,
# limit it to 10k or 100k (still takes hours on single node)
tripletsDF = tripletsDF.limit(1000)

# export a csv file for a glance view of tripletsDF
tripletsDF.limit(200).write.csv('tripletsDF_preview')

                                                                                

In [13]:
tripletsDF.show(5)



+--------------------+------------------+-----+----------+-----------+--------+
|                User|              Song|Count|    UserID|     SongID|CountNum|
+--------------------+------------------+-----+----------+-----------+--------+
|b80344d063b5ccb32...|SOAKIMP12A8C130995|    1|1365117428| 1315780877|       1|
|b80344d063b5ccb32...|SOAPDEY12A81C210A9|    1|1365117428|-1623759929|       1|
|b80344d063b5ccb32...|SOBBMDR12A8C13253B|    2|1365117428|-1218290021|       2|
|b80344d063b5ccb32...|SOBFNSP12AF72A0E22|    1|1365117428|-1227648141|       1|
|b80344d063b5ccb32...|SOBFOVM12A58A7D494|    1|1365117428| 2054460487|       1|
+--------------------+------------------+-----+----------+-----------+--------+
only showing top 5 rows



                                                                                

In [14]:
# train test split
# make this seprately if we skip the cv process
(train, test) = tripletsDF.randomSplit([0.8,0.2], seed = 42)
CV = 0

In [None]:
ALS_model = ALS(maxIter = 3, userCol = "UserId", itemCol = 'SongId', ratingCol = 'CountNum', coldStartStrategy = 'drop')

# grid seraching for the best parameter, adding more costs exponential time
grid = ParamGridBuilder().addGrid(ALS_model.rank, [10]).addGrid(ALS_model.regParam, [0.01]).build()

# set up cross validation process
CrossVal = CrossValidator(numFolds = 5, estimator = ALS_model, estimatorParamMaps = grid, evaluator = evaluator)

model = CrossVal.fit(train)
CV = 1
# get the best model from cross validation
TopModel = model.bestModel

In [16]:
# using the best hyperparameters from cv process
if CV == 1:
    estimator1 = ALS(rank = TopModel._java_obj.parent().getRank(), regParam = TopModel._java_obj.parent().getRegParam(), maxIter = 10, userCol = "UserId", itemCol = 'SongId', ratingCol = 'CountNum', coldStartStrategy = 'drop')
else:
    estimator1 = ALS(rank = 10, regParam = 0.01, maxIter = 10, userCol = "UserID", itemCol = 'SongID', ratingCol = 'CountNum', coldStartStrategy = 'drop')
model1 = estimator1.fit(train)

# initialzie a RMSE evaluator
evaluator = RegressionEvaluator(metricName = 'rmse', labelCol = 'CountNum', predictionCol = 'prediction')

# make predictions from this model and see the RMSE
predictions1 = model1.transform(test)
print('The RMSE is:', evaluator.evaluate(predictions1))

                                                                                

The RMSE is: 7.702996979255093


                                                                                

In [22]:
# generating recommendations (TOP 10)
Top10Rec = model1.recommendForAllUsers(10)
Top10Rec.printSchema()
Top10Rec.show(30)

Top10RecExploded = Top10Rec.withColumn('rec', explode("recommendations")).select('UserID', col("rec.SongID"), col("rec.Rating"))
Top10RecExploded.show(30)

# export a glance view of Top10Recommendations
Top10RecExploded.limit(100).write.csv('Top10Recommendations')

root
 |-- UserID: integer (nullable = false)
 |-- recommendations: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- SongID: integer (nullable = true)
 |    |    |-- rating: float (nullable = true)



                                                                                

+-----------+--------------------+
|     UserID|     recommendations|
+-----------+--------------------+
| 1175829282|[{-1949811848, 8....|
|-1485784057|[{-1511816900, 4....|
| -263623207|[{-1082424070, 10...|
| -913738036|[{1195944994, 5.5...|
|-1396894026|[{1968992317, 12....|
|-1748408205|[{1968992317, 8.3...|
|-1438603335|[{562809730, 14.2...|
|  287723006|[{1401083513, 8.9...|
|-1822857953|[{1401083513, 5.6...|
| -585083843|[{1571353104, 15....|
| 1365117428|[{1195944994, 7.9...|
| 1598577928|[{562809730, 28.0...|
| 1468154379|[{1401083513, 23....|
+-----------+--------------------+



                                                                                

+-----------+-----------+---------+
|     UserID|     SongID|   Rating|
+-----------+-----------+---------+
| 1175829282|-1949811848| 8.992541|
| 1175829282| 1401083513| 6.880934|
| 1175829282| 1280572134|5.3850784|
| 1175829282|  562809730|5.1725216|
| 1175829282|-1220381942| 4.995857|
| 1175829282| 1968992317| 4.260912|
| 1175829282| -463911918|4.1883945|
| 1175829282|  505486927|3.9966848|
| 1175829282|-1047121397|3.6946583|
| 1175829282|  587182546|3.5900528|
|-1485784057|-1511816900| 4.205525|
|-1485784057|-2124687182| 4.205525|
|-1485784057| 1441848682| 2.989493|
|-1485784057| 1616747166|2.5233152|
|-1485784057|  535871298|1.9929951|
|-1485784057| 1852075704|1.9929951|
|-1485784057| 1626940440|1.9929951|
|-1485784057|-1327476140|1.9929951|
|-1485784057| 1409316643|1.9929951|
|-1485784057| -395021924|1.9929951|
| -263623207|-1082424070|10.995666|
| -263623207| 1968992317|  9.03208|
| -263623207|-1164106294| 5.997636|
| -263623207| -876213096|5.5595717|
| -263623207|  409720438| 4.

                                                                                