In [1]:
# =============================================================
# GOLD MODEL TRAINING ‚Äî Proper OOT Split (Latest 3 Months)
# =============================================================

import os, re, glob
from datetime import datetime
import pandas as pd
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import DoubleType
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# -------------------------------------------------------------
# 1Ô∏è‚É£ Spark Init (Memory Safe + Local Mode Fix)
# -------------------------------------------------------------
spark = (
    SparkSession.builder
    .appName("Gold_Model_Training_OOT")
    .master("local[4]")
    .config("spark.sql.debug.maxToStringFields", "2000")
    .config("spark.driver.memory", "8g")
    .config("spark.executor.memory", "8g")
    .config("spark.memory.fraction", "0.8")
    .config("spark.memory.storageFraction", "0.3")
    .config("spark.sql.autoBroadcastJoinThreshold", -1)
    .config("spark.sql.shuffle.partitions", "50")
    .config("spark.default.parallelism", "50")
    .config("spark.driver.maxResultSize", "3g")
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

print("\nüöÄ Starting Gold Model Training with Proper OOT Split\n")
print(f"‚úÖ Spark Session Created: {spark.sparkContext.master}")
print(f"   - Version: {spark.version}")
print(f"   - Available Cores: {spark.sparkContext.defaultParallelism}\n")

# -------------------------------------------------------------
# 2Ô∏è‚É£ Paths
# -------------------------------------------------------------
BASE_DIR = "/opt/airflow" if os.path.exists("/opt/airflow") else "."
FEATURE_PATH = os.path.join(BASE_DIR, "datamart/gold/feature_store")
LABEL_PATH = os.path.join(BASE_DIR, "datamart/gold/label_store")
MODEL_BANK = os.path.join(BASE_DIR, "utils/model_bank")
os.makedirs(MODEL_BANK, exist_ok=True)

# -------------------------------------------------------------
# 3Ô∏è‚É£ Gather All Monthly Parquets
# -------------------------------------------------------------
feature_files = sorted(glob.glob(os.path.join(FEATURE_PATH, "gold_feature_store_*.parquet")))
label_files = sorted(glob.glob(os.path.join(LABEL_PATH, "gold_label_store_*.parquet")))

if not feature_files or not label_files:
    raise FileNotFoundError("‚ùå No Gold feature or label store found. Run main.py first.")

print(f"üìÇ Found {len(feature_files)} monthly feature files.\n")

all_df = None
for fpath in feature_files:
    tag_match = re.search(r"(\d{4}_\d{2}_\d{2})", fpath)
    if not tag_match:
        continue
    tag = tag_match.group(1)
    lpath = os.path.join(LABEL_PATH, f"gold_label_store_{tag}.parquet")
    if not os.path.exists(lpath):
        continue

    f_df = spark.read.parquet(fpath)
    l_df = spark.read.parquet(lpath)

    # Avoid ambiguous label join
    g_df = (
        f_df.alias("f")
        .join(l_df.select("Customer_ID", F.col("label").alias("label_y")), "Customer_ID", "inner")
        .drop("label")
        .withColumnRenamed("label_y", "label")
        .withColumn("snapshot_tag", F.lit(tag))
    )

    all_df = g_df if all_df is None else all_df.unionByName(g_df, allowMissingColumns=True)

if all_df is None:
    raise ValueError("‚ùå No valid feature-label pairs found.")

# Optimize data partitioning
print("üîÑ Optimizing data partitioning...")
all_df = all_df.repartition(20)
all_df = all_df.cache()

print("üìä Attempting to count rows...")
try:
    total_rows = all_df.count()
    print(f"‚úÖ Combined dataset: {total_rows} rows across {len(feature_files)} months.\n")
except Exception as e:
    print(f"‚ö†Ô∏è  Count failed: {str(e)[:100]}")
    print("   Continuing without count...\n")
    all_df.unpersist()
    all_df = all_df.cache()
    total_rows = "Unknown"

# -------------------------------------------------------------
# 4Ô∏è‚É£ Feature Cleaning
# -------------------------------------------------------------
exclude_cols = {"Customer_ID", "snapshot_date", "gold_processing_date", "label", "snapshot_tag"}
feature_cols = [
    c for (c, t) in all_df.dtypes
    if (t in ["int", "double", "float", "bigint"]) and (c not in exclude_cols)
]

for c in feature_cols:
    all_df = all_df.withColumn(
        c, F.when(F.col(c).isin(float("inf"), float("-inf")), None).otherwise(F.col(c))
    )
all_df = all_df.fillna(0, subset=feature_cols)
all_df = all_df.withColumn("label", F.col("label").cast(DoubleType()))

print(f"üßÆ Using {len(feature_cols)} numeric features.\n")

# -------------------------------------------------------------
# 5Ô∏è‚É£ Parse snapshot_tag ‚Üí snapshot_date & Identify OOT Period
# -------------------------------------------------------------
def parse_tag_from_path(path):
    match = re.search(r"(\d{4}_\d{2}_\d{2})", os.path.basename(path))
    if match:
        try:
            return datetime.strptime(match.group(1), "%Y_%m_%d")
        except:
            return None
    return None

tags = sorted([parse_tag_from_path(f) for f in feature_files if parse_tag_from_path(f) is not None])
if not tags:
    raise ValueError("‚ùå No valid snapshot tags found in feature_store filenames.")

all_df = all_df.withColumn(
    "snapshot_date",
    F.to_date(F.regexp_extract(F.col("snapshot_tag"), r"(\d{4}_\d{2}_\d{2})", 1), "yyyy_MM_dd")
)

print(f"üìÖ Available snapshots: {[t.strftime('%Y-%m-%d') for t in tags]}")

# Get latest 3 months for OOT
if len(tags) < 4:
    raise ValueError(f"‚ùå Need at least 4 months of data. Found only {len(tags)} months.")

oot_cutoff = tags[-3]  # Start of latest 3 months
print(f"\nüéØ OOT Period: {oot_cutoff.strftime('%Y-%m-%d')} onwards (latest 3 months)")
print(f"üìö Training Period: Before {oot_cutoff.strftime('%Y-%m-%d')}\n")

# Split: Historical data vs OOT
historical_df = all_df.filter(F.col("snapshot_date") < F.lit(oot_cutoff)).cache()
oot_df = all_df.filter(F.col("snapshot_date") >= F.lit(oot_cutoff)).cache()

# Optimize partitions
historical_df = historical_df.repartition(15)
oot_df = oot_df.repartition(5)

print("üìä Data Split:")
try:
    hist_count = historical_df.count()
    print(f"  Historical (for train/val/test): {hist_count} rows")
except Exception as e:
    print(f"  Historical: [count failed - {str(e)[:50]}]")
    hist_count = "N/A"

try:
    oot_count = oot_df.count()
    print(f"  OOT (latest 3 months):           {oot_count} rows\n")
except Exception as e:
    print(f"  OOT: [count failed - {str(e)[:50]}]\n")
    oot_count = "N/A"

# -------------------------------------------------------------
# 6Ô∏è‚É£ Split Historical Data: 70/15/15 (Train/Val/Test)
# -------------------------------------------------------------
print("üîÄ Splitting historical data into 70% train / 15% val / 15% test...")

# Add random column for splitting
historical_df = historical_df.withColumn("rand", F.rand(seed=42))

train_df = historical_df.filter(F.col("rand") < 0.70).drop("rand").cache()
val_df = historical_df.filter((F.col("rand") >= 0.70) & (F.col("rand") < 0.85)).drop("rand").cache()
test_df = historical_df.filter(F.col("rand") >= 0.85).drop("rand").cache()

# Optimize partitions
train_df = train_df.repartition(10)
val_df = val_df.repartition(3)
test_df = test_df.repartition(3)

print("\nüìä Historical Split Sizes:")
try:
    train_count = train_df.count()
    print(f"  Train (70%):    {train_count} rows")
except Exception as e:
    print(f"  Train (70%):    [count failed - {str(e)[:50]}]")
    train_count = "N/A"

try:
    val_count = val_df.count()
    print(f"  Val (15%):      {val_count} rows")
except Exception as e:
    print(f"  Val (15%):      [count failed - {str(e)[:50]}]")
    val_count = "N/A"

try:
    test_count = test_df.count()
    print(f"  Test (15%):     {test_count} rows\n")
except Exception as e:
    print(f"  Test (15%):     [count failed - {str(e)[:50]}]\n")
    test_count = "N/A"

# -------------------------------------------------------------
# 7Ô∏è‚É£ Assemble + Scale
# -------------------------------------------------------------
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features_raw", handleInvalid="skip")
scaler = StandardScaler(inputCol="features_raw", outputCol="features", withStd=True, withMean=False)

# -------------------------------------------------------------
# 8Ô∏è‚É£ Hyperparameter Tuning on Validation Set
# -------------------------------------------------------------
def train_and_eval(model_name, model_obj, train_df, val_df):
    try:
        pipeline = Pipeline(stages=[assembler, scaler, model_obj])
        print(f"  Training {model_name}...", end=" ", flush=True)
        model = pipeline.fit(train_df)
        preds = model.transform(val_df)
        
        eval_auc = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderROC")
        eval_pr = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderPR")
        auc = eval_auc.evaluate(preds)
        pr = eval_pr.evaluate(preds)
        
        print(f"‚úÖ AUC={auc:.4f} | PR={pr:.4f}")
        return model, auc, pr
    except Exception as e:
        print(f"‚ùå FAILED: {str(e)[:100]}")
        return None, 0.0, 0.0

print("üîç Hyperparameter tuning on validation set...\n")
results = []

# Logistic Regression grid
for reg in [0.01, 0.05, 0.1]:
    lr = LogisticRegression(featuresCol="features", labelCol="label", maxIter=50, regParam=reg)
    _, auc, pr = train_and_eval(f"LR_reg{reg}", lr, train_df, val_df)
    if auc > 0:
        results.append(("LogisticRegression", reg, auc, pr))

# Random Forest grid
for depth in [6, 8]:
    rf = RandomForestClassifier(
        featuresCol="features", 
        labelCol="label", 
        numTrees=50,
        maxDepth=depth,
        maxBins=32
    )
    _, auc, pr = train_and_eval(f"RF_depth{depth}", rf, train_df, val_df)
    if auc > 0:
        results.append(("RandomForest", depth, auc, pr))

if not results:
    raise ValueError("‚ùå All models failed to train. Check memory and data quality.")

# -------------------------------------------------------------
# 9Ô∏è‚É£ Select Best Model & Evaluate on Historical Test Set
# -------------------------------------------------------------
best_model_name, best_param, best_auc, _ = max(results, key=lambda x: x[2])
print(f"\nüèÜ Best model (from validation): {best_model_name} (param={best_param}, ValAUC={best_auc:.4f})")

if best_model_name == "LogisticRegression":
    final_model = LogisticRegression(featuresCol="features", labelCol="label", regParam=best_param, maxIter=50)
else:
    final_model = RandomForestClassifier(
        featuresCol="features", 
        labelCol="label", 
        numTrees=50, 
        maxDepth=best_param,
        maxBins=32
    )

print("\nüîÑ Training final model on full training set...")
pipeline = Pipeline(stages=[assembler, scaler, final_model])
final_fit = pipeline.fit(train_df)

print("üìä Evaluating on historical test set (15%)...")
pred_test = final_fit.transform(test_df)

eval_auc = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderROC")
eval_pr = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderPR")
auc_test = eval_auc.evaluate(pred_test)
pr_test = eval_pr.evaluate(pred_test)

print(f"   Historical Test ‚Äî AUC={auc_test:.4f}, PR={pr_test:.4f}")

# -------------------------------------------------------------
# üîü Final Evaluation on OOT (Latest 3 Months)
# -------------------------------------------------------------
print("\nüìä Evaluating on OOT set (latest 3 months)...")
pred_oot = final_fit.transform(oot_df)

auc_oot = eval_auc.evaluate(pred_oot)
pr_oot = eval_pr.evaluate(pred_oot)

print(f"   OOT Test ‚Äî AUC={auc_oot:.4f}, PR={pr_oot:.4f}\n")

# -------------------------------------------------------------
# 1Ô∏è‚É£1Ô∏è‚É£ Save Model + Metrics
# -------------------------------------------------------------
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(MODEL_BANK, f"{best_model_name}_OOT_model_{timestamp}")
final_fit.write().overwrite().save(model_path)

metrics_path = os.path.join(MODEL_BANK, f"oot_model_metrics_{timestamp}.csv")
pd.DataFrame(
    [[best_model_name, best_param, best_auc, auc_test, pr_test, auc_oot, pr_oot, 
      str(train_count), str(val_count), str(test_count), str(oot_count)]],
    columns=["Model", "BestParam", "ValAUC", "HistTestAUC", "HistTestPR", "OOT_AUC", "OOT_PR",
             "TrainRows", "ValRows", "HistTestRows", "OOT_Rows"]
).to_csv(metrics_path, index=False)

print(f"üíæ Saved model ‚Üí {model_path}")
print(f"üìä Saved metrics ‚Üí {metrics_path}")

# Summary
print("\n" + "="*60)
print("üìà FINAL RESULTS SUMMARY")
print("="*60)
print(f"Best Model: {best_model_name} (param={best_param})")
print(f"\nValidation AUC:        {best_auc:.4f}")
print(f"Historical Test AUC:   {auc_test:.4f} | PR={pr_test:.4f}")
print(f"OOT Test AUC:          {auc_oot:.4f} | PR={pr_oot:.4f}")
print("="*60 + "\n")

# Cleanup
all_df.unpersist()
historical_df.unpersist()
train_df.unpersist()
val_df.unpersist()
test_df.unpersist()
oot_df.unpersist()

spark.stop()
print("‚úÖ OOT Model Training Completed Successfully.\n")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/09 08:46:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable



üöÄ Starting Gold Model Training with Proper OOT Split

‚úÖ Spark Session Created: local[4]
   - Version: 3.5.5
   - Available Cores: 50

üìÇ Found 25 monthly feature files.



                                                                                

üîÑ Optimizing data partitioning...
üìä Attempting to count rows...


                                                                                

‚úÖ Combined dataset: 104814 rows across 25 months.

üßÆ Using 23 numeric features.

üìÖ Available snapshots: ['2023-01-01', '2023-02-01', '2023-03-01', '2023-04-01', '2023-05-01', '2023-06-01', '2023-07-01', '2023-08-01', '2023-09-01', '2023-10-01', '2023-11-01', '2023-12-01', '2024-01-01', '2024-02-01', '2024-03-01', '2024-04-01', '2024-05-01', '2024-06-01', '2024-07-01', '2024-08-01', '2024-09-01', '2024-10-01', '2024-11-01', '2024-12-01', '2025-11-01']

üéØ OOT Period: 2024-11-01 onwards (latest 3 months)
üìö Training Period: Before 2024-11-01

üìä Data Split:


                                                                                

  Historical (for train/val/test): 93256 rows


                                                                                

  OOT (latest 3 months):           11558 rows

üîÄ Splitting historical data into 70% train / 15% val / 15% test...

üìä Historical Split Sizes:


                                                                                

  Train (70%):    65562 rows


                                                                                

  Val (15%):      13909 rows


                                                                                

  Test (15%):     13785 rows

üîç Hyperparameter tuning on validation set...

  Training LR_reg0.01... 

                                                                                

‚úÖ AUC=0.7239 | PR=0.3514
  Training LR_reg0.05... 

                                                                                

‚úÖ AUC=0.7238 | PR=0.3511
‚úÖ AUC=0.7237 | PR=0.3508
  Training RF_depth6... 

                                                                                

‚úÖ AUC=0.7261 | PR=0.3559
  Training RF_depth8... 

                                                                                

‚úÖ AUC=0.7373 | PR=0.3682

üèÜ Best model (from validation): RandomForest (param=8, ValAUC=0.7373)

üîÑ Training final model on full training set...


                                                                                

üìä Evaluating on historical test set (15%)...
   Historical Test ‚Äî AUC=0.7269, PR=0.3655

üìä Evaluating on OOT set (latest 3 months)...
   OOT Test ‚Äî AUC=0.6837, PR=0.3661



                                                                                

üíæ Saved model ‚Üí /opt/airflow/utils/model_bank/RandomForest_OOT_model_20251109_085136
üìä Saved metrics ‚Üí /opt/airflow/utils/model_bank/oot_model_metrics_20251109_085136.csv

üìà FINAL RESULTS SUMMARY
Best Model: RandomForest (param=8)

Validation AUC:        0.7373
Historical Test AUC:   0.7269 | PR=0.3655
OOT Test AUC:          0.6837 | PR=0.3661

‚úÖ OOT Model Training Completed Successfully.

