In [13]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, expr, rank, countDistinct, count
from pyspark.sql.window import Window
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import StringIndexer
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder, CrossValidator
import pandas as pd

In [14]:
spark = SparkSession.builder.appName('tourism-destination').getOrCreate()

In [15]:
df_landing = pd.read_csv('/Users/raopend/Workspace/travel-recommendation-system/notebooks/dataset/tourism_rating.csv')
spark_df_landing =  spark.createDataFrame(df_landing) 
spark_df_landing.show(vertical=True)

-RECORD 0------------
 User_Id       | 1   
 Place_Id      | 179 
 Place_Ratings | 3   
-RECORD 1------------
 User_Id       | 1   
 Place_Id      | 344 
 Place_Ratings | 2   
-RECORD 2------------
 User_Id       | 1   
 Place_Id      | 5   
 Place_Ratings | 5   
-RECORD 3------------
 User_Id       | 1   
 Place_Id      | 373 
 Place_Ratings | 3   
-RECORD 4------------
 User_Id       | 1   
 Place_Id      | 101 
 Place_Ratings | 4   
-RECORD 5------------
 User_Id       | 1   
 Place_Id      | 312 
 Place_Ratings | 2   
-RECORD 6------------
 User_Id       | 1   
 Place_Id      | 258 
 Place_Ratings | 5   
-RECORD 7------------
 User_Id       | 1   
 Place_Id      | 20  
 Place_Ratings | 4   
-RECORD 8------------
 User_Id       | 1   
 Place_Id      | 154 
 Place_Ratings | 2   
-RECORD 9------------
 User_Id       | 1   
 Place_Id      | 393 
 Place_Ratings | 5   
-RECORD 10-----------
 User_Id       | 1   
 Place_Id      | 103 
 Place_Ratings | 3   
-RECORD 11-----------
 User_Id  

In [16]:
df_rec = spark_df_landing.select('User_Id', 'Place_Id', 'Place_Ratings').withColumnRenamed("User_Id","userId")\
                                                                 .withColumnRenamed("Place_Id","itemId")\
                                                                 .withColumnRenamed("Place_Ratings","rating")
df_rec = df_rec.orderBy("userId", "itemId")

In [17]:
popularity_df = df_rec.groupBy('itemId') \
                 .agg(count('*').alias('popularity')) \
                 .orderBy(col('popularity').desc())

In [18]:
user_window = Window.partitionBy("userId").orderBy(col("itemId").desc())
df_rec = df_rec.withColumn("num_items", expr("count(*) over (partition by userId)"))


In [19]:
# For example, 30% of items will be masked
percent_items_to_mask = 0.3 
# Determine the number of items to mask for each user
df_rec_final = df_rec.withColumn("num_items_to_mask", (col("num_items") * percent_items_to_mask).cast("int"))
# Masks items for each user
df_rec_final = df_rec_final.withColumn("item_rank", rank().over(user_window))

# Create a StringIndexer model to index the user ID column
indexer_user = StringIndexer(inputCol='userId', outputCol='userIndex').setHandleInvalid("keep")
indexer_item = StringIndexer(inputCol='itemId', outputCol='itemIndex').setHandleInvalid("keep")

# Fit the indexer model to the data and transform the DataFrame
df_rec_final = indexer_user.fit(df_rec_final).transform(df_rec_final)
df_rec_final = indexer_item.fit(df_rec_final).transform(df_rec_final)

# Convert the userIndex column to integer type
df_rec_final = df_rec_final.withColumn('userIndex', df_rec_final['userIndex'].cast('integer'))\
               .withColumn('itemIndex', df_rec_final['itemIndex'].cast('integer'))

train_df_rec = df_rec_final.filter(col("item_rank") > col("num_items_to_mask"))
test_df_rec = df_rec_final.filter(col("item_rank") <= col("num_items_to_mask"))

In [20]:
train_df_rec.show()

+------+------+------+---------+-----------------+---------+---------+---------+
|userId|itemId|rating|num_items|num_items_to_mask|item_rank|userIndex|itemIndex|
+------+------+------+---------+-----------------+---------+---------+---------+
|     1|   307|     4|       30|                9|       10|      199|       82|
|     1|   302|     2|       30|                9|       11|      199|      137|
|     1|   292|     3|       30|                9|       12|      199|      314|
|     1|   265|     5|       30|                9|       13|      199|       21|
|     1|   258|     5|       30|                9|       14|      199|      159|
|     1|   246|     4|       30|                9|       15|      199|       66|
|     1|   222|     3|       30|                9|       16|      199|      336|
|     1|   208|     5|       30|                9|       17|      199|        4|
|     1|   179|     3|       30|                9|       18|      199|      154|
|     1|   154|     2|      

In [21]:
# Configure the ALS model
als = ALS(userCol='userIndex', itemCol='itemIndex', ratingCol='rating',
          coldStartStrategy='drop', nonnegative=True)


param_grid = ParamGridBuilder()\
             .addGrid(als.rank, [1, 20, 30])\
             .addGrid(als.maxIter, [20])\
             .addGrid(als.regParam, [.05, .15])\
             .build()
evaluator = RegressionEvaluator(metricName='rmse', labelCol='rating', predictionCol='prediction')

cv = CrossValidator(
        estimator=als,
        estimatorParamMaps=param_grid,
        evaluator=evaluator,
        numFolds=3)

model = cv.fit(train_df_rec)

best_model = model.bestModel
print('rank: ', best_model.rank)
print('MaxIter: ', best_model._java_obj.parent().getMaxIter())
print('RegParam: ', best_model._java_obj.parent().getRegParam())

24/01/18 10:48:07 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
                                                                                

rank:  1
MaxIter:  20
RegParam:  0.15


In [22]:
model = als.fit(train_df_rec)

# Generate predictions on the test data
predictions = best_model.transform(test_df_rec)
predictions = predictions.withColumn("prediction", expr("CASE WHEN prediction < 1 THEN 1 WHEN prediction > 5 THEN 5 ELSE prediction END"))

evaluator = RegressionEvaluator(metricName='rmse', labelCol='rating', predictionCol='prediction')
rmse = evaluator.evaluate(predictions)
print(f'Root Mean Squared Error (RMSE): {rmse}')

Root Mean Squared Error (RMSE): 1.6368986356089679


In [23]:
from pyspark.mllib.evaluation import RankingMetrics
from pyspark.sql.functions import col, collect_list

# Convert the predictions DataFrame to include all predictions per user
# Generate top-k recommendations for each user
userRecs = best_model.recommendForAllUsers(100)  # Top-100 recommendations for each user

# Prepare the input for RankingMetrics
user_ground_truth = test_df_rec.groupby('userIndex').agg(collect_list('itemIndex').alias('ground_truth_items'))
user_train_items = train_df_rec.groupby('userIndex').agg(collect_list('itemIndex').alias('train_items'))

# Join the recommendations and ground truth data on the user ID
user_eval = userRecs.join(user_ground_truth, on='userIndex').join(user_train_items, on='userIndex') \
    .select('userIndex', 'recommendations.itemIndex', 'ground_truth_items', 'train_items', 'recommendations.rating')
user_eval = user_eval.toPandas()
user_eval['itemIndex_filtered'] = user_eval.apply(lambda x:[b for (b,z) in zip(x.itemIndex, x.rating) if b not in x.train_items], axis=1)
user_eval['rating_filtered'] = user_eval.apply(lambda x:[z for (b,z) in zip(x.itemIndex, x.rating) if b not in x.train_items], axis=1)

                                                                                

In [24]:
user_eval

Unnamed: 0,userIndex,itemIndex,ground_truth_items,train_items,rating,itemIndex_filtered,rating_filtered
0,0,"[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[91, 29, 397, 209, 395, 433, 429, 284, 283, 27...","[165, 164, 102, 426, 162, 21, 339, 67, 159, 13...","[5.26589822769165, 5.209471702575684, 5.115672...","[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[5.26589822769165, 5.209471702575684, 5.115672..."
1,1,"[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[286, 53, 411, 24, 348, 420, 393, 168, 104, 16...","[366, 267, 196, 195, 265, 265, 190, 36, 36, 0,...","[3.9084765911102295, 3.866595506668091, 3.7969...","[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[3.9084765911102295, 3.866595506668091, 3.7969..."
2,2,"[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[243, 115, 91, 140, 171, 170, 204, 43, 281, 13...","[51, 198, 371, 427, 164, 102, 99, 195, 194, 65...","[5.068458080291748, 5.0141472816467285, 4.9238...","[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[5.068458080291748, 5.0141472816467285, 4.9238..."
3,3,"[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[142, 289, 12, 286, 88, 394, 374, 168, 200, 1,...","[337, 130, 311, 363, 64, 64, 230, 35, 261, 261...","[5.758917808532715, 5.697208404541016, 5.59462...","[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[5.758917808532715, 5.697208404541016, 5.59462..."
4,4,"[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[244, 142, 292, 54, 324, 112, 53, 28, 171, 283...","[374, 13, 167, 372, 9, 197, 162, 418, 6, 192, ...","[4.415762424468994, 4.368445873260498, 4.28978...","[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[4.415762424468994, 4.368445873260498, 4.28978..."
...,...,...,...,...,...,...,...
295,295,"[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[2, 74, 349, 140, 169, 137]","[390, 134, 158, 7, 425, 365, 97, 263, 62, 34, ...","[3.5155069828033447, 3.477836847305298, 3.4152...","[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[3.5155069828033447, 3.477836847305298, 3.4152..."
296,296,"[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[286, 413, 380, 53, 420, 394]","[200, 50, 365, 98, 263, 156, 127, 154, 79, 385...","[5.281162261962891, 5.22457218170166, 5.130500...","[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[5.281162261962891, 5.22457218170166, 5.130500..."
297,297,"[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[74, 73, 421, 325, 70, 420]","[23, 103, 99, 20, 423, 0, 330, 303, 77, 149, 1...","[5.120181560516357, 5.065316677093506, 4.97411...","[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[5.120181560516357, 5.065316677093506, 4.97411..."
298,298,"[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[243, 91, 54, 412, 203, 84]","[372, 166, 426, 161, 99, 266, 192, 97, 261, 30...","[4.828635215759277, 4.7768940925598145, 4.6908...","[70, 380, 321, 378, 377, 203, 379, 282, 280, 2...","[4.828635215759277, 4.7768940925598145, 4.6908..."
