In [1]:
# Suppress native-hadoop warning
!sed -i '$a\# Add the line for suppressing the NativeCodeLoader warning \nlog4j.logger.org.apache.hadoop.util.NativeCodeLoader=ERROR,console' /$HADOOP_HOME/etc/hadoop/log4j.properties

In [2]:
import sys
sys.path.append('/home/work')

BASE_DIR = '/home/work'

In [3]:
import pyspark
from pyspark.sql import SparkSession, functions as F
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
from pyspark.ml.evaluation import RegressionEvaluator

from data.utils.data_loader import load_from_hdfs

In [4]:
# Set Spark Settings
conf = pyspark.SparkConf().setAll([
    ('spark.master', 'local[*]'),
    ('spark.app.name', 'MusicRecommender'),
    # ('spark.executor.instances', '2'),  # Number of executors
    # ('spark.executor.cores', '8'),  # Cores per executor
    # ('spark.executor.memory', '10g'),  # Memory per executor
    ('spark.driver.memory','14g'),
])
spark = SparkSession.builder.config(conf=conf).getOrCreate()

# Print Spark Settings
settings = spark.sparkContext.getConf().getAll()
for s in settings:
    print(s)

('spark.app.startTime', '1715810974721')
('spark.driver.port', '33083')
('spark.executor.id', 'driver')
('spark.driver.host', '50278eaf4c0f')
('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false')
('spark.app.id',

In [5]:
datasets = ["raw"]

In [6]:
train_data, test_data = load_from_hdfs('raw', 1)

                                                                                

In [7]:
print("Train data size: ", train_data.count())
# print head
train_data.show(5)

                                                                                

Train data size:  76344627
+-------+-------+------+------------+
|user_id|song_id|rating|partition_id|
+-------+-------+------+------------+
|      0|    166|     5|           0|
|      0|   2245|     4|           0|
|      0|   3637|     4|           0|
|      0|   5580|     4|           0|
|      0|   5859|     4|           0|
+-------+-------+------+------------+
only showing top 5 rows



In [8]:
train_data.printSchema()

root
 |-- user_id: integer (nullable = true)
 |-- song_id: integer (nullable = true)
 |-- rating: integer (nullable = true)
 |-- partition_id: integer (nullable = false)



In [9]:
# Define ALS model
als = ALS(userCol="user_id", itemCol="song_id", ratingCol="rating",
          coldStartStrategy="drop")
# als = ALS(userCol="user_id", itemCol="song_id", ratingCol="rating", coldStartStrategy="drop")

In [10]:
# Tune model hyperparameters
param_grid = ParamGridBuilder() \
    .addGrid(als.maxIter, [5]) \
    .addGrid(als.regParam, [0.1]) \
    .addGrid(als.rank, [5]) \
    .build()
    
# Define a model evaluator
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")

# Build cross validator
tvs = TrainValidationSplit(estimator=als,
                            estimatorParamMaps=param_grid,
                            evaluator=evaluator)

In [11]:
# Train ALS model
try:  
    # clear spark cache to avoid memory issues
    spark.catalog.clearCache()

    # Fit ALS model
    model = tvs.fit(train_data)

    # Get best model
    best_model = model.bestModel
    print(f'Best model: {best_model}')

    # Save the best model for later evaluation
    best_model.save(f'file://{BASE_DIR}/models/als_model')
    
except Exception as e:
    print(f'Error training ALS model: {e}')

                                                                                

Best model: ALSModel: uid=ALS_bd27ca2bffba, rank=5


                                                                                

In [12]:
spark.stop()