In [0]:

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, DoubleType
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
import datetime

gold_base = "/Volumes/workspace/data/bda_project/gold"   # base used earlier
kpi_daily_path = f"{gold_base}/kpi_daily"                # delta location 
output_predictions_path = f"{gold_base}/ml/kpi_model_predictions_fixed2"
model_output_path = f"{gold_base}/models/kpi_rf_model_fixed2"
database = "data"

# 1) Read KPI daily table (Delta)
kpi_daily = spark.read.format("delta").load(kpi_daily_path)
print("kpi_daily columns:", kpi_daily.columns)

# 2) Ensure event_date exists and is date-typed
if "event_date" not in kpi_daily.columns:
    if "date" in kpi_daily.columns:
        kpi_daily = kpi_daily.withColumnRenamed("date","event_date")
    elif "event_time" in kpi_daily.columns:
        kpi_daily = kpi_daily.withColumn("event_date", F.to_date("event_time"))
    else:
        raise ValueError("kpi_daily has no event_date, date, or event_time column. Add a date column.")

kpi_daily = kpi_daily.withColumn("event_date", F.to_date("event_date"))


rename_map = {}
if "txn_count" in kpi_daily.columns:
    kpi_daily = kpi_daily.withColumnRenamed("txn_count", "txn_count_today")
    rename_map["txn_count"] = "txn_count_today"

for col in ["txn_count_today", "txn_amount_total", "txn_amount_avg", "fraud_count", "unique_cards", "fraud_rate"]:
    if col in kpi_daily.columns:
        kpi_daily = kpi_daily.withColumn(col, F.col(col).cast(DoubleType()))


w7 = Window.orderBy(F.col("event_date").cast("timestamp")).rowsBetween(-6, 0)

kpi_daily_with_7d = kpi_daily \
    .withColumn("txn_count_7d_sum", F.sum("txn_count_today").over(w7)) \
    .withColumn("txn_count_7d_avg", F.avg("txn_count_today").over(w7))


kpi_daily_with_7d = kpi_daily_with_7d.fillna({"txn_count_7d_sum":0.0, "txn_count_7d_avg":0.0})

display(kpi_daily_with_7d.orderBy(F.desc("event_date")).limit(10))


kpi_daily_with_7d = kpi_daily_with_7d.repartition(200, F.col("event_date"))


w_global = Window.orderBy(F.col("event_date").cast("timestamp"))
kpi_daily_with_7d = kpi_daily_with_7d.withColumn("txn_count_next", F.lead("txn_count_today", 1).over(w_global))


kpi_daily_with_7d = kpi_daily_with_7d.withColumn(
    "txn_up_next_day",
    F.when(F.col("txn_count_next") > F.col("txn_count_today"), 1).otherwise(0).cast(IntegerType())
)


df = kpi_daily_with_7d.filter(F.col("txn_count_next").isNotNull())


candidate_features = [
    "txn_count_today", "txn_amount_total", "txn_amount_avg", "fraud_count", "unique_cards", "fraud_rate",
    "txn_count_7d_sum", "txn_count_7d_avg"
]
features = [c for c in candidate_features if c in df.columns]
if not features:
    raise ValueError(f"No feature columns found from candidates. Available columns: {df.columns}")

print("Using features:", features)


selected_cols = ["event_date", "txn_count_next", "txn_up_next_day"] + features
selected_cols = list(dict.fromkeys(selected_cols))  
final_df = df.select(selected_cols).na.drop()


for f in features:
    final_df = final_df.withColumn(f, F.col(f).cast(DoubleType()))

display(final_df.orderBy(F.desc("event_date")).limit(10))


min_max = final_df.select(F.min("event_date").alias("min_d"), F.max("event_date").alias("max_d")).collect()[0]
min_d = min_max["min_d"]
max_d = min_max["max_d"]
if min_d is None or max_d is None:
    raise ValueError("No dates found in dataset")

total_days = (max_d - min_d).days + 1
train_days = int(total_days * 0.8)
split_date = (min_d + datetime.timedelta(days=train_days))

print(f"Date range: {min_d} -> {max_d}, total_days={total_days}, split_date={split_date}")

train_df = final_df.filter(F.col("event_date") <= F.lit(split_date))
test_df = final_df.filter(F.col("event_date") > F.lit(split_date))

print("Train rows:", train_df.count(), " Test rows:", test_df.count())


assembler = VectorAssembler(inputCols=features, outputCol="features_raw", handleInvalid="skip")
scaler = StandardScaler(inputCol="features_raw", outputCol="features", withStd=True, withMean=False)
rf = RandomForestClassifier(featuresCol="features", labelCol="txn_up_next_day",
                            probabilityCol="probability", rawPredictionCol="rawPrediction",
                            predictionCol="prediction", numTrees=100, maxDepth=6, seed=42)

pipeline = Pipeline(stages=[assembler, scaler, rf])


model = pipeline.fit(train_df)


preds = model.transform(train_df)


bce = BinaryClassificationEvaluator(labelCol="txn_up_next_day", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
auc = bce.evaluate(preds)
multi_eval = MulticlassClassificationEvaluator(labelCol="txn_up_next_day", predictionCol="prediction", metricName="accuracy")
accuracy = multi_eval.evaluate(preds)

print(f"AUC (ROC): {auc:.4f}")
print(f"Accuracy: {accuracy:.4f}")


conf = preds.groupBy("txn_up_next_day", "prediction").count().orderBy("txn_up_next_day", "prediction")
display(conf)


rf_stage = model.stages[-1]
importances = rf_stage.featureImportances.toArray().tolist()
feat_imp = list(zip(features, importances))
feat_imp_sorted = sorted(feat_imp, key=lambda x: x[1], reverse=True)
print("Feature importances (descending):")
for f, imp in feat_imp_sorted:
    print(f"{f}: {imp:.4f}")


preds_to_save = preds.select(
    "event_date",
    F.col("txn_count_today").alias("txn_count_today"),
    "txn_count_next",
    "txn_up_next_day",
    "prediction",
    "probability"
).withColumn("_pred_ts", F.current_timestamp()).withColumn("_model_version", F.lit("rf_kpi_v2_fixed"))


try:
    dbutils.fs.rm(output_predictions_path, recurse=True)
except Exception:
    pass
preds_to_save.write.format("delta").mode("overwrite").option("overwriteSchema","true").save(output_predictions_path)
print("Predictions written to:", output_predictions_path)


try:
    dbutils.fs.rm(model_output_path, recurse=True)
except Exception:
    pass
model.write().overwrite().save(model_output_path)
print("Model saved to:", model_output_path)


display(preds_to_save.orderBy(F.desc("event_date")).limit(20))


print("\nDONE.")
print("Tips to improve model:")
print(" - Add lag features (txn_count_lag1, txn_count_lag7, percent_change).")
print(" - Add merchant-level or city-level KPIs as extra features.")
print(" - Add market sentiment if available for better performance.")
print(" - If performance low: hyperparameter tuning (CrossValidator) or time-series models.")


kpi_daily columns: ['event_date', 'txn_count', 'txn_amount_total', 'txn_amount_avg', 'fraud_count', 'unique_cards', 'fraud_rate']




event_date,txn_count_today,txn_amount_total,txn_amount_avg,fraud_count,unique_cards,fraud_rate,txn_count_7d_sum,txn_count_7d_avg
2025-02-28,9089.0,14860.859125251229,1.635037861728598,88.0,9089.0,0.0096820332269776,59724.0,8532.0
2025-02-27,9053.0,15004.107912680742,1.657363074415193,91.0,9052.0,0.0100519164917706,59710.0,8530.0
2025-02-26,8937.0,14898.1763868479,1.6670220864773304,88.0,8936.0,0.009846704710753,59566.0,8509.42857142857
2025-02-25,9108.0,15271.465140973747,1.6767089526760814,84.0,9108.0,0.0092226613965744,59639.0,8519.857142857143
2025-02-24,9136.0,15342.776243439695,1.6793756833887583,81.0,9136.0,0.0088660245183887,59445.0,8492.142857142857
2025-02-23,6792.0,11329.80733163751,1.6681106200879727,53.0,6792.0,0.0078032979976442,59186.0,8455.142857142857
2025-02-22,7609.0,12543.962981502653,1.6485691919440997,61.0,7609.0,0.0080168221842554,59182.0,8454.57142857143
2025-02-21,9075.0,14987.474984727594,1.651512395011305,90.0,9075.0,0.0099173553719008,58947.0,8421.0
2025-02-20,8909.0,14782.174677492814,1.6592406193167375,91.0,8909.0,0.0102143899427545,58786.0,8398.0
2025-02-19,9010.0,14854.89320300762,1.64871178723725,79.0,9010.0,0.0087680355160932,58846.0,8406.57142857143




Using features: ['txn_count_today', 'txn_amount_total', 'txn_amount_avg', 'fraud_count', 'unique_cards', 'fraud_rate', 'txn_count_7d_sum', 'txn_count_7d_avg']


event_date,txn_count_next,txn_up_next_day,txn_count_today,txn_amount_total,txn_amount_avg,fraud_count,unique_cards,fraud_rate,txn_count_7d_sum,txn_count_7d_avg
2025-02-27,9089.0,1,9053.0,15004.107912680742,1.657363074415193,91.0,9052.0,0.0100519164917706,59710.0,8530.0
2025-02-26,9053.0,1,8937.0,14898.1763868479,1.6670220864773304,88.0,8936.0,0.009846704710753,59566.0,8509.42857142857
2025-02-25,8937.0,0,9108.0,15271.465140973747,1.6767089526760814,84.0,9108.0,0.0092226613965744,59639.0,8519.857142857143
2025-02-24,9108.0,0,9136.0,15342.776243439695,1.6793756833887583,81.0,9136.0,0.0088660245183887,59445.0,8492.142857142857
2025-02-23,9136.0,1,6792.0,11329.80733163751,1.6681106200879727,53.0,6792.0,0.0078032979976442,59186.0,8455.142857142857
2025-02-22,6792.0,0,7609.0,12543.962981502653,1.6485691919440997,61.0,7609.0,0.0080168221842554,59182.0,8454.57142857143
2025-02-21,7609.0,0,9075.0,14987.474984727594,1.651512395011305,90.0,9075.0,0.0099173553719008,58947.0,8421.0
2025-02-20,9075.0,1,8909.0,14782.174677492814,1.6592406193167375,91.0,8909.0,0.0102143899427545,58786.0,8398.0
2025-02-19,8909.0,0,9010.0,14854.89320300762,1.64871178723725,79.0,9010.0,0.0087680355160932,58846.0,8406.57142857143
2025-02-18,9010.0,1,8914.0,14264.37581074541,1.600221652540432,77.0,8914.0,0.0086380973749158,58959.0,8422.714285714286


Date range: 2025-01-01 -> 2025-02-27, total_days=58, split_date=2025-02-16
Train rows: 47  Test rows: 11
AUC (ROC): 1.0000
Accuracy: 0.9787


txn_up_next_day,prediction,count
0,0.0,27
1,0.0,1
1,1.0,19


Feature importances (descending):
txn_count_today: 0.1931
txn_amount_total: 0.1490
unique_cards: 0.1319
txn_amount_avg: 0.1232
txn_count_7d_avg: 0.1182
fraud_rate: 0.1071
fraud_count: 0.1049
txn_count_7d_sum: 0.0726
Predictions written to: /Volumes/workspace/data/bda_project/gold/ml/kpi_model_predictions_fixed2
Model saved to: /Volumes/workspace/data/bda_project/gold/models/kpi_rf_model_fixed2


event_date,txn_count_today,txn_count_next,txn_up_next_day,prediction,probability,_pred_ts,_model_version
2025-02-16,6788.0,8877.0,1,1.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.04"",""0.96""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed
2025-02-15,7374.0,6788.0,0,0.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.7382208393458394"",""0.2617791606541606""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed
2025-02-14,8914.0,7374.0,0,0.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.804105093532513"",""0.19589490646748703""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed
2025-02-13,8969.0,8914.0,0,0.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.780221976649396"",""0.219778023350604""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed
2025-02-12,9123.0,8969.0,0,0.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.8686180204204398"",""0.13138197957956022""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed
2025-02-11,9021.0,9123.0,1,1.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.25234638047138047"",""0.7476536195286195""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed
2025-02-10,8985.0,9021.0,1,1.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.35788287131029073"",""0.6421171286897093""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed
2025-02-09,6873.0,8985.0,1,1.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.2735235732009926"",""0.7264764267990074""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed
2025-02-08,7587.0,6873.0,0,0.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.928119412546832"",""0.07188058745316814""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed
2025-02-07,8762.0,7587.0,0,0.0,"{""type"":""1"",""size"":null,""indices"":null,""values"":[""0.8472598554372748"",""0.1527401445627252""]}",2025-10-16T03:47:14.536Z,rf_kpi_v2_fixed



DONE.
Tips to improve model:
 - Add lag features (txn_count_lag1, txn_count_lag7, percent_change).
 - Add merchant-level or city-level KPIs as extra features.
 - Add market sentiment if available for better performance.
 - If performance low: hyperparameter tuning (CrossValidator) or time-series models.
