In [7]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging

from pyspark.sql import SparkSession
import os
from pyspark.sql import SparkSession
from data_ingestion import DataIngestion
from preprocessing import Preprocessor
from prediction_manager import PredictionManager
from local_model_manager import LocalModelManager
from evaluation import Evaluator
from utilities import show_compact
import time
import json
from random import sample
from dtaidistance import dtw

## ---


In [8]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[*]").appName("GenericRDD").getOrCreate()

# Access the SparkContext
sc = spark.sparkContext

# ---

In [9]:
data = [
    {"label": 1, "time_series": [1.0, 2.1, 3.2, 4.3, 5.4]},
    {"label": 2, "time_series": [2.0, 3.1, 4.2, 5.3, 6.4]},
    {"label": 3, "time_series": [3.0, 4.1, 5.2, 6.3, 7.4]},
    {"label": 4, "time_series": [4.0, 5.1, 6.2, 7.3, 8.4]},
    {"label": 1, "time_series": [1.5, 2.6, 3.7, 4.8, 5.9]},
    {"label": 2, "time_series": [2.5, 3.6, 4.7, 5.8, 6.9]},
    {"label": 3, "time_series": [3.5, 4.6, 5.7, 6.8, 7.9]},
    {"label": 4, "time_series": [4.5, 5.6, 6.7, 7.8, 8.9]}
]

rdd = sc.parallelize(data)

In [10]:
rdd = rdd.repartition(2)
rdd.getNumPartitions()

2

In [11]:
def print_partition_rows(index, iterator):
    # Add partition index to each row
    return [(index, row) for row in iterator]

# Use mapPartitionsWithIndex to include partition index
partitioned_rdd = rdd.mapPartitionsWithIndex(print_partition_rows)

# Collect and print the rows along with their partition index
for partition_index, row in partitioned_rdd.collect():
    print(f"Partition {partition_index}: {row}")

Partition 0: {'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4]}
Partition 0: {'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4]}
Partition 0: {'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9]}
Partition 0: {'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9]}
Partition 0: {'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9]}
Partition 1: {'label': 1, 'time_series': [1.0, 2.1, 3.2, 4.3, 5.4]}
Partition 1: {'label': 2, 'time_series': [2.0, 3.1, 4.2, 5.3, 6.4]}
Partition 1: {'label': 4, 'time_series': [4.5, 5.6, 6.7, 7.8, 8.9]}


# adding exemplar column

In [12]:
def sample_and_add_column(iterator):
    partition_data = list(iterator)
    sampled_element = sample(partition_data, 1)[0]['time_series']
    return iter([{**row, "exemplar": sampled_element} for row in partition_data])

rdd_with_sampled_column = rdd.mapPartitions(sample_and_add_column)

# Collect and print the updated RDD
for row in rdd_with_sampled_column.collect():
    print(row)

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9]}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9]}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9]}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9]}
{'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9]}
{'label': 1, 'time_series': [1.0, 2.1, 3.2, 4.3, 5.4], 'exemplar': [4.5, 5.6, 6.7, 7.8, 8.9]}
{'label': 2, 'time_series': [2.0, 3.1, 4.2, 5.3, 6.4], 'exemplar': [4.5, 5.6, 6.7, 7.8, 8.9]}
{'label': 4, 'time_series': [4.5, 5.6, 6.7, 7.8, 8.9], 'exemplar': [4.5, 5.6, 6.7, 7.8, 8.9]}


# calculating DTW distance using time series and exemplar columns

In [13]:
# def calc_dtw_distance(iterator):
#     partition_data = list(iterator)
#     time_series = partition_data['time_series']
#     exemplar = partition_data['exemplar']
#     dtw_distance = dtw.distance(time_series, exemplar)
#     return iter([{**row, "dtw_distance": dtw_distance} for row in partition_data])

def calc_dtw_distance(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        time_series = row['time_series']
        exemplar = row['exemplar']
        
        dtw_distance = dtw.distance(time_series, exemplar)
        
        updated_row = {**row, "dtw_distance": dtw_distance}
        updated_rows.append(updated_row)
    return iter(updated_rows)

rdd_with_dtw = rdd_with_sampled_column.mapPartitions(calc_dtw_distance)
for row in rdd_with_dtw.collect():
    print(row)

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9], 'dtw_distance': 1.118033988749895}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9], 'dtw_distance': 1.118033988749895}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9], 'dtw_distance': 3.1208973068654466}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9], 'dtw_distance': 1.42828568570857}
{'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9], 'exemplar': [3.5, 4.6, 5.7, 6.8, 7.9], 'dtw_distance': 0.0}
{'label': 1, 'time_series': [1.0, 2.1, 3.2, 4.3, 5.4], 'exemplar': [4.5, 5.6, 6.7, 7.8, 8.9], 'dtw_distance': 6.283311228962003}
{'label': 2, 'time_series': [2.0, 3.1, 4.2, 5.3, 6.4], 'exemplar': [4.5, 5.6, 6.7, 7.8, 8.9], 'dtw_distance': 4.085339643163099}
{'label': 4, 'time_series': [4.5, 5.6, 6.7, 7.8, 8.9], 'exemplar': [4.5, 5.6, 6.7, 7.8, 8.9], 'dtw_distance': 0

---

# WORKS FOR ANY NUM OF PARTITIONS AND EXEMPLARS

In [14]:
def create_sample_and_add_column_function(num_exemplars):
    def sample_and_add_column(iterator):
        partition_data = list(iterator)
        exemplars = []
        for row in sample(partition_data, min(num_exemplars, len(partition_data))):
            exemplars.append(row['time_series'])
        return iter([{**row, "exemplars": exemplars} for row in partition_data])
    return sample_and_add_column

# Example usage
num_exemplars = 2

chosen_exemplars = create_sample_and_add_column_function(num_exemplars)
rdd_with_exemplar_column = rdd.mapPartitions(chosen_exemplars)

for row in rdd_with_exemplar_column.collect():
    print(row)

print(f'\nrdd num partitions: {rdd_with_exemplar_column.getNumPartitions()}')

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]]}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]]}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]]}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]]}
{'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]]}
{'label': 1, 'time_series': [1.0, 2.1, 3.2, 4.3, 5.4], 'exemplars': [[4.5, 5.6, 6.7, 7.8, 8.9], [2.0, 3.1, 4.2, 5.3, 6.4]]}
{'label': 2, 'time_series': [2.0, 3.1, 4.2, 5.3, 6.4], 'exemplars': [[4.5, 5.6, 6.7, 7.8, 8.9], [2.0, 3.1, 4.2, 5.3, 6.4]]}
{'label': 4, 'time_series': [4.5, 5.6, 6.7, 7.8, 8.9], 'exemplars': [[4.5, 5.6, 6.7, 7.8, 8.9], [2.0, 3.1, 4.2, 5.3, 6.4]]}

rdd num

In [15]:
def calc_dtw_distance(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        time_series = row['time_series']
        exemplars = row['exemplars']
        
        # Calculate DTW distances for each exemplar
        dtw_distances = [dtw.distance(time_series, exemplar) for exemplar in exemplars]
        
        # Add each DTW distance as a separate column
        updated_row = {**row}
        for i, dtw_distance in enumerate(dtw_distances):
            updated_row[f"dtw_distance_exemplar_{i+1}"] = dtw_distance
        
        updated_rows.append(updated_row)
    
    return iter(updated_rows)

# Example usage
rdd_with_dtw = rdd_with_exemplar_column.mapPartitions(calc_dtw_distance)
for row in rdd_with_dtw.collect():
    print(row)

print(f'\nrdd num partitions: {rdd_with_dtw.getNumPartitions()}')

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]], 'dtw_distance_exemplar_1': 1.118033988749895, 'dtw_distance_exemplar_2': 1.1180339887498947}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]], 'dtw_distance_exemplar_1': 1.118033988749895, 'dtw_distance_exemplar_2': 2.2671568097509267}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]], 'dtw_distance_exemplar_1': 3.1208973068654466, 'dtw_distance_exemplar_2': 1.42828568570857}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]], 'dtw_distance_exemplar_1': 1.42828568570857, 'dtw_distance_exemplar_2': 0.0}
{'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]], 'dtw_distance_exemplar_1'

In [16]:
# not sure if this is needed

def assign_closest_exemplar(iterator):
    partition_data = list(iterator)

    for row in partition_data:
        # Check if there are DTW distances for exemplars
        exemplar_distances = {key: value for key, value in row.items() if key.startswith("dtw_distance_exemplar_")}
        
        if exemplar_distances:
            # Find the exemplar with the smallest DTW distance
            closest_exemplar = min(exemplar_distances, key=exemplar_distances.get)
            
            # Assign the closest exemplar to the row
            row["closest exemplar"] = closest_exemplar

    return iter(partition_data)

# Example usage
rdd_with_classification = rdd_with_dtw.mapPartitions(assign_closest_exemplar)
for row in rdd_with_classification.collect():
    print(row)

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]], 'dtw_distance_exemplar_1': 1.118033988749895, 'dtw_distance_exemplar_2': 1.1180339887498947, 'closest exemplar': 'dtw_distance_exemplar_2'}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]], 'dtw_distance_exemplar_1': 1.118033988749895, 'dtw_distance_exemplar_2': 2.2671568097509267, 'closest exemplar': 'dtw_distance_exemplar_1'}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]], 'dtw_distance_exemplar_1': 3.1208973068654466, 'dtw_distance_exemplar_2': 1.42828568570857, 'closest exemplar': 'dtw_distance_exemplar_2'}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplars': [[3.5, 4.6, 5.7, 6.8, 7.9], [2.5, 3.6, 4.7, 5.8, 6.9]], 'dtw_distance_exemplar_1': 1.42828568570857, 'dtw_distance_exemplar_2': 0.0, 'closest

In [17]:
def calculate_partition_gini(iterator):
    labels = [row['label'] for row in iterator]

    label_counts_dict = {}
    for label in labels:
        if label in label_counts_dict:
            label_counts_dict[label] += 1
        else:
            label_counts_dict[label] = 1
    
    total = sum(label_counts_dict.values())
    proportion_sqrd_values = [(count / total) ** 2 for count in label_counts_dict.values()]
    gini_impurity = 1 - sum(proportion_sqrd_values)
    
    return iter([gini_impurity])

In [18]:
# Example usage
gini_rdd = rdd_with_classification.mapPartitions(calculate_partition_gini)

# Collect and print the Gini impurity for each partition
i=0
for gini in gini_rdd.collect():
    print(f'gini impurity of partition {i+1}: {gini}')
    i+=1

gini impurity of partition 1: 0.72
gini impurity of partition 2: 0.6666666666666667


### trying splitting code

In [19]:
# tsdata = [
#     {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0], 'closest_exemplar': 'exemplar_1'},
#     {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9], 'closest_exemplar': 'exemplar_2'},
#     {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5], 'closest_exemplar': 'exemplar_1'},
#     {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2], 'closest_exemplar': 'exemplar_2'},
#     {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2], 'closest_exemplar': 'exemplar_1'},
#     {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9], 'closest_exemplar': 'exemplar_2'},
#     {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0], 'closest_exemplar': 'exemplar_1'},
#     {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5], 'closest_exemplar': 'exemplar_2'},
#     {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5], 'closest_exemplar': 'exemplar_1'},
#     {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5], 'closest_exemplar': 'exemplar_2'}
# ]

# ts_rdd = sc.parallelize(tsdata)
# ts_rdd = ts_rdd.repartition(2)

# print(f'ts_rdd num partitions: {ts_rdd.getNumPartitions()}')

# ts_rdd_gini = ts_rdd.mapPartitions(calculate_partition_gini)
# # Collect and print the Gini impurity for each partition
# i=0
# for gini in ts_rdd_gini.collect():
#     print(f'gini impurity of partition {i+1}: {gini}')
#     i+=1

In [20]:
# closestto1_yes = [row['label'] for row in tsdata if row['closest_exemplar'] == 'exemplar_1']
# closestto1_no = [row['label'] for row in tsdata if row['closest_exemplar'] != 'exemplar_1']

# print(closestto1_yes)
# print(closestto1_no)

In [21]:
def calculate_gini(labels):
    label_counts = {}
    for label in labels:
        label_counts[label] = label_counts.get(label, 0) + 1
    total = sum(label_counts.values())
    gini = 1 - sum((count / total) ** 2 for count in label_counts.values())
    return gini

def evaluate_splits_within_partition(iterator):
    partition_data = list(iterator)
    
    # If the partition is empty, return an empty iterator
    if not partition_data:
        return iter([])
    
    # Get all unique exemplar names in the partition
    unique_exemplars = set(row['closest exemplar'] for row in partition_data)
    
    results = []
    
    # Loop through each exemplar to evaluate splits
    for exemplar_name in unique_exemplars:
        # Split the data based on the current exemplar
        yes_split = [row for row in partition_data if row['closest exemplar'] == exemplar_name]
        no_split = [row for row in partition_data if row['closest exemplar'] != exemplar_name]
        
        # Calculate metrics for the split (e.g., Gini impurity)
        yes_labels = [row['label'] for row in yes_split]
        no_labels = [row['label'] for row in no_split]
        
        yes_gini = calculate_gini(yes_labels)
        no_gini = calculate_gini(no_labels)
        
        # Store the results for this split
        results.append({
            'exemplar': exemplar_name,
            'yes_gini': yes_gini,
            'no_gini': no_gini,
            'yes_split_size': len(yes_split),
            'no_split_size': len(no_split)
        })
    
    # Return the results as an iterator
    return iter(results)

In [22]:
# example usage
ts_rdd_split_results = rdd_with_classification.mapPartitions(evaluate_splits_within_partition)
# Collect and print the results
for result in ts_rdd_split_results.collect():
    print(result)

{'exemplar': 'dtw_distance_exemplar_1', 'yes_gini': 0.5, 'no_gini': 0.6666666666666667, 'yes_split_size': 2, 'no_split_size': 3}
{'exemplar': 'dtw_distance_exemplar_2', 'yes_gini': 0.6666666666666667, 'no_gini': 0.5, 'yes_split_size': 3, 'no_split_size': 2}
{'exemplar': 'dtw_distance_exemplar_1', 'yes_gini': 0.0, 'no_gini': 0.5, 'yes_split_size': 1, 'no_split_size': 2}
{'exemplar': 'dtw_distance_exemplar_2', 'yes_gini': 0.5, 'no_gini': 0.0, 'yes_split_size': 2, 'no_split_size': 1}


In [23]:
# def split_within_partition(iterator):
#     partition_data = list(iterator)
#     exemplar_name = partition_data[0]['closest_exemplar'] # randomly chosen exemplar name
    
#     # Split the data within the partition
#     yes_split = [row for row in partition_data if row['closest_exemplar'] == exemplar_name]
#     no_split = [row for row in partition_data if row['closest_exemplar'] != exemplar_name]
    
#     # Add a flag to indicate which split the row belongs to
#     yes_split = [{**row, 'split': 'yes'} for row in yes_split]
#     no_split = [{**row, 'split': 'no'} for row in no_split]
    
#     # Combine the splits and return as an iterator
#     return iter(yes_split + no_split)

In [24]:
# # example usage
# split_rdd = ts_rdd.mapPartitions(split_within_partition)

# # Collect and print the results
# for row in split_rdd.collect():
#     print(row)

OG returned best exemplar -->

In [59]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging
import json
import os
from random import sample
from dtaidistance import dtw

class GlobalModelManager:
    def __init__(self):
        self.num_exemplars = 2
        self.num_partitions = 2

    def train(self, df: DataFrame):
        rdd = self.partition_data(df)
        rdd = rdd.repartition(self.num_partitions)

        # choose_exemplars = self.choose_exemplars_function(self.num_exemplars)
        rdd_with_exemplar_column = rdd.mapPartitions(choose_exemplars)

        rdd_with_dtw = rdd_with_exemplar_column.mapPartitions(self.calc_dtw_distance)

        rdd_with_closest_exemplar = rdd_with_dtw.mapPartitions(self.assign_closest_exemplar)

        # Evaluate splits and collect minimal results
        results = rdd_with_closest_exemplar.mapPartitionsWithIndex(self.evaluate_splits_within_partition).collect()
        index = np.argmax([result['best_gini_reduction'] for result in results])
        bestsplitter = results[index]['best_split_exemplar']
        bestsplitters_gini_reduction = results[index]['best_gini_reduction']
        print(f"best time series to split on is {bestsplitter} with a gini reduction of {bestsplitters_gini_reduction}")
        return bestsplitter

    def partition_data(self, df: DataFrame) -> DataFrame:
        # Convert DataFrame to RDD and add partition_id
        rdd = df.rdd.mapPartitionsWithIndex(
            lambda idx, iter: [{**row.asDict(), "partition_id": idx} for row in iter]
        )
        return rdd

    def choose_exemplars(self, iterator):
        partition_data = list(iterator)
        if not partition_data:
            return iter([])
        
        # Group data by class
        grouped_data_by_class = {}
        for row in partition_data:
            label = row.get('label')
            if label is not None:
                grouped_data_by_class.setdefault(label, []).append(row)
        
        # Select one exemplar per class
        rows_containing_chosen_exemplars = []
        for label, instances in grouped_data_by_class.items():
            if instances:  # Ensure there are instances for the class
                exemplar = sample(instances, 1)[0]
                rows_containing_chosen_exemplars.append(exemplar)
        
        chosen_exemplars_list = [row['time_series'] for row in rows_containing_chosen_exemplars]
        
        # Remove chosen exemplars from the working data
        filtered_partition = [
            row for row in partition_data
            if row['time_series'] not in chosen_exemplars_list
        ]
        
        # Return rows with exemplars attached
        return iter([{**row, "exemplars": chosen_exemplars_list} for row in filtered_partition])

    def calc_dtw_distance(self, iterator):
        partition_data = list(iterator)
        updated_rows = []
        
        for row in partition_data:
            time_series = row.get('time_series', [])
            exemplars = row.get('exemplars', [])
            if not exemplars:
                continue  # Skip if no exemplars
            
            dtw_distances = [dtw.distance(time_series, exemplar) for exemplar in exemplars]
            
            updated_row = {**row}
            updated_row['exemplar_map'] = {}  # Map exemplar IDs to time series

            for i, (dtw_distance, exemplar_ts) in enumerate(zip(dtw_distances, exemplars)):
                exemplar_id = f"dtw_distance_exemplar_{i+1}"
                updated_row[exemplar_id] = dtw_distance
                updated_row['exemplar_map'][exemplar_id] = exemplar_ts
            
            updated_rows.append(updated_row)
        
        return iter(updated_rows)
    
    def assign_closest_exemplar(self, iterator):
        partition_data = list(iterator)
        updated_rows = []

        for row in partition_data:
            # Check if there are DTW distances for exemplars
            exemplar_distances = {
                key: value for key, value in row.items()
                if key.startswith("dtw_distance_exemplar_") and isinstance(value, (int, float))
            }
            
            updated_row = {**row}
            if exemplar_distances:
                # Find the exemplar with the smallest DTW distance
                closest_exemplar_id = min(exemplar_distances, key=exemplar_distances.get)
                
                # Assign the closest exemplar ID and time series
                updated_row["closest_exemplar_id"] = closest_exemplar_id
                updated_row["closest_exemplar_ts"] = row.get('exemplar_map', {}).get(closest_exemplar_id, [])
            else:
                # Handle case with no valid distances
                updated_row["closest_exemplar_id"] = None
                updated_row["closest_exemplar_ts"] = []

            updated_rows.append(updated_row)

        return iter(updated_rows)
    
    def calculate_gini(self, labels):
        if not labels:
            return 0
        label_counts = {}
        for label in labels:
            label_counts[label] = label_counts.get(label, 0) + 1
        total = sum(label_counts.values())
        gini = 1 - sum((count / total) ** 2 for count in label_counts.values()) if total > 0 else 0
        return gini

    def evaluate_splits_within_partition(self, index, iterator):
        partition_data = list(iterator)
        
        # Handle empty partition
        if not partition_data:
            return iter([{
                "partition_id": index,
                "best_split_exemplar": None,
                "best_gini_reduction": None
            }])
        
        # Calculate Gini impurity before splitting
        labels = [row.get('label') for row in partition_data if row.get('label') is not None]
        before_split_gini = self.calculate_gini(labels)
        
        # Get all unique exemplar IDs in the partition
        unique_exemplars = set(
            row['closest_exemplar_id'] for row in partition_data
            if row.get('closest_exemplar_id') is not None
        )
        
        # Handle case with no valid exemplars
        if not unique_exemplars:
            return iter([{
                "partition_id": index,
                "best_split_exemplar": None,
                "best_gini_reduction": None
            }])
        
        # Evaluate all possible splits
        best_split_exemplar = None
        best_gini_reduction = float('-inf')
        best_weighted_gini = None
        
        for exemplar_id in unique_exemplars:
            # Split the data based on the current exemplar
            yes_split = [r for r in partition_data if r.get('closest_exemplar_id') == exemplar_id]
            no_split = [r for r in partition_data if r.get('closest_exemplar_id') != exemplar_id]
            
            # Calculate Gini for each daughter node
            yes_labels = [r.get('label') for r in yes_split if r.get('label') is not None]
            no_labels = [r.get('label') for r in no_split if r.get('label') is not None]
            
            yes_gini = self.calculate_gini(yes_labels)
            no_gini = self.calculate_gini(no_labels)
            
            # Calculate weighted Gini after split
            total_size = len(yes_split) + len(no_split)
            weighted_gini = (yes_gini * len(yes_split) / total_size + no_gini * len(no_split) / total_size) if total_size > 0 else float('inf')
            
            # Calculate Gini reduction
            gini_reduction = before_split_gini - weighted_gini if total_size > 0 else float('-inf')
            
            # Get the exemplar time series for this exemplar_id
            exemplar = yes_split[0].get('exemplar_map', {}).get(exemplar_id, []) if yes_split else []
            
            # Update best split if this one has a larger Gini reduction
            if gini_reduction > best_gini_reduction:
                best_gini_reduction = gini_reduction
                best_split_exemplar = exemplar
                best_weighted_gini = weighted_gini
        
        # Handle case where no valid split was found
        if best_gini_reduction == float('-inf'):
            return iter([{
                "partition_id": index,
                "best_split_exemplar": None,
                "best_gini_reduction": None
            }])
        
        return iter([{
            "partition_id": index,
            "best_split_exemplar": best_split_exemplar,
            "best_gini_reduction": best_gini_reduction
        }])

<-- OG returned best exemplar

claude code -->

In [26]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
import logging
import json
import os
from random import sample, choice
from dtaidistance import dtw

class GlobalModelManager:
    def __init__(self, max_depth=3, min_samples_split=2, n_splitters=5):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.n_splitters = n_splitters
        self.num_partitions = 2
        self.distance_measures = ['dtw', 'euclidean', 'manhattan']  # Available distance measures

    def train(self, df: DataFrame):
        # Convert DataFrame to a format we can work with
        data = df.rdd.map(lambda row: row.asDict()).collect()
        
        # Build the tree recursively
        tree = self.build_tree(data)
        
        return tree

    def build_tree(self, data, depth=0):
        """
        Recursively build a Proximity Tree
        """
        # Base case: stop if max depth reached or not enough samples
        if (self.max_depth is not None and depth >= self.max_depth) or len(data) < self.min_samples_split:
            return self.create_leaf_node(data)
        
        # Check if node is pure (all samples belong to the same class)
        labels = [row.get('label') for row in data]
        unique_labels = set(labels)
        if len(unique_labels) == 1:
            return self.create_leaf_node(data)
        
        # Find the best split for this node
        best_split = self.find_best_split(data)
        
        # If no good split found, create a leaf node
        if best_split is None or best_split['gini_reduction'] <= 0:
            return self.create_leaf_node(data)
        
        # Create an internal node
        node = {
            'type': 'internal',
            'exemplars': best_split['exemplars'],
            'distance_measure': best_split['distance_measure'],
            'children': {}
        }
        
        # Split the data
        split_data = self.split_data(data, best_split)
        
        # Recursively build subtrees for each split
        for class_label, subset in split_data.items():
            if subset:  # Only create child if there are samples
                node['children'][class_label] = self.build_tree(subset, depth + 1)
            else:
                # Create a leaf node for empty splits
                node['children'][class_label] = self.create_leaf_node(data)
        
        return node
    
    def create_leaf_node(self, data):
        """
        Create a leaf node that predicts the majority class
        """
        labels = [row.get('label') for row in data if row.get('label') is not None]
        if not labels:
            return {'type': 'leaf', 'prediction': None}
        
        # Count occurrences of each class
        class_counts = {}
        for label in labels:
            class_counts[label] = class_counts.get(label, 0) + 1
        
        # Find majority class
        majority_class = max(class_counts, key=class_counts.get)
        
        return {
            'type': 'leaf',
            'prediction': majority_class,
            'class_distribution': class_counts
        }
    
    def find_best_split(self, data):
        """
        Find the best split by evaluating multiple candidate splitters
        """
        if not data:
            return None
        
        # Calculate Gini impurity before splitting
        labels = [row.get('label') for row in data if row.get('label') is not None]
        before_split_gini = self.calculate_gini(labels)
        
        best_split = None
        best_gini_reduction = -float('inf')
        
        # Try multiple candidate splitters
        for _ in range(self.n_splitters):
            # Randomly choose a distance measure
            distance_measure = choice(self.distance_measures)
            
            # Select exemplars (one per class)
            exemplars = self.select_exemplars_per_class(data)
            if not exemplars:
                continue
                
            # Calculate distances and assign each instance to closest exemplar
            assignments = self.assign_to_exemplars(data, exemplars, distance_measure)
            
            # Calculate Gini impurity after split
            weighted_gini = 0
            total_samples = len(data)
            valid_split = False
            
            for class_label, instances in assignments.items():
                if not instances:
                    continue
                    
                class_labels = [row.get('label') for row in instances]
                class_gini = self.calculate_gini(class_labels)
                weighted_gini += (len(instances) / total_samples) * class_gini
                valid_split = True
            
            # Skip invalid splits
            if not valid_split:
                continue
                
            # Calculate Gini reduction
            gini_reduction = before_split_gini - weighted_gini
            
            # Update best split if this one is better
            if gini_reduction > best_gini_reduction:
                best_gini_reduction = gini_reduction
                best_split = {
                    'exemplars': exemplars,
                    'distance_measure': distance_measure,
                    'gini_reduction': gini_reduction
                }
        
        return best_split
    
    def select_exemplars_per_class(self, data):
        """
        Select one exemplar per class from the data
        """
        # Group data by class
        class_groups = {}
        for row in data:
            label = row.get('label')
            if label is not None:
                if label not in class_groups:
                    class_groups[label] = []
                class_groups[label].append(row)
        
        # Select one exemplar per class
        exemplars = {}
        for label, instances in class_groups.items():
            if instances:
                exemplar = choice(instances)
                exemplars[label] = exemplar.get('time_series', [])
        
        return exemplars
    
    def assign_to_exemplars(self, data, exemplars, distance_measure):
        """
        Assign each instance to its closest exemplar
        """
        assignments = {label: [] for label in exemplars.keys()}
        
        for row in data:
            time_series = row.get('time_series', [])
            closest_label = None
            min_distance = float('inf')
            
            # Find closest exemplar
            for label, exemplar in exemplars.items():
                distance = self.calculate_distance(time_series, exemplar, distance_measure)
                if distance < min_distance:
                    min_distance = distance
                    closest_label = label
            
            # Assign to closest exemplar
            if closest_label is not None:
                assignments[closest_label].append(row)
        
        return assignments
    
    def calculate_distance(self, ts1, ts2, measure):
        """
        Calculate distance between two time series based on the specified measure
        """
        if measure == 'dtw':
            return dtw.distance(ts1, ts2)
        elif measure == 'euclidean':
            return np.linalg.norm(np.array(ts1) - np.array(ts2))
        elif measure == 'manhattan':
            return np.sum(np.abs(np.array(ts1) - np.array(ts2)))
        else:
            return dtw.distance(ts1, ts2)  # Default to DTW
    
    def split_data(self, data, split):
        """
        Split data based on the closest exemplar
        """
        return self.assign_to_exemplars(data, split['exemplars'], split['distance_measure'])
    
    def calculate_gini(self, labels):
        """
        Calculate Gini impurity for a set of labels
        """
        if not labels:
            return 0
            
        label_counts = {}
        for label in labels:
            label_counts[label] = label_counts.get(label, 0) + 1
            
        total = sum(label_counts.values())
        gini = 1 - sum((count / total) ** 2 for count in label_counts.values()) if total > 0 else 0
        
        return gini
    
    def predict(self, X):
        """
        Predict class labels for samples in X
        """
        if not hasattr(self, 'tree'):
            raise ValueError("Model has not been trained yet")
            
        predictions = []
        for instance in X:
            predictions.append(self.predict_single(instance, self.tree))
            
        return predictions
    
    def predict_single(self, instance, node):
        """
        Predict class label for a single instance
        """
        # If leaf node, return prediction
        if node['type'] == 'leaf':
            return node['prediction']
        
        # Find closest exemplar
        time_series = instance.get('time_series', [])
        closest_label = None
        min_distance = float('inf')
        
        for label, exemplar in node['exemplars'].items():
            distance = self.calculate_distance(time_series, exemplar, node['distance_measure'])
            if distance < min_distance:
                min_distance = distance
                closest_label = label
        
        # If no exemplar is close or branch doesn't exist, take first available child
        if closest_label is None or closest_label not in node['children']:
            closest_label = next(iter(node['children'].keys()))
            
        # Continue down the tree
        return self.predict_single(instance, node['children'][closest_label])

<-- claude code

In [60]:
tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]}
]

df = spark.createDataFrame(tsdata)


In [28]:
df.show(5, truncate=False)

+-----+-------------------------+
|label|time_series              |
+-----+-------------------------+
|1    |[1.2, 2.4, 3.6, 4.8, 6.0]|
|1    |[1.0, 1.8, 2.6, 3.4, 4.2]|
|1    |[0.9, 1.8, 2.7, 3.6, 4.5]|
|1    |[1.5, 2.1, 2.7, 3.3, 3.9]|
|1    |[0.8, 1.7, 2.5, 3.2, 4.0]|
+-----+-------------------------+
only showing top 5 rows



In [61]:
global_model = GlobalModelManager()
global_model.train(df)


best time series to split on is [2.4, 3.5, 4.6, 5.7, 6.8] with a gini reduction of 0.3375


[2.4, 3.5, 4.6, 5.7, 6.8]

In [None]:
def choose_exemplars(iterator):
        partition_data = list(iterator)
        if not partition_data:
            return iter([])
        
        # Group data by class
        grouped_data_by_class = {}
        for row in partition_data:
            label = row.get('label')
            if label is not None:
                grouped_data_by_class.setdefault(label, []).append(row)
        
        # Select one exemplar per class
        rows_containing_chosen_exemplars = []
        for label, instances in grouped_data_by_class.items():
            if instances:  # Ensure there are instances for the class
                exemplar = sample(instances, 1)[0]
                rows_containing_chosen_exemplars.append(exemplar)
        
        chosen_exemplars_list = [row['time_series'] for row in rows_containing_chosen_exemplars]
        
        # Remove chosen exemplars from the working data
        filtered_partition = [
            row for row in partition_data
            if row['time_series'] not in chosen_exemplars_list
        ]
        
        # Return rows with exemplars attached
        return iter([{**row, "exemplars": chosen_exemplars_list} for row in filtered_partition])

In [31]:
tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]}
]

ts_rdd = sc.parallelize(tsdata)
ts_rdd = ts_rdd.repartition(2)
ts_rdd.getNumPartitions()

2

In [32]:
def returndata(iterator):
    partition_data = list(iterator)
    return iter(partition_data)

# example usage
ts_rdd = ts_rdd.mapPartitions(returndata)
# Collect and print the results
[row.get('label') for row in ts_rdd.collect()]

[1, 2, 2, 2, 3, 3, 4, 4, 1, 1, 1, 2, 1, 1, 1, 1, 2, 3, 4, 1]

In [43]:
ts_rdd.collect()

[{'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
 {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
 {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
 {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
 {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
 {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
 {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
 {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
 {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
 {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
 {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
 {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
 {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
 {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
 {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
 {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
 {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
 {'label': 3, 'time_series': [1

In [None]:
grouped_data_by_class = {}
for row in ts_rdd.collect():
    label = row.get('label')
    grouped_data_by_class.setdefault(label, []).append(row)

# grouped_data_by_class
rows_containing_chosen_exemplars = []
for label, instances in grouped_data_by_class.items():
    if instances:  # Ensure there are instances for the class
        exemplar = sample(instances, 1)[0]
        rows_containing_chosen_exemplars.append(exemplar)

print(rows_containing_chosen_exemplars)

chosen_exemplars = [row['time_series'] for row in rows_containing_chosen_exemplars]
print(chosen_exemplars)

[{'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]}, {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]}, {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]}, {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]}]
[[0.9, 1.8, 2.7, 3.6, 4.5], [2.1, 3.3, 4.5, 5.7, 6.9], [0.5, 1.5, 2.5, 3.5, 4.5], [6.0, 7.0, 8.0, 9.0, 10.0]]


{'slayble': 'hello'}

In [35]:
exemplars = ts_rdd.mapPartitions(choose_exemplars).collect()

for exemplar in exemplars:
    print(exemplar)

{'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0], 'exemplars': [[1.1, 2.1, 3.1, 4.1, 5.1], [2.4, 3.5, 4.6, 5.7, 6.8], [2.0, 2.5, 3.0, 3.5, 4.0], [5.5, 6.6, 7.7, 8.8, 9.9]]}
{'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9], 'exemplars': [[1.1, 2.1, 3.1, 4.1, 5.1], [2.4, 3.5, 4.6, 5.7, 6.8], [2.0, 2.5, 3.0, 3.5, 4.0], [5.5, 6.6, 7.7, 8.8, 9.9]]}
{'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2], 'exemplars': [[1.1, 2.1, 3.1, 4.1, 5.1], [2.4, 3.5, 4.6, 5.7, 6.8], [2.0, 2.5, 3.0, 3.5, 4.0], [5.5, 6.6, 7.7, 8.8, 9.9]]}
{'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5], 'exemplars': [[1.1, 2.1, 3.1, 4.1, 5.1], [2.4, 3.5, 4.6, 5.7, 6.8], [2.0, 2.5, 3.0, 3.5, 4.0], [5.5, 6.6, 7.7, 8.8, 9.9]]}
{'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5], 'exemplars': [[1.1, 2.1, 3.1, 4.1, 5.1], [2.4, 3.5, 4.6, 5.7, 6.8], [2.0, 2.5, 3.0, 3.5, 4.0], [5.5, 6.6, 7.7, 8.8, 9.9]]}
{'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5], 'exemplars': [[1.1, 2.1, 3.1, 4.1, 5.1], [2.4, 3.5, 4.6

In [36]:
lst = [{'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},{'label': 4, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]}]

lst

[{'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
 {'label': 4, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]}]