In [1]:
# 🔹 Imports
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lower, lit
from pyspark.ml.recommendation import ALS
from pyspark.ml.feature import StringIndexer
from pyspark.ml.evaluation import RegressionEvaluator
import utils.config as config

In [2]:
# 🔹 Create Spark session
spark = SparkSession.builder \
    .appName("CollaborativeFilteringSearch") \
    .getOrCreate()

In [3]:
# 🔹 Load review data
df = spark.read.json(config.PHILADELPHIA)
df = df.select("user_id", "name", "business_stars").dropna()
df = df.filter(df.business_stars > 0)

In [4]:

# 🔹 Index users and restaurants
user_indexer = StringIndexer(inputCol="user_id", outputCol="userIndex")
item_indexer = StringIndexer(inputCol="name", outputCol="itemIndex")
df = user_indexer.fit(df).transform(df)
df = item_indexer.fit(df).transform(df)

In [16]:
df.show()

+--------------------+--------------------+--------------+---------+---------+
|             user_id|                name|business_stars|userIndex|itemIndex|
+--------------------+--------------------+--------------+---------+---------+
|sqkiFAnk4gmL1LYmZ...|Waterfront Gourme...|           4.0| 195551.0|   1824.0|
|7RuSAc-Mslk4aizXX...|Waterfront Gourme...|           4.0|    425.0|   1824.0|
|xa0aM4h8FZYGHAFZ0...|Waterfront Gourme...|           4.0|  14729.0|   1824.0|
|QAVwfV8qy6meUjIZp...|Waterfront Gourme...|           4.0|   9279.0|   1824.0|
|nmmaBI8t0JN4Hkxay...|Waterfront Gourme...|           4.0|   1889.0|   1824.0|
|d_QKFVZuYDm0wzhkq...|Waterfront Gourme...|           4.0|  72361.0|   1824.0|
|0KCOEsM1WGKYUg6ey...|Waterfront Gourme...|           4.0|   2495.0|   1824.0|
|b47MFJu3LYjv6xmCR...|Waterfront Gourme...|           4.0|    831.0|   1824.0|
|ePmanjMTkYwpO65_9...|Waterfront Gourme...|           4.0|   4786.0|   1824.0|
|9b74lTGD6blywdUYt...|Waterfront Gourme...|         

In [5]:
%%time

# 🔹 Train ALS model
als = ALS(
    maxIter=10,
    regParam=0.1,
    userCol="userIndex",
    itemCol="itemIndex",
    ratingCol="business_stars",
    coldStartStrategy="drop",
    nonnegative=True
)
model = als.fit(df)

CPU times: total: 15.6 ms
Wall time: 40.7 s


In [17]:
%%time

# 🔹 Function to search and get similar restaurants
def recommend_similar_restaurants_by_name(query, top_n=5):
    matched_items = df.filter(lower(col("name")).contains(query.lower())) \
                      .select("name", "itemIndex").distinct().limit(1).collect()

    if not matched_items:
        print(f"No match found for: {query}")
        return

    item_idx = matched_items[0]["itemIndex"]
    matched_name = matched_items[0]["name"]
    print(f"🔍 Matched: {matched_name} (itemIndex={item_idx})")

    # Simulate recommendations for a dummy user
    dummy_user_index = 14729.0
    all_items = df.select("itemIndex").distinct().withColumn("userIndex", lit(dummy_user_index))
    predictions = model.transform(all_items)

    similar_items = predictions.orderBy(col("prediction").desc()) \
        .join(df.select("itemIndex", "name").distinct(), on="itemIndex") \
        .filter(col("itemIndex") != item_idx) \
        .dropDuplicates(["itemIndex"]) \
        .select("name", "prediction") \
        .limit(top_n)

    similar_items.show(truncate=False)

CPU times: total: 0 ns
Wall time: 0 ns


In [18]:
%%time
# 🔹 Example usage: Just like content-based
recommend_similar_restaurants_by_name("Real Food Eatery", top_n=5)

🔍 Matched: Real Food Eatery (itemIndex=461.0)
+-----------------------+----------+
|name                   |prediction|
+-----------------------+----------+
|Reading Terminal Market|4.258159  |
|Pat's King of Steaks   |2.8640873 |
|Sabrina's Café         |3.800693  |
|Green Eggs Café        |3.6720803 |
|Geno's Steaks          |2.384353  |
+-----------------------+----------+

CPU times: total: 0 ns
Wall time: 6.65 s
