In [13]:
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS

from pyspark.sql.functions import col, count, when
from pyspark.sql.functions import rand

from pyspark.sql.functions import explode, col

from pyspark.ml.recommendation import ALS


from pyspark.ml.evaluation import RegressionEvaluator

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, col

In [4]:
# Start a Spark Session
spark = SparkSession.builder \
    .appName("ALS Recommender") \
    .getOrCreate()

# Load data
file_path = 'interactions.csv'
data = spark.read.csv(file_path, header=True, inferSchema=True)

In [5]:
# Data Cleaning

data = data.select(
    col("User_A").alias("user_id"),
    col("User_B").alias("item_id"),
    col("Interaction_Intensity").alias("rating")
)

# Total rows in the dataset
total_rows = data.count()
print(f"Total rows: {total_rows}")

# Count the number of distinct rows
distinct_rows = data.dropDuplicates().count()
print(f"Distinct rows: {distinct_rows}")

# Calculate duplicates
duplicates = total_rows - distinct_rows
print(f"Number of duplicate rows: {duplicates}")


data = data.orderBy(rand())

data.printSchema()

Total rows: 632211




Distinct rows: 632211
Number of duplicate rows: 0
root
 |-- user_id: integer (nullable = true)
 |-- item_id: integer (nullable = true)
 |-- rating: double (nullable = true)



                                                                                

In [6]:
# View of the data
data.limit(10).show()

+-------+-------+------+
|user_id|item_id|rating|
+-------+-------+------+
|   1082|    875|   5.0|
|    302|   1101|   4.0|
|    343|    583|   2.0|
|   1223|    100|   4.0|
|    686|     85|   4.0|
|   1099|   1259|   3.0|
|    725|     50|   1.0|
|    196|    888|   5.0|
|    481|    716|   1.0|
|   1114|    627|   2.0|
+-------+-------+------+



In [7]:
# Data Split

train_data, test_data = data.randomSplit([0.9, 0.1], seed=42)
# Verify split sizes
print(f"Training data count: {train_data.count()}")
print(f"Test data count: {test_data.count()}")


                                                                                

Training data count: 569062




Test data count: 63127


                                                                                

In [None]:
# Model Training



# Initialize ALS model
als = ALS(
    #implicitPrefs=True,
    maxIter=20,              # Number of iterations
    regParam=0.1,            # Regularization parameter
    rank=20,                 # Number of latent factors
    userCol="user_id",       # Column for user IDs
    itemCol="item_id",       # Column for item IDs
    ratingCol="rating",      # Column for ratings
    coldStartStrategy="drop" # Handle unseen users/items during predictions
)

# Fit the ALS model on the training data
als_model = als.fit(train_data)


24/12/16 12:19:51 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
24/12/16 12:19:51 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
24/12/16 12:19:51 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK


In [9]:
# Generate predictions on the test set
predictions = als_model.transform(test_data)

# Show the predictions
predictions.show()


+-------+-------+------+----------+
|user_id|item_id|rating|prediction|
+-------+-------+------+----------+
|      0|    362|   4.0|   2.95297|
|      0|    491|   5.0| 3.0577862|
|      1|    277|   5.0| 2.7910562|
|      1|    500|   4.0| 3.0103023|
|      1|    982|   3.0| 2.7162454|
|      2|    317|   2.0| 3.0044632|
|      2|    509|   3.0| 2.9574344|
|      2|    572|   1.0| 2.9543889|
|      2|    985|   5.0|  3.023848|
|      3|    447|   1.0|  2.820734|
|      3|    884|   3.0|  3.003873|
|      3|   1296|   1.0|  2.833526|
|      4|    170|   3.0|  3.015949|
|      4|    914|   2.0| 2.8267546|
|      5|    437|   1.0| 2.8948205|
|      5|    791|   2.0| 2.7099829|
|      6|    779|   4.0| 2.9317272|
|      7|    139|   4.0|  2.889961|
|      7|    193|   1.0|  2.733398|
|      7|    265|   1.0| 2.7542136|
+-------+-------+------+----------+
only showing top 20 rows



Rating prediction metrics: RMSE, MAE, explained variance, R-squared
Ranking-based evaluation: MAP, NDCG, precision@K, recall@K

In [None]:
# Evaluate with RMSE and MAE

evaluator_rmse = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
evaluator_mae = RegressionEvaluator(metricName="mae", labelCol="rating", predictionCol="prediction")

rmse = evaluator_rmse.evaluate(predictions)
mae = evaluator_mae.evaluate(predictions)
print(f"rmse : {rmse} \n mae: {mae}")

rmse : 1.3737048459668288 
 mae: 1.1881133907431838


In [11]:
# Flatten Recommendations: Explode the nested recommendations to make it easier to evaluate: 


top_k_recommendations = als_model.recommendForAllUsers(500)


top_k_recommendations = top_k_recommendations.withColumn("recommendation", explode(col("recommendations"))) \
    .select("user_id", col("recommendation.item_id").alias("item_id"), col("recommendation.rating").alias("prediction"))

# Join with Ground Truth
joined = top_k_recommendations.join(test_data, on=["user_id", "item_id"], how="inner")

# Precision@K and Recall@K
precision_at_k = joined.count() / (top_k_recommendations.count() * 10)
recall_at_k = joined.count() / test_data.count()

print(f"Ranking Evaluation - Precision@K: {precision_at_k}, Recall@K: {recall_at_k}")


                                                                                

Ranking Evaluation - Precision@K: 0.0038103076923076923, Recall@K: 0.3924662199622994


Assuming we have obtained a good model after a deep model evaluation

In [12]:
# Initial Recommendation Example of user 0
top_k_recommendations.filter(col("user_id").isin([0])).show(n=50, truncate=False)

                                                                                

+-------+-------+----------+
|user_id|item_id|prediction|
+-------+-------+----------+
|0      |100    |3.5692847 |
|0      |840    |3.5431314 |
|0      |1047   |3.48891   |
|0      |901    |3.427322  |
|0      |117    |3.3889112 |
|0      |1258   |3.3825827 |
|0      |874    |3.3819797 |
|0      |436    |3.3791256 |
|0      |832    |3.3397155 |
|0      |701    |3.336751  |
|0      |904    |3.3150492 |
|0      |899    |3.298587  |
|0      |818    |3.2930748 |
|0      |58     |3.2900932 |
|0      |938    |3.2883945 |
|0      |15     |3.2758088 |
|0      |29     |3.2718625 |
|0      |525    |3.2670658 |
|0      |656    |3.2640038 |
|0      |1089   |3.2618954 |
|0      |1071   |3.2614617 |
|0      |422    |3.2583003 |
|0      |407    |3.2581136 |
|0      |1289   |3.252446  |
|0      |992    |3.2519023 |
|0      |934    |3.2493954 |
|0      |64     |3.2349408 |
|0      |876    |3.234303  |
|0      |528    |3.2331333 |
|0      |1295   |3.2298903 |
|0      |668    |3.2295349 |
|0      |56   

 Select Top-N Candidates for Content-Based Filtering (CBF)

In [14]:
# Define a window to rank predictions for each user
window = Window.partitionBy("user_id").orderBy(col("prediction").desc())

# Select top 200 candidates
top_n_candidates = predictions.withColumn("rank", row_number().over(window)) \
    .filter(col("rank") <= 200)

In [15]:
top_n_candidates.limit(300).show()

+-------+-------+------+----------+----+
|user_id|item_id|rating|prediction|rank|
+-------+-------+------+----------+----+
|      0|    904|   2.0| 3.3150492|   1|
|      0|    763|   3.0| 3.2139301|   2|
|      0|    601|   5.0|  3.204808|   3|
|      0|   1031|   2.0| 3.1805916|   4|
|      0|   1007|   2.0| 3.1190002|   5|
|      0|    246|   4.0| 3.0943031|   6|
|      0|    301|   3.0| 3.0645356|   7|
|      0|    491|   5.0| 3.0577862|   8|
|      0|   1015|   2.0| 3.0565495|   9|
|      0|    240|   2.0|  3.039404|  10|
|      0|    105|   4.0|  3.031528|  11|
|      0|    292|   1.0| 3.0151675|  12|
|      0|    766|   3.0| 3.0024548|  13|
|      0|    622|   1.0| 2.9558887|  14|
|      0|    362|   4.0|   2.95297|  15|
|      0|    419|   5.0|   2.92155|  16|
|      0|   1189|   2.0| 2.9207273|  17|
|      0|   1127|   5.0|  2.902138|  18|
|      0|    831|   5.0|  2.869922|  19|
|      0|   1117|   5.0|   2.85972|  20|
+-------+-------+------+----------+----+
only showing top