In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Suppress native-hadoop warning 
!sed -i '$a\# Add the line for suppressing the NativeCodeLoader warning \nlog4j.logger.org.apache.hadoop.util.NativeCodeLoader=ERROR,console' /$HADOOP_HOME/etc/hadoop/log4j.properties

In [3]:
import sys
sys.path.append('/home/work')

BASE_DIR = '/home/work'

In [4]:
import pyspark
from pyspark.sql import SparkSession, functions as F
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit, CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator

from data.utils.data_loader import load_from_hdfs

In [5]:
# Set Spark Settings
conf = pyspark.SparkConf().setAll([
    ('spark.master', 'local[8]'),
    ('spark.app.name', 'MusicRecommender'),
    # ('spark.executor.instances', '2'),  # Number of executors
    # ('spark.executor.cores', '8'),  # Cores per executor
    # ('spark.executor.memory', '10g'),  # Memory per executor
    ('spark.driver.memory','14g'),
])
spark = SparkSession.builder.config(conf=conf).getOrCreate()

# Print Spark Settings
settings = spark.sparkContext.getConf().getAll()
for s in settings:
    print(s)

('spark.driver.port', '33363')
('spark.master', 'local[8]')
('spark.executor.id', 'driver')
('spark.driver.host', '693f94dcf7da')
('spark.app.startTime', '1717122931426')
('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHa

In [6]:
dataset = 'processed/user_song_balanced'
model_name = 'user_song_balanced'
train_data, test_data = load_from_hdfs(dataset, 1)

AnalysisException: [PATH_NOT_FOUND] Path does not exist: hdfs://localhost:9000/data/processed/user_song_balanced/train/train_0.txt.

In [None]:
print("Train data size: ", train_data.count())
# print head
train_data.show(5)

In [None]:
train_data.printSchema()

In [None]:
# Define ALS model
als = ALS(userCol="user_id", itemCol="song_id", ratingCol="rating",
          coldStartStrategy="drop")
# als = ALS(userCol="user_id", itemCol="song_id", ratingCol="rating", coldStartStrategy="drop")

In [None]:
# Tune model hyperparameters
param_grid = ParamGridBuilder() \
    .addGrid(als.maxIter, [5]) \
    .addGrid(als.regParam, [0.1]) \
    .addGrid(als.rank, [10]) \
    .build()
    
# Define a model evaluator
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")

# Build cross validator
# crossval = CrossValidator(estimator=als,
#                             estimatorParamMaps=param_grid,
#                             evaluator=evaluator,
#                             numFolds=5)

tvs = TrainValidationSplit(estimator=als,
                            estimatorParamMaps=param_grid,
                            evaluator=evaluator,
                            trainRatio=0.8)

In [None]:
# Train ALS model
try:  
    # clear spark cache to avoid memory issues
    spark.catalog.clearCache()

    # Fit ALS model
    model = tvs.fit(train_data)
    # model = crossval.fit(train_data)

    # Get best model
    best_model = model.bestModel
    print(f'Best model: {best_model}')

    # Save the best model for later evaluation
    best_model.write().overwrite().save(f'file://{BASE_DIR}/models/als_model/{model_name}')
    
except Exception as e:
    print(f'Error training ALS model: {e}')

In [None]:
spark.stop()