In [8]:
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import StringIndexer
from pyspark.sql.functions import col

# Initialize Spark session
spark = SparkSession.builder \
    .appName("RecommendationSystem") \
    .getOrCreate()

# Load the dataset into a DataFrame
file_path = "movies1.json"
df = spark.read.json(file_path)

# Prepare the data
data = df.select(col("user_id").alias("user"), 
                 col("product_id").alias("item"), 
                 col("score").alias("rating"))

# Index user and item columns
user_indexer = StringIndexer(inputCol="user", outputCol="userIndex")
item_indexer = StringIndexer(inputCol="item", outputCol="itemIndex")

data_indexed = user_indexer.fit(data).transform(data)
data_indexed = item_indexer.fit(data_indexed).transform(data_indexed)

# Select only the indexed columns
data_indexed = data_indexed.select(col("userIndex").alias("user"), 
                                   col("itemIndex").alias("item"), 
                                   col("rating"))

# Split the data into training and test sets
(training_data, test_data) = data_indexed.randomSplit([0.8, 0.2])

# Initialize ALS model
als = ALS(
    maxIter=10, 
    regParam=0.1, 
    rank=10, 
    userCol="user", 
    itemCol="item", 
    ratingCol="rating", 
    coldStartStrategy="drop"
)

# Fit the model
model = als.fit(training_data)

# Make predictions
predictions = model.transform(test_data)

# Evaluate the model
evaluator = RegressionEvaluator(
    metricName="rmse", 
    labelCol="rating", 
    predictionCol="prediction"
)
rmse = evaluator.evaluate(predictions)
print(f"Root-Mean-Square Error = {rmse:.4f}")
predictions.show()
# Stop the Spark session
spark.stop()




Root-Mean-Square Error = 1.8555
+------+------+------+----------+
|  user|  item|rating|prediction|
+------+------+------+----------+
|1959.0|  78.0|   5.0|  2.115138|
| 540.0|   7.0|   5.0|  4.985047|
|2393.0|  21.0|   5.0| 4.8314753|
|5670.0|   7.0|   5.0|  1.937971|
|1005.0|   7.0|   1.0| 0.9697579|
| 362.0|  37.0|   5.0| 4.1891227|
|3441.0|1117.0|   2.0| 2.0044537|
| 126.0|  63.0|   4.0|  4.014833|
| 830.0|   7.0|   4.0| 2.0984683|
| 183.0|  63.0|   5.0| 3.3185565|
|3445.0| 961.0|   1.0|  2.087726|
|4208.0| 303.0|   5.0|    4.9335|
|4958.0|  21.0|   5.0| 4.8314753|
|1415.0|  83.0|   4.0| 1.8284632|
|1415.0|  85.0|   3.0| 2.2572324|
| 412.0|   7.0|   5.0| 3.9537559|
|1030.0|  79.0|   3.0| 1.2394274|
|2871.0|  21.0|   5.0| 4.8314753|
|  26.0|  78.0|   4.0|  3.913481|
|  27.0|  63.0|   5.0| 3.9538033|
+------+------+------+----------+
only showing top 20 rows

