In [0]:
# 05_batch_inference
# Batch scoring using the model stored under MLflow Run ID + writing Gold predictions table

import mlflow
import pandas as pd
import numpy as np
from pyspark.sql.functions import current_timestamp, lit
from pyspark.sql import functions as F

# ------------------------------------
# CONFIG â€” UPDATE ONLY THIS IF NEEDED
# ------------------------------------
RUN_ID = "cd36fb2b675e4cef86a1e0f1d26fb76c"   # your last successful run
FEATURES_TABLE = "churn_mlo_mdb.features_churn"
GOLD_TABLE = "churn_mlo_mdb.gold_predictions"

print("Using model from Run ID:", RUN_ID)
print("Reading features from:", FEATURES_TABLE)

# ------------------------------------
# LOAD FEATURE TABLE (Spark)
# ------------------------------------
df_features_spark = spark.table(FEATURES_TABLE)
print("Rows in feature table:", df_features_spark.count())
df_features_spark.display()

# Convert to Pandas (safe)
pdf = df_features_spark.toPandas()

# ------------------------------------
# LOAD MODEL FROM MLflow USING RUN ID
# ------------------------------------
model_uri = f"runs:/{RUN_ID}/model"
print("Loading model from:", model_uri)

model = mlflow.sklearn.load_model(model_uri)

# Prepare X (features)
pdf_local = pdf.copy()

# Save customer IDs
customer_ids = pdf_local["customerID"].astype(str)

# Drop columns NOT used during training
drop_cols = ["customerID", "churn_label"]   # exactly like training step
for c in drop_cols:
    if c in pdf_local.columns:
        pdf_local = pdf_local.drop(columns=[c])

# Safety fill
pdf_local = pdf_local.fillna(0)


# ------------------------------------
# Predict churn probability + label
# ------------------------------------
proba = model.predict_proba(pdf_local)[:, 1]
preds = (proba >= 0.5).astype(int)

print("Sample predictions:", list(preds[:10]))
print("Sample probabilities:", list(proba[:10]))

# ------------------------------------
# BUILD OUTPUT DATAFRAME (Detailed)
# ------------------------------------
pdf_output = pdf.copy()  # keep all features
pdf_output["churn_probability"] = proba
pdf_output["prediction"] = preds
pdf_output["run_id"] = RUN_ID

# Convert back to Spark
df_gold = spark.createDataFrame(pdf_output)

# Add scoring time column
df_gold = df_gold.withColumn("scoring_timestamp", current_timestamp())

# ------------------------------------
# WRITE GOLD TABLE
# ------------------------------------
spark.sql(f"DROP TABLE IF EXISTS {GOLD_TABLE}")
df_gold.write.format("delta").mode("overwrite").saveAsTable(GOLD_TABLE)

print("Gold predictions table saved as:", GOLD_TABLE)

display(df_gold.limit(20))
