In [6]:
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import col
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, StringType
from pyspark.sql import SparkSession, functions as F

In [2]:
# ====================== INITIALISATION ======================
spark = SparkSession.builder \
    .appName("TrainWeightedALSModel") \
    .master("yarn") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/05/02 09:52:42 WARN Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.


In [3]:
# ====================== CHARGEMENT DES DONN√âES ======================
weighted_schema = StructType([
    StructField("userId", IntegerType(), True),
    StructField("movieId", IntegerType(), True),
    StructField("rating", FloatType(), True),
    StructField("normalized_rating", FloatType(), True),
    StructField("confidence", FloatType(), True),
    StructField("user_mean", FloatType(), True),
    StructField("user_stddev", FloatType(), True),
    StructField("movie_count", IntegerType(), True),
    StructField("title", StringType(), True),
    StructField("genres", StringType(), True),
    StructField("timestamp", StringType(), True)
])

print("üì• Lecture du fichier CSV pond√©r√© depuis HDFS...")
# Si le fichier n'existe pas encore, ex√©cutez d'abord le script de pr√©traitement
try:
    weighted_df = spark.read.csv("hdfs:///processed/weighted_ratings.csv", header=True, schema=weighted_schema)
    print(f"‚úÖ Donn√©es pond√©r√©es charg√©es : {weighted_df.count()} lignes")
except:
    print("‚ùå Erreur: Fichier de ratings pond√©r√©s non trouv√©.")
    print("Ex√©cutez d'abord le script de pr√©traitement pour cr√©er hdfs:///processed/weighted_ratings.csv")
    exit(1)

üì• Lecture du fichier CSV pond√©r√© depuis HDFS...




‚úÖ Donn√©es pond√©r√©es charg√©es : 20000263 lignes


                                                                                

In [4]:
# ====================== NETTOYAGE ======================
# Supprimer les lignes avec valeurs nulles
weighted_df = weighted_df.dropna(subset=["userId", "movieId", "normalized_rating", "confidence"])

In [7]:
# ====================== DIVISION TRAIN / TEST ======================
train_df, test_df = weighted_df.randomSplit([0.8, 0.2], seed=42)
print(f"üìä Donn√©es divis√©es : {train_df.count()} pour l'entra√Ænement, {test_df.count()} pour le test")

# Log de quelques statistiques sur les donn√©es pond√©r√©es
print("üìà Aper√ßu des statistiques sur les donn√©es d'entra√Ænement:")
train_df.select(
    F.avg("confidence").alias("avg_conf"),
    F.min("confidence").alias("min_conf"),
    F.max("confidence").alias("max_conf"),
    F.avg("normalized_rating").alias("avg_norm_rating")
).show()

                                                                                

üìä Donn√©es divis√©es : 16000386 pour l'entra√Ænement, 3999877 pour le test
üìà Aper√ßu des statistiques sur les donn√©es d'entra√Ænement:


                                                                                

+-------------------+-----------+---------+--------------------+
|           avg_conf|   min_conf| max_conf|     avg_norm_rating|
+-------------------+-----------+---------+--------------------+
|0.02179536703537661|0.009849171|0.4738658|-3.17720320854605E-7|
+-------------------+-----------+---------+--------------------+



In [10]:
# ====================== ENTRA√éNEMENT ======================
print("ü§ñ Entra√Ænement du mod√®le ALS avec pond√©ration...")
als = ALS(
    userCol="userId",
    itemCol="movieId",
    ratingCol="confidence",   # utilis√© comme feedback implicite
    implicitPrefs=True,       # important !
    nonnegative=True,
    coldStartStrategy="drop",
    rank=12,
    maxIter=15,
    regParam=0.05
)

model = als.fit(train_df)

ü§ñ Entra√Ænement du mod√®le ALS avec pond√©ration...


25/05/02 10:01:32 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
                                                                                

In [11]:
# ====================== √âVALUATION ======================
print("üìä √âvaluation du mod√®le...")
# Faire des pr√©dictions
predictions = model.transform(test_df)

# D√©normaliser les pr√©dictions pour les ramener √† l'√©chelle d'origine
predictions_denorm = predictions.withColumn(
    "predicted_rating",
    F.col("prediction") * 
    F.when(F.col("user_stddev") > 0, F.col("user_stddev")).otherwise(1.0) + 
    F.col("user_mean")
)

# Appliquer des bornes [0.5, 5.0] aux pr√©dictions d√©normalis√©es
predictions_final = predictions_denorm.withColumn(
    "predicted_rating",
    F.when(F.col("predicted_rating") > 5.0, 5.0)
     .when(F.col("predicted_rating") < 0.5, 0.5)
     .otherwise(F.col("predicted_rating"))
)

# √âvaluer avec RMSE sur les notes originales vs pr√©dites (d√©normalis√©es)
evaluator = RegressionEvaluator(
    metricName="rmse",
    labelCol="rating",
    predictionCol="predicted_rating"
)
rmse = evaluator.evaluate(predictions_final)
print(f"‚úÖ RMSE sur l'ensemble test : {rmse:.4f}")

# √âgalement calculer MAE pour une √©valuation compl√©mentaire
evaluator_mae = RegressionEvaluator(
    metricName="mae",
    labelCol="rating",
    predictionCol="predicted_rating"
)
mae = evaluator_mae.evaluate(predictions_final)
print(f"‚úÖ MAE sur l'ensemble test : {mae:.4f}")


üìä √âvaluation du mod√®le...


25/05/02 10:04:02 WARN TaskSetManager: Lost task 2.0 in stage 680.0 (TID 895) (datanode2 executor 1): java.lang.StackOverflowError
	at java.base/java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1862)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2201)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:489)
	at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:447)
	at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:527)
	at jdk.internal.reflect.GeneratedMethodAccessor5.invoke(Unknown Source)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at java.base/java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1046)
	at java.base/java.io.ObjectInputStream.readSe

‚úÖ RMSE sur l'ensemble test : 0.9677


25/05/02 10:08:26 WARN TaskSetManager: Lost task 1.0 in stage 759.0 (TID 1152) (namenode executor 3): java.lang.StackOverflowError
	at java.base/java.io.ObjectInputStream$PeekInputStream.peek(ObjectInputStream.java:2871)
	at java.base/java.io.ObjectInputStream$BlockDataInputStream.peek(ObjectInputStream.java:3198)
	at java.base/java.io.ObjectInputStream$BlockDataInputStream.peekByte(ObjectInputStream.java:3208)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1638)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java

‚úÖ MAE sur l'ensemble test : 0.7375


                                                                                

In [None]:
# ====================== ANALYSE DES PR√âDICTIONS ======================
print("üîç Analyse des pr√©dictions...")
# Distribution des erreurs
predictions_final = predictions_final.withColumn(
    "error", 
    F.abs(F.col("rating") - F.col("predicted_rating"))
)

print("üìä Distribution des erreurs:")
predictions_final.select(
    F.avg("error").alias("avg_error"),
    F.expr("percentile(error, 0.5)").alias("median_error"),
    F.expr("percentile(error, 0.9)").alias("90th_percentile_error")
).show()


In [None]:
# ====================== ENREGISTREMENT ======================
print("üíæ Sauvegarde du mod√®le pond√©r√© dans HDFS (/models/als_weighted)...")
model.write().overwrite().save("hdfs:///models/als_weighted")

# Sauvegarder √©galement les statistiques utilisateurs pour la d√©normalisation future
user_stats = weighted_df.select("userId", "user_mean", "user_stddev").distinct()
user_stats.write.mode("overwrite").parquet("hdfs:///models/user_stats")

print("üéâ Mod√®le pond√©r√© entra√Æn√© et sauvegard√© avec succ√®s.")

In [None]:
spark.stop()