In [0]:
# Databricks notebook: Step 3 – Train, Evaluate, and Log Model (Unity Catalog)
from pyspark.sql import SparkSession
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from mlflow.models.signature import infer_signature
from pyspark.sql.functions import col
import mlflow
import mlflow.spark

spark = SparkSession.builder.getOrCreate()

# --------------------------------------------------
# 1. Load engineered features
# --------------------------------------------------
data = spark.read.table("default.features_credit_train")

# --------------------------------------------------
# 2. Split into train/test
# --------------------------------------------------
train_df, test_df = data.randomSplit([0.8, 0.2], seed=42)

# --------------------------------------------------
# 3. Train model
# --------------------------------------------------
rf = RandomForestClassifier(
    labelCol="default_flag",
    featuresCol="features",
    numTrees=100,
    maxDepth=8,
    seed=42
)

# --------------------------------------------------
# 4. Configure MLflow (Unity Catalog)
# --------------------------------------------------
mlflow.set_registry_uri("databricks-uc")
mlflow.set_experiment("/Shared/banking_credit_default_experiment")

with mlflow.start_run(run_name="rf_baseline_train_eval") as run:
    model = rf.fit(train_df)
    preds = model.transform(test_df)

    # --------------------------------------------------
    # 5. Evaluate metrics
    # --------------------------------------------------
    evaluator_auc = BinaryClassificationEvaluator(labelCol="default_flag", metricName="areaUnderROC")
    evaluator_acc = MulticlassClassificationEvaluator(labelCol="default_flag", metricName="accuracy")
    evaluator_f1  = MulticlassClassificationEvaluator(labelCol="default_flag", metricName="f1")

    auc = evaluator_auc.evaluate(preds)
    acc = evaluator_acc.evaluate(preds)
    f1  = evaluator_f1.evaluate(preds)

    print(f"ROC-AUC: {auc:.4f} | Accuracy: {acc:.4f} | F1: {f1:.4f}")

    mlflow.log_params({"num_trees": 100, "max_depth": 8})
    mlflow.log_metrics({"roc_auc": auc, "accuracy": acc, "f1": f1})

    # --------------------------------------------------
    # 6. Prepare model signature and input example
    # --------------------------------------------------
    # Skip vector column for JSON serialization
    feature_cols = [c for c in train_df.columns if c not in ["features", "default_flag"]]

    sample_input = train_df.select(*feature_cols).limit(5).toPandas()
    sample_output = model.transform(train_df.limit(5)).select("prediction").limit(5).toPandas()

    signature = infer_signature(sample_input, sample_output)

    # --------------------------------------------------
    # 7. Log model to MLflow (UC volume temp path)
    # --------------------------------------------------
    mlflow.spark.log_model(
        spark_model=model,
        artifact_path="model",
        dfs_tmpdir="/Volumes/banking/default/mlflow_tmp",
        signature=signature,
        input_example=sample_input
    )

    print(f"Run ID: {run.info.run_id}")
    mlflow.log_text(run.info.run_id, "run_id.txt")

print("✅ Step 3 complete – model, signature, and metrics logged to MLflow (Unity Catalog).")