In [0]:
# Read Gold customer metrics
gold_df = spark.table("ecommerce.gold.gold_customer_metrics")

display(gold_df)

CustomerID,total_spend,total_orders,last_purchase_date,days_inactive,churn
12451,9035.52,5,2011-11-29T08:40:00.000Z,10,0
14397,2612.96,19,2011-12-07T09:22:00.000Z,2,0
16557,281.85,2,2011-11-15T14:04:00.000Z,24,0
15277,255.9,1,2011-10-24T10:08:00.000Z,46,0
13884,787.6,5,2011-12-02T16:30:00.000Z,7,0
12604,254.18,2,2011-09-21T14:08:00.000Z,79,0
12354,1079.4,1,2011-04-21T13:11:00.000Z,232,1
15830,576.0,1,2011-11-02T10:40:00.000Z,37,0
14408,2606.53,7,2011-11-29T11:41:00.000Z,10,0
16054,783.8999999999999,1,2011-07-17T13:08:00.000Z,145,1


In [0]:
# Prepare ML Dataset
from pyspark.ml.feature import VectorAssembler

feature_cols = ["total_spend", "total_orders", "days_inactive"]

assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features"
)

ml_df = assembler.transform(gold_df).select("features", "churn")
display(ml_df)


features,churn
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""9035.519999999999"",""5.0"",""10.0""]}",0
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""2612.9599999999996"",""19.0"",""2.0""]}",0
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""281.85"",""2.0"",""24.0""]}",0
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""255.89999999999998"",""1.0"",""46.0""]}",0
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""787.6"",""5.0"",""7.0""]}",0
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""254.18"",""2.0"",""79.0""]}",0
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""1079.4"",""1.0"",""232.0""]}",1
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""576.0"",""1.0"",""37.0""]}",0
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""2606.5299999999997"",""7.0"",""10.0""]}",0
"{""type"":""1"",""size"":null,""indices"":null,""values"":[""783.8999999999999"",""1.0"",""145.0""]}",1


In [0]:
# TRAIN / TEST SPLIT
train_df, test_df = ml_df.randomSplit([0.8, 0.2], seed=42)

print(f"Train rows: {train_df.count()}")
print(f"Test rows: {test_df.count()}")


Train rows: 3518
Test rows: 821


In [0]:
%sql
CREATE VOLUME IF NOT EXISTS ecommerce.gold.ml_models;


In [0]:
import os

# Set MLflow temp directory to Unity Catalog volume
os.environ["MLFLOW_DFS_TMP"] = "/Volumes/ecommerce/gold/ml_models"


In [0]:
import os
import mlflow
import mlflow.spark
from pyspark.ml.classification import LogisticRegression

# Tell MLflow to use UC volume
os.environ["MLFLOW_DFS_TMP"] = "/Volumes/ecommerce/gold/ml_models"

mlflow.set_experiment("/Shared/churn_prediction_experiment")

lr = LogisticRegression(
    featuresCol="features",
    labelCol="churn"
)
model = lr.fit(train_df)
predictions = model.transform(test_df)

with mlflow.start_run():
    # Log model
    mlflow.spark.log_model(model, "logistic_regression_model")




In [0]:
# MODEL EVALUATION
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

auc_evaluator = BinaryClassificationEvaluator(
    labelCol="churn",
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC"
)

auc = auc_evaluator.evaluate(predictions)

accuracy_evaluator = MulticlassClassificationEvaluator(
    labelCol="churn",
    predictionCol="prediction",
    metricName="accuracy"
)

accuracy = accuracy_evaluator.evaluate(predictions)

mlflow.log_metric("AUC", auc)
mlflow.log_metric("Accuracy", accuracy)

print("AUC:", auc)
print("Accuracy:", accuracy)


AUC: 1.0
Accuracy: 1.0


In [0]:
predictions.select("churn", "prediction", "probability").show(10)


+-----+----------+--------------------+
|churn|prediction|         probability|
+-----+----------+--------------------+
|    0|       0.0|           [1.0,0.0]|
|    1|       1.0|[2.59828276653958...|
|    1|       1.0|           [0.0,1.0]|
|    1|       1.0|           [0.0,1.0]|
|    0|       0.0|           [1.0,0.0]|
|    0|       0.0|           [1.0,0.0]|
|    1|       1.0|           [0.0,1.0]|
|    1|       1.0|           [0.0,1.0]|
|    0|       0.0|           [1.0,0.0]|
|    0|       0.0|           [1.0,0.0]|
+-----+----------+--------------------+
only showing top 10 rows
