## Import Statements and Dependencies

This cell imports all the necessary modules and libraries required for the pipeline.
These imports set up the environment for data ingestion, preprocessing, model training, prediction, and evaluation.

In [9]:
from __future__ import annotations

import collections
import json
import logging
import math
import os
import pathlib
import pickle
import random
import time
from typing import Any, Dict, Tuple

import numpy as np
import pandas as pd
from pyspark.sql import DataFrame, Row, SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import (
    DoubleType,
    IntegerType,
    LongType,
)

from aeon.classification.distance_based import (
    ProximityForest,
    ProximityTree,
)

from data_ingestion import DataIngestion
from evaluation import Evaluator

from utilities import (
    compute_min_max,
    randomSplit_dist,
    randomSplit_stratified_via_sampleBy,
    show_compact,
)



## Data Preprocessing

The `Preprocessor` class prepares raw ECG time series data for parallel training. It performs three primary tasks:

1. **Missing Value Handling**  
   Drops rows where all columns are null, ensuring clean data for modelling.

2. **Repartitioning**  
   - **Balanced (Local Model)**: Uses stratified sampling to maintain class distributions across partitions. Partition IDs are assigned using a randomised window function and used to repartition the dataset.
   - **Unbalanced (Global Model)**: Applies random repartitioning for efficiency without enforcing label stratification.

3. **Normalization**  
   Applies min-max scaling to feature columns using precomputed global min-max values. This scales each feature to the [0,1] range, which is essential for distance-based classifiers like Proximity Trees.

The full preprocessing pipeline is executed through the `run_preprocessing()` method, which dynamically selects the appropriate partitioning strategy based on the configuration. It also supports preservation of partition IDs for downstream tasks when required.


In [10]:
class Preprocessor:
    """
    This class cleans up our ECG data.
    It handles missing rows, splits the label from the features, and does a simple normalization on the feature columns.
    It returns a Spark DataFrame ready for training.
    """

    def __init__(self, config: dict):
        self.config = config

    def handle_missing_values(self, df: DataFrame) -> DataFrame:
        """ Drops rows where every column is null """
        return df.dropna(how="all")
   
    
    def normalize(self, df: DataFrame, min_max: dict, preserve_partition_id: bool = False) -> DataFrame:
        """
        Normalizes feature columns using precomputed global min and max values.
        """
        feature_cols = [col for col in df.columns if (col != "label" and col != "_partition_id")]
        
        normalized_cols = []
        for col in feature_cols:
            min_val, max_val = min_max[col]
            if max_val != min_val:
                expr = (F.col(col) - F.lit(min_val)) / (F.lit(max_val - min_val))
            else:
                expr = F.lit(0.0)  # if all values identical, set to zero or keep original
            normalized_cols.append(expr.alias(col))
        
        cols_to_select = ["label"]
        if "_partition_id" in df.columns:
            cols_to_select.append("_partition_id")

        return df.select(*normalized_cols, *cols_to_select)

    def _repartition_data_NotBalanced(self, df: DataFrame, preserve_partition_id: bool = False) -> DataFrame:
        if "num_partitions" in self.config:
            new_parts = self.config["num_partitions"]  # 
            #self.logger.info(f"Repartitioning data to {new_parts} parts")
            return df.repartition(new_parts)
        return df
    
    def _repartition_data_Balanced(self, df: DataFrame, preserve_partition_id: bool = False) -> DataFrame:
        
        if ("num_partitions" in self.config["local_model_config"] \
            or "num_partitions" in self.config["global_model_config"]) \
            and "label_col" in self.config:
            
            if self.config["local_model_config"]["test_local_model"] is True:
                num_parts = self.config["local_model_config"]["num_partitions"]
            else:
                num_parts = self.config["global_model_config"]["num_partitions"]
                
            label_col = self.config["label_col"]
            # self.logger.info(f"Stratified repartitioning into {num_parts} partitions")
            
            # Assign partition IDs (0 to num_parts-1 per class)
            # Subtracting 1 so that modulo is computed from 0
            window = Window.partitionBy(label_col) \
                            .orderBy(F.rand())          #  one shuffles to group all rows of each label together so we can number them
                            
            df = df.withColumn("_partition_id", ((F.row_number().over(window) - 1) % num_parts).cast("int"))
            
            # Force exact number of partitions using partition_id
            df = df.repartition(num_parts, F.col("_partition_id"))          # one shuffles to repartition by _partition_id to ensure we have num_parts partitions
            print(f'Repartitioning to <<<< {num_parts} >>>> workers - partitions.')
            
            if not preserve_partition_id:
                df = df.drop("_partition_id")
            return df
            
        return df

    def run_preprocessing(self, df: DataFrame, min_max) -> DataFrame:
        """
        Args:
            df (DataFrame): Input DataFrame to be preprocessed. pyspark Sql DataFrame.
        Returns:
            DataFrame: Preprocessed DataFrame ready for training.
        
            
        Run all preprocessing steps in order:
         1. Drop rows that are completely null.
         2. Repartition the data : shuffle the data to balance the partitions.
         3. Normalize the feature columns.
        """
        
        df = self.handle_missing_values(df)
        
        if self.config["local_model_config"]["test_local_model"] is True:
            df = self._repartition_data_Balanced(df, preserve_partition_id = self.config["reserve_partition_id"])
        elif self.config["global_model_config"]["test_global_model"] is True:
            df = self._repartition_data_NotBalanced(df, preserve_partition_id = self.config["reserve_partition_id"])
        else:
            raise ValueError("Preprocessing error.")
        
        df = self.normalize(df, min_max, preserve_partition_id = self.config["reserve_partition_id"])
        
        return df


## Local Model Manager

The `LocalModelManager` class is responsible for training local Proximity Tree models and assembling them into a Proximity Forest ensemble.

### Overview
This file implements the local parallelisation strategy, which applies ensemble learning by partitioning the ECG dataset and independently training multiple `ProximityTree` models. The process is as follows:
1. Receives a preprocessed Spark DataFrame.
2. Partitions the data (with unbalanced or balanced class distributions).
3. Trains one Proximity Tree per partition using the AEON library.
4. Serializes trees, returns them to the driver, and combines them into a `ProximityForest`.

### Key Features
- **Parallel Training**: Spark executors handle partitioned model training concurrently using `mapPartitions`.
- **One Tree per Partition**: The number of trained trees equals the number of partitions.
- **Optional Weighting**: Tree predictions can be weighted by validation accuracy.
- **Manual Assembly**: The ensemble is manually configured with class labels, job parameters, and fit flags for compatibility with AEON.
- **Inspection Tools**: Includes utilities to print model structure and details for debugging and analysis.

This approach significantly reduces computational overhead via distributed execution but may slightly limit model accuracy due to isolated training on partitions.


In [11]:
class LocalModelManager:
    """
    This class handles training local models (Proximity Trees) on chunks of our data and then
    puts them together into an  Proximity Forest ensemble.
    
    The steps are pretty simple:
      1. Get a preprocessed Spark DataFrame.
      2. Split it into parts.
      3. Train a Proximity Tree  model on each part .
      4. Then, it gathers all the trees into one Proximity Forest ensemble
    """

    def __init__(self, config: dict):
        """
        Init with our settings.
        
        Args:
            config (dict): Settings like:
              - num_partitions: How many parts to split the data into.
              - tree_params: Extra parameters for the Proximity Tree  model.
        """       
        # Set default configuration
        self.config = config
        
        # List to store trained trees
        self.trees = []
        
        # Final ensemble model
        self.ensemble = None
        
        # Set up a logger so we can see whats going on
        self.logger = logging.getLogger(__name__)
        self.logger.addHandler(logging.StreamHandler())
        self.logger.setLevel(logging.ERROR)
        

    def _set_forest_classes(self):
        """Collect all class labels from individual trees and mark the forest as fitted."""
        all_classes = []
        for tree in self.trees:
            if hasattr(tree, "classes_"):
                all_classes.extend(tree.classes_)

        unique_classes = np.unique(all_classes)
        self.ensemble.classes_ = unique_classes
        self.ensemble.n_classes_ = len(unique_classes)

        # AEON’s BaseClassifier typically expects a '_class_dictionary' mapping class->int
        self.ensemble._class_dictionary = {
            cls: idx for idx, cls in enumerate(unique_classes)
        }

        # Some older AEON versions store the number of classes in a private attribute
        self.ensemble._n_classes = len(unique_classes)

        # If n_jobs is used, set it explicitly here
        if "n_jobs" in self.config["forest_params"]:
            self.ensemble._n_jobs = self.config["forest_params"]["n_jobs"]

        # BaseClassifier sets 'is_fitted = True' at the end of fit().
        # So we must set the public property 'is_fitted' (not just 'is_fitted_').
        # This ensures ._check_is_fitted() passes in predict().
        self.ensemble.is_fitted_ = True
        self.ensemble.is_fitted = True

        
    def get_ensemble(self) -> ProximityForest:
        """
        Return the trained Proximity Forest ensemble.
        """
        return self.ensemble

    def print_ensemble_details(self):
        """
        Print the details of the aggregated Proximity Forest ensemble.
        """
        if self.ensemble and hasattr(self.ensemble, 'trees_'):
            num_trees = len(self.ensemble.trees_)
            print(f"Aggregated Proximity Forest (contains {num_trees} trees):")
            print(f"  Number of trees (in trees_ attribute): {num_trees}")
            # You might want to print a summary of the parameters used for the forest here
            print(f"  Forest Parameters: {self.ensemble.get_params()}")
            for i, tree in enumerate(self.ensemble.trees_):
                print(f"  Tree {i+1} Details:")
                self._print_tree_node_info(tree.root, depth=2)
            print("-" * 20)
        else:
            print("Proximity Forest ensemble has not been trained yet or the 'trees_' attribute is missing.")

    def _print_tree_node_info(self, node, depth):
        indent = "  " * depth
        print(f"{indent}Node ID: {node.node_id}, Leaf: {node._is_leaf}")

        if node._is_leaf:
            print(f"{indent}  Label: {node.label}, Class Distribution: {node.class_distribution}")
        else:
            splitter = node.splitter
            if splitter:
                exemplars = splitter[0]
                distance_info = splitter[1]
                distance_measure = list(distance_info.keys())[0]
                distance_params = distance_info[distance_measure]

                print(f"{indent}  Splitter:")
                print(f"{indent}    Distance Measure: {distance_measure}, Parameters: {distance_params}")
                print(f"{indent}    Exemplar Classes: {list(exemplars.keys())}")

                print(f"{indent}  Children:")
                for label, child_node in node.children.items():
                    print(f"{indent}    Branch on exemplar of class '{label}':")
                    self._print_tree_node_info(child_node, depth + 1)

   
   
    def train_ensemble(self, df: DataFrame) -> ProximityForest:
        
        """
             Train a forest model iin 3 steps:
        1. Prepare data partitions
        2. Train trees on each partition
        3. Combine trees into a forest
        
        """
        
        tree_params = self.config["tree_params"]      
        
         # Define how to process each partition - inline function
        def process_partition(partition_data):
            """Process one data partition to train a tree."""
            try:
                # Convert Spark rows to pandas DataFrame
                pandas_df = pd.DataFrame([row.asDict() for row in partition_data])
                if pandas_df.empty:
                    return []
                
                # Prepare features (3D format for AEON) and labes
                X = np.ascontiguousarray(pandas_df.drop("label", axis=1).values)
                X_3d = X.reshape((X.shape[0], 1, X.shape[1]))  # (samples, 1, features)
                y = pandas_df["label"].values
                
                # Train one tree
                tree = ProximityTree(**tree_params)
                tree.fit(X_3d, y)
                
                # Return serialized tree
                return [pickle.dumps(tree)]
            
            except Exception as e:
                print(f"Failed to train tree on partition: {str(e)}")
                return []  # Skip failed partitions
            
        # Run training on all partitions
        trees_rdd = df.rdd.mapPartitions(process_partition)
        serialized_trees = trees_rdd.collect()
        self.trees = [pickle.loads(b) for b in serialized_trees if b is not None]

        # Build the forest
        if self.trees:
            self.ensemble = ProximityForest(
                n_trees=len(self.trees),
                **self.config["forest_params"]
            )
            # Manually set forest properties
            self.ensemble.trees_ = self.trees
            self._set_forest_classes() 
            return self.ensemble
        else:
            print("Warning: No trees were trained!")
            return None


## GlobalModelManager: Distributed Proximity Tree 

The `GlobalModelManager` class builds a single proximity tree across an entire dataset distributed in Apache Spark. It is optimized for global parallelisation and minimizes driver-node bottlenecks via strategic broadcasting and in-place DataFrame updates.

### Key Features:
- **Tree Construction**: Builds the tree in a depth-wise, breadth-first manner. Each node selects the best exemplar-based split using Gini impurity. Nodes failing split conditions are marked as leaves.
- **Distributed Processing**: Uses Spark DataFrames, caching, and broadcast variables to coordinate node splits and data routing across workers without full dataset shuffling.
- **Enhanced Prediction**: Converts the tree to a plain dictionary, broadcasts it, and uses a UDF for efficient traversal during prediction.
- **Persistence**: Supports saving/loading of the full tree state via pickling.


In [12]:

# Configure logging for this module
# Note: In a distributed environment, the root logger config might be set by Spark.
# However, getting a module-specific logger like below is standard practice.
# The level set here acts as a minimum threshold for this logger.
# The actual output depends on the *handler* configuration (e.g., basicConfig)
# and the overall Spark/root logger level.
# We keep basicConfig for potential use when running as a standalone script __main__.
logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
logger = logging.getLogger(__name__)

# Suppress excessive logging from py4j and pyspark itself
logging.getLogger("py4j").setLevel(logging.ERROR)
logging.getLogger("pyspark").setLevel(logging.ERROR)


try:
    import numpy as np

    _NP = True
except ImportError: # pragma: no cover – NumPy optional for tiny envs
    _NP = False

# .............................................................................
# helpers
# .............................................................................

TreeNode = collections.namedtuple(
    "TreeNode", "node_id parent_id split_on is_leaf prediction children".split()
)

# Keep the original efficient euclidean distance function
def _euclid(a, b):
    """Fast Euclidean distance for python *or* NumPy inputs."""
    # --- Use logger for debugging ---
    # Note: Logs from UDFs go to worker logs. In local mode, often appear in console.
    logger.debug(f"UDF: _euclid inputs: a={a[:5] if isinstance(a, (list, np.ndarray)) else a}..., b={b[:5] if isinstance(b, (list, np.ndarray)) else b}...")
    # -------------------------------------
    if a is None or b is None or len(a) != len(b):
        # --- Use logger ---
        logger.debug(f"UDF: _euclid returning inf due to None/len mismatch: a is None={a is None}, b is None={b is None}, len(a)={len(a) if a is not None else 'N/A'}, len(b)={len(b) if b is not None else 'N/A'}")
        # --------------------
        return float("inf")
    if _NP:
        try:
            diff = np.subtract(a, b, dtype=float)
            dist = float(np.sqrt(np.dot(diff, diff)))
            # --- Use logger ---
            # logger.debug(f"UDF: _euclid (NumPy) returning {dist}") # Avoid logging too much if successful
            # --------------------
            return dist
        except Exception as e:
            # --- Use logger for NumPy errors ---
            logger.error(f"UDF: ERROR in _euclid (NumPy path): {e}. Inputs: a={a[:5] if isinstance(a, (list, np.ndarray)) else a}..., b={b[:5] if isinstance(b, (list, np.ndarray)) else b}...")
            # ---------------------------------------
            # Re-raise or return inf, depending on desired behavior on error
            return float("inf") # Or raise e
    else: # Pure Python path
        try:
            dist = float(math.sqrt(sum((x - y) ** 2 for x, y in zip(a, b))))
            # --- Use logger ---
            # logger.debug(f"UDF: _euclid (Python) returning {dist}") # Avoid logging too much if successful
            # --------------------
            return dist
        except Exception as e:
            # --- Use logger for Python errors ---
            logger.error(f"UDF: ERROR in _euclid (Python path): {e}. Inputs: a={a[:5] if isinstance(a, (list, np.ndarray)) else a}..., b={b[:5] if isinstance(b, (list, np.ndarray)) else b}...")
            # --------------------------------------
            # Re-raise or return inf
            return float("inf") # Or raise e


# .............................................................................
# prediction-side helper – pure python, broadcasted once per executor
# (MODIFIED to use enhanced traversal logic)
# .............................................................................

def _enhanced_mk_traverse(bc_plain_tree):
    """
    Return a local function that navigates the broadcast tree using enhanced logic.
    Expects a broadcasted plain dictionary structure.
    """

    # The tree is now a plain dictionary
    tree: Dict[int, Dict[str, Any]] = bc_plain_tree.value
    # Get a logger instance inside the UDF factory function
    # This logger will be serialized and sent to workers
    udf_logger = logging.getLogger(__name__)


    def _enhanced_traverse(ts):
        """Enhanced traversal logic for a single time series."""
        # Logging inside UDFs can be tricky; messages go to worker logs by default.
        # Use sparingly or configure Spark logging to collect worker logs.
        # udf_logger.debug("UDF: _enhanced_traverse started for time series.")

        if ts is None:
            # udf_logger.debug("UDF: Input time series is None. Returning None.")
            # Fallback handled by coalesce in predict method
            return None

        node_id = 0 # Start at root
        # udf_logger.debug(f"UDF: Starting traversal from root node {node_id}.")

        # Traverse the tree until a leaf node is reached or traversal stops
        while node_id in tree:
            current_node = tree[node_id]
            # udf_logger.debug(f"UDF: Current node_id: {node_id}, is_leaf: {current_node.get('is_leaf', False)}")

            # If it's a leaf node, return its prediction
            if current_node.get('is_leaf', False): # Use .get for safety
                # udf_logger.debug(f"UDF: Node {node_id} is leaf. Returning prediction: {current_node.get('prediction')}")
                return current_node.get('prediction') # Prediction is in the plain dict

            # If it's an internal node, use the split info to decide which branch to follow
            split_info = current_node.get('split_on') # (measure_type, {branch_id: exemplar_ts})
            children = current_node.get('children')

            # Ensure split info and children exist for internal nodes
            if split_info and children and len(children) > 0:
                _, exemplars = split_info # We only need exemplars for the split
                # udf_logger.debug(f"UDF: Node {node_id} is internal. Split info: {split_info}, Children: {children}")

                # Calculate distance to ALL exemplars used in THIS node's split
                min_dist_all_exemplars = float("inf")
                best_branch_id_all_exemplars = None # Label of the nearest exemplar

                # Handle case where exemplars might be empty (shouldn't happen in valid tree)
                if not exemplars:
                    # udf_logger.warning(f"UDF: Node {node_id} is internal but has no exemplars. Treating as leaf.")
                    # Internal node with no exemplars or children? Treat as leaf with its prediction
                    return current_node.get('prediction') # Prediction should be None if not finalized

                # udf_logger.debug(f"UDF: Calculating distances to exemplars for node {node_id}.")
                for branch_id, exemplar_ts in exemplars.items():
                    # Use the base _euclid function
                    d = _euclid(ts, exemplar_ts)
                    # udf_logger.debug(f"UDF: Distance to exemplar {branch_id}: {d}")
                    if d < min_dist_all_exemplars:
                        min_dist_all_exemplars = d
                        best_branch_id_all_exemplars = branch_id

                # udf_logger.debug(f"UDF: Nearest exemplar branch_id for node {node_id}: {best_branch_id_all_exemplars}")

                # --- Enhanced Traversal Logic ---
                # Check if the child node corresponding to the overall nearest exemplar exists
                if best_branch_id_all_exemplars is not None and best_branch_id_all_exemplars in children:
                    # If the child exists, move to that child node
                    next_node = children[best_branch_id_all_exemplars]
                    # udf_logger.debug(f"UDF: Moving to child node {next_node} via branch {best_branch_id_all_exemplars}.")
                    node_id = next_node
                else:
                    # If the ideal child does NOT exist (pruned branch),
                    # find the nearest exemplar among the *existing* child branches and follow that path.
                    # udf_logger.debug(f"UDF: Ideal child branch {best_branch_id_all_exemplars} not found in children {children}. Finding nearest among existing children.")
                    min_dist_existing_children = float("inf")
                    next_node_id = None # The child node ID to move to

                    # Iterate through the *existing* child branches listed in the tree structure
                    for existing_branch_id, existing_child_id in children.items():
                        # Find the exemplar time series for this existing branch from the original exemplars used for the split
                        if existing_branch_id in exemplars: # Double check exemplar exists for this branch
                            existing_exemplar_ts = exemplars[existing_branch_id]
                            # Calculate distance to this existing branch's exemplar
                            d = _euclid(ts, existing_exemplar_ts)
                            # udf_logger.debug(f"UDF: Distance to existing child branch {existing_branch_id} exemplar: {d}")
                            if d < min_dist_existing_children:
                                min_dist_existing_children = d
                                next_node_id = existing_child_id

                    # If a nearest existing child was found, move to that child node
                    if next_node_id is not None:
                        # udf_logger.debug(f"UDF: Routing to nearest existing child {next_node_id}.")
                        node_id = next_node_id
                    else:
                        # If no existing children were found or no nearest existing child determined,
                        # stop traversal and return the current node's prediction (which should be None for internal nodes)
                        # udf_logger.warning(f"UDF: No nearest existing child found for node {current_node.get('node_id')}. Stopping traversal.")
                        return current_node.get('prediction') # Fallback handled by coalesce later


            else:
                # If the node is internal but has no split info or children (shouldn't happen in valid tree)
                # Stop traversal and return the current node's prediction.
                # udf_logger.warning(f"UDF: Node {current_node.get('node_id')} is internal but missing split info or children. Stopping traversal.")
                return current_node.get('prediction') # Fallback handled by coalesce later


        # If the loop finishes without returning (e.g., node_id not found, error)
        # This indicates a problem in the tree structure.
        # Returning None here, will be caught by coalesce.
        # udf_logger.error(f"UDF: Traversal loop finished unexpectedly at node_id {node_id}.")
        return None


    return _enhanced_traverse


# =============================================================================
# GlobalModelManager class (Optimised + Enhanced Prediction)
# =============================================================================

class GlobalModelManager:
    """Distribution-friendly proximity-tree learner."""

    # ------------------------------------------------------------------
    # init
    # ------------------------------------------------------------------

    def __init__(self, spark: SparkSession, config: Dict[str, Any]):
        logger.debug("GlobalModelManager __init__ started.")
        p = config["tree_params"]
        self.spark = spark
        self.max_depth: int | None = p.get("max_depth") # Use .get for safety
        self.min_samples: int = p.get("min_samples_split", 2) # Use .get with default
        self.k: int = p.get("n_splitters", 5) # Use .get with default
        self.tree: Dict[int, TreeNode] = {0: TreeNode(0, None, None, False, None, {})}
        self._next_id: int = 1
        self._maj: int = 1 # fallback class if everything else fails
        logger.debug(f"Initialized with max_depth={self.max_depth}, min_samples={self.min_samples}, k={self.k}")
        logger.debug("GlobalModelManager __init__ finished.")


    # ------------------------------------------------------------------
    # private helpers
    # ------------------------------------------------------------------

    def _to_ts_df(self, df):
        """Ensure DataFrame has (row_id, time_series[, true_label])."""
        logger.debug("_to_ts_df started.")

        if "row_id" not in df.columns:
            logger.debug("Adding row_id.")
            df = df.withColumn("row_id", F.monotonically_increasing_id())
        else:
            logger.debug("Casting existing row_id to LongType.")
            df = df.withColumn("row_id", F.col("row_id").cast(LongType()))

        if "time_series" in df.columns:
            logger.debug("'time_series' column already exists.")
            if "label" in df.columns and "true_label" not in df.columns:
                logger.debug("Renaming 'label' to 'true_label'.")
                df = df.withColumnRenamed("label", "true_label")
            logger.debug("_to_ts_df finished (already formatted).")
            return df

        lbl = "label" if "label" in df.columns else (
            "true_label" if "true_label" in df.columns else None
        )
        feat_cols = [c for c in df.columns if c not in {lbl, "row_id"}]
        logger.debug(f"Found feature columns: {feat_cols}")
        cols = [
            "row_id",
            F.array(*[F.col(c) for c in feat_cols]).alias("time_series"),
        ]
        if lbl:
            logger.debug(f"Including label column: {lbl}")
            cols.append(F.col(lbl).cast(IntegerType()).alias("true_label"))
        else:
            logger.debug("No label column found.")

        ts_df = df.select(*cols)
        logger.debug("_to_ts_df finished (conversion done).")
        return ts_df

    @staticmethod
    def _gini(counts: Dict[int, int]) -> float:
        tot = sum(counts.values())
        if tot == 0:
            return 0.0
        return 1.0 - sum((c / tot) ** 2 for c in counts.values())

    # ------------------------------------------------------------------
    # fitting
    # ------------------------------------------------------------------

    def fit(self, df): # noqa: C901 (complexity accepted here)
        """Train the proximity tree."""
        logger.debug("fit started.")
        df = self._to_ts_df(df).cache()
        initial_row_count = df.count()
        logger.debug(f"Initial DataFrame prepared and cached. Row count: {initial_row_count}")

        if initial_row_count == 0:
             logger.debug("Input DataFrame is empty. Setting root as leaf.")
             self.tree[0] = self.tree[0]._replace(is_leaf=True, prediction=self._maj, children={})
             df.unpersist()
             logger.debug("fit finished (empty DataFrame).")
             return self


        maj_row = df.groupBy("true_label").count().orderBy(F.desc("count")).first()
        if maj_row:
            self._maj = maj_row[0]
            logger.debug(f"Overall majority class calculated: {self._maj}")
        else:
            logger.debug(f"No data to calculate overall majority class. Keeping default: {self._maj}")


        assign = (
            df.select("row_id", "time_series", "true_label")
            .withColumn("node_id", F.lit(0))
            .cache()
        )
        assign_initial_count = assign.count()
        logger.debug(f"Initial assignment DataFrame created and cached. Row count: {assign_initial_count} at root node 0.")
        df.unpersist() # Unpersist initial DataFrame


        open_nodes, depth = {0}, 0
        logger.debug(f"Starting tree building loop with initial open_nodes: {open_nodes}")

        # --------------- depth-wise growth ----------------------------
        while open_nodes and (self.max_depth is None or depth < self.max_depth):
            logger.debug(f"\n--- Starting tree level {depth} ---")
            logger.debug(f"Open nodes for this level: {open_nodes}")

            cur = assign.filter(F.col("node_id").isin(list(open_nodes))).cache()
            cur_count = cur.count()
            logger.debug(f"Filtered data for current level. Row count: {cur_count}")

            if cur.isEmpty():
                logger.debug(f"No data for open nodes at depth {depth}. Breaking loop.")
                cur.unpersist(); break

            # 1) exemplar pool per (node,label) – single pass
            logger.debug("Starting exemplar pool sampling.")
            all_ts_per_node_label = (
                cur.groupBy("node_id", "true_label")
                .agg(F.collect_list("time_series").alias("ts_list"))
                .collect() # Collects list of Rows: Row(node_id=..., true_label=..., ts_list=[...])
            )
            logger.debug(f"Collected time series lists for {len(all_ts_per_node_label)} node-label groups.")

            pool: Dict[int, Dict[int, list]] = {}
            # Perform random sampling on the driver from the collected lists
            for r in all_ts_per_node_label:
                node_id = r.node_id
                true_label = r.true_label
                ts_list = r.ts_list

                if not ts_list:
                    logger.debug(f"No time series found for node {node_id}, label {true_label} in collected list.")
                    continue

                # Randomly sample self.k exemplars from the list
                # Ensure we don't sample more than available
                num_to_sample = min(self.k, len(ts_list))
                # Use random.sample for actual random selection
                sampled_ts = random.sample(ts_list, num_to_sample)

                pool.setdefault(node_id, {})[true_label] = sampled_ts
                # logger.debug(f"Sampled {len(sampled_ts)} exemplars for node {node_id}, label {true_label}. Sample: {sampled_ts[:2]}...") # Avoid printing large lists
                logger.debug(f"Sampled {len(sampled_ts)} exemplars for node {node_id}, label {true_label}.")


            logger.debug(f"Finished exemplar pool sampling. Pool structure keys: {list(pool.keys())}")


            best: Dict[int, Tuple[str, Dict[int, list]]] = {}
            to_leaf: set[int] = set()

            # 2) pick best split per node (driver)
            logger.debug("Starting best split evaluation per node.")
            for nid in list(open_nodes): # Iterate over a copy in case nodes are removed from open_nodes
                logger.debug(f"Evaluating splits for node {nid}.")
                nd_df = cur.filter(F.col("node_id") == nid)
                nd_df_count = nd_df.count()
                logger.debug(f"Data count for node {nid}: {nd_df_count}")

                # Calculate local stats for leaf conditions and parent Gini
                stats_rows = nd_df.groupBy("true_label").count().collect()
                # --- FIXED: Access count using r['count'] ---
                stats = {r.true_label: r['count'] for r in stats_rows}
                tot = sum(stats.values())
                logger.debug(f"Node {nid} stats: {stats}, total samples: {tot}")


                # Leaf Condition 1: Insufficient samples, purity, or no exemplars
                if tot < self.min_samples:
                    logger.debug(f"Node {nid} has {tot} samples, below min_samples {self.min_samples}. Marking as leaf.")
                    to_leaf.add(nid); continue
                if len(stats) <= 1:
                     logger.debug(f"Node {nid} is pure ({len(stats)} labels). Marking as leaf.")
                     to_leaf.add(nid); continue
                if nid not in pool or not pool[nid]:
                    logger.debug(f"No exemplars found in pool for node {nid}. Marking as leaf.")
                    to_leaf.add(nid); continue

                parent_g = self._gini(stats)
                logger.debug(f"Node {nid} parent Gini: {parent_g}")

                labels = list(pool[nid].keys())

                # Leaf Condition 2: Insufficient exemplar labels for split
                if len(labels) < 2:
                    logger.debug(f"Node {nid} has {len(labels)} exemplar labels in pool, need >= 2 for split. Marking as leaf.")
                    to_leaf.add(nid); continue

                best_gain, best_exp = -1.0, None
                logger.debug(f"Evaluating {self.k} candidate splits for node {nid}.")
                for i in range(self.k): # Evaluate k candidate splits
                    # Sample exemplars for THIS candidate split from the pool
                    # Ensure pool[nid][lbl] is not empty before random.choice
                    candidate_ex = {}
                    for lbl in labels:
                        if lbl in pool[nid] and pool[nid][lbl]:
                            candidate_ex[lbl] = random.choice(pool[nid][lbl])
                        else:
                            logger.debug(f"Warning: No exemplars in pool for label {lbl} in node {nid} for candidate {i+1}. Skipping this candidate split.")
                            candidate_ex = None # Invalidate this candidate
                            break # Stop evaluating this candidate

                    if candidate_ex is None or len(candidate_ex) < 2:
                         logger.debug(f"Candidate split {i+1} for node {nid} has less than 2 exemplars ({len(candidate_ex) if candidate_ex is not None else 'None'}). Skipping.")
                         continue # Need at least two branches for a valid split

                    # logger.debug(f"Evaluating candidate split {i+1} for node {nid} with exemplars for labels: {list(candidate_ex.keys())}. Exemplars: {candidate_ex}")
                    logger.debug(f"Evaluating candidate split {i+1} for node {nid} with exemplars for labels: {list(candidate_ex.keys())}.")
                    bc_ex = self.spark.sparkContext.broadcast(candidate_ex)

                    @F.udf(IntegerType())
                    def nearest_lbl_udf(ts):
                        # Use logger here instead of print
                        udf_logger_local = logging.getLogger(__name__) # Get logger instance in worker
                        udf_logger_local.debug(f"UDF: nearest_lbl_udf processing TS: {ts[:5] if isinstance(ts, (list, np.ndarray)) else ts}...")

                        best_d, best_l = float("inf"), None
                        # Use the original _euclid function
                        exemplars_val = bc_ex.value
                        if not exemplars_val:
                            udf_logger_local.debug("UDF: Exemplars is empty, returning None.")
                            return None # Safety check

                        for l, ex_ts in exemplars_val.items():
                            # Use logger inside _euclid as well
                            d = _euclid(ts, ex_ts) # _euclid now uses logger internally
                            if d < best_d:
                                best_d, best_l = d, l

                        udf_logger_local.debug(f"UDF: Finished calculating distances. Best label: {best_l}")
                        return best_l

                    # This is the DataFrame-based Gini calculation (RETAINED for speed)
                    ass = (
                        nd_df.withColumn("branch", nearest_lbl_udf("time_series"))
                        .groupBy("branch", "true_label")
                        .count()
                    )
                    branch_cnt = ass.groupBy("branch").agg(F.sum("count").alias("tot"))
                    joined = ass.join(branch_cnt, "branch")

                    # Check if joined DataFrame is empty before calculating impurity
                    if joined.isEmpty():
                         logger.debug(f"Joined DataFrame is empty for candidate {i+1} on node {nid}. Cannot calculate impurity.")
                         bc_ex.unpersist(False)
                         continue

                    imp_row = (
                         joined.withColumn("prob_sq", (F.col("count") / F.col("tot")) ** 2)
                         .groupBy("branch", "tot")
                         .agg(F.sum("prob_sq").alias("s"))
                         .withColumn("g", 1.0 - F.col("s"))
                         .withColumn("w", (F.col("tot") / tot) * F.col("g"))
                         .agg(F.sum("w").alias("imp"))
                         .first() # Collect the single result to the driver
                    )

                    if imp_row is None:
                         logger.debug(f"Impurity calculation returned None for candidate {i+1} on node {nid}. Skipping.")
                         bc_ex.unpersist(False)
                         continue

                    imp = imp_row[0]
                    gain = parent_g - imp
                    logger.debug(f"Candidate split {i+1} for node {nid}: Impurity={imp:.4f}, Gain={gain:.4f}")

                    bc_ex.unpersist(False) # Unpersist broadcasted exemplars for this candidate

                    # Leaf Condition 3: Update best split if gain is improved
                    if gain > best_gain:
                        best_gain, best_exp = gain, candidate_ex
                        logger.debug(f"Candidate split {i+1} is the best so far for node {nid} with gain {best_gain:.4f}.")

                # Leaf Condition 4: If best gain is not significantly positive after all candidates
                if best_gain > 1e-9: # Use tolerance for splitting decision
                    logger.debug(f"Node {nid} found a good split with gain {best_gain:.4f}.")
                    best[nid] = ("euclidean", best_exp)
                else:
                    logger.debug(f"Node {nid} did not find a good split (best gain {best_gain:.4f}). Marking as leaf.")
                    to_leaf.add(nid) # Node becomes a leaf

            logger.debug("Finished best split evaluation per node.")


            # 2b) mark leaves right away **with LOCAL majority** (REFINED CALCULATION)
            logger.debug("Finalizing nodes marked as leaves in this iteration.")
            for nid in list(to_leaf): # Iterate over a copy as we remove from open_nodes
                if nid not in self.tree:
                     logger.debug(f"Node {nid} already removed from tree? Skipping finalization.")
                     open_nodes.discard(nid) # Ensure it's not in open_nodes
                     continue

                if self.tree[nid].is_leaf:
                     logger.debug(f"Node {nid} already finalized as leaf. Skipping.")
                     open_nodes.discard(nid) # Ensure it's not in open_nodes
                     continue

                logger.debug(f"Finalizing node {nid} as a leaf.")
                # Recalculate stats from the data currently assigned to this node
                leaf_data_df = cur.filter(F.col("node_id") == nid).cache()
                leaf_stats_rows = leaf_data_df.groupBy("true_label").count().collect()
                leaf_stats = {r.true_label: r['count'] for r in leaf_stats_rows}
                leaf_data_df.unpersist() # Unpersist leaf data

                logger.debug(f"Node {nid} local stats for leaf prediction: {leaf_stats}")

                maj_lbl = self._maj # Default to overall majority

                if leaf_stats:
                    # Find majority label using count, break ties using smallest label
                    maj_lbl = max(leaf_stats.items(), key=lambda kv: (kv[1], -kv[0]))[0]
                    logger.debug(f"Node {nid} local majority prediction: {maj_lbl}")
                else:
                    logger.debug(f"No data found for node {nid} during leaf finalization. Using overall majority fallback: {maj_lbl}")


                self.tree[nid] = self.tree[nid]._replace(is_leaf=True, prediction=maj_lbl, children={})
                logger.debug(f"Node {nid} marked as leaf with prediction {maj_lbl}.")
                open_nodes.discard(nid) # Remove finalized leaves from consideration

            logger.debug("Finished finalizing leaves for this iteration.")


            # 3) create children + update assignment DF
            # This block is largely REETAINED from the original for its DataFrame efficiency
            if not best:
                logger.debug("No nodes found good splits in this iteration. Breaking loop.")
                cur.unpersist(); break # No nodes successfully split in this iteration

            logger.debug("Creating children and updating assignment DataFrame.")
            split_map, new_open = {}, {}
            for pid, (m, ex) in best.items():
                logger.debug(f"Processing best split for parent node {pid}.")
                ch = {}
                # Only create children for branches with exemplars in the chosen split
                for lbl in ex:
                    cid = self._next_id; self._next_id += 1
                    # Children are initially non-leaves with no prediction or split info
                    self.tree[cid] = TreeNode(cid, pid, None, False, None, {})
                    ch[lbl] = cid
                    # Map (parent_id, branch_label) to new child_id
                    split_map[(pid, lbl)] = cid
                    # Add new children to the set for the next iteration
                    new_open[cid] = None # Value doesn't matter, just need the set keys
                    logger.debug(f"Created child {cid} for branch {lbl} of parent {pid}.")


                # Update the parent node in the tree structure
                self.tree[pid] = self.tree[pid]._replace(split_on=(m, ex), children=ch, is_leaf=False)
                logger.debug(f"Parent node {pid} updated: split_on={self.tree[pid].split_on}, children={self.tree[pid].children}")


            # Prepare for the next iteration
            open_nodes = set(new_open.keys())
            logger.debug(f"New open_nodes for next level: {open_nodes}")

            # Broadcast the split mapping and exemplars for the route UDF
            bc_split = self.spark.sparkContext.broadcast(split_map)
            bc_exs = self.spark.sparkContext.broadcast({pid: ex for pid, (_, ex) in best.items()})
            logger.debug("Broadcasted split_map and best_exs for route_udf.")


            # Define the UDF for routing rows to children (RETAINED)
            @F.udf(IntegerType())
            def route_udf(pid, ts):
                # If this parent didn't split in this iteration (shouldn't happen if filtering 'cur' correctly)
                if pid not in bc_exs.value:
                    # This case might happen if a node was in open_nodes but had no data in cur
                    # or if there's a logic error. Returning pid keeps the row at the current node.
                    return pid

                ex = bc_exs.value[pid]
                best_d, best_lbl = float("inf"), None
                # Find nearest exemplar for the row among the split exemplars
                for l, ex_ts in ex.items():
                    d = _euclid(ts, ex_ts)
                    if d < best_d:
                        best_d, best_lbl = d, l

                # Get the new child node ID from the split mapping.
                # If the (parent_id, branch_label) is NOT in the map (e.g., branch didn't meet min_samples)
                # return the parent_id, effectively keeping the row at the parent node.
                return bc_split.value.get((pid, best_lbl), pid)

            # Apply the route UDF to the entire assignment DataFrame
            # This updates the node_id for all rows that were in splitting nodes
            old_assign = assign # Keep reference to unpersist
            logger.debug("Applying route_udf to update assign DataFrame.")
            assign = assign.withColumn("node_id", route_udf("node_id", "time_series")).cache()
            assign_updated_count = assign.count() # Trigger action and cache
            logger.debug(f"assign DataFrame updated and cached. New total rows: {assign_updated_count}")


            # Unpersist previous assignment DF and broadcasts
            old_assign.unpersist()
            bc_split.unpersist(False)
            bc_exs.unpersist(False)
            cur.unpersist() # Unpersist the current level's data
            logger.debug(f"Unpersisted intermediates for depth {depth}.")

            depth += 1 # Increment depth for next iteration

        logger.debug("\n--- Main tree building loop finished ---")
        logger.debug(f"Final open_nodes: {open_nodes}")


        # --- Final Dangling Node Finalization (Using the final 'assign' state) ---
        # This block should execute *before* the final assign.unpersist()
        logger.debug("Performing final dangling node finalization.")
        nodes_to_finalize_at_end = [nid for nid, nd in self.tree.items() if not nd.is_leaf and not nd.children]
        if nodes_to_finalize_at_end:
             logger.debug(f"Found {len(nodes_to_finalize_at_end)} dangling nodes to finalize.")
             # Filter the final assignment DF for these dangling nodes
             dangling_df = assign.filter(F.col("node_id").isin(nodes_to_finalize_at_end)).cache()
             dangling_df_count = dangling_df.count()
             logger.debug(f"Data count for dangling nodes: {dangling_df_count}")


             # Calculate local majority for each dangling node
             dangling_stats_rows = dangling_df.groupBy("node_id", "true_label").count().collect()
             dangling_stats_by_node = collections.defaultdict(dict)
             for r in dangling_stats_rows:
                 # --- FIXED: Access count using r['count'] ---
                 dangling_stats_by_node[r.node_id][r.true_label] = r['count']

             dangling_df.unpersist() # Unpersist dangling data
             logger.debug("Dangling DataFrame unpersisted.")


             for nid in nodes_to_finalize_at_end:
                 stats = dangling_stats_by_node.get(nid, {})
                 maj_lbl = self._maj # fallback

                 if stats:
                     maj_lbl = max(stats.items(), key=lambda kv: (kv[1], -kv[0]))[0]
                     logger.debug(f"Dangling node {nid} local majority prediction: {maj_lbl}")
                 else:
                     logger.debug(f"No data found for dangling node {nid}. Using overall majority fallback: {maj_lbl}")


                 self.tree[nid] = self.tree[nid]._replace(is_leaf=True, prediction=maj_lbl, split_on=None)
                 logger.debug(f"Dangling node {nid} finalized as leaf with prediction {maj_lbl}.")
        else:
             logger.debug("No dangling nodes found to finalize at the end.")

        assign.unpersist()
        logger.debug("Final assign DataFrame unpersisted.")


        logger.debug("fit finished.")
        return self

    def predict(self, df):
        logger.debug("predict started.")
        df = self._to_ts_df(df)
        df_count = df.count()
        logger.debug(f"Input DataFrame for prediction prepared. Row count: {df_count}")


        if not self.tree or (0 not in self.tree) or (self.tree[0].prediction is None and not self.tree[0].children):
             logger.warning("Tree is not fitted or is empty. Returning DataFrame with default prediction.")
             # Return DataFrame with default prediction if tree is not usable
             default_pred_col = F.lit(self._maj).cast(IntegerType()).alias("prediction")
             select_cols = ["row_id", "time_series"] + (["true_label"] if "true_label" in df.columns else []) + [default_pred_col]
             return df.select(*select_cols)


        # --- Convert tree structure to a plain dictionary for broadcasting ---
        logger.debug("Converting tree structure to plain dictionary for broadcasting.")
        plain_tree_structure = {}
        for node_id, node in self.tree.items():
            plain_tree_structure[node_id] = {
                'node_id': node.node_id,
                'parent_id': node.parent_id,
                # Ensure split_on is a plain structure (tuple of string and dict)
                'split_on': node.split_on,
                'is_leaf': node.is_leaf,
                # Ensure prediction is serializable (should be int or None)
                'prediction': node.prediction, # Should be int or None already
                # Children dictionary keys (branch_id) and values (child_node_id) are already plain types
                'children': node.children
            }
        logger.debug(f"Converted tree structure to plain dictionary with {len(plain_tree_structure)} nodes.")

        # Broadcast the plain tree structure
        logger.debug("Broadcasting plain tree structure.")
        bc_tree = self.spark.sparkContext.broadcast(plain_tree_structure)
        logger.debug("Broadcasted plain tree structure.")
        # Use the enhanced traversal function
        udf_pred = F.udf(_enhanced_mk_traverse(bc_tree), IntegerType())
        logger.debug("Created prediction UDF.")
        logger.debug("Applying prediction UDF and coalescing results.")
        out = (
            df.withColumn("pred", udf_pred("time_series"))
            .withColumn("prediction", F.coalesce("pred", F.lit(self._maj)))
            .drop("pred")
        )
        out_count = out.count() # Trigger action
        logger.debug(f"Prediction applied. Output DataFrame row count: {out_count}")

        bc_tree.unpersist(False)
        logger.debug("Broadcasted tree unpersisted.")

        # Select relevant output columns
        sel = ["row_id", "time_series"] + (["true_label"] if "true_label" in out.columns else []) + ["prediction"]
        logger.debug(f"Selecting final columns: {sel}")
        final_output_df = out.select(*sel)

        logger.debug("predict finished.")
        return final_output_df

    def print_tree(self) -> str:
        """Return a human-readable representation (driver-side)."""
        logger.debug("print_tree started.")
        lines = []

        def rec(nid: int, depth: int):
            nd = self.tree.get(nid)
            if nd is None:
                lines.append("  " * depth + f"#{nid} MISSING")
                return
            ind = "  " * depth
            if nd.is_leaf:
                lines.append(f"{ind}Leaf {nid} → {nd.prediction}")
            else:
                meas, ex = nd.split_on or (None, {})
                lines.append(f"{ind}Node {nid} split={meas} labels={list(ex.keys())}")
                for lbl, cid in sorted(nd.children.items()):
                    lines.append(f"{ind}  ├─ lbl={lbl} → child {cid}")
                    rec(cid, depth + 2)

        rec(0, 0)
        tree_str = "\n".join(lines)
        logger.debug("print_tree finished.")
        return tree_str

    def save_tree(self, path: str):
        """Pickle the *entire* manager (tree + params) to a file."""
        logger.debug(f"save_tree started. Path: {path}")
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "wb") as fh:
            pickle.dump({
                "max_depth": self.max_depth,
                "min_samples": self.min_samples,
                "k": self.k,
                "tree": self.tree, # Tree contains TreeNode namedtuples
                "_next_id": self._next_id,
                "_maj": self._maj,
            }, fh)
        logger.debug(f"Tree saved successfully to {path}.")
        logger.debug("save_tree finished.")


    @classmethod
    def load_tree(cls, spark: SparkSession, path: str) -> "GlobalModelManager":
        logger.debug(f"load_tree started. Path: {path}")
        try:
            with open(path, "rb") as fh:
                data: Dict[str, Any] = pickle.load(fh)
            logger.debug("Data loaded successfully from pickle.")
        except FileNotFoundError:
            logger.error(f"Model file not found at {path}")
            raise
        except Exception as e:
            logger.error(f"Failed to load tree from {path}: {e}")
            raise

        dummy_conf = {
            "tree_params": {
                "max_depth": data.get("max_depth"), # Use .get for safety
                "min_samples_split": data.get("min_samples", 2), # Use .get with default
                "n_splitters": data.get("k", 5), # Use .get with default
            }
        }
        logger.debug(f"Loaded hyperparameters: {dummy_conf['tree_params']}")

        inst = cls(spark, dummy_conf)
        inst.tree = data.get("tree", {}) # Use .get with default
        inst._next_id = data.get("_next_id", 1) # Use .get with default
        inst._maj = data.get("_maj", 1) # Use .get with default
        logger.debug(f"Instance created and state restored. Root node exists: {0 in inst.tree}")
        logger.debug("load_tree finished.")
        return inst


## Prediction Logic

### Global Model
`predict_with_global_prox_tree()` validates and uses a global model to generate predictions on a Spark DataFrame, renaming columns as needed for evaluation.

### Local Model
`PredictionManager` broadcasts the local ensemble to workers and uses a `pandas_udf` to apply AEON’s `.predict()` on partitioned data, returning predictions in a new column.


In [13]:

# Set up a logger for this external function if needed
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)

def predict_with_global_prox_tree(global_tree_model, data_df: DataFrame) -> DataFrame:

    logger.info("Starting external prediction with GlobalProxTree.")
    



    if not (hasattr(global_tree_model, 'predict') and callable(global_tree_model.predict) and hasattr(global_tree_model, 'spark')):
        model_type_name = type(global_tree_model).__name__ if global_tree_model is not None else "None"
        logger.error(f"Invalid model type provided. Expected GlobalProxTree-like object, got {model_type_name}.")
        raise TypeError(f"Invalid model type provided to predict_with_global_prox_tree. Expected GlobalProxTree-like object.")

    predictions_df = global_tree_model.predict(data_df)


    if "true_label" in predictions_df.columns and "label" not in predictions_df.columns:
        logger.debug("Renaming 'true_label' to 'label' in predictions DataFrame for evaluation compatibility.")
        predictions_df = predictions_df.withColumnRenamed("true_label", "label")

    if "prediction" not in predictions_df.columns:
         logger.error("GlobalProxTree.predict did not return a 'prediction' column.")    
    logger.info("Finished external prediction with GlobalProxTree.")

    return predictions_df


class PredictionManager:
    def __init__(self, spark, ensemble: ProximityForest):
        """
        Initialize with a trained ProximityForest model.
        Args:
            spark: Spark session
            ensemble: Trained model from LocalModelManager.train_ensemble()
        
        """
        self.spark = spark
        self.ensemble = ensemble
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.addHandler(logging.StreamHandler())
        self.logger.setLevel(logging.ERROR)
        
        # Basic model validation
        if not ensemble or not hasattr(ensemble, 'is_fitted') or not ensemble.is_fitted_:
            raise ValueError("Model is not trained. First call LocalModelManager.train_ensemble()")


    def _create_predict_udf(self):
        """Create Spark UDF for making predictions."""
        # Broadcast model to all workers
        broadcast_model = self.spark.sparkContext.broadcast(self.ensemble)
        
        @pandas_udf(DoubleType())
        def predict_udf(features: pd.Series) -> pd.Series:
            """Converts features to AEON format and makes predictions."""
            def predict_single(feature_array):
                try:
                    # Reshape to AEON's expected format: (samples, channels, features)
                    X = np.ascontiguousarray(feature_array).reshape(1, 1, -1)
                    return float(broadcast_model.value.predict(X)[0])
                except Exception as e:
                    print(f"Prediction error: {e}")
                    return float(-999)
  
            return features.apply(predict_single)
            
        return predict_udf

    def generate_predictions_local(self, test_df: DataFrame) -> DataFrame:
        
        """
        We take our test data and add a new column to it that will hold the predictions.        

        """
        # First, gotta make sure we actually have some models to use!
        if not self.ensemble:
            raise ValueError("No models available for prediction")

        feature_cols = [col for col in test_df.columns if col != "label"]
        
        test_df = test_df.withColumn(
            "features", 
            F.array(*[F.col(c).cast("double") for c in feature_cols])
        )
        
        predict_udf = self._create_predict_udf()
      
        predictions_df = test_df.withColumn(
            "prediction", 
            predict_udf("features")  
        ).drop("features")
        return predictions_df
    
    

## PipelineController_Loop: Orchestrating End-to-End Model Training

The `PipelineController_Loop` class coordinates the entire model training pipeline for both global and local Proximity Tree models using Apache Spark. It supports dynamic partitioning and looped experimentation based on configuration.

### Key Responsibilities:
- **Spark Setup**: Dynamically configures a local or Databricks SparkSession. Verifies data paths and manages module dependencies.
- **Pipeline Execution**: Iteratively runs ingestion, preprocessing, training, prediction, and evaluation for both model types.
- **Model Training**:
  - **Global**: Trains a single Proximity Tree using the full dataset with distributed logic.
  - **Local**: Trains multiple Proximity Trees on partitioned data in parallel, aggregating them into a Proximity Forest.
- **Evaluation & Logging**: Tracks runtime, logs metrics, and saves performance reports.
- **Persistence**: Stores models and evaluation logs after each iteration with timestamped filenames.

In [14]:
class PipelineController_Loop:
    def __init__(self, config):
        """
        Initialize the controller with the pipeline configuration.
        """
        self.config = config
        self.spark = None # SparkSession managed once for the entire run
        self.ingestion_config = {} # Populated during _setup_spark
        
        # Logger setup
        self.logger = logging.getLogger(__name__)
        if not self.logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)
            if self.logger.level == logging.NOTSET:
                 self.logger.setLevel(logging.INFO) 
            # Prevent duplicate messages if root logger also has handlers
            self.logger.propagate = False 

    def _setup_spark(self):

        # Configure Spark Session only if one doesn't exist 
        if self.spark is None:
            try:
                if "DATABRICKS_RUNTIME_VERSION" in os.environ:
                    # Databricks environment
                    self.spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
                    print("\nUsing Databricks Spark session.")
                    self.ingestion_config = {
                        "data_path": self.config.get("databricks_data_path", "/mnt/2025-team6/fulldataset_ECG5000.csv"),
                        "data_percentage": self.config.get("data_percentage", 0.05) 
                    }
                else:
                    # Local environment setup
                    # Stop existing local session if present before creating new one
                    existing_spark = SparkSession.getActiveSession()
                    if existing_spark:
                         self.logger.info("Stopping existing Spark session before starting new one.")
                         existing_spark.stop()

                    self.spark = SparkSession.builder \
                        .appName(f"LocalPipeline_Run_{time.time()}") \
                        .master("local[6]") \
                        .config("spark.driver.memory", "12g") \
                        .config("spark.executor.memory", "12g") \
                        .config("spark.driver.maxResultSize", "12g") \
                        .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
                        .getOrCreate()
                    print("\nUsing local Spark session.")

                    # --- Construct Local Data Path (Robustly) ---
                    project_root = None
                    current_dir = None
                    try:
                         # __file__ is the path of the current script (controller_loop.py)
                         current_script_path = os.path.abspath(__file__) 
                         current_dir = pathlib.Path(current_script_path).parent # Directory containing this script (src)
                         # Assumes script is in src, parent is code, parent.parent is project root
                         project_root = current_dir.parent.parent.resolve() 
                         self.logger.info(f"Project root determined using __file__: {project_root}")
                    except NameError:
                         # Fallback if __file__ is not defined (e.g., interactive/notebook)
                         # *** CORRECTED FALLBACK LOGIC ***
                         # Assume CWD is the directory containing the notebook/where python was launched
                         cwd = pathlib.Path(os.getcwd()).resolve() 
                         self.logger.warning(f"__file__ not defined, using CWD: {cwd}")
                         # Check if CWD looks like the 'src' directory
                         if cwd.name == 'src':
                              project_root = cwd.parent.parent.resolve() # Go up two levels
                              current_dir = cwd # Set current_dir for module loading later
                              self.logger.info(f"Assuming CWD is 'src', project root set to: {project_root}")
                         # Check if CWD looks like the 'code' directory
                         elif cwd.name == 'code':
                              project_root = cwd.parent.resolve() # Go up one level
                              current_dir = cwd / "src" # Assume src exists for module loading
                              self.logger.info(f"Assuming CWD is 'code', project root set to: {project_root}")
                         else: # Otherwise, assume CWD *is* the project root
                              project_root = cwd 
                              current_dir = project_root / "src" # Assume src exists for module loading
                              self.logger.info(f"Assuming CWD is project root: {project_root}")
                         
                    # Ensure project_root was determined
                    if project_root is None:
                         self.logger.error("FATAL: Could not determine project root directory.")
                         if self.spark: self.spark.stop(); self.spark = None
                         return False

                    local_data_file = self.config.get("local_data_path", "fulldataset_ECG5000.csv").lstrip('/\\') 
                    final_data_path_os = project_root / local_data_file # Use pathlib's / operator
                    
                    if not final_data_path_os.exists():
                         self.logger.error(f"FATAL: Constructed data path does NOT exist: {final_data_path_os}")
                         # Stop the newly created session if path is invalid
                         if self.spark: self.spark.stop(); self.spark = None
                         return False # Indicate setup failure
                    else:
                         self.logger.info(f"Verified data path exists: {final_data_path_os}") 
                         # Convert OS path to file URI for Spark AFTER verification
                         final_data_path_uri = final_data_path_os.as_uri() 

                    self.ingestion_config = {
                        "data_path": final_data_path_uri, 
                        "data_percentage": self.config.get("data_percentage", 1.0) 
                    }
                    self.logger.info(f"Data path set in ingestion_config: {self.ingestion_config['data_path']}")
                    self.logger.info(f"Local data percentage set to: {self.ingestion_config['data_percentage']}")
                    # --- End Construct Local Data Path ---

                # Add Python module dependencies for local runs (needs to be done only once per session)
                if "DATABRICKS_RUNTIME_VERSION" not in os.environ and self.spark:
                    modules_to_add = ['global_model_manager.py'] # Add other required modules if needed
                    try:
                        # Use the current_dir determined earlier (should be src)
                        if current_dir is None or not current_dir.exists() or not current_dir.is_dir():
                             # Fallback if current_dir wasn't determined correctly
                             current_dir = project_root / "src" 
                             self.logger.warning(f"Re-assuming 'src' directory for module loading: {current_dir}")
                             if not current_dir.exists():
                                  raise FileNotFoundError("Cannot find assumed 'src' directory for module loading.")
                             
                        for module_name in modules_to_add:
                             module_path = current_dir / module_name # Path relative to src
                             if module_path.exists():
                                 self.spark.sparkContext.addPyFile(str(module_path)) # addPyFile needs string path
                                 self.logger.debug(f"Added {module_path} to SparkContext pyFiles.") 
                             else:
                                  self.logger.error(f"Could not find module {module_name} at {module_path}")
                    except Exception as e:
                        self.logger.error(f"Failed to add Python module to SparkContext: {e}")
                
                return True # Indicate successful setup

            except Exception as e:
                 self.logger.error(f"Error during Spark setup: {e}", exc_info=True)
                 if self.spark: self.spark.stop(); self.spark = None # Ensure cleanup on error
                 return False # Indicate setup failure
        else:
             # Spark session already exists
             self.logger.debug("Spark session already exists.")
             return True
            
    def run(self):
        """
        Executes the main pipeline loop for model training and evaluation.
        Iterates through specified partition counts, running the full pipeline (ingestion, preprocessing, training, eval, save) in each iteration.
        """
        # --- Determine number of iterations and which models to run ---
        number_iterations_global, number_iterations_local = 0, 0
        run_local = self.config.get("local_model_config", {}).get("test_local_model", False)
        run_global = self.config.get("global_model_config", {}).get("test_global_model", False)

        if not run_local and not run_global:
            self.logger.error("No model selected for testing. Set 'test_local_model' or 'test_global_model' to True in config.")
            return 

        if run_local:
            number_iterations_local = self.config.get("local_model_config", {}).get("num_partitions", 0)
        if run_global:
            # Use the partition count from global config to control loop iterations for the experiment
            number_iterations_global = self.config.get("global_model_config", {}).get("num_partitions", 0) 
            
        # Determine the overall maximum number of iterations needed
        number_iterations = 0
        if run_local: number_iterations = max(number_iterations, number_iterations_local)
        if run_global: number_iterations = max(number_iterations, number_iterations_global)

        min_iterations = self.config.get("min_number_iterarations", 2)
        start_iteration = min_iterations
        end_iteration = number_iterations
        
        if end_iteration < start_iteration:
             self.logger.warning(f"Max iterations ({end_iteration}) is less than min iterations ({start_iteration}). No iterations will run.")
             start_iteration = end_iteration + 1 # Make range empty

        # --- Initialize report accumulators ---
        all_reports_global = {} 
        all_reports_local = {}  

        # --- Setup Spark Session ONCE before the loop ---
        if not self._setup_spark():
             self.logger.error("Initial Spark setup failed. Aborting run.")
             return
        # Check data path after setup
        if self.ingestion_config.get("data_path") is None:
             self.logger.error("Initial data path construction failed. Aborting run.")
             # Spark might have been stopped in _setup_spark if path was invalid
             if "DATABRICKS_RUNTIME_VERSION" not in os.environ and self.spark: self.spark.stop() 
             return

        # ============================================================
        # === Main Iteration Loop ===
        # ============================================================
        try: # Add try block around the loop for final cleanup
            for i in range(start_iteration, end_iteration + 1): 
                self.logger.info(f"========== Starting Iteration {i} ==========")
                
                iteration_run_local = run_local and i <= number_iterations_local 
                iteration_run_global = run_global # Global runs in all iterations up to its max if enabled

                # --- Setup Modules for Iteration ---
                # Spark session is already running
                current_datetime = time.strftime("%Y-%m-%d-%H-%M-%S")
                # Initialize modules for this iteration
                self.evaluator = Evaluator(track_memory=self.config.get("track_memory", False)) 
                # Create a deep copy of the config for this iteration to avoid side effects
                current_iter_config = json.loads(json.dumps(self.config)) 
                # Update partition counts in the copied config for this iteration
                if iteration_run_local: 
                     current_iter_config.setdefault("local_model_config", {})["num_partitions"] = i
                if iteration_run_global: 
                     current_iter_config.setdefault("global_model_config", {})["num_partitions"] = i
                
                # Initialize Preprocessor and Ingestion with the iteration-specific config
                self.preprocessor = Preprocessor(config=current_iter_config) 
                # Ingestion uses the config set during initial _setup_spark
                self.ingestion = DataIngestion(spark=self.spark, config=self.ingestion_config) 

                # Define DataFrame variables for this iteration scope
                preprocessed_train_df: DataFrame = None
                preprocessed_test_df: DataFrame = None
                min_max_values: dict = None

                try:
                    # --- Data Ingestion (runs every iteration) ---
                    self.evaluator.start_timer("Ingestion")
                    self.logger.info(f"Iteration {i}: Loading data...")
                    df = self.ingestion.load_data() 
                    if df.limit(1).count() == 0:
                         self.logger.error(f"Iteration {i}: Data ingestion resulted in empty DataFrame. Skipping iteration.")
                         continue # Skip to next iteration
                    self.evaluator.record_time("Ingestion")

                    # --- Split, Calculate Min/Max (runs every iteration) ---
                    self.evaluator.start_timer("Split_MinMax")
                    min_max_values = compute_min_max(df) 
                    self.logger.info(f"Iteration {i}: Computed Min-Max values.") 
                    train_df, test_df = randomSplit_stratified_via_sampleBy(df, label_col="label", weights=[0.8, 0.2], seed=123)       
                    self.evaluator.record_time("Split_MinMax")
                    
                    # --- Preprocessing Train Data (runs every iteration, includes repartitioning) ---
                    self.evaluator.start_timer("Preprocessing_Train")
                    target_partitions = current_iter_config.get("local_model_config" if iteration_run_local else "global_model_config", {}).get("num_partitions", "N/A")
                    self.logger.info(f"Iteration {i}: Preprocessing train data (target partitions={target_partitions})...")
                    preprocessed_train_df = self.preprocessor.run_preprocessing(train_df, min_max_values) 
                    train_count = preprocessed_train_df.count() # Action to materialize preprocessing
                    self.logger.info(f"Iteration {i}: Preprocessed train data count: {train_count}")
                    self.evaluator.record_time("Preprocessing_Train")

                    # --- Preprocessing Test Data (runs every iteration, includes repartitioning) ---
                    self.evaluator.start_timer("Preprocessing_Test")
                    self.logger.info(f"Iteration {i}: Preprocessing test data (target partitions={target_partitions})...")
                    preprocessed_test_df = self.preprocessor.run_preprocessing(test_df, min_max_values)
                    test_count = preprocessed_test_df.count() # Action
                    self.logger.info(f"Iteration {i}: Preprocessed test data count: {test_count}")
                    self.evaluator.record_time("Preprocessing_Test")

                    if train_count == 0 or test_count == 0:
                        self.logger.error(f"Iteration {i}: Preprocessing resulted in empty train or test set. Skipping model steps.")
                        continue 

                except Exception as e:
                     # Handle potential errors during data loading/preprocessing
                     if "Path does not exist" in str(e):
                          spark_path_attempt = self.ingestion_config.get("data_path", "N/A") 
                          self.logger.error(f"Iteration {i}: Spark failed to find data file during load. Path: {spark_path_attempt}. Error: {e}", exc_info=False) 
                     elif isinstance(e, (ConnectionRefusedError, ConnectionResetError)) or "Connection reset by peer" in str(e):
                          self.logger.error(f"Iteration {i}: Connection error during data processing: {e}", exc_info=True)
                     else:
                          self.logger.error(f"Iteration {i}: Error during data processing: {e}", exc_info=True) 
                     continue # Skip to next iteration
                finally:
                     # Clean up intermediate raw dataframes for this iteration
                     if 'df' in locals(): del df
                     if 'train_df' in locals(): del train_df
                     if 'test_df' in locals(): del test_df


                # Define variables to hold model/prediction results for the iteration
                model_ensamble = None 
                predictions_df = None 
                
                # =================== GLOBAL MODEL ===================
                if iteration_run_global:
                    # ... (Global model block remains the same) ...
                    self.global_model_manager = None 
                    model_ensamble = None 
                    global_report = None 
                    try:
                        # Pass the specific global config for this iteration
                        global_config = current_iter_config.get("global_model_config", {}) 
                        if not global_config:
                             self.logger.error(f"Iteration {i}: global_model_config missing. Skipping global model.")
                        else:
                            self.global_model_manager = GlobalModelManager(spark=self.spark, config=global_config) 
                            
                            print(f"\nIteration {i}: Train global model......")
                            self.evaluator.start_timer("Global_Training")
                            model_ensamble = self.global_model_manager.fit(preprocessed_train_df) 
                            self.evaluator.record_time("Global_Training")
                            print(f"Iteration {i}: Finish Global Training.")

                            if model_ensamble and hasattr(model_ensamble, 'tree') and len(model_ensamble.tree) > 1:
                                self.logger.info(f"Iteration {i}: Global model training successful.")
                                
                                print(f"\nIteration {i}: Generate predictions with global model......")
                                self.evaluator.start_timer("Global_Prediction")
                                predictions_df = predict_with_global_prox_tree(model_ensamble, preprocessed_test_df) 
                                self.evaluator.record_time("Global_Prediction")
                                print(f"Iteration {i}: Finish Global Prediction.")
                                
                                print(f"\nIteration {i}: Global Model Predictions Distribution:")
                                predictions_df.groupBy("prediction").count().show() 

                                print(f"\nIteration {i}: Generate metrics for global model......")
                                global_report, class_names = self.evaluator.log_metrics(predictions_df, model=model_ensamble) 
                                all_reports_global[str(i)] = global_report 
                                print(f"Iteration {i}: Finish Global Evaluation.")
                                
                                depth = global_config.get("tree_params", {}).get("max_depth", "NA")
                                
                                model_folder = "models_global" 
                                os.makedirs(model_folder, exist_ok=True) 
                                model_filename = f"global_model_iter_{i}_parti_{i}_{current_datetime}_depth_{depth}.pkl" 
                                model_save_path = os.path.join(model_folder, model_filename)
                                try:
                                    self.global_model_manager.save_tree(model_save_path) 
                                    print(f"Saved global model to {model_save_path}")
                                except Exception as e: self.logger.error(f"Failed to save global model {model_filename}: {e}")

                            else:
                                self.logger.warning(f"Iteration {i}: Global model training failed or resulted in trivial tree.")

                    except Exception as e:
                        self.logger.error(f"Iteration {i}: Error during global model processing: {e}", exc_info=True) 
                    finally:
                        # Clean up global model objects specific to this iteration
                        self.global_model_manager = None 
                        if 'model_ensamble' in locals() and model_ensamble is not None: del model_ensamble
                        if not iteration_run_local and 'predictions_df' in locals() and predictions_df is not None: del predictions_df 
                        self.logger.debug(f"Iteration {i}: Cleaned up global model objects.")


                # =================== LOCAL MODEL ===================
                if iteration_run_local:
                    # ... (Local model block remains the same) ...
                    self.local_model_manager = None 
                    model_ensamble = None 
                    if iteration_run_global and 'predictions_df' in locals(): predictions_df = None 
                    local_report = None 
                    try:
                        # Pass the specific local config for this iteration
                        local_config = current_iter_config.get("local_model_config", {})
                        if not local_config:
                            self.logger.error(f"Iteration {i}: local_model_config missing. Skipping local model.")
                        else: 
                            self.local_model_manager = LocalModelManager(config=local_config) 
                            
                            print(f"\nIteration {i}: Train local model with {i} partitions......")
                            self.evaluator.start_timer("Local_Training")
                            model_ensamble = self.local_model_manager.train_ensemble(preprocessed_train_df) 
                            self.evaluator.record_time("Local_Training")
                            print(f"Iteration {i}: Finish Local Training.")
                            
                            if model_ensamble is not None and hasattr(model_ensamble, 'trees_') and model_ensamble.trees_:
                                self.logger.info(f"Iteration {i}: Local model training successful.")
                                
                                print(f"\nIteration {i}: Generate predictions with local model......")
                                self.evaluator.start_timer("Local_Prediction")
                                self.predictor = PredictionManager(self.spark, model_ensamble) 
                                predictions_df = self.predictor.generate_predictions_local(preprocessed_test_df) 
                                self.evaluator.record_time("Local_Prediction")
                                print(f"Iteration {i}: Finish Local Prediction.")
                            
                                print(f"\nIteration {i}: Local model Predictions Distribution:")
                                predictions_df.groupBy("prediction").count().show() 

                                print(f"\nIteration {i}: Generate metrics for local model......")
                                local_report, class_names = self.evaluator.log_metrics(predictions_df, model=model_ensamble)
                                all_reports_local[str(i)] = local_report 
                                print(f"Iteration {i}: Finish Local Evaluation.")
                                
                                depth = local_config.get("tree_params", {}).get("max_depth", "NA") 
                                
                                model_folder = "models_local" 
                                os.makedirs(model_folder, exist_ok=True) 
                                model_filename = f"local_model_iter_{i}_parti_{i}_{current_datetime}_depth_{depth}.pkl" 
                                model_save_path = os.path.join(model_folder, model_filename)
                                try:
                                    with open(model_save_path, 'wb') as f: pickle.dump(model_ensamble, f) 
                                    print(f"Saved local model ensemble to {model_save_path}")
                                except Exception as e: self.logger.error(f"Failed to save local model {model_filename}: {e}")
                                
                            else:
                                self.logger.warning(f"Iteration {i}: Local model training failed or resulted in empty ensemble.")

                    except Exception as e:
                         self.logger.error(f"Iteration {i}: Error during local model processing: {e}", exc_info=True) 
                    finally:
                        # Clean up local model objects specific to this iteration
                        self.local_model_manager = None 
                        if 'model_ensamble' in locals() and model_ensamble is not None: del model_ensamble
                        if 'predictions_df' in locals() and predictions_df is not None: del predictions_df
                        self.logger.debug(f"Iteration {i}: Cleaned up local model objects.")


                # --- Iteration End ---
                # Cleanup preprocessed DataFrames for this iteration 
                try:
                     if 'preprocessed_train_df' in locals() and preprocessed_train_df is not None: 
                          # preprocessed_train_df.unpersist() # Caching removed, unpersist not needed
                          del preprocessed_train_df
                     if 'preprocessed_test_df' in locals() and preprocessed_test_df is not None: 
                          # preprocessed_test_df.unpersist() # Caching removed, unpersist not needed
                          del preprocessed_test_df
                except Exception as cleanup_e:
                     self.logger.warning(f"Iteration {i}: Error during DataFrame cleanup: {cleanup_e}")

                
                self.logger.info(f"========== Finished Iteration {i} ==========")
                # Optional delay
                if self.config.get('delay_time', 0) > 0 and i < end_iteration : 
                    print(f"\nIteration {i}: Waiting for {self.config['delay_time']} seconds before next iteration...\n")
                    time.sleep(self.config['delay_time'])

                # *** Spark session is NOT stopped here anymore ***

            # End of main loop

        #  Save Accumulated Reports After Loop ===
        # ============================================================
        finally: # Use finally to ensure Spark stops even if loop errors out
             final_datetime = time.strftime("%Y-%m-%d-%H-%M-%S") 
            
             if all_reports_global:
                  report_folder = "logs" 
                  os.makedirs(report_folder, exist_ok=True) 
                  report_filename_global = f"report_global_model_ALL_{final_datetime}.json" 
                  report_save_path_global = os.path.join(report_folder, report_filename_global)
                  try:
                      with open(report_save_path_global, "w") as f: json.dump(all_reports_global, f, indent=2)
                      print(f"Saved ALL global model reports to {report_save_path_global}") 
                  except Exception as e: self.logger.error(f"Failed to save aggregated global report: {e}")

             if all_reports_local:
                  report_folder = "logs"
                  os.makedirs(report_folder, exist_ok=True) 
                  report_filename_local = f"report_local_model_ALL_{final_datetime}.json" 
                  report_save_path_local = os.path.join(report_folder, report_filename_local)
                  try:
                      with open(report_save_path_local, "w") as f: json.dump(all_reports_local, f, indent=2)
                      print(f"Saved ALL local model reports to {report_save_path_local}") 
                  except Exception as e: self.logger.error(f"Failed to save aggregated local report: {e}")


             # --- Pipeline End ---
             # Final Spark stop if running locally 
             if "DATABRICKS_RUNTIME_VERSION" not in os.environ and self.spark:
                  self.spark.stop()
                  self.spark = None # Reset attribute
                  print(f"\nFinal Spark session stopped (local mode)!")

             print("\n--- Pipeline execution finished ---")



## Centralized Configuration

The `config.py` file consolidates all key parameters and constants used throughout the project. It simplifies updates and experimentation by centralizing:

- **Data Settings**: File paths for Databricks and local environments, target data usage percentage, and label column name.
- **Loop Control**: Minimum iterations and optional delays between runs.
- **Local Model Config**: Enables local Proximity Forest training with customizable partition count and hyperparameters for individual trees and the ensemble.
- **Global Model Config**: Enables distributed Proximity Tree training with tunable depth, splitters, and minimum sample thresholds.
- **Partitioning Behavior**: Controls whether `_partition_id` columns are preserved.


In [15]:
# config.py
"""
config.py

This file holds configuration settings and constants.
It stores paths, hyperparameters, and Spark settings in one place,
so they can be easily managed and updated as the project grows.
Typically, it's created early on, but it can be refined later.
"""

config = {
    "databricks_data_path" : "/mnt/2025-team6/fulldataset_ECG5000.csv",
    "local_data_path" : "/fulldataset_ECG5000.csv",    # Relative to project root
    "label_col" : "label",
    "data_percentage" : 1.0,
    "min_number_iterarations" : 2, # Minimum number of iterations for the loop
    "delay_time" : 3,
    
    "local_model_config": {
        "test_local_model" : True,
        "num_partitions": 3,  # loop to this number of partitions - This is the number of partitions for the local model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": 2,  #TODO: NONE
            "min_samples_split": 5,  # From ProximityTree default
            "random_state": 123
            },
        "forest_params": {
            "random_state": 123,
            "n_jobs": -1  # Use all available cores
            }
    },
    "global_model_config": {
        "test_global_model" : True,
        "num_partitions": 3,  # loop to this number of partitions - This is the number of partitions for the global model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": 1,  
            "min_samples_split": 5,  # From ProximityTree defaults
            "random_state": 123
            },
    },
    "reserve_partition_id": False
}

## Running the Pipeline

In [16]:

print("Starting pipeline via controller")
config = config
controller = PipelineController_Loop(config)
controller.run()

__file__ not defined, using CWD: D:\repos\BigData-main\BigData-1\code\src
__file__ not defined, using CWD: D:\repos\BigData-main\BigData-1\code\src
__file__ not defined, using CWD: D:\repos\BigData-main\BigData-1\code\src
__file__ not defined, using CWD: D:\repos\BigData-main\BigData-1\code\src
Assuming CWD is 'src', project root set to: D:\repos\BigData-main\BigData-1
Assuming CWD is 'src', project root set to: D:\repos\BigData-main\BigData-1
Assuming CWD is 'src', project root set to: D:\repos\BigData-main\BigData-1
Assuming CWD is 'src', project root set to: D:\repos\BigData-main\BigData-1
2025-05-06 08:08:50,023 - INFO - __main__ - Assuming CWD is 'src', project root set to: D:\repos\BigData-main\BigData-1
Verified data path exists: D:\repos\BigData-main\BigData-1\fulldataset_ECG5000.csv
Verified data path exists: D:\repos\BigData-main\BigData-1\fulldataset_ECG5000.csv
Verified data path exists: D:\repos\BigData-main\BigData-1\fulldataset_ECG5000.csv
Verified data path exists: D:\r

Starting pipeline via controller

Using local Spark session.


Iteration 2: Loading data...
Iteration 2: Loading data...
Iteration 2: Loading data...
Iteration 2: Loading data...
2025-05-06 08:08:50,131 - INFO - __main__ - Iteration 2: Loading data...


Data Path: file:///D:/repos/BigData-main/BigData-1/fulldataset_ECG5000.csv
Loading 100.0% of data
Data size: 5000



Iteration 2: Computed Min-Max values.
Iteration 2: Computed Min-Max values.
Iteration 2: Computed Min-Max values.
Iteration 2: Computed Min-Max values.
2025-05-06 08:08:51,984 - INFO - __main__ - Iteration 2: Computed Min-Max values.
Iteration 2: Preprocessing train data (target partitions=2)...
Iteration 2: Preprocessing train data (target partitions=2)...
Iteration 2: Preprocessing train data (target partitions=2)...
Iteration 2: Preprocessing train data (target partitions=2)...
2025-05-06 08:08:52,292 - INFO - __main__ - Iteration 2: Preprocessing train data (target partitions=2)...


Repartitioning to <<<< 2 >>>> workers - partitions.


Iteration 2: Preprocessed train data count: 3990
Iteration 2: Preprocessed train data count: 3990
Iteration 2: Preprocessed train data count: 3990
Iteration 2: Preprocessed train data count: 3990
2025-05-06 08:08:53,312 - INFO - __main__ - Iteration 2: Preprocessed train data count: 3990
Iteration 2: Preprocessing test data (target partitions=2)...
Iteration 2: Preprocessing test data (target partitions=2)...
Iteration 2: Preprocessing test data (target partitions=2)...
Iteration 2: Preprocessing test data (target partitions=2)...
2025-05-06 08:08:53,314 - INFO - __main__ - Iteration 2: Preprocessing test data (target partitions=2)...


Repartitioning to <<<< 2 >>>> workers - partitions.


Iteration 2: Preprocessed test data count: 1010
Iteration 2: Preprocessed test data count: 1010
Iteration 2: Preprocessed test data count: 1010
Iteration 2: Preprocessed test data count: 1010
2025-05-06 08:08:54,743 - INFO - __main__ - Iteration 2: Preprocessed test data count: 1010



Iteration 2: Train global model......


Iteration 2: Global model training successful.
Iteration 2: Global model training successful.
Iteration 2: Global model training successful.
Iteration 2: Global model training successful.
2025-05-06 08:10:13,991 - INFO - __main__ - Iteration 2: Global model training successful.
Starting external prediction with GlobalProxTree.
Starting external prediction with GlobalProxTree.
Starting external prediction with GlobalProxTree.
Starting external prediction with GlobalProxTree.
2025-05-06 08:10:13,993 - INFO - __main__ - Starting external prediction with GlobalProxTree.


Iteration 2: Finish Global Training.

Iteration 2: Generate predictions with global model......


Finished external prediction with GlobalProxTree.
Finished external prediction with GlobalProxTree.
Finished external prediction with GlobalProxTree.
Finished external prediction with GlobalProxTree.
2025-05-06 08:10:15,886 - INFO - __main__ - Finished external prediction with GlobalProxTree.


Iteration 2: Finish Global Prediction.

Iteration 2: Global Model Predictions Distribution:
+----------+-----+
|prediction|count|
+----------+-----+
|         1|  569|
|         4|   32|
|         2|  409|
+----------+-----+


Iteration 2: Generate metrics for global model......
Iteration 2: Finish Global Evaluation.
Saved global model to models_global\global_model_iter_2_parti_2_2025-05-06-08-08-50_depth_1.pkl

Iteration 2: Train local model with 2 partitions......
Iteration 2: Finish Local Training.

Iteration 2: Generate predictions with local model......
Iteration 2: Finish Local Prediction.

Iteration 2: Local model Predictions Distribution:
+----------+-----+
|prediction|count|
+----------+-----+
|       1.0|  572|
|       4.0|   20|
|       3.0|    1|
|       2.0|  417|
+----------+-----+


Iteration 2: Generate metrics for local model......
Iteration 2: Finish Local Evaluation.
Saved local model ensemble to models_local\local_model_iter_2_parti_2_2025-05-06-08-08-50_depth_2.pkl