# IMPORTS

In [15]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging
from pyspark.sql import SparkSession
import os
from pyspark.sql import SparkSession
from data_ingestion import DataIngestion
from preprocessing import Preprocessor
from prediction_manager import PredictionManager
from local_model_manager import LocalModelManager
from evaluation import Evaluator
from utilities import show_compact
import time
import json
from random import sample
from dtaidistance import dtw
import collections

# CREATE SPARK SESSION

In [16]:
spark = SparkSession.builder.master("local[*]").appName("testingglobal").getOrCreate()
sc = spark.sparkContext

# CREATE CLASS INSTANCES

In [17]:
config = {
    "databricks_data_path": "/mnt/2025-team6/fulldataset_ECG5000.csv",
    "local_data_path": "/fulldataset_ECG5000.csv",
    "label_col": "label",
    "data_percentage": 1.0,
    "min_number_iterarations": 2,

    "local_model_config": {
        "test_local_model" : True,
        "num_partitions": 10,  
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": None,  
            "min_samples_split": 5,  # From ProximityTree default
            "random_state": 123
            },
        "forest_params": {
            "random_state": 123,
            "n_jobs": -1  # Use all available cores
            }
    },
    "global_model_config": {
        "test_local_model" : False,
        "num_partitions": 10
    }
}

ingestion_config = {
                "data_path":r"C:\Users\benat\OneDrive\0. MSc MLiS\0. GitHub Repositories\BigDataProject_repo\fulldataset_ECG5000.csv",
                "data_percentage": config.get("data_percentage", 0.5)
}

In [18]:
ingestion = DataIngestion(spark=spark, config=ingestion_config)
preprocessor = Preprocessor(config=config)

# CODE

In [19]:
def repartition_sparkdf(df, num_partitions):
    rdd = df.rdd
    rdd = rdd.repartition(num_partitions)
    rdd = rdd.mapPartitionsWithIndex(
            lambda idx, iter: [{**row.asDict(), "partition_id": idx} for row in iter]
        )
    return rdd

def choose_exemplars(iterator):
    partition_data = list(iterator)
    if not partition_data:
        return iter([])
    
    # Group data by class
    grouped_data_by_class = {}
    for row in partition_data:
        label = row.get('label')
        if label is not None:
            grouped_data_by_class.setdefault(label, []).append(row)
    
    # Select one exemplar per class and track exemplar-to-label mapping
    chosen_exemplars = []
    exemplar_labels = {}  # Map to track exemplar labels
    
    for label, instances in grouped_data_by_class.items():
        if instances:  # Ensure there are instances for the class
            exemplar = sample(instances, 1)[0]
            chosen_exemplars.append((exemplar['time_series'], label))  # Store tuple of (time_series, label)
    
    # Remove chosen exemplars from the working data
    exemplar_time_series = [ex[0] for ex in chosen_exemplars]
    filtered_partition = [
        row for row in partition_data
        if row['time_series'] not in exemplar_time_series
    ]
    
    # Return rows with individual exemplar columns and their labels
    result = []
    for row in filtered_partition:
        new_row = {**row}
        # Add each exemplar and its label as columns
        for i, (exemplar, label) in enumerate(chosen_exemplars, 1):
            new_row[f"exemplar_{i}"] = exemplar
            new_row[f"exemplar_{i}_label"] = label
        result.append(new_row)
    
    return iter(result)

def calc_dtw_distance(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        time_series = row.get('time_series', [])
        
        # Check for individual exemplar columns (exemplar_1, exemplar_2, etc.)
        exemplar_columns = {k: v for k, v in row.items() 
                           if k.startswith('exemplar_') and not k.endswith('_label') and isinstance(v, list)}
        
        if not exemplar_columns:
            # Try to get exemplars from the 'exemplars' list if individual columns aren't found
            exemplars = row.get('exemplars', [])
            if not exemplars:
                continue  # Skip if no exemplars found
            
            # Calculate DTW distances for each exemplar in the list
            updated_row = {**row}
            for i, exemplar in enumerate(exemplars):
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{i+1}"] = dtw_distance
            
            updated_rows.append(updated_row)
        else:
            # Calculate DTW distances for each exemplar column
            updated_row = {**row}
            for col_name, exemplar in exemplar_columns.items():
                # Extract index from column name (e.g., "exemplar_1" -> "1")
                idx = col_name.split('_')[1]
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{idx}"] = dtw_distance
            
            updated_rows.append(updated_row)
    
    return iter(updated_rows)

def assign_closest_exemplar(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        # Get all DTW distances
        dtw_distances = {k: v for k, v in row.items() if k.startswith('dtw_distance_exemplar_')}
        
        if not dtw_distances:
            # Create a simplified row without exemplar columns
            simplified_row = {k: v for k, v in row.items() 
                             if not k.startswith('exemplar_')}
            updated_rows.append(simplified_row)
            continue
        
        # Find the closest exemplar based on the minimum DTW distance
        closest_exemplar_key = min(dtw_distances, key=dtw_distances.get)
        min_distance = dtw_distances[closest_exemplar_key]
        
        # Extract exemplar number from the key (e.g., "dtw_distance_exemplar_1" -> "1")
        exemplar_num = closest_exemplar_key.split('_')[-1]
        
        # Get the corresponding exemplar time series data and its original label
        exemplar_key = f'exemplar_{exemplar_num}'
        exemplar_label_key = f'exemplar_{exemplar_num}_label'
        exemplar_time_series = row.get(exemplar_key, None)
        exemplar_original_label = row.get(exemplar_label_key, None)
        
        # Create a new row without the DTW distance columns and exemplar columns
        updated_row = {k: v for k, v in row.items() 
                      if not k.startswith('dtw_distance_exemplar_') and not k.startswith('exemplar_')}
        
        # Add information about the closest exemplar
        updated_row['closest_exemplar_id'] = closest_exemplar_key
        updated_row['closest_exemplar_data'] = exemplar_time_series
        updated_row['closest_exemplar_original_label'] = exemplar_original_label
        
        updated_rows.append(updated_row)
    
    return iter(updated_rows)

def calculate_gini(labels):
    if not labels:
        return 0
    label_counts = {}
    for label in labels:
        label_counts[label] = label_counts.get(label, 0) + 1
    total = sum(label_counts.values())
    gini = 1 - sum((count / total) ** 2 for count in label_counts.values()) if total > 0 else 0
    return gini

# Define tree node structure
class TreeNode:
    def __init__(self, node_id=None, split_exemplar_id=None, split_exemplar_data=None, 
                 split_exemplar_label=None, is_leaf=False, predicted_label=None, gini_reduction=None,
                 partition_id=None):
        self.node_id = node_id
        self.split_exemplar_id = split_exemplar_id
        self.split_exemplar_data = split_exemplar_data
        self.split_exemplar_label = split_exemplar_label
        self.gini_reduction = gini_reduction
        self.is_leaf = is_leaf
        self.predicted_label = predicted_label
        self.yes_child = None  # Instances closest to this exemplar
        self.no_child = None   # Instances closest to other exemplars
        self.partition_id = partition_id  # Track which partition built this tree
    
    def to_dict(self):
        """Convert tree to dictionary for serialization"""
        result = {
            'node_id': self.node_id,
            'split_exemplar_id': self.split_exemplar_id,
            'split_exemplar_data': self.split_exemplar_data,
            'split_exemplar_label': self.split_exemplar_label,
            'gini_reduction': self.gini_reduction,
            'is_leaf': self.is_leaf,
            'predicted_label': self.predicted_label,
            'partition_id': self.partition_id,
        }
        
        if self.yes_child:
            result['yes_child'] = self.yes_child.to_dict()
        if self.no_child:
            result['no_child'] = self.no_child.to_dict()
        
        return result

# Find the best split from all evaluated splits (within a single partition)
def find_best_split(all_splits):
    if not all_splits:
        return None
    return max(all_splits, key=lambda x: x['gini_reduction'])

# Build a tree for a single partition (locally, not distributed)
def build_partition_tree(partition_data, partition_id, max_depth=3, min_samples=2, node_id=1):
    # If too few samples or max depth reached, create a leaf node
    if len(partition_data) < min_samples or max_depth <= 0:
        # Determine the majority class for prediction
        label_counts = collections.Counter([row.get('label') for row in partition_data if row.get('label') is not None])
        if not label_counts:
            return TreeNode(node_id=node_id, is_leaf=True, predicted_label=None, partition_id=partition_id)
        
        majority_class = label_counts.most_common(1)[0][0]
        return TreeNode(node_id=node_id, is_leaf=True, predicted_label=majority_class, partition_id=partition_id)
    
    # Evaluate all possible splits (reuse your existing function logic)
    labels = [row.get('label') for row in partition_data if row.get('label') is not None]
    before_split_gini = calculate_gini(labels)
    
    # Get all unique exemplars in the partition
    unique_exemplars = set(
        row['closest_exemplar_id'] for row in partition_data
        if row.get('closest_exemplar_id') is not None
    )
    
    # Create a mapping of exemplar_id to exemplar_data and exemplar_label
    exemplar_data_map = {}
    exemplar_label_map = {}
    
    for row in partition_data:
        exemplar_id = row.get('closest_exemplar_id')
        if exemplar_id and exemplar_id not in exemplar_data_map:
            exemplar_data_map[exemplar_id] = row.get('closest_exemplar_data')
            exemplar_label_map[exemplar_id] = row.get('closest_exemplar_original_label')
    
    # Evaluate all possible splits
    splits = []
    for exemplar_id in unique_exemplars:
        # Split the data based on the current exemplar
        yes_split = [r for r in partition_data if r.get('closest_exemplar_id') == exemplar_id]
        no_split = [r for r in partition_data if r.get('closest_exemplar_id') != exemplar_id]
        
        # Calculate Gini for each daughter node
        yes_labels = [r.get('label') for r in yes_split if r.get('label') is not None]
        no_labels = [r.get('label') for r in no_split if r.get('label') is not None]
        
        yes_gini = calculate_gini(yes_labels)
        no_gini = calculate_gini(no_labels)
        
        # Calculate weighted Gini after split
        total_size = len(yes_split) + len(no_split)
        weighted_gini = (yes_gini * len(yes_split) / total_size + no_gini * len(no_split) / total_size) if total_size > 0 else float('inf')
        
        # Calculate Gini reduction
        gini_reduction = before_split_gini - weighted_gini if total_size > 0 else float('-inf')
        
        # Get the exemplar data and label for this exemplar_id
        exemplar_data = exemplar_data_map.get(exemplar_id)
        exemplar_label = exemplar_label_map.get(exemplar_id)
        
        # Add this split evaluation to results
        splits.append({
            "partition_id": partition_id,
            "exemplar_id": exemplar_id,
            "exemplar_data": exemplar_data,
            "exemplar_label": exemplar_label,
            "gini_reduction": gini_reduction   
        })
    
    # Find best split
    best_split = find_best_split(splits)
    
    # If no good split found, create a leaf node
    if not best_split or best_split['gini_reduction'] <= 0:
        label_counts = collections.Counter([row.get('label') for row in partition_data if row.get('label') is not None])
        if not label_counts:
            return TreeNode(node_id=node_id, is_leaf=True, predicted_label=None, partition_id=partition_id)
        
        majority_class = label_counts.most_common(1)[0][0]
        return TreeNode(node_id=node_id, is_leaf=True, predicted_label=majority_class, partition_id=partition_id)
    
    # Create a decision node
    node = TreeNode(
        node_id=node_id,
        split_exemplar_id=best_split['exemplar_id'],
        split_exemplar_data=best_split['exemplar_data'],
        split_exemplar_label=best_split['exemplar_label'],
        gini_reduction=best_split['gini_reduction'],
        partition_id=partition_id
    )
    
    # Split the data based on the best exemplar
    yes_data = [r for r in partition_data if r.get('closest_exemplar_id') == best_split['exemplar_id']]
    no_data = [r for r in partition_data if r.get('closest_exemplar_id') != best_split['exemplar_id']]
    
    # Recursively build the subtrees
    node.yes_child = build_partition_tree(yes_data, partition_id, max_depth-1, min_samples, node_id=node_id*2)
    node.no_child = build_partition_tree(no_data, partition_id, max_depth-1, min_samples, node_id=node_id*2+1)
    
    return node

# Function to visualize a single tree
def visualize_tree(node, indent=""):
    if node is None:
        return
    
    if node.is_leaf:
        print(f"{indent}Leaf: Predict class {node.predicted_label}")
    else:
        # Format the exemplar data nicely for display
        exemplar_data_str = "[" + ", ".join(f"{val:.1f}" for val in node.split_exemplar_data) + "]"
        
        print(f"{indent}Is closest to exemplar {node.split_exemplar_id} (class {node.split_exemplar_label})?")
        print(f"{indent}Data: {exemplar_data_str}, Gini reduction: {node.gini_reduction:.4f}")
        
        print(f"{indent}Yes ->")
        visualize_tree(node.yes_child, indent + "  ")
        
        print(f"{indent}No ->")
        visualize_tree(node.no_child, indent + "  ")

# Function to predict with a single tree
def predict_with_tree(row, tree):
    current_node = tree
    
    while not current_node.is_leaf:
        closest_exemplar = row.get('closest_exemplar_id')
        
        if closest_exemplar == current_node.split_exemplar_id:
            current_node = current_node.yes_child
        else:
            current_node = current_node.no_child
    
    return current_node.predicted_label

# RUNNING THE CODE

In [20]:
# Main function to run the full pipeline
def main():
    # Initialize Spark session
    spark = SparkSession.builder.appName("ExemplarTreeEnsemble").getOrCreate()
    
    # load + preprocess data
    df = ingestion.load_data()
    df = preprocessor.run_preprocessing(df)
    
    # Repartition data
    num_partitions = 10
    rdd = repartition_sparkdf(df, num_partitions)
    
    # Choose exemplars
    rdd_with_exemplars = rdd.mapPartitions(choose_exemplars)
    
    # Calculate DTW distances
    rdd_with_dtw = rdd_with_exemplars.mapPartitions(calc_dtw_distance)
    
    # Assign closest exemplars
    rdd_with_closest_exemplar = rdd_with_dtw.mapPartitions(assign_closest_exemplar)
    
    # Collect all data (for our small test dataset, this is fine)
    all_data = rdd_with_closest_exemplar.collect()
    
    # Group by partition
    partitions = {}
    for row in all_data:
        pid = row['partition_id']
        if pid not in partitions:
            partitions[pid] = []
        partitions[pid].append(row)
    
    # Build trees for each partition
    trees = []
    for pid, data in partitions.items():
        if data:  # Make sure there's data in this partition
            tree = build_partition_tree(data, pid, max_depth=3, min_samples=2)
            trees.append(tree)
    
    # Visualize each tree
    for i, tree in enumerate(trees):
        print(f"\nDecision Tree for Partition {tree.partition_id}:")
        visualize_tree(tree)
    
    # Make predictions and evaluate accuracy
    correct = 0
    total = 0
    
    for row in all_data:
        # Get predictions from all trees
        predictions = [predict_with_tree(row, tree) for tree in trees]
        # Take majority vote
        if predictions:
            counter = collections.Counter(predictions)
            ensemble_prediction = counter.most_common(1)[0][0]
            
            # Update accuracy counts
            total += 1
            if ensemble_prediction == row.get('label'):
                correct += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"\nEnsemble Accuracy: {accuracy:.4f} ({correct}/{total})")
    
    # Save the ensemble for future use (optional)
    with open("exemplar_tree_ensemble.json", "w") as f:
        json.dump([tree.to_dict() for tree in trees], f, indent=2)
    
    spark.stop()

# Run the main function
if __name__ == "__main__":
    main()

Data Path: C:\Users\benat\OneDrive\0. MSc MLiS\0. GitHub Repositories\BigDataProject_repo\fulldataset_ECG5000.csv
Loading 100.0% of data
Data size: 5000

Repartitioning to 10 workers


Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 41.0 failed 1 times, most recent failure: Lost task 3.0 in stage 41.0 (TID 68) (BenAtkinson-Dell-Inspiron3505.Home executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "C:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\python\lib\pyspark.zip\pyspark\worker.py", line 1247, in main
    process()
  File "C:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\python\lib\pyspark.zip\pyspark\worker.py", line 1237, in process
    out_iter = func(split_index, iterator)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\rdd.py", line 5434, in pipeline_func
    return func(split, prev_func(split, iterator))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\rdd.py", line 5434, in pipeline_func
    return func(split, prev_func(split, iterator))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\rdd.py", line 5434, in pipeline_func
    return func(split, prev_func(split, iterator))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\rdd.py", line 840, in func
    return f(iterator)
           ^^^^^^^^^^^
  File "C:\Users\benat\AppData\Local\Temp\ipykernel_30224\2873670571.py", line 28, in choose_exemplars
KeyError: 'time_series'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1049)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2433)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:842)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2458)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1049)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1048)
	at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:195)
	at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:842)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "C:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\python\lib\pyspark.zip\pyspark\worker.py", line 1247, in main
    process()
  File "C:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\python\lib\pyspark.zip\pyspark\worker.py", line 1237, in process
    out_iter = func(split_index, iterator)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\rdd.py", line 5434, in pipeline_func
    return func(split, prev_func(split, iterator))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\rdd.py", line 5434, in pipeline_func
    return func(split, prev_func(split, iterator))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\rdd.py", line 5434, in pipeline_func
    return func(split, prev_func(split, iterator))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\benat\miniconda3\envs\bigdata_env\Lib\site-packages\pyspark\rdd.py", line 840, in func
    return f(iterator)
           ^^^^^^^^^^^
  File "C:\Users\benat\AppData\Local\Temp\ipykernel_30224\2873670571.py", line 28, in choose_exemplars
KeyError: 'time_series'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1049)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2433)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	... 1 more


---

In [24]:
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import array, col
import logging
import os
import time
import json
import collections
from dtaidistance import dtw
import traceback

# Initialize Spark session
spark = SparkSession.builder.master("local[*]").appName("ExemplarTreeEnsemble").getOrCreate()
sc = spark.sparkContext

# Function to convert wide format data to features array
def convert_wide_format_to_features(df):
    """
    Convert a wide-format DataFrame with columns _c1, _c2, etc. to a DataFrame with 'features' array column
    
    Args:
        df: PySpark DataFrame in wide format (each time point is a separate column)
        
    Returns:
        PySpark DataFrame with 'features' and 'label' columns
    """
    # Get all feature columns (excluding label)
    feature_cols = [col_name for col_name in df.columns if col_name != 'label']
    
    # Create a features array column by combining all individual columns
    df_with_features = df.select(
        array(*[col(c) for c in feature_cols]).alias("features"),
        col("label")
    )
    
    return df_with_features

def repartition_sparkdf(df, num_partitions):
    """Repartition a DataFrame and add partition_id"""
    # First convert to RDD and repartition
    rdd = df.rdd.repartition(num_partitions)
    
    # Add partition_id to each row
    def add_partition_id(idx, iterator):
        return ({**row.asDict(), "partition_id": idx} for row in iterator)
    
    return rdd.mapPartitionsWithIndex(add_partition_id)

def choose_exemplars(iterator):
    """Select one exemplar per class in each partition"""
    partition_data = list(iterator)
    if not partition_data:
        return iter([])
    
    # Check the structure of the data to find the correct time series field
    # It could be 'features', or another field
    if partition_data and len(partition_data) > 0:
        sample_row = partition_data[0]
        time_series_field = None
        
        # Find which field contains the time series data
        for key, value in sample_row.items():
            if isinstance(value, (list, np.ndarray)) and key != 'partition_id':
                time_series_field = key
                break
        
        if time_series_field is None:
            print(f"Warning: Cannot find time series data in row: {sample_row}")
            # If we can't identify the time series field, return the data unchanged
            return iter(partition_data)
    else:
        return iter([])
    
    # Group data by class
    grouped_data_by_class = {}
    for row in partition_data:
        label = row.get('label')
        if label is not None:
            if label not in grouped_data_by_class:
                grouped_data_by_class[label] = []
            grouped_data_by_class[label].append(row)
    
    # Select one exemplar per class
    chosen_exemplars = []
    
    for label, instances in grouped_data_by_class.items():
        if instances:  # Ensure there are instances for the class
            # Using first instance as exemplar (more deterministic than random sampling)
            exemplar = instances[0]
            time_series_data = exemplar.get(time_series_field)
            chosen_exemplars.append((time_series_data, label))
    
    # Store exemplar time_series for filtering
    exemplar_time_series = [ex[0] for ex in chosen_exemplars]
    
    # Remove chosen exemplars from the working data
    filtered_partition = []
    for row in partition_data:
        row_time_series = row.get(time_series_field)
        keep = True
        for ex_ts in exemplar_time_series:
            # Compare time series - we need to handle both list and numpy array formats
            if isinstance(row_time_series, (list, np.ndarray)) and isinstance(ex_ts, (list, np.ndarray)):
                # For lists, check if they have the same values
                if len(row_time_series) == len(ex_ts) and all(x == y for x, y in zip(row_time_series, ex_ts)):
                    keep = False
                    break
        if keep:
            filtered_partition.append(row)
    
    # Return rows with individual exemplar columns and their labels
    result = []
    for row in filtered_partition:
        new_row = {**row}
        # Add each exemplar and its label as columns
        for i, (exemplar, label) in enumerate(chosen_exemplars, 1):
            new_row[f"exemplar_{i}"] = exemplar
            new_row[f"exemplar_{i}_label"] = label
        result.append(new_row)
    
    return iter(result)

def calc_dtw_distance(iterator):
    """Calculate DTW distance between time series and exemplars"""
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        # Find the time series field (could be 'features' or other field name)
        time_series_field = None
        time_series = None
        
        for key, value in row.items():
            if isinstance(value, (list, np.ndarray)) and key != 'partition_id' and not key.startswith('exemplar_'):
                time_series_field = key
                time_series = value
                break
        
        if time_series_field is None or time_series is None:
            # If we can't find the time series data, keep the row but skip distance calculation
            updated_rows.append(row)
            continue
        
        # Check for individual exemplar columns
        exemplar_columns = {k: v for k, v in row.items() 
                          if k.startswith('exemplar_') and not k.endswith('_label') and isinstance(v, (list, np.ndarray))}
        
        if not exemplar_columns:
            # If no exemplars found, skip distance calculation but keep the row
            updated_rows.append(row)
            continue
        
        # Calculate DTW distances for each exemplar column
        updated_row = {**row}
        for col_name, exemplar in exemplar_columns.items():
            # Extract index from column name (e.g., "exemplar_1" -> "1")
            idx = col_name.split('_')[1]
            
            # Safely calculate DTW distance with error handling
            try:
                # Ensure both arrays have same dimensions and data types
                ts1 = np.array(time_series, dtype=float)
                ts2 = np.array(exemplar, dtype=float)
                
                # Handle potential empty arrays
                if len(ts1) == 0 or len(ts2) == 0:
                    dtw_distance = float('inf')
                else:
                    # Use DTW distance function from dtaidistance
                    dtw_distance = dtw.distance(ts1, ts2)
                    
                updated_row[f"dtw_distance_exemplar_{idx}"] = dtw_distance
            except Exception as e:
                # Handle any errors gracefully
                updated_row[f"dtw_distance_exemplar_{idx}"] = float('inf')
                print(f"DTW calculation error for {time_series_field}: {e}")
        
        updated_rows.append(updated_row)
    
    return iter(updated_rows)

def assign_closest_exemplar(iterator):
    """Assign each instance to its closest exemplar"""
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        # Get all DTW distances
        dtw_distances = {k: v for k, v in row.items() if k.startswith('dtw_distance_exemplar_')}
        
        if not dtw_distances:
            # Create a simplified row without exemplar columns
            simplified_row = {k: v for k, v in row.items() 
                             if not k.startswith('exemplar_')}
            updated_rows.append(simplified_row)
            continue
        
        # Find the closest exemplar based on the minimum DTW distance
        try:
            closest_exemplar_key = min(dtw_distances.items(), key=lambda x: x[1])[0]
            min_distance = dtw_distances[closest_exemplar_key]
        except ValueError:
            # Handle case with no valid distances
            simplified_row = {k: v for k, v in row.items() 
                             if not k.startswith('exemplar_') and not k.startswith('dtw_distance_exemplar_')}
            updated_rows.append(simplified_row)
            continue
        
        # Extract exemplar number from the key
        exemplar_num = closest_exemplar_key.split('_')[-1]
        
        # Get the corresponding exemplar time series data and its original label
        exemplar_key = f'exemplar_{exemplar_num}'
        exemplar_label_key = f'exemplar_{exemplar_num}_label'
        exemplar_time_series = row.get(exemplar_key, None)
        exemplar_original_label = row.get(exemplar_label_key, None)
        
        # Create a new row without the DTW distance columns and exemplar columns
        updated_row = {k: v for k, v in row.items() 
                      if not k.startswith('dtw_distance_exemplar_') and not k.startswith('exemplar_')}
        
        # Add information about the closest exemplar
        updated_row['closest_exemplar_id'] = exemplar_key
        updated_row['closest_exemplar_data'] = exemplar_time_series
        updated_row['closest_exemplar_original_label'] = exemplar_original_label
        updated_row['min_dtw_distance'] = min_distance
        
        updated_rows.append(updated_row)
    
    return iter(updated_rows)

def calculate_gini(labels):
    """Calculate Gini impurity for a list of labels"""
    if not labels:
        return 0
    label_counts = {}
    for label in labels:
        label_counts[label] = label_counts.get(label, 0) + 1
    total = sum(label_counts.values())
    gini = 1 - sum((count / total) ** 2 for count in label_counts.values()) if total > 0 else 0
    return gini

# Define tree node structure
class TreeNode:
    def __init__(self, node_id=None, split_exemplar_id=None, split_exemplar_data=None, 
                 split_exemplar_label=None, is_leaf=False, predicted_label=None, gini_reduction=None,
                 partition_id=None):
        self.node_id = node_id
        self.split_exemplar_id = split_exemplar_id
        self.split_exemplar_data = split_exemplar_data
        self.split_exemplar_label = split_exemplar_label
        self.gini_reduction = gini_reduction
        self.is_leaf = is_leaf
        self.predicted_label = predicted_label
        self.yes_child = None  # Instances closest to this exemplar
        self.no_child = None   # Instances closest to other exemplars
        self.partition_id = partition_id  # Track which partition built this tree
    
    def to_dict(self):
        """Convert tree to dictionary for serialization"""
        result = {
            'node_id': self.node_id,
            'split_exemplar_id': self.split_exemplar_id,
            'split_exemplar_label': self.split_exemplar_label,
            'gini_reduction': self.gini_reduction,
            'is_leaf': self.is_leaf,
            'predicted_label': self.predicted_label,
            'partition_id': self.partition_id,
        }
        
        # Handle split_exemplar_data separately to make it JSON serializable
        if self.split_exemplar_data is not None:
            if isinstance(self.split_exemplar_data, list):
                result['split_exemplar_data'] = self.split_exemplar_data
            else:
                # Convert numpy arrays or other types to list
                try:
                    result['split_exemplar_data'] = list(self.split_exemplar_data)
                except:
                    result['split_exemplar_data'] = None
        
        if self.yes_child:
            result['yes_child'] = self.yes_child.to_dict()
        if self.no_child:
            result['no_child'] = self.no_child.to_dict()
        
        return result

# Find the best split from all evaluated splits (within a single partition)
def find_best_split(all_splits):
    if not all_splits:
        return None
    return max(all_splits, key=lambda x: x['gini_reduction'])

# Build a tree for a single partition (locally, not distributed)
def build_partition_tree(partition_data, partition_id, max_depth=3, min_samples=2, node_id=1):
    # If too few samples or max depth reached, create a leaf node
    if len(partition_data) < min_samples or max_depth <= 0:
        # Determine the majority class for prediction
        label_counts = collections.Counter([row.get('label') for row in partition_data if row.get('label') is not None])
        if not label_counts:
            return TreeNode(node_id=node_id, is_leaf=True, predicted_label=None, partition_id=partition_id)
        
        majority_class = label_counts.most_common(1)[0][0]
        return TreeNode(node_id=node_id, is_leaf=True, predicted_label=majority_class, partition_id=partition_id)
    
    # Evaluate all possible splits
    labels = [row.get('label') for row in partition_data if row.get('label') is not None]
    before_split_gini = calculate_gini(labels)
    
    # Get all unique exemplars in the partition
    unique_exemplars = set(
        row['closest_exemplar_id'] for row in partition_data
        if row.get('closest_exemplar_id') is not None
    )
    
    # Create a mapping of exemplar_id to exemplar_data and exemplar_label
    exemplar_data_map = {}
    exemplar_label_map = {}
    
    for row in partition_data:
        exemplar_id = row.get('closest_exemplar_id')
        if exemplar_id and exemplar_id not in exemplar_data_map:
            exemplar_data_map[exemplar_id] = row.get('closest_exemplar_data')
            exemplar_label_map[exemplar_id] = row.get('closest_exemplar_original_label')
    
    # Evaluate all possible splits
    splits = []
    for exemplar_id in unique_exemplars:
        # Split the data based on the current exemplar
        yes_split = [r for r in partition_data if r.get('closest_exemplar_id') == exemplar_id]
        no_split = [r for r in partition_data if r.get('closest_exemplar_id') != exemplar_id]
        
        # Calculate Gini for each daughter node
        yes_labels = [r.get('label') for r in yes_split if r.get('label') is not None]
        no_labels = [r.get('label') for r in no_split if r.get('label') is not None]
        
        yes_gini = calculate_gini(yes_labels)
        no_gini = calculate_gini(no_labels)
        
        # Calculate weighted Gini after split
        total_size = len(yes_split) + len(no_split)
        weighted_gini = (yes_gini * len(yes_split) / total_size + 
                         no_gini * len(no_split) / total_size) if total_size > 0 else float('inf')
        
        # Calculate Gini reduction
        gini_reduction = before_split_gini - weighted_gini if total_size > 0 else float('-inf')
        
        # Get the exemplar data and label for this exemplar_id
        exemplar_data = exemplar_data_map.get(exemplar_id)
        exemplar_label = exemplar_label_map.get(exemplar_id)
        
        # Add this split evaluation to results
        splits.append({
            "partition_id": partition_id,
            "exemplar_id": exemplar_id,
            "exemplar_data": exemplar_data,
            "exemplar_label": exemplar_label,
            "gini_reduction": gini_reduction   
        })
    
    # Find best split
    best_split = find_best_split(splits)
    
    # If no good split found, create a leaf node
    if not best_split or best_split['gini_reduction'] <= 0:
        label_counts = collections.Counter([row.get('label') for row in partition_data if row.get('label') is not None])
        if not label_counts:
            return TreeNode(node_id=node_id, is_leaf=True, predicted_label=None, partition_id=partition_id)
        
        majority_class = label_counts.most_common(1)[0][0]
        return TreeNode(node_id=node_id, is_leaf=True, predicted_label=majority_class, partition_id=partition_id)
    
    # Create a decision node
    node = TreeNode(
        node_id=node_id,
        split_exemplar_id=best_split['exemplar_id'],
        split_exemplar_data=best_split['exemplar_data'],
        split_exemplar_label=best_split['exemplar_label'],
        gini_reduction=best_split['gini_reduction'],
        partition_id=partition_id
    )
    
    # Split the data based on the best exemplar
    yes_data = [r for r in partition_data if r.get('closest_exemplar_id') == best_split['exemplar_id']]
    no_data = [r for r in partition_data if r.get('closest_exemplar_id') != best_split['exemplar_id']]
    
    # Recursively build the subtrees
    node.yes_child = build_partition_tree(yes_data, partition_id, max_depth-1, min_samples, node_id=node_id*2)
    node.no_child = build_partition_tree(no_data, partition_id, max_depth-1, min_samples, node_id=node_id*2+1)
    
    return node

# Function to visualize a single tree
def visualize_tree(node, indent=""):
    if node is None:
        return
    
    if node.is_leaf:
        print(f"{indent}Leaf: Predict class {node.predicted_label}")
    else:
        # Format the exemplar data nicely for display
        exemplar_data_str = "[...]"  # Simplified display for potentially large time series
        if node.split_exemplar_data and len(node.split_exemplar_data) <= 10:
            exemplar_data_str = "[" + ", ".join(f"{val:.1f}" for val in node.split_exemplar_data[:10]) + "]"
        
        print(f"{indent}Is closest to exemplar {node.split_exemplar_id} (class {node.split_exemplar_label})?")
        print(f"{indent}Gini reduction: {node.gini_reduction:.4f}")
        
        print(f"{indent}Yes ->")
        visualize_tree(node.yes_child, indent + "  ")
        
        print(f"{indent}No ->")
        visualize_tree(node.no_child, indent + "  ")

# Function to predict with a single tree
def predict_with_tree(row, tree):
    current_node = tree
    
    while current_node and not current_node.is_leaf:
        closest_exemplar = row.get('closest_exemplar_id')
        
        if closest_exemplar == current_node.split_exemplar_id:
            current_node = current_node.yes_child
        else:
            current_node = current_node.no_child
        
        # Safety check in case the tree structure is incomplete
        if current_node is None:
            return None
    
    return current_node.predicted_label if current_node else None

def main():
    try:
        # Initialize configuration
        config = {
            "data_path": "fulldataset_ECG5000.csv",  # Update with your actual path
            "label_col": "label",
            "num_partitions": 5,
            "max_depth": 3,
            "min_samples_split": 2
        }
        
        # Print all available files in the current directory to help debug file path issues
        import os
        print("Files in current directory:")
        for file in os.listdir('.'):
            print(f"  - {file}")
        
        print("Step 1: Loading and preprocessing data...")
        
        df = ingestion.load_data()
        df = preprocessor.run_preprocessing(df)
            
        print("Step 2: Repartitioning data...")
        # Prepare data for exemplar tree algorithm
        rdd = repartition_sparkdf(df, config["num_partitions"])
        
        print("Step 3: Choosing exemplars in each partition...")
        # Choose exemplars
        rdd_with_exemplars = rdd.mapPartitions(choose_exemplars)
        
        print("Step 4: Calculating DTW distances...")
        # Calculate DTW distances
        rdd_with_dtw = rdd_with_exemplars.mapPartitions(calc_dtw_distance)
        
        print("Step 5: Assigning closest exemplars...")
        # Assign closest exemplars
        rdd_with_closest_exemplar = rdd_with_dtw.mapPartitions(assign_closest_exemplar)
        
        print("Step 6: Collecting results...")
        # Apply take() instead of collect() to limit the amount of data transferred
        # This helps avoid memory issues with large datasets
        all_data = rdd_with_closest_exemplar.take(500)  # Take a sample first to test
        
        print(f"Collected {len(all_data)} rows")
        
        if len(all_data) == 0:
            print("WARNING: No data was collected after processing. Check for errors in the previous steps.")
            return
            
        # Print a sample row to help debug
        if all_data:
            print("Sample processed row structure:")
            sample_row = all_data[0]
            print(f"Keys: {list(sample_row.keys())}")
            for key, value in sample_row.items():
                if isinstance(value, (list, np.ndarray)):
                    print(f"{key}: [Array of length {len(value)}]")
                else:
                    print(f"{key}: {value}")
        
        print("Step 7: Grouping by partition...")
        # Group by partition
        partitions = {}
        for row in all_data:
            pid = row.get('partition_id')
            if pid is not None:
                if pid not in partitions:
                    partitions[pid] = []
                partitions[pid].append(row)
        
        print(f"Number of partitions with data: {len(partitions)}")
        
        print("Step 8: Building trees for each partition...")
        # Build trees for each partition
        trees = []
        for pid, data in partitions.items():
            if data:  # Make sure there's data in this partition
                print(f"Building tree for partition {pid} with {len(data)} rows...")
                tree = build_partition_tree(
                    data, 
                    pid, 
                    max_depth=config["max_depth"], 
                    min_samples=config["min_samples_split"]
                )
                trees.append(tree)
        
        print(f"Built {len(trees)} trees")
        
        print("Step 9: Visualizing sample trees...")
        # Visualize each tree (limit to first 2 trees for brevity)
        for i, tree in enumerate(trees[:2]):
            print(f"\nDecision Tree for Partition {tree.partition_id}:")
            visualize_tree(tree)
        
        print("Step 10: Making predictions and evaluating accuracy...")
        # Make predictions and evaluate accuracy
        correct = 0
        total = 0
        
        for row in all_data:
            # Get predictions from all trees
            predictions = [predict_with_tree(row, tree) for tree in trees]
            # Remove None predictions
            predictions = [p for p in predictions if p is not None]
            
            # Take majority vote if we have predictions
            if predictions:
                counter = collections.Counter(predictions)
                ensemble_prediction = counter.most_common(1)[0][0]
                
                # Update accuracy counts
                actual_label = row.get('label')
                if actual_label is not None:
                    total += 1
                    if ensemble_prediction == actual_label:
                        correct += 1
        
        accuracy = correct / total if total > 0 else 0
        print(f"\nEnsemble Accuracy: {accuracy:.4f} ({correct}/{total})")
        
        print("Step 11: Saving the ensemble model...")
        # Save the ensemble for future use (optional)
        try:
            with open("exemplar_tree_ensemble.json", "w") as f:
                json.dump([tree.to_dict() for tree in trees], f)
            print("Model saved successfully to exemplar_tree_ensemble.json")
        except Exception as e:
            print(f"Error saving model: {e}")
        
    except Exception as e:
        print(f"Error in main execution: {e}")
        import traceback
        traceback.print_exc()
    finally:
        print("Stopping Spark session...")

# Run the main function
if __name__ == "__main__":
    main()

Files in current directory:
  - complete_global.ipynb
  - config.py
  - controller_loop.py
  - data_ingestion.py
  - distance_measures.py
  - evaluation.py
  - exemplar_tree.json
  - exemplar_tree_ensemble.json
  - global.ipynb
  - global_model_manager.py
  - local_model_manager.py
  - main.py
  - partition_metrics
  - prediction_manager.py
  - preprocessing.py
  - RDD_BA_2.ipynb
  - rdd_global_trialling_stuff.ipynb
  - reports
  - test.ipynb
  - test_pipeline_local_model.ipynb
  - Tyler_global.py
  - utilities.py
  - visualization.py
  - __pycache__
Step 1: Loading and preprocessing data...
Data Path: C:\Users\benat\OneDrive\0. MSc MLiS\0. GitHub Repositories\BigDataProject_repo\fulldataset_ECG5000.csv
Loading 100.0% of data
Data size: 5000

Repartitioning to 10 workers
Step 2: Repartitioning data...
Step 3: Choosing exemplars in each partition...
Step 4: Calculating DTW distances...
Step 5: Assigning closest exemplars...
Step 6: Collecting results...
Collected 500 rows
Sample process