In [0]:
# COMMAND ----------
# ONE-TIME FIX: Run this cell once to delete the old prediction table

storage_account = "finlakeadlsa3b3"
container_prediction = "prediction"
prediction_delta_path = f"abfss://{container_prediction}@{storage_account}.dfs.core.windows.net/delta/predicted_fraud"

print(f"Attempting to delete old prediction table at: {prediction_delta_path}")
dbutils.fs.rm(prediction_delta_path, recurse=True)
print("✅ Old prediction table successfully deleted. You can now rerun the main notebook.")

In [0]:
# Databricks notebook source
# =======================================================================================
# 04_train_fraud_model_pro
#
# Description:
#   1️⃣ Loads the feature-rich dataset created in the previous step.
#   2️⃣ Correctly handles the severe class imbalance inherent in fraud data.
#   3️⃣ Trains a powerful LightGBM classifier, which excels at this type of problem.
#   4️⃣ Integrates seamlessly with MLflow to track experiments, log models, metrics,
#      and artifacts (like a feature importance plot).
#   5️⃣ Registers the best model in the MLflow Model Registry for versioning and deployment.
#   6️⃣ Uses the registered model to perform batch prediction and saves the results.
#
# What's New (Professional Enhancements):
#   - ADVANCED MODEL: Replaced Logistic Regression with LightGBM for superior performance.
#   - IMBALANCE HANDLING: Implemented class weighting to force the model to pay
#     attention to the rare fraud cases.
#   - FULL MLOPS LIFECYCLE: Uses MLflow to log parameters, metrics (AUPRC, AUROC),
#     the model itself, and a feature importance plot. This is crucial for production.
#   - MODEL REGISTRATION: Automatically registers the trained model, making it
#     discoverable and ready for staging or production deployment.
#   - BATCH PREDICTION: Demonstrates the end-to-end loop by loading the registered
#     model to score the full dataset.
#   - MEMORY OPTIMIZATION: Includes data subsampling to prevent driver OOM errors on
#     resource-constrained clusters.
#   - ROBUST WRITES: Implemented dynamic partition overwrites to handle schema evolution
#     and ensure idempotent writes.
# =======================================================================================

import mlflow
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import DoubleType
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler

# Coordinate: com.microsoft.azure:synapseml_2.12:0.10.2
from synapse.ml.lightgbm import LightGBMClassifier

from pyspark.ml.evaluation import BinaryClassificationEvaluator

# --------------------------------------------------------------------------------------
# 1️⃣ Setup & Parameters
# --------------------------------------------------------------------------------------
spark = SparkSession.builder.appName("FraudModelTraining").getOrCreate()

# BEST PRACTICE: Set dynamic partition overwrite mode for idempotent and flexible writes
spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")

storage_account = "finlakeadlsa3b3"
container_feature = "feature"
container_prediction = "prediction"
scope = "finlake_scope"
ingest_date = "2025-10-10"

feature_delta_path = f"abfss://{container_feature}@{storage_account}.dfs.core.windows.net/delta/feature_transactions"
prediction_delta_path = f"abfss://{container_prediction}@{storage_account}.dfs.core.windows.net/delta/predicted_fraud"

mlflow.set_experiment(f"/Users/{dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()}/finlake-fraud-detection")
mlflow_model_name = "finlake_fraud_classifier"

# --------------------------------------------------------------------------------------
# 2️⃣ Helper function for ADLS Authentication (reuse from previous notebook)
# --------------------------------------------------------------------------------------
def setup_spark_adls_auth(spark, storage_account, scope):
    print(f"🔐 Authenticating to ADLS Gen2 storage account: {storage_account}...")
    client_id = dbutils.secrets.get(scope=scope, key="finlake-sp-client-id")
    tenant_id = dbutils.secrets.get(scope=scope, key="finlake-sp-tenant-id")
    client_secret = dbutils.secrets.get(scope=scope, key="finlake-sp-client-secret")
    
    spark.conf.set(f"fs.azure.account.auth.type.{storage_account}.dfs.core.windows.net", "OAuth")
    spark.conf.set(f"fs.azure.account.oauth.provider.type.{storage_account}.dfs.core.windows.net", "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider")
    spark.conf.set(f"fs.azure.account.oauth2.client.id.{storage_account}.dfs.core.windows.net", client_id)
    spark.conf.set(f"fs.azure.account.oauth2.client.secret.{storage_account}.dfs.core.windows.net", client_secret)
    spark.conf.set(f"fs.azure.account.oauth2.client.endpoint.{storage_account}.dfs.core.windows.net", f"https://login.microsoftonline.com/{tenant_id}/oauth2/token")
    print("✅ ADLS Gen2 authentication configured successfully.")

setup_spark_adls_auth(spark, storage_account, scope)

# --------------------------------------------------------------------------------------
# 3️⃣ Load Feature Data
# --------------------------------------------------------------------------------------
print(f"📂 Loading feature dataset for ingest_date = {ingest_date}...")
df_features = (
    spark.read.format("delta")
    .load(feature_delta_path)
    .filter(F.col("ingest_date") == ingest_date)
)

df_model_data = df_features.withColumnRenamed("is_fraud", "label")
print(f"✅ Loaded {df_model_data.count()} records.")

# --------------------------------------------------------------------------------------
# 4️⃣ Handle Class Imbalance & Optional Subsampling (Crucial for Memory Management)
# --------------------------------------------------------------------------------------
df_fraud = df_model_data.filter(F.col("label") == 1)
df_non_fraud = df_model_data.filter(F.col("label") == 0)

fraud_count = df_fraud.count()
non_fraud_count = df_non_fraud.count()
sample_fraction = (fraud_count * 10) / non_fraud_count

print(f"Fraud count: {fraud_count}, Non-fraud count: {non_fraud_count}")
print(f"Sampling non-fraud data with fraction: {sample_fraction:.4f} to reduce memory pressure.")

df_non_fraud_sampled = df_non_fraud.sample(withReplacement=False, fraction=sample_fraction, seed=42)

df_model_data_balanced = df_fraud.union(df_non_fraud_sampled)
print(f"Combined balanced dataset size: {df_model_data_balanced.count()} records.")

balance_ratio = df_model_data_balanced.filter(F.col("label") == 0).count() / df_model_data_balanced.count()
df_model_data = df_model_data_balanced.withColumn("weight", F.when(F.col("label") == 1, balance_ratio).otherwise(1 - balance_ratio))

print("⚖️ Class imbalance handled on subsampled data. Weights calculated for model training.")
df_model_data.groupBy("label").agg(F.count("*").alias("count"), F.first("weight").alias("weight")).show()

# --------------------------------------------------------------------------------------
# 5️⃣ Train/Test Split
# --------------------------------------------------------------------------------------
train_df, test_df = df_model_data.randomSplit([0.8, 0.2], seed=42)
print(f"📊 Training set: {train_df.count()}, Test set: {test_df.count()}")
train_df.cache()
test_df.cache()

# --------------------------------------------------------------------------------------
# 6️⃣ Train Model with MLflow Tracking
# --------------------------------------------------------------------------------------
with mlflow.start_run() as run:
    print(f"🚀 Starting MLflow Run: {run.info.run_id}")
    
    feature_cols = [
        'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10',
        'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20',
        'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount',
        'amount_log', 'avg_amount_user_24h', 'stddev_amount_user_24h',
        'txn_count_user_1h', 'txn_count_user_24h', 'amount_deviation_zscore'
    ]
    
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features", handleInvalid="skip")
    
    lgbm = LightGBMClassifier(
        featuresCol="features",
        labelCol="label",
        weightCol="weight",
        isUnbalance=True,
        objective="binary",
        learningRate=0.1,
        numLeaves=31
    )
    
    pipeline = Pipeline(stages=[assembler, lgbm])
    
    mlflow.log_params({
        "learning_rate": lgbm.getLearningRate(),
        "num_leaves": lgbm.getNumLeaves(),
        "is_unbalance": lgbm.getIsUnbalance(),
        "objective": lgbm.getObjective()
    })
    
    print("💪 Training LightGBM model...")
    pipeline_model = pipeline.fit(train_df)
    
    print("📈 Evaluating model on test data...")
    predictions = pipeline_model.transform(test_df)
    
    evaluator_pr = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderPR")
    auprc = evaluator_pr.evaluate(predictions)
    
    evaluator_roc = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
    auroc = evaluator_roc.evaluate(predictions)
    
    mlflow.log_metrics({"auprc": auprc, "auroc": auroc})
    print(f"✅ Metrics: AUPRC = {auprc:.4f}, AUROC = {auroc:.4f}")
    
    model = pipeline_model.stages[-1]
    importances = model.getFeatureImportances()

    importance_df = spark.createDataFrame(
        zip(feature_cols, map(float, importances)),
        ["feature", "importance"]
    )

    top_20_features_pd = (
        importance_df
        .orderBy(F.col("importance").desc())
        .limit(20)
        .toPandas()
    )
    
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.barplot(x="importance", y="feature", data=top_20_features_pd, ax=ax)
    plt.title("Top 20 Feature Importances")
    plt.tight_layout()
    mlflow.log_figure(fig, "feature_importance.png")
    print("🎨 Feature importance plot logged.")
    
    print(f"✅ Logging and registering model as '{mlflow_model_name}'...")
    mlflow.spark.log_model(
        spark_model=pipeline_model,
        artifact_path="model",
        registered_model_name=mlflow_model_name
    )
    print("🎉 MLflow run completed successfully!")

train_df.unpersist()
test_df.unpersist()
print(" освободил unpersisted train and test DataFrames.")

# --------------------------------------------------------------------------------------
# 7️⃣ Batch Prediction with the Registered Model
# --------------------------------------------------------------------------------------
print("\n--- Performing Batch Prediction ---")
model_uri = f"models:/{mlflow_model_name}/latest"
loaded_model = mlflow.spark.load_model(model_uri)

print(f"🤖 Loaded model version from '{model_uri}' for batch scoring.")

final_predictions = loaded_model.transform(df_features.withColumnRenamed("is_fraud", "label"))

def get_fraud_probability(prob_vector):
    """Extracts the probability of the positive class (fraud)."""
    try:
        return float(prob_vector[1])
    except (IndexError, TypeError):
        return 0.0

extract_prob_udf = F.udf(get_fraud_probability, DoubleType())

output_df = (
    final_predictions
    .withColumn("predicted_fraud", F.col("prediction").cast("int"))
    .withColumn("score", extract_prob_udf(F.col("probability")))
    .withColumn("detection_ts", F.current_timestamp())
    .select(
        "Time", "V1", "Amount",
        "label", "predicted_fraud", "score",
        "detection_ts", "ingest_date"
    )
)

# --------------------------------------------------------------------------------------
# 8️⃣ Write Predictions to Delta
# --------------------------------------------------------------------------------------
print(f"💾 Writing predictions to: {prediction_delta_path}")
(
    output_df.write
    .format("delta")
    .mode("overwrite")
    .partitionBy("ingest_date")
    .option("mergeSchema", "true").save(prediction_delta_path)
)

print(f"✅ Fraud predictions written successfully.")
print("🎯 End-to-end model training and prediction pipeline completed!")

