In [2]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging

from pyspark.sql import SparkSession
import os
from pyspark.sql import SparkSession
from data_ingestion import DataIngestion
from preprocessing import Preprocessor
from prediction_manager import PredictionManager
from local_model_manager import LocalModelManager
from evaluation import Evaluator
from utilities import show_compact
import time
import json
from random import sample
from dtaidistance import dtw

from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[*]").appName("GenericRDD").getOrCreate()

# Access the SparkContext
sc = spark.sparkContext

In [98]:
tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]},
    {'label': 1, 'time_series': [0.9, 1.4, 1.9, 2.4, 2.9]},
    {'label': 1, 'time_series': [1.4, 2.0, 2.6, 3.2, 3.8]},
    {'label': 2, 'time_series': [2.2, 3.1, 4.0, 4.9, 5.8]},
    {'label': 2, 'time_series': [2.6, 3.2, 3.8, 4.4, 5.0]},
    {'label': 3, 'time_series': [1.2, 2.0, 2.8, 3.6, 4.4]},
    {'label': 3, 'time_series': [0.6, 1.3, 2.0, 2.7, 3.4]},
    {'label': 4, 'time_series': [6.3, 6.5, 6.7, 6.9, 7.1]},
    {'label': 4, 'time_series': [7.0, 7.8, 8.6, 9.4, 10.2]},
    {'label': 4, 'time_series': [6.5, 7.0, 7.5, 8.0, 8.5]},
    {'label': 1, 'time_series': [0.5, 1.0, 1.5, 2.0, 2.5]},
    {'label': 2, 'time_series': [0.6, 1.4, 1.3, 2.1, 2.5]},
    {'label': 3, 'time_series': [0.3, 1.7, 1.6, 2.2, 2.6]},
    {'label': 4, 'time_series': [0.2, 1.9, 1.6, 2.3, 2.7]},
    {'label': 4, 'time_series': [0.9, 1.7, 1.2, 2.4, 2.8]}
]

df = spark.createDataFrame(tsdata)

# 1

In [99]:
def repartition_sparkdf(df, num_partitions):
    rdd = df.rdd
    rdd = rdd.repartition(num_partitions)
    rdd = rdd.mapPartitionsWithIndex(
            lambda idx, iter: [{**row.asDict(), "partition_id": idx} for row in iter]
        )
    return rdd

# example usage
rdd = repartition_sparkdf(df, 2)
print(rdd.getNumPartitions())  # should print 1
rdd.collect()

2


[{'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5], 'partition_id': 0},
 {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0], 'partition_id': 0},
 {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9], 'partition_id': 0},
 {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5], 'partition_id': 0},
 {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1], 'partition_id': 0},
 {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1], 'partition_id': 0},
 {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0], 'partition_id': 0},
 {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8], 'partition_id': 0},
 {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5], 'partition_id': 0},
 {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2], 'partition_id': 0},
 {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0], 'partition_id': 0},
 {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3], 'partition_id': 0},
 {'label': 1, 'time_series': [0.9, 1.4, 1.9, 2.4, 2.9], 'partition_id': 0},
 {'label': 

# 2

In [109]:
def choose_exemplars(iterator):
    partition_data = list(iterator)
    if not partition_data:
        return iter([])
    
    # Group data by class
    grouped_data_by_class = {}
    for row in partition_data:
        label = row.get('label')
        if label is not None:
            grouped_data_by_class.setdefault(label, []).append(row)
    
    # Select one exemplar per class and track exemplar-to-label mapping
    chosen_exemplars = []
    exemplar_labels = {}  # Map to track exemplar labels
    
    for label, instances in grouped_data_by_class.items():
        if instances:  # Ensure there are instances for the class
            exemplar = sample(instances, 1)[0]
            chosen_exemplars.append((exemplar['time_series'], label))  # Store tuple of (time_series, label)
    
    # Remove chosen exemplars from the working data
    exemplar_time_series = [ex[0] for ex in chosen_exemplars]
    filtered_partition = [
        row for row in partition_data
        if row['time_series'] not in exemplar_time_series
    ]
    
    # Return rows with individual exemplar columns and their labels
    result = []
    for row in filtered_partition:
        new_row = {**row}
        # Add each exemplar and its label as columns
        for i, (exemplar, label) in enumerate(chosen_exemplars, 1):
            new_row[f"exemplar_{i}"] = exemplar
            new_row[f"exemplar_{i}_label"] = label
        result.append(new_row)
    
    return iter(result)

# example usage
rdd_with_exemplars = rdd.mapPartitions(choose_exemplars)
rdd_with_exemplars.collect()

[{'label': 3,
  'time_series': [2.0, 2.5, 3.0, 3.5, 4.0],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_1_label': 3,
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_2_label': 4,
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_3_label': 1,
  'exemplar_4': [1.9, 2.8, 3.7, 4.6, 5.5],
  'exemplar_4_label': 2},
 {'label': 4,
  'time_series': [5.5, 6.6, 7.7, 8.8, 9.9],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_1_label': 3,
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_2_label': 4,
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_3_label': 1,
  'exemplar_4': [1.9, 2.8, 3.7, 4.6, 5.5],
  'exemplar_4_label': 2},
 {'label': 4,
  'time_series': [6.1, 6.2, 6.3, 6.4, 6.5],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_1_label': 3,
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_2_label': 4,
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_3_label': 1,
  'exempla

# 3

In [110]:
def calc_dtw_distance(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        time_series = row.get('time_series', [])
        
        # Check for individual exemplar columns (exemplar_1, exemplar_2, etc.)
        exemplar_columns = {k: v for k, v in row.items() if k.startswith('exemplar_') and isinstance(v, list)}
        
        if not exemplar_columns:
            # Try to get exemplars from the 'exemplars' list if individual columns aren't found
            exemplars = row.get('exemplars', [])
            if not exemplars:
                continue  # Skip if no exemplars found
            
            # Calculate DTW distances for each exemplar in the list
            updated_row = {**row}
            for i, exemplar in enumerate(exemplars):
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{i+1}"] = dtw_distance
            
            updated_rows.append(updated_row)
        else:
            # Calculate DTW distances for each exemplar column
            updated_row = {**row}
            for col_name, exemplar in exemplar_columns.items():
                # Extract index from column name (e.g., "exemplar_1" -> "1")
                idx = col_name.split('_')[1]
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{idx}"] = dtw_distance
            
            updated_rows.append(updated_row)
    
    return iter(updated_rows)

# example usage
rdd_with_dtw = rdd_with_exemplars.mapPartitions(calc_dtw_distance)
rdd_with_dtw.collect()[:3]

[{'label': 3,
  'time_series': [2.0, 2.5, 3.0, 3.5, 4.0],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_1_label': 3,
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_2_label': 4,
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_3_label': 1,
  'exemplar_4': [1.9, 2.8, 3.7, 4.6, 5.5],
  'exemplar_4_label': 2,
  'dtw_distance_exemplar_1': 1.7320508075688772,
  'dtw_distance_exemplar_2': 11.20267825120404,
  'dtw_distance_exemplar_3': 0.7348469228349536,
  'dtw_distance_exemplar_4': 1.6703293088490065},
 {'label': 4,
  'time_series': [5.5, 6.6, 7.7, 8.8, 9.9],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_1_label': 3,
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_2_label': 4,
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_3_label': 1,
  'exemplar_4': [1.9, 2.8, 3.7, 4.6, 5.5],
  'exemplar_4_label': 2,
  'dtw_distance_exemplar_1': 10.78424777163433,
  'dtw_distance_exemplar_2': 0.7416198487095661,
  '

# 4

In [111]:
def assign_closest_exemplar(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        # Get all DTW distances
        dtw_distances = {k: v for k, v in row.items() if k.startswith('dtw_distance_exemplar_')}
        
        if not dtw_distances:
            # Create a simplified row without exemplar columns
            simplified_row = {k: v for k, v in row.items() 
                             if not k.startswith('exemplar_')}
            updated_rows.append(simplified_row)
            continue
        
        # Find the closest exemplar based on the minimum DTW distance
        closest_exemplar_key = min(dtw_distances, key=dtw_distances.get)
        min_distance = dtw_distances[closest_exemplar_key]
        
        # Extract exemplar number from the key (e.g., "dtw_distance_exemplar_1" -> "1")
        exemplar_num = closest_exemplar_key.split('_')[-1]
        
        # Get the corresponding exemplar time series data and its original label
        exemplar_key = f'exemplar_{exemplar_num}'
        exemplar_label_key = f'exemplar_{exemplar_num}_label'
        exemplar_time_series = row.get(exemplar_key, None)
        exemplar_original_label = row.get(exemplar_label_key, None)
        
        # Create a new row without the DTW distance columns and exemplar columns
        updated_row = {k: v for k, v in row.items() 
                      if not k.startswith('dtw_distance_exemplar_') and not k.startswith('exemplar_')}
        
        # Add information about the closest exemplar
        updated_row['closest_exemplar_id'] = closest_exemplar_key
        updated_row['closest_exemplar_data'] = exemplar_time_series
        updated_row['closest_exemplar_original_label'] = exemplar_original_label
        
        updated_rows.append(updated_row)
    
    return iter(updated_rows)

# Example usage
rdd_with_closest_exemplar = rdd_with_dtw.mapPartitions(assign_closest_exemplar)
rdd_with_closest_exemplar.collect()

[{'label': 3,
  'time_series': [2.0, 2.5, 3.0, 3.5, 4.0],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_3',
  'closest_exemplar_data': [1.4, 2.0, 2.6, 3.2, 3.8],
  'closest_exemplar_original_label': 1},
 {'label': 4,
  'time_series': [5.5, 6.6, 7.7, 8.8, 9.9],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_2',
  'closest_exemplar_data': [6.0, 7.0, 8.0, 9.0, 10.0],
  'closest_exemplar_original_label': 4},
 {'label': 4,
  'time_series': [6.1, 6.2, 6.3, 6.4, 6.5],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_2',
  'closest_exemplar_data': [6.0, 7.0, 8.0, 9.0, 10.0],
  'closest_exemplar_original_label': 4},
 {'label': 1,
  'time_series': [0.7, 1.3, 1.9, 2.5, 3.1],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_3',
  'closest_exemplar_data': [1.4, 2.0, 2.6, 3.2, 3.8],
  'closest_exemplar_original_label': 1},
 {'label': 1,
  'time_series': [1.1, 2.1, 3.1, 4.1, 5.1],
  'partition_id': 0,
  'closest_exem

# 5

In [112]:
def calculate_gini(labels):
    if not labels:
        return 0
    label_counts = {}
    for label in labels:
        label_counts[label] = label_counts.get(label, 0) + 1
    total = sum(label_counts.values())
    gini = 1 - sum((count / total) ** 2 for count in label_counts.values()) if total > 0 else 0
    return gini

def evaluate_splits_within_partition(index, iterator):
    partition_data = list(iterator)
    results = []
    
    # Calculate Gini impurity before splitting
    labels = [row.get('label') for row in partition_data if row.get('label') is not None]
    before_split_gini = calculate_gini(labels)
    
    # Get all unique exemplars in the partition
    unique_exemplars = set(
        row['closest_exemplar_id'] for row in partition_data
        if row.get('closest_exemplar_id') is not None
    )
    
    # Create a mapping of exemplar_id to exemplar_data and exemplar_label
    exemplar_data_map = {}
    exemplar_label_map = {}
    
    for row in partition_data:
        exemplar_id = row.get('closest_exemplar_id')
        if exemplar_id and exemplar_id not in exemplar_data_map:
            exemplar_data_map[exemplar_id] = row.get('closest_exemplar_data')
            # Use the original exemplar label instead of the row's label
            exemplar_label_map[exemplar_id] = row.get('closest_exemplar_original_label')
    
    # Evaluate all possible splits
    for exemplar_id in unique_exemplars:
        # Split the data based on the current exemplar
        yes_split = [r for r in partition_data if r.get('closest_exemplar_id') == exemplar_id]
        no_split = [r for r in partition_data if r.get('closest_exemplar_id') != exemplar_id]
        
        # Calculate Gini for each daughter node
        yes_labels = [r.get('label') for r in yes_split if r.get('label') is not None]
        no_labels = [r.get('label') for r in no_split if r.get('label') is not None]
        
        yes_gini = calculate_gini(yes_labels)
        no_gini = calculate_gini(no_labels)
        
        # Calculate weighted Gini after split
        total_size = len(yes_split) + len(no_split)
        weighted_gini = (yes_gini * len(yes_split) / total_size + no_gini * len(no_split) / total_size) if total_size > 0 else float('inf')
        
        # Calculate Gini reduction
        gini_reduction = before_split_gini - weighted_gini if total_size > 0 else float('-inf')
        
        # Get the exemplar data and label for this exemplar_id
        exemplar_data = exemplar_data_map.get(exemplar_id)
        exemplar_label = exemplar_label_map.get(exemplar_id)
        
        # Add this split evaluation to results
        results.append({
            "partition_id": index,
            "exemplar_id": exemplar_id,
            "exemplar_data": exemplar_data,
            "exemplar_label": exemplar_label,
            "gini_reduction": gini_reduction   
        })
    
    return iter(results)

# Example usage
rdd_with_splits = rdd_with_closest_exemplar.mapPartitionsWithIndex(evaluate_splits_within_partition)
rdd_with_splits.collect()

[{'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_4',
  'exemplar_data': [1.9, 2.8, 3.7, 4.6, 5.5],
  'exemplar_label': 2,
  'gini_reduction': 0.1376420454545454},
 {'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_3',
  'exemplar_data': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_label': 1,
  'gini_reduction': 0.16679067460317454},
 {'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_2',
  'exemplar_data': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_label': 4,
  'gini_reduction': 0.25260416666666663},
 {'partition_id': 1,
  'exemplar_id': 'dtw_distance_exemplar_4',
  'exemplar_data': [0.3, 1.7, 1.6, 2.2, 2.6],
  'exemplar_label': 3,
  'gini_reduction': 0.09499999999999997},
 {'partition_id': 1,
  'exemplar_id': 'dtw_distance_exemplar_1',
  'exemplar_data': [1.2, 2.4, 3.6, 4.8, 6.0],
  'exemplar_label': 1,
  'gini_reduction': 0.04222222222222227},
 {'partition_id': 1,
  'exemplar_id': 'dtw_distance_exemplar_3',
  'exemplar_data': [0.9, 1.7, 1.2, 2.4, 2.8],
  'exempl

# ---

In [116]:
from random import sample
import json
from pyspark.sql import SparkSession
import collections

# Create initial test data
tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]},
    {'label': 1, 'time_series': [0.9, 1.4, 1.9, 2.4, 2.9]},
    {'label': 1, 'time_series': [1.4, 2.0, 2.6, 3.2, 3.8]},
    {'label': 2, 'time_series': [2.2, 3.1, 4.0, 4.9, 5.8]},
    {'label': 2, 'time_series': [2.6, 3.2, 3.8, 4.4, 5.0]},
    {'label': 3, 'time_series': [1.2, 2.0, 2.8, 3.6, 4.4]},
    {'label': 3, 'time_series': [0.6, 1.3, 2.0, 2.7, 3.4]},
    {'label': 4, 'time_series': [6.3, 6.5, 6.7, 6.9, 7.1]},
    {'label': 4, 'time_series': [7.0, 7.8, 8.6, 9.4, 10.2]},
    {'label': 4, 'time_series': [6.5, 7.0, 7.5, 8.0, 8.5]},
    {'label': 1, 'time_series': [0.5, 1.0, 1.5, 2.0, 2.5]},
    {'label': 2, 'time_series': [0.6, 1.4, 1.3, 2.1, 2.5]},
    {'label': 3, 'time_series': [0.3, 1.7, 1.6, 2.2, 2.6]},
    {'label': 4, 'time_series': [0.2, 1.9, 1.6, 2.3, 2.7]},
    {'label': 4, 'time_series': [0.9, 1.7, 1.2, 2.4, 2.8]}
]

def repartition_sparkdf(df, num_partitions):
    rdd = df.rdd
    rdd = rdd.repartition(num_partitions)
    rdd = rdd.mapPartitionsWithIndex(
            lambda idx, iter: [{**row.asDict(), "partition_id": idx} for row in iter]
        )
    return rdd

def choose_exemplars(iterator):
    partition_data = list(iterator)
    if not partition_data:
        return iter([])
    
    # Group data by class
    grouped_data_by_class = {}
    for row in partition_data:
        label = row.get('label')
        if label is not None:
            grouped_data_by_class.setdefault(label, []).append(row)
    
    # Select one exemplar per class and track exemplar-to-label mapping
    chosen_exemplars = []
    exemplar_labels = {}  # Map to track exemplar labels
    
    for label, instances in grouped_data_by_class.items():
        if instances:  # Ensure there are instances for the class
            exemplar = sample(instances, 1)[0]
            chosen_exemplars.append((exemplar['time_series'], label))  # Store tuple of (time_series, label)
    
    # Remove chosen exemplars from the working data
    exemplar_time_series = [ex[0] for ex in chosen_exemplars]
    filtered_partition = [
        row for row in partition_data
        if row['time_series'] not in exemplar_time_series
    ]
    
    # Return rows with individual exemplar columns and their labels
    result = []
    for row in filtered_partition:
        new_row = {**row}
        # Add each exemplar and its label as columns
        for i, (exemplar, label) in enumerate(chosen_exemplars, 1):
            new_row[f"exemplar_{i}"] = exemplar
            new_row[f"exemplar_{i}_label"] = label
        result.append(new_row)
    
    return iter(result)

def calc_dtw_distance(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        time_series = row.get('time_series', [])
        
        # Check for individual exemplar columns (exemplar_1, exemplar_2, etc.)
        exemplar_columns = {k: v for k, v in row.items() 
                           if k.startswith('exemplar_') and not k.endswith('_label') and isinstance(v, list)}
        
        if not exemplar_columns:
            # Try to get exemplars from the 'exemplars' list if individual columns aren't found
            exemplars = row.get('exemplars', [])
            if not exemplars:
                continue  # Skip if no exemplars found
            
            # Calculate DTW distances for each exemplar in the list
            updated_row = {**row}
            for i, exemplar in enumerate(exemplars):
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{i+1}"] = dtw_distance
            
            updated_rows.append(updated_row)
        else:
            # Calculate DTW distances for each exemplar column
            updated_row = {**row}
            for col_name, exemplar in exemplar_columns.items():
                # Extract index from column name (e.g., "exemplar_1" -> "1")
                idx = col_name.split('_')[1]
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{idx}"] = dtw_distance
            
            updated_rows.append(updated_row)
    
    return iter(updated_rows)

def assign_closest_exemplar(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        # Get all DTW distances
        dtw_distances = {k: v for k, v in row.items() if k.startswith('dtw_distance_exemplar_')}
        
        if not dtw_distances:
            # Create a simplified row without exemplar columns
            simplified_row = {k: v for k, v in row.items() 
                             if not k.startswith('exemplar_')}
            updated_rows.append(simplified_row)
            continue
        
        # Find the closest exemplar based on the minimum DTW distance
        closest_exemplar_key = min(dtw_distances, key=dtw_distances.get)
        min_distance = dtw_distances[closest_exemplar_key]
        
        # Extract exemplar number from the key (e.g., "dtw_distance_exemplar_1" -> "1")
        exemplar_num = closest_exemplar_key.split('_')[-1]
        
        # Get the corresponding exemplar time series data and its original label
        exemplar_key = f'exemplar_{exemplar_num}'
        exemplar_label_key = f'exemplar_{exemplar_num}_label'
        exemplar_time_series = row.get(exemplar_key, None)
        exemplar_original_label = row.get(exemplar_label_key, None)
        
        # Create a new row without the DTW distance columns and exemplar columns
        updated_row = {k: v for k, v in row.items() 
                      if not k.startswith('dtw_distance_exemplar_') and not k.startswith('exemplar_')}
        
        # Add information about the closest exemplar
        updated_row['closest_exemplar_id'] = closest_exemplar_key
        updated_row['closest_exemplar_data'] = exemplar_time_series
        updated_row['closest_exemplar_original_label'] = exemplar_original_label
        
        updated_rows.append(updated_row)
    
    return iter(updated_rows)

def calculate_gini(labels):
    if not labels:
        return 0
    label_counts = {}
    for label in labels:
        label_counts[label] = label_counts.get(label, 0) + 1
    total = sum(label_counts.values())
    gini = 1 - sum((count / total) ** 2 for count in label_counts.values()) if total > 0 else 0
    return gini

# Define tree node structure
class TreeNode:
    def __init__(self, node_id=None, split_exemplar_id=None, split_exemplar_data=None, 
                 split_exemplar_label=None, is_leaf=False, predicted_label=None, gini_reduction=None,
                 partition_id=None):
        self.node_id = node_id
        self.split_exemplar_id = split_exemplar_id
        self.split_exemplar_data = split_exemplar_data
        self.split_exemplar_label = split_exemplar_label
        self.gini_reduction = gini_reduction
        self.is_leaf = is_leaf
        self.predicted_label = predicted_label
        self.yes_child = None  # Instances closest to this exemplar
        self.no_child = None   # Instances closest to other exemplars
        self.partition_id = partition_id  # Track which partition built this tree
    
    def to_dict(self):
        """Convert tree to dictionary for serialization"""
        result = {
            'node_id': self.node_id,
            'split_exemplar_id': self.split_exemplar_id,
            'split_exemplar_data': self.split_exemplar_data,
            'split_exemplar_label': self.split_exemplar_label,
            'gini_reduction': self.gini_reduction,
            'is_leaf': self.is_leaf,
            'predicted_label': self.predicted_label,
            'partition_id': self.partition_id,
        }
        
        if self.yes_child:
            result['yes_child'] = self.yes_child.to_dict()
        if self.no_child:
            result['no_child'] = self.no_child.to_dict()
        
        return result

# Find the best split from all evaluated splits (within a single partition)
def find_best_split(all_splits):
    if not all_splits:
        return None
    return max(all_splits, key=lambda x: x['gini_reduction'])

# Build a tree for a single partition (locally, not distributed)
def build_partition_tree(partition_data, partition_id, max_depth=3, min_samples=2, node_id=1):
    # If too few samples or max depth reached, create a leaf node
    if len(partition_data) < min_samples or max_depth <= 0:
        # Determine the majority class for prediction
        label_counts = collections.Counter([row.get('label') for row in partition_data if row.get('label') is not None])
        if not label_counts:
            return TreeNode(node_id=node_id, is_leaf=True, predicted_label=None, partition_id=partition_id)
        
        majority_class = label_counts.most_common(1)[0][0]
        return TreeNode(node_id=node_id, is_leaf=True, predicted_label=majority_class, partition_id=partition_id)
    
    # Evaluate all possible splits (reuse your existing function logic)
    labels = [row.get('label') for row in partition_data if row.get('label') is not None]
    before_split_gini = calculate_gini(labels)
    
    # Get all unique exemplars in the partition
    unique_exemplars = set(
        row['closest_exemplar_id'] for row in partition_data
        if row.get('closest_exemplar_id') is not None
    )
    
    # Create a mapping of exemplar_id to exemplar_data and exemplar_label
    exemplar_data_map = {}
    exemplar_label_map = {}
    
    for row in partition_data:
        exemplar_id = row.get('closest_exemplar_id')
        if exemplar_id and exemplar_id not in exemplar_data_map:
            exemplar_data_map[exemplar_id] = row.get('closest_exemplar_data')
            exemplar_label_map[exemplar_id] = row.get('closest_exemplar_original_label')
    
    # Evaluate all possible splits
    splits = []
    for exemplar_id in unique_exemplars:
        # Split the data based on the current exemplar
        yes_split = [r for r in partition_data if r.get('closest_exemplar_id') == exemplar_id]
        no_split = [r for r in partition_data if r.get('closest_exemplar_id') != exemplar_id]
        
        # Calculate Gini for each daughter node
        yes_labels = [r.get('label') for r in yes_split if r.get('label') is not None]
        no_labels = [r.get('label') for r in no_split if r.get('label') is not None]
        
        yes_gini = calculate_gini(yes_labels)
        no_gini = calculate_gini(no_labels)
        
        # Calculate weighted Gini after split
        total_size = len(yes_split) + len(no_split)
        weighted_gini = (yes_gini * len(yes_split) / total_size + no_gini * len(no_split) / total_size) if total_size > 0 else float('inf')
        
        # Calculate Gini reduction
        gini_reduction = before_split_gini - weighted_gini if total_size > 0 else float('-inf')
        
        # Get the exemplar data and label for this exemplar_id
        exemplar_data = exemplar_data_map.get(exemplar_id)
        exemplar_label = exemplar_label_map.get(exemplar_id)
        
        # Add this split evaluation to results
        splits.append({
            "partition_id": partition_id,
            "exemplar_id": exemplar_id,
            "exemplar_data": exemplar_data,
            "exemplar_label": exemplar_label,
            "gini_reduction": gini_reduction   
        })
    
    # Find best split
    best_split = find_best_split(splits)
    
    # If no good split found, create a leaf node
    if not best_split or best_split['gini_reduction'] <= 0:
        label_counts = collections.Counter([row.get('label') for row in partition_data if row.get('label') is not None])
        if not label_counts:
            return TreeNode(node_id=node_id, is_leaf=True, predicted_label=None, partition_id=partition_id)
        
        majority_class = label_counts.most_common(1)[0][0]
        return TreeNode(node_id=node_id, is_leaf=True, predicted_label=majority_class, partition_id=partition_id)
    
    # Create a decision node
    node = TreeNode(
        node_id=node_id,
        split_exemplar_id=best_split['exemplar_id'],
        split_exemplar_data=best_split['exemplar_data'],
        split_exemplar_label=best_split['exemplar_label'],
        gini_reduction=best_split['gini_reduction'],
        partition_id=partition_id
    )
    
    # Split the data based on the best exemplar
    yes_data = [r for r in partition_data if r.get('closest_exemplar_id') == best_split['exemplar_id']]
    no_data = [r for r in partition_data if r.get('closest_exemplar_id') != best_split['exemplar_id']]
    
    # Recursively build the subtrees
    node.yes_child = build_partition_tree(yes_data, partition_id, max_depth-1, min_samples, node_id=node_id*2)
    node.no_child = build_partition_tree(no_data, partition_id, max_depth-1, min_samples, node_id=node_id*2+1)
    
    return node

# Function to visualize a single tree
def visualize_tree(node, indent=""):
    if node is None:
        return
    
    if node.is_leaf:
        print(f"{indent}Leaf: Predict class {node.predicted_label}")
    else:
        # Format the exemplar data nicely for display
        exemplar_data_str = "[" + ", ".join(f"{val:.1f}" for val in node.split_exemplar_data) + "]"
        
        print(f"{indent}Is closest to exemplar {node.split_exemplar_id} (class {node.split_exemplar_label})?")
        print(f"{indent}Data: {exemplar_data_str}, Gini reduction: {node.gini_reduction:.4f}")
        
        print(f"{indent}Yes ->")
        visualize_tree(node.yes_child, indent + "  ")
        
        print(f"{indent}No ->")
        visualize_tree(node.no_child, indent + "  ")

# Function to predict with a single tree
def predict_with_tree(row, tree):
    current_node = tree
    
    while not current_node.is_leaf:
        closest_exemplar = row.get('closest_exemplar_id')
        
        if closest_exemplar == current_node.split_exemplar_id:
            current_node = current_node.yes_child
        else:
            current_node = current_node.no_child
    
    return current_node.predicted_label

# Main function to run the full pipeline
def main():
    # Initialize Spark session
    spark = SparkSession.builder.appName("ExemplarTreeEnsemble").getOrCreate()
    
    # Create DataFrame from the test data
    df = spark.createDataFrame(tsdata)
    
    # Repartition data
    rdd = repartition_sparkdf(df, 2)
    
    # Choose exemplars
    rdd_with_exemplars = rdd.mapPartitions(choose_exemplars)
    
    # Calculate DTW distances
    
    rdd_with_dtw = rdd_with_exemplars.mapPartitions(calc_dtw_distance)
    
    # Assign closest exemplars
    rdd_with_closest_exemplar = rdd_with_dtw.mapPartitions(assign_closest_exemplar)
    
    # Collect all data (for our small test dataset, this is fine)
    all_data = rdd_with_closest_exemplar.collect()
    
    # Group by partition
    partitions = {}
    for row in all_data:
        pid = row['partition_id']
        if pid not in partitions:
            partitions[pid] = []
        partitions[pid].append(row)
    
    # Build trees for each partition
    trees = []
    for pid, data in partitions.items():
        if data:  # Make sure there's data in this partition
            tree = build_partition_tree(data, pid, max_depth=3, min_samples=2)
            trees.append(tree)
    
    # Visualize each tree
    for i, tree in enumerate(trees):
        print(f"\nDecision Tree for Partition {tree.partition_id}:")
        visualize_tree(tree)
    
    # Make predictions and evaluate accuracy
    correct = 0
    total = 0
    
    for row in all_data:
        # Get predictions from all trees
        predictions = [predict_with_tree(row, tree) for tree in trees]
        # Take majority vote
        if predictions:
            counter = collections.Counter(predictions)
            ensemble_prediction = counter.most_common(1)[0][0]
            
            # Update accuracy counts
            total += 1
            if ensemble_prediction == row.get('label'):
                correct += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"\nEnsemble Accuracy: {accuracy:.4f} ({correct}/{total})")
    
    # Save the ensemble for future use (optional)
    with open("exemplar_tree_ensemble.json", "w") as f:
        json.dump([tree.to_dict() for tree in trees], f, indent=2)
    
    spark.stop()

# Run the main function
if __name__ == "__main__":
    main()


Decision Tree for Partition 0:
Is closest to exemplar dtw_distance_exemplar_2 (class 4)?
Data: [6.0, 7.0, 8.0, 9.0, 10.0], Gini reduction: 0.2526
Yes ->
  Leaf: Predict class 4
No ->
  Is closest to exemplar dtw_distance_exemplar_3 (class 1)?
  Data: [1.4, 2.0, 2.6, 3.2, 3.8], Gini reduction: 0.1671
  Yes ->
    Leaf: Predict class 3
  No ->
    Leaf: Predict class 2

Decision Tree for Partition 1:
Is closest to exemplar dtw_distance_exemplar_3 (class 4)?
Data: [0.9, 1.7, 1.2, 2.4, 2.8], Gini reduction: 0.2533
Yes ->
  Leaf: Predict class 1
No ->
  Is closest to exemplar dtw_distance_exemplar_1 (class 1)?
  Data: [1.2, 2.4, 3.6, 4.8, 6.0], Gini reduction: 0.2111
  Yes ->
    Leaf: Predict class 1
  No ->
    Is closest to exemplar dtw_distance_exemplar_4 (class 3)?
    Data: [0.3, 1.7, 1.6, 2.2, 2.6], Gini reduction: 0.0133
    Yes ->
      Leaf: Predict class 2
    No ->
      Leaf: Predict class 2

Ensemble Accuracy: 0.5000 (13/26)
