# Entraînement du modèle ML avec Spark

Ce notebook entraîne un modèle de détection de fraudes sur plus de 13 millions de transactions.

**Pipeline ML:**

1. Chargement des données avec spark

2. Feature engineering

3. Entraînement **Random Forest** avec Spark MLlib

4. Evaluation et métriques

5. Sauvegarde du modèle

In [1]:
import os
import sys
from dotenv import load_dotenv

#  Charger les variables d'environnement
load_dotenv()

# Configuration Hadoop (CRITIQUE pour Parquet sur Windows)
os.environ['HADOOP_HOME'] = os.getenv('HADOOP_HOME') #'C:/hadoop'
#os.environ['SPARK_HOME'] = 'C:/spark/spark-3.5.0-bin-hadoop3'
#os.environ['PATH'] = 'C:/hadoop/bin;' + os.environ.get('PATH', '')
hadoop_bin = os.path.join(os.environ.get('HADOOP_HOME'), 'bin')
os.environ['PATH'] = hadoop_bin + os.environ.get('PATH', '')

# Vérification
from pathlib import Path
path_winutils = os.path.join(os.environ.get('HADOOP_HOME'), 'bin', 'winutils.exe')
winutils = Path(path_winutils)
path_hadoop_dll = os.path.join(os.environ.get('HADOOP_HOME'), 'bin', 'hadoop.dll')
hadoop_dll = Path(path_hadoop_dll)

print(f"winutils.exe: {'OK' if winutils.exists() else 'NOT FOUND'}")
print(f"hadoop.dll: {'OK' if hadoop_dll.exists() else 'NOT FOUND'}")
print(f"HADOOP_HOME: {os.environ.get('HADOOP_HOME')}")

winutils.exe: OK
hadoop.dll: OK
HADOOP_HOME: C:/hadoop


In [2]:
from pathlib import Path

sys.path.insert(0, '..')


from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from src.utils.logger import setup_logger

logger = setup_logger(__name__)

## 1. Initialisation de Spark

In [None]:
# Création de la session Spark avec configuration optimisée pour gros dataset
spark = SparkSession.builder \
    .appName("FraudDetectionTraining") \
    .master("local[*]") \
    .config("spark.driver.memory", "6g") \
    .config("spark.executor.memory", "6g") \
    .config("spark.driver.maxResultSize", "2g") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.default.parallelism", "8") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "false") \
    .config("spark.hadoop.io.nativeio.enabled", "false") \
    .config("spark.memory.fraction", "0.8") \
    .config("spark.memory.storageFraction", "0.3") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN") # Afficher uniquement les logs important
logger.info(f"** Spark {spark.version} démarré")
logger.info(f"** Master: {spark.sparkContext.master}")

[2025-12-03 11:24:14 - __main__] - INFO - ** Spark 3.5.3 démarré
[2025-12-03 11:24:14 - __main__] - INFO - ** Master: local[*]


## 2. Chargement des données (Traitement distribué)

In [4]:
logger.info(" Chargement des données")

data_path = Path("../data/historical/").resolve()
print(str(data_path))
df = spark.read.parquet(str(data_path))

logger.info(f"=\"=\"=\" Données chargées: {df.count():,} transactions =\"=\"=\"")
logger.info(f"=\"=\"=\" Partitions Spark: {df.rdd.getNumPartitions()} =\"=\"=\"")

#df_sample = df.sample(fraction=0.1, seed=42)
#sample_count = df_sample.count()

df.printSchema()

[2025-12-03 11:24:30 - __main__] - INFO -  Chargement des données


E:\Ecole\AS3\Semestre1\Big_Data\Spark\fraud-detection-spark\data\historical


[2025-12-03 11:24:43 - __main__] - INFO - ="="=" Données chargées: 13,779,047 transactions ="="="
[2025-12-03 11:24:43 - __main__] - INFO - ="="=" Partitions Spark: 9 ="="="


root
 |-- transaction_id: string (nullable = true)
 |-- user_id: string (nullable = true)
 |-- timestamp: string (nullable = true)
 |-- amount: double (nullable = true)
 |-- merchant_id: string (nullable = true)
 |-- merchant_category: string (nullable = true)
 |-- location_lat: double (nullable = true)
 |-- location_lon: double (nullable = true)
 |-- device_id: string (nullable = true)
 |-- is_online: boolean (nullable = true)
 |-- is_fraud: long (nullable = true)
 |-- card_last_4: string (nullable = true)
 |-- cvv_provided: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)



In [5]:
# Aperçu des données
df.show(5, truncate=False)

+----------------+-----------+--------------------------+------+-------------+-----------------+------------+------------+---------------+---------+--------+-----------+------------+----+-----+
|transaction_id  |user_id    |timestamp                 |amount|merchant_id  |merchant_category|location_lat|location_lon|device_id      |is_online|is_fraud|card_last_4|cvv_provided|year|month|
+----------------+-----------+--------------------------+------+-------------+-----------------+------------+------------+---------------+---------+--------+-----------+------------+----+-----+
|TRX_F5FABF0C3039|USER_000000|2024-03-01T19:30:24.738146|124.7 |MERCHANT_8957|hotel            |-31.310658  |135.126117  |DEVICE_6888d29c|false    |0       |5765       |true        |2024|3    |
|TRX_13629D668F3C|USER_000000|2024-03-01T16:19:07.738146|57.49 |MERCHANT_3009|supermarket      |-31.209713  |135.006961  |DEVICE_6888d29c|false    |0       |5765       |true        |2024|3    |
|TRX_E9D863EB71FE|USER_000001|

## 3. Statistiques exploratoires

In [7]:
# Distribution fraudes vs normales
fraud_stats = df.groupBy("is_fraud").count().toPandas()

fig = px.pie(fraud_stats,
             values='count',
             names='is_fraud',
             title='Distribution: Normal vs Fraude',
             labels={'is_fraud': 'Type'})

fig.show()

# Taux de fraude
fraud_rate = fraud_stats[fraud_stats['is_fraud']==1]['count'].values[0] / fraud_stats['count'].sum() * 100
print(f"\n ** Taux de fraude: {fraud_rate:.2f}%")


 ** Taux de fraude: 2.00%


In [8]:
# Statistiques par type
df.groupBy("is_fraud").agg(
    count("*").alias("count"),
    avg("amount").alias("avg_amount"),
    stddev("amount").alias("std_amount"),
    min("amount").alias("min_amount"),
    max("amount").alias("max_amount")
).show()

+--------+--------+------------------+------------------+----------+----------+
|is_fraud|   count|        avg_amount|        std_amount|min_amount|max_amount|
+--------+--------+------------------+------------------+----------+----------+
|       0|13503515|109.23591536647827|109.13321166470813|      15.0|     490.0|
|       1|  275532| 634.8797072572339| 598.1456913776228|      50.0|   4806.07|
+--------+--------+------------------+------------------+----------+----------+



## 4. Feature Engineering

In [5]:
logger.info("*** Feature Engineering ***")

# Convertir timestamp en datetime
df = df.withColumn("timestamp", to_timestamp(col("timestamp")))

# Features temporelles
df = df.withColumn("hour_of_day", hour(col("timestamp"))) \
       .withColumn("day_of_week", dayofweek(col("timestamp"))) \
       .withColumn("is_weekend", when((dayofweek("timestamp") == 1) | (dayofweek("timestamp") == 7), 1).otherwise(0)) \
       .withColumn("is_unusual_hour", when((hour("timestamp") >= 2) & (hour("timestamp") <= 5), 1).otherwise(0))

# Features par montant
df = df.withColumn("amount_log", log10(col("amount") + 1)) \
       .withColumn("is_high_amount", when(col("amount") > 500, 1).otherwise(0)) \
       .withColumn("is_round_amount", when(col("amount") % 10 == 0, 1).otherwise(0))

# Features marchand
df = df.withColumn("is_high_risk_merchant", 
                   when(col("merchant_category").isin(['electronics', 'online_shopping', 'airline', 'hotel']), 1)
                   .otherwise(0))

# Features online/offline
df = df.withColumn("is_online_int", when(col("is_online") == True, 1).otherwise(0))

logger.info("*** Features créées ***")

df.select(["transaction_id", "amount", "hour_of_day", "is_unusual_hour", "is_high_amount", "is_fraud"]).show(5)

[2025-12-03 11:25:03 - __main__] - INFO - *** Feature Engineering ***
[2025-12-03 11:25:03 - __main__] - INFO - *** Features créées ***


+----------------+------+-----------+---------------+--------------+--------+
|  transaction_id|amount|hour_of_day|is_unusual_hour|is_high_amount|is_fraud|
+----------------+------+-----------+---------------+--------------+--------+
|TRX_F5FABF0C3039| 124.7|         19|              0|             0|       0|
|TRX_13629D668F3C| 57.49|         16|              0|             0|       0|
|TRX_E9D863EB71FE| 53.98|         16|              0|             0|       0|
|TRX_28512CFAE42F| 72.27|         22|              0|             0|       0|
|TRX_A8BF537FC4CE|196.11|          8|              0|             0|       0|
+----------------+------+-----------+---------------+--------------+--------+
only showing top 5 rows



## 5. Préparation des données pour le modèle ML

In [6]:
# Sélection des features utiles pour le modèle
feature_columns = [
    'amount', 'amount_log', 'hour_of_day', 'day_of_week',
    'is_weekend', 'is_unusual_hour', 'is_high_amount',
    'is_round_amount', 'is_high_risk_merchant', 'is_online_int',
    'location_lat', 'location_lon'
]

In [7]:
# Assembler les features en un vecteur
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
print(assembler.getInputCols())
print(assembler.getOutputCol())

['amount', 'amount_log', 'hour_of_day', 'day_of_week', 'is_weekend', 'is_unusual_hour', 'is_high_amount', 'is_round_amount', 'is_high_risk_merchant', 'is_online_int', 'location_lat', 'location_lon']
features


In [8]:
# Renommer la colonne target
df = df.withColumnRenamed("is_fraud", "label")

# Appliquer l'assembleur
df_ml = assembler.transform(df).select("features", "label")

df_ml.show(5, truncate=False)

+---------------------------------------------------------------------------------+-----+
|features                                                                         |label|
+---------------------------------------------------------------------------------+-----+
|[124.7,2.0993352776859577,19.0,6.0,0.0,0.0,0.0,0.0,1.0,0.0,-31.310658,135.126117]|0    |
|(12,[0,1,2,3,10,11],[57.49,1.7670816213633223,16.0,6.0,-31.209713,135.006961])   |0    |
|(12,[0,1,2,3,10,11],[53.98,1.7402047355074497,16.0,6.0,74.662357,86.506162])     |0    |
|(12,[0,1,2,3,10,11],[72.27,1.8649261915390054,22.0,6.0,13.188504,176.108534])    |0    |
|[196.11,2.294708657940791,8.0,6.0,0.0,0.0,0.0,0.0,1.0,0.0,13.191124,176.330775]  |0    |
+---------------------------------------------------------------------------------+-----+
only showing top 5 rows



## 6. Séparation du jeu de données

In [10]:
from pyspark import StorageLevel

# Split stratifié 
train_data, test_data = df_ml.randomSplit([0.5, 0.1], seed=42)

# IMPORTANT : Persister en mémoire et disque
train_data.persist(StorageLevel.MEMORY_AND_DISK)
test_data.persist(StorageLevel.MEMORY_AND_DISK)

# Cache en mémoire pour accélérer l'entraînement
train_data.cache()
test_data.cache()

# Forcer le cache
_ = train_data.count()
_ = test_data.count()

# Vérifier l'équilibre
print("\n*** Distribution Train:")
train_data.groupBy("label").count().show()
print("\n*** Distribution Test:")
test_data.groupBy("label").count().show()


*** Distribution Train:
+-----+--------+
|label|   count|
+-----+--------+
|    0|11252384|
|    1|  229549|
+-----+--------+


*** Distribution Test:
+-----+-------+
|label|  count|
+-----+-------+
|    0|2251131|
|    1|  45983|
+-----+-------+



## 7. Entraînement d'un modèle **Random Forest** (Spark distribué)

In [11]:
from datetime import datetime

logger.info("*** Entraînement du modèle ***")

# Configuration du random forest
logger.info("Configuration du random forest")

rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="label",
    numTrees=50,              # Nombre d'arbres
    maxDepth=8,               # Profondeur max
    maxBins=32,                # Bins pour variables continues
    seed=42,
    subsamplingRate=0.8,       # Sous-échantillonnage
    featureSubsetStrategy="sqrt" # moins de features par arbre
)

# Entraînement du modèle
# Notons que Spark va paralléliser automatiquement le traitement
logger.info("Entraînement du modèle (Spark va paralléliser automatiquement)")

start_time = datetime.now()

try:
    model = rf.fit(train_data)

    end_time = datetime.now()
    duration = end_time - start_time # durée en secondes

    logger.info("Entraînement terminé !")
    logger.info(f"Durée: {duration} secondes")
except Exception as e:
    logger.error(f"Erreur lors de l'entraînement du modèle: {e}", exc_info=True)

[2025-12-03 11:33:18 - __main__] - INFO - *** Entraînement du modèle ***
[2025-12-03 11:33:18 - __main__] - INFO - Configuration du random forest
[2025-12-03 11:33:19 - __main__] - INFO - Entraînement du modèle (Spark va paralléliser automatiquement)
[2025-12-03 11:45:02 - __main__] - INFO - Entraînement terminé !
[2025-12-03 11:45:02 - __main__] - INFO - Durée: 0:11:42.888181 secondes


## 8. Feature importance

In [12]:
# Extraction de l'importance des features
logger.info("*** Feature importance ***")

feature_importance = pd.DataFrame({
    'feature': feature_columns,
    'importance': model.featureImportances.toArray()
}).sort_values('importance', ascending=False)

print(feature_importance)

[2025-12-03 12:05:08 - __main__] - INFO - *** Feature importance ***


                  feature  importance
6          is_high_amount    0.416534
8   is_high_risk_merchant    0.190156
0                  amount    0.157812
1              amount_log    0.112463
2             hour_of_day    0.060759
5         is_unusual_hour    0.052361
9           is_online_int    0.009799
10           location_lat    0.000054
11           location_lon    0.000052
3             day_of_week    0.000007
7         is_round_amount    0.000001
4              is_weekend    0.000001


In [17]:
# Visualisation du feature importance
fig = px.bar(
    feature_importance,
    x='importance',
    y='feature',
    orientation='h',
    title='Importance des Features (Random Forest)',
    labels={'importance': 'Importance', 'feature': 'Feature'}
)
fig.update_yaxes(categoryorder='total ascending')

fig.show()

## 9. Evaluation sur l'échantillon de test

In [20]:
# Prédictions sur l'échantillon de test
logger.info("*** Évaluation du modèle ***")
predictions = model.transform(test_data)

# Afficher quelques prédictions
predictions.select("label", "prediction", "probability").show(10, truncate=False)

[2025-12-03 12:11:43 - __main__] - INFO - *** Évaluation du modèle ***


+-----+----------+------------------------------------------+
|label|prediction|probability                               |
+-----+----------+------------------------------------------+
|0    |0.0       |[0.9988806453142064,0.0011193546857935874]|
|0    |0.0       |[0.9987341195603312,0.0012658804396687242]|
|0    |0.0       |[0.9986188355193637,0.0013811644806362522]|
|0    |0.0       |[0.9990085582676472,9.914417323526935E-4] |
|0    |0.0       |[0.9987605700396743,0.00123942996032572]  |
|0    |0.0       |[0.9986394289665597,0.0013605710334402617]|
|0    |0.0       |[0.9986188355193637,0.0013811644806362522]|
|0    |0.0       |[0.9988797094662649,0.0011202905337350931]|
|0    |0.0       |[0.9987663869112242,0.0012336130887757098]|
|0    |0.0       |[0.9986188355193637,0.0013811644806362522]|
+-----+----------+------------------------------------------+
only showing top 10 rows



In [21]:
# Confusion matrix
confusion_matrix = predictions.groupBy("label", "prediction").count().toPandas()
confusion_matrix = confusion_matrix.pivot(index='label', columns='prediction', values='count').fillna(0)

print(confusion_matrix)

# Visualisation
fig = px.imshow(
    confusion_matrix,
    labels=dict(x="Prédiction", y="Réel", color="Count"),
    x=['Normal', 'Fraude'],
    y=['Normal', 'Fraude'],
    title='Confusion Matrix',
    text_auto=True
)

fig.show()

prediction      0.0    1.0
label                     
0           2251120     11
1              7442  38541


In [22]:
# Calcul des métriques
logger.info("*** Calcul des métriques ***")
evaluator_auc = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderROC")
evaluator_pr = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderPR")
evaluator_acc = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
evaluator_prec = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="weightedPrecision")
evaluator_rec = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="weightedRecall")
evaluator_f1 = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")
auc = evaluator_auc.evaluate(predictions)
pr_auc = evaluator_pr.evaluate(predictions)
accuracy = evaluator_acc.evaluate(predictions)
precision = evaluator_prec.evaluate(predictions)
recall = evaluator_rec.evaluate(predictions)
f1 = evaluator_f1.evaluate(predictions)
logger.info("Métriques calculées !")

# Affichage
for metric_name, metric_value in zip(
    ["Accuracy", "Precision", "Recall", "F1-Score", "AUC-ROC", "AUC-PR"],
    [accuracy, precision, recall, f1, auc, pr_auc]
):
    print(f"{metric_name}: {metric_value*100:.2f}%")
    logger.info(f"{metric_name}: {metric_value*100:.2f}%")

[2025-12-03 12:17:25 - __main__] - INFO - *** Calcul des métriques ***
[2025-12-03 12:18:50 - __main__] - INFO - Métriques calculées !
[2025-12-03 12:18:50 - __main__] - INFO - Accuracy: 99.68%
[2025-12-03 12:18:50 - __main__] - INFO - Precision: 99.68%
[2025-12-03 12:18:50 - __main__] - INFO - Recall: 99.68%
[2025-12-03 12:18:50 - __main__] - INFO - F1-Score: 99.66%
[2025-12-03 12:18:50 - __main__] - INFO - AUC-ROC: 97.57%
[2025-12-03 12:18:50 - __main__] - INFO - AUC-PR: 88.95%


Accuracy: 99.68%
Precision: 99.68%
Recall: 99.68%
F1-Score: 99.66%
AUC-ROC: 97.57%
AUC-PR: 88.95%


## 10. Sauvegarde du modèle

In [23]:
import json

# Sauvegarde du modèle et des métadonnées
logger.info("*** Sauvegarde du modèle ***")
try:
  model_path = "../data/models/random_forest_fraud_detector"
  model.write().overwrite().save(model_path)
  logger.info(f"--- Modèle sauvegardé dans: {model_path}")

  # Sauvegarder les métadonnées
  metadata = {
      "model_type": "RandomForestClassifier",
      "num_trees": rf.getNumTrees(),
      "max_depth": rf.getMaxDepth(),
      "training_date": datetime.now().isoformat(),
      "training_samples": train_data.count(),
      "test_samples": test_data.count(),
      "accuracy": accuracy,
      "precision": precision,
      "recall": recall,
      "f1_score": f1,
      "auc_roc": auc,
      "features": feature_columns,
      "feature_importance": feature_importance.to_dict("records")
  }

  with open("../data/models/model_metadata.json", "w") as f:
      json.dump(metadata, f, indent=2)
  logger.info("--- Métadonnées sauvegardées ---")
except Exception as e:
  logger.error(f"Erreur lors de la sauvegarde du modèle: {e}", exc_info=True)

[2025-12-03 12:25:45 - __main__] - INFO - *** Sauvegarde du modèle ***
[2025-12-03 12:25:51 - __main__] - INFO - --- Modèle sauvegardé dans: ../data/models/random_forest_fraud_detector
[2025-12-03 12:25:54 - __main__] - INFO - --- Métadonnées sauvegardées ---


### **FIN DE L'ENTRAÎNEMENT DU MODELE**