## Item-based Collaborative Filtering

## PERSONAL NOTES:
Runnning with pyspark
- If you get Py4JJavaError, remember to ensure pyspark system variables correctly
    - echo %PYSPARK_PYTHON%
    - echo %PYSPARK_DRIVER_PYTHON%


#### Rationale
1. Relatively large number of users compared to relevant news articles. Thus it is easier computationally to compare items than users.
2. Item stability > User stability. Once a news article is out, it's content is fixed, while a user might change taste often. This can make user-based collaborative filtering more inaccurate in relation to the user's present taste. Similarity between items is constanst, i.e. the need for recalculations will be less with item-based collaborative filtering.
3. Few news article interactions per user. This makes it harder to guess similar users as in user-based collaborative filtering.


#### Item-based collaborative filtering in a nutshell (MIND)
"Find articles that are likely to be of interest, based on shared user interest patterns. Return the top N articles that are most similar to any of the news articles the user has clicked on, based on the similarity calculations between items."
1. For each news article a user has clicked, get an overview of articles other users have also clicked
2. Matrix factorization for efficiency - Alternating Least Squares (ALS)
3. Calculate the similarity of each article (similarity of interactions) - Locality-Sensitive Hasing (LSH)
4. Repeat steps for each news articles, and sort the recommendations list according to articles with the highest cosine similarity


#### Preperation of data

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.functions import explode, split, col, lit, desc, sum, udf, broadcast
from pyspark.ml.recommendation import ALS
from pyspark.ml.feature import StringIndexer, BucketedRandomProjectionLSH
from pyspark.ml.linalg import Vectors, VectorUDT


spark = SparkSession.builder \
    .appName("MINDItemBasedFiltering") \
    .config("spark.executor.memory", "8g") \
    .config("spark.driver.memory", "8g") \
    .config("spark.driver.extraJavaOptions", "-XX:+UseG1GC") \
    .config("spark.executor.extraJavaOptions", "-XX:+UseG1GC") \
    .getOrCreate()

    
# Define the schema of the dataset
schema = StructType([
    StructField("ImpressionID", IntegerType(), True),
    StructField("UserID", StringType(), True),
    StructField("Time", StringType(), True),
    StructField("History", StringType(), True),
    StructField("Impressions", StringType(), True)
])

# Load the dataset with the defined schema
data = spark.read.csv("data/MINDsmall_dev/behaviors.tsv", sep="\t", schema=schema)

data.show(5, truncate=False)

# Explode the history column into separate rows for each article per user 
# I.e. UserID | NewsArticle (that that user has stored in their history)
data = data.withColumn("NewsArticle", explode(split(col("History"), " "))) \
    .select(col("UserID").alias("user_id"), col("NewsArticle").alias("news_article"))

data.show(5, truncate=False)

+------------+------+----------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|ImpressionID|UserID|Time                  |History                                                                                                                                                                                                                                                                                                  |Impressions                                                                            

#### Alternating Least Squares (ALS)

In [2]:
# Prepare data to work with ALS
# Add a dummy 'rating' column to indicate interaction
data = data.withColumn("rating", lit(1))

# Index the user_id and news_article columns
user_indexer = StringIndexer(inputCol="user_id", outputCol="user_id_index").fit(data)
item_indexer = StringIndexer(inputCol="news_article", outputCol="news_article_id_index").fit(data)

data = user_indexer.transform(data)
data = item_indexer.transform(data)

# Extract mappings from StringIndexer models
user_id_index_mapping = user_indexer.labels
news_article_id_index_mapping = item_indexer.labels

# Convert mappings to DataFrames for easier use
user_id_index_df = spark.createDataFrame([(i, user_id_index_mapping[i]) for i in range(len(user_id_index_mapping))], ["user_id_index", "user_id"])
news_article_id_index_df = spark.createDataFrame([(i, news_article_id_index_mapping[i]) for i in range(len(news_article_id_index_mapping))], ["news_article_id_index", "news_article"])

# Select the final columns for ALS
data = data.select("user_id_index", "news_article_id_index", "rating")

data.show(5, truncate=False)

# Train the ALS model
# Note: We use implicitPrefs=True to indicate that we are working with implicit feedback (clicks)
als = ALS(maxIter=5, regParam=0.01, userCol="user_id_index", itemCol="news_article_id_index", ratingCol="rating", coldStartStrategy="drop", implicitPrefs=True)
model = als.fit(data)

# Extract the item factors from the ALS model
item_factors = model.itemFactors
#item_factors = model.itemFactors.limit(1000) #Subset for testing

num_item_factors = item_factors.count()
print(f"Number of item factors: {num_item_factors}")

item_factors.show(5)

+-------------+---------------------+------+
|user_id_index|news_article_id_index|rating|
+-------------+---------------------+------+
|10460.0      |6.0                  |1     |
|10460.0      |279.0                |1     |
|10460.0      |1243.0               |1     |
|10460.0      |201.0                |1     |
|10460.0      |1734.0               |1     |
+-------------+---------------------+------+
only showing top 5 rows

Number of item factors: 37704
+---+--------------------+
| id|            features|
+---+--------------------+
|  0|[-0.5722367, -0.0...|
| 10|[-0.295189, -0.60...|
| 20|[-0.5555033, -0.2...|
| 30|[0.12897371, -0.4...|
| 40|[0.29763302, -0.2...|
+---+--------------------+
only showing top 5 rows



#### Calculating Similarity - Locality-Sensitive Hasing (LSH)

In [3]:
# In order to calculate the similarity between items by using Spark's LSH, we need to convert the item factors into a DenseVector
# Define a UDF that converts an array of floats into a DenseVector
to_vector = udf(lambda x: Vectors.dense(x), VectorUDT())

# Apply the UDF to the 'features' column
item_factors = item_factors.withColumn("features", to_vector("features"))

# Prepare for calculating similarity
# Initialize the LSH model
brp = BucketedRandomProjectionLSH(inputCol="features", outputCol="hashes", bucketLength=3.0, numHashTables=2)

# Fit the LSH model on the item factors
model_lsh = brp.fit(item_factors)

# Transform item factors to hash table
item_factors_hashed = model_lsh.transform(item_factors)

# Calculate Similiary
# Calculate approx similarity join
similar_items = model_lsh.approxSimilarityJoin(item_factors_hashed, item_factors_hashed, threshold=1.5, distCol="EuclideanDistance")
 
# Show some results
#similar_items.select(col("datasetA.id").alias("idA"), col("datasetB.id").alias("idB"), "EuclideanDistance").show(5)


Py4JJavaError: An error occurred while calling o338.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 208.0 failed 1 times, most recent failure: Lost task 3.0 in stage 208.0 (TID 322) (10.24.83.125 executor driver): java.io.IOException: Det er ikke nok plass på disk(ett)en
	at java.base/java.io.FileOutputStream.writeBytes(Native Method)
	at java.base/java.io.FileOutputStream.write(FileOutputStream.java:349)
	at org.apache.spark.storage.TimeTrackingOutputStream.write(TimeTrackingOutputStream.java:59)
	at java.base/java.io.BufferedOutputStream.flushBuffer(BufferedOutputStream.java:81)
	at java.base/java.io.BufferedOutputStream.write(BufferedOutputStream.java:127)
	at net.jpountz.lz4.LZ4BlockOutputStream.flushBufferedData(LZ4BlockOutputStream.java:225)
	at net.jpountz.lz4.LZ4BlockOutputStream.write(LZ4BlockOutputStream.java:178)
	at org.apache.spark.storage.DiskBlockObjectWriter.write(DiskBlockObjectWriter.scala:323)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.write(UnsafeSorterSpillWriter.java:136)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spillIterator(UnsafeExternalSorter.java:576)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill(UnsafeExternalSorter.java:231)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.createWithExistingInMemorySorter(UnsafeExternalSorter.java:115)
	at org.apache.spark.sql.execution.UnsafeKVExternalSorter.<init>(UnsafeKVExternalSorter.java:158)
	at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter(UnsafeFixedWidthAggregationMap.java:243)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.hashAgg_doConsume_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.hashAgg_doAggregateWithKeys_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:842)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
Caused by: java.io.IOException: Det er ikke nok plass på disk(ett)en
	at java.base/java.io.FileOutputStream.writeBytes(Native Method)
	at java.base/java.io.FileOutputStream.write(FileOutputStream.java:349)
	at org.apache.spark.storage.TimeTrackingOutputStream.write(TimeTrackingOutputStream.java:59)
	at java.base/java.io.BufferedOutputStream.flushBuffer(BufferedOutputStream.java:81)
	at java.base/java.io.BufferedOutputStream.write(BufferedOutputStream.java:127)
	at net.jpountz.lz4.LZ4BlockOutputStream.flushBufferedData(LZ4BlockOutputStream.java:225)
	at net.jpountz.lz4.LZ4BlockOutputStream.write(LZ4BlockOutputStream.java:178)
	at org.apache.spark.storage.DiskBlockObjectWriter.write(DiskBlockObjectWriter.scala:323)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.write(UnsafeSorterSpillWriter.java:136)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spillIterator(UnsafeExternalSorter.java:576)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill(UnsafeExternalSorter.java:231)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.createWithExistingInMemorySorter(UnsafeExternalSorter.java:115)
	at org.apache.spark.sql.execution.UnsafeKVExternalSorter.<init>(UnsafeKVExternalSorter.java:158)
	at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter(UnsafeFixedWidthAggregationMap.java:243)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.hashAgg_doConsume_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.hashAgg_doAggregateWithKeys_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:842)


### Prepare similiarity data for recommendations

In [None]:
user_item_interactions = data.select("user_id_index", "news_article_id_index").distinct().cache()

# Step 1: Flatten similar_items for easier handling
flat_similar_items = similar_items.select(
    col("datasetA.id").alias("article_id"),
    col("datasetB.id").alias("similar_article_id"),
    col("EuclideanDistance")
)

# Step 2: Filter for new recommendations per user
# Join user interactions with similar items to find potential recommendations
potential_recommendations = user_item_interactions.join(
    broadcast(flat_similar_items),
    user_item_interactions.news_article_id_index == flat_similar_items.article_id,
    "inner"
).select(
    "user_id_index",
    "similar_article_id",
    "EuclideanDistance"
).distinct()

# Step 3: Filter out articles the user has already interacted with
filtered_recommendations = potential_recommendations.join(
    broadcast(user_item_interactions),
    (potential_recommendations.user_id_index == user_item_interactions.user_id_index) & 
    (potential_recommendations.similar_article_id == user_item_interactions.news_article_id_index),
    "left_anti"
)


In [None]:
def get_top_n_recommendations(user_id, N=5):
    # Check if the user_id mapping to user_id_index is successful and the user ID exists in the dataset
    user_id_index_row = user_id_index_df.filter(col("user_id") == user_id).select("user_id_index").first()
    if user_id_index_row is None:
        print(f"No user_id_index found for user_id {user_id}")
        return None
    user_id_index = user_id_index_row["user_id_index"]
    
    # Retrieve recommendations
    specific_user_recommendations = filtered_recommendations.filter(
        filtered_recommendations.user_id_index == user_id_index
    )
    
    # Aggregate and rank recommendations
    ranked_recommendations = specific_user_recommendations.groupBy("similar_article_id").agg(
        (1 / sum("EuclideanDistance")).alias("score")
    ).orderBy(desc("score")).limit(N)
    
    # Fetch the top N recommendations
    news_article_id_index_df = spark.createDataFrame([(i, news_article_id_index_mapping[i]) for i in range(len(news_article_id_index_mapping))], ["news_article_id_index", "news_article"])
    top_n_recommendations = ranked_recommendations.join(
        broadcast(news_article_id_index_df), 
        ranked_recommendations.similar_article_id == news_article_id_index_df.news_article_id_index
    ).select("news_article", "score")
    
    return top_n_recommendations

# Usage example:
top_n_recommendations = get_top_n_recommendations("U80234", 5)
top_n_recommendations.show()


+------------+------------------+
|news_article|             score|
+------------+------------------+
|      N25677|0.7655715202494393|
|      N31213| 0.760682585315338|
|      N28053|0.7467859080521225|
|      N49469|0.7434247009918218|
|      N45330|0.7433255580554285|
+------------+------------------+

