In [8]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging
from pyspark.sql import SparkSession
import os
from pyspark.sql import SparkSession
from data_ingestion import DataIngestion
from preprocessing import Preprocessor
from prediction_manager import PredictionManager
from local_model_manager import LocalModelManager
from evaluation import Evaluator
from utilities import show_compact
import time
import json
from random import sample
from dtaidistance import dtw
import collections
from pprint import pprint
import random
import collections
from pyspark.sql import functions as F
from pyspark.ml.evaluation import MulticlassClassificationEvaluator


In [None]:
import pyspark.sql.functions as F
from pyspark.sql.types import StructType, StructField, IntegerType, DoubleType, ArrayType, MapType
import random
import collections
import math # For Euclidean distance
import json # To potentially serialize complex split_on info if needed, though plain dict is better

# Define a simple Euclidean distance function for use in UDF
# In a real implementation, this would handle multiple distance measures and parameters
def euclidean_distance(ts1, ts2):
    """Calculates Euclidean distance between two time series."""
    if ts1 is None or ts2 is None or len(ts1) != len(ts2):
        return float('inf') # Handle invalid inputs
    # Ensure both are lists of numbers
    try:
        dist = math.sqrt(sum([(a - b) ** 2 for a, b in zip(ts1, ts2)]))
        return float(dist) # Return as float
    except Exception as e:
        # Print error only in debug mode or with proper logging
        # print(f"Error calculating distance: {e}")
        return float('inf')


# Define a UDF for predicting a single time series instance
# This UDF will need access to the tree structure (broadcasted plain dictionary)
def predict_udf_func(plain_tree_structure_broadcast):
    """
    Returns a UDF that traverses the tree for a single time series instance.
    plain_tree_structure_broadcast: Broadcast variable containing the plain dictionary tree structure.
    """
    tree = plain_tree_structure_broadcast.value

    def traverse_tree(time_series):
        """Traverse the tree for a single time series instance."""
        if time_series is None:
            return None # Or a default prediction

        node_id = 0  # Start at root

        # Traverse the tree until a leaf node is reached or traversal stops
        while node_id in tree:
            current_node = tree[node_id]

            # If it's a leaf node, return its prediction
            if current_node['is_leaf']:
                return current_node['prediction']

            # If it's an internal node, use the split info to decide which branch to follow
            split_info = current_node.get('split_on') # Use .get for safety
            children = current_node.get('children')

            # Ensure split info and children exist for internal nodes
            if split_info and children and len(children) > 0:
                measure_type, exemplars = split_info # split_info is (measure_type, {branch_id: exemplar_ts})

                # Calculate distance to ALL exemplars for this node's split
                min_dist_all_exemplars = float('inf')
                best_branch_id_all_exemplars = None

                for branch_id, exemplar_ts in exemplars.items():
                    # Calculate distance using the specified measure (placeholder: euclidean)
                    # In a real implementation, call a function that dispatches based on measure_type
                    d = euclidean_distance(time_series, exemplar_ts) # Use the distance function

                    if d < min_dist_all_exemplars:
                        min_dist_all_exemplars = d
                        best_branch_id_all_exemplars = branch_id

                # --- Enhanced Traversal Logic ---
                # Check if the child node corresponding to the nearest exemplar exists
                if best_branch_id_all_exemplars is not None and best_branch_id_all_exemplars in children:
                    # If the child exists, move to that child node
                    node_id = children[best_branch_id_all_exemplars]
                    # print(f"DEBUG: Node {current_node['node_id']}, nearest exemplar branch {best_branch_id_all_exemplars} exists, moving to child {node_id}") # Debug
                else:
                    # If the child corresponding to the nearest exemplar does NOT exist (pruned branch),
                    # find the nearest exemplar among the *existing* child branches and follow that path.
                    # print(f"DEBUG: Node {current_node['node_id']}, nearest exemplar branch {best_branch_id_all_exemplars} does not exist. Finding nearest among existing children.") # Debug)
                    min_dist_existing_children = float('inf')
                    next_node_id = None

                    # Iterate through the *existing* child branches
                    for existing_branch_id, existing_child_id in children.items():
                        # Find the exemplar time series for this existing branch from the original exemplars
                        # It's crucial that the branch_id used as the key in 'children' corresponds
                        # to the branch_id (exemplar label) in the 'exemplars' dictionary.
                        if existing_branch_id in exemplars:
                            existing_exemplar_ts = exemplars[existing_branch_id]
                            # Calculate distance to this existing branch's exemplar
                            d = euclidean_distance(time_series, existing_exemplar_ts) # Use the distance function

                            if d < min_dist_existing_children:
                                min_dist_existing_children = d
                                next_node_id = existing_child_id

                    # If a nearest existing child was found, move to that child node
                    if next_node_id is not None:
                        node_id = next_node_id
                        # print(f"DEBUG: Node {current_node['node_id']}, routed to nearest existing child {node_id} via branch {next_node_id}.") # Debug)
                    else:
                        # If no existing children were found (shouldn't happen if children dict is non-empty),
                        # stop traversal and return the current node's prediction.
                        # print(f"DEBUG: Node {current_node['node_id']}, no existing children found. Stopping traversal.") # Debug)
                        return current_node.get('prediction') # Return prediction of current node

            else:
                # If the node is internal but has no split info or children (shouldn't happen with correct training)
                # print(f"DEBUG: Node {current_node['node_id']} is internal but has no split info/children, stopping traversal.") # Debug)
                # Return the prediction of the current node
                return current_node.get('prediction')


        # If the loop finishes without returning (shouldn't happen if root exists),
        # or if the final node_id is not in the tree (error case)
        # Return a default fallback prediction
        # print(f"DEBUG: Traversal ended unexpectedly at node {node_id}. Using default 1.") # Debug)
        return 1 # Default fallback prediction


    return traverse_tree


class GlobalProxTree:
    def __init__(self, spark, max_depth=5, min_samples=5, num_candidate_splits=5, num_exemplars_per_class=1):
        """
        Initialize the Global Proximity Tree

        Parameters:
        -----------
        spark : SparkSession
            The Spark session to use
        max_depth : int
            Maximum depth of the tree
        min_samples : int
            Minimum number of samples required to split a node
        num_candidate_splits : int
            Number of random candidate splits to evaluate at each node.
        num_exemplars_per_class : int
            Number of exemplars to sample per class for each open node.
            (Used for sampling pool on driver, not per candidate split as in paper)
        """
        self.spark = spark
        self.max_depth = max_depth
        self.min_samples = min_samples
        self.num_candidate_splits = num_candidate_splits
        # Note: num_exemplars_per_class here is used to sample a pool of exemplars
        # to the driver per node/label, not per candidate split.
        # The paper samples 1 exemplar per class *per candidate split*.
        self.num_exemplars_per_class = num_exemplars_per_class


        # Define the schema for data assigned to nodes
        self.assignment_schema = StructType([
            StructField("row_id", IntegerType(), False), # Add a unique row ID
            StructField("node_id", IntegerType(), False),
            StructField("time_series", ArrayType(DoubleType()), False),
            StructField("true_label", IntegerType(), False),
        ])

        # Define the schema for tagged dataframe during splitting
        # This schema is for the lighter DataFrame used for Gini calculation
        # (No longer used to create a DataFrame, but kept for reference)
        # self.tagged_gini_schema = StructType([
        #     StructField("row_id", IntegerType(), False),
        #     StructField("node_id", IntegerType(), False), # Parent node ID
        #     StructField("true_label", IntegerType(), False),
        #     StructField("assigned_branch_id", IntegerType(), False), # Branch ID assigned by the split (label of nearest exemplar)
        # ])


        # Define the TreeNode structure (still used on the driver)
        # split_on will now store information about the chosen split:
        # (distance_measure_type: str, {branch_id: exemplar_time_series})
        # branch_id here is the label of the exemplar
        self.TreeNode = collections.namedtuple(
            "TreeNode",
            "node_id parent_id split_on is_leaf prediction children".split()
        )

        # Initialize the tree with a root node
        self.tree = {
            0: self.TreeNode(
                node_id=0,
                parent_id=None,
                split_on=None, # Will store the chosen split info (measure, exemplars)
                is_leaf=False,
                prediction=None,
                children={}, # {branch_id: child_node_id}
            )
        }
        self._next_node_id = 1 # Counter for assigning new node IDs

        # Store the overall majority class for fallback prediction if needed
        self._overall_majority_class = None


    def _convert_to_time_series_format(self, df):
        """
        Convert wide dataframe (with each feature in its own column) to a dataframe
        with a single array column containing all features. Adds a unique row_id.

        Parameters:
        -----------
        df : Spark DataFrame
            Wide DataFrame with feature columns and label column

        Returns:
        --------
        Spark DataFrame
            DataFrame with 'row_id', 'time_series' and 'label' columns
        """
        print("DEBUG: _convert_to_time_series_format started.")
        # Check if 'time_series' column already exists
        if 'time_series' in df.columns:
            print("DEBUG: DataFrame already has 'time_series' column, no conversion needed.")
            # Ensure row_id is present
            if 'row_id' not in df.columns:
                print("DEBUG: Adding row_id to existing time_series DataFrame.")
                df = df.withColumn("row_id", F.monotonically_increasing_id())
            return df

        # Get all column names except 'label'
        feature_cols = [col for col in df.columns if col != 'label']

        print(f"DEBUG: Converting {len(feature_cols)} feature columns to 'time_series' array.")

        # Use array() function to combine columns and add a unique row_id
        ts_df = df.select(
            F.monotonically_increasing_id().alias("row_id"), # Add unique ID
            F.array(*[F.col(c) for c in feature_cols]).alias("time_series"),
            df["label"].cast(IntegerType()).alias("true_label")  # Ensure label is an integer and rename
        )

        # Show sample of converted data
        print("DEBUG: Sample of converted DataFrame:")
        ts_df.show(2, truncate=False)
        print("DEBUG: _convert_to_time_series_format finished.")

        return ts_df


    def fit(self, df):
        """
        Fit the decision tree on the dataframe

        Parameters:
        -----------
        df : Spark DataFrame
            DataFrame with feature columns and 'label' column

        Returns:
        --------
        self : GlobalProxTree
            The fitted tree
        """
        print("DEBUG: fit started.")
        # First, convert to time_series format if needed and add row_id
        df = self._convert_to_time_series_format(df)

        # Calculate overall majority class for fallback prediction
        label_counts = df.groupBy("true_label").count().collect()
        if label_counts:
            self._overall_majority_class = max(label_counts, key=lambda x: x['count'])['true_label']
            print(f"DEBUG: Overall majority class calculated: {self._overall_majority_class}")
        else:
            self._overall_majority_class = None
            print("DEBUG: No data to calculate overall majority class.")


        # Initialize assignment dataframe with all rows at the root node
        # Select only necessary columns to minimize data size
        assign_df = (
            df
            .withColumn("node_id", F.lit(0).cast(IntegerType()))
            .select("row_id", "node_id", "time_series", "true_label")
            .cache()
        )
        print(f"DEBUG: Initial assign_df created with {assign_df.count()} rows at root node 0.")


        open_nodes = {0}

        for depth in range(self.max_depth):
            print(f"\nDEBUG: === Starting tree level {depth} ===")
            # If no nodes to expand, stop
            if not open_nodes:
                print(f"DEBUG: No open_nodes at depth {depth}, stopping tree building.")
                break

            print(f"DEBUG: Open nodes at depth {depth}: {open_nodes}")

            # Filter assign_df to only include rows at the current open nodes
            current_level_df = assign_df.filter(F.col("node_id").isin(list(open_nodes))).cache()
            print(f"DEBUG: Filtered data for current level. Row count: {current_level_df.count()}")

            # Check if any data exists for the current open nodes
            if current_level_df.count() == 0:
                print(f"DEBUG: No data for open nodes at depth {depth}, stopping.")
                current_level_df.unpersist()
                break


            # --- Corrected Exemplar Sampling Logic (Driver-side) ---
            print("DEBUG: Sampling exemplars (driver-side).")
            sampled_exemplars = {} # {node_id: {true_label: [exemplar_ts1, exemplar_ts2, ...]}}

            # Get distinct (node_id, true_label) pairs present in the current level's data
            node_label_pairs = current_level_df.select("node_id", "true_label").distinct().collect()
            print(f"DEBUG: Found {len(node_label_pairs)} distinct (node_id, true_label) pairs for sampling.")

            for node_id, true_label in node_label_pairs:
                print(f"DEBUG: Sampling exemplars for node {node_id}, label {true_label}.")
                # Filter the current level's data for this specific node and label
                node_label_df = current_level_df.filter((F.col("node_id") == node_id) & (F.col("true_label") == true_label))

                # Take a sample of rows for this node and label
                # Use .limit() and .collect() on a small sample to avoid OOM
                # A more robust way might use RDD.takeSample
                sampled_rows = node_label_df.limit(self.num_exemplars_per_class).collect()
                sampled_time_series = [row.time_series for row in sampled_rows]

                if node_id not in sampled_exemplars:
                    sampled_exemplars[node_id] = {}
                sampled_exemplars[node_id][true_label] = sampled_time_series
                print(f"DEBUG: Sampled {len(sampled_time_series)} exemplars for node {node_id}, label {true_label}.")

            print(f"DEBUG: Finished sampling exemplars. Total sampled exemplars structure: {sampled_exemplars}")
            # --- End Corrected Exemplar Sampling ---


            # 2. Generate and evaluate candidate splits for each open node
            # This logic runs on the driver, but uses distributed operations for evaluation
            best_splits = {} # {node_id: (best_gini_gain, best_distance_measure, {branch_id: exemplar_ts})}
            nodes_to_make_leaves_this_iter = set() # Nodes that should become leaves in *this* iteration

            for node_id in open_nodes:
                print(f"DEBUG: Evaluating splits for node {node_id}.")
                if node_id not in sampled_exemplars or not sampled_exemplars[node_id]:
                    print(f"DEBUG: No exemplars found for node {node_id}, making it a leaf.")
                    nodes_to_make_leaves_this_iter.add(node_id)
                    continue # Cannot split without exemplars

                node_data_df = current_level_df.filter(F.col("node_id") == node_id).cache()
                node_total_samples = node_data_df.count()

                if node_total_samples < self.min_samples:
                    print(f"DEBUG: Node {node_id} has {node_total_samples} samples, below min_samples {self.min_samples}, making it a leaf.")
                    nodes_to_make_leaves_this_iter.add(node_id)
                    node_data_df.unpersist()
                    continue

                # Calculate parent Gini impurity
                parent_label_counts = node_data_df.groupBy("true_label").count().collect()
                parent_gini = self._calculate_gini_impurity(parent_label_counts, node_total_samples)
                print(f"DEBUG: Node {node_id} parent Gini: {parent_gini}")

                best_gini_gain = -1.0
                best_split_info = None # (distance_measure, {branch_id: exemplar_ts})

                # Generate and evaluate candidate splits
                for i in range(self.num_candidate_splits):
                    print(f"DEBUG: Evaluating candidate split {i+1} for node {node_id}.")
                    # Sample a distance measure and parameters (simplified: using Euclidean)
                    # In a full implementation, sample from the 11 measures and their params
                    distance_measure_type = "euclidean" # Placeholder
                    # Sample exemplars for this candidate split (one per class present in node_data_df)
                    # Need to get unique labels in node_data_df first
                    unique_labels_in_node = [row['true_label'] for row in node_data_df.select("true_label").distinct().collect()]
                    candidate_exemplars = {}
                    for label in unique_labels_in_node:
                        if label in sampled_exemplars[node_id] and sampled_exemplars[node_id][label]:
                            # Pick one random exemplar for this label from the sampled pool for this node
                            candidate_exemplars[label] = random.choice(sampled_exemplars[node_id][label])
                        else:
                            # Should not happen if sampling was done correctly and node_data_df has this label
                            print(f"WARNING: No sampled exemplar in pool for label {label} in node {node_id}. Skipping candidate split.")
                            candidate_exemplars = None # Invalidate this candidate
                            break

                    if candidate_exemplars is None or len(candidate_exemplars) < 2:
                        print(f"DEBUG: Candidate split {i+1} for node {node_id} has less than 2 exemplars, skipping.")
                        continue # Need at least two branches

                    print(f"DEBUG: Candidate split {i+1} exemplars (labels): {list(candidate_exemplars.keys())}")

                    # --- Modified Gini Calculation: Use RDD transformations to get counts directly ---
                    bc_candidate_exemplars = self.spark.sparkContext.broadcast(candidate_exemplars)

                    def map_to_branch_label_pair(row):
                        exemplars = bc_candidate_exemplars.value
                        min_dist = float('inf')
                        assigned_branch_id = None # The label of the nearest exemplar

                        for ex_lbl, ex_ts in exemplars.items():
                            # Use the chosen distance measure (placeholder: euclidean)
                            # In a real implementation, call a function that dispatches based on measure_type
                            d = euclidean_distance(row.time_series, ex_ts) # Use the distance function
                            if d < min_dist:
                                min_dist = d
                                assigned_branch_id = ex_lbl

                        # Return a tuple of (assigned_branch_id, true_label) for counting
                        return (assigned_branch_id, row.true_label)

                    # Apply the map and countByValue to get the counts per (branch, label) pair
                    print(f"DEBUG: Calculating branch-label counts for candidate split {i+1} using RDD.")
                    # countByValue returns a dictionary {(branch_id, true_label): count}
                    branch_label_counts_dict = node_data_df.rdd.map(map_to_branch_label_pair).countByValue()
                    print(f"DEBUG: Branch label counts dictionary collected for candidate split {i+1}: {branch_label_counts_dict}")

                    # Convert the dictionary to the list format expected by _calculate_gini_impurity
                    # branch_label_counts = [{"assigned_branch_id": k[0], "true_label": k[1], "count": v} for k, v in branch_label_counts_dict.items()] # Not needed in this format anymore


                    # Calculate weighted impurity for this split
                    weighted_impurity = 0.0
                    total_samples_in_split = node_total_samples # Total samples in the node

                    # Group counts by branch_id to calculate branch impurity directly from the dictionary
                    branch_counts = {}
                    for (branch_id, true_label), count in branch_label_counts_dict.items():
                        if branch_id not in branch_counts:
                            branch_counts[branch_id] = []
                        branch_counts[branch_id].append((true_label, count))

                    print(f"DEBUG: Branch counts grouped for candidate split {i+1}: {branch_counts}")

                    for branch_id, label_counts_list in branch_counts.items():
                        branch_total = sum(count for label, count in label_counts_list)
                        if branch_total > 0:
                            branch_impurity = self._calculate_gini_impurity(label_counts_list, branch_total)
                            weighted_impurity += (branch_total / total_samples_in_split) * branch_impurity
                            print(f"DEBUG: Branch {branch_id} impurity: {branch_impurity}, weighted: {(branch_total / total_samples_in_split) * branch_impurity}")


                    gini_gain = parent_gini - weighted_impurity
                    print(f"DEBUG: Candidate split {i+1} Gini gain: {gini_gain}")

                    # Unpersist the broadcast variable
                    bc_candidate_exemplars.unpersist()
                    # --- End Modified Gini Calculation ---


                    # Check if this is the best split so far
                    if gini_gain > best_gini_gain:
                        best_gini_gain = gini_gain
                        best_split_info = (distance_measure_type, candidate_exemplars)
                        print(f"DEBUG: Candidate split {i+1} is the best so far for node {node_id} with gain {best_gini_gain}.")


                node_data_df.unpersist() # Unpersist node data

                # Decide if the node should split
                # A split occurs if best_gini_gain is positive and results in valid children (handled in _split_node_gini)
                if best_gini_gain > 0:
                    print(f"DEBUG: Node {node_id} has a positive Gini gain ({best_gini_gain}), attempting to split.")
                    best_splits[node_id] = (best_gini_gain, best_split_info[0], best_split_info[1])
                else:
                    print(f"DEBUG: Node {node_id} has non-positive Gini gain ({best_gini_gain}), making it a leaf.")
                    nodes_to_make_leaves_this_iter.add(node_id)


            # --- Finalize nodes marked as leaves *in this iteration* ---
            # This loop should be here, inside the depth loop, after evaluating all open_nodes
            for node_id in nodes_to_make_leaves_this_iter:
                if node_id in self.tree and not self.tree[node_id].is_leaf:
                    print(f"DEBUG: Finalizing node {node_id} as a leaf.")
                    # Need to calculate the prediction for this leaf node
                    # Collect label counts for this node from assign_df
                    leaf_data_df = assign_df.filter(F.col("node_id") == node_id).cache()
                    leaf_label_counts = leaf_data_df.groupBy("true_label").count().collect()
                    leaf_data_df.unpersist()

                    leaf_prediction = None
                    if leaf_label_counts:
                        leaf_prediction = max(leaf_label_counts, key=lambda x: x['count'])['true_label']
                    elif self._overall_majority_class is not None:
                        # Fallback to overall majority if no data at this node (shouldn't happen with correct logic)
                        leaf_prediction = self._overall_majority_class
                        print(f"DEBUG: Node {node_id} had no data, using overall majority prediction: {leaf_prediction}")
                    else:
                        # Final fallback if no data and no overall majority
                        leaf_prediction = 1 # Defaulting to 1

                    self.tree[node_id] = self.tree[node_id]._replace(is_leaf=True, prediction=leaf_prediction)
                    print(f"DEBUG: Node {node_id} marked as leaf with prediction {leaf_prediction}.")
            # --- End Finalization in Iteration ---


            # 3. Perform the best splits and update the tree structure (on driver)
            # and push rows down to the new child nodes (distributed)
            next_open = set() # Nodes that successfully split and will be processed in the next iteration
            if best_splits:
                print("DEBUG: Performing best splits and pushing rows down.")
                # Create a mapping from (parent_node_id, assigned_branch_id) to new_child_node_id
                split_mapping = {} # {(parent_id, assigned_branch_id): child_node_id}

                for parent_id, (gain, measure, exemplars) in best_splits.items():
                    print(f"DEBUG: Processing best split for parent node {parent_id}.")
                    # Update the tree structure on the driver with the chosen split info
                    self.tree[parent_id] = self.tree[parent_id]._replace(split_on=(measure, exemplars))
                    print(f"DEBUG: Node {parent_id} split_on updated: measure={measure}, exemplars={list(exemplars.keys())}.")

                    # --- Mark the parent node as INTERNAL ---
                    # Only mark as INTERNAL if it successfully splits and creates children
                    # This is determined by checking branch counts against min_samples below.
                    # We'll set is_leaf=False and prediction=None here tentatively,
                    # and confirm it becomes internal if children are created.
                    self.tree[parent_id] = self.tree[parent_id]._replace(is_leaf=False, prediction=None)
                    print(f"DEBUG: Node {parent_id} tentatively marked as INTERNAL.")


                    # Recalculate branch counts for the best split to check min_samples
                    node_data_for_split_df = assign_df.filter(F.col("node_id") == parent_id).cache()
                    bc_chosen_exemplars_for_counts = self.spark.sparkContext.broadcast(exemplars)

                    def map_to_branch_for_counts(row):
                        exemplars = bc_chosen_exemplars_for_counts.value
                        min_dist = float('inf')
                        assigned_branch_id = None

                        for ex_lbl, ex_ts in exemplars.items():
                            d = euclidean_distance(row.time_series, ex_ts)
                            if d < min_dist:
                                min_dist = d
                                assigned_branch_id = ex_lbl
                        return assigned_branch_id # Return only the assigned branch ID

                    # Count samples per assigned branch for the best split
                    branch_counts_for_children = node_data_for_split_df.rdd.map(map_to_branch_for_counts).countByValue()
                    print(f"DEBUG: Branch counts for creating children for node {parent_id}: {branch_counts_for_children}")

                    node_data_for_split_df.unpersist()
                    bc_chosen_exemplars_for_counts.unpersist()

                    children_created_for_node = False
                    for branch_id, count in branch_counts_for_children.items():
                        # Only create a child node if the branch has enough samples
                        if count >= self.min_samples:
                            child_id = self._next_node_id
                            self._next_node_id += 1
                            print(f"DEBUG: Creating child node {child_id} for branch {branch_id} of parent {parent_id}.")
                            self.tree[child_id] = self.TreeNode(
                                node_id=child_id,
                                parent_id=parent_id,
                                split_on=None, # Split info will be determined in a future iteration if not a leaf
                                is_leaf=False, # Initially internal, will be finalized later
                                prediction=None,
                                children={},
                            )
                            # Update the parent node's children dictionary on the driver
                            self.tree[parent_id].children[branch_id] = child_id
                            # Add to the mapping used for pushing rows
                            split_mapping[(parent_id, branch_id)] = child_id
                            # Add the new child node to the set of nodes to process in the next iteration
                            next_open.add(child_id)
                            children_created_for_node = True
                            print(f"DEBUG: Added child {child_id} to parent {parent_id} children for branch {branch_id}.")
                        else:
                            print(f"DEBUG: Branch {branch_id} for node {parent_id} has {count} samples, below min_samples. Not creating child node.")
                            # Data points assigned to branches that don't create a child node
                            # will remain at the parent_id in the next assignment_df.
                            # This is a simplification; ideally, they might be handled differently
                            # (e.g., contribute to the parent's prediction if it becomes a leaf).


                    # If no children were created for this node despite a positive Gini gain
                    # (e.g., all branches had < min_samples), mark it as a leaf.
                    if not children_created_for_node:
                        print(f"DEBUG: Node {parent_id} had positive Gini gain but no branches met min_samples. Finalizing as a leaf.")
                        # Recalculate prediction based on data at this node
                        leaf_data_df = assign_df.filter(F.col("node_id") == parent_id).cache()
                        leaf_label_counts = leaf_data_df.groupBy("true_label").count().collect()
                        leaf_data_df.unpersist()

                        leaf_prediction = None
                        if leaf_label_counts:
                            leaf_prediction = max(leaf_label_counts, key=lambda x: x['count'])['true_label']
                        elif self._overall_majority_class is not None:
                            leaf_prediction = self._overall_majority_class
                            print(f"DEBUG: Node {node_id} had no data, using overall majority prediction: {leaf_prediction}")
                        else:
                            leaf_prediction = 1 # Default

                        self.tree[parent_id] = self.tree[parent_id]._replace(is_leaf=True, prediction=leaf_prediction, children={}) # Clear children if no split occurred
                        print(f"DEBUG: Node {parent_id} marked as leaf with prediction {leaf_prediction}.")


                # --- Modified Pushing Rows Down: Use a single UDF ---
                # Only apply the push down UDF if any children were actually created in this iteration
                if split_mapping: # split_mapping will be non-empty if any children were created
                    print("DEBUG: Applying single UDF to push rows down.")
                    bc_split_mapping = self.spark.sparkContext.broadcast(split_mapping)
                    bc_best_splits_info = self.spark.sparkContext.broadcast({nid: (split_info[0], split_info[2]) for nid, split_info in best_splits.items()}) # Broadcast (measure, exemplars) for splitting nodes

                    def push_row_udf_func(split_mapping_broadcast, best_splits_info_broadcast):
                        mapping = split_mapping_broadcast.value
                        splits_info = best_splits_info_broadcast.value

                        def _push_row(row_id, current_node_id, time_series, true_label):
                            # If the current node is one of the nodes that split in this iteration
                            if current_node_id in splits_info:
                                measure_type, exemplars = splits_info[current_node_id]

                                # Calculate distance to exemplars for this node's split
                                min_dist = float('inf')
                                assigned_branch_id = None

                                for ex_lbl, ex_ts in exemplars.items():
                                    # Use the chosen distance measure (placeholder: euclidean)
                                    # In a real implementation, call a function that dispatches based on measure_type
                                    d = euclidean_distance(time_series, ex_ts)
                                    if d < min_dist:
                                        min_dist = d
                                        assigned_branch_id = ex_lbl

                                # Use the split mapping to find the new node ID
                                key = (current_node_id, assigned_branch_id)
                                # If there's a mapping for this parent/branch, return the child node ID
                                # Otherwise, keep the old node ID (this handles branches that didn't create children)
                                return mapping.get(key, current_node_id)
                            else:
                                # If the current node was not one of the nodes that split,
                                # the row stays at its current node ID.
                                return current_node_id

                        # Return the UDF itself
                        return F.udf(_push_row, IntegerType())

                    # Create an instance of the UDF
                    push_row_udf = push_row_udf_func(bc_split_mapping, bc_best_splits_info)

                    # Apply the UDF to the entire assign_df to get the new node_id for each row
                    old_assign_df = assign_df # Keep reference to unpersist later
                    assign_df = assign_df.withColumn(
                        "node_id", # Overwrite the node_id column
                        push_row_udf(F.col("row_id"), F.col("node_id"), F.col("time_series"), F.col("true_label"))
                    ).cache() # Cache the updated DataFrame
                    print(f"DEBUG: assign_df updated for depth {depth+1}. Total rows: {assign_df.count()}")

                    # Unpersist intermediate DataFrames and broadcast variables
                    old_assign_df.unpersist()
                    bc_split_mapping.unpersist()
                    bc_best_splits_info.unpersist()
                else:
                    print("DEBUG: No children created in this iteration. No push down needed. assign_df remains unchanged for relevant nodes.")


            else:
                print("DEBUG: No nodes split in this iteration. assign_df remains unchanged.")


            # Unpersist data for the current level
            current_level_df.unpersist()

            # Update open_nodes for the next iteration
            # Only include nodes that were successfully split into internal nodes
            open_nodes = next_open
            print(f"DEBUG: open_nodes for next level: {open_nodes}")


        # --- Finalize any internal nodes that were not explicitly finalized (e.g., hit max_depth) ---
        print("\nDEBUG: === Finalizing remaining internal nodes as leaves ===")
        all_node_ids = list(self.tree.keys()) # Get keys before potential modification
        for node_id in all_node_ids:
            # Check if the node exists and is NOT already marked as a leaf
            if node_id in self.tree and not self.tree[node_id].is_leaf:
                print(f"DEBUG: Finalizing node {node_id} as a leaf (reached max_depth or other stop criteria).")
                # Need to calculate the prediction for this leaf node
                # Collect label counts for this node from the final assign_df state
                leaf_data_df = assign_df.filter(F.col("node_id") == node_id).cache()
                leaf_label_counts = leaf_data_df.groupBy("true_label").count().collect()
                leaf_data_df.unpersist() # Unpersist after collecting counts

                leaf_prediction = None
                if leaf_label_counts:
                    leaf_prediction = max(leaf_label_counts, key=lambda x: x['count'])['true_label']
                    print(f"DEBUG: Node {node_id} majority prediction: {leaf_prediction}")
                elif self._overall_majority_class is not None:
                    # Fallback to overall majority if no data at this node (shouldn't happen with correct logic)
                    leaf_prediction = self._overall_majority_class
                    print(f"DEBUG: Node {node_id} had no data, using overall majority prediction: {leaf_prediction}")
                else:
                    # Final fallback if no data and no overall majority
                    leaf_prediction = 1 # Defaulting to 1
                    print(f"DEBUG: Node {node_id} had no data and no overall majority, using default prediction: {leaf_prediction}")


                # Update the node in the tree structure
                self.tree[node_id] = self.tree[node_id]._replace(is_leaf=True, prediction=leaf_prediction, children={}) # Clear children as it's a leaf
                print(f"DEBUG: Node {node_id} marked as leaf with prediction {leaf_prediction}.")
        # --- End Finalization ---


        assign_df.unpersist() # Unpersist the final assignment DataFrame
        print("DEBUG: fit finished.")
        return self

    def _calculate_gini_impurity(self, label_counts, total_samples):
        """
        Calculates Gini impurity.

        Parameters:
        -----------
        label_counts : list of (label, count) tuples or dict {label: count}
            Counts of each label in the dataset or branch.
        total_samples : int
            Total number of samples.

        Returns:
        --------
        float : Gini impurity
        """
        if total_samples == 0:
            return 0.0

        impurity = 1.0
        # Ensure label_counts is treated as a dictionary-like structure
        if isinstance(label_counts, list):
            counts_dict = dict(label_counts)
        else:
            counts_dict = label_counts

        for label, count in counts_dict.items():
            probability_of_label = count / total_samples
            impurity -= probability_of_label ** 2

        return impurity


    def predict(self, df):
        """
        Make predictions using the trained tree

        Parameters:
        -----------
        df : Spark DataFrame
            DataFrame with feature columns or 'time_series' column

        Returns:
        --------
        Spark DataFrame : DataFrame with predictions
        """
        print("DEBUG: predict started.")
        # First, convert to time_series format if needed and add row_id
        df = self._convert_to_time_series_format(df)

        # --- Convert tree structure to a plain dictionary for broadcasting ---
        print("DEBUG: Converting tree structure to plain dictionary for broadcasting.")
        plain_tree_structure = {}
        for node_id, node in self.tree.items():
            plain_tree_structure[node_id] = {
                'node_id': node.node_id,
                'parent_id': node.parent_id,
                # Ensure split_on is also a plain structure (e.g., tuple of string and dict)
                'split_on': node.split_on,
                'is_leaf': node.is_leaf,
                'prediction': node.prediction,
                # Children dictionary keys (branch_id) and values (child_node_id) are already plain types
                'children': node.children
            }

        # Broadcast the plain tree structure
        print("DEBUG: Broadcasting plain tree structure for prediction.")
        plain_tree_structure_broadcast = self.spark.sparkContext.broadcast(plain_tree_structure)

        # Create the prediction UDF using the broadcasted plain tree
        # Pass the broadcast variable to the function that defines the UDF
        prediction_udf = F.udf(predict_udf_func(plain_tree_structure_broadcast), IntegerType())

        # Apply the prediction UDF to each row
        predictions_df = df.withColumn("prediction", prediction_udf(F.col("time_series")))

        # Unpersist the broadcast variable after the prediction is done
        # Note: Spark manages broadcast lifecycle, but explicit unpersist is good practice
        # in interactive sessions or when memory is tight.
        plain_tree_structure_broadcast.unpersist()
        print("DEBUG: Plain tree structure unbroadcasted.")

        print("DEBUG: predict finished.")
        # Select the original columns plus the new prediction column

        return predictions_df.select("row_id", "time_series", "true_label", "prediction")


    def print_tree(self):
        """
        Print a representation of the tree (driver-side).
        Adjusted to show children even if node is marked as leaf,
        to better reflect the structure built.

        Returns:
        --------
        str : String representation of the tree
        """
        print("DEBUG: print_tree started.")
        def print_node(node_id, depth=0):
            if node_id not in self.tree:
                return f"{'  ' * depth}Node {node_id}: Does Not Exist\n"

            node = self.tree[node_id]
            indent = "  " * depth

            # Format split_on info nicely
            split_info_str = "None"
            if node.split_on:
                measure_type, exemplars = node.split_on
                # Print exemplar time series for small trees
                exemplar_details = {lbl: ts for lbl, ts in exemplars.items()}
                split_info_str = f"measure={measure_type}, exemplars={exemplar_details}"


            # Print node info including leaf status, prediction, parent, and depth
            result = f"{indent}Node {node_id} (Depth {depth}, Parent: {node.parent_id}): {'LEAF' if node.is_leaf else 'INTERNAL'}, prediction={node.prediction}, split_on=[{split_info_str}]\n"

            # Recursively print children if they exist, regardless of is_leaf flag
            if node.children:
                result += f"{indent}  Children:\n"
                for branch_id, child_id in sorted(node.children.items()):
                    result += f"{indent}    Branch {branch_id} -> Child {child_id}\n"
                    # Only recurse if the child node exists in the tree
                    if child_id in self.tree:
                        result += print_node(child_id, depth + 1) # Increase depth for child nodes
                    else:
                        result += f"{indent}      Node {child_id}: Does Not Exist\n"


            return result

        tree_str = print_node(0)  # Start at root at depth 0
        print("DEBUG: print_tree finished.")
        return tree_str


In [10]:
def randomSplit_stratified_via_sampleBy(df, label_col, weights=[0.8, 0.2], seed=123):
    
    """
    Splits a Spark DataFrame into train/test sets based on partition-Preserves per‑class proportions

    """
    
    assert abs(sum(weights) - 1.0) < 1e-6 # ensure that our weights must sum to 1.0
    train_frac = weights[0]

    # figure out all the distinct label values 
    labels = [row[label_col] for row in df.select(label_col) 
                                            .distinct()  # build a tiny DataFrame of unique labels
                                            .collect() # brings the list to the driver
                                            ]

    # build a dict: each label -> same fraction
    fractions = {dict_lbl: train_frac for dict_lbl in labels}

    # sample train set: Use Spark’s native stratified sampler
    train_df = df.stat.sampleBy(label_col, fractions, seed) # map‑side sampling per key, jno shuffle
    # everything else is test
    test_df  = df.join(train_df, on=df.columns, how="left_anti") # one shuffles to get the rest of the data
    return train_df, test_df 


In [11]:
spark = SparkSession.builder.master("local[*]").appName("testingglobal").getOrCreate()
sc = spark.sparkContext

In [12]:
config = {
    "databricks_data_path": "/mnt/2025-team6/fulldataset_ECG5000.csv",
    "local_data_path": "/fulldataset_ECG5000.csv",
    "label_col": "label",
    "data_percentage": 1.0,
    "min_number_iterarations": 2,

    "local_model_config": {
        "test_local_model" : True,
        "num_partitions": 10,  
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": None,  
            "min_samples_split": 5,  # From ProximityTree default
            "random_state": 123
            },
        "forest_params": {
            "random_state": 123,
            "n_jobs": -1  # Use all available cores
            }
    },
    "global_model_config": {
        "test_local_model" : False,
        "num_partitions": 10
    }
}

ingestion_config = {
                "data_path":r"D:\repos\BigData-main\BigData-1\fulldataset_ECG5000.csv",
                "data_percentage": config.get("data_percentage", 0.5)
}
ingestion = DataIngestion(spark=spark, config=ingestion_config)
preprocessor = Preprocessor(config=config)
# load + preprocess data
df = ingestion.load_data()
df = preprocessor.run_preprocessing(df)
train_df, test_df = randomSplit_stratified_via_sampleBy(df, label_col = "label", weights=[0.8, 0.2], seed=123)

Data Path: D:\repos\BigData-main\BigData-1\fulldataset_ECG5000.csv
Loading 100.0% of data
Data size: 5000

Repartitioning to 10 workers


In [13]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType, ArrayType
import random
import collections
import math
import json
import traceback # Import traceback to print error details

# Assume the GlobalProxTree class definition is available in the environment
# (e.g., defined in a previous cell or imported from a file)

# Create a SparkSession (replace with your actual SparkSession if running in a cluster)
# If running in a notebook like environment, spark might already be defined.
try:
    spark
except NameError:
    spark = SparkSession.builder.appName("GlobalProxTreeTest").getOrCreate()
    print("DEBUG: SparkSession created.")

# --- FIX: Explicitly access SparkContext to help ensure initialization ---
# This line can sometimes help resolve issues with SparkContext not being available
spark.sparkContext
print("DEBUG: SparkContext accessed.")


# --- Create a small, dummy DataFrame for testing ---
# This data simulates time series with 2 features and 2 classes (1 and 2)
# # Designed to potentially create a simple split at the root
# dummy_data = [
#     (1.0, 1.1, 1), # Class 1
#     (1.2, 1.3, 1), # Class 1
#     (0.8, 0.9, 1), # Class 1
#     (5.0, 5.1, 2), # Class 2
#     (5.2, 5.3, 2), # Class 2
#     (4.8, 4.9, 2), # Class 2
#     (2.5, 2.6, 1), # Class 1 (closer to Class 1 exemplars)
#     (3.5, 3.6, 2), # Class 2 (closer to Class 2 exemplars)
#     (4.8, 4.9, 2), # Class 2
#     (7.8, 5.9, 3), # Class 3
#     (6.8, 5.9, 3), # Class 3
#     (7.8, 5.9, 3), # Class 3
#     (5.8, 5.9, 3), # Class 3
# ]

# # Define schema for the dummy data
# dummy_schema = StructType([
#     StructField("feature1", DoubleType(), True),
#     StructField("feature2", DoubleType(), True),
#     StructField("label", IntegerType(), True)
# ])

# # Create the dummy DataFrame
# # This is where the error occurred previously
# dummy_train_df = spark.createDataFrame(dummy_data, dummy_schema)

# print("Dummy Training DataFrame:")
# dummy_train_df.show()
print(f"DataFrame shape: {len(train_df.columns)} columns, {train_df.count()} rows")


# --- Test the GlobalProxTree class ---

# First, create the tree with the desired parameters
# Using small max_depth and min_samples for a shallow tree
# num_candidate_splits=3 to see evaluation process
# num_exemplars_per_class=1 as in the paper's conceptual split
tree = GlobalProxTree(spark, max_depth=3, min_samples=5, num_candidate_splits=3, num_exemplars_per_class=1)


# Now we can directly fit the tree on the wide DataFrame
# The conversion will happen automatically inside the fit method
try:
    print("\nFitting tree on dummy DataFrame...")
    tree.fit(train_df)

    print("\nTree structure:")
    # Use the corrected print_tree method to see the full structure
    print(tree.print_tree())

except Exception as e:
    print(f"Error while fitting tree: {e}")
    # Print the full traceback for detailed debugging
    traceback.print_exc()

print("\nTree fitting complete!")

DEBUG: SparkContext accessed.
DataFrame shape: 141 columns, 4031 rows

Fitting tree on dummy DataFrame...
DEBUG: fit started.
DEBUG: _convert_to_time_series_format started.
DEBUG: Converting 140 feature columns to 'time_series' array.
DEBUG: Sample of converted DataFrame:
+----------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [14]:

# # Example Prediction (uncomment and run after successful fitting)
# dummy_test_data = [
#     (1.1, 1.2, 1), # Should predict 1
#     (0.1, 1.5, 1), # Should predict 1
#     (1.9, 1.5, 2), # Should predict 1
#     (1.0, 1.6, 2), # Should predict 1
#     (5.3, 5.4, 2), # Closer to class 2 exemplars
#     (3.0, 3.1, 1), # Closer to class 2 exemplars
#     (4.0, 4.1, 2), # Closer to class 2 exemplars
#     (7.8, 5.9, 3), # shoudl predict class 3
#     (7.8, 5.9, 3), # shoudl predict class 3
#     (8.8, 5.9, 2), # shoudl predict class 3
#     (0.8, 0.9, 3), # shoudl predict class 1
# ]
# dummy_test_df = spark.createDataFrame(dummy_test_data, dummy_schema)

print("\nTesting Prediction...")
try:
    pred_df = tree.predict(test_df)

    # Rename true_label to label for evaluator compatibility
    pred_df = pred_df.withColumnRenamed("true_label", "label")

    # Cast both label and prediction to Double for evaluator
    pred_df = (
        pred_df
        .withColumn("prediction", F.col("prediction").cast(DoubleType()))
        .withColumn("label",      F.col("label")      .cast(DoubleType()))
    )

    print("\n True Labels vs Sample Predictions:")
    pred_df.select("label", "prediction").show(10)

    # Evaluate accuracy
    evaluator = MulticlassClassificationEvaluator(
        labelCol="label",
        predictionCol="prediction",
        metricName="accuracy"
    )
    acc_evaluator = evaluator.evaluate(pred_df)
    print(f"Prediction Accuracy = {acc_evaluator:.3f}")

except Exception as e:
    print(f"Error during prediction: {e}")
    traceback.print_exc()





Testing Prediction...
DEBUG: predict started.
DEBUG: _convert_to_time_series_format started.
DEBUG: Converting 140 feature columns to 'time_series' array.
DEBUG: Sample of converted DataFrame:
+----------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Traceback (most recent call last):
  File "C:\Users\Petru\AppData\Local\Temp\ipykernel_3996\654261130.py", line 40, in <module>
    acc_evaluator = evaluator.evaluate(pred_df)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Petru\anaconda3\envs\bigdata_env\Lib\site-packages\pyspark\ml\evaluation.py", line 111, in evaluate
    return self._evaluate(dataset)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Petru\anaconda3\envs\bigdata_env\Lib\site-packages\pyspark\ml\evaluation.py", line 148, in _evaluate
    return self._java_obj.evaluate(dataset._jdf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Petru\anaconda3\envs\bigdata_env\Lib\site-packages\py4j\java_gateway.py", line 1322, in __call__
    return_value = get_return_value(
                   ^^^^^^^^^^^^^^^^^
  File "c:\Users\Petru\anaconda3\envs\bigdata_env\Lib\site-packages\pyspark\errors\exceptions\captured.py", line 179, in deco
    return f(*a, **kw)
           ^^^^^^^^^^^
  File "c:\Us

In [15]:
# # Note: If you created the SparkSession here, you might want to stop it
# spark.stop()