MLP Model using PyTorch + GPU Support
=====================================

# Summary

- Here we build Distributed MLP models. 


In [0]:
"""
cv.py (simplified, CUSTOM-only, no parametrization)

Assumptions:
- Folds were created from split.py with N_FOLDS = 3 and CREATE_TEST_FOLD = True
- Therefore total fold indices written = 4:
    FOLD_1_VAL, FOLD_2_VAL, FOLD_3_VAL, FOLD_4_TEST
- Files live in:
    dbfs:/mnt/mids-w261/student-groups/Group_4_2/processed
- File naming:
    OTPW_CUSTOM_{VERSION}_FOLD_{i}_{TRAIN|VAL|TEST}.parquet
"""

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.functions import col
from pyspark.ml.evaluation import RegressionEvaluator
import pandas as pd
from pyspark import StorageLevel



# -----------------------------
# HARD-CODED GLOBALS
# -----------------------------
FOLDER_PATH = "dbfs:/mnt/mids-w261/student-groups/Group_4_2/processed"
SOURCE = "CUSTOM"
VERSIONS = ["3M", "12M", "60M"]

# 3 CV folds + 1 test fold = 4 total fold indices
TOTAL_FOLDS = 4


# -----------------------------
# DATA LOADER (CUSTOM ONLY)
# -----------------------------
class FlightDelayDataLoader:
    def __init__(self):
        self.folds = {}
        self.numerical_features = [
            'hourlyprecipitation',
            'hourlysealevelpressure',
            'hourlyaltimetersetting',
            'hourlywetbulbtemperature',
            'hourlystationpressure',
            'hourlywinddirection',
            'hourlyrelativehumidity',
            'hourlywindspeed',
            'hourlydewpointtemperature',
            'hourlydrybulbtemperature',
            'hourlyvisibility',
            'crs_elapsed_time',
            'distance',
            'elevation',
        ]

    def _cast_numerics(self, df):
        """Safely cast all configured numeric columns to doubles."""
        NULL_PAT = r'^(NA|N/A|NULL|null|None|none|\\N|\\s*|\\.|M|T)$'
        
        for colname in self.numerical_features:
            if colname in df.columns:
                df = df.withColumn(
                    colname,
                    F.regexp_replace(F.col(colname).cast("string"), NULL_PAT, "")
                    .cast("double")
                )
        
        # Explicitly cast labels to expected numeric types
        if "DEP_DELAY" in df.columns:
            df = df.withColumn("DEP_DELAY", col("DEP_DELAY").cast("double"))
        if "DEP_DEL15" in df.columns:
            df = df.withColumn("DEP_DEL15", col("DEP_DEL15").cast("int"))
        if "SEVERE_DEL60" in df.columns:
            df = df.withColumn("SEVERE_DEL60", col("SEVERE_DEL60").cast("int"))
        
        return df

    def _load_parquet(self, name):
        spark = SparkSession.builder.getOrCreate()
        df = spark.read.parquet(f"{FOLDER_PATH}/{name}.parquet")
        df = self._cast_numerics(df)
        return df

    def _load_version(self, version):
        folds = []
        for fold_idx in range(1, TOTAL_FOLDS + 1):
            train_name = f"OTPW_{SOURCE}_{version}_FOLD_{fold_idx}_TRAIN"
            train_df = self._load_parquet(train_name)

            if fold_idx < TOTAL_FOLDS:
                val_name = f"OTPW_{SOURCE}_{version}_FOLD_{fold_idx}_VAL"
                val_df = self._load_parquet(val_name)
                folds.append((train_df, val_df))
            else:
                test_name = f"OTPW_{SOURCE}_{version}_FOLD_{fold_idx}_TEST"
                test_df = self._load_parquet(test_name)
                folds.append((train_df, test_df))

        return folds

    def load_version(self, version):
        if version not in self.folds:
            self.folds[version] = self._load_version(version)
        return self.folds[version]

    def get_version(self, version):
        return self.load_version(version)
    


# -----------------------------
# EVALUATOR (NULL-SAFE RMSE)
# -----------------------------
class FlightDelayEvaluator:
    def __init__(
        self,
        prediction_col="prediction",
        numeric_label_col="DEP_DELAY",
        binary_label_col="DEP_DEL15",
        severe_label_col="SEVERE_DEL60",
    ):
        self.prediction_col = prediction_col
        self.numeric_label_col = numeric_label_col
        self.binary_label_col = binary_label_col
        self.severe_label_col = severe_label_col

        self.rmse_evaluator = RegressionEvaluator(
            predictionCol=prediction_col,
            labelCol=numeric_label_col,
            metricName="rmse"
        )

    def calculate_rmse(self, predictions_df):
        # Drop any residual nulls before RegressionEvaluator sees them
        clean = predictions_df.dropna(
            subset=[self.numeric_label_col, self.prediction_col]
        )
        return self.rmse_evaluator.evaluate(clean)

    def _calculate_classification_metrics(self, predictions_df, threshold, label_col):
        # Null-safe for classification too
        df = predictions_df.dropna(subset=[self.prediction_col, label_col])

        pred_binary_col = f"pred_binary_{threshold}"
        df = df.withColumn(
            pred_binary_col,
            F.when(F.col(self.prediction_col) >= threshold, 1).otherwise(0)
        )

        tp = df.filter((F.col(pred_binary_col) == 1) & (F.col(label_col) == 1)).count()
        fp = df.filter((F.col(pred_binary_col) == 1) & (F.col(label_col) == 0)).count()
        tn = df.filter((F.col(pred_binary_col) == 0) & (F.col(label_col) == 0)).count()
        fn = df.filter((F.col(pred_binary_col) == 0) & (F.col(label_col) == 1)).count()

        total = tp + fp + tn + fn
        precision = tp / (tp + fp) if (tp + fp) else 0.0
        recall = tp / (tp + fn) if (tp + fn) else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
        accuracy = (tp + tn) / total if total else 0.0

        return dict(tp=tp, fp=fp, tn=tn, fn=fn,
                    precision=precision, recall=recall, f1=f1, accuracy=accuracy)

    def calculate_otpa_metrics(self, predictions_df):
        return self._calculate_classification_metrics(
            predictions_df, threshold=15, label_col=self.binary_label_col
        )["accuracy"]

    def calculate_sddr_metrics(self, predictions_df):
        return self._calculate_classification_metrics(
            predictions_df, threshold=60, label_col=self.severe_label_col
        )["recall"]

    def evaluate(self, predictions_df):

        preds = predictions_df.persist(StorageLevel.MEMORY_AND_DISK)
        rmse = self.calculate_rmse(preds)
        otpa = self.calculate_otpa_metrics(preds)
        sddr = self.calculate_sddr_metrics(preds)
        preds.unpersist()
        return {"rmse": rmse, "otpa": otpa, "sddr": sddr}


# -----------------------------
# CROSS-VALIDATOR (NO PARAMS)
# -----------------------------
class FlightDelayCV:
    def __init__(self, estimator, version, dataloader = None):
        self.estimator = estimator
        self.version = version

        self.data_loader = dataloader or FlightDelayDataLoader()
        self.folds = self.data_loader.get_version(version)
        
        self.evaluator = FlightDelayEvaluator()
        self.train_metrics = []
        self.val_metrics = []
        self.models = []
        self.test_train_metric = None
        self.test_metric = None
        self.test_model = None

    def fit(self):
        # CV folds only (exclude last test fold)
        for fold_idx, (train_df, val_df) in enumerate(self.folds[:-1]):
            model = self.estimator.fit(train_df)
            
            # Evaluate on training data
            train_preds = model.transform(train_df)
            train_metric = self.evaluator.evaluate(train_preds)
            self.train_metrics.append(train_metric)
            
            # Evaluate on validation data
            val_preds = model.transform(val_df)
            val_metric = self.evaluator.evaluate(val_preds)
            self.val_metrics.append(val_metric)
            
            self.models.append(model)

        # Create combined DataFrame with train and val metrics
        train_df = pd.DataFrame(self.train_metrics)
        val_df = pd.DataFrame(self.val_metrics)
        
        # Rename indices to include train/val labels
        train_df.index = [f"{i}-train" for i in range(len(train_df))]
        val_df.index = [f"{i}-val" for i in range(len(val_df))]
        
        # Interleave train and val rows
        combined_rows = []
        for i in range(len(self.train_metrics)):
            combined_rows.append(train_df.iloc[i])
            combined_rows.append(val_df.iloc[i])
        
        m = pd.DataFrame(combined_rows)
        
        # Add mean and std for train and val separately
        train_mean = train_df.mean()
        train_mean.name = "mean-train"
        train_std = train_df.std()
        train_std.name = "std-train"
        
        val_mean = val_df.mean()
        val_mean.name = "mean-val"
        val_std = val_df.std()
        val_std.name = "std-val"
        
        m = pd.concat([m, train_mean.to_frame().T, train_std.to_frame().T,
                       val_mean.to_frame().T, val_std.to_frame().T])
        
        return m

    def evaluate(self):
        # deprecated in favor of `test`
        # kept for backwards compatibility
        print("`evaluate()` is deprecated! Use `test()` instead")
        return self.test()
    
    def test(self):
        # does the same thing as evaluate
        train_df, test_df = self.folds[-1]
        self.test_model = self.estimator.fit(train_df)
        
        # Evaluate on training data
        train_preds = self.test_model.transform(train_df)
        self.test_train_metric = self.evaluator.evaluate(train_preds)
        
        # Evaluate on test data
        test_preds = self.test_model.transform(test_df)
        self.test_metric = self.evaluator.evaluate(test_preds)
        
        # Return both as a DataFrame
        results = pd.DataFrame([self.test_train_metric, self.test_metric],
                              index=["train", "test"])
        return results


In [0]:

# Spark Settings
# to avoid OOM Error
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "16800") 

# import importlib.util
# import sys

# # Load cv module directly from file path
# cv_path = "/Workspace/Shared/Team 4_2/flight-departure-delay-predictive-modeling/notebooks/Cross Validator/cv.py"
# spec = importlib.util.spec_from_file_location("cv", cv_path)
# cv = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(cv)


import uuid
from pathlib import Path
from pyspark.sql.functions import col
from pyspark.sql.types import DoubleType


from pyspark.sql import SparkSession, functions as F
from pathlib import Path
from pyspark.ml.feature import (
    Imputer, StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
)
from pyspark.ml import Pipeline
from pyspark.ml.functions import vector_to_array
from pyspark.sql.types import DoubleType

# >>> PYTORCH AND DISTRIBUTOR IMPORTS <<<
import torch
import torch.nn as nn
import torch.optim as optim
from pyspark.ml.torch.distributor import TorchDistributor 
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# --- PYTORCH TRAIN FUNCTION (RUNS ON WORKERS) ---

In [0]:
data_loader = FlightDelayDataLoader()


In [0]:
# =====================================================
# PYTORCH MLP REGRESSOR INTEGRATED WITH TORCHDISTRIBUTOR
# =====================================================

def train_fn(params):
    """
    Run on each worker launched by TorchDistributor.
    """
    import os
    import sys
    import traceback
    
    # 1. Error Handling Wrapper
    try:
        import glob
        import random
        import torch
        import torch.nn as nn
        import torch.optim as optim
        import torch.distributed as dist
        from torch.nn.parallel import DistributedDataParallel as DDP
        from torch.utils.data import DataLoader, IterableDataset
        import numpy as np
        import pandas as pd

        # --- Local Model Definition ---
        class PyTorchMLPRegressor_Worker(nn.Module):
            def __init__(self, input_dim, hidden_layers, dropout_rate=0.3):
                super().__init__()
                layers = []
                in_features = input_dim
                for units in hidden_layers:
                    layers.append(nn.Linear(in_features, units))
                    layers.append(nn.BatchNorm1d(units))
                    layers.append(nn.ReLU())
                    layers.append(nn.Dropout(dropout_rate))
                    in_features = units
                layers.append(nn.Linear(in_features, 1))
                self.network = nn.Sequential(*layers)

            def forward(self, x):
                return self.network(x).squeeze(1)

        # --- DDP Init ---
        backend = "nccl" if params["use_gpu"] else "gloo"
        dist.init_process_group(backend=backend)

        if params["use_gpu"]:
            local_rank = int(os.environ["LOCAL_RANK"])
            device = torch.device(f"cuda:{local_rank}")
            device_ids = [local_rank]
        else:
            device = torch.device("cpu")
            device_ids = None

        rank = dist.get_rank()
        world_size = dist.get_world_size()

        # Convert dbfs:/ URI to /dbfs/ path for local IO
        def dbfs_to_local(path: str) -> str:
            if path.startswith("dbfs:/"):
                return path.replace("dbfs:/", "/dbfs/")
            return path

        train_path_dbfs = params["train_path"]
        train_path_local = dbfs_to_local(train_path_dbfs)

        # --- Dataset ---
        class ParquetFlightIterableDataset(IterableDataset):
            def __init__(self, path_local: str, rank: int, world_size: int):
                if os.path.isdir(path_local):
                    all_files = sorted(glob.glob(os.path.join(path_local, "*.parquet")))
                else:
                    all_files = [path_local]

                if not all_files:
                    raise FileNotFoundError(f"No parquet files found at: {path_local}")

                # Shard files by rank so each process reads a disjoint subset
                self.files = all_files[rank::world_size]

            def __iter__(self):
                # Shuffle files at the START of every epoch (every time __iter__ is called)
                random.shuffle(self.files)
                
                for f in self.files:
                    try:
                        # Load ONE file at a time
                        pdf = pd.read_parquet(f, columns=["features_arr", "DEP_DELAY"])
                        
                        if len(pdf) == 0:
                            continue

                        X = np.stack(pdf["features_arr"].to_numpy()).astype(np.float32, copy=False)
                        y = pdf["DEP_DELAY"].to_numpy(dtype=np.float32, copy=False)

                        for i in range(len(y)):
                            yield torch.from_numpy(X[i]), torch.tensor(y[i])
                    except Exception as e:
                        print(f"Skipping bad file {f}: {e}")
                        continue

        # Initialize Dataset
        dataset = ParquetFlightIterableDataset(train_path_local, rank, world_size)

        # Initialize DataLoader
        # REMOVED: sampler=... (Incompatible with IterableDataset)
        # REMOVED: shuffle=True (Incompatible with IterableDataset, handled inside __iter__)
        loader = DataLoader(
            dataset,
            batch_size=params["batch_size"],
            drop_last=False,
            num_workers=0,
            pin_memory=params["use_gpu"],
        )

        # --- Model / Optimizer ---
        model = PyTorchMLPRegressor_Worker(
            params["input_dim"],
            params["hidden_layers"],
            params["dropout_rate"],
        ).to(device)

        ddp_model = DDP(model, device_ids=device_ids)
        optimizer = optim.Adam(ddp_model.parameters(), lr=params["learning_rate"])
        criterion = nn.MSELoss()

        # --- Training Loop ---
        for epoch in range(params["epochs"]):
            ddp_model.train()
            
            # REMOVED: sampler.set_epoch(epoch) -> Not needed for IterableDataset
            
            total_loss = 0.0
            num_batches = 0

            for xb, yb in loader:
                xb = xb.to(device)
                yb = yb.to(device)

                optimizer.zero_grad()
                out = ddp_model(xb)
                loss = criterion(out, yb)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1

            # Log only on rank 0
            if rank == 0:
                avg_loss = total_loss / max(num_batches, 1)
                print(f"Epoch {epoch+1}/{params['epochs']} - Loss: {avg_loss:.4f}")

        # --- Save Model ---
        dist.barrier()
        if rank == 0:
            torch.save(model.state_dict(), params["model_path"])

        dist.destroy_process_group()

    except Exception:
        # Print full traceback to driver logs if something crashes
        print(traceback.format_exc())
        sys.exit(1)

# def train_fn(X, y, params):
#     """
#     Run on each worker launched by TorchDistributor.
#     No Spark code in here.
#     """
#     import os
#     import torch
#     import torch.nn as nn
#     import torch.optim as optim
#     import torch.distributed as dist
#     from torch.nn.parallel import DistributedDataParallel as DDP
#     from torch.utils.data import TensorDataset, DataLoader, DistributedSampler
#     import numpy as np

#     # --- local model definition ---
#     class PyTorchMLPRegressor_Worker(nn.Module):
#         def __init__(self, input_dim, hidden_layers, dropout_rate=0.3):
#             super().__init__()
#             layers = []
#             in_features = input_dim
#             for units in hidden_layers:
#                 layers.append(nn.Linear(in_features, units))
#                 layers.append(nn.BatchNorm1d(units))
#                 layers.append(nn.ReLU())
#                 layers.append(nn.Dropout(dropout_rate))
#                 in_features = units
#             layers.append(nn.Linear(in_features, 1))
#             self.network = nn.Sequential(*layers)

#         def forward(self, x):
#             return self.network(x).squeeze(1)

#     # --- DDP / process group init ---
#     backend = "nccl" if params["use_gpu"] else "gloo"
#     dist.init_process_group(backend=backend)

#     if params["use_gpu"]:
#         local_rank = int(os.environ["LOCAL_RANK"])
#         device = torch.device(f"cuda:{local_rank}")
#         device_ids = [local_rank]
#     else:
#         device = torch.device("cpu")
#         device_ids = None

#     rank = dist.get_rank()
#     world_size = dist.get_world_size()

#     # --- dataset & sampler ---
#     X = torch.from_numpy(np.asarray(X)).float()
#     y = torch.from_numpy(np.asarray(y)).float()
#     dataset = TensorDataset(X, y)
#     sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)

#     loader = DataLoader(
#         dataset,
#         batch_size=params["batch_size"],
#         sampler=sampler,
#         drop_last=False,
#     )

#     # --- model / optimizer / loss ---
#     model = PyTorchMLPRegressor_Worker(
#         params["input_dim"],
#         params["hidden_layers"],
#         params["dropout_rate"],
#     ).to(device)

#     ddp_model = DDP(model, device_ids=device_ids)
#     optimizer = optim.Adam(ddp_model.parameters(), lr=params["learning_rate"])
#     criterion = nn.MSELoss()

#     # --- training loop ---
#     for epoch in range(params["epochs"]):
#         ddp_model.train()
#         sampler.set_epoch(epoch)  # important for DistributedSampler
#         total_loss = 0.0
#         num_batches = 0

#         for xb, yb in loader:
#             xb = xb.to(device)
#             yb = yb.to(device)

#             optimizer.zero_grad()
#             out = ddp_model(xb)
#             loss = criterion(out, yb)
#             loss.backward()
#             optimizer.step()

#             total_loss += loss.item()
#             num_batches += 1

#         # log from rank 0 only
#         if rank == 0:
#             avg_loss = total_loss / max(num_batches, 1)
#             print(f"Epoch {epoch+1}/{params['epochs']} - Loss: {avg_loss:.4f}")

#     # --- save model once on rank 0 ---
#     dist.barrier()
#     if rank == 0:
#         torch.save(model.state_dict(), params["model_path"])

#     dist.destroy_process_group()


# --- DRIVER-SIDE MLP MODEL DEFINITION ---

class PyTorchMLPRegressor(nn.Module):
    """
    A standard PyTorch MLP model for regression.
    """
    def __init__(self, input_dim, hidden_layers, dropout_rate=0.3):
        super().__init__()
        
        layers = []
        in_features = input_dim
        
        for units in hidden_layers:
            # Dense Layer
            layers.append(nn.Linear(in_features, units))
            
            # Batch Normalization
            layers.append(nn.BatchNorm1d(units))
            
            # Activation and Regularization
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            
            in_features = units
            
        # Output layer (Regression: 1 unit, no activation)
        layers.append(nn.Linear(in_features, 1))
        
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x).squeeze(1)  # Squeeze to match target shape


# --- PYTORCH SPARK ESTIMATOR ---

# class SparkPyTorchEstimator:
#     def __init__(self, hidden_layers=None, dropout_rate=0.3, learning_rate=0.001, 
#                  batch_size=256, epochs=30, num_processes=None, infer_batch_size=None):
        
#         self.hidden_layers = hidden_layers or [128, 64]
#         self.dropout_rate = dropout_rate
#         self.learning_rate = learning_rate
#         self.batch_size = batch_size
#         self.epochs = epochs

#         # NEW: separate inference batch size (defaults to training batch size)
#         self.infer_batch_size = infer_batch_size or batch_size
        
#         # 1. Determine GPU availability
#         self.use_gpu = torch.cuda.is_available()
        
#         # 2. Determine the number of processes (Fixes NoneType error)
#         if num_processes is None:
#             if self.use_gpu:
#                 # If GPU is available, use one process per GPU
#                 self.num_processes = torch.cuda.device_count()
#             else:
#                 # If CPU is used, default to 1 process if not specified
#                 self.num_processes = 1
#         else:
#             self.num_processes = num_processes
            
#         # Ensure minimum of 1 process
#         if self.num_processes < 1:
#             self.num_processes = 1

#         self.model_path = "/dbfs/tmp/torch_mlp_state_dict.pth"
#         self.input_dim = None
        
#     def fit(self, df):
#         """
#         Uses TorchDistributor for distributed training.
#         """
#         if self.input_dim is None:
#             first_row = df.select(vector_to_array("scaled_features")).head()
#             self.input_dim = len(first_row[0])

#         df_train = df.select(
#             vector_to_array("scaled_features").alias("features_arr"),
#             F.col("DEP_DELAY")
#         ).dropna()

#         #  HACK HACK HACK HACK use 10% sample
#         df_train = df_train.sample(withReplacement=False, fraction=0.1, seed=42)
 
#         # Repartition for better shuffle balance (not used directly by TorchDistributor)
#         num_partitions = self.num_processes * 4
#         df_train = df_train.repartition(num_partitions)

#         # Collect as pandas and build NumPy arrays
#         pdf = df_train.toPandas()
#         X = np.stack(pdf["features_arr"].values)
#         y = pdf["DEP_DELAY"].values.astype(np.float32)

#         params = {
#             "input_dim": self.input_dim,
#             "hidden_layers": self.hidden_layers,
#             "dropout_rate": self.dropout_rate,
#             "learning_rate": self.learning_rate,
#             "batch_size": self.batch_size,
#             "epochs": self.epochs,
#             "use_gpu": self.use_gpu,
#             "model_path": self.model_path,
#         }

#         distributor = TorchDistributor(
#             num_processes=self.num_processes,
#             local_mode=False,
#             use_gpu=self.use_gpu,
#         )

#         # Ensure the parent directory for the model file exists
#         model_path_obj = Path(self.model_path)
#         model_path_obj.parent.mkdir(parents=True, exist_ok=True)
#         print(f"Starting distributed training. Processes: {self.num_processes}, CUDA: {self.use_gpu}")

#         # NOTE: args are exactly what train_fn expects: (X, y, params)
#         distributor.run(train_fn, X, y, params)

#         # load model on driver
#         self.trained_model = PyTorchMLPRegressor(
#             self.input_dim, self.hidden_layers, self.dropout_rate
#         )
#         self.trained_model.load_state_dict(torch.load(self.model_path)) 
#         self.trained_model.eval()

#         return self

class SparkPyTorchEstimator:
    def __init__(self, hidden_layers=None, dropout_rate=0.3, learning_rate=0.001, 
                 batch_size=256, epochs=30, num_processes=None, infer_batch_size=None):
        
        self.hidden_layers = hidden_layers or [128, 64]
        self.dropout_rate = dropout_rate
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.epochs = epochs

        self.infer_batch_size = infer_batch_size or batch_size
        
        self.use_gpu = torch.cuda.is_available()
        
        if num_processes is None:
            if self.use_gpu:
                self.num_processes = torch.cuda.device_count()
            else:
                self.num_processes = 1
        else:
            self.num_processes = max(1, num_processes)

        self.model_path = "/dbfs/tmp/torch_mlp_state_dict.pth"
        self.input_dim = None
        
    def fit(self, df):
        """
        Uses TorchDistributor for distributed training, without collecting
        the entire dataset to the driver.
        """
        # Optional: clear old Spark caches to avoid leftover 60M stuff
        df.sparkSession.catalog.clearCache()

        # Determine input_dim from a single row
        if self.input_dim is None:
            sample = df.select(vector_to_array("scaled_features").alias("features_arr")).limit(1).collect()
            if not sample:
                raise ValueError("No rows in training dataframe.")
            self.input_dim = len(sample[0]["features_arr"])

        # Prepare train DF (features_arr + label)
        df_train = df.select(
            vector_to_array("scaled_features").alias("features_arr"),
            F.col("DEP_DELAY").cast(DoubleType())
        ).dropna(subset=["features_arr", "DEP_DELAY"])


        unique_id = str(uuid.uuid4())
        train_path = f"dbfs:/tmp/mlp_train_{unique_id}"
        
        
        # Write training data to sharded parquet; this is streaming and
        # does NOT collect everything to the driver.
    
        num_shards = max(self.num_processes * 200, 200)

        (
            df_train
            .repartition(num_shards)
            .write
            .option("maxRecordsPerFile", 200_000)  # adjust lower if needed
            .mode("overwrite")
            .parquet(train_path)
        )

        
        params = {
            "input_dim": self.input_dim,
            "hidden_layers": self.hidden_layers,
            "dropout_rate": self.dropout_rate,
            "learning_rate": self.learning_rate,
            "batch_size": self.batch_size,
            "epochs": self.epochs,
            "use_gpu": self.use_gpu,
            "model_path": self.model_path,
            "train_path": train_path,
        }

        distributor = TorchDistributor(
            num_processes=self.num_processes,
            local_mode=False,
            use_gpu=self.use_gpu,
        )

        # Ensure model path directory exists
        model_path_obj = Path(self.model_path)
        model_path_obj.parent.mkdir(parents=True, exist_ok=True)

        print(f"Starting distributed training. Processes: {self.num_processes}, CUDA: {self.use_gpu}")
        distributor.run(train_fn, params)

        # load model on driver
        self.trained_model = PyTorchMLPRegressor(
            self.input_dim, self.hidden_layers, self.dropout_rate
        )
        self.trained_model.load_state_dict(torch.load(self.model_path))
        self.trained_model.eval()

        return self

    def transform(self, df):
        """
        Uses mapInPandas for parallel inference on workers, with explicit batching
        to avoid OOM during evaluation.
        """
        if not hasattr(self, 'trained_model'):
            raise ValueError("Model not fitted.")
        
        # Serialize model weights for broadcast
        model_state_dict = self.trained_model.state_dict()
        schema = df.schema.add("prediction", DoubleType())

        # Determine if workers should use GPU (if driver used GPU)
        use_gpu = self.use_gpu
        infer_batch_size = self.infer_batch_size

        def predict_partition_full(iterator):
            # 1. Setup device (GPU for inference if available)
            device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
            
            # 2. Initialize and load model ONCE per worker task
            worker_model = PyTorchMLPRegressor(
                self.input_dim, self.hidden_layers, self.dropout_rate
            ).to(device)
            worker_model.load_state_dict(model_state_dict)
            worker_model.eval()
            
            with torch.no_grad():
                for pdf_batch in iterator:
                    # Extract features from this Spark partition batch
                    X_np = np.stack(pdf_batch["features_arr"].values)
                    n = X_np.shape[0]

                    # Pre-allocate predictions on CPU as float64 to match DoubleType
                    preds_all = np.empty(n, dtype=np.float64)

                    # --- BATCHED INFERENCE TO AVOID OOM ---
                    for start in range(0, n, infer_batch_size):
                        end = min(start + infer_batch_size, n)
                        inputs = torch.from_numpy(X_np[start:end]).float().to(device)
                        preds = worker_model(inputs).cpu().numpy().astype(np.float64)
                        preds_all[start:end] = preds

                    pdf_batch["prediction"] = preds_all
                    yield pdf_batch.drop(columns=["features_arr"])

        # Add the array column temporarily
        df_with_arr = df.withColumn("features_arr", vector_to_array("scaled_features"))
        
        # Final transform
        return df_with_arr.mapInPandas(predict_partition_full, schema=schema)
    

# =====================================================
# 2. MLP PIPELINE WRAPPER
# =====================================================

class MLPFlightDelayPipeline:
    """
    Wrapper that combines Spark preprocessing + PyTorch MLP into a single estimator.
    """
    
    def __init__(
        self,
        categorical_features,
        numerical_features,
        mlp_params=None,
    ):
        self.categorical_features = categorical_features
        self.numerical_features = numerical_features
        self.mlp_params = mlp_params or {}
        
        self.preprocessing_pipeline = None
        self.pytorch_estimator = None
        
    def _build_preprocessing_pipeline(self):
        imputer = Imputer(
            inputCols=self.numerical_features,
            outputCols=[f"{col}_IMPUTED" for col in self.numerical_features],
            strategy="mean"
        )
        
        indexer = StringIndexer(
            inputCols=self.categorical_features,
            outputCols=[f"{col}_INDEX" for col in self.categorical_features],
            handleInvalid="keep"
        )
        
        encoder = OneHotEncoder(
            inputCols=[f"{col}_INDEX" for col in self.categorical_features],
            outputCols=[f"{col}_VEC" for col in self.categorical_features],
            dropLast=False
        )
        
        assembler = VectorAssembler(
            inputCols=[f"{col}_VEC" for col in self.categorical_features] + 
                      [f"{col}_IMPUTED" for col in self.numerical_features],
            outputCol="features",
            handleInvalid="skip"
        )
        
        scaler = StandardScaler(
            inputCol="features",
            outputCol="scaled_features",
            withMean=True,
            withStd=True
        )
        
        self.preprocessing_pipeline = Pipeline(
            stages=[imputer, indexer, encoder, assembler, scaler]
        )
        
        return self.preprocessing_pipeline
    
    def fit(self, df):
        # Build and fit preprocessing pipeline
        if self.preprocessing_pipeline is None:
            self._build_preprocessing_pipeline()
            # Ensure numerical columns are DoubleType before fitting the Imputer
            temp_df = df
            for col_name in self.numerical_features:
                temp_df = temp_df.withColumn(col_name, F.col(col_name).cast(DoubleType()))
                
            self.preprocessing_pipeline = self.preprocessing_pipeline.fit(temp_df)
        
        # Transform training data
        preprocessed = self.preprocessing_pipeline.transform(df)
        
        # Build and fit PyTorch Estimator
        self.pytorch_estimator = SparkPyTorchEstimator(**self.mlp_params)
        self.pytorch_estimator.fit(preprocessed)
        
        return self
    
    def transform(self, df):
        if self.preprocessing_pipeline is None or self.pytorch_estimator is None:
            raise ValueError("Pipeline not fitted yet. Call fit() first.")
        
        # Apply preprocessing
        preprocessed = self.preprocessing_pipeline.transform(df)
        
        # Generate predictions
        predictions_df = self.pytorch_estimator.transform(preprocessed)
        
        return predictions_df
    

# =====================================================
# 4. USAGE WITH FLIGHTDELAYCV
# =====================================================

# Feature definitions
categorical_features = [
    'day_of_week', 'op_carrier', 'dep_time_blk', 'arr_time_blk', 'day_of_month', 'month'
]

numerical_features = [
    'hourlyprecipitation', 'hourlysealevelpressure', 'hourlyaltimetersetting',
    'hourlywetbulbtemperature', 'hourlystationpressure', 'hourlywinddirection',
    'hourlyrelativehumidity', 'hourlywindspeed', 'hourlydewpointtemperature',
    'hourlydrybulbtemperature', 'hourlyvisibility', 'crs_elapsed_time', 'distance', 'elevation'
]


# PyTorch hyperparameters (updated for TorchDistributor)
mlp_params = {
    'hidden_layers': [512, 256, 128],
    'dropout_rate': 0.1,
    'learning_rate': 0.005,
    'batch_size': 128,
    'epochs': 30,
    # OPTIONAL: different batch size for inference to be extra safe
    'infer_batch_size': 128,
    # Optional: Set num_processes to override automatic GPU/CPU detection
    # 'num_processes': 4
}

# Initialize MLP pipeline
mlp_pipeline = MLPFlightDelayPipeline(
    categorical_features=categorical_features,
    numerical_features=numerical_features,
    mlp_params=mlp_params,
)

# Your FlightDelayCV usage (assuming cv.FlightDelayCV is defined/imported elsewhere)
crossvalidator = FlightDelayCV(
    estimator=mlp_pipeline,
    version="60M"
)

# Run cross-validation
cv_results = crossvalidator.fit()
print("Cross-Validation Results:")
print(cv_results)
display(cv_results)

# # Evaluate on held-out test fold
# test_results = crossvalidator.evaluate()
# print("\nTest Set Results:")
# print(test_results)
 