In [2]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging

from pyspark.sql import SparkSession
import os
from pyspark.sql import SparkSession
from data_ingestion import DataIngestion
from preprocessing import Preprocessor
from prediction_manager import PredictionManager
from local_model_manager import LocalModelManager
from evaluation import Evaluator
from utilities import show_compact
import time
import json
from random import sample
from dtaidistance import dtw

from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[*]").appName("GenericRDD").getOrCreate()

# Access the SparkContext
sc = spark.sparkContext

In [98]:
tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]},
    {'label': 1, 'time_series': [0.9, 1.4, 1.9, 2.4, 2.9]},
    {'label': 1, 'time_series': [1.4, 2.0, 2.6, 3.2, 3.8]},
    {'label': 2, 'time_series': [2.2, 3.1, 4.0, 4.9, 5.8]},
    {'label': 2, 'time_series': [2.6, 3.2, 3.8, 4.4, 5.0]},
    {'label': 3, 'time_series': [1.2, 2.0, 2.8, 3.6, 4.4]},
    {'label': 3, 'time_series': [0.6, 1.3, 2.0, 2.7, 3.4]},
    {'label': 4, 'time_series': [6.3, 6.5, 6.7, 6.9, 7.1]},
    {'label': 4, 'time_series': [7.0, 7.8, 8.6, 9.4, 10.2]},
    {'label': 4, 'time_series': [6.5, 7.0, 7.5, 8.0, 8.5]},
    {'label': 1, 'time_series': [0.5, 1.0, 1.5, 2.0, 2.5]},
    {'label': 2, 'time_series': [0.6, 1.4, 1.3, 2.1, 2.5]},
    {'label': 3, 'time_series': [0.3, 1.7, 1.6, 2.2, 2.6]},
    {'label': 4, 'time_series': [0.2, 1.9, 1.6, 2.3, 2.7]},
    {'label': 4, 'time_series': [0.9, 1.7, 1.2, 2.4, 2.8]}
]

df = spark.createDataFrame(tsdata)

# 1

In [99]:
def repartition_sparkdf(df, num_partitions):
    rdd = df.rdd
    rdd = rdd.repartition(num_partitions)
    rdd = rdd.mapPartitionsWithIndex(
            lambda idx, iter: [{**row.asDict(), "partition_id": idx} for row in iter]
        )
    return rdd

# example usage
rdd = repartition_sparkdf(df, 2)
print(rdd.getNumPartitions())  # should print 1
rdd.collect()

2


[{'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5], 'partition_id': 0},
 {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0], 'partition_id': 0},
 {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9], 'partition_id': 0},
 {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5], 'partition_id': 0},
 {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1], 'partition_id': 0},
 {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1], 'partition_id': 0},
 {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0], 'partition_id': 0},
 {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8], 'partition_id': 0},
 {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5], 'partition_id': 0},
 {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2], 'partition_id': 0},
 {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0], 'partition_id': 0},
 {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3], 'partition_id': 0},
 {'label': 1, 'time_series': [0.9, 1.4, 1.9, 2.4, 2.9], 'partition_id': 0},
 {'label': 

# 2

In [100]:
def choose_exemplars(iterator):
    partition_data = list(iterator)
    if not partition_data:
        return iter([])
    
    # Group data by class
    grouped_data_by_class = {}
    for row in partition_data:
        label = row.get('label')
        if label is not None:
            grouped_data_by_class.setdefault(label, []).append(row)
    
    # Select one exemplar per class
    chosen_exemplars = []
    for label, instances in grouped_data_by_class.items():
        if instances:  # Ensure there are instances for the class
            exemplar = sample(instances, 1)[0]
            chosen_exemplars.append(exemplar['time_series'])
    
    # Remove chosen exemplars from the working data
    filtered_partition = [
        row for row in partition_data
        if row['time_series'] not in chosen_exemplars
    ]
    
    # Return rows with individual exemplar columns
    result = []
    for row in filtered_partition:
        new_row = {**row}
        # Add each exemplar as its own column
        for i, exemplar in enumerate(chosen_exemplars):
            new_row[f"exemplar_{i+1}"] = exemplar
        result.append(new_row)
    
    return iter(result)

# example usage
rdd_with_exemplars = rdd.mapPartitions(choose_exemplars)
rdd_with_exemplars.collect()

[{'label': 3,
  'time_series': [2.0, 2.5, 3.0, 3.5, 4.0],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_4': [1.9, 2.8, 3.7, 4.6, 5.5]},
 {'label': 4,
  'time_series': [5.5, 6.6, 7.7, 8.8, 9.9],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_4': [1.9, 2.8, 3.7, 4.6, 5.5]},
 {'label': 4,
  'time_series': [6.1, 6.2, 6.3, 6.4, 6.5],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_4': [1.9, 2.8, 3.7, 4.6, 5.5]},
 {'label': 1,
  'time_series': [0.7, 1.3, 1.9, 2.5, 3.1],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_4': [1.9, 2.8, 3.7, 

# 3

In [101]:
def calc_dtw_distance(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        time_series = row.get('time_series', [])
        
        # Check for individual exemplar columns (exemplar_1, exemplar_2, etc.)
        exemplar_columns = {k: v for k, v in row.items() if k.startswith('exemplar_') and isinstance(v, list)}
        
        if not exemplar_columns:
            # Try to get exemplars from the 'exemplars' list if individual columns aren't found
            exemplars = row.get('exemplars', [])
            if not exemplars:
                continue  # Skip if no exemplars found
            
            # Calculate DTW distances for each exemplar in the list
            updated_row = {**row}
            for i, exemplar in enumerate(exemplars):
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{i+1}"] = dtw_distance
            
            updated_rows.append(updated_row)
        else:
            # Calculate DTW distances for each exemplar column
            updated_row = {**row}
            for col_name, exemplar in exemplar_columns.items():
                # Extract index from column name (e.g., "exemplar_1" -> "1")
                idx = col_name.split('_')[1]
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{idx}"] = dtw_distance
            
            updated_rows.append(updated_row)
    
    return iter(updated_rows)

# example usage
rdd_with_dtw = rdd_with_exemplars.mapPartitions(calc_dtw_distance)
rdd_with_dtw.collect()[:3]

[{'label': 3,
  'time_series': [2.0, 2.5, 3.0, 3.5, 4.0],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_4': [1.9, 2.8, 3.7, 4.6, 5.5],
  'dtw_distance_exemplar_1': 1.7320508075688772,
  'dtw_distance_exemplar_2': 11.20267825120404,
  'dtw_distance_exemplar_3': 0.7348469228349536,
  'dtw_distance_exemplar_4': 1.6703293088490065},
 {'label': 4,
  'time_series': [5.5, 6.6, 7.7, 8.8, 9.9],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5, 2.5, 3.5, 4.5],
  'exemplar_2': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_3': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_4': [1.9, 2.8, 3.7, 4.6, 5.5],
  'dtw_distance_exemplar_1': 10.78424777163433,
  'dtw_distance_exemplar_2': 0.7416198487095661,
  'dtw_distance_exemplar_3': 11.244998888394788,
  'dtw_distance_exemplar_4': 7.784600182411426},
 {'label': 4,
  'time_series': [6.1, 6.2, 6.3, 6.4, 6.5],
  'partition_id': 0,
  'exemplar_1': [0.5, 1.5,

# 4

In [102]:
def assign_closest_exemplar(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        # Get all DTW distances
        dtw_distances = {k: v for k, v in row.items() if k.startswith('dtw_distance_exemplar_')}
        
        if not dtw_distances:
            # Create a simplified row without exemplar columns
            simplified_row = {k: v for k, v in row.items() 
                             if not k.startswith('exemplar_')}
            updated_rows.append(simplified_row)
            continue
        
        # Find the closest exemplar based on the minimum DTW distance
        closest_exemplar_key = min(dtw_distances, key=dtw_distances.get)
        min_distance = dtw_distances[closest_exemplar_key]
        
        # Extract exemplar number from the key (e.g., "dtw_distance_exemplar_1" -> "1")
        exemplar_num = closest_exemplar_key.split('_')[-1]
        
        # Get the corresponding exemplar time series data
        exemplar_key = f'exemplar_{exemplar_num}'
        exemplar_time_series = row.get(exemplar_key, None)
        
        # Create a new row without the DTW distance columns and exemplar columns
        updated_row = {k: v for k, v in row.items() 
                      if not k.startswith('dtw_distance_exemplar_') and not k.startswith('exemplar_')}
        
        # Add information about the closest exemplar
        updated_row['closest_exemplar_id'] = closest_exemplar_key
        # updated_row['closest_exemplar_distance'] = min_distance
        updated_row['closest_exemplar_data'] = exemplar_time_series
        
        updated_rows.append(updated_row)
    
    return iter(updated_rows)

# Example usage
rdd_with_closest_exemplar = rdd_with_dtw.mapPartitions(assign_closest_exemplar)
rdd_with_closest_exemplar.collect()

[{'label': 3,
  'time_series': [2.0, 2.5, 3.0, 3.5, 4.0],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_3',
  'closest_exemplar_data': [1.4, 2.0, 2.6, 3.2, 3.8]},
 {'label': 4,
  'time_series': [5.5, 6.6, 7.7, 8.8, 9.9],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_2',
  'closest_exemplar_data': [6.0, 7.0, 8.0, 9.0, 10.0]},
 {'label': 4,
  'time_series': [6.1, 6.2, 6.3, 6.4, 6.5],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_2',
  'closest_exemplar_data': [6.0, 7.0, 8.0, 9.0, 10.0]},
 {'label': 1,
  'time_series': [0.7, 1.3, 1.9, 2.5, 3.1],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_3',
  'closest_exemplar_data': [1.4, 2.0, 2.6, 3.2, 3.8]},
 {'label': 1,
  'time_series': [1.1, 2.1, 3.1, 4.1, 5.1],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_4',
  'closest_exemplar_data': [1.9, 2.8, 3.7, 4.6, 5.5]},
 {'label': 1,
  'time_series': [0.6, 1.2, 1.8, 2.4, 3.0],
  'partiti

# 5

In [103]:
def calculate_gini(labels):
    if not labels:
        return 0
    label_counts = {}
    for label in labels:
        label_counts[label] = label_counts.get(label, 0) + 1
    total = sum(label_counts.values())
    gini = 1 - sum((count / total) ** 2 for count in label_counts.values()) if total > 0 else 0
    return gini

def evaluate_splits_within_partition(index, iterator):
    partition_data = list(iterator)
    results = []
    
    # Calculate Gini impurity before splitting
    labels = [row.get('label') for row in partition_data if row.get('label') is not None]
    before_split_gini = calculate_gini(labels)
    
    # Get all unique exemplars in the partition
    unique_exemplars = set(
        row['closest_exemplar_id'] for row in partition_data
        if row.get('closest_exemplar_id') is not None
    )
    
    # Create a mapping of exemplar_id to exemplar_data and exemplar_label
    exemplar_data_map = {}
    exemplar_label_map = {}
    
    for row in partition_data:
        exemplar_id = row.get('closest_exemplar_id')
        if exemplar_id and exemplar_id not in exemplar_data_map:
            exemplar_data_map[exemplar_id] = row.get('closest_exemplar_data')
            # Find the label of the exemplar itself
            # This assumes the first row with this exemplar_id has the correct label
            exemplar_label_map[exemplar_id] = row.get('label')
    
    # Evaluate all possible splits
    for exemplar_id in unique_exemplars:
        # Split the data based on the current exemplar
        yes_split = [r for r in partition_data if r.get('closest_exemplar_id') == exemplar_id]
        no_split = [r for r in partition_data if r.get('closest_exemplar_id') != exemplar_id]
        
        # Calculate Gini for each daughter node
        yes_labels = [r.get('label') for r in yes_split if r.get('label') is not None]
        no_labels = [r.get('label') for r in no_split if r.get('label') is not None]
        
        yes_gini = calculate_gini(yes_labels)
        no_gini = calculate_gini(no_labels)
        
        # Calculate weighted Gini after split
        total_size = len(yes_split) + len(no_split)
        weighted_gini = (yes_gini * len(yes_split) / total_size + no_gini * len(no_split) / total_size) if total_size > 0 else float('inf')
        
        # Calculate Gini reduction
        gini_reduction = before_split_gini - weighted_gini if total_size > 0 else float('-inf')
        
        # Get the exemplar data and label for this exemplar_id
        exemplar_data = exemplar_data_map.get(exemplar_id)
        exemplar_label = exemplar_label_map.get(exemplar_id)
        
        # Add this split evaluation to results
        results.append({
            "partition_id": index,
            "exemplar_id": exemplar_id,
            "exemplar_data": exemplar_data,
            "exemplar_label": exemplar_label,
            "gini_reduction": gini_reduction   
        })
    
    return iter(results)

# Example usage
rdd_with_splits = rdd_with_closest_exemplar.mapPartitionsWithIndex(evaluate_splits_within_partition)
rdd_with_splits.collect()

[{'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_4',
  'exemplar_data': [1.9, 2.8, 3.7, 4.6, 5.5],
  'exemplar_label': 1,
  'gini_reduction': 0.1376420454545454},
 {'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_3',
  'exemplar_data': [1.4, 2.0, 2.6, 3.2, 3.8],
  'exemplar_label': 3,
  'gini_reduction': 0.16679067460317454},
 {'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_2',
  'exemplar_data': [6.0, 7.0, 8.0, 9.0, 10.0],
  'exemplar_label': 4,
  'gini_reduction': 0.25260416666666663},
 {'partition_id': 1,
  'exemplar_id': 'dtw_distance_exemplar_4',
  'exemplar_data': [0.3, 1.7, 1.6, 2.2, 2.6],
  'exemplar_label': 2,
  'gini_reduction': 0.09499999999999997},
 {'partition_id': 1,
  'exemplar_id': 'dtw_distance_exemplar_1',
  'exemplar_data': [1.2, 2.4, 3.6, 4.8, 6.0],
  'exemplar_label': 1,
  'gini_reduction': 0.04222222222222227},
 {'partition_id': 1,
  'exemplar_id': 'dtw_distance_exemplar_3',
  'exemplar_data': [0.9, 1.7, 1.2, 2.4, 2.8],
  'exempl

In [104]:
def build_decision_tree_node(exemplar_evaluations):
    """
    Find the best exemplar across all evaluations to use for splitting.
    
    Parameters:
    - exemplar_evaluations: List of dicts containing exemplar evaluation results
    
    Returns:
    - Dictionary with best split information
    """
    # Find the best exemplar across all partitions (highest gini_reduction)
    best_exemplar = max(exemplar_evaluations, key=lambda x: x['gini_reduction'])
    
    # Create a decision node
    node = {
        'split_feature': best_exemplar['exemplar_id'],
        'split_value': best_exemplar['exemplar_data'],
        'exemplar_label': best_exemplar['exemplar_label'],
        'gini_reduction': best_exemplar['gini_reduction']
    }
    
    return node

def split_data_by_exemplar(best_exemplar_id):
    """
    Returns a function that can be used with mapPartitions to split data based on the best exemplar.
    
    Parameters:
    - best_exemplar_id: ID of the best exemplar to split on
    
    Returns:
    - A function for use with mapPartitions
    """
    def split_partition(iterator):
        yes_branch = []
        no_branch = []
        
        for row in iterator:
            if row.get('closest_exemplar_id') == best_exemplar_id:
                # Add branch indicator
                row_with_branch = row.copy()
                row_with_branch['branch'] = 'yes'
                yes_branch.append(row_with_branch)
            else:
                # Add branch indicator
                row_with_branch = row.copy()
                row_with_branch['branch'] = 'no'
                no_branch.append(row_with_branch)
        
        # Return all rows with branch indicators
        return iter(yes_branch + no_branch)
    
    return split_partition

# Example usage:
# 1. Collect the exemplar evaluations
exemplar_evaluations = rdd_with_splits.collect()

# 2. Find the best exemplar
node = build_decision_tree_node(exemplar_evaluations)

# 3. Use the best exemplar to split the data
split_fn = split_data_by_exemplar(node['split_feature'])
rdd_with_branch = rdd_with_closest_exemplar.mapPartitions(split_fn)

# 4. Now you can separate the yes/no branches if needed
yes_branch = rdd_with_branch.filter(lambda x: x.get('branch') == 'yes')
no_branch = rdd_with_branch.filter(lambda x: x.get('branch') == 'no')

In [None]:
# root_node

{'split_feature': 'dtw_distance_exemplar_1',
 'split_value': [1.0, 1.8, 2.6, 3.4, 4.2],
 'gini_reduction': 0.23645128205128213,
 'is_leaf': False,
 'left_child': None,
 'right_child': None}