# Hybrid Recommendation System

**Team Structure:**
- Member 1: Infrastructure, Data Loading, Fusion & Evaluation
- Member 2: Collaborative Filtering (ALS)
- Member 3: Content-Based Filtering (TF-IDF + LSH)

In [5]:
import os
import urllib.request
import zipfile
from math import log2

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, FloatType, StructType, StructField

## Data Download

In [6]:
DATA_URL = "https://files.grouplens.org/datasets/movielens/ml-1m.zip"
DATA_DIR = "data"
DATASET_DIR = os.path.join(DATA_DIR, "ml-1m")
ZIP_PATH = os.path.join(DATA_DIR, "ml-1m.zip")

In [7]:
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)

In [8]:
if not os.path.exists(DATASET_DIR):

    if not os.path.exists(ZIP_PATH):
        print("Downloading MovieLens ml-1m...")
        urllib.request.urlretrieve(DATA_URL, ZIP_PATH)
    
    print("Extracting...")
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(DATA_DIR)

## Spark Session

In [9]:
spark = SparkSession.builder.appName("MMDS").getOrCreate()
spark.sparkContext.setLogLevel("WARN")

## Load Data

In [10]:
users_df = spark.read.text(os.path.join(DATASET_DIR, "users.dat")).select(
    F.split(F.col("value"), "::").getItem(0).cast(IntegerType()).alias("user_id"),
    F.split(F.col("value"), "::").getItem(1).alias("gender"),
    F.split(F.col("value"), "::").getItem(2).cast(IntegerType()).alias("age"),
    F.split(F.col("value"), "::").getItem(3).cast(IntegerType()).alias("occupation"),
    F.split(F.col("value"), "::").getItem(4).alias("zip_code")
)
users_df.count()

6040

In [11]:
items_df = spark.read.text(os.path.join(DATASET_DIR, "movies.dat")).select(
    F.split(F.col("value"), "::").getItem(0).cast(IntegerType()).alias("item_id"),
    F.split(F.col("value"), "::").getItem(1).alias("title"),
    F.split(F.col("value"), "::").getItem(2).alias("genres")
)
items_df.count()

3883

In [12]:
ratings_df = spark.read.text(os.path.join(DATASET_DIR, "ratings.dat")).select(
    F.split(F.col("value"), "::").getItem(0).cast(IntegerType()).alias("user_id"),
    F.split(F.col("value"), "::").getItem(1).cast(IntegerType()).alias("item_id"),
    F.split(F.col("value"), "::").getItem(2).cast(FloatType()).alias("rating"),
    F.split(F.col("value"), "::").getItem(3).cast(IntegerType()).alias("timestamp")
)
ratings_df.count()

1000209

In [13]:
users_df.show(5)

+-------+------+---+----------+--------+
|user_id|gender|age|occupation|zip_code|
+-------+------+---+----------+--------+
|      1|     F|  1|        10|   48067|
|      2|     M| 56|        16|   70072|
|      3|     M| 25|        15|   55117|
|      4|     M| 45|         7|   02460|
|      5|     M| 25|        20|   55455|
+-------+------+---+----------+--------+
only showing top 5 rows


In [14]:
items_df.show(5)

+-------+--------------------+--------------------+
|item_id|               title|              genres|
+-------+--------------------+--------------------+
|      1|    Toy Story (1995)|Animation|Childre...|
|      2|      Jumanji (1995)|Adventure|Childre...|
|      3|Grumpier Old Men ...|      Comedy|Romance|
|      4|Waiting to Exhale...|        Comedy|Drama|
|      5|Father of the Bri...|              Comedy|
+-------+--------------------+--------------------+
only showing top 5 rows


In [15]:
ratings_df.show(5)

+-------+-------+------+---------+
|user_id|item_id|rating|timestamp|
+-------+-------+------+---------+
|      1|   1193|   5.0|978300760|
|      1|    661|   3.0|978302109|
|      1|    914|   3.0|978301968|
|      1|   3408|   4.0|978300275|
|      1|   2355|   5.0|978824291|
+-------+-------+------+---------+
only showing top 5 rows


## Exploratory Data Analysis

In [16]:
num_users = users_df.count()
num_items = items_df.count()
num_ratings = ratings_df.count()
sparsity = (1 - (num_ratings / (num_users * num_items))) * 100

print(f"Users:            {num_users:,}")
print(f"Movies:           {num_items:,}")
print(f"Ratings:          {num_ratings:,}")
print(f"Sparsity:         {sparsity:.2f}%")
print(f"Avg ratings/user: {num_ratings/num_users:.1f}")
print(f"Avg ratings/movie:{num_ratings/num_items:.1f}")

Users:            6,040
Movies:           3,883
Ratings:          1,000,209
Sparsity:         95.74%
Avg ratings/user: 165.6
Avg ratings/movie:257.6


In [17]:
ratings_df.groupBy("rating").count().orderBy("rating").show()

+------+------+
|rating| count|
+------+------+
|   1.0| 56174|
|   2.0|107557|
|   3.0|261197|
|   4.0|348971|
|   5.0|226310|
+------+------+



In [18]:
items_df.select(F.explode(F.split(F.col("genres"), "\\|")).alias("genre")) \
    .groupBy("genre").count().orderBy(F.desc("count")).show()

+-----------+-----+
|      genre|count|
+-----------+-----+
|      Drama| 1603|
|     Comedy| 1200|
|     Action|  503|
|   Thriller|  492|
|    Romance|  471|
|     Horror|  343|
|  Adventure|  283|
|     Sci-Fi|  276|
| Children's|  251|
|      Crime|  211|
|        War|  143|
|Documentary|  127|
|    Musical|  114|
|    Mystery|  106|
|  Animation|  105|
|    Fantasy|   68|
|    Western|   68|
|  Film-Noir|   44|
+-----------+-----+



In [19]:
users_df.groupBy("gender").count().show()

+------+-----+
|gender|count|
+------+-----+
|     F| 1709|
|     M| 4331|
+------+-----+



In [20]:
users_df.groupBy("age").count().orderBy("age").show()

+---+-----+
|age|count|
+---+-----+
|  1|  222|
| 18| 1103|
| 25| 2096|
| 35| 1193|
| 45|  550|
| 50|  496|
| 56|  380|
+---+-----+



## Train/Test Split

In [21]:
train_df, test_df = ratings_df.randomSplit([0.8, 0.2], seed=42)
train_df = train_df.cache()
test_df = test_df.cache()

In [22]:
train_df.count()

800092

In [23]:
test_df.count()

200117

# Collaborative Filtering (ALS)

Implement using `pyspark.ml.recommendation.ALS`

In [24]:
# TODO

als_model = None

In [25]:
als_recs = spark.createDataFrame([], StructType([
    StructField("user_id", IntegerType()),
    StructField("item_id", IntegerType()),
    StructField("als_score", FloatType())
]))

In [26]:
als_recs.show(10)

Py4JJavaError: An error occurred while calling o271.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 41.0 failed 1 times, most recent failure: Lost task 0.0 in stage 41.0 (TID 64) (ThinkBook executor driver): java.io.IOException: Cannot run program "python3": CreateProcess error=2, The system cannot find the file specified
	at java.base/java.lang.ProcessBuilder.start(ProcessBuilder.java:1143)
	at java.base/java.lang.ProcessBuilder.start(ProcessBuilder.java:1073)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:218)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:143)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:158)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:178)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:261)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:70)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:77)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:650)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:840)
Caused by: java.io.IOException: CreateProcess error=2, The system cannot find the file specified
	at java.base/java.lang.ProcessImpl.create(Native Method)
	at java.base/java.lang.ProcessImpl.<init>(ProcessImpl.java:505)
	at java.base/java.lang.ProcessImpl.start(ProcessImpl.java:158)
	at java.base/java.lang.ProcessBuilder.start(ProcessBuilder.java:1110)
	... 35 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$3(DAGScheduler.scala:2935)
	at scala.Option.getOrElse(Option.scala:201)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2935)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2927)
	at scala.collection.immutable.List.foreach(List.scala:334)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2927)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1295)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1295)
	at scala.Option.foreach(Option.scala:437)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1295)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3207)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3141)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3130)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:50)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1009)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2484)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2505)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2524)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:544)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:497)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:58)
	at org.apache.spark.sql.classic.Dataset.collectFromPlan(Dataset.scala:2244)
	at org.apache.spark.sql.classic.Dataset.$anonfun$head$1(Dataset.scala:1379)
	at org.apache.spark.sql.classic.Dataset.$anonfun$withAction$2(Dataset.scala:2234)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:654)
	at org.apache.spark.sql.classic.Dataset.$anonfun$withAction$1(Dataset.scala:2232)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$8(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$.withSessionTagsApplied(SQLExecution.scala:272)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$7(SQLExecution.scala:125)
	at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
	at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:112)
	at org.apache.spark.sql.artifact.ArtifactManager.withClassLoaderIfNeeded(ArtifactManager.scala:106)
	at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:111)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:295)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$1(SQLExecution.scala:124)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:804)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId0(SQLExecution.scala:78)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:237)
	at org.apache.spark.sql.classic.Dataset.withAction(Dataset.scala:2232)
	at org.apache.spark.sql.classic.Dataset.head(Dataset.scala:1379)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2810)
	at org.apache.spark.sql.classic.Dataset.getRows(Dataset.scala:339)
	at org.apache.spark.sql.classic.Dataset.showString(Dataset.scala:375)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:569)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:184)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:108)
	at java.base/java.lang.Thread.run(Thread.java:840)
Caused by: java.io.IOException: Cannot run program "python3": CreateProcess error=2, The system cannot find the file specified
	at java.base/java.lang.ProcessBuilder.start(ProcessBuilder.java:1143)
	at java.base/java.lang.ProcessBuilder.start(ProcessBuilder.java:1073)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:218)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:143)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:158)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:178)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:261)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:70)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:77)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:650)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	... 1 more
Caused by: java.io.IOException: CreateProcess error=2, The system cannot find the file specified
	at java.base/java.lang.ProcessImpl.create(Native Method)
	at java.base/java.lang.ProcessImpl.<init>(ProcessImpl.java:505)
	at java.base/java.lang.ProcessImpl.start(ProcessImpl.java:158)
	at java.base/java.lang.ProcessBuilder.start(ProcessBuilder.java:1110)
	... 35 more


### Bonus: Hyperparameter Tuning

In [None]:
# TODO

# Content-Based Filtering (TF-IDF + LSH)

Implement using `pyspark.ml.feature` (Tokenizer, HashingTF, IDF, BucketedRandomProjectionLSH)

In [None]:
content_recs = spark.createDataFrame([], StructType([
    StructField("user_id", IntegerType()),
    StructField("item_id", IntegerType()),
    StructField("content_score", FloatType())
]))

In [None]:
content_recs.show(10)

# Fusion & Evaluation

In [None]:
ALPHA = 0.7
K = 10
RELEVANCE_THRESHOLD = 4.0

### Normalization

In [None]:
def normalize(df, col_name):
    stats = df.agg(F.min(col_name).alias("min"), F.max(col_name).alias("max")).collect()[0]
    if stats["max"] == stats["min"]:
        return df.withColumn(col_name + "_norm", F.lit(0.5))
    return df.withColumn(col_name + "_norm", (F.col(col_name) - stats["min"]) / (stats["max"] - stats["min"]))

### Hybrid Fusion

In [None]:
als_norm = normalize(als_recs, "als_score")
content_norm = normalize(content_recs, "content_score")

hybrid_recs = als_norm.select("user_id", "item_id", "als_score_norm") \
    .join(content_norm.select("user_id", "item_id", "content_score_norm"), ["user_id", "item_id"], "full_outer") \
    .fillna(0) \
    .withColumn("final_score", ALPHA * F.col("als_score_norm") + (1 - ALPHA) * F.col("content_score_norm"))

In [None]:
hybrid_recs.orderBy(F.desc("final_score")).show(10)

### Ground Truth

In [None]:
ground_truth = test_df.filter(F.col("rating") >= RELEVANCE_THRESHOLD) \
    .groupBy("user_id").agg(F.collect_list("item_id").alias("relevant_items"))

### Evaluation Functions

In [None]:
def get_top_k(recs_df, score_col, k):
    window = Window.partitionBy("user_id").orderBy(F.desc(score_col))
    return recs_df.withColumn("rank", F.row_number().over(window)) \
        .filter(F.col("rank") <= k) \
        .groupBy("user_id").agg(F.collect_list("item_id").alias("recommended_items"))

In [None]:
def precision_at_k(top_k_df, ground_truth_df, k):
    joined = top_k_df.join(ground_truth_df, "user_id")
    result = joined.withColumn("hits", F.size(F.array_intersect("recommended_items", "relevant_items"))) \
        .agg(F.avg(F.col("hits") / k)).collect()[0][0]
    return result or 0.0

In [None]:
def recall_at_k(top_k_df, ground_truth_df):
    joined = top_k_df.join(ground_truth_df, "user_id")
    result = joined.withColumn("hits", F.size(F.array_intersect("recommended_items", "relevant_items"))) \
        .withColumn("recall", F.when(F.size("relevant_items") > 0, F.col("hits") / F.size("relevant_items")).otherwise(0)) \
        .agg(F.avg("recall")).collect()[0][0]
    return result or 0.0

In [None]:
def ndcg_at_k(top_k_df, ground_truth_df, k):
    joined = top_k_df.join(ground_truth_df, "user_id")
    exploded = joined.select("user_id", "relevant_items", F.posexplode("recommended_items").alias("pos", "item_id"))
    
    with_dcg = exploded \
        .withColumn("rel", F.when(F.array_contains("relevant_items", "item_id"), 1.0).otherwise(0.0)) \
        .withColumn("dcg", F.col("rel") / F.log2(F.col("pos") + 2)) \
        .groupBy("user_id", "relevant_items").agg(F.sum("dcg").alias("dcg"))
    
    idcg_vals = [sum(1.0 / log2(i + 2) for i in range(n)) for n in range(k + 1)]
    idcg_map = F.create_map(*[x for i, v in enumerate(idcg_vals) for x in (F.lit(i), F.lit(v))])
    
    result = with_dcg \
        .withColumn("num_rel", F.least(F.size("relevant_items"), F.lit(k))) \
        .withColumn("idcg", idcg_map[F.col("num_rel")]) \
        .withColumn("ndcg", F.when(F.col("idcg") > 0, F.col("dcg") / F.col("idcg")).otherwise(0)) \
        .agg(F.avg("ndcg")).collect()[0][0]
    return result or 0.0

In [None]:
def evaluate(recs_df, score_col, name):
    if recs_df.count() == 0:
        print(f"{name}: No recommendations (not implemented)")
        return {"Precision@10": 0, "Recall@10": 0, "NDCG@10": 0}
    
    top_k = get_top_k(recs_df, score_col, K)
    p = precision_at_k(top_k, ground_truth, K)
    r = recall_at_k(top_k, ground_truth)
    n = ndcg_at_k(top_k, ground_truth, K)
    
    print(f"{name}: P@{K}={p:.4f}, R@{K}={r:.4f}, NDCG@{K}={n:.4f}")
    return {"Precision@10": p, "Recall@10": r, "NDCG@10": n}

### Evaluation

In [None]:
als_metrics = evaluate(als_recs, "als_score", "ALS")

In [None]:
content_metrics = evaluate(content_recs, "content_score", "Content-Based")

In [None]:
hybrid_metrics = evaluate(hybrid_recs, "final_score", "Hybrid")

### Bonus: GBT Re-Ranking

In [None]:
# TODO

## Results Summary

In [None]:
summary = [
    ("ALS", als_metrics["Precision@10"], als_metrics["Recall@10"], als_metrics["NDCG@10"]),
    ("Content-Based", content_metrics["Precision@10"], content_metrics["Recall@10"], content_metrics["NDCG@10"]),
    ("Hybrid", hybrid_metrics["Precision@10"], hybrid_metrics["Recall@10"], hybrid_metrics["NDCG@10"]),
]
spark.createDataFrame(summary, ["Model", "Precision@10", "Recall@10", "NDCG@10"]).show()