# Regression Comparison

In [None]:
import os
import time
import csv
import json
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import shutil
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, month
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression, RandomForestRegressor, GBTRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.storagelevel import StorageLevel
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator



## Auxiliary Functions

In [2]:
BASE_OUTPUT_DIR = "saved_models"
MASTER_LOG_FILE = "model_comparison.csv"
BASE_DATA_PATH = "processed_data"


def get_paths(years):
    """Generates file paths for the specific years."""
    return [f"{BASE_DATA_PATH}/climate_{y}.parquet" for y in years]

def save_training_history(model, output_dir, model_name):
    """
    Extracts iteration history (Objective History) if available.
    Works specifically for LinearRegression.
    """
    history_path = os.path.join(output_dir, "training_history.csv")
    
    if hasattr(model, "summary") and hasattr(model.summary, "objectiveHistory"):
        history = model.summary.objectiveHistory
        
        df_hist = pd.DataFrame({
            "Iteration": range(1, len(history) + 1),
            "Objective_Loss": history
        })
        df_hist.to_csv(history_path, index=False)
        print(f"   [v] Training history (Loss) saved to: {history_path}")
        
        plt.figure(figsize=(8, 4))
        plt.plot(df_hist["Iteration"], df_hist["Objective_Loss"], marker='o')
        plt.title(f"Convergence Curve - {model_name}")
        plt.xlabel("Iteration")
        plt.ylabel("Loss (Objective Function)")
        plt.grid(True)
        plt.savefig(os.path.join(output_dir, "convergence_plot.png"))
        plt.close()
    else:
        print("   [!] This algorithm does not expose iterative history (objectiveHistory).")

def evaluate_and_log(predictions, target_col, time_taken, output_dir, model_name):
    """
    Calculates metrics, saves them to the Master CSV, and saves a local JSON.
    """
    evaluator = RegressionEvaluator(labelCol=target_col, predictionCol="prediction")
    
    r2 = evaluator.setMetricName("r2").evaluate(predictions)
    rmse = evaluator.setMetricName("rmse").evaluate(predictions)
    mae = evaluator.setMetricName("mae").evaluate(predictions)
    
    snow_subset = predictions.filter(col(target_col) > 0)
    if snow_subset.count() > 0:
        r2_snow = evaluator.setMetricName("r2").evaluate(snow_subset)
    else:
        r2_snow = 0.0

    metrics = {
        "Model": model_name,
        "Time_Sec": round(time_taken, 2),
        "R2_Global": round(r2, 4),
        "RMSE_Global": round(rmse, 4),
        "MAE_Global": round(mae, 4),
        "R2_Snow_Only": round(r2_snow, 4)
    }

    file_exists = os.path.isfile(MASTER_LOG_FILE)
    with open(MASTER_LOG_FILE, mode='a', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=metrics.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(metrics)
    
    with open(os.path.join(output_dir, "metrics.json"), 'w') as f:
        json.dump(metrics, f, indent=4)

    print(f"\n--- METRICS ({model_name}) ---")
    print(f"R2: {r2:.4f} | RMSE: {rmse:.4f} | R2 Snow Only: {r2_snow:.4f}")
    return metrics

def plot_predictions(predictions, target_col, output_dir, model_name):
    """Generates and saves Scatter and Residual plots."""
    print("   Generating plots...")
    
    pdf = predictions.select(target_col, "prediction").sample(False, 0.05, seed=42).toPandas()
    
    plt.figure(figsize=(14, 6))

    plt.subplot(1, 2, 1)
    sns.scatterplot(x=pdf[target_col], y=pdf["prediction"], alpha=0.3)
    plt.plot([pdf[target_col].min(), pdf[target_col].max()], 
             [pdf[target_col].min(), pdf[target_col].max()], 'r--', lw=2)
    plt.xlabel('Reality (Actual Value)')
    plt.ylabel('Prediction')
    plt.title(f'Prediction vs Reality - {model_name}')

    plt.subplot(1, 2, 2)
    residuals = pdf[target_col] - pdf["prediction"]
    sns.histplot(residuals, bins=50, kde=True)
    plt.xlabel('Error (Real - Predicted)')
    plt.title('Residuals Distribution')
    
    plt.savefig(os.path.join(output_dir, "prediction_plots.png"))
    plt.close()
    print(f"   [v] Plots saved in: {output_dir}")

## Linear Regression

In [None]:
MODEL_NAME = "LinearRegression_Baseline"

print(f"--- STARTING: {MODEL_NAME} ---")

spark = SparkSession.builder \
    .appName(MODEL_NAME) \
    .master("local[*]") \
    .config("spark.driver.memory", "10g") \
    .config("spark.executor.memory", "10g") \
    .config("spark.sql.files.maxPartitionBytes", "128m") \
    .config("spark.driver.maxResultSize", "4g") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

model_output_dir = os.path.join(BASE_OUTPUT_DIR, MODEL_NAME)
if os.path.exists(model_output_dir):
    shutil.rmtree(model_output_dir)
os.makedirs(model_output_dir)

print(f"Output Directory: {model_output_dir}")

train_years = range(2010, 2021) 
val_years   = range(2021, 2023) 
test_years  = range(2023, 2025) 

print("1. Loading Data (Full Dataset)...")
try:
    train_df = spark.read.parquet(*get_paths(train_years))
    val_df   = spark.read.parquet(*get_paths(val_years))
    test_df  = spark.read.parquet(*get_paths(test_years))
except Exception as e:
    print(f"Error loading data: {e}")
    spark.stop()
    raise e

target_col = "SNDP"
ignore_cols = [target_col, "DATE", "STATION", "NAME", "features", "prediction", "FRSHTT"]
valid_types = ['int', 'bigint', 'float', 'double', 'tinyint', 'smallint']

dtypes = train_df.dtypes
feature_cols = [c for c, t in dtypes if t in valid_types and c not in ignore_cols]
print(f"   Features ({len(feature_cols)}): {feature_cols}")

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features").setHandleInvalid("skip")

train_vec = assembler.transform(train_df)
val_vec   = assembler.transform(val_df)
test_vec  = assembler.transform(test_df)

train_vec.persist(StorageLevel.MEMORY_AND_DISK)
print(f"   Training Rows: {train_vec.count():,}")

print(f"2. Training {MODEL_NAME}...")

lr = LinearRegression(
    featuresCol="features", 
    labelCol=target_col,
    maxIter=50, 
    regParam=0.1, 
    elasticNetParam=0.5
)

start_time = time.time()
model = lr.fit(train_vec)
end_time = time.time()
duration = end_time - start_time

print(f"   Training completed in {duration:.2f} seconds.")

print("3. Processing Results...")

model_save_path = os.path.join(model_output_dir, "spark_model")
model.write().overwrite().save(model_save_path)
print(f"   [v] Spark Model saved to: {model_save_path}")

save_training_history(model, model_output_dir, MODEL_NAME)

print("   Generating predictions on Test Set...")
test_preds = model.transform(test_vec)
metrics = evaluate_and_log(test_preds, target_col, duration, model_output_dir, MODEL_NAME)

plot_predictions(test_preds, target_col, model_output_dir, MODEL_NAME)

print("\n--- PROCESS FINISHED SUCCESSFULLY ---")
spark.stop()

--- STARTING: LinearRegression_Baseline ---


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/01 20:50:18 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Output Directory: saved_models/LinearRegression_Baseline
1. Loading Data (Full Dataset)...


                                                                                

   Features (19): ['LATITUDE', 'LONGITUDE', 'ELEVATION', 'TEMP', 'DEWP', 'SLP', 'STP', 'VISIB', 'WDSP', 'MXSPD', 'MAX', 'MIN', 'PRCP', 'is_Fog', 'is_Rain', 'is_Snow', 'is_Hail', 'is_Thunder', 'is_Tornado']


[Stage 4:>                                                          (0 + 4) / 7]

## Random Forest

In [None]:
BASE_OUTPUT_DIR = "saved_models"     
MASTER_LOG_FILE = "model_comparison.csv" 
BASE_DATA_PATH = "processed_data"    

def get_paths(years):
    return [f"{BASE_DATA_PATH}/climate_{y}.parquet" for y in years]

def save_feature_importance(model, feature_cols, output_dir):
    """
    Extracts Feature Importance from Random Forest and saves it to CSV.
    """
    if hasattr(model, "featureImportances"):
        importances = model.featureImportances
        feature_list = []
        for i, col_name in enumerate(feature_cols):
            feature_list.append({"Feature": col_name, "Importance": float(importances[i])})
        
        df_imp = pd.DataFrame(feature_list).sort_values(by="Importance", ascending=False)
        csv_path = os.path.join(output_dir, "feature_importance.csv")
        df_imp.to_csv(csv_path, index=False)
        print(f"   [v] Feature Importance saved to: {csv_path}")
        
        print("   --- TOP 5 FEATURES ---")
        print(df_imp.head(5))
    else:
        print("   [!] This model does not support feature importance.")

def evaluate_and_log(predictions, target_col, time_taken, output_dir, model_name):
    """
    Calculates R2, RMSE, MAE, saves to Master CSV and local JSON.
    """
    evaluator = RegressionEvaluator(labelCol=target_col, predictionCol="prediction")
    
    r2 = evaluator.setMetricName("r2").evaluate(predictions)
    rmse = evaluator.setMetricName("rmse").evaluate(predictions)
    mae = evaluator.setMetricName("mae").evaluate(predictions)
    
    snow_subset = predictions.filter(col(target_col) > 0)
    if snow_subset.count() > 0:
        r2_snow = evaluator.setMetricName("r2").evaluate(snow_subset)
    else:
        r2_snow = 0.0

    metrics = {
        "Model": model_name,
        "Time_Sec": round(time_taken, 2),
        "R2_Global": round(r2, 4),
        "RMSE_Global": round(rmse, 4),
        "MAE_Global": round(mae, 4),
        "R2_Snow_Only": round(r2_snow, 4)
    }

    file_exists = os.path.isfile(MASTER_LOG_FILE)
    with open(MASTER_LOG_FILE, mode='a', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=metrics.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(metrics)
    
    with open(os.path.join(output_dir, "metrics.json"), 'w') as f:
        json.dump(metrics, f, indent=4)

    print(f"\n--- METRICS ({model_name}) ---")
    print(f"R2: {r2:.4f} | RMSE: {rmse:.4f} | R2 Snow Only: {r2_snow:.4f}")
    return metrics

def plot_predictions(predictions, target_col, output_dir, model_name):
    """Generates and saves Scatter and Residual plots."""
    print("   Generating plots...")
    
    pdf = predictions.select(target_col, "prediction").sample(False, 0.05, seed=42).toPandas()
    
    plt.figure(figsize=(14, 6))

    plt.subplot(1, 2, 1)
    sns.scatterplot(x=pdf[target_col], y=pdf["prediction"], alpha=0.3)
    plt.plot([pdf[target_col].min(), pdf[target_col].max()], 
             [pdf[target_col].min(), pdf[target_col].max()], 'r--', lw=2)
    plt.xlabel('Reality (Actual Value)')
    plt.ylabel('Prediction')
    plt.title(f'Prediction vs Reality - {model_name}')

    plt.subplot(1, 2, 2)
    residuals = pdf[target_col] - pdf["prediction"]
    sns.histplot(residuals, bins=50, kde=True)
    plt.xlabel('Error (Real - Predicted)')
    plt.title('Residuals Distribution')
    
    plt.savefig(os.path.join(output_dir, "prediction_plots.png"))
    plt.close()
    print(f"   [v] Plots saved in: {output_dir}")

In [None]:
MODEL_NAME = "RandomForest_Baseline"

print(f"--- STARTING: {MODEL_NAME} ---")

spark = SparkSession.builder \
    .appName(MODEL_NAME) \
    .master("local[*]") \
    .config("spark.driver.memory", "10g") \
    .config("spark.executor.memory", "10g") \
    .config("spark.sql.files.maxPartitionBytes", "128m") \
    .config("spark.driver.maxResultSize", "4g") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

model_output_dir = os.path.join(BASE_OUTPUT_DIR, MODEL_NAME)
if os.path.exists(model_output_dir):
    shutil.rmtree(model_output_dir)
os.makedirs(model_output_dir)

train_years = range(2010, 2021) 
val_years   = range(2021, 2023) 
test_years  = range(2023, 2025) 

print("1. Loading Data...")
try:
    train_df = spark.read.parquet(*get_paths(train_years))
    val_df   = spark.read.parquet(*get_paths(val_years))
    test_df  = spark.read.parquet(*get_paths(test_years))
except Exception as e:
    print(f"Error loading data: {e}")
    spark.stop()
    raise e

target_col = "SNDP"
ignore_cols = [target_col, "DATE", "STATION", "NAME", "features", "prediction", "FRSHTT"]
valid_types = ['int', 'bigint', 'float', 'double', 'tinyint', 'smallint']

dtypes = train_df.dtypes
feature_cols = [c for c, t in dtypes if t in valid_types and c not in ignore_cols]
print(f"   Features ({len(feature_cols)}): {feature_cols}")

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features").setHandleInvalid("skip")

train_vec = assembler.transform(train_df)
val_vec   = assembler.transform(val_df)
test_vec  = assembler.transform(test_df)

train_vec.persist(StorageLevel.MEMORY_AND_DISK)
print(f"   Training Rows: {train_vec.count():,}")

print(f"2. Training {MODEL_NAME}...")

rf = RandomForestRegressor(
    featuresCol="features", 
    labelCol=target_col,
    numTrees=40,         
    maxDepth=10,         
    seed=42,
    subsamplingRate=0.7 
)

start_time = time.time()
model = rf.fit(train_vec)
end_time = time.time()
duration = end_time - start_time

print(f"   Training completed in {duration:.2f} seconds.")

print("3. Processing Results...")
model.write().overwrite().save(os.path.join(model_output_dir, "spark_model"))

save_feature_importance(model, feature_cols, model_output_dir)

test_preds = model.transform(test_vec)
metrics = evaluate_and_log(test_preds, target_col, duration, model_output_dir, MODEL_NAME)

plot_predictions(test_preds, target_col, model_output_dir, MODEL_NAME)

print("\n--- PROCESS FINISHED ---")
spark.stop()

In [None]:
MODEL_NAME = "RandomForest_New_Features"

print(f"--- STARTING: {MODEL_NAME} ---")

spark = SparkSession.builder \
    .appName(MODEL_NAME) \
    .master("local[*]") \
    .config("spark.driver.memory", "10g") \
    .config("spark.executor.memory", "10g") \
    .config("spark.sql.files.maxPartitionBytes", "128m") \
    .config("spark.driver.maxResultSize", "4g") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

model_output_dir = os.path.join(BASE_OUTPUT_DIR, MODEL_NAME)
if os.path.exists(model_output_dir):
    shutil.rmtree(model_output_dir)
os.makedirs(model_output_dir)

train_years = range(2010, 2021) 
val_years   = range(2021, 2023) 
test_years  = range(2023, 2025) 

print("1. Loading Data...")
try:
    train_raw = spark.read.parquet(*get_paths(train_years))
    val_raw   = spark.read.parquet(*get_paths(val_years))
    test_raw  = spark.read.parquet(*get_paths(test_years))
except Exception as e:
    print(f"Error loading data: {e}")
    spark.stop()
    raise e

print("2. Applying Feature Engineering (Month + Solid_PRCP)...")

def add_smart_features(df):
    df = df.withColumn("MONTH", month(col("DATE")))
    df = df.withColumn(
        "Solid_PRCP", 
        when((col("PRCP") > 0) & (col("TEMP") < 2.0), col("PRCP")).otherwise(0.0)
    )
    return df

train_df = add_smart_features(train_raw)
val_df   = add_smart_features(val_raw)
test_df  = add_smart_features(test_raw)

target_col = "SNDP"
ignore_cols = [target_col, "DATE", "STATION", "NAME", "features", "prediction", "FRSHTT"]
valid_types = ['int', 'bigint', 'float', 'double', 'tinyint', 'smallint']

dtypes = train_df.dtypes
feature_cols = [c for c, t in dtypes if t in valid_types and c not in ignore_cols]
print(f"   Features ({len(feature_cols)}): {feature_cols}")

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features").setHandleInvalid("skip")

train_vec = assembler.transform(train_df)
val_vec   = assembler.transform(val_df)
test_vec  = assembler.transform(test_df)

train_vec.persist(StorageLevel.MEMORY_AND_DISK)
print(f"   Training Rows: {train_vec.count():,}")

print(f"3. Training {MODEL_NAME}...")

rf = RandomForestRegressor(
    featuresCol="features", 
    labelCol=target_col,
    numTrees=40,         
    maxDepth=10,         
    seed=42,
    subsamplingRate=0.7 
)

start_time = time.time()
model = rf.fit(train_vec)
end_time = time.time()
duration = end_time - start_time

print(f"   Training completed in {duration:.2f} seconds.")

print("4. Processing Results...")

model.write().overwrite().save(os.path.join(model_output_dir, "spark_model"))
save_feature_importance(model, feature_cols, model_output_dir)

test_preds = model.transform(test_vec)
metrics = evaluate_and_log(test_preds, target_col, duration, model_output_dir, MODEL_NAME)

plot_predictions(test_preds, target_col, model_output_dir, MODEL_NAME)

print("\n--- PROCESS FINISHED ---")
spark.stop()

## GBT Regression

In [None]:
MODEL_NAME = "GBT_New_Features"

print(f"--- STARTING: {MODEL_NAME} ---")

spark = SparkSession.builder \
    .appName(MODEL_NAME) \
    .master("local[*]") \
    .config("spark.driver.memory", "10g") \
    .config("spark.executor.memory", "10g") \
    .config("spark.sql.files.maxPartitionBytes", "128m") \
    .config("spark.driver.maxResultSize", "4g") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

model_output_dir = os.path.join(BASE_OUTPUT_DIR, MODEL_NAME)
if os.path.exists(model_output_dir):
    shutil.rmtree(model_output_dir)
os.makedirs(model_output_dir)

train_years = range(2010, 2021) 
val_years   = range(2021, 2023) 
test_years  = range(2023, 2025) 

print("1. Loading Data...")
try:
    train_raw = spark.read.parquet(*get_paths(train_years))
    val_raw   = spark.read.parquet(*get_paths(val_years))
    test_raw  = spark.read.parquet(*get_paths(test_years))
except Exception as e:
    print(f"Error loading data: {e}")
    spark.stop()
    raise e

print("2. Applying Feature Engineering (Month + Solid_PRCP)...")

def add_smart_features(df):
    df = df.withColumn("MONTH", month(col("DATE")))
    df = df.withColumn(
        "Solid_PRCP", 
        when((col("PRCP") > 0) & (col("TEMP") < 2.0), col("PRCP")).otherwise(0.0)
    )
    return df

train_df = add_smart_features(train_raw)
val_df   = add_smart_features(val_raw)
test_df  = add_smart_features(test_raw)

target_col = "SNDP"
ignore_cols = [target_col, "DATE", "STATION", "NAME", "features", "prediction", "FRSHTT"]
valid_types = ['int', 'bigint', 'float', 'double', 'tinyint', 'smallint']

dtypes = train_df.dtypes
feature_cols = [c for c, t in dtypes if t in valid_types and c not in ignore_cols]
print(f"   Features ({len(feature_cols)}): {feature_cols}")

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features").setHandleInvalid("skip")

train_vec = assembler.transform(train_df)
val_vec   = assembler.transform(val_df)
test_vec  = assembler.transform(test_df)

train_vec.persist(StorageLevel.MEMORY_AND_DISK)
print(f"   Training Rows: {train_vec.count():,}")

print(f"3. Training {MODEL_NAME}...")

gbt = GBTRegressor(
    featuresCol="features", 
    labelCol=target_col,
    maxIter=50,         
    maxDepth=5,         
    stepSize=0.1,
    seed=42,
    subsamplingRate=0.7 
)

start_time = time.time()
model = gbt.fit(train_vec)
end_time = time.time()
duration = end_time - start_time

print(f"   Training completed in {duration:.2f} seconds.")

print("4. Processing Results...")

model.write().overwrite().save(os.path.join(model_output_dir, "spark_model"))
save_feature_importance(model, feature_cols, model_output_dir)

test_preds = model.transform(test_vec)
metrics = evaluate_and_log(test_preds, target_col, duration, model_output_dir, MODEL_NAME)

plot_predictions(test_preds, target_col, model_output_dir, MODEL_NAME)

print("\n--- PROCESS FINISHED ---")
spark.stop()

## Hypertuning of Random Forest

In [None]:
BASE_OUTPUT_DIR = "optimized_models"     
MASTER_LOG_FILE = "tuning_comparison.csv" 
BASE_DATA_PATH = "processed_data"    

def get_paths(years):
    return [f"{BASE_DATA_PATH}/climate_{y}.parquet" for y in years]

def save_best_params(cv_model, output_dir):
    """
    Extracts and saves the best hyperparameters found by CrossValidator.
    """
    rf_model = cv_model.bestModel
    
    params = {
        "numTrees": rf_model.getNumTrees,
        "maxDepth": rf_model.getOrDefault("maxDepth"),
        "maxBins": rf_model.getOrDefault("maxBins"),
        "subsamplingRate": rf_model.getOrDefault("subsamplingRate")
    }
    
    with open(os.path.join(output_dir, "best_params.json"), 'w') as f:
        json.dump(params, f, indent=4)
    
    print("\n   [v] WINNING PARAMETERS SAVED:")
    print(json.dumps(params, indent=4))

def evaluate_and_log(predictions, target_col, time_taken, output_dir, model_name):
    """
    Evaluates the model on the final Test set.
    """
    evaluator = RegressionEvaluator(labelCol=target_col, predictionCol="prediction")
    
    r2 = evaluator.setMetricName("r2").evaluate(predictions)
    rmse = evaluator.setMetricName("rmse").evaluate(predictions)
    
    snow_subset = predictions.filter(col(target_col) > 0)
    r2_snow = evaluator.setMetricName("r2").evaluate(snow_subset) if snow_subset.count() > 0 else 0.0

    metrics = {
        "Model": model_name,
        "Time_Min": round(time_taken / 60, 2),
        "R2_Global": round(r2, 4),
        "RMSE_Global": round(rmse, 4),
        "R2_Snow_Only": round(r2_snow, 4)
    }

    file_exists = os.path.isfile(MASTER_LOG_FILE)
    with open(MASTER_LOG_FILE, mode='a', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=metrics.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(metrics)

    print(f"\n--- FINAL TEST RESULTS ({model_name}) ---")
    print(f"R2 Global: {r2:.4f} | R2 Snow Only: {r2_snow:.4f}")
    return metrics

In [None]:
MODEL_NAME = "RandomForest_HyperTuning"
print(f"--- STARTING TUNING: {MODEL_NAME} ---")

spark = SparkSession.builder \
    .appName(MODEL_NAME) \
    .master("local[*]") \
    .config("spark.driver.memory", "10g") \
    .config("spark.executor.memory", "10g") \
    .config("spark.sql.files.maxPartitionBytes", "128m") \
    .config("spark.driver.maxResultSize", "4g") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

model_output_dir = os.path.join(BASE_OUTPUT_DIR, MODEL_NAME)
if os.path.exists(model_output_dir):
    shutil.rmtree(model_output_dir)
os.makedirs(model_output_dir)

train_years = range(2010, 2021) 
test_years  = range(2023, 2025)

print("1. Loading and Transforming Data...")
train_raw = spark.read.parquet(*get_paths(train_years))
test_raw  = spark.read.parquet(*get_paths(test_years))

def add_features(df):
    df = df.withColumn("MONTH", month(col("DATE")))
    df = df.withColumn("Solid_PRCP", when((col("PRCP") > 0) & (col("TEMP") < 2.0), col("PRCP")).otherwise(0.0))
    return df

train_df = add_features(train_raw)
test_df  = add_features(test_raw)

target_col = "SNDP"
ignore_cols = [target_col, "DATE", "STATION", "NAME", "features", "prediction", "FRSHTT"]
valid_types = ['int', 'bigint', 'float', 'double', 'tinyint', 'smallint']

dtypes = train_df.dtypes
feature_cols = [c for c, t in dtypes if t in valid_types and c not in ignore_cols]
print(f"   Features to use: {feature_cols}")

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features").setHandleInvalid("skip")

train_vec = assembler.transform(train_df)
test_vec  = assembler.transform(test_df)

train_vec.persist(StorageLevel.MEMORY_AND_DISK)
print(f"   Training Rows (Internal Train + Val): {train_vec.count():,}")

print("2. Configuring Grid Search...")

rf = RandomForestRegressor(featuresCol="features", labelCol=target_col, seed=42)

paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [50, 100]) \
    .addGrid(rf.maxDepth, [10, 15]) \
    .addGrid(rf.maxBins, [32, 64]) \
    .build()


print(f"   Combinations to test: {len(paramGrid)}")

evaluator = RegressionEvaluator(labelCol=target_col, predictionCol="prediction", metricName="rmse")

cv = CrossValidator(
    estimator=rf,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=3,
    parallelism=1
)


print("3. Running Cross-Validation (Please wait)...")
start_time = time.time()

cv_model = cv.fit(train_vec)

end_time = time.time()
duration = end_time - start_time
print(f"   Tuning completed in {duration/60:.2f} minutes!")

print("4. Analyzing the Champion Model...")

best_model = cv_model.bestModel

best_model.write().overwrite().save(os.path.join(model_output_dir, "spark_model_winner"))

save_best_params(cv_model, model_output_dir)

test_preds = best_model.transform(test_vec)
evaluate_and_log(test_preds, target_col, duration, model_output_dir, MODEL_NAME)

print("\n--- PROCESS FINISHED ---")
spark.stop()