In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Suppress native-hadoop warning
!sed -i '$a\# Add the line for suppressing the NativeCodeLoader warning \nlog4j.logger.org.apache.hadoop.util.NativeCodeLoader=ERROR,console' /$HADOOP_HOME/etc/hadoop/log4j.properties

In [3]:
import sys
sys.path.append('/home/work')

BASE_DIR = '/home/work'

In [4]:
import pyspark
from pyspark.sql import SparkSession, functions as F
from pyspark.ml import Pipeline
from pyspark.ml.recommendation import ALSModel
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType

from data.utils.data_loader import load_from_hdfs
from models.utils import load_model
from models.evaluation_metrics import calculate_rmse, calculate_mae, calculate_song_coverage, calculate_user_coverage

In [5]:
# Set Spark Settings
conf = pyspark.SparkConf().setAll([
    ('spark.master', 'local[4]'),
    ('spark.app.name', 'MusicRecommender'),
    ('spark.driver.memory','14g'),
])
spark = SparkSession.builder.config(conf=conf).getOrCreate()

# Print Spark Settings
settings = spark.sparkContext.getConf().getAll()
for s in settings:
    print(s)

('spark.master', 'local[4]')
('spark.driver.port', '36169')
('spark.executor.id', 'driver')
('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false')
('spark.app.submitTime', '1716700201254')
('spark.app.startTime', 

## Evaluate Models

In [6]:
# Results DataFrame
results = []
model_dirs = ['als_model']
model_types = ['ALS']
datasets = ['raw']
partitions = [1]

# Static Variables
total_size = 717872016
total_users = 1823179
total_songs = 136736

for model_dir, model_type in zip(model_dirs, model_types):
    for dataset in datasets: 
        for par in partitions:            
            # Load Data
            train_data, test_data = load_from_hdfs(dataset, par)
            
            # Load Model
            model_path = f'file://{BASE_DIR}/models/{model_dir}'
            model = load_model(model_type, model_path)
            
            # Get Predictions
            predictions = model.transform(test_data)
            
            # Size Metrics
            train_size = train_data.count()
            test_size = test_data.count()
            
            # User Metrics
            train_user_ids = train_data.select('user_id').distinct()
            test_user_ids = test_data.select('user_id').distinct()
            prediction_user_ids = predictions.select('user_id').distinct()
            unique_user_ids = train_user_ids.union(test_user_ids).distinct()
            train_users = train_user_ids.count()
            test_users = test_user_ids.count()
            unique_users = unique_user_ids.count()
            prediction_users = prediction_user_ids.count()
            
            # Song Metrics
            train_song_ids = train_data.select('song_id').distinct()
            test_song_ids = test_data.select('song_id').distinct()
            prediction_song_ids = predictions.select('song_id').distinct()
            unique_song_ids = train_song_ids.union(test_song_ids).distinct()
            train_songs = train_song_ids.count()
            test_songs = test_song_ids.count()
            prediction_songs = prediction_song_ids.count()
            unique_songs = unique_song_ids.count()
                                    
            # Coverage Metrics
            song_coverage = calculate_song_coverage(train_songs, prediction_songs)
            overall_song_coverage = calculate_song_coverage(total_songs, prediction_songs)
            user_coverage = calculate_user_coverage(train_users, prediction_users)
            overall_user_coverage = calculate_user_coverage(total_users, prediction_users)
            
            # Evaluation Metrics
            rmse = calculate_rmse(predictions)
            mae = calculate_mae(predictions)
            
            results.append({
                'Model': model_type,
                'Dataset': dataset,
                'Users(Train:Test)': f'{train_users} : {test_users}',
                'Songs(Train:Test)': f'{train_songs} : {test_songs}',
                'User Coverage(Train/Test:Overall)': f'{round(user_coverage, 2)} : {round(overall_user_coverage, 2)}',
                'Song Coverage(Train/Test:Overall)': f'{round(song_coverage, 2)} : {round(overall_song_coverage, 2)}',
                'RMSE': rmse,
                'MAE': mae,
            })

                                                                                

## Evaluation Results

In [7]:
# Evaluation Results Schema
schema = StructType([
    StructField("Model", StringType(), True),
    StructField("Dataset", StringType(), True),
    StructField("Users(Train:Test)", StringType(), True),
    StructField("Songs(Train:Test)", StringType(), True),
    StructField("User Coverage(Train/Test:Overall)", StringType(), True),
    StructField("Song Coverage(Train/Test:Overall)", StringType(), True),
    # StructField("Precision", FloatType(), True),
    StructField("RMSE", FloatType(), True),
    StructField("MAE", FloatType(), True),
])

# Output Results as DF
results_df = spark.createDataFrame(results, schema)
results_df.show(truncate=False)

+-----+-------+-----------------+-----------------+---------------------------------+---------------------------------+---------+----------+
|Model|Dataset|Users(Train:Test)|Songs(Train:Test)|User Coverage(Train/Test:Overall)|Song Coverage(Train/Test:Overall)|RMSE     |MAE       |
+-----+-------+-----------------+-----------------+---------------------------------+---------------------------------+---------+----------+
|ALS  |raw    |200000 : 200000  |136736 : 127771  |1.0 : 0.11                       |0.93 : 0.93                      |1.1547565|0.91526073|
+-----+-------+-----------------+-----------------+---------------------------------+---------------------------------+---------+----------+

