In [0]:
# %%
# Install required packages if running in an environment where they might be missing
import sys
import subprocess

required_packages = ['aeon', 'psutil', 'pyspark', 'numpy', 'pandas', 'numba'] # Added numba just in case

for package in required_packages:
    try:
        __import__(package)
        print(f"{package} already installed.")
    except ImportError:
        print(f"Installing {package}...")
        try:
            # Use Databricks' recommended %pip magic command if available, otherwise use subprocess
            if 'dbutils' in locals() or 'ipykernel' in sys.modules:
                 print(f"Using %pip install {package}")
                 # Cannot execute %pip directly here, user needs to run '%pip install package' in a separate cell if needed.
                 # Fallback to subprocess for general environments
                 subprocess.check_call([sys.executable, "-m", "pip", "install", package])
            else:
                 subprocess.check_call([sys.executable, "-m", "pip", "install", package])
            print(f"{package} installed successfully.")
            # Re-import after installation if needed immediately (usually not necessary for top-level scripts)
            # __import__(package) 
        except Exception as e:
            print(f"ERROR: Failed to install {package}. Please install it manually (e.g., using '%pip install {package}' in Databricks). Error: {e}")
            # Optionally raise the error if installation is critical

import osos.environ['NUMBA_CACHE_DIR'] = '/tmp/numba_cache'

Set environment variable to prevent Numba caching issues on executors. **Run this cell before importing other project modules.**


In [0]:
import os
# Disable Numba caching via environment variable
# This needs to be set *before* Numba tries to cache functions, 
# especially when code is sent to Spark executors.
os.environ['NUMBA_DISABLE_CACHING'] = '1' 
print("Attempted to disable Numba caching by setting NUMBA_DISABLE_CACHING=1")
# Verify if set (optional)
print(f"NUMBA_DISABLE_CACHING set to: {os.environ.get('NUMBA_DISABLE_CACHING')}")

In [0]:
from __future__ import annotations

import collections
import json
import logging
import math
import os
import pathlib
import pickle
import random
import time
from typing import Any, Dict, Tuple

import numpy as np
import pandas as pd
from pyspark.sql import DataFrame, Row, SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import (
    DoubleType,
    IntegerType,
    LongType,
)

from aeon.classification.distance_based import (
    ProximityForest,
    ProximityTree,
)

from data_ingestion import DataIngestion
from evaluation import Evaluator

from utilities import (
    compute_min_max,
    randomSplit_dist,
    randomSplit_stratified_via_sampleBy,
    show_compact,
)



This module does some basic cleaning and transformations on a Spark DataFrame.
It comes after the data ingestion step.

In [0]:
class Preprocessor:
    """
    This class cleans up our ECG data.
    It handles missing rows, splits the label from the features, and
    does a simple normalization on the feature columns.
    It returns a Spark DataFrame ready for training.
    """

    def __init__(self, config: dict):
        self.config = config

    def handle_missing_values(self, df: DataFrame) -> DataFrame:
        """ Drops rows where every column is null """
        return df.dropna(how="all")
   
    
    def normalize(self, df: DataFrame, min_max: dict, preserve_partition_id: bool = False) -> DataFrame:
        """
        Normalizes feature columns using precomputed global min and max values.
        """
        feature_cols = [col for col in df.columns if (col != "label" and col != "_partition_id")]
        
        normalized_cols = []
        for col in feature_cols:
            min_val, max_val = min_max[col]
            if max_val != min_val:
                expr = (F.col(col) - F.lit(min_val)) / (F.lit(max_val - min_val))
            else:
                expr = F.lit(0.0)  # if all values identical, set to zero or keep original
            normalized_cols.append(expr.alias(col))
        
        cols_to_select = ["label"]
        if "_partition_id" in df.columns:
            cols_to_select.append("_partition_id")

        return df.select(*normalized_cols, *cols_to_select)

    def _repartition_data_NotBalanced(self, df: DataFrame, preserve_partition_id: bool = False) -> DataFrame:
        if "num_partitions" in self.config:
            new_parts = self.config["num_partitions"]  # 
            #self.logger.info(f"Repartitioning data to {new_parts} parts")
            return df.repartition(new_parts)
        return df
    
    def _repartition_data_Balanced(self, df: DataFrame, preserve_partition_id: bool = False) -> DataFrame:
        
        if ("num_partitions" in self.config["local_model_config"] \
            or "num_partitions" in self.config["global_model_config"]) \
            and "label_col" in self.config:
            
            if self.config["local_model_config"]["test_local_model"] is True:
                num_parts = self.config["local_model_config"]["num_partitions"]
            else:
                num_parts = self.config["global_model_config"]["num_partitions"]
                
            label_col = self.config["label_col"]
            # self.logger.info(f"Stratified repartitioning into {num_parts} partitions")
            
            # Assign partition IDs (0 to num_parts-1 per class)
            # Subtracting 1 so that modulo is computed from 0
            window = Window.partitionBy(label_col) \
                            .orderBy(F.rand())          #  one shuffles to group all rows of each label together so we can number them
                            
            df = df.withColumn("_partition_id", ((F.row_number().over(window) - 1) % num_parts).cast("int"))
            
            # Force exact number of partitions using partition_id
            df = df.repartition(num_parts, F.col("_partition_id"))          # one shuffles to repartition by _partition_id to ensure we have num_parts partitions
            print(f'Repartitioning to <<<< {num_parts} >>>> workers - partitions.')
            
            if not preserve_partition_id:
                df = df.drop("_partition_id")
            return df
            
        return df

    def run_preprocessing(self, df: DataFrame, min_max) -> DataFrame:
        """
        Args:
            df (DataFrame): Input DataFrame to be preprocessed. pyspark Sql DataFrame.
        Returns:
            DataFrame: Preprocessed DataFrame ready for training.
        
            
        Run all preprocessing steps in order:
         1. Drop rows that are completely null.
         2. Repartition the data : shuffle the data to balance the partitions.
         3. Normalize the feature columns.
        """
        
        df = self.handle_missing_values(df)
        
        if self.config["local_model_config"]["test_local_model"] is True:
            df = self._repartition_data_Balanced(df, preserve_partition_id = self.config["reserve_partition_id"])
        elif self.config["global_model_config"]["test_global_model"] is True:
            df = self._repartition_data_NotBalanced(df, preserve_partition_id = self.config["reserve_partition_id"])
        else:
            raise ValueError("Preprocessing error.")
        
        df = self.normalize(df, min_max, preserve_partition_id = self.config["reserve_partition_id"])
        
        return df


In [0]:

# Set up a logger for this external function if needed
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)

def predict_with_global_prox_tree(global_tree_model, data_df: DataFrame) -> DataFrame:

    logger.info("Starting external prediction with GlobalProxTree.")
    



    if not (hasattr(global_tree_model, 'predict') and callable(global_tree_model.predict) and hasattr(global_tree_model, 'spark')):
        model_type_name = type(global_tree_model).__name__ if global_tree_model is not None else "None"
        logger.error(f"Invalid model type provided. Expected GlobalProxTree-like object, got {model_type_name}.")
        raise TypeError(f"Invalid model type provided to predict_with_global_prox_tree. Expected GlobalProxTree-like object.")

    predictions_df = global_tree_model.predict(data_df)

    

        # --- Ensure the output DataFrame has a 'label' column for evaluation ---
    # The GlobalProxTree.predict method returns 'true_label' and 'prediction'.
    # The Evaluator expects 'label' and 'prediction'.
    # Rename 'true_label' to 'label' if it exists and 'label' does not.
    if "true_label" in predictions_df.columns and "label" not in predictions_df.columns:
        logger.debug("Renaming 'true_label' to 'label' in predictions DataFrame for evaluation compatibility.")
        predictions_df = predictions_df.withColumnRenamed("true_label", "label")
    # If 'label' already exists, no renaming is needed.
    # If neither exists, the Evaluator will log a warning.

    # Ensure 'prediction' column exists (should be returned by GlobalProxTree.predict)
    if "prediction" not in predictions_df.columns:
         logger.error("GlobalProxTree.predict did not return a 'prediction' column.")
         # Depending on severity, you might want to raise an error here
         # raise ValueError("Prediction column missing from GlobalProxTree output.")
         # For now, we'll let it proceed, and the Evaluator will likely fail gracefully.
    
    logger.info("Finished external prediction with GlobalProxTree.")

    return predictions_df


class PredictionManager:
    def __init__(self, spark, ensemble: ProximityForest):
        """
        Initialize with a trained ProximityForest model.
        Args:
            spark: Spark session
            ensemble: Trained model from LocalModelManager.train_ensemble()
        
        """
        self.spark = spark
        self.ensemble = ensemble
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.addHandler(logging.StreamHandler())
        self.logger.setLevel(logging.ERROR)
        
        # Basic model validation
        if not ensemble or not hasattr(ensemble, 'is_fitted') or not ensemble.is_fitted_:
            raise ValueError("Model is not trained. First call LocalModelManager.train_ensemble()")


    def _create_predict_udf(self):
        """Create Spark UDF for making predictions."""
        # Broadcast model to all workers
        broadcast_model = self.spark.sparkContext.broadcast(self.ensemble)
        
        @pandas_udf(DoubleType())
        def predict_udf(features: pd.Series) -> pd.Series:
            """Converts features to AEON format and makes predictions."""
            def predict_single(feature_array):
                try:
                    # Reshape to AEON's expected format: (samples, channels, features)
                    X = np.ascontiguousarray(feature_array).reshape(1, 1, -1)
                    return float(broadcast_model.value.predict(X)[0])
                except Exception as e:
                    print(f"Prediction error: {e}")
                    return float(-999)
  
            return features.apply(predict_single)
            
        return predict_udf

    def generate_predictions_local(self, test_df: DataFrame) -> DataFrame:
        
        """
        We take our test data and add a new column to it that will hold the predictions.        

        """
        # First, gotta make sure we actually have some models to use!
        if not self.ensemble:
            raise ValueError("No models available for prediction")

        feature_cols = [col for col in test_df.columns if col != "label"]
        
        test_df = test_df.withColumn(
            "features", 
            F.array(*[F.col(c).cast("double") for c in feature_cols])
        )
        
        predict_udf = self._create_predict_udf()
      
        predictions_df = test_df.withColumn(
            "prediction", 
            predict_udf("features")  
        ).drop("features")
        return predictions_df
    
    

This file is in charge of training our local models.
It takes the preprocessed Spark DataFrame and splits it into parts.
For each part, it trains a Proximity Tree model.
Then, it gathers all the models into one Proximity Forest ensemble that we can use later to make predictions.

NOTE: Trains models in parallel across Spark worker nodes. Number of trees = number of partitions.

In [0]:
class LocalModelManager:
    """
    This class handles training local models (Proximity Trees) on chunks of our data and then
    puts them together into an  Proximity Forest ensemble.
    
    The steps are pretty simple:
      1. Get a preprocessed Spark DataFrame.
      2. Split it into parts.
      3. Train a Proximity Tree  model on each part .
      4. Then, it gathers all the trees into one Proximity Forest ensemble
    """

    def __init__(self, config: dict):
        """
        Init with our settings.
        
        Args:
            config (dict): Settings like:
              - num_partitions: How many parts to split the data into.
              - tree_params: Extra parameters for the Proximity Tree  model.
        """       
        # Set default configuration
        self.config = config
        
        # List to store trained trees
        self.trees = []
        
        # Final ensemble model
        self.ensemble = None
        
        # Set up a logger so we can see whats going on
        self.logger = logging.getLogger(__name__)
        self.logger.addHandler(logging.StreamHandler())
        self.logger.setLevel(logging.ERROR)
        

    def _set_forest_classes(self):
        """Collect all class labels from individual trees and mark the forest as fitted."""
        all_classes = []
        for tree in self.trees:
            if hasattr(tree, "classes_"):
                all_classes.extend(tree.classes_)

        unique_classes = np.unique(all_classes)
        self.ensemble.classes_ = unique_classes
        self.ensemble.n_classes_ = len(unique_classes)

        # AEON’s BaseClassifier typically expects a '_class_dictionary' mapping class->int
        self.ensemble._class_dictionary = {
            cls: idx for idx, cls in enumerate(unique_classes)
        }

        # Some older AEON versions store the number of classes in a private attribute
        self.ensemble._n_classes = len(unique_classes)

        # If n_jobs is used, set it explicitly here
        if "n_jobs" in self.config["forest_params"]:
            self.ensemble._n_jobs = self.config["forest_params"]["n_jobs"]

        # BaseClassifier sets 'is_fitted = True' at the end of fit().
        # So we must set the public property 'is_fitted' (not just 'is_fitted_').
        # This ensures ._check_is_fitted() passes in predict().
        self.ensemble.is_fitted_ = True
        self.ensemble.is_fitted = True

        
    def get_ensemble(self) -> ProximityForest:
        """
        Return the trained Proximity Forest ensemble.
        """
        return self.ensemble

    def print_ensemble_details(self):
        """
        Print the details of the aggregated Proximity Forest ensemble.
        """
        if self.ensemble and hasattr(self.ensemble, 'trees_'):
            num_trees = len(self.ensemble.trees_)
            print(f"Aggregated Proximity Forest (contains {num_trees} trees):")
            print(f"  Number of trees (in trees_ attribute): {num_trees}")
            # You might want to print a summary of the parameters used for the forest here
            print(f"  Forest Parameters: {self.ensemble.get_params()}")
            for i, tree in enumerate(self.ensemble.trees_):
                print(f"  Tree {i+1} Details:")
                self._print_tree_node_info(tree.root, depth=2)
            print("-" * 20)
        else:
            print("Proximity Forest ensemble has not been trained yet or the 'trees_' attribute is missing.")

    def _print_tree_node_info(self, node, depth):
        indent = "  " * depth
        print(f"{indent}Node ID: {node.node_id}, Leaf: {node._is_leaf}")

        if node._is_leaf:
            print(f"{indent}  Label: {node.label}, Class Distribution: {node.class_distribution}")
        else:
            splitter = node.splitter
            if splitter:
                exemplars = splitter[0]
                distance_info = splitter[1]
                distance_measure = list(distance_info.keys())[0]
                distance_params = distance_info[distance_measure]

                print(f"{indent}  Splitter:")
                print(f"{indent}    Distance Measure: {distance_measure}, Parameters: {distance_params}")
                print(f"{indent}    Exemplar Classes: {list(exemplars.keys())}")

                print(f"{indent}  Children:")
                for label, child_node in node.children.items():
                    print(f"{indent}    Branch on exemplar of class '{label}':")
                    self._print_tree_node_info(child_node, depth + 1)

   
   
    def train_ensemble(self, df: DataFrame) -> ProximityForest:
        
        """
             Train a forest model iin 3 steps:
        1. Prepare data partitions
        2. Train trees on each partition
        3. Combine trees into a forest
        
        """
        
        tree_params = self.config["tree_params"]      
        
         # Define how to process each partition - inline function
        def process_partition(partition_data):
            """Process one data partition to train a tree."""
            try:
                # Convert Spark rows to pandas DataFrame
                pandas_df = pd.DataFrame([row.asDict() for row in partition_data])
                if pandas_df.empty:
                    return []
                
                # Prepare features (3D format for AEON) and labes
                X = np.ascontiguousarray(pandas_df.drop("label", axis=1).values)
                X_3d = X.reshape((X.shape[0], 1, X.shape[1]))  # (samples, 1, features)
                y = pandas_df["label"].values
                
                # Train one tree
                tree = ProximityTree(**tree_params)
                tree.fit(X_3d, y)
                
                # Return serialized tree
                return [pickle.dumps(tree)]
            
            except Exception as e:
                print(f"Failed to train tree on partition: {str(e)}")
                return []  # Skip failed partitions
            
        # Run training on all partitions
        trees_rdd = df.rdd.mapPartitions(process_partition)
        serialized_trees = trees_rdd.collect()
        self.trees = [pickle.loads(b) for b in serialized_trees if b is not None]

        # Build the forest
        if self.trees:
            self.ensemble = ProximityForest(
                n_trees=len(self.trees),
                **self.config["forest_params"]
            )
            # Manually set forest properties
            self.ensemble.trees_ = self.trees
            self._set_forest_classes() 
            return self.ensemble
        else:
            print("Warning: No trees were trained!")
            return None


In [0]:
# global_model_manager.py (Optimized Version - Gini Fix)
"""
Implements the GlobalModelManager class for training a distribution-friendly 
proximity tree using Spark DataFrames. Includes optimizations for UDFs, 
exemplar sampling, count reduction, and reproducibility.
"""

from __future__ import annotations

import collections
import json
import logging
import math
import os
import pickle
import random # Import random for seeding
import sys
import time
from typing import Any, Dict, List, Tuple # Adjusted typing imports

import numpy as np
import pandas as pd
from pyspark import StorageLevel # Import StorageLevel
from pyspark.sql import DataFrame, Row, SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf # Import pandas_udf
from pyspark.sql.types import (ArrayType, DoubleType, IntegerType, LongType,
                               StructField, StructType)

# Configure logging for this module
# Using a specific logger name is good practice
logger_gmm = logging.getLogger("GlobalModelManager") 
# Ensure handler is added only once
if not logger_gmm.handlers:
     handler_gmm = logging.StreamHandler(sys.stdout) # Log to stdout
     formatter_gmm = logging.Formatter('%(asctime)s - GMM - %(levelname)s - %(message)s')
     handler_gmm.setFormatter(formatter_gmm)
     logger_gmm.addHandler(handler_gmm)
     if logger_gmm.level == logging.NOTSET:
          logger_gmm.setLevel(logging.INFO) # Set desired level
     logger_gmm.propagate = False # Prevent duplicate logs

# Suppress excessive logging from py4j and pyspark itself if needed
logging.getLogger("py4j").setLevel(logging.ERROR)
logging.getLogger("pyspark").setLevel(logging.ERROR)


try:
    import numpy as np
    _NP_GMM = True # Use a distinct name
except ImportError: 
    _NP_GMM = False

# Define TreeNode namedtuple (should be defined once globally or imported)
TreeNode = collections.namedtuple(
    "TreeNode", "node_id parent_id split_on is_leaf prediction children".split()
)

# Keep the original efficient euclidean distance function (renamed)
def _euclid_gmm(a, b): 
    """Fast Euclidean distance for python *or* NumPy inputs."""
    # Add basic type/length checks for robustness within UDFs
    if a is None or b is None: return float("inf")
    # Check if inputs are list-like and have length attribute
    len_a = len(a) if hasattr(a, '__len__') else -1
    len_b = len(b) if hasattr(b, '__len__') else -1
    if len_a != len_b or len_a == -1: return float("inf")
    
    if _NP_GMM:
        try:
            # Ensure inputs are numpy arrays for subtraction
            a_np = np.asarray(a, dtype=float); b_np = np.asarray(b, dtype=float)
            diff = a_np - b_np; dist = float(np.sqrt(np.dot(diff, diff)))
            return dist
        except Exception as e: 
            # Avoid logging excessively inside UDF, maybe log sample errors if needed
            # logger_gmm.error(f"Error in _euclid_gmm (NumPy): {e}") 
            return float("inf") 
    else: # Pure Python path
        try:
            dist = float(math.sqrt(sum((float(x) - float(y)) ** 2 for x, y in zip(a, b)))) # Add float conversion
            return dist
        except Exception as e: 
            # logger_gmm.error(f"Error in _euclid_gmm (Python): {e}")
            return float("inf") 

# =============================================================================
# GlobalModelManager class (Optimised + Enhanced Prediction)
# =============================================================================

class GlobalModelManager:
    """
    Distribution-friendly proximity-tree learner using Spark DataFrames.
    
    Optimizations Included:
    P1: Pandas UDFs for routing and prediction.
    P1: Reduced redundant .count() actions in fit loop.
    P2: Distributed exemplar sampling using Window functions.
    P3: Seeded RNG for reproducibility.
    P2: Uses MEMORY_AND_DISK caching for intermediate DataFrames.
    """
    def __init__(self, spark: SparkSession, config: Dict[str, Any]):
        self.logger = logging.getLogger("GlobalModelManager") 
        self.logger.info("Initializing GlobalModelManager.")
        if "tree_params" not in config: raise ValueError("Config must contain 'tree_params'.")
        
        p = config["tree_params"]
        self.spark = spark
        self.max_depth: int | None = p.get("max_depth") 
        self.min_samples: int = p.get("min_samples_split", 2) 
        self.k: int = p.get("n_splitters", 5) 
        # P3: Add random_state for reproducibility
        self.random_state: int | None = p.get("random_state") 
        #self._rng = random.Random(self.random_state) # Initialize RNG instance with seed
        
        self.tree: Dict[int, TreeNode] = {0: TreeNode(0, None, None, False, None, {})}
        self._next_id: int = 1
        self._maj: int = 1 # Default majority class
        self.logger.info(f"Initialized with max_depth={self.max_depth}, min_samples={self.min_samples}, k={self.k}, seed={self.random_state}")

    def _to_ts_df(self, df):
        """ Ensure DataFrame has (row_id, time_series[, true_label]). """
        self.logger.debug("Converting DataFrame to time series format.")
        lbl = None
        if "label" in df.columns: lbl = "label"
        elif "true_label" in df.columns: lbl = "true_label"
        else: self.logger.warning("No 'label' or 'true_label' column found.")
        
        # Add row_id if it doesn't exist
        if "row_id" not in df.columns:
            self.logger.debug("Adding row_id column.")
            df = df.withColumn("row_id", F.monotonically_increasing_id())
        else: # Ensure existing row_id is LongType
            if df.schema["row_id"].dataType != LongType():
                 self.logger.debug("Casting existing row_id to LongType.")
                 df = df.withColumn("row_id", F.col("row_id").cast(LongType()))

        # Check if time_series column already exists
        if "time_series" in df.columns:
            self.logger.debug("'time_series' column already exists.")
            # Ensure label column is named 'true_label' if it exists
            if lbl == "label": df = df.withColumnRenamed("label", "true_label")
            # Select necessary columns
            select_cols = ["row_id", "time_series"]
            if "true_label" in df.columns: select_cols.append("true_label")
            return df.select(*select_cols)

        # If time_series doesn't exist, create it
        cols_to_exclude = {"row_id"}
        if lbl: cols_to_exclude.add(lbl)
        
        feat_cols = [c for c in df.columns if c not in cols_to_exclude]
        if not feat_cols: raise ValueError("No feature columns found to create 'time_series'.")
        self.logger.debug(f"Creating 'time_series' from feature columns: {feat_cols}")
        
        select_exprs = [ F.col("row_id"), F.array(*[F.col(c) for c in feat_cols]).alias("time_series") ]
        if lbl: select_exprs.append(F.col(lbl).cast(IntegerType()).alias("true_label"))
        
        return df.select(*select_exprs)

    @staticmethod
    def _gini(counts: Dict[int, int]) -> float:
        """ Calculates the Gini impurity for a dictionary of class counts. """
        tot = sum(counts.values())
        if tot == 0: return 0.0
        return 1.0 - sum((c / tot) ** 2 for c in counts.values())

    def fit(self, df): 
        """Train the proximity tree."""
        self.logger.info("Starting GlobalProxTree fitting process.")
        # P2: Use MEMORY_AND_DISK caching
        ts_df = self._to_ts_df(df).persist(StorageLevel.MEMORY_AND_DISK) 
        initial_row_count = ts_df.count() # P1: Necessary count here
        self.logger.info(f"Input data converted and cached. Row count: {initial_row_count}")

        if initial_row_count == 0: 
             self.logger.warning("Input DataFrame is empty. Cannot train tree.")
             ts_df.unpersist(); return self

        # Determine majority class
        try: 
            maj_row = ts_df.groupBy("true_label").count().orderBy(F.desc("count")).first()
            if maj_row: self._maj = maj_row["true_label"] 
            self.logger.info(f"Overall majority class: {self._maj}")
        except Exception as e: self.logger.error(f"Error calculating majority class: {e}. Using default: {self._maj}")

        # Initialize assignment DataFrame
        assign = ts_df.select("row_id", "time_series", "true_label") \
                      .withColumn("node_id", F.lit(0)) \
                      .persist(StorageLevel.MEMORY_AND_DISK) # P2: Use MEMORY_AND_DISK
        assign_count = assign.count(); self.logger.info(f"Initial assignment created. Rows: {assign_count}") # P1: Necessary count
        ts_df.unpersist() 

        # --- Tree Building Loop ---
        open_nodes = {0} 
        depth = 0
        while open_nodes and (self.max_depth is None or depth < self.max_depth):
            self.logger.info(f"--- Starting Tree Level {depth} ---")
            self.logger.debug(f"Open nodes: {open_nodes}")

            # Filter data for current level nodes
            cur = assign.filter(F.col("node_id").isin(list(open_nodes))) \
                        .persist(StorageLevel.MEMORY_AND_DISK) # P2: Use MEMORY_AND_DISK
            
            # --- P1: Calculate node stats ONCE per level ---
            self.logger.debug("Calculating statistics for current level nodes...")
            node_stats_df = cur.groupBy("node_id", "true_label").count()
            node_stats_rows = node_stats_df.collect() # Collect stats (expect relatively small)
            stats_per_node = collections.defaultdict(dict)
            totals_per_node = collections.defaultdict(int)
            for r in node_stats_rows:
                 stats_per_node[r.node_id][r.true_label] = r["count"]
                 totals_per_node[r.node_id] += r["count"]
            self.logger.debug(f"Calculated stats for {len(totals_per_node)} nodes.")
            # --- End P1 ---

            # Check if any data remains for open nodes (using precalculated totals)
            if not any(totals_per_node.get(nid, 0) > 0 for nid in open_nodes):
                 self.logger.info(f"No data for open nodes at depth {depth}. Stopping tree growth.")
                 cur.unpersist(); break

            # --- P2: Distributed Exemplar Sampling ---
            self.logger.debug("Starting distributed exemplar sampling...")
            # P3: Seed rand with instance RNG state (requires converting int to seed)
            window_spec = Window.partitionBy("node_id", "true_label").orderBy(F.rand()) # REMOVED SEED 
#            window_spec = Window.partitionBy("node_id", "true_label").orderBy(F.rand(self._rng.randint(0, 1000000))) 
            sampled_exemplars_df = cur.withColumn("rank", F.row_number().over(window_spec)) \
                                      .filter(F.col("rank") <= self.k) \
                                      .select("node_id", "true_label", "time_series") 
            collected_exemplars = sampled_exemplars_df.collect() # Collect ONLY the k*nodes*labels samples
            pool: Dict[int, Dict[int, list]] = collections.defaultdict(dict)
            for row in collected_exemplars: pool[row.node_id].setdefault(row.true_label, []).append(row.time_series) 
            self.logger.debug(f"Finished exemplar sampling. Nodes with pools: {list(pool.keys())}")
            # --- End P2 ---

            best_splits: Dict[int, Tuple[str, Dict[int, list]]] = {} 
            nodes_to_make_leaf: set[int] = set()

            # --- Evaluate splits (Driver-side logic using precalculated stats) ---
            self.logger.debug("Evaluating splits...")
            for nid in list(open_nodes): 
                self.logger.debug(f"Evaluating node {nid}...")
                stats = stats_per_node.get(nid, {})
                tot_samples_in_node = totals_per_node.get(nid, 0) # P1: Reuse count
                self.logger.debug(f"Node {nid} stats: {stats}, total samples: {tot_samples_in_node}")

                # Leaf conditions (using reused count)
                is_leaf = False
                if tot_samples_in_node < self.min_samples: is_leaf = True; reason="min_samples"
                elif len(stats) <= 1: is_leaf = True; reason="pure"
                elif nid not in pool or not pool[nid] or len(pool[nid]) < 2: is_leaf = True; reason="exemplars"
                
                if is_leaf: self.logger.info(f"Node {nid} becoming leaf: {reason}."); nodes_to_make_leaf.add(nid); continue

                # Find best split for non-leaf node
                parent_gini = self._gini(stats)
                best_gain = -1.0; best_exemplars_for_split = None
                node_pool = pool[nid]; available_labels = list(node_pool.keys())

                self.logger.debug(f"Node {nid}: Evaluating {self.k} candidates. Parent Gini: {parent_gini:.4f}")
                for k_idx in range(self.k):
                    candidate_ex = {}
                    possible = True
                    for lbl in available_labels:
                        if node_pool[lbl]: 
                            #candidate_ex[lbl] = self._rng.choice(node_pool[lbl]) # P3: Use seeded RNG
                            candidate_ex[lbl] = random.choice(node_pool[lbl]) 
                        else: possible = False; break 
                    if not possible or len(candidate_ex) < 2: continue 

                    bc_ex = self.spark.sparkContext.broadcast(candidate_ex)
                    
                    # Standard UDF for Gini calculation step (Pandas UDF less obvious benefit here)
                    @F.udf(IntegerType())
                    def nearest_lbl_udf_local(ts):
                        ex_val = bc_ex.value; best_d, best_l = float("inf"), None
                        for l, ex_ts in ex_val.items():
                            d = _euclid_gmm(ts, ex_ts); 
                            if d < best_d: best_d, best_l = d, l
                        return best_l

                    # Filter data for the current node *before* applying UDF
                    node_data_df = cur.filter(F.col("node_id") == nid) 
                    
                    # Calculate weighted Gini impurity (DataFrame based)
                    split_impurity_df = node_data_df.withColumn("branch", nearest_lbl_udf_local("time_series")) \
                                               .groupBy("branch", "true_label").count()
                    branch_totals = split_impurity_df.groupBy("branch").agg(F.sum("count").alias("branch_total"))
                    gini_per_branch = split_impurity_df.join(branch_totals, "branch") \
                                             .withColumn("prob_sq", (F.col("count") / F.col("branch_total")) ** 2) \
                                             .groupBy("branch", "branch_total").agg(F.sum("prob_sq").alias("s")) \
                                             .withColumn("branch_gini", 1.0 - F.col("s")) # <-- Corrected: sum("prob_sq")
                    weighted_gini_row = gini_per_branch.withColumn("weighted_gini", (F.col("branch_total") / tot_samples_in_node) * F.col("branch_gini")) \
                                                  .agg(F.sum("weighted_gini").alias("total_weighted_gini")) \
                                                  .first()
                    bc_ex.unpersist(False) 

                    if weighted_gini_row and weighted_gini_row["total_weighted_gini"] is not None:
                        current_impurity = weighted_gini_row["total_weighted_gini"]
                        current_gain = parent_gini - current_impurity
                        self.logger.debug(f"Node {nid}, Candidate {k_idx+1}: Impurity={current_impurity:.4f}, Gain={current_gain:.4f}")
                        if current_gain > best_gain:
                            best_gain = current_gain; best_exemplars_for_split = candidate_ex
                            self.logger.debug(f"Node {nid}: New best split found (Gain: {best_gain:.4f})")
                    else: self.logger.warning(f"Node {nid}, Candidate {k_idx+1}: Could not calculate impurity.")

                # Decide if node becomes leaf
                if best_gain <= 1e-9: 
                    self.logger.info(f"Node {nid} becoming leaf: best gain ({best_gain:.4f}) too low.")
                    nodes_to_make_leaf.add(nid)
                else:
                    self.logger.info(f"Node {nid}: Selected best split with gain {best_gain:.4f}.")
                    best_splits[nid] = ("euclidean", best_exemplars_for_split)

            # --- Finalize leaves ---
            self.logger.debug(f"Nodes to finalize as leaves: {nodes_to_make_leaf}")
            for nid in list(nodes_to_make_leaf): 
                if nid in open_nodes: 
                    stats = stats_per_node.get(nid, {}) # P1: Reuse stats
                    maj_lbl = self._maj 
                    if stats: maj_lbl = max(stats.items(), key=lambda kv: (kv[1], -kv[0]))[0] 
                    self.tree[nid] = self.tree[nid]._replace(is_leaf=True, prediction=maj_lbl, children={}, split_on=None)
                    self.logger.info(f"Node {nid} finalized as leaf. Prediction: {maj_lbl}.")
                    open_nodes.remove(nid) 

            # --- Create children and update assignments ---
            if not best_splits: self.logger.info("No nodes split."); cur.unpersist(); break 

            self.logger.debug("Creating child nodes...")
            split_map = {}; new_open_nodes_for_next_level = set()
            for pid, (measure, exemplars) in best_splits.items():
                child_dict = {}
                for branch_label in exemplars: 
                    cid = self._next_id; self._next_id += 1
                    self.tree[cid] = TreeNode(cid, pid, None, False, None, {}) 
                    child_dict[branch_label] = cid
                    split_map[(pid, branch_label)] = cid
                    new_open_nodes_for_next_level.add(cid)
                self.tree[pid] = self.tree[pid]._replace(split_on=(measure, exemplars), children=child_dict, is_leaf=False)
                self.logger.debug(f"Parent node {pid} updated. Children: {list(child_dict.values())}")

            open_nodes = new_open_nodes_for_next_level
            self.logger.debug(f"New open_nodes for next level: {open_nodes}")

            # --- P1: Use Pandas UDF for routing ---
            bc_split_map = self.spark.sparkContext.broadcast(split_map)
            bc_best_exemplars = self.spark.sparkContext.broadcast({pid: ex for pid, (_, ex) in best_splits.items()})
            self.logger.debug("Broadcasted split map and best exemplars for routing.")
            _euclid_gmm_local_route = _euclid_gmm # Local ref for UDF

            @F.pandas_udf(IntegerType())
            def route_pandas_udf(pid_series: pd.Series, ts_series: pd.Series) -> pd.Series:
                split_map_val = bc_split_map.value; exs_map_val = bc_best_exemplars.value
                results = []
                for pid, ts in zip(pid_series, ts_series):
                    if pid not in exs_map_val: results.append(pid); continue 
                    split_exemplars = exs_map_val[pid]
                    best_d, best_lbl = float("inf"), None
                    for lbl, ex_ts in split_exemplars.items():
                        d = _euclid_gmm_local_route(ts, ex_ts) 
                        if d < best_d: best_d, best_lbl = d, lbl
                    results.append(split_map_val.get((pid, best_lbl), pid))
                return pd.Series(results, dtype=pd.Int64Dtype()) # Use nullable Int

            # Apply the route UDF
            old_assign = assign 
            self.logger.info("Applying route_pandas_udf to update assignments...")
            assign = assign.withColumn("node_id", route_pandas_udf("node_id", "time_series")) \
                           .persist(StorageLevel.MEMORY_AND_DISK) # P2: Use MEMORY_AND_DISK
            assign_updated_count = assign.count() # P1: Necessary action
            self.logger.info(f"Assignment DataFrame updated. Rows: {assign_updated_count}")

            # Unpersist intermediates
            old_assign.unpersist()
            bc_split_map.unpersist(blocking=False)
            bc_best_exemplars.unpersist(blocking=False)
            cur.unpersist() 
            self.logger.debug(f"Unpersisted intermediates for depth {depth}.")

            depth += 1 
        # --- End of Tree Building Loop ---
        self.logger.info(f"Tree building loop finished at depth {depth}.")

        # --- Final Dangling Node Check --- 
        self.logger.debug("Performing final check for dangling internal nodes.")
        nodes_to_finalize = [nid for nid, nd in self.tree.items() if not nd.is_leaf and not nd.children]
        if nodes_to_finalize:
             self.logger.warning(f"Found {len(nodes_to_finalize)} dangling nodes: {nodes_to_finalize}")
             # Filter final assignment DF for these nodes
             dangling_df = assign.filter(F.col("node_id").isin(nodes_to_finalize))
             dangling_stats_rows = dangling_df.groupBy("node_id", "true_label").count().collect()
             stats_by_node = collections.defaultdict(dict)
             for r in dangling_stats_rows: stats_by_node[r["node_id"]][r["true_label"]] = r["count"]
             for nid in nodes_to_finalize:
                 stats = stats_by_node.get(nid, {}); maj_lbl = self._maj 
                 if stats: maj_lbl = max(stats.items(), key=lambda kv: (kv[1], -kv[0]))[0]
                 self.tree[nid] = self.tree[nid]._replace(is_leaf=True, prediction=maj_lbl, split_on=None)
                 self.logger.info(f"Dangling node {nid} finalized as leaf. Prediction: {maj_lbl}.")
        
        assign.unpersist()
        self.logger.info("GlobalProxTree fitting process finished.")
        return self

    # --- P1: Prediction uses Pandas UDF ---
    def predict(self, df):
        """ Predicts class labels using Pandas UDF for traversal. """
        self.logger.info("Starting GlobalProxTree prediction (using Pandas UDF).")
        df_ts = self._to_ts_df(df) # Ensure correct format

        if not self.tree or 0 not in self.tree or (not self.tree[0].children and self.tree[0].prediction is None):
             self.logger.warning("Tree not fitted/empty. Returning default predictions.")
             default_pred = F.lit(self._maj).cast(IntegerType()).alias("prediction")
             sel_cols = ["row_id", "time_series"] + (["true_label"] if "true_label" in df_ts.columns else []) + [default_pred]
             return df_ts.select(*sel_cols)

        self.logger.debug("Converting tree to plain dict for broadcast.")
        plain_tree = {nid: node._asdict() for nid, node in self.tree.items()}
        bc_tree = self.spark.sparkContext.broadcast(plain_tree)
        self.logger.debug(f"Broadcasted plain tree ({len(plain_tree)} nodes).")

        # Need _euclid_gmm available in the UDF scope
        _euclid_gmm_local_pred = _euclid_gmm # Create local ref for UDF

        @F.pandas_udf(IntegerType())
        def traverse_pandas_udf(ts_series: pd.Series) -> pd.Series:
            tree_dict_pd = bc_tree.value # Access broadcast value once per batch
            predictions = []
            
            for ts in ts_series: # Iterate through the Pandas Series
                if ts is None: predictions.append(None); continue
                
                node_id = 0
                MAX_TRAVERSAL_DEPTH = 50; current_depth = 0
                
                while node_id in tree_dict_pd and current_depth < MAX_TRAVERSAL_DEPTH:
                    current_node = tree_dict_pd[node_id]
                    if current_node.get('is_leaf', False):
                        predictions.append(current_node.get('prediction')); break 
                    
                    split_info = current_node.get('split_on') 
                    children = current_node.get('children')
                    if not split_info or not children: predictions.append(current_node.get('prediction')); break # Fallback

                    _, exemplars = split_info 
                    if not exemplars: predictions.append(current_node.get('prediction')); break # Fallback

                    min_dist_all = float("inf"); best_branch_id_all = None 
                    for branch_id, exemplar_ts in exemplars.items():
                        d = _euclid_gmm_local_pred(ts, exemplar_ts) 
                        if d < min_dist_all: min_dist_all = d; best_branch_id_all = branch_id

                    if best_branch_id_all is not None and best_branch_id_all in children:
                        node_id = children[best_branch_id_all]
                    else: # Fallback to nearest existing child
                        min_dist_existing = float("inf"); next_node_id_found = None 
                        for ex_br_id, ex_ch_id in children.items():
                            if ex_br_id in exemplars: 
                                d = _euclid_gmm_local_pred(ts, exemplars[ex_br_id]) 
                                if d < min_dist_existing: min_dist_existing = d; next_node_id_found = ex_ch_id
                        if next_node_id_found is not None: node_id = next_node_id_found
                        else: predictions.append(current_node.get('prediction')); break # Ultimate fallback
                    
                    current_depth += 1
                else: # Handle while loop exit
                     if current_depth >= MAX_TRAVERSAL_DEPTH:
                          last_node = tree_dict_pd.get(node_id)
                          pred = last_node.get('prediction') if last_node and last_node.get('is_leaf') else None
                          predictions.append(pred)
                     else: predictions.append(None) 
                          
            return pd.Series(predictions, dtype=pd.Int64Dtype()) # Use nullable int

        self.logger.info("Applying prediction Pandas UDF...")
        out_df = df_ts.withColumn("pred_raw", traverse_pandas_udf("time_series")) \
                      .withColumn("prediction", F.coalesce(F.col("pred_raw"), F.lit(self._maj)).cast(IntegerType())) \
                      .drop("pred_raw")
        
        bc_tree.unpersist(blocking=False) 
        self.logger.debug("Unpersisted broadcasted tree.")

        # Select final output columns
        select_cols = ["row_id", "time_series"] + (["true_label"] if "true_label" in out_df.columns else []) + ["prediction"]
        return out_df.select(*select_cols)


    def print_tree(self) -> str:
        """ Returns a human-readable string representation of the tree. """
        # (Keep existing print_tree logic)
        self.logger.debug("print_tree started.")
        lines = []
        def rec(nid: int, depth: int):
             nd = self.tree.get(nid)
             if nd is None: lines.append("  " * depth + f"#{nid} MISSING"); return
             ind = "  " * depth
             if nd.is_leaf: lines.append(f"{ind}Leaf {nid} → {nd.prediction}")
             else:
                 meas, ex = nd.split_on or (None, {})
                 lines.append(f"{ind}Node {nid} split={meas} labels={list(ex.keys())}")
                 for lbl, cid in sorted(nd.children.items()):
                     lines.append(f"{ind}  ├─ lbl={lbl} → child {cid}")
                     rec(cid, depth + 2)
        rec(0, 0)
        tree_str = "\n".join(lines)
        self.logger.debug("print_tree finished.")
        return tree_str


    def save_tree(self, path: str):
        """ Pickles the essential state of the manager to a file. """
        self.logger.info(f"Saving GlobalModelManager state to {path}")
        try:
            os.makedirs(os.path.dirname(path), exist_ok=True) 
            state = {
                "max_depth": self.max_depth, "min_samples": self.min_samples,
                "k": self.k, "tree": self.tree, 
                "_next_id": self._next_id, "_maj": self._maj,
                "random_state": self.random_state # P3: Save seed
            }
            with open(path, "wb") as fh: pickle.dump(state, fh)
            self.logger.info(f"Successfully saved state to {path}.")
        except Exception as e: self.logger.error(f"Failed to save tree state: {e}", exc_info=True)


    @classmethod
    def load_tree(cls, spark: SparkSession, path: str) -> "GlobalModelManager":
        """ Loads the manager state from a pickled file. """
        logger_gmm.info(f"Loading GlobalModelManager state from {path}") # Use class logger
        try:
            with open(path, "rb") as fh: state: Dict[str, Any] = pickle.load(fh)
            logger_gmm.debug("State loaded successfully.")
        except Exception as e: logger_gmm.error(f"Failed to load tree state: {e}", exc_info=True); raise

        # Reconstruct config for initialization
        loaded_config = {
            "tree_params": {
                "max_depth": state.get("max_depth"), 
                "min_samples_split": state.get("min_samples", 2), 
                "n_splitters": state.get("k", 5), 
                "random_state": state.get("random_state") # P3: Load seed
            }
        }
        logger_gmm.debug(f"Reconstructed config: {loaded_config}")

        # Create instance and restore state
        inst = cls(spark, loaded_config)
        inst.tree = state.get("tree", {0: TreeNode(0, None, None, False, None, {})}) 
        inst._next_id = state.get("_next_id", 1) 
        inst._maj = state.get("_maj", 1) 
        # P3: Re-initialize RNG if state loaded
        inst._rng = random.Random(inst.random_state) 
        logger_gmm.info(f"Instance created. Tree size: {len(inst.tree)} nodes.")
        return inst



This file ties the whole pipeline together.
It iterates through different partition configurations, running the 
full pipeline (ingestion, preprocessing, training, prediction, evaluation) 
in each iteration to assess the impact of partitioning on both local and global models.

NOTE: This structure involves re-running ingestion and preprocessing in each 
iteration, which includes potentially expensive data shuffles. This is done 
intentionally for this specific experiment but is less efficient than 
preprocessing once outside the loop if only varying local model partitions.

In [0]:
class PipelineController_Loop:
    def __init__(self, config):
        """
        Initialize the controller with the pipeline configuration.
        """
        self.config = config
        self.spark = None # SparkSession managed once for the entire run
        self.ingestion_config = {} # Populated during _setup_spark
        
        # Logger setup
        self.logger = logging.getLogger(__name__)
        if not self.logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)
            if self.logger.level == logging.NOTSET:
                 self.logger.setLevel(logging.INFO) 
            # Prevent duplicate messages if root logger also has handlers
            self.logger.propagate = True 

    def _setup_spark(self):
        """
        Setup or retrieve the Spark session based on the environment.
        Constructs appropriate data paths.
        Adds necessary Python files to Spark context for local runs.
        Returns True if setup is successful, False otherwise.
        """
        # Configure Spark Session only if one doesn't exist 
        if self.spark is None:
            try:
                if "DATABRICKS_RUNTIME_VERSION" in os.environ:
                    # Databricks environment
                    self.spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
                    print("\nUsing Databricks Spark session.")
                    self.ingestion_config = {
                        "data_path": self.config.get("databricks_data_path", "/mnt/2025-team6/fulldataset_ECG5000.csv"),
                        "data_percentage": self.config.get("data_percentage", 0.05) 
                    }
                else:
                    # Local environment setup
                    # Stop existing local session if present before creating new one
                    existing_spark = SparkSession.getActiveSession()
                    if existing_spark:
                         self.logger.info("Stopping existing Spark session before starting new one.")
                         existing_spark.stop()

                    self.spark = SparkSession.builder \
                        .appName(f"LocalPipeline_Run_{time.time()}") \
                        .master("local[6]") \
                        .config("spark.driver.memory", "12g") \
                        .config("spark.executor.memory", "12g") \
                        .config("spark.driver.maxResultSize", "12g") \
                        .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
                        .getOrCreate()
                    print("\nUsing local Spark session.")

                    # --- Construct Local Data Path (Robustly) ---
                    project_root = None
                    current_dir = None
                    try:
                         # __file__ is the path of the current script (controller_loop.py)
                         current_script_path = os.path.abspath(__file__) 
                         current_dir = pathlib.Path(current_script_path).parent # Directory containing this script (src)
                         # Assumes script is in src, parent is code, parent.parent is project root
                         project_root = current_dir.parent.parent.resolve() 
                         self.logger.info(f"Project root determined using __file__: {project_root}")
                    except NameError:
                         # Fallback if __file__ is not defined (e.g., interactive/notebook)
                         # *** CORRECTED FALLBACK LOGIC ***
                         # Assume CWD is the directory containing the notebook/where python was launched
                         cwd = pathlib.Path(os.getcwd()).resolve() 
                         self.logger.warning(f"__file__ not defined, using CWD: {cwd}")
                         # Check if CWD looks like the 'src' directory
                         if cwd.name == 'src':
                              project_root = cwd.parent.parent.resolve() # Go up two levels
                              current_dir = cwd # Set current_dir for module loading later
                              self.logger.info(f"Assuming CWD is 'src', project root set to: {project_root}")
                         # Check if CWD looks like the 'code' directory
                         elif cwd.name == 'code':
                              project_root = cwd.parent.resolve() # Go up one level
                              current_dir = cwd / "src" # Assume src exists for module loading
                              self.logger.info(f"Assuming CWD is 'code', project root set to: {project_root}")
                         else: # Otherwise, assume CWD *is* the project root
                              project_root = cwd 
                              current_dir = project_root / "src" # Assume src exists for module loading
                              self.logger.info(f"Assuming CWD is project root: {project_root}")
                         
                    # Ensure project_root was determined
                    if project_root is None:
                         self.logger.error("FATAL: Could not determine project root directory.")
                         if self.spark: self.spark.stop(); self.spark = None
                         return False

                    local_data_file = self.config.get("local_data_path", "fulldataset_ECG5000.csv").lstrip('/\\') 
                    final_data_path_os = project_root / local_data_file # Use pathlib's / operator
                    
                    if not final_data_path_os.exists():
                         self.logger.error(f"FATAL: Constructed data path does NOT exist: {final_data_path_os}")
                         # Stop the newly created session if path is invalid
                         if self.spark: self.spark.stop(); self.spark = None
                         return False # Indicate setup failure
                    else:
                         self.logger.info(f"Verified data path exists: {final_data_path_os}") 
                         # Convert OS path to file URI for Spark AFTER verification
                         final_data_path_uri = final_data_path_os.as_uri() 

                    self.ingestion_config = {
                        "data_path": final_data_path_uri, 
                        "data_percentage": self.config.get("data_percentage", 1.0) 
                    }
                    self.logger.info(f"Data path set in ingestion_config: {self.ingestion_config['data_path']}")
                    self.logger.info(f"Local data percentage set to: {self.ingestion_config['data_percentage']}")
                    # --- End Construct Local Data Path ---

                # Add Python module dependencies for local runs (needs to be done only once per session)
                if "DATABRICKS_RUNTIME_VERSION" not in os.environ and self.spark:
                    modules_to_add = ['global_model_manager.py'] # Add other required modules if needed
                    try:
                        # Use the current_dir determined earlier (should be src)
                        if current_dir is None or not current_dir.exists() or not current_dir.is_dir():
                             # Fallback if current_dir wasn't determined correctly
                             current_dir = project_root / "src" 
                             self.logger.warning(f"Re-assuming 'src' directory for module loading: {current_dir}")
                             if not current_dir.exists():
                                  raise FileNotFoundError("Cannot find assumed 'src' directory for module loading.")
                             
                        for module_name in modules_to_add:
                             module_path = current_dir / module_name # Path relative to src
                             if module_path.exists():
                                 self.spark.sparkContext.addPyFile(str(module_path)) # addPyFile needs string path
                                 self.logger.debug(f"Added {module_path} to SparkContext pyFiles.") 
                             else:
                                  self.logger.error(f"Could not find module {module_name} at {module_path}")
                    except Exception as e:
                        self.logger.error(f"Failed to add Python module to SparkContext: {e}")
                
                return True # Indicate successful setup

            except Exception as e:
                 self.logger.error(f"Error during Spark setup: {e}", exc_info=True)
                 if self.spark: self.spark.stop(); self.spark = None # Ensure cleanup on error
                 return False # Indicate setup failure
        else:
             # Spark session already exists
             self.logger.debug("Spark session already exists.")
             return True
            
    def run(self):
        """
        Executes the main pipeline loop for model training and evaluation.
        Iterates through specified partition counts, running the full pipeline 
        (ingestion, preprocessing, training, eval, save) in each iteration.
        """
        # --- Determine number of iterations and which models to run ---
        number_iterations_global, number_iterations_local = 0, 0
        run_local = self.config.get("local_model_config", {}).get("test_local_model", False)
        run_global = self.config.get("global_model_config", {}).get("test_global_model", False)

        if not run_local and not run_global:
            self.logger.error("No model selected for testing. Set 'test_local_model' or 'test_global_model' to True in config.")
            return 

        if run_local:
            number_iterations_local = self.config.get("local_model_config", {}).get("num_partitions", 0)
        if run_global:
            # Use the partition count from global config to control loop iterations for the experiment
            number_iterations_global = self.config.get("global_model_config", {}).get("num_partitions", 0) 
            
        # Determine the overall maximum number of iterations needed
        number_iterations = 0
        if run_local: number_iterations = max(number_iterations, number_iterations_local)
        if run_global: number_iterations = max(number_iterations, number_iterations_global)

        min_iterations = self.config.get("min_number_iterarations", 2)
        start_iteration = min_iterations
        end_iteration = number_iterations
        
        if end_iteration < start_iteration:
             self.logger.warning(f"Max iterations ({end_iteration}) is less than min iterations ({start_iteration}). No iterations will run.")
             start_iteration = end_iteration + 1 # Make range empty

        # --- Initialize report accumulators ---
        all_reports_global = {} 
        all_reports_local = {}  

        # --- Setup Spark Session ONCE before the loop ---
        self._setup_spark() 
        if not self.spark:
             self.logger.error("Initial Spark setup failed. Aborting run.")
             return
        # Check data path after setup
        if self.ingestion_config.get("data_path") is None:
             self.logger.error("Initial data path construction failed. Aborting run.")
             # Spark might have been stopped in _setup_spark if path was invalid
             if "DATABRICKS_RUNTIME_VERSION" not in os.environ and self.spark: self.spark.stop() 
             return

        # ============================================================
        # === Main Iteration Loop ===
        # ============================================================
        try: # Add try block around the loop for final cleanup
            for i in range(start_iteration, end_iteration + 1): 
                self.logger.info(f"========== Starting Iteration {i} ==========")
                
                iteration_run_local = run_local and i <= number_iterations_local 
                iteration_run_global = run_global # Global runs in all iterations up to its max if enabled

                # --- Setup Modules for Iteration ---
                # Spark session is already running
                current_datetime = time.strftime("%Y-%m-%d-%H-%M-%S")
                # Initialize modules for this iteration
                self.evaluator = Evaluator(track_memory=self.config.get("track_memory", False)) 
                # Create a deep copy of the config for this iteration to avoid side effects
                current_iter_config = json.loads(json.dumps(self.config)) 
                # Update partition counts in the copied config for this iteration
                if iteration_run_local: 
                     current_iter_config.setdefault("local_model_config", {})["num_partitions"] = i
                if iteration_run_global: 
                     current_iter_config.setdefault("global_model_config", {})["num_partitions"] = i
                
                # Initialize Preprocessor and Ingestion with the iteration-specific config
                self.preprocessor = Preprocessor(config=current_iter_config) 
                # Ingestion uses the config set during initial _setup_spark
                self.ingestion = DataIngestion(spark=self.spark, config=self.ingestion_config) 

                # Define DataFrame variables for this iteration scope
                preprocessed_train_df: DataFrame = None
                preprocessed_test_df: DataFrame = None
                min_max_values: dict = None

                try:
                    # --- Data Ingestion (runs every iteration) ---
                    self.evaluator.start_timer("Ingestion")
                    self.logger.info(f"Iteration {i}: Loading data...")
                    df = self.ingestion.load_data() 
                    if df.limit(1).count() == 0:
                         self.logger.error(f"Iteration {i}: Data ingestion resulted in empty DataFrame. Skipping iteration.")
                         continue # Skip to next iteration
                    self.evaluator.record_time("Ingestion")

                    # --- Split, Calculate Min/Max (runs every iteration) ---
                    self.evaluator.start_timer("Split_MinMax")
                    min_max_values = compute_min_max(df) 
                    self.logger.info(f"Iteration {i}: Computed Min-Max values.") 
                    train_df, test_df = randomSplit_stratified_via_sampleBy(df, label_col="label", weights=[0.8, 0.2], seed=123)       
                    self.evaluator.record_time("Split_MinMax")
                    
                    # --- Preprocessing Train Data (runs every iteration, includes repartitioning) ---
                    self.evaluator.start_timer("Preprocessing_Train")
                    target_partitions = current_iter_config.get("local_model_config" if iteration_run_local else "global_model_config", {}).get("num_partitions", "N/A")
                    self.logger.info(f"Iteration {i}: Preprocessing train data (target partitions={target_partitions})...")
                    preprocessed_train_df = self.preprocessor.run_preprocessing(train_df, min_max_values) 
                    train_count = preprocessed_train_df.count() # Action to materialize preprocessing
                    self.logger.info(f"Iteration {i}: Preprocessed train data count: {train_count}")
                    self.evaluator.record_time("Preprocessing_Train")

                    # --- Preprocessing Test Data (runs every iteration, includes repartitioning) ---
                    self.evaluator.start_timer("Preprocessing_Test")
                    self.logger.info(f"Iteration {i}: Preprocessing test data (target partitions={target_partitions})...")
                    preprocessed_test_df = self.preprocessor.run_preprocessing(test_df, min_max_values)
                    test_count = preprocessed_test_df.count() # Action
                    self.logger.info(f"Iteration {i}: Preprocessed test data count: {test_count}")
                    self.evaluator.record_time("Preprocessing_Test")

                    if train_count == 0 or test_count == 0:
                        self.logger.error(f"Iteration {i}: Preprocessing resulted in empty train or test set. Skipping model steps.")
                        continue 

                except Exception as e:
                     # Handle potential errors during data loading/preprocessing
                     if "Path does not exist" in str(e):
                          spark_path_attempt = self.ingestion_config.get("data_path", "N/A") 
                          self.logger.error(f"Iteration {i}: Spark failed to find data file during load. Path: {spark_path_attempt}. Error: {e}", exc_info=False) 
                     elif isinstance(e, (ConnectionRefusedError, ConnectionResetError)) or "Connection reset by peer" in str(e):
                          self.logger.error(f"Iteration {i}: Connection error during data processing: {e}", exc_info=True)
                     else:
                          self.logger.error(f"Iteration {i}: Error during data processing: {e}", exc_info=True) 
                     continue # Skip to next iteration
                finally:
                     # Clean up intermediate raw dataframes for this iteration
                     if 'df' in locals(): del df
                     if 'train_df' in locals(): del train_df
                     if 'test_df' in locals(): del test_df


                # Define variables to hold model/prediction results for the iteration
                model_ensamble = None 
                predictions_df = None 
                
                # =================== GLOBAL MODEL ===================
                if iteration_run_global:
                    # ... (Global model block remains the same) ...
                    self.global_model_manager = None 
                    model_ensamble = None 
                    global_report = None 
                    try:
                        # Pass the specific global config for this iteration
                        global_config = current_iter_config.get("global_model_config", {}) 
                        if not global_config:
                             self.logger.error(f"Iteration {i}: global_model_config missing. Skipping global model.")
                        else:
                            self.global_model_manager = GlobalModelManager(spark=self.spark, config=global_config) 
                            
                            print(f"\nIteration {i}: Train global model......")
                            self.evaluator.start_timer("Global_Training")
                            model_ensamble = self.global_model_manager.fit(preprocessed_train_df) 
                            self.evaluator.record_time("Global_Training")
                            print(f"Iteration {i}: Finish Global Training.")

                            if model_ensamble and hasattr(model_ensamble, 'tree') and len(model_ensamble.tree) > 1:
                                self.logger.info(f"Iteration {i}: Global model training successful.")
                                
                                print(f"\nIteration {i}: Generate predictions with global model......")
                                self.evaluator.start_timer("Global_Prediction")
                                predictions_df = predict_with_global_prox_tree(model_ensamble, preprocessed_test_df) 
                                self.evaluator.record_time("Global_Prediction")
                                print(f"Iteration {i}: Finish Global Prediction.")
                                
                                print(f"\nIteration {i}: Global Model Predictions Distribution:")
                                predictions_df.groupBy("prediction").count().show() 

                                print(f"\nIteration {i}: Generate metrics for global model......")
                                global_report, class_names = self.evaluator.log_metrics(predictions_df, model=model_ensamble) 
                                all_reports_global[str(i)] = global_report 
                                print(f"Iteration {i}: Finish Global Evaluation.")
                                
                                depth = global_config.get("tree_params", {}).get("max_depth", "NA")
                                
                                model_folder = "models_global" 
                                os.makedirs(model_folder, exist_ok=True) 
                                model_filename = f"global_model_iter_{i}_parti_{i}_{current_datetime}_depth_{depth}.pkl" 
                                model_save_path = os.path.join(model_folder, model_filename)
                                try:
                                    self.global_model_manager.save_tree(model_save_path) 
                                    print(f"Saved global model to {model_save_path}")
                                except Exception as e: self.logger.error(f"Failed to save global model {model_filename}: {e}")

                            else:
                                self.logger.warning(f"Iteration {i}: Global model training failed or resulted in trivial tree.")

                    except Exception as e:
                        self.logger.error(f"Iteration {i}: Error during global model processing: {e}", exc_info=True) 
                    finally:
                        # Clean up global model objects specific to this iteration
                        self.global_model_manager = None 
                        if 'model_ensamble' in locals() and model_ensamble is not None: del model_ensamble
                        if not iteration_run_local and 'predictions_df' in locals() and predictions_df is not None: del predictions_df 
                        self.logger.debug(f"Iteration {i}: Cleaned up global model objects.")


                # =================== LOCAL MODEL ===================
                if iteration_run_local:
                    # ... (Local model block remains the same) ...
                    self.local_model_manager = None 
                    model_ensamble = None 
                    if iteration_run_global and 'predictions_df' in locals(): predictions_df = None 
                    local_report = None 
                    try:
                        # Pass the specific local config for this iteration
                        local_config = current_iter_config.get("local_model_config", {})
                        if not local_config:
                            self.logger.error(f"Iteration {i}: local_model_config missing. Skipping local model.")
                        else: 
                            self.local_model_manager = LocalModelManager(config=local_config) 
                            
                            print(f"\nIteration {i}: Train local model with {i} partitions......")
                            self.evaluator.start_timer("Local_Training")
                            model_ensamble = self.local_model_manager.train_ensemble(preprocessed_train_df) 
                            self.evaluator.record_time("Local_Training")
                            print(f"Iteration {i}: Finish Local Training.")
                            
                            if model_ensamble is not None and hasattr(model_ensamble, 'trees_') and model_ensamble.trees_:
                                self.logger.info(f"Iteration {i}: Local model training successful.")
                                
                                print(f"\nIteration {i}: Generate predictions with local model......")
                                self.evaluator.start_timer("Local_Prediction")
                                self.predictor = PredictionManager(self.spark, model_ensamble) 
                                predictions_df = self.predictor.generate_predictions_local(preprocessed_test_df) 
                                self.evaluator.record_time("Local_Prediction")
                                print(f"Iteration {i}: Finish Local Prediction.")
                            
                                print(f"\nIteration {i}: Local model Predictions Distribution:")
                                predictions_df.groupBy("prediction").count().show() 

                                print(f"\nIteration {i}: Generate metrics for local model......")
                                local_report, class_names = self.evaluator.log_metrics(predictions_df, model=model_ensamble)
                                all_reports_local[str(i)] = local_report 
                                print(f"Iteration {i}: Finish Local Evaluation.")
                                
                                depth = local_config.get("tree_params", {}).get("max_depth", "NA") 
                                
                                model_folder = "models_local" 
                                os.makedirs(model_folder, exist_ok=True) 
                                model_filename = f"local_model_iter_{i}_parti_{i}_{current_datetime}_depth_{depth}.pkl" 
                                model_save_path = os.path.join(model_folder, model_filename)
                                try:
                                    with open(model_save_path, 'wb') as f: pickle.dump(model_ensamble, f) 
                                    print(f"Saved local model ensemble to {model_save_path}")
                                except Exception as e: self.logger.error(f"Failed to save local model {model_filename}: {e}")
                                
                            else:
                                self.logger.warning(f"Iteration {i}: Local model training failed or resulted in empty ensemble.")

                    except Exception as e:
                         self.logger.error(f"Iteration {i}: Error during local model processing: {e}", exc_info=True) 
                    finally:
                        # Clean up local model objects specific to this iteration
                        self.local_model_manager = None 
                        if 'model_ensamble' in locals() and model_ensamble is not None: del model_ensamble
                        if 'predictions_df' in locals() and predictions_df is not None: del predictions_df
                        self.logger.debug(f"Iteration {i}: Cleaned up local model objects.")


                # --- Iteration End ---
                # Cleanup preprocessed DataFrames for this iteration 
                try:
                     if 'preprocessed_train_df' in locals() and preprocessed_train_df is not None: 
                          # preprocessed_train_df.unpersist() # Caching removed, unpersist not needed
                          del preprocessed_train_df
                     if 'preprocessed_test_df' in locals() and preprocessed_test_df is not None: 
                          # preprocessed_test_df.unpersist() # Caching removed, unpersist not needed
                          del preprocessed_test_df
                except Exception as cleanup_e:
                     self.logger.warning(f"Iteration {i}: Error during DataFrame cleanup: {cleanup_e}")

                
                self.logger.info(f"========== Finished Iteration {i} ==========")
                # Optional delay
                if self.config.get('delay_time', 0) > 0 and i < end_iteration : 
                    print(f"\nIteration {i}: Waiting for {self.config['delay_time']} seconds before next iteration...\n")
                    time.sleep(self.config['delay_time'])

                # *** Spark session is NOT stopped here anymore ***

            # End of main loop

        # ============================================================
        # === Step 3: Save Accumulated Reports After Loop ===
        # ============================================================
        finally: # Use finally to ensure Spark stops even if loop errors out
             final_datetime = time.strftime("%Y-%m-%d-%H-%M-%S") 
            
             if all_reports_global:
                  report_folder = "logs" 
                  os.makedirs(report_folder, exist_ok=True) 
                  report_filename_global = f"report_global_model_ALL_{final_datetime}.json" 
                  report_save_path_global = os.path.join(report_folder, report_filename_global)
                  try:
                      with open(report_save_path_global, "w") as f: json.dump(all_reports_global, f, indent=2)
                      print(f"Saved ALL global model reports to {report_save_path_global}") 
                  except Exception as e: self.logger.error(f"Failed to save aggregated global report: {e}")

             if all_reports_local:
                  report_folder = "logs"
                  os.makedirs(report_folder, exist_ok=True) 
                  report_filename_local = f"report_local_model_ALL_{final_datetime}.json" 
                  report_save_path_local = os.path.join(report_folder, report_filename_local)
                  try:
                      with open(report_save_path_local, "w") as f: json.dump(all_reports_local, f, indent=2)
                      print(f"Saved ALL local model reports to {report_save_path_local}") 
                  except Exception as e: self.logger.error(f"Failed to save aggregated local report: {e}")


             # --- Pipeline End ---
             # Final Spark stop if running locally 
             if "DATABRICKS_RUNTIME_VERSION" not in os.environ and self.spark:
                  self.spark.stop()
                  self.spark = None # Reset attribute
                  print(f"\nFinal Spark session stopped (local mode)!")

             print("\n--- Pipeline execution finished ---")



In [0]:
# config.py
"""
config.py

This file holds configuration settings and constants.
It stores paths, hyperparameters, and Spark settings in one place,
so they can be easily managed and updated as the project grows.
Typically, it's created early on, but it can be refined later.
"""

config = {
    "databricks_data_path" : "/mnt/2025-team6/fulldataset_ECG5000.csv",
    "local_data_path" : "/fulldataset_ECG5000.csv",
    "label_col" : "label",
    "data_percentage" : 1.0,
    "min_number_iterarations" : 1, # Minimum number of iterations for the loop
    "delay_time" : 3,
    
    "local_model_config": {
        "test_local_model" : False,
        "num_partitions": 2,  # loop to this number of partitions - This is the number of partitions for the local model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": 1,  #TODO: NONE
            "min_samples_split": 5,  # From ProximityTree default
            "random_state": 123
            },
        "forest_params": {
            "random_state": 123,
            "n_jobs": -1  # Use all available cores
            }
    },
    "global_model_config": {
        "test_global_model" : True,
        "num_partitions": 1,  # loop to this number of partitions - This is the number of partitions for the global model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": None,  
            "min_samples_split": 5,  # From ProximityTree defaults
            "random_state": 123
            },
    },
    "reserve_partition_id": False
}

In [0]:

print("Starting pipeline via controller")
config = config
try:
  controller = PipelineController_Loop(config)
  controller.run()
except Exception as err:                   # <-- catches anything that went wrong
    print("Pipeline failed:", err)         # <-- quick, human‑readable message

In [0]:

# config.py
"""
config.py

This file holds configuration settings and constants.
It stores paths, hyperparameters, and Spark settings in one place,
so they can be easily managed and updated as the project grows.
Typically, it's created early on, but it can be refined later.
"""

config = {
    "databricks_data_path" : "/mnt/2025-team6/fulldataset_ECG5000.csv",
    "local_data_path" : "/fulldataset_ECG5000.csv",
    "label_col" : "label",
    "data_percentage" : 1.0,
    "min_number_iterarations" : 13, # Minimum number of iterations for the loop
    "delay_time" : 3,
    
    "local_model_config": {
        "test_local_model" : False,
        "num_partitions": 2,  # loop to this number of partitions - This is the number of partitions for the local model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": 1,  #TODO: NONE
            "min_samples_split": 5,  # From ProximityTree default
            "random_state": 123
            },
        "forest_params": {
            "random_state": 123,
            "n_jobs": -1  # Use all available cores
            }
    },
    "global_model_config": {
        "test_global_model" : True,
        "num_partitions": 14,  # loop to this number of partitions - This is the number of partitions for the global model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": None,  
            "min_samples_split": 5,  # From ProximityTree defaults
            "random_state": 123
            },
    },
    "reserve_partition_id": False
}
print("Starting pipeline via controller")
config = config
try:
    controller = PipelineController_Loop(config)
    controller.run()
except Exception as err:                   # <-- catches anything that went wrong
    print("Pipeline failed:", err)         # <-- quick, human‑readable message

In [0]:

# config.py
"""
config.py

This file holds configuration settings and constants.
It stores paths, hyperparameters, and Spark settings in one place,
so they can be easily managed and updated as the project grows.
Typically, it's created early on, but it can be refined later.
"""

config = {
    "databricks_data_path" : "/mnt/2025-team6/fulldataset_ECG5000.csv",
    "local_data_path" : "/fulldataset_ECG5000.csv",
    "label_col" : "label",
    "data_percentage" : 1.0,
    "min_number_iterarations" : 15, # Minimum number of iterations for the loop
    "delay_time" : 3,
    
    "local_model_config": {
        "test_local_model" : False,
        "num_partitions": 2,  # loop to this number of partitions - This is the number of partitions for the local model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": 1,  #TODO: NONE
            "min_samples_split": 5,  # From ProximityTree default
            "random_state": 123
            },
        "forest_params": {
            "random_state": 123,
            "n_jobs": -1  # Use all available cores
            }
    },
    "global_model_config": {
        "test_global_model" : True,
        "num_partitions": 16,  # loop to this number of partitions - This is the number of partitions for the global model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": None,  
            "min_samples_split": 5,  # From ProximityTree defaults
            "random_state": 123
            },
    },
    "reserve_partition_id": False
}
print("Starting pipeline via controller")
config = config
try:
    controller = PipelineController_Loop(config)
    controller.run()
except Exception as err:                   # <-- catches anything that went wrong
    print("Pipeline failed:", err)         # <-- quick, human‑readable message

In [0]:

# config.py
"""
config.py

This file holds configuration settings and constants.
It stores paths, hyperparameters, and Spark settings in one place,
so they can be easily managed and updated as the project grows.
Typically, it's created early on, but it can be refined later.
"""

config = {
    "databricks_data_path" : "/mnt/2025-team6/fulldataset_ECG5000.csv",
    "local_data_path" : "/fulldataset_ECG5000.csv",
    "label_col" : "label",
    "data_percentage" : 1.0,
    "min_number_iterarations" : 17, # Minimum number of iterations for the loop
    "delay_time" : 3,
    
    "local_model_config": {
        "test_local_model" : False,
        "num_partitions": 2,  # loop to this number of partitions - This is the number of partitions for the local model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": 1,  #TODO: NONE
            "min_samples_split": 5,  # From ProximityTree default
            "random_state": 123
            },
        "forest_params": {
            "random_state": 123,
            "n_jobs": -1  # Use all available cores
            }
    },
    "global_model_config": {
        "test_global_model" : True,
        "num_partitions": 18,  # loop to this number of partitions - This is the number of partitions for the global model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": None,  
            "min_samples_split": 5,  # From ProximityTree defaults
            "random_state": 123
            },
    },
    "reserve_partition_id": False
}
print("Starting pipeline via controller")
config = config
try:
    controller = PipelineController_Loop(config)
    controller.run()
except Exception as err:                   # <-- catches anything that went wrong
    print("Pipeline failed:", err)         # <-- quick, human‑readable message

In [0]:

# config.py
"""
config.py

This file holds configuration settings and constants.
It stores paths, hyperparameters, and Spark settings in one place,
so they can be easily managed and updated as the project grows.
Typically, it's created early on, but it can be refined later.
"""

config = {
    "databricks_data_path" : "/mnt/2025-team6/fulldataset_ECG5000.csv",
    "local_data_path" : "/fulldataset_ECG5000.csv",
    "label_col" : "label",
    "data_percentage" : 1.0,
    "min_number_iterarations" : 19, # Minimum number of iterations for the loop
    "delay_time" : 3,
    
    "local_model_config": {
        "test_local_model" : False,
        "num_partitions": 2,  # loop to this number of partitions - This is the number of partitions for the local model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": 1,  #TODO: NONE
            "min_samples_split": 5,  # From ProximityTree default
            "random_state": 123
            },
        "forest_params": {
            "random_state": 123,
            "n_jobs": -1  # Use all available cores
            }
    },
    "global_model_config": {
        "test_global_model" : True,
        "num_partitions": 20,  # loop to this number of partitions - This is the number of partitions for the global model
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": None,  
            "min_samples_split": 5,  # From ProximityTree defaults
            "random_state": 123
            },
    },
    "reserve_partition_id": False
}
print("Starting pipeline via controller")
config = config
try:
    controller = PipelineController_Loop(config)
    controller.run()
except Exception as err:                   # <-- catches anything that went wrong
    print("Pipeline failed:", err)         # <-- quick, human‑readable message